Skip to content

M05C01: Product & Sum Types via dataclasses, Enum, Union

Progression Note

By the end of Module 5, you will model every domain concept as immutable algebraic data types (products and tagged sums), eliminating whole classes of runtime errors through exhaustive pattern matching, mypy-checked totality, and pure serialization contracts.

Module Focus Key Outcomes
4 Safe Recursion & Error Handling Stack-safe tree recursion, folds, Result/Option, streaming validation/retries
5 Advanced Type-Driven Design ADTs, exhaustive pattern matching, total functions, refined types
6 Monadic Flows as Composable Pipelines bind/and_then, Reader/State-like patterns, error-typed flows

Core question
How do you replace ad-hoc dicts, mutable classes, and fragile inheritance with pure product and tagged sum types — guaranteeing exhaustive handling, immutability, and stable serialization in every pipeline stage?

We now take the Chunk and error types from Module 04 and ask the question every seasoned engineer eventually faces:

“Why does my RAG pipeline still have mysterious None crashes, silent data loss, and ‘forgot-to-handle-this-case’ bugs even after adding Result and retries?”

The answer is almost always: you’re still using primitive dicts, mutable state, or untyped “variant” classes instead of real algebraic data types.

The naïve (and extremely common) solution:

class ChunkState:
    def __init__(self, success=None, embedding=None, error=None):
        self.success = success
        self.embedding = embedding
        self.error = error

# somewhere deep in the code...
if state.success:
    index(state.embedding)   # oops, someone forgot to check error is not None

Classic boolean blindness + null soup.

The production solution: model every domain concept as either a product type (AND) or a tagged sum type (OR — exactly one variant).

ChunkState = Success | Failure

@dataclass(frozen=True, slots=True, kw_only=True)
class Success:
    kind: Literal["success"] = "success"
    embedding: tuple[float, ...]
    metadata: tuple[tuple[str, JSON], ...]

@dataclass(frozen=True, slots=True, kw_only=True)
class Failure:
    kind: Literal["failure"] = "failure"
    code: str
    msg: str
    attempt: int

Now mypy + assert_never forces you to handle both cases — forever.

Audience: Engineers tired of “it works on my machine” bugs caused by incomplete state handling.

Outcome 1. Every dict/class soup replaced with proper product and tagged sum types. 2. Exhaustiveness proved via mypy + assert_never. 3. Immutable, serialisable domain models that survive refactors without silent regressions.

Tiny Non-Domain Example – Shape ADT

from dataclasses import dataclass
from typing import Literal
from typing import assert_never

@dataclass(frozen=True, slots=True)
class Circle:
    kind: Literal["circle"] = "circle"
    radius: float

@dataclass(frozen=True, slots=True)
class Rectangle:
    kind: Literal["rectangle"] = "rectangle"
    width: float
    height: float

Shape = Circle | Rectangle

def area(s: Shape) -> float:
    match s:
        case Circle(radius=r):
            return 3.14159265359 * r * r
        case Rectangle(width=w, height=h):
            return w * h
    assert_never(s)  # mypy errors if you add Triangle and forget to handle it

Adding a new variant breaks every handler until you update it — no silent defaults.

Why ADTs? (Three bullets every engineer should internalise)

  • Exhaustiveness: Adding a variant breaks every handler until you update it — no silent missing cases.
  • Immutability + Value semantics: Frozen + structural eq/hash → safe in sets/dicts, pure functions, cache keys (provided nested metadata structures are not mutated after construction).
  • Stable serialization: Explicit kind tag + deterministic field order (sorted tuple for metadata) → JSON round-trip without surprises.

1. Laws & Invariants (machine-checked)

Law Formal Statement Enforcement
Exhaustiveness Every match over a sum type must handle all variants (proved by assert_never) mypy --strict + tests
Immutability Dataclass fields cannot be reassigned. Nested containers in metadata are shared references and remain mutable (do not mutate them after chunk creation to preserve value semantics) test_adt_immutability (top-level)
Structural Equality x == y iff all fields equal (stable under dict key order via sorted tuples) test_chunk_metadata_order_independent
JSON Round-Trip from_dict(to_dict(x)) == x for all instances test_chunk_roundtrip, test_chunk_state_roundtrip

2. Decision Table – Which ADT Construction Do You Actually Use?

Data Shape Has Payload? Needs Tags? Recommended Construction
Simple record (AND of fields) Yes No @dataclass(frozen=True, slots=True, kw_only=True)
Simple enumeration (no data) No Yes class Status(Enum): PENDING = "pending" ...
Tagged variants with data (OR + payload) Yes Yes Union of tagged dataclasses with kind: Literal["..."]
Deeply nested tree Yes Yes Recursive Union of tagged dataclasses

Never use mutable classes or bare dicts for domain data.

3. Public API (fp/core.py – mypy --strict clean)

from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable, Literal, Mapping, Sequence, Tuple, TypeAlias

JSONPrimitive: TypeAlias = str | int | float | bool | None
JSON: TypeAlias = JSONPrimitive | Mapping[str, "JSON"] | Sequence["JSON"]
Path = Tuple[int, ...]

def _freeze_metadata(m: Mapping[str, JSON]) -> Tuple[Tuple[str, JSON], ...]:
    # sort by key only – values may be heterogeneous JSON
    return tuple(sorted(m.items()))

@dataclass(frozen=True, slots=True, kw_only=True)
class Chunk:
    text: str
    path: Path
    metadata: Tuple[Tuple[str, JSON], ...]   # top-level frozen, order-independent
    version: Literal[1] = 1

def make_chunk(
    *,
    text: str,
    path: Path,
    metadata: Mapping[str, JSON],
) -> Chunk:
    return Chunk(text=text, path=path, metadata=_freeze_metadata(metadata))

def chunk_to_dict(c: Chunk) -> dict[str, JSON]:
    return {
        "version": c.version,
        "text": c.text,
        "path": list(c.path),
        "metadata": dict(c.metadata),
    }

def chunk_from_dict(d: Mapping[str, JSON]) -> Chunk:
    if d.get("version") != 1:
        raise ValueError("unsupported version")
    return make_chunk(
        text=str(d["text"]),
        path=tuple(int(i) for i in d["path"]),
        metadata=dict(d["metadata"]),
    )

# Success / Failure sum type for embedding outcomes
@dataclass(frozen=True, slots=True, kw_only=True)
class Success:
    kind: Literal["success"] = "success"
    embedding: Tuple[float, ...]
    metadata: Tuple[Tuple[str, JSON], ...]

@dataclass(frozen=True, slots=True, kw_only=True)
class Failure:
    kind: Literal["failure"] = "failure"
    code: str
    msg: str
    attempt: int

ChunkState = Success | Failure

def success(
    *,
    embedding: Iterable[float],
    metadata: Mapping[str, JSON],
) -> Success:
    return Success(
        embedding=tuple(float(x) for x in embedding),
        metadata=_freeze_metadata(metadata),
    )

def failure(*, code: str, msg: str, attempt: int) -> Failure:
    return Failure(code=code, msg=msg, attempt=attempt)

4. Reference Implementations (continued)

4.1 Tree Sum Type (recursive tagged union)

from typing import assert_never

@dataclass(frozen=True, slots=True, kw_only=True)
class TextNode:
    kind: Literal["text"] = "text"
    content: str

@dataclass(frozen=True, slots=True, kw_only=True)
class SectionNode:
    kind: Literal["section"] = "section"
    title: str
    children: Tuple["Node", ...]

@dataclass(frozen=True, slots=True, kw_only=True)
class ListNode:
    kind: Literal["list"] = "list"
    items: Tuple["Node", ...]

Node = TextNode | SectionNode | ListNode

4.2 Exhaustive Pattern Matching

def node_depth(n: Node) -> int:
    match n:
        case TextNode():
            return 0
        case SectionNode(children=children):
            return 1 + max((node_depth(c) for c in children), default=0)
        case ListNode(items=items):
            return 1 + max((node_depth(i) for i in items), default=0)
    assert_never(n)  # mypy errors if you add a new variant

4.3 JSON Round-Trip for Tagged Sum

def chunk_state_to_dict(state: ChunkState) -> dict[str, JSON]:
    base = {"kind": state.kind, "version": 1}
    if isinstance(state, Success):
        return base | {
            "embedding": list(state.embedding),
            "metadata": dict(state.metadata),
        }
    else:  # Failure
        return base | {
            "code": state.code,
            "msg": state.msg,
            "attempt": state.attempt,
        }

def chunk_state_from_dict(d: Mapping[str, JSON]) -> ChunkState:
    if d.get("version") != 1:
        raise ValueError("unsupported version")
    kind = d["kind"]
    if kind == "success":
        return success(
            embedding=d["embedding"],      # type: ignore[arg-type]
            metadata=dict(d["metadata"]),
        )
    if kind == "failure":
        return failure(
            code=d["code"],                # type: ignore[arg-type]
            msg=d["msg"],
            attempt=d["attempt"],          # type: ignore[arg-type]
        )
    raise ValueError(f"unknown kind {kind}")

4.4 Big-O & Allocation Guarantees

Construction Time Heap Notes
dataclass creation O(1) O(#fields) slots=True → no dict
Tagged union match O(1) O(1) Exhaustive via assert_never
JSON round-trip O(N) O(N) Stable via sorted tuples

4.5 Anti-Patterns & Immediate Fixes

Anti-Pattern Symptom Fix
Mutable domain classes Accidental mutation frozen=True, slots=True, kw_only=True
Untagged Union or dict variants Silent missing cases kind: Literal + assert_never
Dict for metadata Unstable equality/serialization Sorted tuple of tuples
Inheritance for variants Fragile, hard to exhaust Tagged union of dataclasses

5. Property-Based Proofs (tests/test_module_05_c01.py)

import dataclasses
from typing import assert_never

import pytest
from hypothesis import given
import hypothesis.strategies as st

# ... imports of your ADT types ...

@given(text=st.text(), path=st.lists(st.integers(), max_size=10).map(tuple),
       meta=st.dictionaries(st.text(), st.integers() | st.lists(st.integers())))
def test_chunk_immutability(text, path, meta):
    chunk = make_chunk(text=text, path=path, metadata=meta)
    with pytest.raises(dataclasses.FrozenInstanceError):
        chunk.text = "mutated"

@given(meta=st.dictionaries(st.text(), st.integers()))
def test_chunk_metadata_order_independent(meta):
    c1 = make_chunk(text="t", path=(), metadata=meta)
    c2 = make_chunk(text="t", path=(), metadata=dict(reversed(list(meta.items()))))
    assert c1 == c2
    assert hash(c1) == hash(c2)

@given(chunk=st.builds(make_chunk,
                      text=st.text(),
                      path=st.lists(st.integers(), max_size=10).map(tuple),
                      metadata=st.dictionaries(st.text(), st.integers() | st.none())))
def test_chunk_roundtrip(chunk):
    j = chunk_to_dict(chunk)
    reloaded = chunk_from_dict(j)
    assert chunk == reloaded

@given(succ=st.builds(success,
                     embedding=st.lists(st.floats(allow_nan=False, allow_infinity=False), max_size=10),
                     metadata=st.dictionaries(st.text(), st.integers())),
       fail=st.builds(failure, code=st.text(min_size=1), msg=st.text(), attempt=st.integers(min_value=1)))
def test_chunk_state_roundtrip(succ, fail):
    for state in (succ, fail):
        j = chunk_state_to_dict(state)
        reloaded = chunk_state_from_dict(j)
        assert state == reloaded

@given(node=st.recursive(
    st.builds(TextNode, content=st.text()),
    lambda children: st.one_of(
        st.builds(SectionNode, title=st.text(), children=children),
        st.builds(ListNode, items=children),
    ),
    max_leaves=20,
))
def test_node_exhaustive_match(node):
    def dummy(n: Node) -> int:
        match n:
            case TextNode():   return 0
            case SectionNode(): return 1
            case ListNode():    return 2
        assert_never(n)
    dummy(node)

6. Pre-Core Quiz

  1. Product type → AND of fields (dataclass)
  2. Tagged sum type → OR of variants with payloads
  3. frozen=True, slots=True → Immutability + efficiency
  4. assert_never → Exhaustiveness proof
  5. Why sorted tuple for metadata? → Stable equality & hash (independent of dict key order)

7. Post-Core Exercise

  1. Model your current chunk embedding state as a tagged sum type → add assert_never to every handler.
  2. Refactor one dict-based structure to frozen dataclass → test JSON round-trip + order-independent equality.
  3. Add a new variant to an existing sum type → verify mypy errors in all match sites.
  4. Replace a mutable class in your codebase with a frozen dataclass → measure memory improvement (slots!).

Next: M05C02 – Modelling Domain States (Pending/Running/Done, Success/Failure) as ADTs.

You now model every piece of domain data as pure, immutable, exhaustively-handled values — eliminating vast classes of bugs before they happen. The rest of Module 5 is about composing these ADTs into powerful abstractions (functors, applicatives, monoids).