Skip to content

M05C05: Monoids & Semigroups – Aggregation That Never Lies

Progression Note

By the end of Module 5, you will model every domain concept as immutable algebraic data types (products and tagged sums), eliminating whole classes of runtime errors through exhaustive pattern matching, mypy-checked totality, and pure serialization contracts.

Module Focus Key Outcomes
4 Safe Recursion & Error Handling Stack-safe tree recursion, folds, Result/Option, streaming validation/retries
5 Advanced Type-Driven Design ADTs, exhaustive pattern matching, total functions, refined types
6 Monadic Flows as Composable Pipelines bind/and_then, Reader/State-like patterns, error-typed flows

Core question
How do you replace order-dependent, quadratic-time, or mutable aggregation with lawful monoids and semigroups that guarantee identical results regardless of grouping — enabling safe, parallel, tree-based folds for logs, metrics, configs, and error sets?

Every production pipeline eventually hits the same three bugs:

  1. “Why does the total count change when I change chunk size?”
  2. “Why does merging logs take 45 seconds on 2M lines?”
  3. “Why did my average latency become NaN?”

The naïve pattern everyone writes first:

# BEFORE – mutable, order-dependent, quadratic
total = 0
logs = ""
for chunk in chunks:                     # left-to-right only
    total += len(chunk.text)
    logs += f"[chunk {chunk.id}] ok\n"   # O(N²) string concat

metrics = {"count": 0, "sum": 0.0}
for r in results:
    if r.is_ok():
        metrics["count"] += 1
        metrics["sum"] += r.value.latency_ms   # mutable, race-prone

Order matters, quadratic time, mutable state, NaN silently creeps in.

The production pattern: every aggregatable thing is a Monoid (associative + identity) or Semigroup (just associative). Fold with tree_reduce → near-linear time, O(log N) memory, fully parallelizable, mathematically proven identical result.

# AFTER – one lawful line, parallel-safe
total_chars = fold_map(SUM_INT, lambda t: Sum(len(t)), chunk_texts).value

log_lines = tree_reduce(LIST_STR, per_chunk_log_lines)
final_log = "".join(log_lines)

pipeline_metrics = tree_reduce(METRICS, per_chunk_metrics)

Same result on 1 core or 128 cores. No quadratic blowup. No mutable state. Forever.

Audience: Engineers who have ever seen “total differs by grouping” or “logs too slow” and want mathematically guaranteed, parallelizable aggregation.

Outcome 1. Every +=, extend, dict.update replaced with monoidal tree_reduce. 2. All aggregations proven associative + identity (when applicable). 3. Near-linear time (O(N) for fixed-size structs, O(N log N) for lists/dicts), O(log N) memory, fully parallelizable folds.

Tiny Non-Domain Example – Parallel Sum & Log Merge

numbers = [1, 2, 3, 4, 5, 6, 7, 8]
total = tree_reduce(SUM_INT, map(Sum, numbers)).value   # 36, any grouping

log_lines = [["a\n", "b\n"], ["c\n"], ["d\n", "e\n", "f\n"]]
merged = "".join(tree_reduce(LIST_STR, log_lines))
# "a\nb\nc\nd\ne\nf\n" — order preserved, near-linear time

Why Monoids? (Three bullets every engineer should internalise)

  • Associativity → parallelizable: (a <> b) <> c == a <> (b <> c)tree_reduce is safe and fast.
  • Identity → empty-safe folds: empty <> x == xtree_reduce on empty iterable returns correct zero.
  • Lawful → refactor-safe: Adding a new metric field never silently breaks totals.

1. Laws & Invariants (machine-checked)

Law Formal Statement Enforcement
Associativity m.combine(a, m.combine(b, c)) == m.combine(m.combine(a, b), c) Hypothesis on core monoids (Sum, lists, metrics)
Left Identity m.combine(m.empty(), x) == x Hypothesis on core monoids (Sum, lists, metrics)
Right Identity m.combine(x, m.empty()) == x Hypothesis on core monoids (Sum, lists, metrics)
Finite Metrics No NaN/inf in numeric fields Guarded combine + tests

2. Decision Table – Monoid vs Semigroup

Data Can be empty? Recommended
Logs, config dicts, lists Yes Monoid (has safe empty)
Validation failures No Semigroup → wrap in Monoid with () identity for folds
Metrics, counts Yes Monoid (SUM_INT, METRICS, etc.)

3. Public API (fp/monoid.py – mypy --strict clean)

from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Generic, Iterable, TypeVar, Tuple, Dict, List
import math

__all__ = [
    "Monoid", "Semi",
    "fold", "fold_map", "tree_reduce",
    "SUM_INT", "LIST_STR", "DICT_RIGHT_WINS", "map_monoid", "product_monoid",
    "product3", "METRICS", "nonempty_tuple_semigroup", "dedup_stable_semigroup",
]

T = TypeVar("T")
U = TypeVar("U")
E = TypeVar("E")
T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")

@dataclass(frozen=True, slots=True)
class Monoid(Generic[T]):
    empty: Callable[[], T]
    combine: Callable[[T, T], T]

@dataclass(frozen=True, slots=True)
class Semi(Generic[T]):
    combine: Callable[[T, T], T]

# Basic folds
def fold(m: Monoid[T], xs: Iterable[T]) -> T:
    acc = m.empty()
    for x in xs:
        acc = m.combine(acc, x)
    return acc

def fold_map(m: Monoid[T], f: Callable[[U], T], xs: Iterable[U]) -> T:
    return fold(m, map(f, xs))

def tree_reduce(m: Monoid[T], xs: Iterable[T], chunk: int = 2048) -> T:
    buf: List[T] = []
    for x in xs:
        buf.append(x)
        if len(buf) >= chunk:
            buf = [_tree_combine(m, buf)]
    return _tree_combine(m, buf) if buf else m.empty()

def _tree_combine(m: Monoid[T], items: List[T]) -> T:
    while len(items) > 1:
        next_items: List[T] = []
        it = iter(items)
        for a in it:
            b = next(it, None)
            next_items.append(a if b is None else m.combine(a, b))
        items = next_items
    return items[0] if items else m.empty()

# Canonical monoids
@dataclass(frozen=True, slots=True)
class Sum:
    value: int

SUM_INT: Monoid[Sum] = Monoid(
    empty=lambda: Sum(0),
    combine=lambda a, b: Sum(a.value + b.value),
)

LIST_STR: Monoid[List[str]] = Monoid(
    empty=list,
    combine=lambda a, b: a + b,
)

# Monomorphic "right wins" config monoid. For fully typed nested configs, prefer `map_monoid`.
DICT_RIGHT_WINS: Monoid[Dict[str, object]] = Monoid(
    empty=dict,
    combine=lambda a, b: {**a, **b},
)

def map_monoid(value_m: Monoid[T]) -> Monoid[Dict[str, T]]:
    def empty() -> Dict[str, T]:
        return {}
    def combine(a: Dict[str, T], b: Dict[str, T]) -> Dict[str, T]:
        keys = a.keys() | b.keys()
        out: Dict[str, T] = {}
        for k in keys:
            if k in a and k in b:
                out[k] = value_m.combine(a[k], b[k])
            elif k in a:
                out[k] = a[k]
            else:
                out[k] = b[k]
        return out
    return Monoid(empty, combine)

def product_monoid(m1: Monoid[T], m2: Monoid[U]) -> Monoid[Tuple[T, U]]:
    return Monoid(
        empty=lambda: (m1.empty(), m2.empty()),
        combine=lambda a, b: (m1.combine(a[0], b[0]), m2.combine(a[1], b[1])),
    )

def product3(m1: Monoid[T1], m2: Monoid[T2], m3: Monoid[T3]) -> Monoid[Tuple[T1, T2, T3]]:
    return Monoid(
        empty=lambda: (m1.empty(), m2.empty(), m3.empty()),
        combine=lambda a, b: (
            m1.combine(a[0], b[0]),
            m2.combine(a[1], b[1]),
            m3.combine(a[2], b[2]),
        ),
    )

@dataclass(frozen=True, slots=True)
class Metrics:
    processed: int = 0
    succeeded: int = 0
    latency_sum_ms: float = 0.0
    latency_max_ms: float = 0.0

def _check_finite(x: float) -> float:
    if not math.isfinite(x):
        raise ValueError(f"non-finite metric value: {x}")
    return x

METRICS: Monoid[Metrics] = Monoid(
    empty=Metrics,
    combine=lambda a, b: Metrics(
        processed=a.processed + b.processed,
        succeeded=a.succeeded + b.succeeded,
        latency_sum_ms=_check_finite(a.latency_sum_ms + b.latency_sum_ms),
        latency_max_ms=max(a.latency_max_ms, b.latency_max_ms),
    ),
)

# Semigroups (for guaranteed non-empty data)
def nonempty_tuple_semigroup() -> Semi[Tuple[T, ...]]:
    return Semi(lambda a, b: a + b)

def dedup_stable_semigroup() -> Semi[Tuple[E, ...]]:
    def combine(a: Tuple[E, ...], b: Tuple[E, ...]) -> Tuple[E, ...]:
        seen: set[E] = set()
        out: List[E] = []
        for e in a + b:
            if e not in seen:
                seen.add(e)
                out.append(e)
        return tuple(out)
    return Semi(combine)

4. Reference Implementations (continued)

4.1 Before vs After – Pipeline Metrics & Logs

# BEFORE – mutable, slow, order-dependent
total = 0
logs = ""
for chunk in chunks:
    total += len(chunk.text)
    logs += f"[chunk {chunk.id}] ok\n"   # O(N²)

# AFTER – lawful, parallel, near-linear
total_chars = fold_map(SUM_INT, lambda t: Sum(len(t)), chunk_texts).value

log_lines = tree_reduce(LIST_STR, per_chunk_log_lines)
final_log = "".join(log_lines)

4.2 RAG Integration – Full Pipeline Aggregation

# Validation errors are a semigroup (no empty). For folds that may be empty,
# we wrap it in a monoid with () as identity (never used when data present).
errors_monoid = Monoid(
    empty=lambda: (),
    combine=lambda a, b: nonempty_tuple_semigroup().combine(a, b),
)

triple_monoid = product3(METRICS, LIST_STR, errors_monoid)

aggregated = tree_reduce(triple_monoid, per_chunk_triple)
pipeline_metrics, log_lines, validation_errors = aggregated

final_log = "".join(log_lines)

if validation_errors:
    return Err(make_errinfo(
        code="VALIDATION",
        msg="batch failed",
        meta={"errors": list(validation_errors)},
    ))

return Ok((pipeline_metrics, final_log))

5. Property-Based Proofs (tests/test_monoid_laws.py)

import pytest
from hypothesis import given, strategies as st
from funcpipe_rag.fp.monoid import *

@given(a=st.integers(), b=st.integers(), c=st.integers())
def test_sum_int_laws(a, b, c):
    m = SUM_INT
    assert m.combine(m.combine(Sum(a), Sum(b)), Sum(c)) == m.combine(Sum(a), m.combine(Sum(b), Sum(c)))
    e = m.empty()
    assert m.combine(e, Sum(a)) == Sum(a)
    assert m.combine(Sum(a), e) == Sum(a)

@given(xs=st.lists(st.integers()))
def test_tree_reduce_equals_sum(xs):
    sums = [Sum(x) for x in xs]
    assert tree_reduce(SUM_INT, sums).value == sum(xs)

@given(a=st.lists(st.text()), b=st.lists(st.text()), c=st.lists(st.text()))
def test_list_str_laws(a, b, c):
    m = LIST_STR
    # associativity
    assert m.combine(m.combine(a, b), c) == m.combine(a, m.combine(b, c))
    # identities
    e = m.empty()
    assert m.combine(e, a) == a
    assert m.combine(a, e) == a


_finite_floats = st.floats(allow_nan=False, allow_infinity=False, width=32)
metrics_strategy = st.builds(
    Metrics,
    processed=st.integers(min_value=0, max_value=10**9),
    succeeded=st.integers(min_value=0, max_value=10**9),
    latency_sum_ms=_finite_floats,
    latency_max_ms=_finite_floats,
)


@given(a=metrics_strategy, b=metrics_strategy, c=metrics_strategy)
def test_metrics_monoid_laws(a, b, c):
    m = METRICS
    # associativity
    assert m.combine(m.combine(a, b), c) == m.combine(a, m.combine(b, c))
    # identities
    e = m.empty()
    assert m.combine(e, a) == a
    assert m.combine(a, e) == a

def test_metrics_finite_guard():
    with pytest.raises(ValueError):
        METRICS.combine(Metrics(latency_sum_ms=float("nan")), Metrics())

6. Big-O & Allocation Guarantees

Operation Time Memory Notes
m.combine (list) O(N) O(N) Use LIST_STR + "".join at sink
tree_reduce (list) O(N log N) worst O(log N) Avoids left-fold O(N²)
map_monoid combine O(N) O(N) Nested merge
tree_reduce total O(N log N) worst O(chunk size) Fully parallelizable

7. Anti-Patterns & Immediate Fixes

Anti-Pattern Symptom Fix
Left-fold string concat O(N²) time LIST_STR + tree_reduce + join
Mutable dict/counter Race conditions in parallel Immutable map_monoid
Order-dependent reduce Totals vary by chunking Associative monoid + tree_reduce
NaN/inf in metrics Silent corruption Guarded combine
Manual empty handling Off-by-one on empty input m.empty() in tree_reduce

8. Pre-Core Quiz

  1. Monoid = semigroup + what? → identity (empty)
  2. tree_reduce avoids what? → O(N²) left folds
  3. For logs? → LIST_STR + "".join
  4. For nested config? → map_monoid(deep_monoid)
  5. For validation errors? → Semigroup, wrapped in Monoid with () identity

9. Post-Core Exercise

  1. Implement Monoid for your domain metric type → test associativity + identity.
  2. Replace one += / extend aggregation with tree_reduce.
  3. Build a deep config monoid with map_monoid(product_monoid(...)).
  4. Prove your log aggregation is near-linear and parallel-safe.

Next: M05C06 – Pydantic v2 as Smart Constructors – Runtime Enforcement Without Losing ADTs.

You now aggregate anything — logs, metrics, configs, errors — with mathematically proven correctness, near-linear time, and full parallelism. The rest of Module 5 adds Pydantic for runtime validation, pattern matching for orchestration, and serialization contracts.