Skip to content

Cross-Validation API

XDFlow validators run the evaluation loop. They build folds, apply split policies from named coordinates, reuse fold-invariant work, clone and refit stateful steps, score predictions, and keep outputs aligned with the source data.

Use these classes instead of handwritten sklearn split loops when validation depends on metadata, pipeline state, or reusable preprocessing.

For custom split policies, see Writing Custom Cross-Validators.

Base Validator

CrossValidator

CrossValidator(pooling_score_weight: float = 0.0, use_stateful_fit_cache: bool = True, release_fold_memory: bool = False, scoring: str | Callable | None = None, scoring_needs_proba: bool = False, stratify_coord: str | None = None, exclude_intertrial_from_scoring: bool = False, exclude_offsets_from_scoring: bool = False, verbose: bool = True)

Bases: ABC

Base class for evaluating a predictive pipeline with held-out data.

A cross-validator owns the train/validation/holdout splitting strategy and evaluates a complete Pipeline. Stateless pipeline steps can be run once before fold splitting, while stateful steps are cloned and fitted on each fold's training data. This keeps expensive deterministic preprocessing out of the per-fold loop when possible.

Results are stored on the instance after evaluation, including fold scores, out-of-fold predictions, optional probabilities, and holdout predictions. Scorers may accept either (y_true, y_pred) or (y_true, y_pred, container) when coordinate-aware scoring is needed.

Subclasses define only the split policy by implementing _split_holdout and _get_splits.

Initialize scoring, caching, and split-independent CV options.

Parameters:

Name Type Description Default
pooling_score_weight float

Interpolation factor between the average fold score (0.0) and the pooled OOF score (1.0). Defaults to 0.0. Must be between 0.0 and 1.0. Higher values give more weight to folds with more trials.

0.0
use_stateful_fit_cache bool

Whether to cache stateful transforms during CV.

True
release_fold_memory bool

Whether to aggressively release per-fold objects and clear PyTorch CUDA caches after each fold.

False
scoring str | Callable | None

Metric name or callable. If None, classification defaults to "f1_weighted" and regression defaults to "r2". Callable scorers may accept (y_true, y_pred) or (y_true, y_pred, container).

None
scoring_needs_proba bool

Whether a custom scorer expects probabilities from predict_proba instead of hard predictions.

False
exclude_intertrial_from_scoring bool

If True, automatically remove any trials whose event_type coordinate is "intertrial" from CV/holdout scoring.

False
exclude_offsets_from_scoring bool

If True, remove trials whose time_offset_ms coordinate is not 0 from CV/holdout scoring.

False
stratify_coord str | None

Optional coordinate name to use for stratified splits. If set, holdout and CV splits will stratify on this coordinate (must be present in the data). For multi-target/regression tasks, this allows stratifying on a categorical coord such as stimulus.

None
verbose bool

Whether to print verbose output specific to the cross-validator. Verbosity of transforms is separetely controlled by class-level function arguments.

True

Raises:

Type Description
ValueError

If pooling_score_weight is not between 0.0 and 1.0

Source code in xdflow/cv/base.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def __init__(
    self,
    pooling_score_weight: float = 0.0,
    use_stateful_fit_cache: bool = True,
    release_fold_memory: bool = False,
    scoring: str | Callable | None = None,
    scoring_needs_proba: bool = False,
    stratify_coord: str | None = None,
    exclude_intertrial_from_scoring: bool = False,
    exclude_offsets_from_scoring: bool = False,
    verbose: bool = True,
):
    """Initialize scoring, caching, and split-independent CV options.

    Args:
        pooling_score_weight: Interpolation factor between the average fold
            score (0.0) and the pooled OOF score (1.0). Defaults to 0.0.
            Must be between 0.0 and 1.0.
            Higher values give more weight to folds with more trials.
        use_stateful_fit_cache: Whether to cache stateful transforms during CV.
        release_fold_memory: Whether to aggressively release per-fold objects and
            clear PyTorch CUDA caches after each fold.
        scoring: Metric name or callable. If None, classification defaults
            to `"f1_weighted"` and regression defaults to `"r2"`. Callable
            scorers may accept `(y_true, y_pred)` or
            `(y_true, y_pred, container)`.
        scoring_needs_proba: Whether a custom scorer expects probabilities from
            predict_proba instead of hard predictions.
        exclude_intertrial_from_scoring: If True, automatically remove any trials whose
            event_type coordinate is "intertrial" from CV/holdout scoring.
        exclude_offsets_from_scoring: If True, remove trials whose time_offset_ms
            coordinate is not 0 from CV/holdout scoring.
        stratify_coord: Optional coordinate name to use for stratified splits. If set,
            holdout and CV splits will stratify on this coordinate (must be present
            in the data). For multi-target/regression tasks, this allows stratifying
            on a categorical coord such as stimulus.
        verbose: Whether to print verbose output specific to the cross-validator.
            Verbosity of transforms is separetely controlled by class-level function arguments.

    Raises:
        ValueError: If pooling_score_weight is not between 0.0 and 1.0
    """
    if not 0.0 <= pooling_score_weight <= 1.0:
        raise ValueError("pooling_score_weight must be between 0.0 and 1.0")

    self.pooling_score_weight = pooling_score_weight
    self.scoring = scoring
    self.scoring_needs_proba = bool(scoring_needs_proba)

    # Results from cross-validation
    self.cv_scores_ = []
    self.oof_predictions_ = []  # Out-of-fold predictions
    self.oof_probabilities_ = []  # Out-of-fold probabilities/scores
    self.true_labels_ = []

    # Holdout test set management
    # NOTE: `holdout_trial_labels_` stores trial labels in the stateless-preprocessed space.
    #       Trials are assumed to retain their original labels, so these map directly to the raw container.
    self.holdout_trial_labels_: np.ndarray | None = None
    self.holdout_score_: float | None = None
    self.holdout_pred_labels_: np.ndarray | None = None  # Holdout test predictions
    self.holdout_probabilities_: np.ndarray | None = None  # Holdout test probabilities/scores
    self.holdout_true_labels_: np.ndarray | None = None  # Holdout test true labels
    self.holdout_container_: DataContainer | None = None  # Holdout test container (for container-aware scorers)
    self.holdout_scoring_mask_: np.ndarray | None = None  # Mask used by container-aware scorer

    # Set by the user before evaluation (avoid property setter recursion)
    self._pipeline: Pipeline | None = None
    self.final_target_coord_: str | list[str] | None = None
    self.use_stateful_fit_cache = use_stateful_fit_cache
    self.release_fold_memory = release_fold_memory
    self.stratify_coord = stratify_coord
    self.exclude_intertrial_from_scoring = exclude_intertrial_from_scoring
    self.exclude_offsets_from_scoring = exclude_offsets_from_scoring

    # Resolved scoring function (set after pipeline is known)
    self._scoring_func: Callable | None = None
    self._metric_name: str | None = None
    self._scoring_needs_proba = False
    self._scoring_accepts_container = False

    self.verbose = verbose

holdout_confusion_matrix_ property

holdout_confusion_matrix_: ndarray

Calculate confusion matrix from holdout test predictions.

If a scoring mask has been set via compute_holdout_scoring_mask(), the confusion matrix will be computed only on the filtered samples, matching the scorer's logic.

Returns:

Type Description
ndarray

Confusion matrix as numpy array

Raises:

Type Description
ValueError

If no holdout predictions available or task is not classification

holdout_confusion_matrix_normalized_ property

holdout_confusion_matrix_normalized_: ndarray

Calculate normalized confusion matrix from holdout test predictions.

If a scoring mask has been set via compute_holdout_scoring_mask(), the confusion matrix will be computed only on the filtered samples, matching the scorer's logic.

Returns:

Type Description
ndarray

Normalized confusion matrix as numpy array (rows sum to 1)

Raises:

Type Description
ValueError

If no holdout predictions available or task is not classification

metric_name_ property

metric_name_: str

Get the name of the scoring metric used for evaluation.

Returns:

Type Description
str

Name of the metric (e.g., 'r2', 'mse', 'f1_weighted', 'custom')

oof_score_ property

oof_score_: float

Calculate the selected metric score from out-of-fold predictions.

Note: For scorers that require a container argument, OOF scoring is not possible since predictions come from multiple folds. In this case, returns the mean CV score as a fallback.

Returns:

Type Description
float

Score calculated using the selected scoring function

Raises:

Type Description
ValueError

If no out-of-fold predictions available

mean_cv_score_ property

mean_cv_score_: float

Get the mean cross-validation score across all folds.

Returns:

Type Description
float

Mean of scores from all cross-validation folds

Raises:

Type Description
ValueError

If no cross-validation scores available

oof_f1_score_ property

oof_f1_score_: float

Calculate F1 score from out-of-fold predictions.

Convenience property for classification tasks that always returns weighted F1 score, regardless of the configured scoring metric.

Returns:

Type Description
float

Weighted F1 score across all out-of-fold predictions

Raises:

Type Description
ValueError

If no out-of-fold predictions available

holdout_f1_score_ property

holdout_f1_score_: float

Calculate F1 score from holdout predictions.

Convenience property for classification tasks that always returns weighted F1 score, regardless of the configured scoring metric.

Returns:

Type Description
float

Weighted F1 score from holdout predictions

Raises:

Type Description
ValueError

If no holdout predictions available

mean_cv_f1_score_ property

mean_cv_f1_score_: float

Get the mean cross-validation F1 score across all folds.

Returns:

Type Description
float

Mean of F1 scores from all cross-validation folds

Raises:

Type Description
ValueError

If no cross-validation scores available

oof_confusion_matrix_ property

oof_confusion_matrix_: ndarray

Calculate confusion matrix from out-of-fold predictions.

Returns:

Type Description
ndarray

Confusion matrix as numpy array

Raises:

Type Description
ValueError

If no out-of-fold predictions available or task is not classification

oof_confusion_matrix_normalized_ property

oof_confusion_matrix_normalized_: ndarray

Calculate normalized confusion matrix from out-of-fold predictions.

Returns:

Type Description
ndarray

Normalized confusion matrix as numpy array (rows sum to 1)

Raises:

Type Description
ValueError

If no out-of-fold predictions available or task is not classification

score_ property

score_: float

Calculate the final CV score based on the pooling_score_weight.

Blends the average fold score and pooled out-of-fold score using: score = (1 - pooling_score_weight) * mean_cv_f1_score_ + pooling_score_weight * oof_f1_score_

When pooling_score_weight = 0.0: Returns average fold score (standard behavior) When pooling_score_weight = 1.0: Returns pooled OOF score When pooling_score_weight = 0.5: Returns equal blend of both

Returns:

Type Description
float

Final blended cross-validation score

Raises:

Type Description
ValueError

If no cross-validation scores are available

set_pipeline

set_pipeline(pipeline: Pipeline)

Set the pipeline to be used for cross-validation.

Parameters:

Name Type Description Default
pipeline Pipeline

Pipeline to be used for cross-validation

required
Source code in xdflow/cv/base.py
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
def set_pipeline(self, pipeline: Pipeline):
    """
    Set the pipeline to be used for cross-validation.

    Args:
        pipeline: Pipeline to be used for cross-validation
    """
    if pipeline is None:
        raise ValueError("Pipeline cannot be None")

    # Assign internal attribute to avoid triggering setter recursion
    self._pipeline = pipeline

    # Last step of pipeline must be an predictor
    if not pipeline.is_predictor:
        raise ValueError("The last pipeline step must be a Predictor")
    self.final_target_coord_ = pipeline.final_target_coord
    self.holdout_trial_labels_ = None

    # Early validation: disallow multi-target classification
    final_predictor = self._get_final_predictor()
    if final_predictor is None:
        raise ValueError("Pipeline must end with a Predictor to use cross-validation.")
    if (
        final_predictor.is_classifier
        and isinstance(self.final_target_coord_, list)
        and not getattr(final_predictor, "is_multilabel", False)
    ):
        raise ValueError("Multi-target classification is not supported; use a single target_coord for classifiers.")

cross_validate

cross_validate(initial_container: DataContainer, verbose: bool = False, pruning_callback: Callable[[int, float], None] | None = None, **kwargs) -> float

Runs the full cross-validation process on the train and validation sets. Held out test set is not used here.

This method automatically detects stateless and stateful pipeline components and executes them optimally: stateless parts run once, stateful parts per fold.

Parameters:

Name Type Description Default
initial_container DataContainer

Input DataContainer to cross-validate on

required
verbose bool

Whether to enable verbose logging in transforms

False
**kwargs

Additional arguments passed to splitting methods

{}

Returns:

Type Description
float

Mean cross-validation score

Raises:

Type Description
ValueError

If no pipeline is assigned

Source code in xdflow/cv/base.py
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
def cross_validate(
    self,
    initial_container: DataContainer,
    verbose: bool = False,
    pruning_callback: Callable[[int, float], None] | None = None,
    **kwargs,
) -> float:
    """
    Runs the full cross-validation process on the train and validation sets. Held out test set is not used here.

    This method automatically detects stateless and stateful pipeline components and
    executes them optimally: stateless parts run once, stateful parts per fold.

    Args:
        initial_container: Input DataContainer to cross-validate on
        verbose: Whether to enable verbose logging in transforms
        **kwargs: Additional arguments passed to splitting methods

    Returns:
        Mean cross-validation score

    Raises:
        ValueError: If no pipeline is assigned
    """
    # Fit encoders globally on initial data
    self._find_and_fit_encoders(self.pipeline, initial_container)

    # Step 1: Auto-detect and split pipeline into stateless and stateful parts
    if self.verbose:
        print("Auto-detecting pipeline structure...")
    stateless_pipeline, stateful_pipeline = self._auto_detect_pipeline_parts(self.pipeline)

    # Step 2: Log pipeline structure information
    self._log_pipeline_structure(self.pipeline, stateless_pipeline, stateful_pipeline)

    # Step 3: Run stateless preprocessing once before the CV loop
    preprocessed_data = self._run_stateless_preprocessing(stateless_pipeline, initial_container, verbose)

    # Step 4: Split into train/validation and holdout sets using preprocessed data
    train_val_indices, holdout_indices = self._split_holdout(preprocessed_data)
    self.holdout_trial_labels_ = holdout_indices

    # Step 5: Generate cross-validation splits on train/validation data only
    splits = self._get_splits(preprocessed_data, train_val_indices)

    # Step 6: Reset evaluation metrics
    self.cv_scores_ = []
    self.oof_predictions_ = []
    self.oof_probabilities_ = []
    self.true_labels_ = []

    # Step 7: Run cross-validation loop with stateful pipeline
    if stateful_pipeline is not None:
        if self.verbose:
            print("Running cross-validation with stateful pipeline...")
        for fold_idx, (train_indices, validation_indices) in enumerate(splits):
            self._process_cv_fold(
                fold_idx,
                train_indices,
                validation_indices,
                preprocessed_data,
                stateful_pipeline,
                verbose,
                pruning_callback=pruning_callback,
            )

        # Print summary of fold scores
        if self.verbose and self.cv_scores_:
            _, metric_name, _ = self._get_scoring_func()
            print("\nCross-validation summary:")
            print(f"  Individual fold {metric_name} scores: {[f'{score:.4f}' for score in self.cv_scores_]}")
            print(f"  Mean {metric_name}: {self.mean_cv_score_:.4f}")
            print(f"  Std {metric_name}: {np.std(self.cv_scores_):.4f}")
    else:
        # Edge case: no stateful steps (shouldn't happen with predictors, but handle gracefully)
        raise ValueError("Pipeline must contain at least one stateful step (typically a Predictor)")

    return self.score_

finalize_pipeline

finalize_pipeline(container: DataContainer, verbose: bool = False) -> Pipeline

Finalizes a model for production by fitting on the entire provided container.

Parameters:

Name Type Description Default
container DataContainer

DataContainer to fit the final model on

required
verbose bool

Whether to enable verbose logging in transforms

False

Returns:

Type Description
Pipeline

The fitted pipeline object, ready for inference.

Source code in xdflow/cv/base.py
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
def finalize_pipeline(self, container: DataContainer, verbose: bool = False) -> "Pipeline":
    """
    Finalizes a model for production by fitting on the entire provided container.

    Args:
        container: DataContainer to fit the final model on
        verbose: Whether to enable verbose logging in transforms

    Returns:
        The fitted pipeline object, ready for inference.
    """
    fitted_pipeline = self.pipeline.clone()

    # Fit encoders globally
    self._find_and_fit_encoders(self.pipeline, container)

    fitted_pipeline.fit(container, verbose=verbose)
    prepare_for_inference = getattr(fitted_pipeline, "prepare_for_inference", None)
    if callable(prepare_for_inference):
        prepare_for_inference()

    return fitted_pipeline

score_on_holdout

score_on_holdout(initial_container: DataContainer, verbose: bool = False) -> float

Performs the final evaluation on the held-out test set.

Parameters:

Name Type Description Default
initial_container DataContainer

The original DataContainer used in cross_validate()

required
verbose bool

Whether to enable verbose logging in transforms

False

Returns:

Type Description
float

Final holdout test score

Raises:

Type Description
ValueError

If holdout indices don't exist (cross_validate() not called first)

Source code in xdflow/cv/base.py
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
def score_on_holdout(self, initial_container: DataContainer, verbose: bool = False) -> float:
    """
    Performs the final evaluation on the held-out test set.

    Args:
        initial_container: The original DataContainer used in cross_validate()
        verbose: Whether to enable verbose logging in transforms

    Returns:
        Final holdout test score

    Raises:
        ValueError: If holdout indices don't exist (cross_validate() not called first)
    """

    # Ensure encoders are fitted in-place on predictors
    self._find_and_fit_encoders(self.pipeline, initial_container)

    # Run the same preprocessing as in cross_validate()
    stateless_pipeline, stateful_pipeline = self._auto_detect_pipeline_parts(self.pipeline)
    assert stateful_pipeline is not None, "There must be at least one stateful step in the pipeline, for fitting."

    if stateless_pipeline is not None:
        preprocessed_data = stateless_pipeline.fit_transform(initial_container, verbose=verbose)
    else:
        preprocessed_data = initial_container

    if self.holdout_trial_labels_ is None:
        warnings.warn(
            "cross_validate() not called first so no holdout indices available. Calculating holdout indices now.",
            stacklevel=2,
        )  # necessary since Tuner makes a deepcopy of validators
        train_val_indices, holdout_indices = self._split_holdout(preprocessed_data)
        self.holdout_trial_labels_ = holdout_indices

    if len(self.holdout_trial_labels_) == 0:
        raise ValueError("No holdout data available for testing.")

    # Create train/validation container (all data except holdout)
    all_trials = preprocessed_data.data.trial.values
    train_val_mask = ~np.isin(all_trials, self.holdout_trial_labels_)
    train_val_indices = all_trials[train_val_mask]

    train_val_container = DataContainer(preprocessed_data.data.sel(trial=train_val_indices))
    test_container = DataContainer(preprocessed_data.data.sel(trial=self.holdout_trial_labels_))

    # Fit the stateful part of the pipeline on the train/validation data (no caching for holdout)
    stateful_pipeline_fitted = stateful_pipeline.clone()
    stateful_pipeline_fitted.fit(train_val_container, verbose=verbose)

    # Predict on the holdout test set
    test_results_container = stateful_pipeline_fitted.predict(test_container, verbose=verbose)
    scoring_func, _, needs_proba = self._get_scoring_func()
    holdout_probabilities = None
    if needs_proba:
        holdout_probabilities = stateful_pipeline_fitted.predict_proba(test_container, verbose=verbose).data.values

    final_predictor = stateful_pipeline_fitted.predictive_transform
    if final_predictor is None:
        raise ValueError("Stateful pipeline must expose a predictive transform.")
    pred_labels = test_results_container.data.values
    true_labels = self._extract_targets(cast(Predictor, final_predictor), test_results_container)

    pred_labels, true_labels, scoring_container, scoring_mask = self._filter_scoring_inputs(
        pred_labels,
        true_labels,
        test_results_container,
        context="holdout",
    )
    scoring_values = pred_labels
    if needs_proba:
        if holdout_probabilities is None:
            raise RuntimeError("Scoring requires probabilities, but none were produced.")
        scoring_values = holdout_probabilities if scoring_mask is None else holdout_probabilities[scoring_mask]

    # Store the holdout container and filtered predictions for container-aware scorers
    self.holdout_container_ = scoring_container
    self.holdout_pred_labels_ = pred_labels
    self.holdout_probabilities_ = scoring_values if needs_proba else holdout_probabilities
    self.holdout_true_labels_ = true_labels
    self.holdout_scoring_mask_ = None

    # Calculate and store final score using the selected scoring function
    if self._scoring_accepts_container:
        self.holdout_score_ = scoring_func(self.holdout_true_labels_, scoring_values, scoring_container)
    else:
        self.holdout_score_ = scoring_func(self.holdout_true_labels_, scoring_values)

    return self.holdout_score_

compute_holdout_scoring_mask

compute_holdout_scoring_mask(mask_func: Callable[[DataContainer], ndarray]) -> np.ndarray

Compute and store the mask used by a container-aware scorer.

This method should be called after score_on_holdout() when using a custom container-aware scorer that filters samples. The mask will be used to generate filtered confusion matrices that match the scoring logic.

Parameters:

Name Type Description Default
mask_func Callable[[DataContainer], ndarray]

Function that takes a DataContainer and returns a boolean mask array. Should implement the same filtering logic as the custom scorer. Example: lambda c: (c.coords['concentration_bin'] == 'conc_2p4')

required

Returns:

Type Description
ndarray

The computed boolean mask array

Raises:

Type Description
ValueError

If no holdout container is available, or if mask shape/dtype is invalid

Example

cv.score_on_holdout(data_container) cv.compute_holdout_scoring_mask(lambda c: c.coords['concentration_bin'] == 'target') cm = cv.holdout_confusion_matrix_ # Now filtered to match scorer

Source code in xdflow/cv/base.py
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
def compute_holdout_scoring_mask(self, mask_func: Callable[[DataContainer], np.ndarray]) -> np.ndarray:
    """
    Compute and store the mask used by a container-aware scorer.

    This method should be called after score_on_holdout() when using a custom
    container-aware scorer that filters samples. The mask will be used to generate
    filtered confusion matrices that match the scoring logic.

    Args:
        mask_func: Function that takes a DataContainer and returns a boolean mask array.
                  Should implement the same filtering logic as the custom scorer.
                  Example: lambda c: (c.coords['concentration_bin'] == 'conc_2p4')

    Returns:
        The computed boolean mask array

    Raises:
        ValueError: If no holdout container is available, or if mask shape/dtype is invalid

    Example:
        >>> cv.score_on_holdout(data_container)
        >>> cv.compute_holdout_scoring_mask(lambda c: c.coords['concentration_bin'] == 'target')
        >>> cm = cv.holdout_confusion_matrix_  # Now filtered to match scorer
    """
    if self.holdout_container_ is None:
        raise ValueError("No holdout container available. Run score_on_holdout() first.")
    if self.holdout_pred_labels_ is None:
        raise ValueError("No holdout predictions available. Run score_on_holdout() first.")

    mask = mask_func(self.holdout_container_)

    # Validate mask
    if not isinstance(mask, np.ndarray):
        mask = np.asarray(mask)

    if mask.shape != self.holdout_pred_labels_.shape:
        raise ValueError(
            f"Mask shape {mask.shape} does not match holdout predictions shape "
            f"{self.holdout_pred_labels_.shape}. The mask must have one boolean value per sample."
        )

    if mask.dtype != np.bool_:
        # Try to convert to boolean
        try:
            mask = mask.astype(np.bool_)
        except (ValueError, TypeError) as e:
            raise ValueError(f"Mask must be boolean or convertible to boolean, got dtype {mask.dtype}") from e

    self.holdout_scoring_mask_ = mask
    return self.holdout_scoring_mask_

get_fold_scores

get_fold_scores() -> list

Get individual fold scores.

Returns:

Type Description
list

List of scores for each cross-validation fold

Raises:

Type Description
ValueError

If no cross-validation scores available

Source code in xdflow/cv/base.py
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
def get_fold_scores(self) -> list:
    """
    Get individual fold scores.

    Returns:
        List of scores for each cross-validation fold

    Raises:
        ValueError: If no cross-validation scores available
    """
    if not self.cv_scores_:
        raise ValueError("No cross-validation scores available. Run cross_validate() first.")

    return self.cv_scores_.copy()

get_holdout_container

get_holdout_container(initial_container: DataContainer, *, verbose: bool = False) -> DataContainer

Return the holdout trials from the original data container.

This helper returns the raw-space slice referenced by holdout_trial_labels_.

Parameters:

Name Type Description Default
initial_container DataContainer

The original DataContainer provided to cross_validate()/score_on_holdout().

required
verbose bool

Whether to enable verbose logging in transforms.

False

Returns:

Type Description
DataContainer

DataContainer containing only the holdout trials in raw space.

Raises:

Type Description
ValueError

If no holdout data is available (e.g., test_size not set).

Source code in xdflow/cv/base.py
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
def get_holdout_container(self, initial_container: DataContainer, *, verbose: bool = False) -> DataContainer:
    """
    Return the holdout trials from the original data container.

    This helper returns the raw-space slice referenced by `holdout_trial_labels_`.

    Args:
        initial_container: The original DataContainer provided to cross_validate()/score_on_holdout().
        verbose: Whether to enable verbose logging in transforms.

    Returns:
        DataContainer containing only the holdout trials in raw space.

    Raises:
        ValueError: If no holdout data is available (e.g., test_size not set).
    """
    if self.holdout_trial_labels_ is None:
        self._find_and_fit_encoders(self.pipeline, initial_container)
        stateless_pipeline, _ = self._auto_detect_pipeline_parts(self.pipeline)
        if stateless_pipeline is not None:
            preprocessed_data = stateless_pipeline.fit_transform(initial_container, verbose=verbose)
        else:
            preprocessed_data = initial_container
        _, holdout_indices = self._split_holdout(preprocessed_data)
        self.holdout_trial_labels_ = holdout_indices

    if len(self.holdout_trial_labels_) == 0:
        raise ValueError("No holdout data available for testing.")

    # Select directly from the original container using trial labels
    holdout_da = initial_container.data.sel(trial=self.holdout_trial_labels_)
    return DataContainer(holdout_da)

plot_confusion_matrix

plot_confusion_matrix(use_holdout: bool = True, normalize: bool = True, title_info: str = '', **kwargs)

Plot the confusion matrix.

Note: Only works for classification tasks.

Raises:

Type Description
ValueError

If the pipeline is not a classifier

Source code in xdflow/cv/base.py
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
def plot_confusion_matrix(self, use_holdout: bool = True, normalize: bool = True, title_info: str = "", **kwargs):
    """
    Plot the confusion matrix.

    Note: Only works for classification tasks.

    Raises:
        ValueError: If the pipeline is not a classifier
    """
    # Check if this is a classification task
    final_predictor = self.pipeline.predictive_transform
    if final_predictor and (not final_predictor.is_classifier or getattr(final_predictor, "is_multilabel", False)):
        raise ValueError("Confusion matrix plotting is only available for classification tasks.")

    if use_holdout:
        conf_matrix = self.holdout_confusion_matrix_normalized_ if normalize else self.holdout_confusion_matrix_
        f1_score = self.holdout_f1_score_
    else:
        conf_matrix = self.oof_confusion_matrix_normalized_ if normalize else self.oof_confusion_matrix_
        f1_score = self.oof_f1_score_

    # Use classes from the final predictor's encoder
    predictive_transform = self.pipeline.predictive_transform
    if predictive_transform is None or predictive_transform.encoder is None:
        raise ValueError("Pipeline predictor must have a fitted encoder before plotting confusion matrix.")
    classes = predictive_transform.encoder.classes_

    if title_info:
        title_info = f"{title_info},"

    plot_confusion_matrix(conf_matrix, classes, title=f"{title_info} F1 score: {f1_score:.4f}", **kwargs)

K-Fold Validators

KFoldValidator

KFoldValidator(n_splits: int = 5, shuffle: bool = True, random_state: int = 0, test_size: float | None = None, pooling_score_weight: float = 0.0, scoring: str | Callable | None = None, stratify_coord: str | None = None, exclude_intertrial_from_scoring: bool = False, exclude_offsets_from_scoring: bool = False, use_stateful_fit_cache: bool = True, release_fold_memory: bool = False, scoring_needs_proba: bool = False, verbose: bool = True)

Bases: CrossValidator

Implements cross-validation using a stratified K-Fold strategy with optional holdout set.

This provides a concrete implementation of CrossValidator using scikit-learn's StratifiedKFold for balanced splits across classes.

Initialize KFold cross-validator.

Parameters:

Name Type Description Default
n_splits int

Number of folds for cross-validation

5
shuffle bool

Whether to shuffle data before splitting

True
random_state int

Random seed for reproducibility

0
test_size float | None

Proportion of data to use as holdout test set (0.0-1.0). If None or 0, no holdout set is created.

None
pooling_score_weight float

Interpolation factor between the average fold score (0.0) and the pooled OOF score (1.0). Defaults to 0.0.

0.0
scoring str | Callable | None

Scoring metric to use. If None, auto-selects based on task type.

None
stratify_coord str | None

Optional coordinate name to use for stratified splits (train/val/holdout).

None
exclude_intertrial_from_scoring bool

Whether to drop intertrial segments when evaluating folds/holdout.

False
use_stateful_fit_cache bool

Whether to cache stateful transforms during CV.

True
verbose bool

Whether to print verbose output specific to cross-validation.

True
Source code in xdflow/cv/kfold.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def __init__(
    self,
    n_splits: int = 5,
    shuffle: bool = True,
    random_state: int = 0,
    test_size: float | None = None,
    pooling_score_weight: float = 0.0,
    scoring: str | Callable | None = None,
    stratify_coord: str | None = None,
    exclude_intertrial_from_scoring: bool = False,
    exclude_offsets_from_scoring: bool = False,
    use_stateful_fit_cache: bool = True,
    release_fold_memory: bool = False,
    scoring_needs_proba: bool = False,
    verbose: bool = True,
):
    """
    Initialize KFold cross-validator.

    Args:
        n_splits: Number of folds for cross-validation
        shuffle: Whether to shuffle data before splitting
        random_state: Random seed for reproducibility
        test_size: Proportion of data to use as holdout test set (0.0-1.0).
                  If None or 0, no holdout set is created.
        pooling_score_weight: Interpolation factor between the average fold
            score (0.0) and the pooled OOF score (1.0). Defaults to 0.0.
        scoring: Scoring metric to use. If None, auto-selects based on task type.
        stratify_coord: Optional coordinate name to use for stratified splits (train/val/holdout).
        exclude_intertrial_from_scoring: Whether to drop intertrial segments when evaluating folds/holdout.
        use_stateful_fit_cache: Whether to cache stateful transforms during CV.
        verbose: Whether to print verbose output specific to cross-validation.
    """
    super().__init__(
        pooling_score_weight=pooling_score_weight,
        scoring=scoring,
        scoring_needs_proba=scoring_needs_proba,
        stratify_coord=stratify_coord,
        verbose=verbose,
        use_stateful_fit_cache=use_stateful_fit_cache,
        release_fold_memory=release_fold_memory,
        exclude_intertrial_from_scoring=exclude_intertrial_from_scoring,
        exclude_offsets_from_scoring=exclude_offsets_from_scoring,
    )
    self.n_splits = n_splits
    self.shuffle = shuffle
    self.random_state = random_state
    self.test_size = test_size

GroupedKFoldValidator

GroupedKFoldValidator(n_splits: int = 5, shuffle: bool = True, random_state: int = 0, test_size: float | None = None, pooling_score_weight: float = 0.0, group_coord: str | None = None, train_groups: list[Hashable] | Hashable | None = None, val_groups: list[Hashable] | Hashable | None = None, test_groups: list[Hashable] | Hashable | None = None, scoring: str | Callable | None = None, stratify_coord: str | None = None, stratify_by_group: bool = True, exclude_intertrial_from_scoring: bool = False, exclude_offsets_from_scoring: bool = False, use_stateful_fit_cache: bool = True, release_fold_memory: bool = False, scoring_needs_proba: bool = False, verbose: bool = True)

Bases: CrossValidator

Implements cross-validation using a stratified K-Fold strategy. Groups are specified by the group_coord parameter. K-folds are stratified by both the group and target coordinates. Specific groups can be specified for training, validation, and testing using the values of the group_coord coordinate. If no groups are specified, all groups are used for training/validation/testing.

E.g. if group_coord = 'animal', train_groups = None, val_groups = [35], and test_groups = [35], all data will be used for training, but only animal 35 will be used for validation and testing.

Useful for testing the performance of a model across different groups, especially for domain adaptation.

Initialize GroupedKFoldValidator.

Parameters:

Name Type Description Default
n_splits int

Number of folds for cross-validation

5
shuffle bool

Whether to shuffle data before splitting

True
random_state int

Random seed for reproducibility

0
test_size float | None

Proportion of data to use as holdout test set (0.0-1.0). If None or 0, no holdout set is created.

None
pooling_score_weight float

Interpolation factor between the average fold score (0.0) and the pooled OOF score (1.0). Defaults to 0.0.

0.0
group_coord str | None

Coordinate to group by.

None
train_groups list[Hashable] | Hashable | None

Groups to use for training. If None, all groups are used.

None
val_groups list[Hashable] | Hashable | None

Groups to use for validation. If None, all groups are used.

None
test_groups list[Hashable] | Hashable | None

Groups to use for testing. If None, all groups are used.

None
scoring str | Callable | None

Scoring metric to use. If None, auto-selects based on task type.

None
stratify_coord str | None

Optional coordinate name to use for stratified splits.

None
stratify_by_group bool

Whether to stratify splits by group coordinate in addition to target. If True (default), stratifies by group+target combination. If False, only stratifies by target (or stratify_coord if set).

True
exclude_intertrial_from_scoring bool

Whether to drop intertrial segments during evaluation.

False
use_stateful_fit_cache bool

Whether to cache stateful transforms during CV.

True
verbose bool

Whether to print verbose output specific to cross-validation.

True
Source code in xdflow/cv/kfold.py
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def __init__(
    self,
    n_splits: int = 5,
    shuffle: bool = True,
    random_state: int = 0,
    test_size: float | None = None,
    pooling_score_weight: float = 0.0,
    group_coord: str | None = None,
    train_groups: list[Hashable] | Hashable | None = None,
    val_groups: list[Hashable] | Hashable | None = None,
    test_groups: list[Hashable] | Hashable | None = None,
    scoring: str | Callable | None = None,
    stratify_coord: str | None = None,
    stratify_by_group: bool = True,
    exclude_intertrial_from_scoring: bool = False,
    exclude_offsets_from_scoring: bool = False,
    use_stateful_fit_cache: bool = True,
    release_fold_memory: bool = False,
    scoring_needs_proba: bool = False,
    verbose: bool = True,
):
    """
    Initialize GroupedKFoldValidator.

    Args:
        n_splits: Number of folds for cross-validation
        shuffle: Whether to shuffle data before splitting
        random_state: Random seed for reproducibility
        test_size: Proportion of data to use as holdout test set (0.0-1.0).
                  If None or 0, no holdout set is created.
        pooling_score_weight: Interpolation factor between the average fold
            score (0.0) and the pooled OOF score (1.0). Defaults to 0.0.
        group_coord: Coordinate to group by.
        train_groups: Groups to use for training. If None, all groups are used.
        val_groups: Groups to use for validation. If None, all groups are used.
        test_groups: Groups to use for testing. If None, all groups are used.
        scoring: Scoring metric to use. If None, auto-selects based on task type.
        stratify_coord: Optional coordinate name to use for stratified splits.
        stratify_by_group: Whether to stratify splits by group coordinate in addition to target.
                          If True (default), stratifies by group+target combination.
                          If False, only stratifies by target (or stratify_coord if set).
        exclude_intertrial_from_scoring: Whether to drop intertrial segments during evaluation.
        use_stateful_fit_cache: Whether to cache stateful transforms during CV.
        verbose: Whether to print verbose output specific to cross-validation.
    """
    super().__init__(
        pooling_score_weight=pooling_score_weight,
        scoring=scoring,
        scoring_needs_proba=scoring_needs_proba,
        use_stateful_fit_cache=use_stateful_fit_cache,
        release_fold_memory=release_fold_memory,
        exclude_intertrial_from_scoring=exclude_intertrial_from_scoring,
        exclude_offsets_from_scoring=exclude_offsets_from_scoring,
        stratify_coord=stratify_coord,
        verbose=verbose,
    )
    self.n_splits = n_splits
    self.shuffle = shuffle
    self.random_state = random_state
    self.test_size = test_size
    self.group_coord = group_coord
    self.stratify_by_group = stratify_by_group

    self.train_groups = _as_group_list(train_groups)
    self.val_groups = _as_group_list(val_groups)
    self.test_groups = _as_group_list(test_groups)

Domain Sampling

SampledDomainKFoldValidator

SampledDomainKFoldValidator(*, domain_coord: str, target_domains: list[Hashable] | Hashable, source_domains: list[Hashable] | Hashable | None = None, label_coord: str | None = None, label_sample_counts: Mapping[Hashable, int | None] | None = None, default_samples_per_label: int | None = None, n_splits: int = 5, shuffle: bool = True, random_state: int = 0, test_size: float | None = None, pooling_score_weight: float = 0.0, scoring: str | Callable | None = None, scoring_needs_proba: bool = False, stratify_coord: str | None = None, exclude_intertrial_from_scoring: bool = False, exclude_offsets_from_scoring: bool = False, use_stateful_fit_cache: bool = True, release_fold_memory: bool = False, verbose: bool = True)

Bases: CrossValidator

K-fold validation on target domains with sampled target-domain training trials.

Splits are created on target-domain trials only. For each fold: - validation contains one fold of target-domain trials - training contains all source-domain trials plus a sampled subset of the remaining target-domain trials

Target-domain sampling is label-conditional. Use label_sample_counts for per-label overrides and default_samples_per_label for all other labels. A count of 0 means zero-shot for that label; None means use all available target training samples for that label.

Source code in xdflow/cv/domain.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def __init__(
    self,
    *,
    domain_coord: str,
    target_domains: list[Hashable] | Hashable,
    source_domains: list[Hashable] | Hashable | None = None,
    label_coord: str | None = None,
    label_sample_counts: Mapping[Hashable, int | None] | None = None,
    default_samples_per_label: int | None = None,
    n_splits: int = 5,
    shuffle: bool = True,
    random_state: int = 0,
    test_size: float | None = None,
    pooling_score_weight: float = 0.0,
    scoring: str | Callable | None = None,
    scoring_needs_proba: bool = False,
    stratify_coord: str | None = None,
    exclude_intertrial_from_scoring: bool = False,
    exclude_offsets_from_scoring: bool = False,
    use_stateful_fit_cache: bool = True,
    release_fold_memory: bool = False,
    verbose: bool = True,
):
    super().__init__(
        pooling_score_weight=pooling_score_weight,
        scoring=scoring,
        scoring_needs_proba=scoring_needs_proba,
        stratify_coord=stratify_coord,
        verbose=verbose,
        use_stateful_fit_cache=use_stateful_fit_cache,
        release_fold_memory=release_fold_memory,
        exclude_intertrial_from_scoring=exclude_intertrial_from_scoring,
        exclude_offsets_from_scoring=exclude_offsets_from_scoring,
    )
    self.domain_coord = domain_coord
    self.target_domains = target_domains if isinstance(target_domains, list) else [target_domains]
    self.source_domains = (
        None
        if source_domains is None
        else (source_domains if isinstance(source_domains, list) else [source_domains])
    )
    self.label_coord = label_coord
    self.label_sample_counts = dict(label_sample_counts or {})
    self.default_samples_per_label = default_samples_per_label
    self.n_splits = n_splits
    self.shuffle = shuffle
    self.random_state = random_state
    self.test_size = test_size

    if not self.target_domains:
        raise ValueError("target_domains must be provided and non-empty.")
    for label, count in self.label_sample_counts.items():
        if count is not None and count < 0:
            raise ValueError(f"label_sample_counts[{label!r}] must be >= 0 or None.")
    if self.default_samples_per_label is not None and self.default_samples_per_label < 0:
        raise ValueError("default_samples_per_label must be >= 0 or None.")

score_on_holdout

score_on_holdout(initial_container: DataContainer, verbose: bool = False) -> float

Fit and score on the target-domain holdout using the validator sampling policy.

Unlike the base KFoldValidator holdout path, the final training set is not all non-holdout trials. It is all source-domain trials plus the same label-conditional sampled subset of non-holdout target-domain trials used during cross-validation. This keeps holdout scoring aligned with the few-shot/zero-shot transfer regime configured for the validator.

Source code in xdflow/cv/domain.py
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
def score_on_holdout(self, initial_container: DataContainer, verbose: bool = False) -> float:
    """Fit and score on the target-domain holdout using the validator sampling policy.

    Unlike the base ``KFoldValidator`` holdout path, the final training set is not
    all non-holdout trials. It is all source-domain trials plus the same
    label-conditional sampled subset of non-holdout target-domain trials used
    during cross-validation. This keeps holdout scoring aligned with the
    few-shot/zero-shot transfer regime configured for the validator.
    """
    self._find_and_fit_encoders(self.pipeline, initial_container)

    stateless_pipeline, stateful_pipeline = self._auto_detect_pipeline_parts(self.pipeline)
    assert stateful_pipeline is not None, "There must be at least one stateful step in the pipeline, for fitting."

    if stateless_pipeline is not None:
        preprocessed_data = stateless_pipeline.fit_transform(initial_container, verbose=verbose)
    else:
        preprocessed_data = initial_container

    if self.holdout_trial_labels_ is None:
        warnings.warn(
            "cross_validate() not called first so no holdout indices available. Calculating holdout indices now.",
            stacklevel=2,
        )
        _, holdout_indices = self._split_holdout(preprocessed_data)
        self.holdout_trial_labels_ = holdout_indices

    if len(self.holdout_trial_labels_) == 0:
        raise ValueError("No holdout data available for testing.")

    label_coord = self._resolve_label_coord(preprocessed_data)
    labels_all = preprocessed_data.data.coords[label_coord].values
    _, source_mask, target_mask = self._resolve_domain_masks(preprocessed_data)
    target_trials = preprocessed_data.data.trial.values[target_mask]
    source_trials = preprocessed_data.data.trial.values[source_mask]
    target_labels = labels_all[target_mask]

    train_target_mask = ~np.isin(target_trials, self.holdout_trial_labels_)
    sampled_target = self._sample_target_train_indices(
        target_trials[train_target_mask], target_labels[train_target_mask], fold_idx=0
    )
    train_val_indices = np.concatenate([source_trials, sampled_target])
    if train_val_indices.size == 0:
        raise ValueError("Training set is empty after sampling. Check source_domains and sampling settings.")

    train_val_container = DataContainer(preprocessed_data.data.sel(trial=train_val_indices))
    test_container = DataContainer(preprocessed_data.data.sel(trial=self.holdout_trial_labels_))

    stateful_pipeline_fitted = stateful_pipeline.clone()
    stateful_pipeline_fitted.fit(train_val_container, verbose=verbose)

    test_results_container = stateful_pipeline_fitted.predict(test_container, verbose=verbose)
    scoring_func, _, needs_proba = self._get_scoring_func()
    holdout_probabilities = None
    if needs_proba:
        holdout_probabilities = stateful_pipeline_fitted.predict_proba(test_container, verbose=verbose).data.values

    final_predictor = stateful_pipeline_fitted.predictive_transform
    if final_predictor is None:
        raise ValueError("Stateful pipeline must expose a predictive transform.")
    pred_labels = test_results_container.data.values
    true_labels = self._extract_targets(cast("Predictor", final_predictor), test_results_container)

    pred_labels, true_labels, scoring_container, scoring_mask = self._filter_scoring_inputs(
        pred_labels,
        true_labels,
        test_results_container,
        context="holdout",
    )
    scoring_values = pred_labels
    if needs_proba:
        if holdout_probabilities is None:
            raise RuntimeError("Scoring requires probabilities, but none were produced.")
        scoring_values = holdout_probabilities if scoring_mask is None else holdout_probabilities[scoring_mask]

    self.holdout_container_ = scoring_container
    self.holdout_pred_labels_ = pred_labels
    self.holdout_probabilities_ = scoring_values if needs_proba else holdout_probabilities
    self.holdout_true_labels_ = true_labels
    self.holdout_scoring_mask_ = None

    if self._scoring_accepts_container:
        self.holdout_score_ = scoring_func(self.holdout_true_labels_, scoring_values, scoring_container)
    else:
        self.holdout_score_ = scoring_func(self.holdout_true_labels_, scoring_values)

    return self.holdout_score_

Leave-Group-Out Validators

LeaveGroupOutValidator

LeaveGroupOutValidator(group_coord: str, test_group_ids: list[Hashable] | None = None, validation_group_ids: list[Hashable] | None = None, pooling_score_weight: float = 0.0, scoring: str | Callable | None = None, n_splits: int | None = None, random_state: int = 0, exclude_intertrial_from_scoring: bool = False, exclude_offsets_from_scoring: bool = False, use_stateful_fit_cache: bool = True, release_fold_memory: bool = False, scoring_needs_proba: bool = False, verbose: bool = True)

Bases: CrossValidator

Implements cross-validation by leaving one or more groups out at a time with optional holdout groups.

This validator iterates through each unique group/groups, using it as the validation set once, while all other groups are used for training. This is critical for assessing how well a model generalizes to new, unseen groups.

When n_splits is not set, one group is used for validation at a time. When n_splits is set, the groups are split into n_splits folds.

Initialize Leave-One-Group-Out cross-validator.

Parameters:

Name Type Description Default
group_coord str

Coordinate to group by.

required
test_group_ids list[Hashable] | None

List of group IDs to use as final holdout test set. If None or empty, no holdout set is created.

None
validation_group_ids list[Hashable] | None

List of group IDs to use as validation set. If None or empty, no validation set is created.

None
pooling_score_weight float

Interpolation factor between the average fold score (0.0) and the pooled OOF score (1.0). Defaults to 0.0.

0.0
scoring str | Callable | None

Scoring metric to use. If None, auto-selects based on task type.

None
n_splits int | None

Total number of splits to perform. If None, all groups are used.

None
random_state int

Random state for reproducibility. Used for shuffling groups if n_splits is set.

0
exclude_intertrial_from_scoring bool

Whether to drop intertrial segments during evaluation.

False
use_stateful_fit_cache bool

Whether to cache stateful transforms during CV.

True
verbose bool

Whether to print verbose output specific to cross-validation.

True
Source code in xdflow/cv/leave_group_out.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def __init__(
    self,
    group_coord: str,
    test_group_ids: list[Hashable] | None = None,
    validation_group_ids: list[Hashable] | None = None,
    pooling_score_weight: float = 0.0,
    scoring: str | Callable | None = None,
    n_splits: int | None = None,
    random_state: int = 0,
    exclude_intertrial_from_scoring: bool = False,
    exclude_offsets_from_scoring: bool = False,
    use_stateful_fit_cache: bool = True,
    release_fold_memory: bool = False,
    scoring_needs_proba: bool = False,
    verbose: bool = True,
):
    """
    Initialize Leave-One-Group-Out cross-validator.

    Args:
        group_coord: Coordinate to group by.
        test_group_ids: List of group IDs to use as final holdout test set.
                         If None or empty, no holdout set is created.
        validation_group_ids: List of group IDs to use as validation set.
                         If None or empty, no validation set is created.
        pooling_score_weight: Interpolation factor between the average fold
            score (0.0) and the pooled OOF score (1.0). Defaults to 0.0.
        scoring: Scoring metric to use. If None, auto-selects based on task type.
        n_splits: Total number of splits to perform. If None, all groups are used.
        random_state: Random state for reproducibility. Used for shuffling groups if n_splits is set.
        exclude_intertrial_from_scoring: Whether to drop intertrial segments during evaluation.
        use_stateful_fit_cache: Whether to cache stateful transforms during CV.
        verbose: Whether to print verbose output specific to cross-validation.
    """
    super().__init__(
        pooling_score_weight=pooling_score_weight,
        scoring=scoring,
        scoring_needs_proba=scoring_needs_proba,
        verbose=verbose,
        use_stateful_fit_cache=use_stateful_fit_cache,
        release_fold_memory=release_fold_memory,
        exclude_intertrial_from_scoring=exclude_intertrial_from_scoring,
        exclude_offsets_from_scoring=exclude_offsets_from_scoring,
    )
    self.group_coord = group_coord
    self.test_group_ids = test_group_ids or []
    self.validation_group_ids = validation_group_ids or []
    self.n_splits = n_splits
    if self.n_splits is not None and self.n_splits < 2:
        raise ValueError("n_splits must be >= 2 if set.")
    self.random_state = random_state

    # validation_group_ids should not be in test_group_ids
    if set(self.validation_group_ids) & set(self.test_group_ids):
        raise ValueError("Validation group IDs and test group IDs must not overlap.")

LeaveSessionOutValidator

LeaveSessionOutValidator(test_session_ids: list[Hashable] | None = None, validation_session_ids: list[Hashable] | None = None, pooling_score_weight: float = 0.0, scoring: str | Callable | None = None, n_splits: int | None = None, random_state: int = 0, exclude_intertrial_from_scoring: bool = False, exclude_offsets_from_scoring: bool = False, use_stateful_fit_cache: bool = True, release_fold_memory: bool = False, scoring_needs_proba: bool = False, verbose: bool = True)

Bases: LeaveGroupOutValidator

Implements cross-validation by leaving one or more sessions out at a time with optional holdout sessions.

This validator iterates through each unique session/sessions, using it as the validation set once, while all other sessions are used for training. This is critical for assessing how well a model generalizes to new, unseen sessions.

Note: This is a convenience wrapper around LeaveGroupOutValidator with group_coord="session".

Initialize Leave-Session-Out cross-validator.

Parameters:

Name Type Description Default
test_session_ids list[Hashable] | None

List of session IDs to use as final holdout test set. If None or empty, no holdout set is created.

None
validation_session_ids list[Hashable] | None

List of session IDs to use as validation set. If None or empty, no validation set is created.

None
pooling_score_weight float

Interpolation factor between the average fold score (0.0) and the pooled OOF score (1.0). Defaults to 0.0.

0.0
scoring str | Callable | None

Scoring metric to use. If None, auto-selects based on task type.

None
n_splits int | None

Total number of splits to perform. If None, all sessions are used.

None
random_state int

Random state for reproducibility. Used for shuffling sessions if n_splits is set.

0
exclude_intertrial_from_scoring bool

Whether to drop intertrial segments during evaluation.

False
use_stateful_fit_cache bool

Whether to cache stateful transforms during CV.

True
verbose bool

Whether to print verbose output specific to cross-validation.

True
Source code in xdflow/cv/leave_group_out.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
def __init__(
    self,
    test_session_ids: list[Hashable] | None = None,
    validation_session_ids: list[Hashable] | None = None,
    pooling_score_weight: float = 0.0,
    scoring: str | Callable | None = None,
    n_splits: int | None = None,
    random_state: int = 0,
    exclude_intertrial_from_scoring: bool = False,
    exclude_offsets_from_scoring: bool = False,
    use_stateful_fit_cache: bool = True,
    release_fold_memory: bool = False,
    scoring_needs_proba: bool = False,
    verbose: bool = True,
):
    """
    Initialize Leave-Session-Out cross-validator.

    Args:
        test_session_ids: List of session IDs to use as final holdout test set.
                         If None or empty, no holdout set is created.
        validation_session_ids: List of session IDs to use as validation set.
                         If None or empty, no validation set is created.
        pooling_score_weight: Interpolation factor between the average fold
            score (0.0) and the pooled OOF score (1.0). Defaults to 0.0.
        scoring: Scoring metric to use. If None, auto-selects based on task type.
        n_splits: Total number of splits to perform. If None, all sessions are used.
        random_state: Random state for reproducibility. Used for shuffling sessions if n_splits is set.
        exclude_intertrial_from_scoring: Whether to drop intertrial segments during evaluation.
        use_stateful_fit_cache: Whether to cache stateful transforms during CV.
        verbose: Whether to print verbose output specific to cross-validation.
    """
    # Delegate to LeaveGroupOutValidator with group_coord="session"
    super().__init__(
        group_coord="session",
        test_group_ids=test_session_ids,
        validation_group_ids=validation_session_ids,
        pooling_score_weight=pooling_score_weight,
        scoring=scoring,
        n_splits=n_splits,
        random_state=random_state,
        exclude_intertrial_from_scoring=exclude_intertrial_from_scoring,
        exclude_offsets_from_scoring=exclude_offsets_from_scoring,
        use_stateful_fit_cache=use_stateful_fit_cache,
        release_fold_memory=release_fold_memory,
        scoring_needs_proba=scoring_needs_proba,
        verbose=verbose,
    )
    # Maintain backward-compatible attribute names for introspection
    self.test_session_ids = self.test_group_ids
    self.validation_session_ids = self.validation_group_ids

LeaveAnimalOutValidator

LeaveAnimalOutValidator(test_animal_ids: list[Hashable] | None = None, validation_animal_ids: list[Hashable] | None = None, pooling_score_weight: float = 0.0, scoring: str | Callable | None = None, n_splits: int | None = None, random_state: int = 0, exclude_intertrial_from_scoring: bool = False, exclude_offsets_from_scoring: bool = False, use_stateful_fit_cache: bool = True, release_fold_memory: bool = False, scoring_needs_proba: bool = False, verbose: bool = True)

Bases: LeaveGroupOutValidator

Implements cross-validation by leaving one or more animals out at a time with optional holdout animals.

This validator iterates through each unique animal/animals, using it as the validation set once, while all other animals are used for training. This is critical for assessing how well a model generalizes to new, unseen animals.

Note: This is a convenience wrapper around LeaveGroupOutValidator with group_coord="animal".

Initialize Leave-Animal-Out cross-validator.

Parameters:

Name Type Description Default
test_animal_ids list[Hashable] | None

List of animal IDs to use as final holdout test set. If None or empty, no holdout set is created.

None
validation_animal_ids list[Hashable] | None

List of animal IDs to use as validation set. If None or empty, no validation set is created.

None
pooling_score_weight float

Interpolation factor between the average fold score (0.0) and the pooled OOF score (1.0). Defaults to 0.0.

0.0
scoring str | Callable | None

Scoring metric to use. If None, auto-selects based on task type.

None
n_splits int | None

Total number of splits to perform. If None, all animals are used.

None
random_state int

Random state for reproducibility. Used for shuffling sessions if n_splits is set.

0
exclude_intertrial_from_scoring bool

Whether to drop intertrial segments during evaluation.

False
use_stateful_fit_cache bool

Whether to cache stateful transforms during CV.

True
verbose bool

Whether to print verbose output specific to cross-validation.

True
Source code in xdflow/cv/leave_group_out.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
def __init__(
    self,
    test_animal_ids: list[Hashable] | None = None,
    validation_animal_ids: list[Hashable] | None = None,
    pooling_score_weight: float = 0.0,
    scoring: str | Callable | None = None,
    n_splits: int | None = None,
    random_state: int = 0,
    exclude_intertrial_from_scoring: bool = False,
    exclude_offsets_from_scoring: bool = False,
    use_stateful_fit_cache: bool = True,
    release_fold_memory: bool = False,
    scoring_needs_proba: bool = False,
    verbose: bool = True,
):
    """
    Initialize Leave-Animal-Out cross-validator.

    Args:
        test_animal_ids: List of animal IDs to use as final holdout test set.
                         If None or empty, no holdout set is created.
        validation_animal_ids: List of animal IDs to use as validation set.
                         If None or empty, no validation set is created.
        pooling_score_weight: Interpolation factor between the average fold
            score (0.0) and the pooled OOF score (1.0). Defaults to 0.0.
        scoring: Scoring metric to use. If None, auto-selects based on task type.
        n_splits: Total number of splits to perform. If None, all animals are used.
        random_state: Random state for reproducibility. Used for shuffling sessions if n_splits is set.
        exclude_intertrial_from_scoring: Whether to drop intertrial segments during evaluation.
        use_stateful_fit_cache: Whether to cache stateful transforms during CV.
        verbose: Whether to print verbose output specific to cross-validation.
    """
    # Delegate to LeaveGroupOutValidator with group_coord="animal"
    super().__init__(
        group_coord="animal",
        test_group_ids=test_animal_ids,
        validation_group_ids=validation_animal_ids,
        pooling_score_weight=pooling_score_weight,
        scoring=scoring,
        n_splits=n_splits,
        random_state=random_state,
        exclude_intertrial_from_scoring=exclude_intertrial_from_scoring,
        exclude_offsets_from_scoring=exclude_offsets_from_scoring,
        use_stateful_fit_cache=use_stateful_fit_cache,
        release_fold_memory=release_fold_memory,
        scoring_needs_proba=scoring_needs_proba,
        verbose=verbose,
    )

Sklearn Adapter

SklearnCVAdapter

SklearnCVAdapter(cross_validator)

Bases: BaseCrossValidator

Adapter that converts a CrossValidator to sklearn-compatible CV splitter.

This allows using custom CrossValidator classes (LeaveGroupOutValidator, etc.) with sklearn models that accept a cv parameter (LogisticRegressionCV, RidgeCV, etc.).

The adapter uses a context variable to receive the DataContainer during fit(), allowing it to work in nested CV scenarios where the container may change. It is intended for normal pipeline usage (including tuning) because SKLearnTransform automatically wraps estimator.fit with set_cv_container when it detects a SklearnCVAdapter. For standalone sklearn usage, the context manager must be set explicitly.

Initialize the adapter.

Parameters

cross_validator : CrossValidator The custom cross validator to adapt

Source code in xdflow/cv/sklearn_adapter.py
37
38
39
40
41
42
43
44
45
46
47
def __init__(self, cross_validator):
    """
    Initialize the adapter.

    Parameters
    ----------
    cross_validator : CrossValidator
        The custom cross validator to adapt
    """
    self.cross_validator = cross_validator
    self._n_splits = None

split

split(X, y=None, groups=None)

Generate indices to split data into training and test set.

Source code in xdflow/cv/sklearn_adapter.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def split(self, X, y=None, groups=None):
    """
    Generate indices to split data into training and test set.
    """
    # Get container from context variable
    container = _current_container.get()
    if container is None:
        raise ValueError(
            "No container found in context. Use set_cv_container(container) context manager before calling fit()."
        )

    # Get all trial indices (for non-holdout splits)
    all_trials = container.data.trial.values

    # Get train/val split (excluding holdout)
    train_val_indices, _ = self.cross_validator._split_holdout(container)

    # Create mapping from trial IDs to 0-based indices
    trial_to_idx = {trial: idx for idx, trial in enumerate(all_trials)}

    # Get splits from the cross validator
    splits = self.cross_validator._get_splits(container, train_val_indices)

    n_splits = 0
    for train_trials, val_trials in splits:
        # Convert trial IDs to 0-based indices
        train_indices = np.array([trial_to_idx[t] for t in train_trials])
        val_indices = np.array([trial_to_idx[t] for t in val_trials])

        n_splits += 1
        yield train_indices, val_indices

    self._n_splits = n_splits

get_n_splits

get_n_splits(X=None, y=None, groups=None)

Returns the number of splitting iterations in the cross-validator.

Source code in xdflow/cv/sklearn_adapter.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def get_n_splits(self, X=None, y=None, groups=None):
    """
    Returns the number of splitting iterations in the cross-validator.
    """
    # If we've already done a split, return cached value
    if self._n_splits is not None:
        return self._n_splits

    # Get container from context variable
    container = _current_container.get()
    if container is None:
        raise ValueError(
            "No container found in context. Use set_cv_container(container) "
            "context manager before calling get_n_splits()."
        )

    # Count the splits
    train_val_indices, _ = self.cross_validator._split_holdout(container)
    splits = list(self.cross_validator._get_splits(container, train_val_indices))
    self._n_splits = len(splits)

    return self._n_splits

set_cv_container

set_cv_container(container: DataContainer)

Context manager to set the DataContainer for SklearnCVAdapter.

This must be used when fitting sklearn models that use SklearnCVAdapter for cross-validation, unless the estimator is wrapped in SKLearnTransform (which will set the context automatically).

Source code in xdflow/cv/sklearn_adapter.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
@contextmanager
def set_cv_container(container: DataContainer):
    """
    Context manager to set the DataContainer for SklearnCVAdapter.

    This must be used when fitting sklearn models that use SklearnCVAdapter
    for cross-validation, unless the estimator is wrapped in SKLearnTransform
    (which will set the context automatically).
    """
    token = _current_container.set(container)
    try:
        yield
    finally:
        _current_container.reset(token)