Skip to content

Core API

Core APIs define the data container and transform contracts used throughout XDFlow.

For extension-oriented examples, see Writing Custom Transforms and Writing Custom Cross-Validators.

Top-Level Exports

The root package exposes the main workflow primitives:

from xdflow import DataContainer, Transform, Predictor, Pipeline, CrossValidator

Tuner is conditionally exported when the tuning extra is installed.

Data Containers

DataContainer

DataContainer(data: DataArray, required_coords: list[str] | None = None)

Thin framework wrapper around an xarray.DataArray.

XDFlow's data model is xarray. DataContainer is not a parallel array abstraction; it is the object passed between transforms, predictors, and cross-validation utilities so the framework has a consistent boundary. The wrapped xarray.DataArray remains the source of truth for values, dimensions, coordinates, and attrs.

The wrapper initializes the data_history attribute used to track pipeline operations and rewraps common xarray operations so chained calls stay inside the XDFlow transform contract.

Most xarray methods can be called directly on the container. Methods that return a new xarray.DataArray are rewrapped as a new DataContainer, so calls such as container.sel(...) or container.mean(...) remain inside the XDFlow container contract.

The wrapped array is shallow-copied on construction. Transforms should still treat containers as immutable and return new containers instead of mutating their inputs.

Initialize a container from an xarray data array.

Parameters:

Name Type Description Default
data DataArray

Array with labeled dimensions and coordinates.

required
required_coords list[str] | None

Optional coordinate names to check for. Missing coordinates emit warnings rather than raising, which lets callers decide how strict to be for a given pipeline.

None
Notes

The constructor ensures data.attrs["data_history"] exists. It does not validate dimension names or coordinate schemas beyond required_coords.

Source code in xdflow/core/data_container.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def __init__(self, data: xr.DataArray, required_coords: list[str] | None = None):
    """Initialize a container from an xarray data array.

    Args:
        data: Array with labeled dimensions and coordinates.
        required_coords: Optional coordinate names to check for. Missing
            coordinates emit warnings rather than raising, which lets callers
            decide how strict to be for a given pipeline.

    Notes:
        The constructor ensures `data.attrs["data_history"]` exists. It does
        not validate dimension names or coordinate schemas beyond
        `required_coords`.
    """
    if required_coords is not None:
        for coord in required_coords:
            if coord not in data.coords:
                warnings.warn(f"Missing required coordinate: '{coord}'")

    # Create a shallow copy to ensure immutability (deep copy not needed for xarray immutability)
    self._data = data.copy(deep=False)

    # Initialize history without sharing the mutable list across containers.
    attrs = dict(self._data.attrs)
    attrs["data_history"] = list(attrs.get("data_history", []))
    self._data.attrs = attrs

data property

data: DataArray

Public accessor for the wrapped DataArray. Used in order to ensure immutability.

time_units property

time_units: str | None

Return declared time units for the time coordinate if present.

Returns:

Type Description
str | None

The value of data.coords['time'].attrs['units'] if available, otherwise None.

__getstate__

__getstate__()

Return the state to be pickled.

Source code in xdflow/core/data_container.py
67
68
69
def __getstate__(self):
    """Return the state to be pickled."""
    return self.__dict__

__setstate__

__setstate__(state)

Restore the state from the unpickled state.

Source code in xdflow/core/data_container.py
71
72
73
def __setstate__(self, state):
    """Restore the state from the unpickled state."""
    self.__dict__.update(state)

__getitem__

__getitem__(key)

Enable slice indexing on DataContainer.

Parameters:

Name Type Description Default
key

Index, slice, or tuple of indices/slices

required

Returns:

Name Type Description
DataContainer

New DataContainer with indexed data

Source code in xdflow/core/data_container.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def __getitem__(self, key):
    """
    Enable slice indexing on DataContainer.

    Args:
        key: Index, slice, or tuple of indices/slices

    Returns:
        DataContainer: New DataContainer with indexed data
    """
    result = self._data[key]
    if isinstance(result, xr.DataArray):
        return type(self)(result)
    return result

__getattr__

__getattr__(name: str)

Delegate attribute access to the underlying xarray.DataArray.

If the attribute is a method that returns a new DataArray, it is wrapped to return a new DataContainer instance. This preserves the wrapper's validation and immutability for chained operations.

Source code in xdflow/core/data_container.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def __getattr__(self, name: str):
    """
    Delegate attribute access to the underlying xarray.DataArray.

    If the attribute is a method that returns a new DataArray, it is
    wrapped to return a new DataContainer instance. This preserves the
    wrapper's validation and immutability for chained operations.
    """
    # Prevent recursion during pickle deserialization by checking if _data exists
    if not hasattr(self, "_data"):
        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

    # Retrieve the attribute from the wrapped _data object
    attr = getattr(self._data, name)

    # Check if the retrieved attribute is a callable method (e.g., .sel, .mean)
    if callable(attr):
        # Create a wrapper to intercept the method call
        @functools.wraps(attr)
        def wrapper(*args, **kwargs):
            # Execute the original xarray method
            result = attr(*args, **kwargs)

            # If the result is a new DataArray, re-wrap it in a new DataContainer
            if isinstance(result, xr.DataArray):
                return type(self)(result)

            # Otherwise, return the result as-is (e.g., a NumPy scalar, a number)
            return result

        return wrapper
    else:
        # If the attribute is a property (e.g., .coords, .dims), return it directly
        return attr

TransformError

Bases: Exception

Error raised when a transform or pipeline step fails.

Transform Contracts

Transform

Transform(sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None, transform_sel: dict | None = None, transform_drop_sel: dict | None = None, **kwargs)

Bases: ABC

Base class for XDFlow processing steps.

A transform accepts a DataContainer and returns a new DataContainer. Concrete subclasses implement _transform; stateful subclasses also implement _fit. The public fit, transform, and fit_transform methods provide common selection handling, optional timing output, history logging, and the stateless/stateful execution contract used by Pipeline and CrossValidator.

Implementations should prefer named dimensions over positional axes. For example, use data.mean(dim="time") instead of assuming the time axis is at a fixed integer position. Transforms should not mutate their input container; return a new container or an immutable view consistent with xarray behavior.

Class attributes

is_stateful: Whether the transform learns state from fit. input_dims: Required input dimensions. An empty tuple means the transform accepts dynamic input dimensions. output_dims: Output dimensions when known statically. An empty tuple means subclasses must infer them with get_expected_output_dims.

Authoring notes

Define constructor hyperparameters as explicit __init__ arguments and store them on public attributes with matching names. Store learned state in private attributes that are not constructor parameters, so clone creates a fresh unfitted instance. **kwargs exists only for cooperative multiple inheritance; subclasses should not silently consume new hyperparameters through it.

Initialize common transform selection options.

sel and drop_sel subset the whole input before the transform runs, so the output contains only the selected data. transform_sel and transform_drop_sel select only the portion to fit or transform, then write that transformed portion back into the original array. Partial write-back is only allowed for transforms that preserve dims, sizes, and coordinates.

Parameters:

Name Type Description Default
sel dict[str, Any] | None

Label selection passed to xarray .sel before transforming.

None
drop_sel dict[str, Any] | None

Label selection passed to xarray .drop_sel before transforming.

None
transform_sel dict | None

Label selection used only for the transformed portion.

None
transform_drop_sel dict | None

Labels to exclude from the transformed portion.

None
Source code in xdflow/core/base.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def __init__(
    self,
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
    transform_sel: dict | None = None,
    transform_drop_sel: dict | None = None,
    **kwargs,
):
    """Initialize common transform selection options.

    `sel` and `drop_sel` subset the whole input before the transform runs, so
    the output contains only the selected data. `transform_sel` and
    `transform_drop_sel` select only the portion to fit or transform, then
    write that transformed portion back into the original array. Partial
    write-back is only allowed for transforms that preserve dims, sizes, and
    coordinates.

    Args:
        sel: Label selection passed to xarray `.sel` before transforming.
        drop_sel: Label selection passed to xarray `.drop_sel` before
            transforming.
        transform_sel: Label selection used only for the transformed portion.
        transform_drop_sel: Labels to exclude from the transformed portion.
    """
    # kwargs are accepted for cooperative inheritance but not used by Transform itself
    self.sel = sel
    self.drop_sel = drop_sel
    self.transform_sel = transform_sel
    self.transform_drop_sel = transform_drop_sel
    if self.transform_sel and self.transform_drop_sel:
        raise ValueError("Cannot specify both 'transform_sel' and 'transform_drop_sel'.")
    if (self.transform_sel or self.transform_drop_sel) and not self.supports_transform_sel:
        selection_arg = "transform_sel" if self.transform_sel else "transform_drop_sel"
        raise TypeError(
            f"{self.__class__.__name__} does not support {selection_arg}. "
            "Selective transformation is only supported for transforms whose outputs can be safely written "
            "back into the original array without changing dims, sizes, or coordinates. Use sel/drop_sel to "
            "subset the whole output, or apply selection before this transform."
        )

supports_transform_sel property

supports_transform_sel: bool

Whether this transform supports transform_sel semantics.

Defaults to the class attribute _supports_transform_sel but allows subclasses to compute support dynamically via an override.

get_expected_output_dims

get_expected_output_dims(input_dims: tuple[str, ...]) -> tuple[str, ...]

Determines expected output dims based on manually inputed input_dims

Source code in xdflow/core/base.py
192
193
194
195
196
197
def get_expected_output_dims(self, input_dims: tuple[str, ...], /) -> tuple[str, ...]:
    """Determines expected output dims based on manually inputed input_dims"""
    if self.output_dims:
        return self.output_dims
    # e.g. for average, output_dims = tuple([dim for dim in input_dims if dim != self.dim_to_average])
    raise NotImplementedError("Subclasses must either specify output_dims or implement get_expected_output_dims.")

transform

transform(container: DataContainer, **kwargs) -> DataContainer

Applies the transformation.

Parameters:

Name Type Description Default
container DataContainer

DataContainer to transform

required
**kwargs

Additional context/parameters passed through the pipeline

{}

Returns:

Type Description
DataContainer

New DataContainer with transformation applied

Source code in xdflow/core/base.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def transform(self, container: DataContainer, **kwargs) -> DataContainer:
    """
    Applies the transformation.

    Args:
        container: DataContainer to transform
        **kwargs: Additional context/parameters passed through the pipeline

    Returns:
        New DataContainer with transformation applied
    """
    verbose = kwargs.get("verbose", False)

    if verbose:
        print(f"Starting {self.__class__.__name__}.transform")
        start_time = time.time()

    container = self._apply_selection(container)

    effective_transform_sel = self._get_effective_transform_sel(container)

    if effective_transform_sel and self.supports_transform_sel:
        # Create a deep copy to ensure immutability when writing back selected data
        new_container = container.copy(deep=True)  # NOTE: This needs to be deep as sel works on slices.
        # 1. Select the part of the data to be transformed from the original container state
        data_to_transform = container.sel(**effective_transform_sel)  # TODO Takes too long

        # 2. Transform the selected data
        transformed_part = self._transform(data_to_transform, **kwargs)

        # 3. check that the transform_sel output matches the input structure
        self._check_transform_sel_output(data_to_transform, transformed_part)

        # Update the new container with the transformed part
        new_container.data.loc[effective_transform_sel] = transformed_part.data  # TODO takes a bit long too
        transformed_container = new_container
    else:
        transformed_container = self._transform(container, **kwargs)

    if verbose:
        end_time = time.time()
        duration = end_time - start_time
        print(f"Ending {self.__class__.__name__}.transform - took {duration:.3f}s")

    return self._log_history(container, transformed_container)

fit

fit(container: DataContainer, **kwargs) -> Transform

Fits the transform to the data.

Parameters:

Name Type Description Default
container DataContainer

DataContainer to fit on

required
**kwargs

Additional context/parameters passed through the pipeline

{}

Returns:

Type Description
Transform

Self (fitted transform)

Source code in xdflow/core/base.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
def fit(self, container: DataContainer, **kwargs) -> "Transform":
    """
    Fits the transform to the data.

    Args:
        container: DataContainer to fit on
        **kwargs: Additional context/parameters passed through the pipeline

    Returns:
        Self (fitted transform)
    """
    verbose = kwargs.get("verbose", False)

    if verbose:
        print(f"Starting {self.__class__.__name__}.fit")
        start_time = time.time()

    if self.is_stateful:  # Stateful transforms require fitting to data.
        container = self._apply_selection(container)

        effective_transform_sel = self._get_effective_transform_sel(
            container
        )  # Returns None if no transform_sel or transform_drop_sel is set

        if effective_transform_sel and self.supports_transform_sel:
            container = container.sel(**effective_transform_sel)

        result = self._fit(container, **kwargs)
    else:
        result = self

    if verbose:
        end_time = time.time()
        duration = end_time - start_time
        print(f"Ending {self.__class__.__name__}.fit - took {duration:.3f}s")

    return result

fit_transform

fit_transform(container: DataContainer, **kwargs) -> DataContainer

Fit then transform in a single pass. Note that predictors have their own fit_transform.

Parameters:

Name Type Description Default
container DataContainer

DataContainer to fit and transform

required
**kwargs

Additional context/parameters passed through the pipeline

{}

Returns:

Type Description
DataContainer

DataContainer with the fit and transform applied

Source code in xdflow/core/base.py
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
def fit_transform(self, container: DataContainer, **kwargs) -> DataContainer:
    """
    Fit then transform in a single pass. Note that predictors have their own fit_transform.

    Args:
        container: DataContainer to fit and transform
        **kwargs: Additional context/parameters passed through the pipeline

    Returns:
        DataContainer with the fit and transform applied
    """
    verbose = kwargs.get("verbose", False)

    if verbose:
        print(f"Starting {self.__class__.__name__}.fit_transform")
        start_time = time.time()

    if self.is_stateful:
        container = self._apply_selection(container)

        effective_transform_sel = self._get_effective_transform_sel(container)

        if effective_transform_sel and self.supports_transform_sel:
            container_to_fit = container.sel(**effective_transform_sel)
            self._fit(container_to_fit, **kwargs)
            data_to_transform = container.sel(**effective_transform_sel)
            transformed_part = self._transform(data_to_transform, **kwargs)

            # check that the transform_sel output matches the input structure
            self._check_transform_sel_output(data_to_transform, transformed_part)

            new_container = container.copy(deep=True)
            new_container.data.loc[effective_transform_sel] = transformed_part.data
            transformed_container = new_container
        else:
            self._fit(container, **kwargs)
            transformed_container = self._transform(container, **kwargs)
        result = self._log_history(container, transformed_container)
    else:
        # For stateless transforms, just transform
        result = self.transform(container, **kwargs)

    if verbose:
        end_time = time.time()
        duration = end_time - start_time
        print(f"Ending {self.__class__.__name__}.fit_transform - took {duration:.3f}s")

    return result

get_params

get_params(deep: bool = True) -> dict[str, Any]

Get parameters for this transform.

Parameters:

Name Type Description Default
deep bool

If True, will return the parameters for this transform and contained sub-objects that are themselves transforms.

True

Returns:

Type Description
dict[str, Any]

dict[str, Any]: Parameter names mapped to their values.

Source code in xdflow/core/base.py
374
375
376
377
378
379
380
381
382
383
384
385
386
def get_params(self, deep: bool = True) -> dict[str, Any]:
    """
    Get parameters for this transform.

    Args:
        deep (bool): If True, will return the parameters for this transform and
                     contained sub-objects that are themselves transforms.

    Returns:
        dict[str, Any]: Parameter names mapped to their values.
    """
    params = {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
    return params

clone

clone() -> Self

Return a fresh instance with the same constructor parameters.

Subclasses that need to preserve constructor kwargs not surfaced by get_params should override _get_clone_kwargs() instead of overriding this method.

Source code in xdflow/core/base.py
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
def clone(self) -> Self:
    """Return a fresh instance with the same constructor parameters.

    Subclasses that need to preserve constructor kwargs not surfaced by
    ``get_params`` should override ``_get_clone_kwargs()`` instead of
    overriding this method.
    """
    filtered_params = self._get_clone_kwargs()

    ctor = signature(type(self).__init__)
    ctor_param_names = {
        name
        for name, p in ctor.parameters.items()
        if name != "self" and p.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY)
    }
    missing = ctor_param_names - filtered_params.keys()
    assert not missing, f"Clone kwargs missing constructor parameters for {self.__class__.__name__}: {missing}"

    return type(self)(**filtered_params)

set_params

set_params(**params: Any) -> Transform

Set the parameters of this transform.

Supports nested parameter setting for dict/object attributes using '__' delimiter. For example, 'weight_map__stim_A' will set the "stim_A" key in the weight_map dict. Keys are type-inferred from existing dict keys when possible (e.g., "False" -> False).

Returns:

Name Type Description
self Transform

The transform instance.

Source code in xdflow/core/base.py
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
def set_params(self, **params: Any) -> "Transform":
    """
    Set the parameters of this transform.

    Supports nested parameter setting for dict/object attributes using '__' delimiter.
    For example, 'weight_map__stim_A' will set the "stim_A" key in the weight_map dict.
    Keys are type-inferred from existing dict keys when possible (e.g., "False" -> False).

    Returns:
        self: The transform instance.
    """
    for key, value in params.items():
        if "__" not in key:
            setattr(self, key, value)
            continue

        attr_name, nested_keys_str = key.split("__", 1)
        if not hasattr(self, attr_name):
            raise ValueError(f"'{type(self).__name__}' object has no attribute '{attr_name}'")

        attr_value = getattr(self, attr_name)
        if isinstance(attr_value, dict):
            nested_keys = nested_keys_str.split("__")
            setattr(self, attr_name, _set_nested_dict_value(attr_value, nested_keys, value))
        elif hasattr(attr_value, "set_params"):
            attr_value.set_params(**{nested_keys_str: value})
        else:
            raise ValueError(
                f"Cannot set nested parameter '{key}': "
                f"attribute '{attr_name}' is not a dict or doesn't have set_params method"
            )
    return self

Predictor

Predictor(sample_dim: str, target_coord: str | list[str], is_classifier: bool, encoder: LabelEncoder | None = None, proba: bool = False, is_multilabel: bool = False, sel: dict | None = None, drop_sel: dict | None = None, transform_sel: dict | None = None, transform_drop_sel: dict | None = None, calibrated_thresholds: ndarray | list[float] | None = None, **kwargs)

Bases: Transform, ABC

Base class for transforms that learn targets and produce predictions.

Predictors are stateful transforms. During fitting, single-label classifier targets are encoded with a LabelEncoder; regressors and multilabel classifiers use their target coordinates directly. Subclasses implement the estimator-specific _predict method and optionally _predict_proba.

Public prediction methods return DataContainer objects whose sample coordinate is preserved from sample_dim. Classifier outputs are decoded back to original labels when possible, while probability outputs are aligned to the fitted global class order.

Initialize common prediction metadata.

Parameters:

Name Type Description Default
sample_dim str

Dimension whose entries are independent samples.

required
target_coord str | list[str]

Target coordinate name, list of target coordinate names, or a pattern resolved by subclasses during fit.

required
is_classifier bool

Whether predictions are categorical labels instead of continuous values.

required
encoder LabelEncoder | None

Optional label encoder for single-label classifiers. If omitted, a new encoder is created for classifier predictors.

None
proba bool

Whether transform should return probabilities instead of hard predictions.

False
is_multilabel bool

Whether classification targets are multiple binary target coordinates. Multilabel classifiers do not use a LabelEncoder.

False
sel dict | None

Label selection applied before fitting or transforming.

None
drop_sel dict | None

Label selection dropped before fitting or transforming.

None
transform_sel dict | None

Label selection used only for partial transformation.

None
transform_drop_sel dict | None

Labels excluded from partial transformation.

None
calibrated_thresholds ndarray | list[float] | None

Optional multilabel decision thresholds, one per output.

None
Source code in xdflow/core/base.py
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
def __init__(
    self,
    sample_dim: str,
    target_coord: str | list[str],
    is_classifier: bool,
    encoder: LabelEncoder | None = None,
    proba: bool = False,
    is_multilabel: bool = False,
    sel: dict | None = None,
    drop_sel: dict | None = None,
    transform_sel: dict | None = None,
    transform_drop_sel: dict | None = None,
    calibrated_thresholds: np.ndarray | list[float] | None = None,
    **kwargs,
):
    """Initialize common prediction metadata.

    Args:
        sample_dim: Dimension whose entries are independent samples.
        target_coord: Target coordinate name, list of target coordinate
            names, or a pattern resolved by subclasses during fit.
        is_classifier: Whether predictions are categorical labels instead
            of continuous values.
        encoder: Optional label encoder for single-label classifiers. If
            omitted, a new encoder is created for classifier predictors.
        proba: Whether `transform` should return probabilities instead of
            hard predictions.
        is_multilabel: Whether classification targets are multiple binary
            target coordinates. Multilabel classifiers do not use a
            `LabelEncoder`.
        sel: Label selection applied before fitting or transforming.
        drop_sel: Label selection dropped before fitting or transforming.
        transform_sel: Label selection used only for partial transformation.
        transform_drop_sel: Labels excluded from partial transformation.
        calibrated_thresholds: Optional multilabel decision thresholds, one per output.
    """
    allow_unknown_targets = kwargs.pop("allow_unknown_targets", self.allow_unknown_targets)
    unknown_target_encoding = kwargs.pop("unknown_target_encoding", self.unknown_target_encoding)
    is_multilabel = kwargs.pop("is_multilabel", is_multilabel)

    if is_multilabel and not is_classifier:
        raise ValueError(
            f"{self.__class__.__name__} initialized with is_multilabel=True but is_classifier=False. "
            "Multilabel mode requires is_classifier=True."
        )

    # Check if target_coord is explicitly a list (not a pattern string)
    is_multi_target = isinstance(target_coord, list)

    # For now, create a provisional target_coord_list
    # This will be resolved at fit time if it's a pattern
    if is_multi_target:
        target_coord_list = target_coord
    else:
        # Could be single coord or pattern - will be resolved at fit time
        target_coord_list = [target_coord] if isinstance(target_coord, str) else []

    if not is_classifier:
        if proba:
            raise ValueError(
                f"{self.__class__.__name__} has been initialized with is_classifier=False and proba=True. "
                "Probabilities should not be returned for continuous target coordinates."
            )
        if encoder is not None:
            raise ValueError(
                f"{self.__class__.__name__} initialized with is_classifier=False but an encoder was provided. "
                "Encoders are only valid for classifiers."
            )
    elif is_multilabel:
        if encoder is not None:
            raise ValueError(
                f"{self.__class__.__name__} initialized with is_multilabel=True but an encoder was provided. "
                "Multilabel classifiers use binary target coordinates and do not need an encoder."
            )
    else:
        # For single-label classifiers, multi-target is not supported
        if is_multi_target:
            raise ValueError(
                f"{self.__class__.__name__} initialized with is_classifier=True and multiple target_coord. "
                "Multiple classifier targets require is_multilabel=True. Use a regressor for continuous "
                "multi-output targets."
            )
        # For classifiers, always ensure an encoder exists to standardize label handling.
        if encoder is None:
            encoder = LabelEncoder()

    # With the above defaulting, encoder is guaranteed for classifiers.

    # Initialize the parent Transform (cooperative inheritance)
    super().__init__(
        sel=sel, drop_sel=drop_sel, transform_sel=transform_sel, transform_drop_sel=transform_drop_sel, **kwargs
    )

    self.allow_unknown_targets = allow_unknown_targets
    self.unknown_target_encoding = unknown_target_encoding
    self.sample_dim = sample_dim
    self.target_coord = target_coord  # Keep original (string or list)
    self.target_coord_list = target_coord_list  # Normalized to list
    self.is_multi_target = is_multi_target
    self.is_classifier = is_classifier
    self.is_multilabel = is_multilabel
    self.encoder = encoder
    self.proba = proba
    self.calibrated_thresholds = calibrated_thresholds
    self.calibrated_thresholds_ = (
        None if calibrated_thresholds is None else np.asarray(calibrated_thresholds, dtype=float)
    )
    self._is_fitted = False

is_regressor property

is_regressor: bool

Whether this is a regression task (inverse of is_classifier).

get_labels

get_labels() -> list[Any]

Return the learned label ordering for classifiers.

Requires the predictor to be configured as a classifier with a fitted encoder.

Source code in xdflow/core/base.py
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
def get_labels(self) -> list[Any]:
    """
    Return the learned label ordering for classifiers.

    Requires the predictor to be configured as a classifier with a fitted encoder.
    """
    if not self.is_classifier:
        raise TypeError(f"{self.__class__.__name__} is configured as a regressor; labels are undefined.")

    if self.is_multilabel:
        return list(self.target_coord_list)

    if self.encoder is None:
        raise RuntimeError(
            f"{self.__class__.__name__} does not have an encoder. Classifiers require an encoder to expose labels."
        )

    if not hasattr(self.encoder, "classes_"):
        raise ValueError(
            f"Encoder for {self.__class__.__name__} has not been fitted yet, so classes_ is unavailable."
        )

    return list(self.encoder.classes_)

set_encoder

set_encoder(encoder: LabelEncoder)

Sets the encoder for the predictor.

Parameters:

Name Type Description Default
encoder LabelEncoder

The encoder to set

required
Source code in xdflow/core/base.py
631
632
633
634
635
636
637
638
639
640
641
642
643
644
def set_encoder(self, encoder: LabelEncoder):
    """
    Sets the encoder for the predictor.

    Args:
        encoder: The encoder to set
    """

    if not self.is_classifier:
        raise ValueError(
            f"{self.__class__.__name__} has been initialized with is_classifier=False. Continuous target coordinates for regression should not be encoded."
        )

    self.encoder = encoder

fit_and_set_encoder

fit_and_set_encoder(data: DataArray) -> None

Fits the encoder and sets it for the predictor.

Parameters:

Name Type Description Default
data DataArray

The data to fit the encoder on

required
Source code in xdflow/core/base.py
646
647
648
649
650
651
652
653
654
655
656
def fit_and_set_encoder(self, data: xr.DataArray) -> None:
    """
    Fits the encoder and sets it for the predictor.

    Args:
        data: The data to fit the encoder on
    """
    if self.encoder is None:
        raise ValueError(f"{self.__class__.__name__} requires an encoder before fitting target labels.")
    self.encoder.fit(data.coords[self.target_coord].values)
    self.set_encoder(self.encoder)

fit

fit(container: DataContainer, **kwargs) -> Transform

Fits the transform to the data. Handles encoding of the target coordinate.

Parameters:

Name Type Description Default
container DataContainer

DataContainer to fit on

required
**kwargs

Additional context/parameters passed through the pipeline

{}

Returns:

Type Description
Transform

Self (fitted transform)

Source code in xdflow/core/base.py
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
def fit(self, container: DataContainer, **kwargs) -> "Transform":
    """
    Fits the transform to the data.
    Handles encoding of the target coordinate.

    Args:
        container: DataContainer to fit on
        **kwargs: Additional context/parameters passed through the pipeline

    Returns:
        Self (fitted transform)
    """
    verbose = kwargs.get("verbose", False)

    if verbose:
        print(f"Starting {self.__class__.__name__}.fit")
        start_time = time.time()

    container = self._apply_selection(container)

    effective_transform_sel = self._get_effective_transform_sel(container)

    if effective_transform_sel and self.supports_transform_sel:
        container = container.sel(**effective_transform_sel)

    # Encode the target coordinate if a single-label classifier.
    if self.is_classifier and not self.is_multilabel:
        # Fit encoder if not yet fitted
        if not hasattr(self.encoder, "classes_"):
            self.fit_and_set_encoder(container.data)
        container = DataContainer(self._encode_target_coord(container.data))

    # Fit the transform
    fitted = self._fit(container, **kwargs)

    # After fitting, check if target coordinates were resolved (for pattern matching)
    if isinstance(fitted, Predictor) and hasattr(fitted, "_resolved_target_coords"):
        resolved = cast(list[str], fitted._resolved_target_coords)
        fitted.target_coord_list = resolved
        fitted.is_multi_target = len(resolved) > 1

    if verbose:
        end_time = time.time()
        duration = end_time - start_time
        print(f"Ending {self.__class__.__name__}.fit - took {duration:.3f}s")

    self._is_fitted = True

    return fitted

fit_transform

fit_transform(container: DataContainer, **kwargs) -> DataContainer

Predictor-specific fit/transform to avoid double selection and ensure encoded y during fit.

Applies selection once, fits on an encoded view (for classifiers), then transforms the unencoded selected view directly via the protected _transform path.

Parameters:

Name Type Description Default
container DataContainer

DataContainer to fit and transform

required
**kwargs

Additional context/parameters passed through the pipeline

{}

Returns:

Type Description
DataContainer

DataContainer with the fit and transform applied

Source code in xdflow/core/base.py
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
def fit_transform(self, container: DataContainer, **kwargs) -> DataContainer:
    """Predictor-specific fit/transform to avoid double selection and ensure encoded y during fit.

    Applies selection once, fits on an encoded view (for classifiers), then transforms
    the unencoded selected view directly via the protected `_transform` path.

    Args:
        container: DataContainer to fit and transform
        **kwargs: Additional context/parameters passed through the pipeline

    Returns:
        DataContainer with the fit and transform applied
    """
    verbose = kwargs.get("verbose", False)

    if verbose:
        print(f"Starting {self.__class__.__name__}.fit_transform")
        start_time = time.time()

    # Apply selection once
    container = self._apply_selection(container)
    effective_transform_sel = self._get_effective_transform_sel(container)
    perform_transform_sel = bool(effective_transform_sel and self._supports_transform_sel)

    # Get effective container
    selected_container = (
        container.sel(**cast(dict[str, Any], effective_transform_sel)) if perform_transform_sel else container
    )

    # Default fit container is the selected view; classifiers override with encoded targets
    container_for_fit = selected_container

    # Prepare data for fit: encode labels for single-label classifiers.
    if self.is_classifier and not self.is_multilabel:
        # Fit encoder if not fitted
        if not hasattr(self.encoder, "classes_"):
            self.fit_and_set_encoder(selected_container.data)
        encoded_da = self._encode_target_coord(selected_container.data)
        container_for_fit = DataContainer(encoded_da)

    # Fit using encoded labels
    self._fit(container_for_fit, **kwargs)

    # After fitting, check if target coordinates were resolved (for pattern matching)
    if hasattr(self, "_resolved_target_coords"):
        resolved = cast(list[str], self._resolved_target_coords)
        self.target_coord_list = resolved
        self.is_multi_target = len(resolved) > 1

    self._is_fitted = True

    # Transform directly on original-selected data to avoid double encoding/selection
    transformed_container = self._transform(selected_container, **kwargs)  # handles encoding/decoding

    if perform_transform_sel:
        self._check_transform_sel_output(selected_container, transformed_container)
        # Update the full container with the transformed part
        new_container = container.copy(deep=False)  # Shallow copy sufficient
        new_container.data.loc[cast(dict[str, Any], effective_transform_sel)] = transformed_container.data
        transformed_container = new_container

    result = self._log_history(container, transformed_container)

    if verbose:
        end_time = time.time()
        duration = end_time - start_time
        print(f"Ending {self.__class__.__name__}.fit_transform - took {duration:.3f}s")

    return result

predict

predict(container: DataContainer, **kwargs) -> DataContainer

Predicts labels, handling data selection and output structuring.

Parameters:

Name Type Description Default
container DataContainer

DataContainer to predict on

required
**kwargs

Additional context/parameters passed through the pipeline

{}

Returns:

Type Description
DataContainer

DataContainer with predictions, data is 1D with shape (n_trials,)

Source code in xdflow/core/base.py
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
def predict(self, container: DataContainer, **kwargs) -> DataContainer:
    """
    Predicts labels, handling data selection and output structuring.

    Args:
        container: DataContainer to predict on
        **kwargs: Additional context/parameters passed through the pipeline

    Returns:
        DataContainer with predictions, data is 1D with shape (n_trials,)
    """
    verbose = kwargs.get("verbose", False)

    if verbose:
        print(f"Starting {self.__class__.__name__}.predict")
        start_time = time.time()

    container = self._apply_selection(container)
    data = container.data

    # Encode target_coord if single-label classifier.
    if self.is_classifier and not self.is_multilabel:
        data = self._encode_target_coord(data)

    # Make the prediction
    if self.is_multilabel and getattr(self, "calibrated_thresholds_", None) is not None:
        if type(self)._predict_proba is Predictor._predict_proba:
            raise ValueError(
                f"{self.__class__.__name__} has calibrated_thresholds_ but does not implement predict_proba."
            )
        probabilities, _ = self._predict_proba(data, **kwargs)
        thresholds = np.asarray(self.calibrated_thresholds_, dtype=float)
        if thresholds.shape != (probabilities.shape[1],):
            raise ValueError(
                f"calibrated_thresholds_ has shape {thresholds.shape}, expected ({probabilities.shape[1]},)."
            )
        predictions = (probabilities >= thresholds).astype(np.int8)
    else:
        predictions = self._predict(data, **kwargs)

    # Validate prediction shape: can be 1D (single target) or 2D (multi-target)
    if predictions.ndim == 1:
        output_dims = [self.sample_dim]
    elif predictions.ndim == 2:
        if not self.is_multi_target and not self.is_multilabel:
            raise ValueError(
                f"Predictor returned 2D predictions but was initialized with single target_coord. "
                f"Got prediction shape {predictions.shape}"
            )
        output_dims = [self.sample_dim, "target"]
    else:
        raise ValueError(
            f"Predictions must be 1D (single target) or 2D (multi-target), but got {predictions.ndim}D"
        )

    # Inverse transform the prediction if encoded
    if self.is_classifier and not self.is_multilabel:
        if self.encoder is None:
            raise ValueError(f"{self.__class__.__name__} requires an encoder for classifier predictions.")
        predictions = self.encoder.inverse_transform(predictions)
        data = self._reset_target_coord(data)

    # Set the coordinates
    output_coords = self._get_output_coords(data)

    # Add target coordinate for multi-target/multilabel predictions
    if predictions.ndim == 2:
        output_coords["target"] = self.target_coord_list

    # Create the output DataArray
    output_da = xr.DataArray(
        predictions, dims=output_dims, coords=output_coords, attrs=data.attrs, name="prediction"
    )
    result = DataContainer(output_da)

    if verbose:
        end_time = time.time()
        duration = end_time - start_time
        print(f"Ending {self.__class__.__name__}.predict - took {duration:.3f}s")

    return result

predict_proba

predict_proba(container: DataContainer, **kwargs) -> DataContainer

Predicts probabilities, handling data selection and output structuring. DataContainer has data with shape (sample_dim, class) (e.g. n_trials, n_stimuli)

Source code in xdflow/core/base.py
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
def predict_proba(self, container: DataContainer, **kwargs) -> DataContainer:
    """
    Predicts probabilities, handling data selection and output structuring.
    DataContainer has data with shape (sample_dim, class) (e.g. n_trials, n_stimuli)
    """
    verbose = kwargs.get("verbose", False)

    if verbose:
        print(f"Starting {self.__class__.__name__}.predict_proba")
        start_time = time.time()

    if not self.is_classifier:
        raise AttributeError(
            f"'{self.__class__.__name__}' has not been instantiated as a classifier (is_classifier=False) so should not call the 'predict_proba' method."
        )

    container = self._apply_selection(container)
    data = container.data

    if self.is_multilabel:
        probabilities, _ = self._predict_proba(data, **kwargs)
        assert probabilities.ndim == 2, f"Probabilities should have 2 dimensions, but got {probabilities.ndim}"
        output_coords = self._get_output_coords(data)
        output_coords["target"] = self.target_coord_list
        output_da = xr.DataArray(
            probabilities, dims=(self.sample_dim, "target"), coords=output_coords, attrs=data.attrs
        )
        result = DataContainer(output_da)
    else:
        # Encode target_coord, predict, and align to global classes.
        data = self._encode_target_coord(data)
        probabilities, class_labels = self._predict_proba(data, **kwargs)
        probabilities = self._align_proba_to_global(probabilities, class_labels)
        assert probabilities.ndim == 2, f"Probabilities should have 2 dimensions, but got {probabilities.ndim}"
        data = self._reset_target_coord(data)
        output_coords = self._get_output_coords(data)
        if self.encoder is None:
            raise ValueError(f"{self.__class__.__name__} requires an encoder for classifier probabilities.")
        output_coords["class"] = self.encoder.classes_

        output_da = xr.DataArray(
            probabilities, dims=(self.sample_dim, "class"), coords=output_coords, attrs=data.attrs
        )
        result = DataContainer(output_da)

    if verbose:
        end_time = time.time()
        duration = end_time - start_time
        print(f"Ending {self.__class__.__name__}.predict_proba - took {duration:.3f}s")

    return result

SampleWeightMixin

Mixin providing generic coordinate-to-array extraction for sample weights.

This mixin decouples weight reading/alignment (generic across frameworks) from signature inspection and kwargs building (framework-specific).

Any transform that wants to support sample weights can inherit this mixin to gain: - sample_weight_coord attribute for specifying the weight coordinate name - _extract_sample_weights() method for reading and aligning weights from a DataArray

The transform is then responsible for: - Checking if its underlying estimator/model supports sample weights - Passing the weights to the appropriate fit/train method

Example

class MyPredictor(Transform, SampleWeightMixin): def init(self, sample_weight_coord=None, kwargs): super().init(kwargs) self.sample_weight_coord = sample_weight_coord

def _fit(self, container: DataContainer, **kwargs):
    X, sample_index = self._prepare_data(container.data)
    weights = self._extract_sample_weights(container.data, sample_index)
    if weights is not None:
        self.model.fit(X, sample_weight=weights)
    else:
        self.model.fit(X)