M04C02: Folds & Reductions as Safe Recursion (Monoidal Fusion, Custom Accumulators)¶
Progression Note¶
By the end of Module 4, you will master safe recursion over unpredictable tree-shaped data, monoidal folds as the universal recursion pattern, Result/Option for streaming error handling, validation aggregators, retries, and structured error reporting — all while preserving laziness, equational reasoning, and constant call-stack usage.
Here's a snippet from the progression map:
| Module | Focus | Key Outcomes |
|---|---|---|
| 3 | Lazy Iteration & Generators | Memory-efficient streaming, itertools mastery, short-circuiting, observability |
| 4 | Safe Recursion & Error Handling in Streams | Stack-safe tree recursion, folds, Result/Option, streaming validation/retries/reports |
| 5 | Advanced Type-Driven Design | ADTs, exhaustive pattern matching, total functions, refined types |
Core question:
How do you replace any structural-recursive aggregation with an iterative fold (catamorphism) that is stack-safe, fully lazy when needed, and capable of fusing arbitrary numbers of independent aggregations into a single O(N) traversal?
We now take the TreeDoc hierarchy from M04C01 and ask the most common real-world question:
“How many nodes are in this document tree, what is the total length of all text across all nodes, and what is the maximum depth?”
The naïve recursive solution is beautiful and obvious:
def recursive_stats(tree: TreeDoc) -> tuple[int, int, int]:
count = 1
length = len(tree.node.text)
max_d = 0
for child in tree.children:
c, l, md = recursive_stats(child)
count += c
length += l
max_d = max(max_d, md + 1)
return count, length, max_d
It works perfectly… until the tree is 2000 levels deep and you get RecursionError.
The production solution must:
- be iterative (O(1) call-stack),
- compute all three values in one pass (fusion),
- optionally be streaming (scan_tree yields running totals),
- remain provably equivalent to the recursive spec.
This is exactly what a fold (catamorphism) gives us — “reduce over a tree” with the recursion made explicit and safe.
Audience: Engineers who routinely aggregate statistics over tree/document/graph structures and refuse to ship code that can RecursionError on pathological but legal inputs.
Outcome:
1. You will replace any recursive aggregation with an iterative fold that is formally terminating and stack-safe.
2. You will fuse arbitrary numbers of independent aggregations into a single traversal using immutable tuple accumulators.
3. You will ship streaming reductions (scan_tree) that are truly lazy and short-circuitable.
We formalise exactly what we want from a correct, production-ready fold: termination, stack-safety, perfect preorder, fusion, and (when scanning) bounded work.
Concrete Motivating Example¶
Same deep Markdown-derived tree from M04C01:
Root (title) → 50 chars, depth 0
├── Section 1 → 30 chars, depth 1
│ ├── Subsection 1.1 → 20 chars, depth 2
│ └── Subsection 1.2 → 25 chars, depth 2
└── Section 2 → 35 chars, depth 1
└── Subsection 2.1 → 40 chars, depth 2
└── … (2000 levels deep) → leaf node with 10 chars, depth 2002
Desired aggregates (computed in one pass):
We want all three numbers, plus optionally a running total after each node (for progress bars, early termination, etc.).
1. Laws & Invariants (machine-checked where possible)¶
All laws assume finite, acyclic TreeDoc inputs (always non-empty; root node exists).
| Law | Formal Statement | Enforcement |
|---|---|---|
| Termination & Stack-Safety | Completes in O(N) steps with O(1) call-stack frames for any finite acyclic tree. | Formal proof via explicit stack + Hypothesis on 5000-node chains + CI recursion-limit guard. |
| Equivalence | fold_tree(t, seed, f) == recursive_fold(t, seed, f) for all t (identical result and preorder application). |
Hypothesis test_fold_vs_recursive_equivalence. |
| Fusion | Fused tuple fold equals separate folds: (count, length, max_d) == (fold_count(t), fold_len(t), fold_max_d(t)). |
Hypothesis test_fusion_equivalence. |
| Bounded-Work (scan_tree) | Consuming first k partial accumulators visits exactly k nodes. | Instrumented property test_scan_bounded_work. |
| Order Law | Combiner applied in strict preorder (identical sequence to flatten(t) from M04C01). |
Property test test_fold_preorder_matches_flatten. |
These laws turn “fold” from a pattern into a verifiable contract.
2. Decision Table – Which Fold Do You Actually Use?¶
| Need | Streaming Partials? | Multiple Values? | Recommended Variant |
|---|---|---|---|
| Single aggregate (tree) | No | No | fold_tree or fold_tree_no_path |
| Multiple aggregates (tree) | No | Yes | fold_tree_buffered with tuple accumulator (fused) |
| Running totals (tree) | Yes | No/Yes | scan_tree (optionally with tuple accumulator) |
| Linear (list/iterator) | No/Yes | No | linear_reduce / linear_accumulate |
Never use recursive aggregation in library code.
Never run multiple separate folds when a single fused tuple fold gives the same result in one pass.
3. Public API Surface (end-of-Module-04 refactor note)¶
Refactor note: tree folds/scans live in funcpipe_rag.tree (src/funcpipe_rag/tree/folds.py).
funcpipe_rag.api.core re-exports the same names as a stable façade for the teaching modules.
from funcpipe_rag.api.core import (
fold_count_length_maxdepth,
fold_tree,
fold_tree_buffered,
fold_tree_no_path,
linear_accumulate,
linear_reduce,
scan_count_length_maxdepth,
scan_tree,
)
4. Reference Implementations¶
4.1 Recursive Specification (Didactic only)¶
def recursive_fold(
tree: TreeDoc,
seed: R,
combiner: Callable[[R, TreeDoc, int, Path], R],
*,
depth: int = 0,
path: Path = (),
) -> R:
acc = combiner(seed, tree, depth, path)
for i, child in enumerate(tree.children):
acc = recursive_fold(child, acc, combiner, depth=depth + 1, path=path + (i,))
return acc
4.2 Simple Explicit-Stack Fold (Readable reference)¶
def fold_tree(
tree: TreeDoc,
seed: R,
combiner: Callable[[R, TreeDoc, int, Path], R],
) -> R:
acc = seed
stack: deque[tuple[TreeDoc, int, Path, int]] = deque([(tree, 0, (), 0)])
while stack:
node, depth, path, child_idx = stack.pop()
if child_idx == 0:
acc = combiner(acc, node, depth, path)
if child_idx < len(node.children):
stack.append((node, depth, path, child_idx + 1))
child = node.children[child_idx]
stack.append((child, depth + 1, path + (child_idx,), 0))
return acc
4.3 Production Winner – Buffered-Path Fold (zero extra tuple allocation)¶
def fold_tree_buffered(
tree: TreeDoc,
seed: R,
combiner: Callable[[R, TreeDoc, int, Path], R],
) -> R:
"""
Same semantics as fold_tree but maintains the path using a single mutable list
(no tuples on the traversal stack). One tuple per node is still created when
calling combiner (intrinsic to passing the path).
"""
acc = seed
stack: deque[tuple[TreeDoc, int, int | None]] = deque([(tree, 0, None)])
path: list[int] = []
last_depth = 0
while stack:
node, depth, sib_idx = stack.pop()
# Maintain mutable path prefix (identical logic to iter_flatten_buffered)
if depth < last_depth:
del path[depth:]
if sib_idx is not None:
if depth > len(path):
path.append(sib_idx)
else:
path[depth-1] = sib_idx
last_depth = depth
acc = combiner(acc, node, depth, tuple(path[:depth]))
for i in range(len(node.children)-1, -1, -1):
stack.append((node.children[i], depth + 1, i))
return acc
4.4 Optimised Fold Without Path (when path unused)¶
def fold_tree_no_path(
tree: TreeDoc,
seed: R,
combiner: Callable[[R, TreeDoc, int], R],
) -> R:
"""Third parameter is depth."""
acc = seed
stack: deque[tuple[TreeDoc, int, int]] = deque([(tree, 0, 0)])
while stack:
node, depth, child_idx = stack.pop()
if child_idx == 0:
acc = combiner(acc, node, depth)
if child_idx < len(node.children):
stack.append((node, depth, child_idx + 1))
stack.append((node.children[child_idx], depth + 1, 0))
return acc
4.5 Streaming Scan (Running Totals – Truly Lazy)¶
def scan_tree(
tree: TreeDoc,
seed: R,
combiner: Callable[[R, TreeDoc, int, Path], R],
) -> Iterator[R]:
"""Yield running accumulator after each node in preorder – O(k) work for first k yields."""
acc = seed
stack: deque[tuple[TreeDoc, int, Path, int]] = deque([(tree, 0, (), 0)])
while stack:
node, depth, path, child_idx = stack.pop()
if child_idx == 0:
acc = combiner(acc, node, depth, path)
yield acc
if child_idx < len(node.children):
stack.append((node, depth, path, child_idx + 1))
child = node.children[child_idx]
stack.append((child, depth + 1, path + (child_idx,), 0))
Note: linear_accumulate (via itertools.accumulate) yields the initial seed as the first value; scan_tree yields only post-node accumulators (no initial seed yield).
4.6 Fused Multi-Value Example (Count + Length + Max Depth)¶
def fold_count_length_maxdepth(tree: TreeDoc) -> Tuple[int, int, int]:
def step(acc: Tuple[int, int, int], tree: TreeDoc, depth: int, path: Path) -> Tuple[int, int, int]:
count, length, max_d = acc
return (
count + 1,
length + len(tree.node.text),
max(max_d, depth)
)
return fold_tree_buffered(tree, (0, 0, 0), step)
5. Property-Based Proofs (tests/test_tree_folds.py)¶
@given(tree=tree_strategy())
def test_fold_vs_recursive_equivalence(tree):
rec = recursive_fold(tree, (0, 0, 0), step_count_len_maxd)
buf = fold_tree_buffered(tree, (0, 0, 0), step_count_len_maxd)
assert rec == buf
@given(tree=tree_strategy())
def test_fusion_equivalence(tree):
fused = fold_count_length_maxdepth(tree)
count = fold_tree_no_path(tree, 0, lambda a, n, d: a + 1)
length = fold_tree_no_path(tree, 0, lambda a, n, d: a + len(n.node.text))
max_d = fold_tree_no_path(tree, 0, lambda a, n, d: max(a, d))
assert fused == (count, length, max_d)
@given(tree=tree_strategy())
def test_fold_preorder_matches_flatten(tree):
from funcpipe_rag.api.core import flatten
order_via_fold: list[Path] = []
fold_tree(tree, None, lambda _, n, d, p: order_via_fold.append(p))
order_via_flatten = [c.metadata["path"] for c in flatten(tree)]
assert order_via_fold == order_via_flatten
@given(tree=tree_strategy())
def test_fold_buffered_order_matches_simple(tree):
order_simple: list[Path] = []
order_buf: list[Path] = []
fold_tree(tree, None, lambda _, n, d, p: order_simple.append(p))
fold_tree_buffered(tree, None, lambda _, n, d, p: order_buf.append(p))
assert order_simple == order_buf
6. Big-O & Allocation Guarantees (peak auxiliary memory)¶
| Variant | Time | Call-stack | Peak auxiliary heap | Total allocations |
|---|---|---|---|---|
| fold_tree / scan_tree | O(N) | O(1) | O(depth) | O(N×depth) paths |
| fold_tree_buffered | O(N) | O(1) | O(depth) | O(N) paths (only on combine) |
| fold_tree_no_path | O(N) | O(1) | O(depth) | Zero paths |
Result metadata (paths when used) is intrinsic; auxiliary overhead is only the explicit stack + one mutable path list.
7. Anti-Patterns & Immediate Fixes¶
| Anti-Pattern | Symptom | Fix |
|---|---|---|
| Recursive aggregation in library code | RecursionError on deep trees | Replace with fold_tree_buffered |
| Separate folds for related stats | 3–10× slower on large trees | Fuse with tuple accumulator |
| Mutable accumulator | Aliasing / nondeterminism | Use immutable tuples |
| String concatenation in combiner | Quadratic time | Count lengths, join once at end |
8. Pre-Core Quiz¶
- Safe recursion over trees? → Iterative fold with explicit stack
- Compute 5 independent stats? → One fused tuple fold
- Running totals over tree? →
scan_tree - Zero extra path allocation on deep chains? →
fold_tree_buffered - Equivalence guarantee? → Hypothesis vs recursive spec + order test
9. Post-Core Exercise¶
- Implement recursive max-depth → replace with fold → add equivalence + order property.
- Fuse count + text length + max depth + set of all doc_ids in one fold.
- Add
scan_treeprogress reporting to the RAG pipeline (yield running chunk count). - Find any multi-pass aggregation in your codebase → fuse into one fold → measure speedup.
Next: M04C03 – Memoization & Caching of Pure Functions (lru_cache, Custom Disk-Backed Caches, Observational Purity).
You now own the universal pattern for safe aggregation over any algebraic data type. Everything else in Module 4 is just specialising this fold.