Skip to content

M04C02: Folds & Reductions as Safe Recursion (Monoidal Fusion, Custom Accumulators)

Progression Note

By the end of Module 4, you will master safe recursion over unpredictable tree-shaped data, monoidal folds as the universal recursion pattern, Result/Option for streaming error handling, validation aggregators, retries, and structured error reporting — all while preserving laziness, equational reasoning, and constant call-stack usage.

Here's a snippet from the progression map:

Module Focus Key Outcomes
3 Lazy Iteration & Generators Memory-efficient streaming, itertools mastery, short-circuiting, observability
4 Safe Recursion & Error Handling in Streams Stack-safe tree recursion, folds, Result/Option, streaming validation/retries/reports
5 Advanced Type-Driven Design ADTs, exhaustive pattern matching, total functions, refined types

Core question:
How do you replace any structural-recursive aggregation with an iterative fold (catamorphism) that is stack-safe, fully lazy when needed, and capable of fusing arbitrary numbers of independent aggregations into a single O(N) traversal?

We now take the TreeDoc hierarchy from M04C01 and ask the most common real-world question:

“How many nodes are in this document tree, what is the total length of all text across all nodes, and what is the maximum depth?”

The naïve recursive solution is beautiful and obvious:

def recursive_stats(tree: TreeDoc) -> tuple[int, int, int]:
    count = 1
    length = len(tree.node.text)
    max_d = 0
    for child in tree.children:
        c, l, md = recursive_stats(child)
        count += c
        length += l
        max_d = max(max_d, md + 1)
    return count, length, max_d

It works perfectly… until the tree is 2000 levels deep and you get RecursionError.

The production solution must: - be iterative (O(1) call-stack), - compute all three values in one pass (fusion), - optionally be streaming (scan_tree yields running totals), - remain provably equivalent to the recursive spec.

This is exactly what a fold (catamorphism) gives us — “reduce over a tree” with the recursion made explicit and safe.

Audience: Engineers who routinely aggregate statistics over tree/document/graph structures and refuse to ship code that can RecursionError on pathological but legal inputs.

Outcome:
1. You will replace any recursive aggregation with an iterative fold that is formally terminating and stack-safe.
2. You will fuse arbitrary numbers of independent aggregations into a single traversal using immutable tuple accumulators.
3. You will ship streaming reductions (scan_tree) that are truly lazy and short-circuitable.

We formalise exactly what we want from a correct, production-ready fold: termination, stack-safety, perfect preorder, fusion, and (when scanning) bounded work.


Concrete Motivating Example

Same deep Markdown-derived tree from M04C01:

Root (title)                     → 50 chars, depth 0
├── Section 1                    → 30 chars, depth 1
│   ├── Subsection 1.1           → 20 chars, depth 2
│   └── Subsection 1.2           → 25 chars, depth 2
└── Section 2                    → 35 chars, depth 1
    └── Subsection 2.1           → 40 chars, depth 2
        └── … (2000 levels deep) → leaf node with 10 chars, depth 2002

Desired aggregates (computed in one pass):

total_nodes       = 2004
total_text_length = 85_050
max_depth         = 2002

We want all three numbers, plus optionally a running total after each node (for progress bars, early termination, etc.).


1. Laws & Invariants (machine-checked where possible)

All laws assume finite, acyclic TreeDoc inputs (always non-empty; root node exists).

Law Formal Statement Enforcement
Termination & Stack-Safety Completes in O(N) steps with O(1) call-stack frames for any finite acyclic tree. Formal proof via explicit stack + Hypothesis on 5000-node chains + CI recursion-limit guard.
Equivalence fold_tree(t, seed, f) == recursive_fold(t, seed, f) for all t (identical result and preorder application). Hypothesis test_fold_vs_recursive_equivalence.
Fusion Fused tuple fold equals separate folds: (count, length, max_d) == (fold_count(t), fold_len(t), fold_max_d(t)). Hypothesis test_fusion_equivalence.
Bounded-Work (scan_tree) Consuming first k partial accumulators visits exactly k nodes. Instrumented property test_scan_bounded_work.
Order Law Combiner applied in strict preorder (identical sequence to flatten(t) from M04C01). Property test test_fold_preorder_matches_flatten.

These laws turn “fold” from a pattern into a verifiable contract.


2. Decision Table – Which Fold Do You Actually Use?

Need Streaming Partials? Multiple Values? Recommended Variant
Single aggregate (tree) No No fold_tree or fold_tree_no_path
Multiple aggregates (tree) No Yes fold_tree_buffered with tuple accumulator (fused)
Running totals (tree) Yes No/Yes scan_tree (optionally with tuple accumulator)
Linear (list/iterator) No/Yes No linear_reduce / linear_accumulate

Never use recursive aggregation in library code.
Never run multiple separate folds when a single fused tuple fold gives the same result in one pass.


3. Public API Surface (end-of-Module-04 refactor note)

Refactor note: tree folds/scans live in funcpipe_rag.tree (src/funcpipe_rag/tree/folds.py).
funcpipe_rag.api.core re-exports the same names as a stable façade for the teaching modules.

from funcpipe_rag.api.core import (
    fold_count_length_maxdepth,
    fold_tree,
    fold_tree_buffered,
    fold_tree_no_path,
    linear_accumulate,
    linear_reduce,
    scan_count_length_maxdepth,
    scan_tree,
)

4. Reference Implementations

4.1 Recursive Specification (Didactic only)

def recursive_fold(
    tree: TreeDoc,
    seed: R,
    combiner: Callable[[R, TreeDoc, int, Path], R],
    *,
    depth: int = 0,
    path: Path = (),
) -> R:
    acc = combiner(seed, tree, depth, path)
    for i, child in enumerate(tree.children):
        acc = recursive_fold(child, acc, combiner, depth=depth + 1, path=path + (i,))
    return acc

4.2 Simple Explicit-Stack Fold (Readable reference)

def fold_tree(
    tree: TreeDoc,
    seed: R,
    combiner: Callable[[R, TreeDoc, int, Path], R],
) -> R:
    acc = seed
    stack: deque[tuple[TreeDoc, int, Path, int]] = deque([(tree, 0, (), 0)])
    while stack:
        node, depth, path, child_idx = stack.pop()
        if child_idx == 0:
            acc = combiner(acc, node, depth, path)
        if child_idx < len(node.children):
            stack.append((node, depth, path, child_idx + 1))
            child = node.children[child_idx]
            stack.append((child, depth + 1, path + (child_idx,), 0))
    return acc

4.3 Production Winner – Buffered-Path Fold (zero extra tuple allocation)

def fold_tree_buffered(
    tree: TreeDoc,
    seed: R,
    combiner: Callable[[R, TreeDoc, int, Path], R],
) -> R:
    """
    Same semantics as fold_tree but maintains the path using a single mutable list
    (no tuples on the traversal stack). One tuple per node is still created when
    calling combiner (intrinsic to passing the path).
    """
    acc = seed
    stack: deque[tuple[TreeDoc, int, int | None]] = deque([(tree, 0, None)])
    path: list[int] = []
    last_depth = 0

    while stack:
        node, depth, sib_idx = stack.pop()

        # Maintain mutable path prefix (identical logic to iter_flatten_buffered)
        if depth < last_depth:
            del path[depth:]
        if sib_idx is not None:
            if depth > len(path):
                path.append(sib_idx)
            else:
                path[depth-1] = sib_idx
        last_depth = depth

        acc = combiner(acc, node, depth, tuple(path[:depth]))

        for i in range(len(node.children)-1, -1, -1):
            stack.append((node.children[i], depth + 1, i))

    return acc

4.4 Optimised Fold Without Path (when path unused)

def fold_tree_no_path(
    tree: TreeDoc,
    seed: R,
    combiner: Callable[[R, TreeDoc, int], R],
) -> R:
    """Third parameter is depth."""
    acc = seed
    stack: deque[tuple[TreeDoc, int, int]] = deque([(tree, 0, 0)])
    while stack:
        node, depth, child_idx = stack.pop()
        if child_idx == 0:
            acc = combiner(acc, node, depth)
        if child_idx < len(node.children):
            stack.append((node, depth, child_idx + 1))
            stack.append((node.children[child_idx], depth + 1, 0))
    return acc

4.5 Streaming Scan (Running Totals – Truly Lazy)

def scan_tree(
    tree: TreeDoc,
    seed: R,
    combiner: Callable[[R, TreeDoc, int, Path], R],
) -> Iterator[R]:
    """Yield running accumulator after each node in preorder – O(k) work for first k yields."""
    acc = seed
    stack: deque[tuple[TreeDoc, int, Path, int]] = deque([(tree, 0, (), 0)])
    while stack:
        node, depth, path, child_idx = stack.pop()
        if child_idx == 0:
            acc = combiner(acc, node, depth, path)
            yield acc
        if child_idx < len(node.children):
            stack.append((node, depth, path, child_idx + 1))
            child = node.children[child_idx]
            stack.append((child, depth + 1, path + (child_idx,), 0))

Note: linear_accumulate (via itertools.accumulate) yields the initial seed as the first value; scan_tree yields only post-node accumulators (no initial seed yield).

4.6 Fused Multi-Value Example (Count + Length + Max Depth)

def fold_count_length_maxdepth(tree: TreeDoc) -> Tuple[int, int, int]:
    def step(acc: Tuple[int, int, int], tree: TreeDoc, depth: int, path: Path) -> Tuple[int, int, int]:
        count, length, max_d = acc
        return (
            count + 1,
            length + len(tree.node.text),
            max(max_d, depth)
        )
    return fold_tree_buffered(tree, (0, 0, 0), step)

5. Property-Based Proofs (tests/test_tree_folds.py)

@given(tree=tree_strategy())
def test_fold_vs_recursive_equivalence(tree):
    rec = recursive_fold(tree, (0, 0, 0), step_count_len_maxd)
    buf = fold_tree_buffered(tree, (0, 0, 0), step_count_len_maxd)
    assert rec == buf

@given(tree=tree_strategy())
def test_fusion_equivalence(tree):
    fused = fold_count_length_maxdepth(tree)
    count = fold_tree_no_path(tree, 0, lambda a, n, d: a + 1)
    length = fold_tree_no_path(tree, 0, lambda a, n, d: a + len(n.node.text))
    max_d = fold_tree_no_path(tree, 0, lambda a, n, d: max(a, d))
    assert fused == (count, length, max_d)

@given(tree=tree_strategy())
def test_fold_preorder_matches_flatten(tree):
    from funcpipe_rag.api.core import flatten
    order_via_fold: list[Path] = []
    fold_tree(tree, None, lambda _, n, d, p: order_via_fold.append(p))
    order_via_flatten = [c.metadata["path"] for c in flatten(tree)]
    assert order_via_fold == order_via_flatten

@given(tree=tree_strategy())
def test_fold_buffered_order_matches_simple(tree):
    order_simple: list[Path] = []
    order_buf: list[Path] = []
    fold_tree(tree, None, lambda _, n, d, p: order_simple.append(p))
    fold_tree_buffered(tree, None, lambda _, n, d, p: order_buf.append(p))
    assert order_simple == order_buf

6. Big-O & Allocation Guarantees (peak auxiliary memory)

Variant Time Call-stack Peak auxiliary heap Total allocations
fold_tree / scan_tree O(N) O(1) O(depth) O(N×depth) paths
fold_tree_buffered O(N) O(1) O(depth) O(N) paths (only on combine)
fold_tree_no_path O(N) O(1) O(depth) Zero paths

Result metadata (paths when used) is intrinsic; auxiliary overhead is only the explicit stack + one mutable path list.


7. Anti-Patterns & Immediate Fixes

Anti-Pattern Symptom Fix
Recursive aggregation in library code RecursionError on deep trees Replace with fold_tree_buffered
Separate folds for related stats 3–10× slower on large trees Fuse with tuple accumulator
Mutable accumulator Aliasing / nondeterminism Use immutable tuples
String concatenation in combiner Quadratic time Count lengths, join once at end

8. Pre-Core Quiz

  1. Safe recursion over trees? → Iterative fold with explicit stack
  2. Compute 5 independent stats? → One fused tuple fold
  3. Running totals over tree? → scan_tree
  4. Zero extra path allocation on deep chains? → fold_tree_buffered
  5. Equivalence guarantee? → Hypothesis vs recursive spec + order test

9. Post-Core Exercise

  1. Implement recursive max-depth → replace with fold → add equivalence + order property.
  2. Fuse count + text length + max depth + set of all doc_ids in one fold.
  3. Add scan_tree progress reporting to the RAG pipeline (yield running chunk count).
  4. Find any multi-pass aggregation in your codebase → fuse into one fold → measure speedup.

Next: M04C03 – Memoization & Caching of Pure Functions (lru_cache, Custom Disk-Backed Caches, Observational Purity).

You now own the universal pattern for safe aggregation over any algebraic data type. Everything else in Module 4 is just specialising this fold.