Skip to content

core.model_io

Model bundle persistence: save, load, and metadata for trained model artifacts.

ModelMetadata Fields

Field Type Description
schema_version str Feature schema version ("v1" or "v2")
schema_hash str Deterministic hash of the feature schema
label_set list[str] Sorted list of core labels used in training
train_date_from str First date of the training range (ISO-8601)
train_date_to str Last date of the training range (ISO-8601)
params dict Model hyperparameters
git_commit str Git commit SHA at training time
dataset_hash str SHA-256 hash of the training dataset for reproducibility
reject_threshold float or None Reject threshold used during evaluation. Advisory only — the canonical runtime threshold lives in InferencePolicy.
data_provenance str Origin: "real", "synthetic", or "mixed"
created_at str ISO-8601 timestamp of bundle creation
unknown_category_freq_threshold int or None Minimum category frequency used during training (categories below this become __unknown__)
unknown_category_mask_rate float or None Fraction of known categories randomly masked to __unknown__ during training

Schema Version Support

load_model_bundle validates bundles against a registry of known schema versions. Both v1 and v2 bundles are accepted; the bundle's schema_version field selects the expected hash. A hash mismatch (e.g. loading a v1 bundle whose hash has been tampered to v2's value) raises ValueError.

build_metadata accepts a schema_version parameter (default "v1") and fills in the correct version string and hash automatically.

taskclf.core.model_io

Model bundle persistence: save, load, and metadata for trained model artifacts.

ModelMetadata

Bases: BaseModel

Immutable record stored alongside a trained model as metadata.json.

Captures the feature schema version/hash, label vocabulary, training date range, hyperparameters, and the git commit at training time so that inference can verify compatibility before predicting.

Source code in src/taskclf/core/model_io.py
class ModelMetadata(BaseModel, frozen=True):
    """Immutable record stored alongside a trained model as ``metadata.json``.

    Captures the feature schema version/hash, label vocabulary, training
    date range, hyperparameters, and the git commit at training time so
    that inference can verify compatibility before predicting.
    """

    schema_version: str
    schema_hash: str
    label_set: list[str]
    train_date_from: str
    train_date_to: str
    params: dict[str, Any]
    git_commit: str
    dataset_hash: str
    reject_threshold: float | None = None
    """.. deprecated::
        Advisory only.  The canonical runtime reject threshold now
        lives in :class:`~taskclf.core.inference_policy.InferencePolicy`.
    """
    data_provenance: Literal["real", "synthetic", "mixed"] = "real"
    created_at: str = Field(default_factory=lambda: datetime.now(UTC).isoformat())
    unknown_category_freq_threshold: int | None = None
    unknown_category_mask_rate: float | None = None

reject_threshold = None class-attribute instance-attribute

.. deprecated:: Advisory only. The canonical runtime reject threshold now lives in :class:~taskclf.core.inference_policy.InferencePolicy.

generate_run_id()

Produce a unique run directory name: YYYY-MM-DD_HHMMSS_run-XXXX.

Returns:

Type Description
str

A string like 2026-02-19_013000_run-0042.

Source code in src/taskclf/core/model_io.py
def generate_run_id() -> str:
    """Produce a unique run directory name: ``YYYY-MM-DD_HHMMSS_run-XXXX``.

    Returns:
        A string like ``2026-02-19_013000_run-0042``.
    """
    now = datetime.now(UTC)
    suffix = f"{random.randint(0, 9999):04d}"
    return f"{now.strftime('%Y-%m-%d_%H%M%S')}_run-{suffix}"

save_model_bundle(model, metadata, metrics, confusion_df, base_dir, cat_encoders=None)

Persist a complete model bundle into base_dir/<run_id>/.

Writes the core files per the Model Bundle Contract plus an optional categorical_encoders.json mapping each categorical column to its sorted vocabulary list.

Parameters:

Name Type Description Default
model Booster

Trained LightGBM booster.

required
metadata ModelMetadata

Provenance record (schema hash, label set, params, etc.).

required
metrics dict

Evaluation dict (as returned by :func:~taskclf.core.metrics.compute_metrics).

required
confusion_df DataFrame

Labelled confusion matrix for CSV export.

required
base_dir Path

Parent directory (e.g. Path("models")). A new <run_id>/ subdirectory is created inside it.

required
cat_encoders dict | None

Optional dict mapping categorical column names to fitted LabelEncoder instances. Persisted as JSON vocabulary lists so inference can reconstruct them.

None

Returns:

Type Description
Path

Path to the newly created run directory.

Raises:

Type Description
FileExistsError

If the generated run directory already exists.

Source code in src/taskclf/core/model_io.py
def save_model_bundle(
    model: lgb.Booster,
    metadata: ModelMetadata,
    metrics: dict,
    confusion_df: pd.DataFrame,
    base_dir: Path,
    cat_encoders: dict | None = None,
) -> Path:
    """Persist a complete model bundle into ``base_dir/<run_id>/``.

    Writes the core files per the Model Bundle Contract plus an optional
    ``categorical_encoders.json`` mapping each categorical column to its
    sorted vocabulary list.

    Args:
        model: Trained LightGBM booster.
        metadata: Provenance record (schema hash, label set, params, etc.).
        metrics: Evaluation dict (as returned by
            :func:`~taskclf.core.metrics.compute_metrics`).
        confusion_df: Labelled confusion matrix for CSV export.
        base_dir: Parent directory (e.g. ``Path("models")``).
            A new ``<run_id>/`` subdirectory is created inside it.
        cat_encoders: Optional dict mapping categorical column names to
            fitted ``LabelEncoder`` instances.  Persisted as JSON
            vocabulary lists so inference can reconstruct them.

    Returns:
        Path to the newly created run directory.

    Raises:
        FileExistsError: If the generated run directory already exists.
    """
    run_id = generate_run_id()
    run_dir = base_dir / run_id
    if run_dir.exists():
        raise FileExistsError(f"Run directory already exists: {run_dir}")
    run_dir.mkdir(parents=True)

    model.save_model(str(run_dir / "model.txt"))

    (run_dir / "metadata.json").write_text(json.dumps(metadata.model_dump(), indent=2))

    (run_dir / "metrics.json").write_text(json.dumps(metrics, indent=2))

    confusion_df.to_csv(run_dir / "confusion_matrix.csv")

    if cat_encoders:
        vocab = {col: list(le.classes_) for col, le in cat_encoders.items()}
        (run_dir / "categorical_encoders.json").write_text(json.dumps(vocab, indent=2))

    return run_dir

load_model_bundle(run_dir, *, validate_schema=True, validate_labels=True)

Load a model bundle and optionally validate schema hash and label set.

Schema validation accepts v1, v2, and v3 bundles: the bundle's schema_version is looked up in the known schema registry and its schema_hash is checked against the corresponding expected hash.

Parameters:

Name Type Description Default
run_dir Path

Path to an existing run directory (e.g. models/2026-02-19_013000_run-0042/).

required
validate_schema bool

When True (the default), raise if the bundle's schema hash does not match the expected hash for its declared schema version.

True
validate_labels bool

When True (the default), raise if the bundle's label set differs from the current LABEL_SET_V1.

True

Returns:

Type Description
Booster

A (model, metadata, cat_encoders) tuple where cat_encoders

ModelMetadata

is a dict mapping column names to fitted LabelEncoder

dict[str, Any]

instances. Returns an empty dict when no encoder file exists.

Raises:

Type Description
ValueError

If validation is enabled and the schema hash or label set recorded in the bundle does not match the running code.

Source code in src/taskclf/core/model_io.py
def load_model_bundle(
    run_dir: Path,
    *,
    validate_schema: bool = True,
    validate_labels: bool = True,
) -> tuple[lgb.Booster, ModelMetadata, dict[str, Any]]:
    """Load a model bundle and optionally validate schema hash and label set.

    Schema validation accepts v1, v2, and v3 bundles: the bundle's
    ``schema_version`` is looked up in the known schema registry and its
    ``schema_hash`` is checked against the corresponding expected hash.

    Args:
        run_dir: Path to an existing run directory (e.g.
            ``models/2026-02-19_013000_run-0042/``).
        validate_schema: When ``True`` (the default), raise if the
            bundle's schema hash does not match the expected hash for
            its declared schema version.
        validate_labels: When ``True`` (the default), raise if the
            bundle's label set differs from the current ``LABEL_SET_V1``.

    Returns:
        A ``(model, metadata, cat_encoders)`` tuple where *cat_encoders*
        is a dict mapping column names to fitted ``LabelEncoder``
        instances. Returns an empty dict when no encoder file exists.

    Raises:
        ValueError: If validation is enabled and the schema hash or label
            set recorded in the bundle does not match the running code.
    """
    from sklearn.preprocessing import LabelEncoder

    model = lgb.Booster(model_file=str(run_dir / "model.txt"))

    raw = json.loads((run_dir / "metadata.json").read_text())
    metadata = ModelMetadata.model_validate(raw)

    if validate_schema:
        expected_hash = _SCHEMA_HASHES.get(metadata.schema_version)
        if expected_hash is None:
            raise ValueError(
                f"Unknown schema version in bundle: {metadata.schema_version!r}"
            )
        if metadata.schema_hash != expected_hash:
            raise ValueError(
                f"Schema hash mismatch: bundle has {metadata.schema_hash!r}, "
                f"expected {expected_hash!r} for schema {metadata.schema_version!r}"
            )

    if validate_labels and sorted(metadata.label_set) != sorted(LABEL_SET_V1):
        raise ValueError(
            f"Label set mismatch: bundle has {sorted(metadata.label_set)!r}, "
            f"current label set is {sorted(LABEL_SET_V1)!r}"
        )

    cat_encoders: dict[str, LabelEncoder] = {}
    enc_path = run_dir / "categorical_encoders.json"
    if enc_path.exists():
        vocab = json.loads(enc_path.read_text())
        for col, classes in vocab.items():
            le = LabelEncoder()
            le.fit(classes)
            cat_encoders[col] = le

    return model, metadata, cat_encoders

build_metadata(label_set, train_date_from, train_date_to, params, *, dataset_hash, reject_threshold=None, data_provenance='real', unknown_category_freq_threshold=None, unknown_category_mask_rate=None, schema_version=LATEST_FEATURE_SCHEMA_VERSION)

Convenience builder that fills in schema info and git commit.

Parameters:

Name Type Description Default
label_set list[str]

Task-type labels used during training.

required
train_date_from date

First date of the training range.

required
train_date_to date

Last date (inclusive) of the training range.

required
params dict[str, Any]

LightGBM (or other model) hyperparameters dict.

required
dataset_hash str

Deterministic SHA-256 hash of the training dataset used for reproducibility auditing.

required
reject_threshold float | None

Reject threshold used during evaluation.

None
data_provenance Literal['real', 'synthetic', 'mixed']

Origin of the training data ("real", "synthetic", or "mixed").

'real'
unknown_category_freq_threshold int | None

Minimum category frequency used during training (categories below this are __unknown__).

None
unknown_category_mask_rate float | None

Fraction of known categories randomly masked to __unknown__ during training.

None
schema_version str

"v1", "v2", or "v3".

LATEST_FEATURE_SCHEMA_VERSION

Returns:

Type Description
ModelMetadata

A populated ModelMetadata instance.

Raises:

Type Description
ValueError

If schema_version is not recognised.

Source code in src/taskclf/core/model_io.py
def build_metadata(
    label_set: list[str],
    train_date_from: date,
    train_date_to: date,
    params: dict[str, Any],
    *,
    dataset_hash: str,
    reject_threshold: float | None = None,
    data_provenance: Literal["real", "synthetic", "mixed"] = "real",
    unknown_category_freq_threshold: int | None = None,
    unknown_category_mask_rate: float | None = None,
    schema_version: str = LATEST_FEATURE_SCHEMA_VERSION,
) -> ModelMetadata:
    """Convenience builder that fills in schema info and git commit.

    Args:
        label_set: Task-type labels used during training.
        train_date_from: First date of the training range.
        train_date_to: Last date (inclusive) of the training range.
        params: LightGBM (or other model) hyperparameters dict.
        dataset_hash: Deterministic SHA-256 hash of the training dataset
            used for reproducibility auditing.
        reject_threshold: Reject threshold used during evaluation.
        data_provenance: Origin of the training data
            (``"real"``, ``"synthetic"``, or ``"mixed"``).
        unknown_category_freq_threshold: Minimum category frequency
            used during training (categories below this are ``__unknown__``).
        unknown_category_mask_rate: Fraction of known categories randomly
            masked to ``__unknown__`` during training.
        schema_version: ``"v1"``, ``"v2"``, or ``"v3"``.

    Returns:
        A populated ``ModelMetadata`` instance.

    Raises:
        ValueError: If *schema_version* is not recognised.
    """
    schema_hash = _SCHEMA_HASHES.get(schema_version)
    if schema_hash is None:
        raise ValueError(f"Unknown schema version: {schema_version!r}")

    return ModelMetadata(
        schema_version=schema_version,
        schema_hash=schema_hash,
        label_set=sorted(label_set),
        train_date_from=train_date_from.isoformat(),
        train_date_to=train_date_to.isoformat(),
        params=params,
        git_commit=_current_git_commit(),
        dataset_hash=dataset_hash,
        reject_threshold=reject_threshold,
        data_provenance=data_provenance,
        unknown_category_freq_threshold=unknown_category_freq_threshold,
        unknown_category_mask_rate=unknown_category_mask_rate,
    )