Writing Custom Cross-Validators¶
Most projects should use the built-in validators:
KFoldValidatorfor trial-level folds with optional holdoutGroupedKFoldValidatorfor grouped trial splitsLeaveGroupOutValidator,LeaveSessionOutValidator, andLeaveAnimalOutValidatorfor group generalizationSampledDomainKFoldValidatorfor source/target domain sampling
Write a custom cross-validator only when the split policy itself is new. The
base CrossValidator already owns scoring, holdout evaluation, out-of-fold
prediction collection, stateless preprocessing reuse, stateful refitting,
caching, and final pipeline fitting.
The Validator Contract¶
A custom validator usually implements only two methods:
def _split_holdout(self, container: DataContainer) -> tuple[np.ndarray, np.ndarray]:
...
def _get_splits(
self,
container: DataContainer,
indices_to_split: np.ndarray,
) -> Iterator[tuple[np.ndarray, np.ndarray]]:
...
Both methods should return trial coordinate labels, not positional indices.
This matters because trial labels may be non-contiguous IDs.
cross_validate() calls these methods after fold-invariant stateless pipeline
steps have run. Your split logic should therefore assume that the trial
dimension and trial labels have been preserved by preprocessing.
Minimal Grouped Example¶
This validator creates folds by session and optionally reserves named sessions as a final holdout set.
from collections.abc import Hashable, Iterator
import numpy as np
from sklearn.model_selection import KFold
from xdflow.core import DataContainer
from xdflow.cv import CrossValidator
class SessionBlockedKFoldValidator(CrossValidator):
def __init__(
self,
session_coord: str = "session",
n_splits: int = 5,
holdout_sessions: list[Hashable] | None = None,
shuffle: bool = True,
random_state: int = 0,
**kwargs,
):
super().__init__(**kwargs)
self.session_coord = session_coord
self.n_splits = n_splits
self.holdout_sessions = holdout_sessions or []
self.shuffle = shuffle
self.random_state = random_state
def _validate_session_coord(self, container: DataContainer) -> None:
if "trial" not in container.data.dims:
raise ValueError("SessionBlockedKFoldValidator requires a 'trial' dimension.")
if self.session_coord not in container.data.coords:
raise ValueError(f"Coordinate '{self.session_coord}' not found.")
if container.data.coords[self.session_coord].dims != ("trial",):
raise ValueError(f"Coordinate '{self.session_coord}' must index the 'trial' dimension.")
def _split_holdout(self, container: DataContainer) -> tuple[np.ndarray, np.ndarray]:
self._validate_session_coord(container)
trial_labels = container.data.trial.values
sessions = container.data.coords[self.session_coord].values
if not self.holdout_sessions:
return trial_labels, np.array([], dtype=trial_labels.dtype)
holdout_mask = np.isin(sessions, np.asarray(self.holdout_sessions))
return trial_labels[~holdout_mask], trial_labels[holdout_mask]
def _get_splits(
self,
container: DataContainer,
indices_to_split: np.ndarray,
) -> Iterator[tuple[np.ndarray, np.ndarray]]:
self._validate_session_coord(container)
cv_data = container.data.sel(trial=indices_to_split)
trial_labels = cv_data.trial.values
sessions = cv_data.coords[self.session_coord].values
unique_sessions = np.unique(sessions)
if self.n_splits > len(unique_sessions):
raise ValueError("n_splits cannot exceed the number of sessions.")
splitter = KFold(n_splits=self.n_splits, shuffle=self.shuffle, random_state=self.random_state)
for train_pos, val_pos in splitter.split(unique_sessions):
train_sessions = unique_sessions[train_pos]
val_sessions = unique_sessions[val_pos]
train_mask = np.isin(sessions, train_sessions)
val_mask = np.isin(sessions, val_sessions)
yield trial_labels[train_mask], trial_labels[val_mask]
Then use it like any other validator:
cv = SessionBlockedKFoldValidator(
n_splits=4,
holdout_sessions=["session_d"],
scoring="f1_weighted",
verbose=False,
)
cv.set_pipeline(pipeline)
score = cv.cross_validate(container, verbose=False)
What The Base Class Handles¶
Do not reimplement the evaluation loop unless the base class cannot express the
workflow. CrossValidator.cross_validate() already handles:
- splitting the pipeline into fold-invariant stateless steps and stateful steps
- running stateless preprocessing once before folds
- cloning and fitting stateful pipeline steps per fold
- calling
predictorpredict_probaas needed by the scorer - collecting fold scores, out-of-fold predictions, and probabilities
- optional holdout scoring
- final fitting through
finalize_pipeline
Custom validators should focus on split policy only.
Split Invariants¶
Good split methods should maintain these invariants:
- returned arrays contain
triallabels present in the container - train, validation, and holdout labels do not overlap
- holdout labels are excluded from
_get_splits - group-level split policies do not split a group across train and validation
- required coordinates are validated early with clear errors
- randomization is controlled by constructor parameters
For grouped validators, validate that the group coordinate is one-dimensional
and attached to trial.
Testing Checklist¶
Add focused tests for split behavior before testing scoring:
def test_session_blocked_splits_do_not_split_sessions(multi_session_container):
cv = SessionBlockedKFoldValidator(n_splits=2, session_coord="session", verbose=False)
train_val, holdout = cv._split_holdout(multi_session_container)
splits = list(cv._get_splits(multi_session_container, train_val))
assert len(holdout) == 0
assert len(splits) == 2
data = multi_session_container.data
for train_trials, val_trials in splits:
train_sessions = set(data.sel(trial=train_trials).coords["session"].values)
val_sessions = set(data.sel(trial=val_trials).coords["session"].values)
assert train_sessions.isdisjoint(val_sessions)
assert set(train_trials).isdisjoint(set(val_trials))
Then add one integration test with a real pipeline and a container that has both the target coordinate and the session coordinate:
def test_session_blocked_cross_validate_runs(session_labeled_container, pipeline):
cv = SessionBlockedKFoldValidator(n_splits=2, session_coord="session", verbose=False)
cv.set_pipeline(pipeline)
score = cv.cross_validate(session_labeled_container, verbose=False)
assert np.isfinite(score)
assert len(cv.cv_scores_) == 2
Sklearn Compatibility¶
If an sklearn estimator needs a cv= splitter, wrap an XDFlow validator with
SklearnCVAdapter and provide the active DataContainer through
set_cv_container.
from sklearn.linear_model import LogisticRegressionCV
from xdflow.cv import SklearnCVAdapter, set_cv_container
adapter = SklearnCVAdapter(cv)
estimator = LogisticRegressionCV(cv=adapter)
with set_cv_container(container):
estimator.fit(X, y)
Use this only for estimators that own their own internal CV. For XDFlow
pipelines, prefer cv.set_pipeline(pipeline) followed by cv.cross_validate.