Skip to content

Transforms API

Transforms are the reusable preprocessing, feature extraction, and model-adapter units that move DataContainer objects through a pipeline. Validators and tuners use transform metadata during execution: dimension declarations define valid handoffs between steps, and is_stateful controls which steps are refit inside each fold versus reused as fold-invariant work.

For authoring guidance, examples, and test patterns, see Writing Custom Transforms.

Basic Transforms

TransposeDimsTransform

TransposeDimsTransform(dims: tuple[str, ...], sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None)

Bases: Transform

Transpose dimensions of the data container.

Initializes the TransposeDimsTransform.

Parameters:

Name Type Description Default
dims tuple[str, ...]

The dimensions to transpose to.

required
sel dict[str, Any] | None

Optional selection to apply before transforming.

None
drop_sel dict[str, Any] | None

Optional drop selection to apply before transforming.

None
Source code in xdflow/transforms/basic_transforms.py
24
25
26
27
28
29
30
31
32
33
34
35
36
def __init__(
    self, dims: tuple[str, ...], sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None
):
    """
    Initializes the TransposeDimsTransform.

    Args:
        dims: The dimensions to transpose to.
        sel: Optional selection to apply before transforming.
        drop_sel: Optional drop selection to apply before transforming.
    """
    super().__init__(sel=sel, drop_sel=drop_sel)
    self.dims = dims

RenameDimsTransform

RenameDimsTransform(rename_map: dict[str, str], sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None)

Bases: Transform

Rename xarray dimension names in the data container.

This transform applies xarray.DataArray.rename to change dimension names (and matching coordinate names) according to a provided mapping. It preserves data and coordinate values.

Example

Rename feature to channel after a union or feature step: transform = RenameDimsTransform(rename_map={"feature": "channel"})

Parameters:

Name Type Description Default
rename_map dict[str, str]

Mapping from old dimension names to new names.

required
Source code in xdflow/transforms/basic_transforms.py
64
65
66
67
68
def __init__(
    self, rename_map: dict[str, str], sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None
):
    super().__init__(sel=sel, drop_sel=drop_sel)
    self.rename_map = rename_map

get_expected_output_dims

get_expected_output_dims(input_dims: tuple[str, ...]) -> tuple[str, ...]

Infer output dims by applying the mapping to the input dims.

Parameters:

Name Type Description Default
input_dims tuple[str, ...]

Tuple of input dim names.

required

Returns:

Type Description
tuple[str, ...]

Tuple of output dim names with mapping applied.

Source code in xdflow/transforms/basic_transforms.py
83
84
85
86
87
88
89
90
91
92
def get_expected_output_dims(self, input_dims: tuple[str, ...]) -> tuple[str, ...]:
    """Infer output dims by applying the mapping to the input dims.

    Args:
        input_dims: Tuple of input dim names.

    Returns:
        Tuple of output dim names with mapping applied.
    """
    return tuple(self.rename_map.get(d, d) for d in input_dims)

IdentityTransform

IdentityTransform(sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None)

Bases: Transform

A no-op transform that returns the input unchanged.

This is useful as a selectable option in a SwitchTransform when you want the step to optionally do nothing while preserving dimension contracts.

Initialize an identity transform.

Parameters:

Name Type Description Default
sel dict[str, Any] | None

Optional selection to apply before transforming.

None
drop_sel dict[str, Any] | None

Optional drop selection to apply before transforming.

None
Source code in xdflow/transforms/basic_transforms.py
106
107
108
109
110
111
112
113
def __init__(self, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None):
    """Initialize an identity transform.

    Args:
        sel: Optional selection to apply before transforming.
        drop_sel: Optional drop selection to apply before transforming.
    """
    super().__init__(sel=sel, drop_sel=drop_sel)

get_expected_output_dims

get_expected_output_dims(input_dims: tuple[str, ...]) -> tuple[str, ...]

Return the same dims as input.

Parameters:

Name Type Description Default
input_dims tuple[str, ...]

Input dimension names.

required

Returns:

Type Description
tuple[str, ...]

The same dimension names.

Source code in xdflow/transforms/basic_transforms.py
126
127
128
129
130
131
132
133
134
135
def get_expected_output_dims(self, input_dims: tuple[str, ...]) -> tuple[str, ...]:
    """Return the same dims as input.

    Args:
        input_dims: Input dimension names.

    Returns:
        The same dimension names.
    """
    return input_dims

SampleWeightTransform

SampleWeightTransform(coord_name: str, weight_map: Mapping[Any, float] | None = None, weight_func: Callable[[Any], float] | None = None, default_weight: float = 1.0, target_coord: str = 'sample_weight', dtype: dtype | type = np.float64, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None, transform_sel: dict | None = None, transform_drop_sel: dict | None = None)

Bases: Transform

Attach a sample_weight coordinate derived from an existing coordinate.

The transform maps values from a source coordinate (e.g. session) to scalar weights using either a weight_map or a callable weight_func. The resulting weights are written as a new coordinate (default sample_weight) on the same dimension(s) as the source, so downstream predictors can consume them without changing the data layout.

Partial write-back with transform_sel/transform_drop_sel is not supported because this transform changes coordinates.

Parameters:

Name Type Description Default
coord_name str

Name of the coordinate to read values from.

required
weight_map Mapping[Any, float] | None

Mapping of coordinate values to weights.

None
weight_func Callable[[Any], float] | None

Callable taking a coordinate value and returning a weight.

None
default_weight float

Weight to use when neither map nor func provide a value.

1.0
target_coord str

Name of the coordinate that will store computed weights.

'sample_weight'
dtype dtype | type

Numeric dtype used for the weight array.

float64
Source code in xdflow/transforms/basic_transforms.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def __init__(
    self,
    coord_name: str,
    weight_map: Mapping[Any, float] | None = None,
    weight_func: Callable[[Any], float] | None = None,
    default_weight: float = 1.0,
    target_coord: str = "sample_weight",
    dtype: np.dtype | type = np.float64,
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
    transform_sel: dict | None = None,
    transform_drop_sel: dict | None = None,
):
    """
    Args:
        coord_name: Name of the coordinate to read values from.
        weight_map: Mapping of coordinate values to weights.
        weight_func: Callable taking a coordinate value and returning a weight.
        default_weight: Weight to use when neither map nor func provide a value.
        target_coord: Name of the coordinate that will store computed weights.
        dtype: Numeric dtype used for the weight array.
    """
    if weight_map is not None and weight_func is not None:
        raise ValueError("Specify either 'weight_map' or 'weight_func', not both.")
    super().__init__(sel=sel, drop_sel=drop_sel, transform_sel=transform_sel, transform_drop_sel=transform_drop_sel)
    self.coord_name = coord_name
    self.weight_map = {key: float(value) for key, value in weight_map.items()} if weight_map is not None else None
    self.weight_func = weight_func
    self.default_weight = float(default_weight)
    self.target_coord = target_coord
    self.dtype = dtype
    self._active_for_inference = True

BalanceClassWeightTransform

BalanceClassWeightTransform(class_coord: str, balance_domains: bool = False, domain_coord: str | None = None, domain_weights: Mapping[Any, float] | None = None, normalize_domain_totals: bool = False, weight_normalize: str | None = None, target_coord: str = 'sample_weight', dtype: dtype | type = np.float64, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None, transform_sel: dict | None = None, transform_drop_sel: dict | None = None)

Bases: Transform

Attach a sample-weight coordinate that balances classes, optionally by domain.

Partial write-back with transform_sel/transform_drop_sel is not supported because this transform changes coordinates.

Source code in xdflow/transforms/basic_transforms.py
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
def __init__(
    self,
    class_coord: str,
    balance_domains: bool = False,
    domain_coord: str | None = None,
    domain_weights: Mapping[Any, float] | None = None,
    normalize_domain_totals: bool = False,
    weight_normalize: str | None = None,
    target_coord: str = "sample_weight",
    dtype: np.dtype | type = np.float64,
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
    transform_sel: dict | None = None,
    transform_drop_sel: dict | None = None,
):
    if weight_normalize is not None and weight_normalize not in {"mean", "sum"}:
        raise ValueError("weight_normalize must be one of: None, 'mean', or 'sum'.")
    if not balance_domains:
        if domain_coord is not None:
            raise ValueError("Specify 'domain_coord' only when balance_domains=True.")
        if domain_weights is not None:
            raise ValueError("Specify 'domain_weights' only when balance_domains=True.")
        if normalize_domain_totals:
            raise ValueError("Specify 'normalize_domain_totals' only when balance_domains=True.")
    elif domain_coord is None:
        raise ValueError("domain_coord is required when balance_domains=True.")

    super().__init__(sel=sel, drop_sel=drop_sel, transform_sel=transform_sel, transform_drop_sel=transform_drop_sel)
    self.class_coord = class_coord
    self.balance_domains = balance_domains
    self.domain_coord = domain_coord
    self.domain_weights = {key: float(value) for key, value in domain_weights.items()} if domain_weights else None
    self.normalize_domain_totals = normalize_domain_totals
    self.weight_normalize = weight_normalize
    self.target_coord = target_coord
    self.dtype = dtype

AverageTransform

AverageTransform(dims: str | tuple[str, ...], sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None)

Bases: Transform

Average data over one or more dimensions.

The named dimensions are reduced with xarray.DataArray.mean, so their coordinate values are removed from the result and all other dimensions are preserved. Attributes are kept on the resulting array.

Initialize an averaging transform.

Parameters:

Name Type Description Default
dims str | tuple[str, ...]

Dimension name or dimension names to average over.

required
sel dict[str, Any] | None

Label selection applied before averaging.

None
drop_sel dict[str, Any] | None

Label selection dropped before averaging.

None
Source code in xdflow/transforms/basic_transforms.py
465
466
467
468
469
470
471
472
473
474
475
476
def __init__(
    self, dims: str | tuple[str, ...], sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None
):
    """Initialize an averaging transform.

    Args:
        dims: Dimension name or dimension names to average over.
        sel: Label selection applied before averaging.
        drop_sel: Label selection dropped before averaging.
    """
    super().__init__(sel=sel, drop_sel=drop_sel)
    self.dims = (dims,) if isinstance(dims, str) else dims

get_expected_output_dims

get_expected_output_dims(input_dims: tuple[str, ...]) -> tuple[str, ...]

Determines the expected output dimensions by removing the averaged dimensions.

Parameters:

Name Type Description Default
input_dims tuple[str, ...]

A tuple of dimension names of the input data.

required

Returns:

Type Description
tuple[str, ...]

A tuple of dimension names for the output data.

Source code in xdflow/transforms/basic_transforms.py
498
499
500
501
502
503
504
505
506
507
508
def get_expected_output_dims(self, input_dims: tuple[str, ...]) -> tuple[str, ...]:
    """
    Determines the expected output dimensions by removing the averaged dimensions.

    Args:
        input_dims: A tuple of dimension names of the input data.

    Returns:
        A tuple of dimension names for the output data.
    """
    return tuple(dim for dim in input_dims if dim not in self.dims)

FlattenTransform

FlattenTransform(dims: tuple[str, ...], sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None)

Bases: Transform

Stack multiple dimensions into one new dimension.

Flattening uses xarray.DataArray.stack, so the new dimension receives a pandas MultiIndex coordinate containing the original coordinate labels. The new dimension is named flat_<dim1>__<dim2>.

Initialize a flattening transform.

Parameters:

Name Type Description Default
dims tuple[str, ...]

At least two dimension names to stack into one.

required
sel dict[str, Any] | None

Label selection applied before flattening.

None
drop_sel dict[str, Any] | None

Label selection dropped before flattening.

None
Source code in xdflow/transforms/basic_transforms.py
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
def __init__(
    self, dims: tuple[str, ...], sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None
):
    """Initialize a flattening transform.

    Args:
        dims: At least two dimension names to stack into one.
        sel: Label selection applied before flattening.
        drop_sel: Label selection dropped before flattening.
    """
    super().__init__(sel=sel, drop_sel=drop_sel)
    if not isinstance(dims, tuple) or len(dims) < 2:
        raise ValueError("`dims` must be a tuple of at least two strings.")
    self.dims = dims
    self.new_dim_name = f"flat_{'__'.join(self.dims)}"

get_expected_output_dims

get_expected_output_dims(input_dims: tuple[str, ...]) -> tuple[str, ...]

Determines the expected output dimensions after flattening. The new flattened dimension is appended at the end.

Parameters:

Name Type Description Default
input_dims tuple[str, ...]

A tuple of dimension names of the input data.

required

Returns:

Type Description
tuple[str, ...]

A tuple of dimension names for the output data.

Source code in xdflow/transforms/basic_transforms.py
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
def get_expected_output_dims(self, input_dims: tuple[str, ...]) -> tuple[str, ...]:
    """
    Determines the expected output dimensions after flattening. The new
    flattened dimension is appended at the end.

    Args:
        input_dims: A tuple of dimension names of the input data.

    Returns:
        A tuple of dimension names for the output data.
    """
    # Remove the original dimensions that are being flattened
    remaining_dims = [dim for dim in input_dims if dim not in self.dims]
    # Add the new flattened dimension to the end
    return tuple(remaining_dims) + (self.new_dim_name,)

FunctionTransform

FunctionTransform(func: Callable, expected_output_dims: tuple[str, ...] | None = None, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None)

Bases: Transform

Applies a function to the whole xarray.DataArray. The function must work on xarray.DataArray/numpy. Useful for applying simple mathemtical functions like np.abs, np.log, etc. For xarray functions with additional arguments, use partial functions, e.g.

from functools import partial
FunctionTransform(func=partial(xr.DataArray.max, dim="time"), expected_output_dims=("trial", "channel", "freq_band"))

Initializes the FunctionTransform.

Parameters:

Name Type Description Default
func Callable

The function to apply to the data. The function must work on xarray.DataArray/numpy.

required
sel dict[str, Any] | None

A dictionary to select a subset of data for transformation.

None
drop_sel dict[str, Any] | None

A dictionary to drop a subset of data for transformation.

None
Source code in xdflow/transforms/basic_transforms.py
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
def __init__(
    self,
    func: Callable,
    expected_output_dims: tuple[str, ...]
    | None = None,  # needed for functions that change dimensionality (e.g. np.mean)
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
):
    """
    Initializes the FunctionTransform.

    Args:
        func: The function to apply to the data. The function must work on xarray.DataArray/numpy.
        sel: A dictionary to select a subset of data for transformation.
        drop_sel: A dictionary to drop a subset of data for transformation.

        It is highly recommended to use a vectorized function from NumPy (e.g., `np.abs`) or from xarray
    """
    super().__init__(sel=sel, drop_sel=drop_sel)
    self.func = func

    # Perform a check to ensure the function is compatible with xarray DataArrays at initialization.
    try:
        # Use a very small, simple DataArray for the check.
        test_data = xr.DataArray(np.array([1.0, 2.0]))
        self.func(test_data)
    except Exception as e:
        raise ValueError(
            f"The provided function '{getattr(self.func, '__name__', 'unknown')}' is not compatible with xarray.DataArray. "
            "Please provide a vectorized function (like those from NumPy). "
        ) from e

    self.expected_output_dims = expected_output_dims

UnflattenTransform

UnflattenTransform(dim: str, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None)

Bases: Transform

Unstack a flattened dimension back into its source dimensions.

The input dimension must follow the naming convention produced by FlattenTransform, such as flat_trial__time. If the coordinate is not already a pandas MultiIndex, XDFlow attempts to build one from tuple-like coordinate values before unstacking.

Initializes the UnflattenTransform.

Parameters:

Name Type Description Default
dim str

dimension to unflatten, must follow naming output of FlattenTransform (e.g. 'flat_dim1__dim2').

required
Source code in xdflow/transforms/basic_transforms.py
659
660
661
662
663
664
665
666
667
668
669
670
def __init__(self, dim: str, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None):
    """
    Initializes the UnflattenTransform.

    Args:
        dim: dimension to unflatten, must follow naming output of FlattenTransform (e.g. 'flat_dim1__dim2').
    """
    super().__init__(sel=sel, drop_sel=drop_sel)
    if not dim.startswith("flat_"):
        raise ValueError("`dim` must have been flattened before, and must start with 'flat_'.")
    self.dim = dim
    self.new_dim_names = dim.replace("flat_", "").split("__")

get_expected_output_dims

get_expected_output_dims(input_dims: tuple[str, ...]) -> tuple[str, ...]

Determines the expected output dimensions after unflattening. The new unflattened dimensions are appended at the end.

Parameters:

Name Type Description Default
input_dims tuple[str, ...]

A tuple of dimension names of the input data.

required

Returns:

Type Description
tuple[str, ...]

A tuple of dimension names for the output data.

Source code in xdflow/transforms/basic_transforms.py
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
def get_expected_output_dims(self, input_dims: tuple[str, ...]) -> tuple[str, ...]:
    """
    Determines the expected output dimensions after unflattening. The new
    unflattened dimensions are appended at the end.

    Args:
        input_dims: A tuple of dimension names of the input data.

    Returns:
        A tuple of dimension names for the output data.
    """
    # Remove the original dimensions that are being unflattened
    remaining_dims = [dim for dim in input_dims if dim not in self.dim]
    # Add the new unflattened dimensions to the end
    return tuple(remaining_dims) + tuple(self.new_dim_names)

TrialSampler

TrialSampler(n_trials: int, shuffle: bool = True, random_state: int = 0, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None)

Bases: Transform

Select a fixed number of trials from the trial dimension.

The transform samples by integer position. When shuffle=True, trial order is shuffled with NumPy's default random generator before the first n_trials positions are selected.

Initialize trial subsampling.

Parameters:

Name Type Description Default
n_trials int

Number of trials to keep.

required
shuffle bool

Whether to shuffle trial indices before sampling.

True
random_state int

Seed used when shuffle=True.

0
sel dict[str, Any] | None

Label selection applied before sampling.

None
drop_sel dict[str, Any] | None

Label selection dropped before sampling.

None
Source code in xdflow/transforms/basic_transforms.py
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
def __init__(
    self,
    n_trials: int,
    shuffle: bool = True,
    random_state: int = 0,
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
):
    """Initialize trial subsampling.

    Args:
        n_trials: Number of trials to keep.
        shuffle: Whether to shuffle trial indices before sampling.
        random_state: Seed used when `shuffle=True`.
        sel: Label selection applied before sampling.
        drop_sel: Label selection dropped before sampling.
    """
    super().__init__(sel=sel, drop_sel=drop_sel)
    self.n_trials = n_trials
    self.shuffle = shuffle
    self.random_state = random_state

get_expected_output_dims

get_expected_output_dims(input_dims: tuple[str, ...]) -> tuple[str, ...]

Determines the expected output dimensions after sampling.

Source code in xdflow/transforms/basic_transforms.py
811
812
813
814
815
def get_expected_output_dims(self, input_dims: tuple[str, ...]) -> tuple[str, ...]:
    """
    Determines the expected output dimensions after sampling.
    """
    return input_dims

CropTimeTransform

CropTimeTransform(time_window_start_ms: float, time_window_end_ms: float, time_coord: str = 'time', sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None)

Bases: Transform

Crop data to a time window using the time coordinate, inclusive.

This transform selects a subset of the data along a time-like coordinate using label-based slicing, preserving all other dimensions and coordinates.

The start and end are inclusive and interpreted in the same units as the DataArray's time coordinate values.

Parameters:

Name Type Description Default
time_window_start_ms float

Start of the time window (inclusive).

required
time_window_end_ms float

End of the time window (inclusive).

required
time_coord str

Name of the time coordinate/dimension to slice over. Defaults to "time".

'time'
sel dict[str, Any] | None

Optional selection to apply before transforming.

None
drop_sel dict[str, Any] | None

Optional drop selection to apply before transforming.

None
Source code in xdflow/transforms/basic_transforms.py
840
841
842
843
844
845
846
847
848
849
850
851
def __init__(
    self,
    time_window_start_ms: float,
    time_window_end_ms: float,
    time_coord: str = "time",
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
):
    super().__init__(sel=sel, drop_sel=drop_sel)
    self.time_window_start_ms = time_window_start_ms
    self.time_window_end_ms = time_window_end_ms
    self.time_coord = time_coord

get_expected_output_dims

get_expected_output_dims(input_dims: tuple[str, ...]) -> tuple[str, ...]

Cropping preserves dimension names; sizes may shrink.

Parameters:

Name Type Description Default
input_dims tuple[str, ...]

Input dimension names.

required

Returns:

Type Description
tuple[str, ...]

The same dimension names.

Source code in xdflow/transforms/basic_transforms.py
871
872
873
874
875
876
877
878
879
880
def get_expected_output_dims(self, input_dims: tuple[str, ...]) -> tuple[str, ...]:
    """Cropping preserves dimension names; sizes may shrink.

    Args:
        input_dims: Input dimension names.

    Returns:
        The same dimension names.
    """
    return input_dims

Cleaning And Normalization

CARTransform

CARTransform(car_method: str = 'all', excluded_channels: Sequence[str] | None = None, sel: dict[str, object] | None = None, drop_sel: dict[str, object] | None = None)

Bases: Transform

Apply Common Average Referencing (CAR) to the data.

CAR can be applied to all signal channels or disabled.

Parameters:

Name Type Description Default
car_method str

'all' or 'none'.

'all'
excluded_channels Sequence[str] | None

Channels to leave untouched (e.g., reference sensors).

None
Source code in xdflow/transforms/cleaning.py
37
38
39
40
41
42
43
44
45
46
47
def __init__(
    self,
    car_method: str = "all",
    excluded_channels: Sequence[str] | None = None,
    sel: dict[str, object] | None = None,
    drop_sel: dict[str, object] | None = None,
):
    super().__init__(sel=sel, drop_sel=drop_sel)
    self.car_method = car_method
    self.excluded_channels = excluded_channels  # for clone
    self._excluded_channels = tuple(excluded_channels or ())

RegressOutReferenceTransform

RegressOutReferenceTransform(reference_channel: str, excluded_channels: Sequence[str] | None = None, sel: dict[str, object] | None = None, drop_sel: dict[str, object] | None = None)

Bases: Transform

Regress out a reference channel from all other channels.

Uses linear regression to model the relationship between the reference signal and each target channel, then subtracts the predicted component.

Source code in xdflow/transforms/cleaning.py
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def __init__(
    self,
    reference_channel: str,
    excluded_channels: Sequence[str] | None = None,
    sel: dict[str, object] | None = None,
    drop_sel: dict[str, object] | None = None,
):
    if not reference_channel:
        raise ValueError("reference_channel must be provided.")
    super().__init__(sel=sel, drop_sel=drop_sel)
    self.reference_channel = reference_channel
    self.excluded_channels = excluded_channels  # for clone
    extras = tuple(excluded_channels or ())
    self._excluded_channels = tuple(set(extras) | {reference_channel})

RemoveOutliersTransform

RemoveOutliersTransform(per_dim: str | list[str] | tuple[str, ...], std_threshold: float = 5.0, use_fit: bool = False, sel: dict[str, object] | None = None, drop_sel: dict[str, object] | None = None, transform_sel: dict[str, object] | None = None, transform_drop_sel: dict[str, object] | None = None)

Bases: Transform

Remove outliers from the data by clipping independently per selected dimension label.

Outliers are identified as values exceeding a specified number of standard deviations from the mean. They are replaced by the boundary value.

Parameters:

Name Type Description Default
per_dim str | list[str] | tuple[str, ...]

Dimension or dimensions to keep distinct while computing stats.

required
std_threshold float

The number of standard deviations to use as the threshold. Defaults to 5.0.

5.0
use_fit bool

Whether to use the fit data to compute the mean and std. Defaults to False.

False
transform_sel dict

A dictionary to select a subset of data for transformation, leaving the rest untouched.

None
transform_drop_sel dict

A dictionary to select a subset of data for transformation by excluding labels, leaving the rest untouched.

None

E.g. if your input dims are ("trial", "channel", "time"), and you set per_dim to "channel", then the data will be clipped per channel by clipping the data to the std_threshold from the mean of trial and time per channel. Unlike xarray's dim=, per_dim does not name dimensions to reduce. When multiple dimensions are provided, clipping bounds are computed separately for each coordinate tuple across those dimensions.

Source code in xdflow/transforms/cleaning.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def __init__(
    self,
    per_dim: str | list[str] | tuple[str, ...],
    std_threshold: float = 5.0,
    use_fit: bool = False,
    sel: dict[str, object] | None = None,
    drop_sel: dict[str, object] | None = None,
    transform_sel: dict[str, object] | None = None,
    transform_drop_sel: dict[str, object] | None = None,
):
    super().__init__(sel=sel, drop_sel=drop_sel, transform_sel=transform_sel, transform_drop_sel=transform_drop_sel)
    self.std_threshold = std_threshold
    self.per_dim = per_dim
    self.use_fit = use_fit
    self.is_stateful = self.use_fit

DemeanTransform

DemeanTransform(per_dim: str | list[str] | tuple[str, ...], use_fit: bool = False, fit_sel: dict[str, Any] | None = None, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None, transform_sel: dict[str, Any] | None = None, transform_drop_sel: dict[str, Any] | None = None)

Bases: Transform

Subtract a mean independently per selected dimension label.

per_dim names the dimensions whose labels remain distinct while the mean is computed over all other dimensions. For data with dimensions ("trial", "channel", "time"), per_dim="channel" subtracts a separate channel mean computed across trials and time. Unlike xarray's dim=, per_dim does not name dimensions to reduce. When multiple dimensions are provided, statistics are computed separately for each coordinate tuple across those dimensions.

By default the mean is computed from the data being transformed. With use_fit=True, the mean is learned during fit and reused during transform, which is the usual choice inside cross-validation.

Initializes the DemeanTransform.

Parameters:

Name Type Description Default
per_dim str | list[str] | tuple[str, ...]

Dimension or dimensions to keep distinct while computing the mean.

required
use_fit bool

Whether to use the fit data to compute the mean.

False
fit_sel dict[str, Any] | None

A dictionary to select a subset of data for fitting. Useful if you want to demean in reference to a subset of the data. If specified, use_fit will be set to True.

None
sel dict[str, Any] | None

A dictionary to select a subset of data for transformation.

None
drop_sel dict[str, Any] | None

A dictionary to drop a subset of data for transformation.

None
transform_sel dict[str, Any] | None

A dictionary to select a subset of data for transformation.

None
transform_drop_sel dict[str, Any] | None

A dictionary to drop a subset of data for transformation.

None
Source code in xdflow/transforms/normalization.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def __init__(
    self,
    per_dim: str | list[str] | tuple[str, ...],
    use_fit: bool = False,
    fit_sel: dict[str, Any] | None = None,
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
    transform_sel: dict[str, Any] | None = None,
    transform_drop_sel: dict[str, Any] | None = None,
):
    """
    Initializes the DemeanTransform.

    Args:
        per_dim: Dimension or dimensions to keep distinct while computing the mean.
        use_fit: Whether to use the fit data to compute the mean.
        fit_sel: A dictionary to select a subset of data for fitting.
            Useful if you want to demean in reference to a subset of the data.
            If specified, use_fit will be set to True.
        sel: A dictionary to select a subset of data for transformation.
        drop_sel: A dictionary to drop a subset of data for transformation.
        transform_sel: A dictionary to select a subset of data for transformation.
        transform_drop_sel: A dictionary to drop a subset of data for transformation.
    """

    super().__init__(sel=sel, drop_sel=drop_sel, transform_sel=transform_sel, transform_drop_sel=transform_drop_sel)
    self.per_dim = per_dim
    self.use_fit = use_fit
    self.is_stateful = self.use_fit
    self.fit_sel = fit_sel

    if self.fit_sel is not None and not self.use_fit:
        warnings.warn("fit_sel is specified but use_fit is False. use_fit will be set to True.")
        self.use_fit = True

ZScoreTransform

ZScoreTransform(per_dim: str | list[str] | tuple[str, ...], use_fit: bool = False, fit_sel: dict[str, Any] | None = None, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None, transform_sel: dict[str, Any] | None = None, transform_drop_sel: dict[str, Any] | None = None)

Bases: Transform

Apply z-score normalization independently per selected dimension label.

per_dim names the dimensions whose labels remain distinct while the mean and standard deviation are computed over all other dimensions. For data with dimensions ("trial", "channel", "time"), per_dim="channel" normalizes each channel using statistics computed across trials and time. Unlike xarray's dim=, per_dim does not name dimensions to reduce. When multiple dimensions are provided, statistics are computed separately for each coordinate tuple across those dimensions.

By default statistics are computed from the data being transformed. With use_fit=True, statistics are learned during fit and reused during transform, which avoids validation leakage inside cross-validation.

Initialize a z-score transform.

Parameters:

Name Type Description Default
per_dim str | list[str] | tuple[str, ...]

Dimension or dimensions to keep distinct while computing the mean and std.

required
use_fit bool

Whether to use the fit data to compute the mean and std.

False
fit_sel dict[str, Any] | None

A dictionary to select a subset of data for fitting. Useful if you want to zscore in reference to a subset of the data. If specified, use_fit will be set to True.

None
sel dict[str, Any] | None

A dictionary to select a subset of data for transformation.

None
drop_sel dict[str, Any] | None

A dictionary to drop a subset of data for transformation.

None
transform_sel dict[str, Any] | None

A dictionary to select a subset of data for transformation.

None
transform_drop_sel dict[str, Any] | None

A dictionary to drop a subset of data for transformation.

None
Source code in xdflow/transforms/normalization.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def __init__(
    self,
    per_dim: str | list[str] | tuple[str, ...],
    use_fit: bool = False,
    fit_sel: dict[str, Any] | None = None,
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
    transform_sel: dict[str, Any] | None = None,
    transform_drop_sel: dict[str, Any] | None = None,
):
    """Initialize a z-score transform.

    Args:
        per_dim: Dimension or dimensions to keep distinct while computing the mean and std.
        use_fit: Whether to use the fit data to compute the mean and std.
        fit_sel: A dictionary to select a subset of data for fitting.
            Useful if you want to zscore in reference to a subset of the data.
            If specified, use_fit will be set to True.
        sel: A dictionary to select a subset of data for transformation.
        drop_sel: A dictionary to drop a subset of data for transformation.
        transform_sel: A dictionary to select a subset of data for transformation.
        transform_drop_sel: A dictionary to drop a subset of data for transformation.
    """

    super().__init__(sel=sel, drop_sel=drop_sel, transform_sel=transform_sel, transform_drop_sel=transform_drop_sel)
    self.per_dim = per_dim
    self.use_fit = use_fit
    self.is_stateful = self.use_fit
    self.fit_sel = fit_sel

    if self.fit_sel is not None and not self.use_fit:
        warnings.warn("fit_sel is specified but use_fit is False. use_fit will be set to True.")
        self.use_fit = True

Sklearn Adapters

SKLearnTransform

SKLearnTransform(estimator_cls: type[BaseEstimator], sample_dim: str, target_coord: str | None = None, _estimator_instance: BaseEstimator | None = None, sample_weight_coord: str | None = 'sample_weight', sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None, **kwargs: Any)

Bases: Transform, SampleWeightMixin

Adapt a scikit-learn estimator to the XDFlow transform API.

The wrapped estimator receives a two-dimensional matrix with samples along sample_dim. If target_coord is provided, the coordinate values are extracted and passed as y during fit; otherwise the estimator is treated as unsupervised. Sample weights can be forwarded from a coordinate when the estimator's fit method accepts sample_weight.

Keyword arguments not owned by the wrapper or its parent classes are passed to the estimator constructor and preserved for cloning.

Notes

This class participates in cooperative multiple inheritance with Predictor through SKLearnPredictor. Wrapper hyperparameters should be explicit attributes; estimator kwargs are stored separately in _estimator_kwargs.

Initialize the wrapper and the underlying estimator.

Parameters:

Name Type Description Default
estimator_cls type[BaseEstimator]

Uninitialized scikit-learn estimator class.

required
sample_dim str

Dimension whose entries are samples.

required
target_coord str | None

Optional coordinate used as supervised target y.

None
_estimator_instance BaseEstimator | None

Pre-built estimator used internally by SKLearnPredictor after task-type wrapping.

None
sample_weight_coord str | None

Coordinate containing optional sample weights. Set to None to disable sample-weight forwarding.

'sample_weight'
sel dict[str, Any] | None

Label selection applied before fitting or transforming.

None
drop_sel dict[str, Any] | None

Label selection dropped before fitting or transforming.

None
**kwargs Any

Parent-class options and estimator constructor arguments.

{}
Source code in xdflow/transforms/sklearn_transform.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def __init__(
    self,
    estimator_cls: type[BaseEstimator],
    sample_dim: str,
    target_coord: str | None = None,
    _estimator_instance: BaseEstimator
    | None = None,  # SKLearnPredictor needs to instantiate the estimator; only used internally, not by user
    sample_weight_coord: str | None = "sample_weight",
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
    **kwargs: Any,
):
    """Initialize the wrapper and the underlying estimator.

    Args:
        estimator_cls: Uninitialized scikit-learn estimator class.
        sample_dim: Dimension whose entries are samples.
        target_coord: Optional coordinate used as supervised target `y`.
        _estimator_instance: Pre-built estimator used internally by
            `SKLearnPredictor` after task-type wrapping.
        sample_weight_coord: Coordinate containing optional sample weights.
            Set to None to disable sample-weight forwarding.
        sel: Label selection applied before fitting or transforming.
        drop_sel: Label selection dropped before fitting or transforming.
        **kwargs: Parent-class options and estimator constructor arguments.
    """
    # For cooperative inheritance, we pass all kwargs up the chain.
    # for SKLearnPredictor, SKLearnTransform is initialized first, then Predictor.
    super().__init__(sample_dim=sample_dim, target_coord=target_coord, sel=sel, drop_sel=drop_sel, **kwargs)

    parent_param_names = collect_super_init_param_names(type(self), SKLearnTransform)
    self.estimator_cls = estimator_cls  # needed for clone
    self.sample_weight_coord = sample_weight_coord
    self._fit_param_support_cache: dict[str, bool] = {}

    # Extract estimator-specific parameters (everything not used by Transform or Predictor)
    if _estimator_instance is not None:
        self.estimator: Any = _estimator_instance
        self._estimator_kwargs = getattr(self, "_estimator_kwargs", None)
        if self._estimator_kwargs is None:
            self._estimator_kwargs = {k: v for k, v in kwargs.items() if k not in parent_param_names}
    else:
        self._estimator_kwargs = {k: v for k, v in kwargs.items() if k not in parent_param_names}
        self.estimator: Any = estimator_cls(**self._estimator_kwargs)

    if not hasattr(self.estimator, "fit"):
        raise TypeError(
            f"The provided estimator class '{estimator_cls.__name__}' must produce an object with a 'fit' method."
        )

    self.sample_dim = sample_dim
    self.target_coord = target_coord

get_params

get_params(deep: bool = True) -> dict[str, Any]

Get parameters for this transform, including the wrapped estimator. This is part of the scikit-learn estimator API.

Source code in xdflow/transforms/sklearn_transform.py
182
183
184
185
186
187
188
189
190
191
192
193
def get_params(self, deep: bool = True) -> dict[str, Any]:
    """
    Get parameters for this transform, including the wrapped estimator.
    This is part of the scikit-learn estimator API.
    """
    # Get parameters from the wrapper itself
    params = super().get_params(deep=deep)
    if hasattr(self, "estimator"):
        # Get parameters from the wrapped estimator
        estimator_params = self.estimator.get_params(deep=deep)
        params.update(estimator_params)
    return params

set_params

set_params(**params: Any) -> SKLearnTransform

Set the parameters of this transform and its wrapped estimator. This is part of the scikit-learn estimator API.

Source code in xdflow/transforms/sklearn_transform.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def set_params(self, **params: Any) -> "SKLearnTransform":
    """
    Set the parameters of this transform and its wrapped estimator.
    This is part of the scikit-learn estimator API.
    """
    estimator_params = {}
    wrapper_params = {}

    # Separate parameters for the estimator and the wrapper
    for key, value in params.items():
        if hasattr(self, "estimator") and key in self.estimator.get_params(deep=True):
            estimator_params[key] = value
        else:
            wrapper_params[key] = value

    # Set parameters on the estimator
    if hasattr(self, "estimator") and estimator_params:
        self.estimator.set_params(**estimator_params)

    # Set parameters on the wrapper
    super().set_params(**wrapper_params)

    return self

__setattr__

__setattr__(name: str, value: Any) -> None

Override setattr to act as a proxy.

If the attribute is a parameter of the underlying estimator, set it there. Otherwise, set it on the wrapper.

Source code in xdflow/transforms/sklearn_transform.py
219
220
221
222
223
224
225
226
227
228
229
def __setattr__(self, name: str, value: Any) -> None:
    """
    Override setattr to act as a proxy.

    If the attribute is a parameter of the underlying estimator, set it there.
    Otherwise, set it on the wrapper.
    """
    if hasattr(self, "estimator") and name in self.estimator.get_params():
        setattr(self.estimator, name, value)
    else:
        super().__setattr__(name, value)

__hasattr__

__hasattr__(name: str) -> bool

Override hasattr to check if the parameter exists either on the wrapper or the estimator.

Source code in xdflow/transforms/sklearn_transform.py
231
232
233
234
235
236
237
def __hasattr__(self, name: str) -> bool:
    """
    Override hasattr to check if the parameter exists either on the wrapper or the estimator.
    """
    if name in ["estimator", "sample_dim", "target_coord"]:
        return name in self.__dict__ or hasattr(type(self), name)
    return hasattr(self.estimator, name) or name in self.__dict__ or hasattr(type(self), name)

SKLearnTransformer

SKLearnTransformer(estimator_cls: type[BaseEstimator], sample_dim: str, target_coord: str | None = None, output_dim_name: str = 'component', sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None, **kwargs: Any)

Bases: SKLearnTransform

Adapt a scikit-learn transformer to return a DataContainer.

The estimator must implement fit and transform. Input data is arranged as (sample_dim, features). The transformed matrix is returned with sample_dim preserved and a new feature-like dimension named by output_dim_name.

Initialize a transformer wrapper.

Parameters:

Name Type Description Default
estimator_cls type[BaseEstimator]

Uninitialized scikit-learn transformer class.

required
sample_dim str

Dimension whose entries are samples.

required
target_coord str | None

Optional supervised target coordinate for estimators whose fit accepts y.

None
output_dim_name str

Name of the non-sample output dimension.

'component'
sel dict[str, Any] | None

Label selection applied before fitting or transforming.

None
drop_sel dict[str, Any] | None

Label selection dropped before fitting or transforming.

None
**kwargs Any

Parent-class options and estimator constructor arguments.

{}
Source code in xdflow/transforms/sklearn_transform.py
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def __init__(
    self,
    estimator_cls: type[BaseEstimator],
    sample_dim: str,
    target_coord: str | None = None,
    output_dim_name: str = "component",
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
    **kwargs: Any,
):
    """Initialize a transformer wrapper.

    Args:
        estimator_cls: Uninitialized scikit-learn transformer class.
        sample_dim: Dimension whose entries are samples.
        target_coord: Optional supervised target coordinate for estimators
            whose `fit` accepts `y`.
        output_dim_name: Name of the non-sample output dimension.
        sel: Label selection applied before fitting or transforming.
        drop_sel: Label selection dropped before fitting or transforming.
        **kwargs: Parent-class options and estimator constructor arguments.
    """
    super().__init__(
        estimator_cls=estimator_cls,
        sample_dim=sample_dim,
        target_coord=target_coord,
        sel=sel,
        drop_sel=drop_sel,
        **kwargs,
    )
    self.output_dim_name = output_dim_name
    if not hasattr(self.estimator, "transform"):
        raise TypeError(
            f"The provided estimator class '{self.estimator.__class__.__name__}' must have a 'transform' method."
        )

get_expected_output_dims

get_expected_output_dims(input_dims: tuple[str, ...]) -> tuple[str, ...]

Determines the expected output dimensions.

Source code in xdflow/transforms/sklearn_transform.py
356
357
358
359
360
def get_expected_output_dims(self, input_dims: tuple[str, ...]) -> tuple[str, ...]:
    """Determines the expected output dimensions."""
    if len(input_dims) != 2:
        raise ValueError(f"Expected 2 input dimensions, but got {len(input_dims)}")
    return (self.sample_dim, self.output_dim_name)

SKLearnPredictor

SKLearnPredictor(estimator_cls: type[BaseEstimator], sample_dim: str, target_coord: str | list[str], encoder: LabelEncoder | None = None, proba: bool = False, is_classifier: bool | None = None, multi_output: bool = False, is_multilabel: bool = False, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None, sample_weight_coord: str | None = 'sample_weight', **kwargs: Any)

Bases: SKLearnTransform, Predictor

Adapt a scikit-learn estimator to the XDFlow predictor API.

The estimator is fitted on a two-dimensional matrix plus target coordinate values, then exposed through predict, predict_proba, and transform. Classifier or regressor mode is auto-detected from the estimator when possible, or can be supplied with is_classifier.

Multi-target regression and multilabel classification can be wrapped with scikit-learn's multi-output estimators by setting multi_output=True, or by passing multiple target coordinates where wrapping can be inferred.

Initialize a predictor wrapper.

Parameters:

Name Type Description Default
estimator_cls type[BaseEstimator]

Uninitialized scikit-learn estimator class.

required
sample_dim str

Dimension whose entries are samples.

required
target_coord str | list[str]

Target coordinate name, target coordinate list, or wildcard pattern resolved during fit.

required
encoder LabelEncoder | None

Optional label encoder for single-label classifiers.

None
proba bool

Whether transform should call predict_proba instead of predict.

False
is_classifier bool | None

Explicitly set classifier or regressor mode. If None, task type is inferred from the estimator.

None
multi_output bool

Whether to wrap the estimator for multi-target regression or multilabel classification.

False
is_multilabel bool

Whether targets are multiple binary coordinates.

False
sel dict[str, Any] | None

Label selection applied before fitting or transforming.

None
drop_sel dict[str, Any] | None

Label selection dropped before fitting or transforming.

None
sample_weight_coord str | None

Coordinate containing optional sample weights.

'sample_weight'
**kwargs Any

Estimator constructor arguments plus parent-class options.

{}
Source code in xdflow/transforms/sklearn_transform.py
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
def __init__(
    self,
    estimator_cls: type[BaseEstimator],
    sample_dim: str,
    target_coord: str | list[str],
    encoder: LabelEncoder | None = None,
    proba: bool = False,
    is_classifier: bool | None = None,
    multi_output: bool = False,
    is_multilabel: bool = False,
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
    sample_weight_coord: str | None = "sample_weight",
    **kwargs: Any,
):
    """Initialize a predictor wrapper.

    Args:
        estimator_cls: Uninitialized scikit-learn estimator class.
        sample_dim: Dimension whose entries are samples.
        target_coord: Target coordinate name, target coordinate list, or
            wildcard pattern resolved during fit.
        encoder: Optional label encoder for single-label classifiers.
        proba: Whether `transform` should call `predict_proba` instead of
            `predict`.
        is_classifier: Explicitly set classifier or regressor mode. If None,
            task type is inferred from the estimator.
        multi_output: Whether to wrap the estimator for multi-target
            regression or multilabel classification.
        is_multilabel: Whether targets are multiple binary coordinates.
        sel: Label selection applied before fitting or transforming.
        drop_sel: Label selection dropped before fitting or transforming.
        sample_weight_coord: Coordinate containing optional sample weights.
        **kwargs: Estimator constructor arguments plus parent-class options.
    """
    parent_param_names = collect_super_init_param_names(type(self), SKLearnTransform)
    # Separate kwargs for the estimator from the rest
    self._estimator_kwargs = {k: v for k, v in kwargs.items() if k not in parent_param_names}

    # Store the original estimator class before potential wrapping
    self._base_estimator_cls = estimator_cls

    _base_instance = estimator_cls(**self._estimator_kwargs)

    # Auto-detect or use manual override before deciding on multi-output wrapping.
    if is_classifier is None:
        if is_multilabel:
            is_classifier = True
        elif sklearn_is_classifier(_base_instance):
            is_classifier = True
        elif sklearn_is_regressor(_base_instance):
            is_classifier = False
        else:
            raise ValueError(
                f"Could not auto-detect task type for {self._base_estimator_cls.__name__}. "
                "Please explicitly specify is_classifier=True or is_classifier=False."
            )

    has_multiple_targets = isinstance(target_coord, (list, tuple)) and len(target_coord) > 1
    supports_multi_output_task = (not is_classifier) or is_multilabel
    if not multi_output and inspect.isclass(estimator_cls) and has_multiple_targets and supports_multi_output_task:
        multi_output = True

    self.multi_output = multi_output

    if multi_output:
        from sklearn.multioutput import MultiOutputClassifier, MultiOutputRegressor

        from xdflow.transforms.multi_output_wrapper import (
            MultiOutputClassifierFactory,
            MultiOutputRegressorFactory,
        )

        _already_multioutput = estimator_cls in (MultiOutputRegressor, MultiOutputClassifier) or isinstance(
            estimator_cls, (MultiOutputRegressorFactory, MultiOutputClassifierFactory)
        )
        if not _already_multioutput:
            original_cls = estimator_cls
            factory = MultiOutputClassifierFactory if is_multilabel else MultiOutputRegressorFactory
            estimator_cls = factory(cast(type[BaseEstimator], original_cls))
            if kwargs.get("verbose", False):
                print(
                    f"[SKLearnPredictor] Wrapping {original_cls.__name__} with "
                    f"{'MultiOutputClassifier' if is_multilabel else 'MultiOutputRegressor'} "
                    "for multi-target prediction"
                )

        _estimator_instance = estimator_cls(**self._estimator_kwargs)
    else:
        _estimator_instance = _base_instance

    if multi_output and is_classifier and not is_multilabel:
        raise ValueError(
            "multi_output=True with is_classifier=True requires is_multilabel=True. "
            "Multi-output classification is only supported in multilabel mode."
        )

    if proba and not hasattr(_estimator_instance, "predict_proba"):
        raise AttributeError(
            f"Estimator '{_estimator_instance.__class__.__name__}' has no method 'predict_proba' but 'proba' was set to True."
        )

    # Correctly initialize both parent classes for cooperative multiple inheritance
    super().__init__(
        estimator_cls=estimator_cls,
        _estimator_instance=_estimator_instance,
        sample_dim=sample_dim,
        target_coord=target_coord,
        encoder=encoder,
        is_classifier=is_classifier,
        is_multilabel=is_multilabel,
        proba=proba,
        sel=sel,
        drop_sel=drop_sel,
        sample_weight_coord=sample_weight_coord,
        **kwargs,
    )

get_expected_output_dims

get_expected_output_dims(input_dims: tuple[str, ...]) -> tuple[str, ...]

Determines the expected output dimensions for the .transform() method. The public .predict() method will always produce a 1D output.

Source code in xdflow/transforms/sklearn_transform.py
529
530
531
532
533
534
535
536
537
538
539
540
541
542
def get_expected_output_dims(self, input_dims: tuple[str, ...]) -> tuple[str, ...]:
    """
    Determines the expected output dimensions for the `.transform()` method.
    The public `.predict()` method will always produce a 1D output.
    """
    if len(input_dims) != 2:
        raise ValueError(f"Expected 2 input dimensions, but got {len(input_dims)}")

    if self.proba and not self.is_multilabel:
        return (self.sample_dim, "class")
    elif self.is_multi_target or self.is_multilabel:
        return (self.sample_dim, "target")
    else:
        return (self.sample_dim, "prediction")

MultiOutputRegressorFactory

MultiOutputRegressorFactory(base_estimator_cls: type[BaseEstimator])

Picklable factory that wraps any sklearn-compatible estimator in MultiOutputRegressor.

This is needed because lambda functions cannot be pickled, which breaks caching. Using this class allows sklearn estimators to work with multi-target regression while maintaining compatibility with pickle/caching.

Parameters

base_estimator_cls : Type[BaseEstimator] The sklearn-compatible estimator class to wrap (e.g., LGBMRegressor, Ridge, etc.)

Initialize the factory with a base estimator class.

Parameters

base_estimator_cls : Type[BaseEstimator] The estimator class to wrap (not an instance)

Source code in xdflow/transforms/multi_output_wrapper.py
33
34
35
36
37
38
39
40
41
42
def __init__(self, base_estimator_cls: type[BaseEstimator]):
    """
    Initialize the factory with a base estimator class.

    Parameters
    ----------
    base_estimator_cls : Type[BaseEstimator]
        The estimator class to wrap (not an instance)
    """
    self.base_estimator_cls = base_estimator_cls

__call__

__call__(**kwargs) -> MultiOutputRegressor

Create a MultiOutputRegressor wrapping the base estimator.

Parameters

**kwargs All keyword arguments are passed to the base estimator constructor

Returns

MultiOutputRegressor The wrapped estimator ready for multi-target regression

Source code in xdflow/transforms/multi_output_wrapper.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def __call__(self, **kwargs) -> MultiOutputRegressor:
    """
    Create a MultiOutputRegressor wrapping the base estimator.

    Parameters
    ----------
    **kwargs
        All keyword arguments are passed to the base estimator constructor

    Returns
    -------
    MultiOutputRegressor
        The wrapped estimator ready for multi-target regression
    """
    base_estimator = self.base_estimator_cls(**kwargs)
    return MultiOutputRegressor(base_estimator)

__repr__

__repr__() -> str

Readable representation for debugging.

Source code in xdflow/transforms/multi_output_wrapper.py
61
62
63
def __repr__(self) -> str:
    """Readable representation for debugging."""
    return f"MultiOutputRegressorFactory({self.base_estimator_cls.__name__})"

__reduce__

__reduce__()

Support for pickling.

Source code in xdflow/transforms/multi_output_wrapper.py
65
66
67
def __reduce__(self):
    """Support for pickling."""
    return (self.__class__, (self.base_estimator_cls,))

MultiOutputClassifierFactory

MultiOutputClassifierFactory(base_estimator_cls: type[BaseEstimator])

Picklable factory that wraps any sklearn-compatible estimator in MultiOutputClassifier.

Source code in xdflow/transforms/multi_output_wrapper.py
73
74
def __init__(self, base_estimator_cls: type[BaseEstimator]):
    self.base_estimator_cls = base_estimator_cls

__repr__

__repr__() -> str

Readable representation for debugging.

Source code in xdflow/transforms/multi_output_wrapper.py
80
81
82
def __repr__(self) -> str:
    """Readable representation for debugging."""
    return f"MultiOutputClassifierFactory({self.base_estimator_cls.__name__})"

__reduce__

__reduce__()

Support for pickling.

Source code in xdflow/transforms/multi_output_wrapper.py
84
85
86
def __reduce__(self):
    """Support for pickling."""
    return (self.__class__, (self.base_estimator_cls,))

make_multi_output

make_multi_output(estimator_cls: type[BaseEstimator]) -> MultiOutputRegressorFactory

Convenience function to create a multi-output factory.

Parameters

estimator_cls : Type[BaseEstimator] The sklearn-compatible estimator class to wrap

Source code in xdflow/transforms/multi_output_wrapper.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def make_multi_output(estimator_cls: type[BaseEstimator]) -> MultiOutputRegressorFactory:
    """
    Convenience function to create a multi-output factory.

    Parameters
    ----------
    estimator_cls : Type[BaseEstimator]
        The sklearn-compatible estimator class to wrap
    """
    return MultiOutputRegressorFactory(estimator_cls)

Estimators And Predictors

NearestCentroidTransform

NearestCentroidTransform(n_components=None, use_priors=True)

Bases: BaseEstimator, TransformerMixin

Fisher-style discriminant transform under spherical within-class covariance (Σ = I). This yields the same subspace as LDA's .transform when shrinkage α=1 (nearest-centroid case).

Parameters

n_components : int or None (<= C-1) Number of components to keep. If None, uses C-1.

bool

If True, weight the overall mean by class priors (n_k / n). If False, unweighted.

Attributes (after fit)

classes_ : (C,) means_ : (C, p) priors_ : (C,) scalings_ : (p, r) # projection matrix explained_variance_ratio_ : (r,) mean_ : (p,) # overall mean used to center before projecting

Source code in xdflow/transforms/nearestcentroid.py
33
34
35
def __init__(self, n_components=None, use_priors=True):
    self.n_components = n_components
    self.use_priors = use_priors

NearestCentroid

NearestCentroid(sample_dim: str, target_coord: str, n_components=None, use_priors=True, output_dim_name: str = 'component', sel: dict | None = None, drop_sel: dict | None = None, **kwargs)

Bases: SKLearnTransformer

Transform wrapper for NearestCentroidTransform that can be used in pipelines.

Fisher-style discriminant transform under spherical within-class covariance (Σ = I). This yields the same subspace as LDA's .transform when shrinkage α=1 (nearest-centroid case).

Parameters

sample_dim : str The name of the dimension that corresponds to samples. target_coord : str The name of the coordinate containing the target variable for supervised fitting. n_components : int or None Number of components to keep. If None, uses C-1. use_priors : bool If True, weight the overall mean by class priors (n_k / n). If False, unweighted. output_dim_name : str The name for the new dimension created by the transformer. sel : dict, optional Selection dictionary passed to parent. drop_sel : dict, optional Drop selection dictionary passed to parent.

Source code in xdflow/transforms/nearestcentroid.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def __init__(
    self,
    sample_dim: str,
    target_coord: str,
    n_components=None,
    use_priors=True,
    output_dim_name: str = "component",
    sel: dict | None = None,
    drop_sel: dict | None = None,
    **kwargs,
):
    super().__init__(
        estimator_cls=NearestCentroidTransform,
        sample_dim=sample_dim,
        target_coord=target_coord,
        output_dim_name=output_dim_name,
        sel=sel,
        drop_sel=drop_sel,
        n_components=n_components,
        use_priors=use_priors,
        **kwargs,
    )

CholeskyLDA

CholeskyLDA(shrinkage: float | None = None, covariance_estimator: Any | None = None, cov_estimator_on_within: bool = True, cov_estimator_per_class: bool = False, priors: Any | None = None, dtype: str = 'float32', store_covariance: bool = False)

Bases: BaseEstimator, ClassifierMixin

Fast LDA using a Cholesky factorization of the (shrunk) pooled within-class covariance (feature-space 'primal') + a small CxC eigendecomposition for .transform.

Behavior for covariance shrinkage
  • If covariance_estimator is provided (e.g., sklearn.covariance.OAS()):
    • If cov_estimator_on_within=True (default): clone and fit it on within-class residuals Xc.
    • If cov_estimator_on_within=False: clone and fit it on raw X. Its covariance_ is used directly as Σ.
  • Else if shrinkage is a float in [0,1], use Σ = (1-α) S + α μ I, where S is the pooled within-class covariance and μ = tr(S)/p.
  • Else if shrinkage is None, Σ = S (no shrinkage).

Parameters

shrinkage : float in [0, 1] or None, default=None If float, use Σ = (1-α) S + α μ I with α=shrinkage and μ=tr(S)/p. If None, use Σ = S (no shrinkage). Ignored if covariance_estimator is set.

estimator or None, default=None

An sklearn-style covariance estimator (e.g., sklearn.covariance.OAS()). It is cloned and fit on either within-class residuals or raw X depending on cov_estimator_on_within.

bool, default=True

If True, fit the covariance estimator on within-class residuals Xc and set assume_centered=True when supported. If False, fit on raw X and do not modify assume_centered.

bool, default=False

If True, fit the covariance estimator separately for each class and mix with priors (sklearn solver="lsqr" behavior). If False, use original single-fit approach.

array-like of shape (n_classes,), default=None

Class prior probabilities. If None, inferred from data.

{"float32","float64"}, default="float32"

Internal compute dtype for heavy ops. Outputs are float64.

bool, default=False

If True, stores diag(Σ) in the private _covariance_diag_.

Attributes (after fit)

classes_ : (C,) priors_ : (C,) means_ : (C, p) coef_ : (C, p) # rows are Σ^{-1} μ_k intercept_ : (C,) # -0.5 μ_k^T Σ^{-1} μ_k + log π_k scalings_ : (p, r) # projection for transform (r = C-1) explained_variance_ratio_ : (r,) xbar_ : (p,) # prior-weighted global mean used for transform centering

Diagnostics

shrinkage_ : float or None Shrinkage alpha actually used, when available (from estimator or float path). mu_ : float or None Average variance μ = tr(S)/p (if computed on the float/none paths).

Source code in xdflow/transforms/lda.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def __init__(
    self,
    shrinkage: float | None = None,
    covariance_estimator: Any | None = None,
    cov_estimator_on_within: bool = True,
    cov_estimator_per_class: bool = False,
    priors: Any | None = None,
    dtype: str = "float32",
    store_covariance: bool = False,
):
    self.shrinkage = shrinkage
    self.covariance_estimator = covariance_estimator
    self.cov_estimator_on_within = cov_estimator_on_within
    self.cov_estimator_per_class = cov_estimator_per_class
    self.priors = priors
    self.dtype = dtype
    self.store_covariance = store_covariance

    # private diagnostics
    self._alpha = None
    self._mu_scalar = None
    self._N_eff = None
    self._covariance_diag_ = None

CholeskyLDATransformer

CholeskyLDATransformer(sample_dim: str, target_coord: str = 'stimulus', output_dim_name: str = 'component', shrinkage: float | None = None, covariance_estimator: object | None = None, cov_estimator_on_within: bool = True, cov_estimator_per_class: bool = False, priors: Iterable[tuple[Any, Any]] | None = None, dtype: str = 'float32', store_covariance: bool = False, sel=None, drop_sel=None)

Bases: SKLearnTransformer

SKLearnTransformer wrapper for CholeskyLDA.

Parameters

sample_dim : str The name of the dimension that corresponds to samples. target_coord : str, optional Coordinate containing the target variable for supervised fitting. Default "stimulus". output_dim_name : str, optional Name for the new dimension created by the transformer. Default "component". shrinkage : float in [0, 1] or None, default=None Shrinkage to use for LDA. Ignored if covariance_estimator is provided. covariance_estimator : object, optional Any sklearn-compatible covariance estimator instance (e.g., sklearn.covariance.OAS()). cov_estimator_on_within : bool, default=True Whether to fit the covariance estimator on within-class residuals (True) or raw X (False). cov_estimator_per_class : bool, default=False If True, fit the covariance estimator separately for each class and mix with priors (sklearn solver="lsqr" behavior). If False, use original single-fit approach. priors : array-like of shape (n_classes,), default=None Class prior probabilities. If None, inferred from data. dtype : {"float32","float64"}, default="float32" Internal compute dtype for heavy ops. Outputs are float64. store_covariance : bool, default=False If True, stores diag(Σ) in the private _covariance_diag_. sel : dict, optional Selection criteria for input data. drop_sel : dict, optional Drop selection criteria for input data.

Source code in xdflow/transforms/lda.py
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
def __init__(
    self,
    sample_dim: str,
    target_coord: str = "stimulus",
    output_dim_name: str = "component",
    shrinkage: float | None = None,
    covariance_estimator: object | None = None,
    cov_estimator_on_within: bool = True,
    cov_estimator_per_class: bool = False,
    priors: Iterable[tuple[Any, Any]] | None = None,
    dtype: str = "float32",
    store_covariance: bool = False,
    sel=None,
    drop_sel=None,
):
    super().__init__(
        estimator_cls=CholeskyLDA,
        sample_dim=sample_dim,
        target_coord=target_coord,
        output_dim_name=output_dim_name,
        shrinkage=shrinkage,
        covariance_estimator=covariance_estimator,
        cov_estimator_on_within=cov_estimator_on_within,
        cov_estimator_per_class=cov_estimator_per_class,
        priors=priors,
        dtype=dtype,
        store_covariance=store_covariance,
        sel=sel,
        drop_sel=drop_sel,
    )

LGBMPredictor

LGBMPredictor(estimator_cls: type[BaseEstimator], sample_dim: str, target_coord: str | list[str], early_stopping_rounds: int | None = None, validation_size: float = 0.2, validation_seed: int | None = None, eval_metric: str | None = None, verbose_eval: int | bool = False, encoder: LabelEncoder | None = None, proba: bool = False, is_classifier: bool | None = None, multi_output: bool = False, is_multilabel: bool = False, sel: dict | None = None, drop_sel: dict | None = None, sample_weight_coord: str | None = 'sample_weight', **kwargs: Any)

Bases: SKLearnPredictor

LightGBM predictor with built-in early stopping support.

Extends SKLearnPredictor to handle early stopping parameters in init and automatically create validation splits during fitting. Works with both LGBMClassifier and LGBMRegressor.

Parameters

estimator_cls : Type[BaseEstimator] LightGBM estimator class (LGBMClassifier or LGBMRegressor). sample_dim : str Name of the sample dimension. target_coord : Union[str, List[str]] Target coordinate name (or list/pattern for multi-target). early_stopping_rounds : Optional[int], default=None Number of rounds with no improvement before stopping. - None: disabled (no validation split created) - Positive int: enabled (creates validation split automatically) - Note: 0 raises ValueError; use None to disable validation_size : float, default=0.2 Proportion of training data for validation (0.0-1.0). Only used when early_stopping_rounds is set. validation_seed : Optional[int], default=None Random seed for reproducible validation splits. eval_metric : Optional[str], default=None Metric for early stopping. If None, LightGBM auto-selects based on objective: - 'binary' → 'binary_logloss' - 'multiclass' → 'multi_logloss' - 'regression' → 'l2' (RMSE) Common overrides: 'auc', 'rmse', 'mae' verbose_eval : Union[int, bool], default=False Logging frequency (int > 0) or disable (False/0). **kwargs Standard predictor params (encoder, proba, is_classifier, multi_output, etc.) and LightGBM hyperparameters (n_estimators, learning_rate, max_depth, etc.).

Examples

from lightgbm import LGBMClassifier predictor = LGBMPredictor( ... LGBMClassifier, ... sample_dim='trial', ... target_coord='stimulus', ... early_stopping_rounds=50, ... n_estimators=1000 ... ) predictor.fit(train_data) print(f"Stopped at iteration: {predictor.best_iteration_}")

Notes

  • Validation split uses stratification for classifiers when possible
  • best_iteration_ attribute set after fitting with early stopping
  • Sample weights automatically split along with data if provided

Initialize LGBMPredictor with early stopping parameters.

Source code in xdflow/transforms/lgbm_predictor.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def __init__(
    self,
    estimator_cls: type[BaseEstimator],
    sample_dim: str,
    target_coord: str | list[str],
    early_stopping_rounds: int | None = None,
    validation_size: float = 0.2,
    validation_seed: int | None = None,
    eval_metric: str | None = None,
    verbose_eval: int | bool = False,
    encoder: LabelEncoder | None = None,
    proba: bool = False,
    is_classifier: bool | None = None,
    multi_output: bool = False,
    is_multilabel: bool = False,
    sel: dict | None = None,
    drop_sel: dict | None = None,
    sample_weight_coord: str | None = "sample_weight",
    **kwargs: Any,
):
    """Initialize LGBMPredictor with early stopping parameters."""
    if lgb is None:
        raise ImportError(
            "LightGBM is required for LGBMPredictor. "
            "Install with: pip install xdflow[lightgbm] or pip install lightgbm"
        )
    # Validate early stopping parameters
    if early_stopping_rounds is not None:
        if not isinstance(early_stopping_rounds, int) or early_stopping_rounds <= 0:
            raise ValueError(
                f"early_stopping_rounds must be a positive integer or None, got {early_stopping_rounds}"
            )
        if not 0.0 < validation_size < 1.0:
            raise ValueError(f"validation_size must be between 0.0 and 1.0, got {validation_size}")

    # Store early stopping parameters as public attributes (required for cloning)
    self.early_stopping_rounds = early_stopping_rounds
    self.validation_size = validation_size
    self.validation_seed = validation_seed
    self.eval_metric = eval_metric
    self.verbose_eval = verbose_eval

    # Initialize parent class
    super().__init__(
        estimator_cls=estimator_cls,
        sample_dim=sample_dim,
        target_coord=target_coord,
        encoder=encoder,
        proba=proba,
        is_classifier=is_classifier,
        multi_output=multi_output,
        is_multilabel=is_multilabel,
        sel=sel,
        drop_sel=drop_sel,
        sample_weight_coord=sample_weight_coord,
        **kwargs,
    )

get_params

get_params(deep: bool = True) -> dict[str, Any]

Get parameters including early stopping parameters.

Returns parameters from the wrapper, early stopping config, and wrapped estimator.

Source code in xdflow/transforms/lgbm_predictor.py
261
262
263
264
265
266
267
268
def get_params(self, deep: bool = True) -> dict[str, Any]:
    """
    Get parameters including early stopping parameters.

    Returns parameters from the wrapper, early stopping config, and wrapped estimator.
    """
    params = super().get_params(deep=deep)
    return params

clone

clone()

Return a fresh instance with the same constructor parameters.

Ensures early stopping parameters are preserved in the cloned instance.

Source code in xdflow/transforms/lgbm_predictor.py
270
271
272
273
274
275
276
def clone(self):
    """
    Return a fresh instance with the same constructor parameters.

    Ensures early stopping parameters are preserved in the cloned instance.
    """
    return super().clone()

Time-Series And Spatial Transforms

HilbertPhaseTransform

HilbertPhaseTransform(fs: int, mode: str = 'timepoints', timepoints_step_ms: int = 100, timepoints_start_ms: int | None = None, timepoints_end_ms: int | None = None, use_time_coord: bool = True, num_lf_bands_remove: int = 0, num_hf_bands_remove: int = 1, lfp_pad_ms_at_ends: int = 0, n_jobs: int | None = None, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None, transform_sel: dict | None = None, transform_drop_sel: dict | None = None)

Bases: Transform

Compute instantaneous phase via Hilbert transform per frequency band and extract features.

Two modes are supported: - 'timepoints': Extract phase at timepoints for each channel and band. By default, uses the 'time' coordinate from input data (assumed to be in milliseconds). Can optionally use regularly spaced synthetic timepoints. - 'relative_average': Compute channel-wise phases relative to the average phase across channels, then average over a specified time window (in ms).

Input dims must include ('trial', 'channel', 'time'). The transform removes the 'time' dimension and adds a 'freq_band' dimension and optionally a 'timepoint' dimension, depending on the mode. When using 'timepoints' mode with use_time_coord=True (default), the input data must have a 'time' coordinate, expressed in milliseconds.

Parameters:

Name Type Description Default
fs int

Sampling frequency in Hz.

required
mode str

'timepoints' or 'relative_average'.

'timepoints'
timepoints_step_ms int

Step between timepoints to extract (ms), used when mode='timepoints' and use_time_coord=False.

100
timepoints_start_ms int | None

Optional start bound (ms) to begin extracting/averaging within the effective region. - If use_time_coord=True (or a 'time' coordinate exists), interpreted in the same absolute units as the 'time' coordinate (assumed ms) after padding is excluded. - If no 'time' coordinate, interpreted relative to the trimmed region start (0 ms after padding).

None
timepoints_end_ms int | None

Optional end bound (ms) to stop extracting/averaging within the effective region. Same interpretation rules as timepoints_start_ms.

None
use_time_coord bool

If True (default), use the 'time' coordinate from input data for timepoint selection. If False, generate timepoints synthetically based on fs and timepoints_step_ms.

True
num_lf_bands_remove int

Remove this many low-frequency bands from default ranges.

0
num_hf_bands_remove int

Remove this many high-frequency bands from default ranges.

1
lfp_pad_ms_at_ends int

Assumed pre-existing padding present at both start and end of each trial segment. We do NOT add any new samples; instead we ignore the first/last lfp_pad_ms_at_ends when extracting timepoints or computing relative-average features. If total duration is T ms and this value is P ms, the effective region is [P, T - P). When use_time_coord=True, padding is applied to the time coordinate values.

0
n_jobs int | None

Number of parallel jobs for trial processing. If None, runs sequentially. Ignored when the input is dask-backed, in which case Dask controls parallelism.

None
sel dict[str, Any] | None

Optional selection to apply before transforming.

None
drop_sel dict[str, Any] | None

Optional drop selection to apply before transforming.

None
transform_sel dict | None

Optional transform selection to apply before transforming.

None
transform_drop_sel dict | None

Optional transform drop selection to apply before transforming.

None
Source code in xdflow/transforms/phase.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def __init__(
    self,
    fs: int,
    mode: str = "timepoints",
    timepoints_step_ms: int = 100,
    timepoints_start_ms: int | None = None,
    timepoints_end_ms: int | None = None,
    use_time_coord: bool = True,
    num_lf_bands_remove: int = 0,
    num_hf_bands_remove: int = 1,
    lfp_pad_ms_at_ends: int = 0,
    n_jobs: int | None = None,
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
    transform_sel: dict | None = None,
    transform_drop_sel: dict | None = None,
):
    """
    Args:
        fs: Sampling frequency in Hz.
        mode: 'timepoints' or 'relative_average'.
        timepoints_step_ms: Step between timepoints to extract (ms), used when mode='timepoints' and use_time_coord=False.
        timepoints_start_ms: Optional start bound (ms) to begin extracting/averaging within the effective region.
            - If use_time_coord=True (or a 'time' coordinate exists), interpreted in the same absolute units as the
              'time' coordinate (assumed ms) after padding is excluded.
            - If no 'time' coordinate, interpreted relative to the trimmed region start (0 ms after padding).
        timepoints_end_ms: Optional end bound (ms) to stop extracting/averaging within the effective region.
            Same interpretation rules as timepoints_start_ms.
        use_time_coord: If True (default), use the 'time' coordinate from input data for timepoint selection.
            If False, generate timepoints synthetically based on fs and timepoints_step_ms.
        num_lf_bands_remove: Remove this many low-frequency bands from default ranges.
        num_hf_bands_remove: Remove this many high-frequency bands from default ranges.
        lfp_pad_ms_at_ends: Assumed pre-existing padding present at both start and end of each trial segment.
            We do NOT add any new samples; instead we ignore the first/last lfp_pad_ms_at_ends when extracting
            timepoints or computing relative-average features. If total duration is T ms and this value is P ms,
            the effective region is [P, T - P). When use_time_coord=True, padding is applied to the time coordinate values.
        n_jobs: Number of parallel jobs for trial processing. If None, runs sequentially.
            Ignored when the input is dask-backed, in which case Dask controls parallelism.
        sel: Optional selection to apply before transforming.
        drop_sel: Optional drop selection to apply before transforming.
        transform_sel: Optional transform selection to apply before transforming.
        transform_drop_sel: Optional transform drop selection to apply before transforming.
    """
    super().__init__(sel=sel, drop_sel=drop_sel, transform_sel=transform_sel, transform_drop_sel=transform_drop_sel)
    if mode not in ("timepoints", "relative_average"):
        raise ValueError("mode must be either 'timepoints' or 'relative_average'")

    self.fs = fs
    self.mode = mode
    self.timepoints_step_ms = int(timepoints_step_ms)
    self.timepoints_start_ms = None if timepoints_start_ms is None else int(timepoints_start_ms)
    self.timepoints_end_ms = None if timepoints_end_ms is None else int(timepoints_end_ms)
    self.use_time_coord = use_time_coord
    self.num_lf_bands_remove = int(num_lf_bands_remove)
    self.num_hf_bands_remove = int(num_hf_bands_remove)
    self.lfp_pad_ms_at_ends = int(lfp_pad_ms_at_ends)
    self.n_jobs = n_jobs

    # Frequency bands to compute
    freq_ranges = get_remove_freq_ranges(num_hf_bands_remove, self.DEFAULT_FREQ_RANGES.copy())
    if num_lf_bands_remove > 0:
        freq_ranges = get_remove_freq_ranges(num_lf_bands_remove, freq_ranges, remove_high=False)
    if not freq_ranges:
        raise ValueError("No frequency ranges available after filtering configuration")
    self.freq_ranges: dict[str, tuple[float, float]] = freq_ranges

GlobalFeaturePCA

GlobalFeaturePCA(n_components: float | int | None = None, pca_frac_to_keep: float | None = None, center_data: bool | str = True, whiten: bool = False, time_dim: str = 'time', feature_dim: str = 'channel', sel: dict | None = None, drop_sel: dict | None = None)

Bases: Transform

Performs global PCA across features while preserving time structure.

The transform treats each sample as a concatenation of (trial, time_dim) and fits PCA over the feature covariance matrix. It then projects the data onto the top principal components and reshapes back to (trial, component, time_dim). The output dimension name remains "feature_dim" for downstream compatibility (components replace features).

Initialize the transform.

Parameters:

Name Type Description Default
n_components float | int | None

Optional fraction (0, 1] of variance to keep or integer >= 1 specifying exact number of components (sklearn standard behavior). Cannot be specified together with pca_frac_to_keep.

None
pca_frac_to_keep float | None

Optional fraction of principal components to keep. Cannot be specified together with n_components.

None
center_data bool | str

Whether to center features across (trial, time) before PCA.

True
whiten bool

Whether to whiten the data before fitting/projection.

False
time_dim str

The dimension name for the time dimension.

'time'
feature_dim str

The dimension name for the feature dimension.

'channel'
sel dict | None

Optional selection to apply before transforming.

None
drop_sel dict | None

Optional drop selection to apply before transforming.

None
transform_sel

Optional selection applied only during transform.

required
transform_drop_sel

Optional drop selection applied only during transform.

required
Source code in xdflow/transforms/pca.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def __init__(
    self,
    n_components: float | int | None = None,
    pca_frac_to_keep: float | None = None,
    center_data: bool | str = True,
    whiten: bool = False,
    time_dim: str = "time",
    feature_dim: str = "channel",
    sel: dict | None = None,
    drop_sel: dict | None = None,
) -> None:
    """Initialize the transform.

    Args:
        n_components: Optional fraction (0, 1] of variance to keep or integer >= 1
            specifying exact number of components (sklearn standard behavior).
            Cannot be specified together with pca_frac_to_keep.
        pca_frac_to_keep: Optional fraction of principal components to keep.
            Cannot be specified together with n_components.
        center_data: Whether to center features across (trial, time) before PCA.
        whiten: Whether to whiten the data before fitting/projection.
        time_dim: The dimension name for the time dimension.
        feature_dim: The dimension name for the feature dimension.
        sel: Optional selection to apply before transforming.
        drop_sel: Optional drop selection to apply before transforming.
        transform_sel: Optional selection applied only during transform.
        transform_drop_sel: Optional drop selection applied only during transform.
    """
    # Use drop_sel to exclude reference channels globally; selective transform is not supported
    super().__init__(
        sel=sel,
        drop_sel=drop_sel,
    )

    # Validate n_components
    if n_components is not None:
        if isinstance(n_components, float):
            if not (0.0 < n_components <= 1.0):
                raise ValueError("n_components float must be in (0, 1].")
        elif isinstance(n_components, int):
            if n_components < 1:
                raise ValueError("n_components int must be >= 1.")
        else:
            raise TypeError("n_components must be a float in (0,1] or an int >= 1.")

    # Validate pca_frac_to_keep
    if pca_frac_to_keep is not None:
        if not isinstance(pca_frac_to_keep, (int, float)) or not (0.0 < pca_frac_to_keep <= 1.0):
            raise ValueError("pca_frac_to_keep must be a number greater than 0.0 and less than or equal to 1.0.")

    # Mutual exclusion validation
    if n_components is not None and pca_frac_to_keep is not None:
        raise ValueError("Cannot specify both n_components and pca_frac_to_keep. Use one or the other.")

    # Ensure at least one is specified
    if n_components is None and pca_frac_to_keep is None:
        raise ValueError("Must specify either n_components or pca_frac_to_keep.")

    self.n_components = n_components
    self.pca_frac_to_keep = float(pca_frac_to_keep) if pca_frac_to_keep is not None else None
    # Normalize centering mode
    if isinstance(center_data, bool):
        self._center_mode = "true" if center_data else "false"
    elif isinstance(center_data, str) and center_data.lower() == "false_oldstyle":
        self._center_mode = "false_oldstyle"
    else:
        raise ValueError("center_data must be True, False, or 'false_oldstyle'.")
    self.center_data = center_data  # keep original for introspection
    self.whiten = bool(whiten)

    # Learned parameters
    self.components_: np.ndarray | None = None  # shape: (k, n_features)
    self.feature_mean = None  # xarray DataArray of shape (feature,)

    # set up dims
    self.input_dims = ("trial", feature_dim, time_dim)
    self.output_dims = ("trial", feature_dim, time_dim)
    self.time_dim = time_dim
    self.feature_dim = feature_dim

LaplacianCSDTransform

LaplacianCSDTransform(grid_layout: ndarray, radius: float = 1.0, weighted: bool = False, scaling: float = 1.0, handle_nans: bool = True, verbosity: int = 0, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None, transform_sel: dict[str, Any] | None = None, transform_drop_sel: dict[str, Any] | None = None)

Bases: Transform

Computes the Current Source Density (CSD) using a Laplacian filter.

This transform applies spatial filtering to neural electrode data arranged on a 2D grid. The CSD is computed as the negative spatial Laplacian, which estimates local current sources and sinks in neural tissue.

Channel Handling: - Only processes numeric channel names (e.g., '0', '1', '127') that correspond to electrode positions in the spatial grid - Non-numeric channels (e.g., reference or auxiliary sensors) are automatically preserved unchanged in the output - Channel order in the output matches the input exactly

Spatial Processing: - Maps numeric channels to a 2D electrode grid using predefined coordinates - Applies Laplacian kernel convolution for spatial filtering - Maps results back to the original channel structure

Parameters:

Name Type Description Default
grid_layout ndarray

2D numpy array defining the spatial electrode layout, where values correspond to channel IDs. Required for spatial mapping.

required
radius float

Spatial radius for the Laplacian kernel (in grid units).

1.0
weighted bool

If True, weights kernel by inverse distance from center.

False
scaling float

Scaling factor applied to neighbor weights.

1.0
handle_nans bool

If True, fills NaN values using spatial neighbors before processing.

True
verbosity int

Verbosity level for timing output (0=silent, 1=basic, 2=detailed).

0
Example

Data with both LFP channels ('0'-'127') and a reference channel

data.coords['channel'] = ['0', '1', '2', ..., '127', 'ref'] grid = np.array([[0, 1, 2], [3, 4, 5]]) # example grid layout transform = LaplacianCSDTransform(grid_layout=grid, radius=1.5) result = transform.transform(data)

result: LFP channels contain CSD values, reference channel unchanged

Source code in xdflow/transforms/spatial.py
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
def __init__(
    self,
    grid_layout: np.ndarray,
    radius: float = 1.0,
    weighted: bool = False,
    scaling: float = 1.0,
    handle_nans: bool = True,
    verbosity: int = 0,
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
    transform_sel: dict[str, Any] | None = None,
    transform_drop_sel: dict[str, Any] | None = None,
):
    super().__init__(sel=sel, drop_sel=drop_sel, transform_sel=transform_sel, transform_drop_sel=transform_drop_sel)
    self.grid_layout = grid_layout
    self.radius = radius
    self.weighted = weighted
    self.scaling = scaling
    self.handle_nans = handle_nans
    self.verbosity = verbosity

WindowMeanPyramidTransform

WindowMeanPyramidTransform(grid_layout: ndarray, levels: int = 1, keep_only_top_pyramid: bool = True, handle_nans: bool = True, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None)

Bases: Transform

Creates multiple levels of spatial averaging using a 2x2 sliding window. Requires a "channel" dimension, but otherwise dimension-agnostic.

TODO: do we want to support transform_sel even with different dimensions? so we can keep reference channels?

Parameters:

Name Type Description Default
grid_layout ndarray

2D numpy array defining the spatial electrode layout, where values correspond to channel IDs. Required for spatial mapping.

required
levels int

Number of pyramid levels to create.

1
keep_only_top_pyramid bool

If True, only keeps the top level of the pyramid.

True
handle_nans bool

If True, fills NaN values using spatial neighbors before processing.

True
Source code in xdflow/transforms/spatial.py
509
510
511
512
513
514
515
516
517
518
519
520
521
522
def __init__(
    self,
    grid_layout: np.ndarray,
    levels: int = 1,
    keep_only_top_pyramid: bool = True,
    handle_nans: bool = True,
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
):
    super().__init__(sel=sel, drop_sel=drop_sel)
    self.grid_layout = grid_layout
    self.levels = levels
    self.keep_only_top_pyramid = keep_only_top_pyramid
    self.handle_nans = handle_nans

GaussianPyramidTransform

GaussianPyramidTransform(grid_layout: ndarray, levels: int = 1, sigma: float = 1.0, keep_only_top_pyramid: bool = True, handle_nans: bool = True, gaussian_kwargs: dict[str, Any] | None = None, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None)

Bases: Transform

Creates multiple levels of a Gaussian pyramid for spatial feature extraction. Requires a "channel" dimension, but otherwise dimension-agnostic. Levels = 1 means 1 level of gaussian filtering (unlike old implementation).

TODO: do we want to support transform_sel even with different dimensions? so we can keep reference channels?

Parameters:

Name Type Description Default
grid_layout ndarray

2D numpy array defining the spatial electrode layout, where values correspond to channel IDs. Required for spatial mapping.

required
levels int

Number of pyramid levels to create.

1
sigma float

Standard deviation for Gaussian kernel.

1.0
keep_only_top_pyramid bool

If True, only keeps the top level of the pyramid.

True
handle_nans bool

If True, fills NaN values using spatial neighbors before processing.

True
gaussian_kwargs dict[str, Any] | None

Additional keyword arguments for gaussian_filter.

None
Source code in xdflow/transforms/spatial.py
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
def __init__(
    self,
    grid_layout: np.ndarray,
    levels: int = 1,
    sigma: float = 1.0,
    keep_only_top_pyramid: bool = True,
    handle_nans: bool = True,
    gaussian_kwargs: dict[str, Any] | None = None,
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
):
    super().__init__(sel=sel, drop_sel=drop_sel)
    self.grid_layout = grid_layout
    self.levels = levels
    self.sigma = sigma
    self.keep_only_top_pyramid = keep_only_top_pyramid
    self.handle_nans = handle_nans
    self.gaussian_kwargs = gaussian_kwargs if gaussian_kwargs is not None else {}

LocalZCAWhitening

LocalZCAWhitening(grid_layout: ndarray, radius: float = 1.0, epsilon: float = 1e-06, whitening_strength: float = 1.0, n_components: float | int | None = None, pca_frac_to_keep: float | None = None, center_data: bool | str = True, sel: dict[str, object] | None = None, drop_sel: dict[str, object] | None = None)

Bases: Transform

Performs local ZCA whitening on neural data using spatial electrode neighborhoods.

This is a stateful transform that uses the electrode array layout to determine spatial neighborhoods. The 'fit' method computes a whitening matrix based on the covariance of local spatial neighborhoods. The 'transform' method applies this pre-computed matrix to the data.

Attributes:

Name Type Description
grid_layout

2D numpy array defining the spatial electrode layout, where values correspond to channel IDs. Required for spatial mapping and neighborhood determination.

radius float

The radius in grid units for defining a channel's neighborhood.

epsilon float

Regularization parameter for numerical stability.

whitening_strength float

Controls the degree of whitening applied.

n_components

Optional fraction (0, 1] of variance to keep or integer >= 1. Cannot be specified together with pca_frac_to_keep.

pca_frac_to_keep

Optional fraction of principal components to keep. Cannot be specified together with n_components.

center_data bool

Whether to center data before whitening.

Source code in xdflow/transforms/zca.py
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def __init__(
    self,
    grid_layout: np.ndarray,
    radius: float = 1.0,
    epsilon: float = 1e-6,
    whitening_strength: float = 1.0,
    n_components: float | int | None = None,
    pca_frac_to_keep: float | None = None,
    center_data: bool | str = True,
    sel: dict[str, object] | None = None,
    drop_sel: dict[str, object] | None = None,
):
    super().__init__(sel=sel, drop_sel=drop_sel)
    self.grid_layout = grid_layout

    # Validate inputs
    if not isinstance(radius, (int, float)) or radius < 0:
        raise ValueError("Radius must be a non-negative number.")

    if not isinstance(epsilon, (int, float)) or epsilon <= 0:
        raise ValueError("Epsilon must be a positive number.")

    if not isinstance(whitening_strength, (int, float)) or not (0.0 <= whitening_strength <= 1.0):
        raise ValueError("whitening_strength must be a number between 0.0 and 1.0.")

    # Validate n_components
    if n_components is not None:
        if isinstance(n_components, float):
            if not (0.0 < n_components <= 1.0):
                raise ValueError("n_components float must be in (0, 1].")
        elif isinstance(n_components, int):
            if n_components < 1:
                raise ValueError("n_components int must be >= 1.")
        else:
            raise TypeError("n_components must be a float in (0,1] or an int >= 1.")

    # Validate pca_frac_to_keep
    if pca_frac_to_keep is not None:
        if not isinstance(pca_frac_to_keep, (int, float)) or not (0.0 < pca_frac_to_keep <= 1.0):
            raise ValueError("pca_frac_to_keep must be a number greater than 0.0 and less than or equal to 1.0.")

    # Mutual exclusion validation
    if n_components is not None and pca_frac_to_keep is not None:
        raise ValueError("Cannot specify both n_components and pca_frac_to_keep. Use one or the other.")

    # Ensure at least one is specified
    if n_components is None and pca_frac_to_keep is None:
        raise ValueError("Must specify either n_components or pca_frac_to_keep.")

    self.radius = radius
    self.epsilon = epsilon
    self.whitening_strength = whitening_strength
    self.n_components = n_components
    self.pca_frac_to_keep = float(pca_frac_to_keep) if pca_frac_to_keep is not None else None
    # Normalize centering mode
    if isinstance(center_data, bool):
        self._center_mode = "true" if center_data else "false"
    elif isinstance(center_data, str) and center_data.lower() == "false_oldstyle":
        self._center_mode = "false_oldstyle"
    else:
        raise ValueError("center_data must be True, False, or 'false_oldstyle'.")
    self.center_data = center_data
    self.W_local = None  # The whitening matrix, learned during fit
    self.data_mean = None  # Mean for centering, computed during fit

GlobalZCAWhitening

GlobalZCAWhitening(epsilon: float = 1e-06, whitening_strength: float = 1.0, n_components: float | int | None = None, pca_frac_to_keep: float | None = None, center_data: bool | str = True, keep_in_pc_space: bool = False, shrinkage: str | None = None, sel: dict[str, object] | None = None, drop_sel: dict[str, object] | None = None)

Bases: Transform

Performs global ZCA whitening across channels.

This transform computes a single ZCA whitening matrix over all channels, flattening samples across (trial, time). It supports partial whitening via whitening_strength, optional dimensionality reduction via n_components or pca_frac_to_keep, and optional centering of the data before whitening.

Attributes:

Name Type Description
epsilon float

Regularization added to eigenvalues for stability.

whitening_strength float

0.0=no whitening, 1.0=full whitening.

n_components

Optional fraction (0, 1] of variance to keep or integer >= 1. Cannot be specified together with pca_frac_to_keep.

pca_frac_to_keep

Optional fraction of principal components to retain. Cannot be specified together with n_components.

center_data bool

Whether to center data across (trial, time).

shrinkage str | None

Placeholder to choose covariance shrinkage strategy during fitting (e.g., "ledoit-wolf"). Not implemented yet.

Source code in xdflow/transforms/zca.py
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
def __init__(
    self,
    epsilon: float = 1e-6,
    whitening_strength: float = 1.0,
    n_components: float | int | None = None,
    pca_frac_to_keep: float | None = None,
    center_data: bool | str = True,
    keep_in_pc_space: bool = False,
    shrinkage: str | None = None,
    sel: dict[str, object] | None = None,
    drop_sel: dict[str, object] | None = None,
):
    super().__init__(
        sel=sel,
        drop_sel=drop_sel,
    )

    if not isinstance(epsilon, (int, float)) or epsilon <= 0:
        raise ValueError("epsilon must be a positive number.")
    if not isinstance(whitening_strength, (int, float)) or not (0.0 <= whitening_strength <= 1.0):
        raise ValueError("whitening_strength must be in [0.0, 1.0].")

    # Validate n_components
    if n_components is not None:
        if isinstance(n_components, float):
            if not (0.0 < n_components <= 1.0):
                raise ValueError("n_components float must be in (0, 1].")
        elif isinstance(n_components, int):
            if n_components < 1:
                raise ValueError("n_components int must be >= 1.")
        else:
            raise TypeError("n_components must be a float in (0,1] or an int >= 1.")

    # Validate pca_frac_to_keep
    if pca_frac_to_keep is not None:
        if not isinstance(pca_frac_to_keep, (int, float)) or not (0.0 < pca_frac_to_keep <= 1.0):
            raise ValueError("pca_frac_to_keep must be in (0.0, 1.0].")

    # Mutual exclusion validation
    if n_components is not None and pca_frac_to_keep is not None:
        raise ValueError("Cannot specify both n_components and pca_frac_to_keep. Use one or the other.")

    # Ensure at least one is specified
    if n_components is None and pca_frac_to_keep is None:
        raise ValueError("Must specify either n_components or pca_frac_to_keep.")

    if shrinkage is not None:
        warnings.warn(
            "Covariance shrinkage is accepted as a parameter but not implemented yet; proceeding without shrinkage.",
            stacklevel=2,
        )

    self.epsilon = float(epsilon)
    self.whitening_strength = float(whitening_strength)
    self.n_components = n_components
    self.pca_frac_to_keep = float(pca_frac_to_keep) if pca_frac_to_keep is not None else None
    if isinstance(center_data, bool):
        self._center_mode = "true" if center_data else "false"
    elif isinstance(center_data, str) and center_data.lower() == "false_oldstyle":
        self._center_mode = "false_oldstyle"
    else:
        raise ValueError("center_data must be True, False, or 'false_oldstyle'.")
    self.center_data = center_data
    self.keep_in_pc_space = bool(keep_in_pc_space)
    self.shrinkage = shrinkage

    self.W_global: np.ndarray | None = None
    self.channel_mean = None  # xarray DataArray of shape (channel,)

GlobalColoringProjection

GlobalColoringProjection(epsilon: float = 1e-06, n_components: float | int | None = None, pca_frac_to_keep: float | None = None, center_data: bool | str = True, sel: dict[str, object] | None = None, drop_sel: dict[str, object] | None = None)

Bases: Transform

Performs a global "coloring" projection across channels.

This transform learns the covariance structure of the channels and creates a "coloring" matrix, which is the inverse of a ZCA whitening matrix. Instead of removing correlations, this matrix embodies them.

The transform then projects the data onto the basis vectors of this coloring matrix. The resulting features represent how strongly the neural activity at each time point aligns with the dominant, learned patterns of spatial covariance.

This serves as a powerful alternative to GlobalFeaturePCA. While PCA finds orthogonal axes of maximum variance, this transform finds axes that represent the natural, correlated modes of the system.

Attributes:

Name Type Description
epsilon float

Regularization added to eigenvalues for stability.

n_components

Optional fraction (0, 1] of variance to keep or integer >= 1. Cannot be specified together with pca_frac_to_keep.

pca_frac_to_keep

Optional fraction of principal components to retain, analogous to PCA, for dimensionality reduction. Cannot be specified together with n_components.

center_data bool

Whether to center data across (trial, time).

Initialize the transform.

Parameters:

Name Type Description Default
epsilon float

Regularization for numerical stability.

1e-06
n_components float | int | None

Optional fraction (0, 1] of variance to keep or integer >= 1 specifying exact number of components (sklearn standard behavior). Cannot be specified together with pca_frac_to_keep.

None
pca_frac_to_keep float | None

Optional fraction of components to keep, controlling output dimensionality. Cannot be specified together with n_components.

None
center_data bool | str

Whether to center channels before fitting/projection.

True
sel dict[str, object] | None

Optional selection to apply before transforming.

None
drop_sel dict[str, object] | None

Optional drop selection to apply before transforming.

None
Source code in xdflow/transforms/zca.py
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
def __init__(
    self,
    epsilon: float = 1e-6,
    n_components: float | int | None = None,
    pca_frac_to_keep: float | None = None,
    center_data: bool | str = True,
    sel: dict[str, object] | None = None,
    drop_sel: dict[str, object] | None = None,
):
    """Initialize the transform.

    Args:
        epsilon: Regularization for numerical stability.
        n_components: Optional fraction (0, 1] of variance to keep or integer >= 1
            specifying exact number of components (sklearn standard behavior).
            Cannot be specified together with pca_frac_to_keep.
        pca_frac_to_keep: Optional fraction of components to keep, controlling output
            dimensionality. Cannot be specified together with n_components.
        center_data: Whether to center channels before fitting/projection.
        sel: Optional selection to apply before transforming.
        drop_sel: Optional drop selection to apply before transforming.
    """
    super().__init__(sel=sel, drop_sel=drop_sel)

    if not isinstance(epsilon, (int, float)) or epsilon <= 0:
        raise ValueError("epsilon must be a positive number.")

    # Validate n_components
    if n_components is not None:
        if isinstance(n_components, float):
            if not (0.0 < n_components <= 1.0):
                raise ValueError("n_components float must be in (0, 1].")
        elif isinstance(n_components, int):
            if n_components < 1:
                raise ValueError("n_components int must be >= 1.")
        else:
            raise TypeError("n_components must be a float in (0,1] or an int >= 1.")

    # Validate pca_frac_to_keep
    if pca_frac_to_keep is not None:
        if not isinstance(pca_frac_to_keep, (int, float)) or not (0.0 < pca_frac_to_keep <= 1.0):
            raise ValueError("pca_frac_to_keep must be in (0.0, 1.0].")

    # Mutual exclusion validation
    if n_components is not None and pca_frac_to_keep is not None:
        raise ValueError("Cannot specify both n_components and pca_frac_to_keep. Use one or the other.")

    # Ensure at least one is specified
    if n_components is None and pca_frac_to_keep is None:
        raise ValueError("Must specify either n_components or pca_frac_to_keep.")

    self.epsilon = float(epsilon)
    self.n_components = n_components
    self.pca_frac_to_keep = float(pca_frac_to_keep) if pca_frac_to_keep is not None else None
    if isinstance(center_data, bool):
        self._center_mode = "true" if center_data else "false"
    elif isinstance(center_data, str) and center_data.lower() == "false_oldstyle":
        self._center_mode = "false_oldstyle"
    else:
        raise ValueError("center_data must be True, False, or 'false_oldstyle'.")
    self.center_data = center_data

    # Learned parameters
    self.W_color: np.ndarray | None = None  # The coloring matrix
    self.channel_mean: xr.DataArray | None = None  # Mean for centering

ZCAWhitening

ZCAWhitening(n_components=None, random_state=None)

Bases: BaseEstimator, TransformerMixin

ZCA (Zero Component Analysis) whitening using sklearn's PCA.

This estimator performs PCA with whitening and then inverts the PCA transform to return to the original space, effectively implementing ZCA whitening.

Parameters:

Name Type Description Default
n_components

int, float or None, default=None Number of components to keep. If None, all components are kept. If float in (0, 1], it represents the fraction of variance to keep.

None
random_state

int, RandomState instance or None, default=None Random state for reproducibility.

None
Source code in xdflow/transforms/zca.py
883
884
885
def __init__(self, n_components=None, random_state=None):
    self.n_components = n_components
    self.random_state = random_state

fit

fit(x, y=None)

Fit the ZCA whitening transform.

Parameters:

Name Type Description Default
x

array-like of shape (n_samples, n_features)

required
y

ignored

None

Returns:

Type Description

self

Source code in xdflow/transforms/zca.py
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
def fit(self, x, y=None):
    """
    Fit the ZCA whitening transform.

    Args:
        x: array-like of shape (n_samples, n_features)
        y: ignored

    Returns:
        self
    """
    # Use PCA with whitening
    self.pca_ = PCA(n_components=self.n_components, whiten=True, random_state=self.random_state)
    self.pca_.fit(x)
    return self

transform

transform(x)

Apply ZCA whitening transform.

Parameters:

Name Type Description Default
x

array-like of shape (n_samples, n_features)

required

Returns:

Type Description

array-like of shape (n_samples, n_features)

Source code in xdflow/transforms/zca.py
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
def transform(self, x):
    """
    Apply ZCA whitening transform.

    Args:
        x: array-like of shape (n_samples, n_features)

    Returns:
        array-like of shape (n_samples, n_features)
    """
    if not hasattr(self, "pca_"):
        raise ValueError("ZCAWhitening must be fitted before transform")

    # Apply PCA whitening and then invert back to original space
    x_whitened = self.pca_.transform(x)
    x_zca = self.pca_.inverse_transform(x_whitened)
    return x_zca

inverse_transform

inverse_transform(x)

Apply inverse ZCA transform.

Parameters:

Name Type Description Default
x

array-like of shape (n_samples, n_features)

required

Returns:

Type Description

array-like of shape (n_samples, n_features)

Source code in xdflow/transforms/zca.py
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
def inverse_transform(self, x):
    """
    Apply inverse ZCA transform.

    Args:
        x: array-like of shape (n_samples, n_features)

    Returns:
        array-like of shape (n_samples, n_features)
    """
    if not hasattr(self, "pca_"):
        raise ValueError("ZCAWhitening must be fitted before inverse_transform")

    # For ZCA, the inverse is the same as the forward transform
    # since ZCA is symmetric: ZCA = PCA_whiten + PCA_inverse
    return self.transform(x)

ZCATransform

ZCATransform(n_components=None, random_state=None, sample_dim='trial', output_dim_name='channel', sel=None, drop_sel=None)

Bases: SKLearnTransformer

ZCA (Zero Component Analysis) whitening transform.

This transform performs PCA with whitening and then inverts the PCA transform to return to the original space, effectively implementing ZCA whitening.

Parameters:

Name Type Description Default
n_components

int, float or None, default=None Number of components to keep. If None, all components are kept. If float in (0, 1], it represents the fraction of variance to keep.

None
random_state

int, RandomState instance or None, default=None Random state for reproducibility.

None
sample_dim

str, default="trial" The dimension that corresponds to samples.

'trial'
output_dim_name

str, default="channel" The name for the output dimension (same as input for ZCA).

'channel'
sel

dict, optional Selection to apply before transforming.

None
drop_sel

dict, optional Drop selection to apply before transforming.

None
Source code in xdflow/transforms/zca.py
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
def __init__(
    self,
    n_components=None,
    random_state=None,
    sample_dim="trial",
    output_dim_name="channel",
    sel=None,
    drop_sel=None,
):
    # Validate n_components - require it to be specified
    if n_components is None:
        raise ValueError("Must specify n_components.")

    super().__init__(
        estimator_cls=ZCAWhitening,
        sample_dim=sample_dim,
        output_dim_name=output_dim_name,
        sel=sel,
        drop_sel=drop_sel,
        n_components=n_components,
        random_state=random_state,
    )

Domain Adaptation

AdaptiveStrategy

AdaptiveStrategy(group_coord: str, n_jobs: int = 1, adapt_sel: dict[str, Any] | None = None)

Bases: ABC

Abstract base class for domain adaptation strategies.

A strategy encapsulates a high-level procedure for domain adaptation, such as a single-target or joint domain approach. It orchestrates the process by managing data grouping and calling back to the aligner for specific mathematical computations.

Initialize the AdaptiveStrategy.

Parameters:

Name Type Description Default
group_coord str

The coordinate to group by.

required
n_jobs int

The number of jobs to use for parallel processing.

1
adapt_sel dict[str, Any] | None

The selection criteria for data used for adaptation calculations. None means all data is used.

None
Source code in xdflow/transforms/domain_adaptation.py
77
78
79
80
81
82
83
84
85
86
87
88
def __init__(self, group_coord: str, n_jobs: int = 1, adapt_sel: dict[str, Any] | None = None):
    """
    Initialize the AdaptiveStrategy.

    Args:
        group_coord: The coordinate to group by.
        n_jobs: The number of jobs to use for parallel processing.
        adapt_sel: The selection criteria for data used for adaptation calculations. None means all data is used.
    """
    self.group_coord = group_coord
    self.n_jobs = n_jobs
    self.adapt_sel = adapt_sel

adapt

adapt(aligner: AdaptiveTransform, container: DataContainer, **kwargs)

Adapt the data using the adaptation strategy.

Source code in xdflow/transforms/domain_adaptation.py
90
91
92
93
94
95
96
97
def adapt(self, aligner: "AdaptiveTransform", container: DataContainer, **kwargs):
    """
    Adapt the data using the adaptation strategy.
    """
    if self.adapt_sel:
        container = get_container_by_conditions(container, self.adapt_sel)
        print(f"Adapted container shape: {container.data.shape}")
    self._adapt(aligner, container, **kwargs)

transform abstractmethod

transform(aligner: AdaptiveTransform, container: DataContainer, **kwargs) -> DataContainer

Abstract method for the transformation logic of the adaptation strategy.

Parameters:

Name Type Description Default
aligner AdaptiveTransform

The AdaptiveTransform instance using this strategy.

required
container DataContainer

The DataContainer to be transformed.

required
**kwargs

Additional arguments.

{}

Returns:

Type Description
DataContainer

The transformed DataContainer.

Source code in xdflow/transforms/domain_adaptation.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
@abstractmethod
def transform(self, aligner: "AdaptiveTransform", container: DataContainer, **kwargs) -> DataContainer:
    """
    Abstract method for the transformation logic of the adaptation strategy.

    Args:
        aligner: The AdaptiveTransform instance using this strategy.
        container: The DataContainer to be transformed.
        **kwargs: Additional arguments.

    Returns:
        The transformed DataContainer.
    """
    raise NotImplementedError("Subclasses must implement this method.")

SingleTargetStrategy

SingleTargetStrategy(group_coord: str, target_group: str | int | float, n_jobs: int = 1, adapt_sel: dict[str, Any] | None = None)

Bases: AdaptiveStrategy

An adaptation strategy for a single target group and multiple source groups. During adaptation, it fits the target domain and then adapts each source domain to the target domain independently. During transformation, it leaves the target domain unchanged and applies the adapted models to the source domains. Domains are determined by the group_coord.

SingleTargetStrategy can be implemented by any aligner that inherits from SingleTargetAligner.

Initialize the SingleTargetStrategy.

Parameters:

Name Type Description Default
group_coord str

The coordinate to group by. Determines different domains.

required
target_group str | int | float

The target group to adapt to.

required
n_jobs int

The number of jobs to use for parallel processing.

1
adapt_sel dict[str, Any] | None

The selection criteria for data used for adaptation calculations. None means all data is used.

None
Source code in xdflow/transforms/domain_adaptation.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def __init__(
    self,
    group_coord: str,
    target_group: str | int | float,
    n_jobs: int = 1,
    adapt_sel: dict[str, Any] | None = None,
):
    """
    Initialize the SingleTargetStrategy.

    Args:
        group_coord: The coordinate to group by. Determines different domains.
        target_group: The target group to adapt to.
        n_jobs: The number of jobs to use for parallel processing.
        adapt_sel: The selection criteria for data used for adaptation calculations. None means all data is used.
    """
    super().__init__(group_coord=group_coord, n_jobs=n_jobs, adapt_sel=adapt_sel)
    self.target_group = target_group
    self.target_params = {}
    self.adapted_params = {}
    self.group_dim = None
    self.seen_source_groups_ = []
    self.seen_groups_ = []

transform

transform(aligner: SingleTargetAligner, container: DataContainer, **kwargs) -> DataContainer

Transforms data by applying the appropriate target or adapted model to each group.

Parameters:

Name Type Description Default
aligner SingleTargetAligner

The SingleTargetAligner instance using this strategy.

required
container DataContainer

The DataContainer to be transformed.

required
**kwargs

Additional arguments.

{}

Returns:

Type Description
DataContainer

The transformed DataContainer.

Source code in xdflow/transforms/domain_adaptation.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
def transform(self, aligner: "SingleTargetAligner", container: DataContainer, **kwargs) -> DataContainer:
    """
    Transforms data by applying the appropriate target or adapted model to each group.

    Args:
        aligner: The SingleTargetAligner instance using this strategy.
        container: The DataContainer to be transformed.
        **kwargs: Additional arguments.

    Returns:
        The transformed DataContainer.
    """
    current_groups = self._discover_groups(container)

    def transform_group(group_val):
        group_container = self._select_group(container, group_val)
        if group_val == self.target_group:
            return group_container
        elif group_val in self.adapted_params:
            return aligner._adapted_transform(
                group_container, self.adapted_params[group_val], self.target_params, **kwargs
            )
        else:  # Unseen source group
            raise TransformError(
                f"Source group '{group_val}' was not seen during 'adapt'. "
                f"Seen source groups: {self.seen_source_groups_}"
            )

    group_outputs = []
    if self.n_jobs != 1:
        transformed_containers = Parallel(n_jobs=self.n_jobs)(
            delayed(transform_group)(group_val) for group_val in current_groups
        )
        group_outputs.extend([output.data for output in transformed_containers])
    else:
        for group_val in current_groups:
            transformed_container = transform_group(group_val)
            group_outputs.append(transformed_container.data)

    # Reassemble outputs #TODO: do we need to reassemble in the same order?
    reassembled = xr.concat(group_outputs, dim=self.group_dim)

    # Note: Reordering to match original is complex and may not be necessary.
    # If order is critical, it should be handled carefully.

    return DataContainer(reassembled)

JointGroupStrategy

JointGroupStrategy(group_coord: str, n_jobs: int = 1, adapt_sel: dict[str, Any] | None = None)

Bases: AdaptiveStrategy

An adaptation strategy for jointly fitting all domains. Domains are determined by the group_coord. During adaptation, it fits all domains jointly. During transformation, it applies the appropriate adapted model to each domain.

JointGroupStrategy can be implemented by any aligner that inherits from JointGroupAligner.

Initialize the JointGroupStrategy.

Parameters:

Name Type Description Default
group_coord str

The coordinate to group by. Determines different domains.

required
n_jobs int

The number of jobs to use for parallel processing.

1
adapt_sel dict[str, Any] | None

The selection criteria for data used for adaptation calculations. None means all data is used.

None
Source code in xdflow/transforms/domain_adaptation.py
305
306
307
308
309
310
311
312
313
314
315
316
317
def __init__(self, group_coord: str, n_jobs: int = 1, adapt_sel: dict[str, Any] | None = None):
    """
    Initialize the JointGroupStrategy.

    Args:
        group_coord: The coordinate to group by. Determines different domains.
        n_jobs: The number of jobs to use for parallel processing.
        adapt_sel: The selection criteria for data used for adaptation calculations. None means all data is used.
    """
    super().__init__(group_coord=group_coord, n_jobs=n_jobs, adapt_sel=adapt_sel)
    self.adapted_params = {}
    self.group_dim = None
    self.seen_groups_ = []

transform

transform(aligner: JointGroupAligner, container: DataContainer, **kwargs) -> DataContainer

Transforms data by applying the appropriate adapted model to each group.

Parameters:

Name Type Description Default
aligner JointGroupAligner

The JointGroupAligner instance using this strategy.

required
container DataContainer

The DataContainer to be transformed.

required
**kwargs

Additional arguments.

{}

Returns:

Type Description
DataContainer

The transformed DataContainer.

Source code in xdflow/transforms/domain_adaptation.py
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
def transform(self, aligner: "JointGroupAligner", container: DataContainer, **kwargs) -> DataContainer:
    """
    Transforms data by applying the appropriate adapted model to each group.

    Args:
        aligner: The JointGroupAligner instance using this strategy.
        container: The DataContainer to be transformed.
        **kwargs: Additional arguments.

    Returns:
        The transformed DataContainer.
    """
    current_groups = self._discover_groups(container)

    def transform_group(group_val):
        group_container = self._select_group(container, group_val)
        if group_val in self.adapted_params:
            return aligner._adapted_transform(group_container, self.adapted_params[group_val], **kwargs)
        else:
            raise TransformError(
                f"Group '{group_val}' was not seen during 'adapt'. Seen groups: {self.seen_groups_}"
            )

    group_outputs = []
    if self.n_jobs != 1:
        transformed_containers = Parallel(n_jobs=self.n_jobs)(
            delayed(transform_group)(group_val) for group_val in current_groups
        )
        group_outputs.extend([output.data for output in transformed_containers])
    else:
        for group_val in current_groups:
            transformed_container = transform_group(group_val)
            group_outputs.append(transformed_container.data)

    # Reassemble outputs
    reassembled = xr.concat(group_outputs, dim=self.group_dim)

    return DataContainer(reassembled)

AdaptiveTransform

AdaptiveTransform(strategy_name: str, sample_dim: str, **kwargs)

Bases: Transform, ABC

Abstract base class for domain adaptation aligners.

This class serves as the main context for the Strategy pattern. It is initialized with a strategy (either as a string or a pre-made instance) and delegates the fitting and transforming logic to it.

Initialize domain adaptation transform.

Parameters:

Name Type Description Default
strategy_name str

The adaptation strategy to use.

required
sample_dim str

The dimension to use for the sample.

required
**kwargs

Additional arguments passed to Transform base class or strategy's constructor.

{}
Source code in xdflow/transforms/domain_adaptation.py
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
def __init__(
    self,
    strategy_name: str,
    sample_dim: str,
    **kwargs,
):
    """
    Initialize domain adaptation transform.

    Args:
        strategy_name: The adaptation strategy to use.
        sample_dim: The dimension to use for the sample.
        **kwargs: Additional arguments passed to Transform base class or strategy's constructor.
    """
    super().__init__(**kwargs)

    self.sample_dim = sample_dim

    self.strategy_name = strategy_name  # save for cloning
    self.strategy, self._strategy_kwargs = self.get_strategy_and_kwargs(strategy_name, kwargs)

    self._is_fitted = False

get_strategy_and_kwargs

get_strategy_and_kwargs(strategy_name: str, kwargs: dict) -> tuple[AdaptiveStrategy, dict]

Create a strategy instance.

Source code in xdflow/transforms/domain_adaptation.py
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
def get_strategy_and_kwargs(self, strategy_name: str, kwargs: dict) -> tuple[AdaptiveStrategy, dict]:
    """
    Create a strategy instance.
    """
    # make sure strategy_name is in the strategy map
    if strategy_name not in self.__class__._STRATEGY_MAP:
        raise ValueError(
            f"Unknown or unsupported strategy: '{self.strategy_name}' for {self.__class__.__name__}. "
            f"Available strategies: {list(self.__class__._STRATEGY_MAP.keys())}"
        )

    strategy_class = self.__class__._STRATEGY_MAP[strategy_name]

    # make sure required strategyargs are in kwargs
    sig = signature(strategy_class.__init__)
    required_strategy_args = [
        name
        for name, p in sig.parameters.items()
        if p.default is _empty
        and p.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY)
        and name != "self"  # skip 'self' in methods
    ]

    for arg in required_strategy_args:
        if arg not in kwargs:
            raise ValueError(f"Required argument '{arg}' for strategy {strategy_class.__name__} is missing.")
        if kwargs[arg] is None:
            raise ValueError(f"Required argument '{arg}' for strategy {strategy_class.__name__} is None.")

    # all strategy constructor args
    all_strategy_args = [
        name
        for name, p in sig.parameters.items()
        if name != "self"
        and p.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY)
    ]

    strategy_kwargs = {k: v for k, v in kwargs.items() if k in all_strategy_args}

    return strategy_class(**strategy_kwargs), strategy_kwargs

get_expected_output_dims

get_expected_output_dims(input_dims: tuple[str, ...]) -> tuple[str, ...]

Domain adaptation transforms preserve input dimensions.

Parameters:

Name Type Description Default
input_dims tuple[str, ...]

Input dimension names

required

Returns:

Type Description
tuple[str, ...]

Same dimensions as input

Source code in xdflow/transforms/domain_adaptation.py
512
513
514
515
516
517
518
519
520
521
522
def get_expected_output_dims(self, input_dims: tuple[str, ...]) -> tuple[str, ...]:
    """
    Domain adaptation transforms preserve input dimensions.

    Args:
        input_dims: Input dimension names

    Returns:
        Same dimensions as input
    """
    return input_dims

SingleTargetAligner

SingleTargetAligner(strategy_name: str, **kwargs)

Bases: AdaptiveTransform, ABC

Defines the contract for aligners compatible with SingleTargetStrategy.

This ABC acts as an interface. Any aligner that inherits from it promises to implement the methods required by the SingleTargetStrategy, ensuring a safe and explicit connection between the two components.

Source code in xdflow/transforms/domain_adaptation.py
553
554
555
556
def __init__(self, strategy_name: str, **kwargs):
    super().__init__(strategy_name, **kwargs)

    self._strategy_class: SingleTargetStrategy = self.__class__._STRATEGY_MAP[strategy_name]

JointGroupAligner

JointGroupAligner(strategy_name: str, output_dim_name: str = 'component', **kwargs)

Bases: AdaptiveTransform, ABC

Defines the contract for aligners compatible with JointGroupStrategy. This ABC acts as an interface. Any aligner that inherits from it promises to implement the methods required by the JointGroupStrategy.

Initialize the JointGroupAligner.

Parameters:

Name Type Description Default
strategy_name str

The strategy to use for adaptation.

required
output_dim_name str

The name of the output dimension/space that all groups are aligned to.

'component'
**kwargs

Additional arguments.

{}
Source code in xdflow/transforms/domain_adaptation.py
610
611
612
613
614
615
616
617
618
619
620
621
622
def __init__(self, strategy_name: str, output_dim_name: str = "component", **kwargs):
    """
    Initialize the JointGroupAligner.

    Args:
        strategy_name: The strategy to use for adaptation.
        output_dim_name: The name of the output dimension/space that all groups are aligned to.
        **kwargs: Additional arguments.
    """
    super().__init__(strategy_name, **kwargs)

    self._strategy_class: JointGroupStrategy = self.__class__._STRATEGY_MAP[strategy_name]
    self.output_dim_name = output_dim_name

get_expected_output_dims

get_expected_output_dims(input_dims: tuple[str, ...]) -> tuple[str, ...]

Determines the expected output dimensions. JointGroupAligner aligns data to a common space, so it does not preserve the original dimension names.

Source code in xdflow/transforms/domain_adaptation.py
651
652
653
654
655
656
657
658
659
660
661
def get_expected_output_dims(self, input_dims: tuple[str, ...]) -> tuple[str, ...]:
    """
    Determines the expected output dimensions.
    JointGroupAligner aligns data to a common space, so it does not preserve the original dimension names.
    """
    if len(input_dims) != 2:
        raise ValueError(
            f"AdaptiveTransforms currently requires 2D data, but got data with {len(input_dims)} dimensions: {input_dims}"
        )

    return (self.sample_dim, self.output_dim_name)

ProcrustesAligner

ProcrustesAligner(target_coord: str, sample_dim: str, scaling: bool = True, strategy_name: str = 'single_target', sampling_method: str = 'mean', random_state: int = 0, group_coord: str | None = None, target_group: str | int | float | None = None, n_jobs: int = 1, adapt_sel: dict[str, Any] | None = None, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None, **kwargs)

Bases: SingleTargetAligner, SamplingMixin

Concrete aligner that performs Procrustes analysis to align datasets.

This method aligns target class means and finds an optimal rotation and scaling transformation to match the distributions between source and target domains.

Initialize ProcrustesAligner.

Parameters:

Name Type Description Default
target_coord str

Target coordinate to adapt to.

required
sample_dim str

The dimension to average over when calculating class means.

required
scaling bool

Whether to scale the data.

True
strategy_name str

The adaptation strategy to use. Currently only 'single_target' is supported.

'single_target'
sampling_method str

The method to use for sampling the data for alignment.

'mean'
random_state int

The random state to use for sampling the data.

0
group_coord str | None

Group coordinate to adapt to.

None
target_group str | int | float | None

Target group to adapt to.

None
n_jobs int

Number of jobs to use for parallel processing.

1
adapt_sel dict[str, Any] | None

The selection criteria for adaptation. None means all data is used for adaptation calculations.

None
sel dict[str, Any] | None

Dictionary of coordinates to select.

None
drop_sel dict[str, Any] | None

Dictionary of coordinates to drop.

None
**kwargs

Arguments passed to Transform base class.

{}
Source code in xdflow/transforms/domain_adaptation.py
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
def __init__(
    self,
    target_coord: str,
    sample_dim: str,
    scaling: bool = True,
    strategy_name: str = "single_target",
    sampling_method: str = "mean",
    random_state: int = 0,
    group_coord: str | None = None,
    target_group: str | int | float | None = None,
    n_jobs: int = 1,
    adapt_sel: dict[str, Any] | None = None,
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
    **kwargs,
):
    """
    Initialize ProcrustesAligner.

    Args:
        target_coord: Target coordinate to adapt to.
        sample_dim: The dimension to average over when calculating class means.
        scaling: Whether to scale the data.
        strategy_name: The adaptation strategy to use. Currently only 'single_target' is supported.
        sampling_method: The method to use for sampling the data for alignment.
        random_state: The random state to use for sampling the data.
        group_coord: Group coordinate to adapt to.
        target_group: Target group to adapt to.
        n_jobs: Number of jobs to use for parallel processing.
        adapt_sel: The selection criteria for adaptation. None means all data is used for adaptation calculations.
        sel: Dictionary of coordinates to select.
        drop_sel: Dictionary of coordinates to drop.
        **kwargs: Arguments passed to Transform base class.
    """

    super().__init__(
        strategy_name=strategy_name,
        sample_dim=sample_dim,
        # strategy_kwargs
        group_coord=group_coord,
        target_group=target_group,
        n_jobs=n_jobs,
        adapt_sel=adapt_sel,
        # transform kwargs
        sel=sel,
        drop_sel=drop_sel,
        **kwargs,
    )
    self.target_coord = target_coord
    self.scaling = scaling
    self.sampling_method = sampling_method
    self.random_state = random_state

    assert self.sampling_method in [
        "mean",
        "min_count",
    ], "Sampling method must be either 'mean' or 'min_count'."

CoralAligner

CoralAligner(sample_dim: str, reg: float = 1e-05, strategy_name: str = 'single_target', group_coord: str | None = None, target_group: str | int | float | None = None, n_jobs: int = 1, adapt_sel: dict[str, Any] | None = None, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None, **kwargs)

Bases: SingleTargetAligner

Correlation Alignment (CORAL) for domain adaptation.

This method aligns the second-order statistics (covariances) of the source and target distributions by learning a linear transformation matrix. The transformation preserves the original data dimensions.

The method learns a transformation matrix that aligns covariance structures: transform_matrix_ = source_cov^(-1/2) @ target_cov^(1/2)

Initialize CoralAligner.

Parameters:

Name Type Description Default
sample_dim str

The dimension to average over when calculating class means.

required
reg float

Regularization parameter for covariance matrix stability

1e-05
strategy_name str

The adaptation strategy to use. Currently only 'single_target' is supported.

'single_target'
group_coord str | None

Group coordinate to adapt to.

None
target_group str | int | float | None

Target group to adapt to.

None
n_jobs int

Number of jobs to use for parallel processing.

1
adapt_sel dict[str, Any] | None

The selection criteria for adaptation. None means all data is used for adaptation calculations.

None
sel dict[str, Any] | None

Dictionary of coordinates to select.

None
drop_sel dict[str, Any] | None

Dictionary of coordinates to drop.

None
**kwargs

Arguments passed to Transform base class.

{}
Source code in xdflow/transforms/domain_adaptation.py
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
def __init__(
    self,
    sample_dim: str,
    reg: float = 1e-5,
    strategy_name: str = "single_target",
    group_coord: str | None = None,
    target_group: str | int | float | None = None,
    n_jobs: int = 1,
    adapt_sel: dict[str, Any] | None = None,
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
    **kwargs,
):
    """
    Initialize CoralAligner.

    Args:
        sample_dim: The dimension to average over when calculating class means.
        reg: Regularization parameter for covariance matrix stability
        strategy_name: The adaptation strategy to use. Currently only 'single_target' is supported.
        group_coord: Group coordinate to adapt to.
        target_group: Target group to adapt to.
        n_jobs: Number of jobs to use for parallel processing.
        adapt_sel: The selection criteria for adaptation. None means all data is used for adaptation calculations.
        sel: Dictionary of coordinates to select.
        drop_sel: Dictionary of coordinates to drop.
        **kwargs: Arguments passed to Transform base class.
    """

    super().__init__(
        strategy_name=strategy_name,
        sample_dim=sample_dim,
        # strategy_kwargs
        group_coord=group_coord,
        target_group=target_group,
        n_jobs=n_jobs,
        adapt_sel=adapt_sel,
        # transform kwargs
        sel=sel,
        drop_sel=drop_sel,
        **kwargs,
    )

    self.sample_dim = sample_dim
    self.reg = reg  # Regularization parameter

SAAligner

SAAligner(sample_dim: str, n_components: int = 10, strategy_name: str = 'single_target', group_coord: str | None = None, target_group: str | int | float | None = None, n_jobs: int = 1, adapt_sel: dict[str, Any] | None = None, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None, **kwargs)

Bases: SingleTargetAligner

Subspace Alignment (SA) for domain adaptation.

This method aligns the basis vectors (subspaces) of the source and target domains, learned via PCA. The transformation projects data onto source subspace then aligns it to target subspace.

The method learns a transformation matrix that aligns PCA subspaces: transform_matrix_ = (source_basis @ target_basis)^T

Initialize SAAligner.

Parameters:

Name Type Description Default
sample_dim str

The dimension to average over when calculating class means.

required
n_components int

Number of principal components to use for alignment

10
strategy_name str

The adaptation strategy to use. Currently only 'single_target' is supported.

'single_target'
group_coord str | None

Group coordinate to adapt to.

None
target_group str | int | float | None

Target group to adapt to.

None
n_jobs int

Number of jobs to use for parallel processing.

1
adapt_sel dict[str, Any] | None

The selection criteria for adaptation. None means all data is used for adaptation calculations.

None
sel dict[str, Any] | None

Dictionary of coordinates to select.

None
drop_sel dict[str, Any] | None

Dictionary of coordinates to drop.

None
**kwargs

Arguments passed to Transform base class.

{}
Source code in xdflow/transforms/domain_adaptation.py
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
def __init__(
    self,
    sample_dim: str,
    n_components: int = 10,
    strategy_name: str = "single_target",
    group_coord: str | None = None,
    target_group: str | int | float | None = None,
    n_jobs: int = 1,
    adapt_sel: dict[str, Any] | None = None,
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
    **kwargs,
):
    """
    Initialize SAAligner.

    Args:
        sample_dim: The dimension to average over when calculating class means.
        n_components: Number of principal components to use for alignment
        strategy_name: The adaptation strategy to use. Currently only 'single_target' is supported.
        group_coord: Group coordinate to adapt to.
        target_group: Target group to adapt to.
        n_jobs: Number of jobs to use for parallel processing.
        adapt_sel: The selection criteria for adaptation. None means all data is used for adaptation calculations.
        sel: Dictionary of coordinates to select.
        drop_sel: Dictionary of coordinates to drop.
        **kwargs: Arguments passed to Transform base class.
    """

    super().__init__(
        sample_dim=sample_dim,
        strategy_name=strategy_name,
        # strategy_kwargs
        group_coord=group_coord,
        target_group=target_group,
        n_jobs=n_jobs,
        adapt_sel=adapt_sel,
        # transform kwargs
        sel=sel,
        drop_sel=drop_sel,
        **kwargs,
    )
    self.n_components = n_components

CCAAligner

CCAAligner(target_coord: str, sample_dim: str, output_dim_name: str = 'component', n_components: int = 10, strategy_name: str = 'joint', group_coord: str | None = None, n_jobs: int = 1, sampling_method: str = 'mean', random_state: int = 0, adapt_sel: dict[str, Any] | None = None, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None, **kwargs)

Bases: JointGroupAligner, SamplingMixin

Canonical Correlation Analysis (CCA) for domain adaptation.

This method finds a common latent space for two domains by learning a linear transformation matrix that maximizes the correlation between the two domains.

Initialize the CCAAligner.

Parameters:

Name Type Description Default
target_coord str

The coordinate to use for the target domain.

required
sample_dim str

The dimension to use for the sample.

required
output_dim_name str

The name of the output dimension/space that all groups are aligned to.

'component'
n_components int

The number of components to learn.

10
strategy_name str

The strategy to use for adaptation.

'joint'
group_coord str | None

The coordinate to group by. Determines different domains.

None
n_jobs int

The number of jobs to use for parallel processing.

1
sampling_method str

The method to use for sampling the data for alignment.

'mean'
random_state int

The random state to use for sampling the data.

0
adapt_sel dict[str, Any] | None

The selection criteria for adaptation. None means all data is used for adaptation calculations.

None
sel dict[str, Any] | None

Dictionary of coordinates to select.

None
drop_sel dict[str, Any] | None

Dictionary of coordinates to drop.

None
Source code in xdflow/transforms/domain_adaptation.py
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
def __init__(
    self,
    target_coord: str,
    sample_dim: str,
    output_dim_name: str = "component",
    n_components: int = 10,
    strategy_name: str = "joint",
    group_coord: str | None = None,
    n_jobs: int = 1,
    sampling_method: str = "mean",
    random_state: int = 0,
    adapt_sel: dict[str, Any] | None = None,
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
    **kwargs,
):
    """
    Initialize the CCAAligner.

    Args:
        target_coord: The coordinate to use for the target domain.
        sample_dim: The dimension to use for the sample.
        output_dim_name: The name of the output dimension/space that all groups are aligned to.
        n_components: The number of components to learn.
        strategy_name: The strategy to use for adaptation.
        group_coord: The coordinate to group by. Determines different domains.
        n_jobs: The number of jobs to use for parallel processing.
        sampling_method: The method to use for sampling the data for alignment.
        random_state: The random state to use for sampling the data.
        adapt_sel: The selection criteria for adaptation. None means all data is used for adaptation calculations.
        sel: Dictionary of coordinates to select.
        drop_sel: Dictionary of coordinates to drop.
    """

    super().__init__(
        strategy_name=strategy_name,
        sample_dim=sample_dim,
        output_dim_name=output_dim_name,
        group_coord=group_coord,
        n_jobs=n_jobs,
        adapt_sel=adapt_sel,
        sel=sel,
        drop_sel=drop_sel,
        **kwargs,
    )

    self.n_components = n_components if n_components is not None else np.inf
    self.target_coord = target_coord
    self.sampling_method = sampling_method
    self.random_state = random_state

    assert self.sampling_method in ["mean", "min_count"], "Sampling method must be either 'mean' or 'min_count'."

MCCAAligner

MCCAAligner(target_coord: str, sample_dim: str, output_dim_name: str = 'component', n_components: int = 10, reg: float = 1e-05, strategy_name: str = 'joint', sampling_method: str = 'mean', random_state: int = 0, group_coord: str | None = None, n_jobs: int = 1, adapt_sel: dict[str, Any] | None = None, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None, **kwargs)

Bases: JointGroupAligner, SamplingMixin

Multiset Canonical Correlation Analysis (MCCA) for domain adaptation.

This method finds a common latent space for multiple domains (groups) by finding projections that maximize the total correlation between the domains. It solves a generalized eigenvalue problem to find the canonical components.

Initialize MCCAAligner.

Parameters:

Name Type Description Default
target_coord str

Target coordinate to adapt to.

required
sample_dim str

The dimension to average over when calculating class means.

required
output_dim_name str

The name of the output dimension/space that all groups are aligned to.

'component'
n_components int

Number of canonical components to keep.

10
reg float

Regularization parameter for covariance matrices.

1e-05
strategy_name str

The adaptation strategy to use. Currently only 'joint' is supported.

'joint'
group_coord str | None

Group coordinate to adapt to.

None
n_jobs int

Number of jobs to use for parallel processing.

1
adapt_sel dict[str, Any] | None

The selection criteria for data used for adaptation calculations. None means all data is used.

None
sel dict[str, Any] | None

Dictionary of coordinates to select.

None
drop_sel dict[str, Any] | None

Dictionary of coordinates to drop.

None
**kwargs

Arguments passed to Transform base class.

{}
Source code in xdflow/transforms/domain_adaptation.py
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
def __init__(
    self,
    target_coord: str,
    sample_dim: str,
    output_dim_name: str = "component",
    n_components: int = 10,
    reg: float = 1e-5,
    strategy_name: str = "joint",
    sampling_method: str = "mean",
    random_state: int = 0,
    group_coord: str | None = None,
    n_jobs: int = 1,
    adapt_sel: dict[str, Any] | None = None,
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
    **kwargs,
):
    """
    Initialize MCCAAligner.

    Args:
        target_coord: Target coordinate to adapt to.
        sample_dim: The dimension to average over when calculating class means.
        output_dim_name: The name of the output dimension/space that all groups are aligned to.
        n_components: Number of canonical components to keep.
        reg: Regularization parameter for covariance matrices.
        strategy_name: The adaptation strategy to use. Currently only 'joint' is supported.
        group_coord: Group coordinate to adapt to.
        n_jobs: Number of jobs to use for parallel processing.
        adapt_sel: The selection criteria for data used for adaptation calculations. None means all data is used.
        sel: Dictionary of coordinates to select.
        drop_sel: Dictionary of coordinates to drop.
        **kwargs: Arguments passed to Transform base class.
    """

    super().__init__(
        strategy_name=strategy_name,
        sample_dim=sample_dim,
        output_dim_name=output_dim_name,
        group_coord=group_coord,
        n_jobs=n_jobs,
        adapt_sel=adapt_sel,
        sel=sel,
        drop_sel=drop_sel,
        **kwargs,
    )

    self.n_components = n_components
    self.target_coord = target_coord
    self.reg = reg
    self.sampling_method = sampling_method
    self.random_state = random_state

    assert self.sampling_method in ["mean", "min_count"], "Sampling method must be either 'mean' or 'min_count'."

GCCAAligner

GCCAAligner(target_coord: str, sample_dim: str, output_dim_name: str = 'component', n_components: int = 10, reg: float = 1e-05, strategy_name: str = 'joint', sampling_method: str = 'mean', random_state: int = 0, group_coord: str | None = None, n_jobs: int = 1, adapt_sel: dict[str, Any] | None = None, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None, **kwargs)

Bases: JointGroupAligner, SamplingMixin

Generalized Canonical Correlation Analysis (GCCA) for domain adaptation.

This method finds a common latent space for multiple domains (groups) by finding projections that maximize the agreement between the domains. It is based on finding a common subspace that is predictable from all domains. This implementation solves the GCCA problem by finding the leading eigenvectors of the sum of projection matrices onto each domain's space.

Initialize GCCAAligner.

Parameters:

Name Type Description Default
target_coord str

Target coordinate to adapt to.

required
sample_dim str

The dimension to average over when calculating class means.

required
output_dim_name str

The name of the output dimension/space that all groups are aligned to.

'component'
n_components int

Number of canonical components to keep.

10
reg float

Regularization parameter for covariance matrices.

1e-05
strategy_name str

The adaptation strategy to use. Currently only 'joint' is supported.

'joint'
group_coord str | None

Group coordinate to adapt to.

None
n_jobs int

Number of jobs to use for parallel processing.

1
adapt_sel dict[str, Any] | None

The selection criteria for data used for adaptation calculations. None means all data is used.

None
sel dict[str, Any] | None

Dictionary of coordinates to select.

None
drop_sel dict[str, Any] | None

Dictionary of coordinates to drop.

None
**kwargs

Arguments passed to Transform base class.

{}
Source code in xdflow/transforms/domain_adaptation.py
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
def __init__(
    self,
    target_coord: str,
    sample_dim: str,
    output_dim_name: str = "component",
    n_components: int = 10,
    reg: float = 1e-5,
    strategy_name: str = "joint",
    sampling_method: str = "mean",
    random_state: int = 0,
    group_coord: str | None = None,
    n_jobs: int = 1,
    adapt_sel: dict[str, Any] | None = None,
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
    **kwargs,
):
    """
    Initialize GCCAAligner.

    Args:
        target_coord: Target coordinate to adapt to.
        sample_dim: The dimension to average over when calculating class means.
        output_dim_name: The name of the output dimension/space that all groups are aligned to.
        n_components: Number of canonical components to keep.
        reg: Regularization parameter for covariance matrices.
        strategy_name: The adaptation strategy to use. Currently only 'joint' is supported.
        group_coord: Group coordinate to adapt to.
        n_jobs: Number of jobs to use for parallel processing.
        adapt_sel: The selection criteria for data used for adaptation calculations. None means all data is used.
        sel: Dictionary of coordinates to select.
        drop_sel: Dictionary of coordinates to drop.
        **kwargs: Arguments passed to Transform base class.
    """

    super().__init__(
        strategy_name=strategy_name,
        sample_dim=sample_dim,
        output_dim_name=output_dim_name,
        group_coord=group_coord,
        n_jobs=n_jobs,
        adapt_sel=adapt_sel,
        sel=sel,
        drop_sel=drop_sel,
        **kwargs,
    )

    self.n_components = n_components if n_components is not None else np.inf
    self.target_coord = target_coord
    self.reg = reg
    self.sampling_method = sampling_method
    self.random_state = random_state

    assert self.sampling_method in ["mean", "min_count"], "Sampling method must be either 'mean' or 'min_count'."

JDAAligner

JDAAligner(target_coord: str, sample_dim: str, output_dim_name: str = 'component', n_components: int = 50, kernel: str = 'linear', gamma: float = 1.0, mu: float = 0.5, reg: float = 0.001, strategy_name: str = 'joint', sampling_method: str = 'mean', random_state: int = 0, group_coord: str | None = None, n_jobs: int = 1, adapt_sel: dict[str, Any] | None = None, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None, **kwargs)

Bases: JointGroupAligner, SamplingMixin

Joint Distribution Adaptation (JDA) for domain adaptation.

This method simultaneously adapts both marginal and conditional distributions between domains using Maximum Mean Discrepancy (MMD). JDA learns a transformation that projects all groups to a common subspace while minimizing distribution discrepancy and preserving discriminative information.

The method optimizes both: 1. Marginal distribution alignment: P(X_1) ≈ P(X_2) ≈ ... ≈ P(X_k) 2. Conditional distribution alignment: P(X_1|Y_1) ≈ P(X_2|Y_2) ≈ ... ≈ P(X_k|Y_k)

Initialize JDAAligner.

Parameters:

Name Type Description Default
target_coord str

Coordinate name for class labels (e.g., 'stimulus').

required
sample_dim str

Dimension name for samples (e.g., 'sample').

required
output_dim_name str

Name for the output component dimension.

'component'
n_components int

Number of components for dimensionality reduction.

50
kernel str

Kernel type for MMD computation ('linear', 'rbf').

'linear'
gamma float

Kernel bandwidth parameter for RBF kernel.

1.0
mu float

Trade-off parameter between marginal (0) and conditional (1) adaptation.

0.5
reg float

Regularization parameter for numerical stability.

0.001
strategy_name str

Strategy to use ('joint' for JointGroupStrategy).

'joint'
sampling_method str

Method for sampling data ('mean' or 'min_count').

'mean'
random_state int

Random seed for reproducibility.

0
group_coord str | None

Coordinate name for groups (e.g., 'session').

None
n_jobs int

Number of parallel jobs.

1
adapt_sel dict[str, Any] | None

Selection criteria for adaptation data.

None
sel dict[str, Any] | None

General selection criteria.

None
drop_sel dict[str, Any] | None

Coordinates to drop.

None
**kwargs

Additional arguments for base class.

{}
Source code in xdflow/transforms/domain_adaptation.py
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
def __init__(
    self,
    target_coord: str,
    sample_dim: str,
    output_dim_name: str = "component",
    n_components: int = 50,
    kernel: str = "linear",
    gamma: float = 1.0,
    mu: float = 0.5,
    reg: float = 1e-3,
    strategy_name: str = "joint",
    sampling_method: str = "mean",
    random_state: int = 0,
    group_coord: str | None = None,
    n_jobs: int = 1,
    adapt_sel: dict[str, Any] | None = None,
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
    **kwargs,
):
    """
    Initialize JDAAligner.

    Args:
        target_coord: Coordinate name for class labels (e.g., 'stimulus').
        sample_dim: Dimension name for samples (e.g., 'sample').
        output_dim_name: Name for the output component dimension.
        n_components: Number of components for dimensionality reduction.
        kernel: Kernel type for MMD computation ('linear', 'rbf').
        gamma: Kernel bandwidth parameter for RBF kernel.
        mu: Trade-off parameter between marginal (0) and conditional (1) adaptation.
        reg: Regularization parameter for numerical stability.
        strategy_name: Strategy to use ('joint' for JointGroupStrategy).
        sampling_method: Method for sampling data ('mean' or 'min_count').
        random_state: Random seed for reproducibility.
        group_coord: Coordinate name for groups (e.g., 'session').
        n_jobs: Number of parallel jobs.
        adapt_sel: Selection criteria for adaptation data.
        sel: General selection criteria.
        drop_sel: Coordinates to drop.
        **kwargs: Additional arguments for base class.
    """
    super().__init__(
        strategy_name=strategy_name,
        sample_dim=sample_dim,
        output_dim_name=output_dim_name,
        group_coord=group_coord,
        n_jobs=n_jobs,
        adapt_sel=adapt_sel,
        sel=sel,
        drop_sel=drop_sel,
        **kwargs,
    )

    self.target_coord = target_coord
    self.n_components = n_components
    self.kernel = kernel
    self.gamma = gamma
    self.mu = mu
    self.reg = reg
    self.sampling_method = sampling_method
    self.random_state = random_state

    # Validation
    assert 0 <= self.mu <= 1, "mu must be between 0 and 1"
    assert self.kernel in ["linear", "rbf"], "kernel must be 'linear' or 'rbf'"
    assert self.sampling_method in ["mean", "min_count"], "sampling_method must be 'mean' or 'min_count'"

KCCAAligner

KCCAAligner(target_coord: str, sample_dim: str, output_dim_name: str = 'component', n_components: int = 10, strategy_name: str = 'joint', group_coord: str | None = None, n_jobs: int = 1, sampling_method: str = 'mean', random_state: int = 0, kernel: str = 'linear', gamma: float | None = None, degree: int = 3, coef0: float = 1.0, reg: float = 1e-05, adapt_sel: dict[str, Any] | None = None, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None, **kwargs)

Bases: JointGroupAligner, SamplingMixin

Kernel Canonical Correlation Analysis (KCCA) for domain adaptation.

This method extends CCA to nonlinear relationships by using the kernel trick. It finds a common latent space for two domains by learning a nonlinear transformation that maximizes the correlation between the two domains. This is a custom implementation of KCCA.

Initialize the KCCAAligner.

Parameters:

Name Type Description Default
target_coord str

The coordinate to use for the target domain.

required
sample_dim str

The dimension to use for the sample.

required
output_dim_name str

The name of the output dimension/space that all groups are aligned to.

'component'
n_components int

The number of components to learn.

10
strategy_name str

The strategy to use for adaptation.

'joint'
group_coord str | None

The coordinate to group by. Determines different domains.

None
n_jobs int

The number of jobs to use for parallel processing.

1
sampling_method str

The method to use for sampling the data for alignment.

'mean'
random_state int

The random state to use for sampling the data.

0
kernel str

Kernel mapping used internally. One of 'linear', 'poly', 'rbf', 'sigmoid', 'cosine'.

'linear'
gamma float | None

Kernel coefficient for rbf, poly and sigmoid.

None
degree int

Degree for poly kernels.

3
coef0 float

Independent term in poly and sigmoid kernels.

1.0
reg float

Regularization parameter.

1e-05
adapt_sel dict[str, Any] | None

The selection criteria for data used for adaptation calculations. None means all data is used.

None
sel dict[str, Any] | None

Dictionary of coordinates to select.

None
drop_sel dict[str, Any] | None

Dictionary of coordinates to drop.

None
Source code in xdflow/transforms/domain_adaptation.py
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
def __init__(
    self,
    target_coord: str,
    sample_dim: str,
    output_dim_name: str = "component",
    n_components: int = 10,
    strategy_name: str = "joint",
    group_coord: str | None = None,
    n_jobs: int = 1,
    sampling_method: str = "mean",
    random_state: int = 0,
    kernel: str = "linear",
    gamma: float | None = None,
    degree: int = 3,
    coef0: float = 1.0,
    reg: float = 1e-5,
    adapt_sel: dict[str, Any] | None = None,
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
    **kwargs,
):
    """
    Initialize the KCCAAligner.

    Args:
        target_coord: The coordinate to use for the target domain.
        sample_dim: The dimension to use for the sample.
        output_dim_name: The name of the output dimension/space that all groups are aligned to.
        n_components: The number of components to learn.
        strategy_name: The strategy to use for adaptation.
        group_coord: The coordinate to group by. Determines different domains.
        n_jobs: The number of jobs to use for parallel processing.
        sampling_method: The method to use for sampling the data for alignment.
        random_state: The random state to use for sampling the data.
        kernel: Kernel mapping used internally. One of 'linear', 'poly', 'rbf', 'sigmoid', 'cosine'.
        gamma: Kernel coefficient for rbf, poly and sigmoid.
        degree: Degree for poly kernels.
        coef0: Independent term in poly and sigmoid kernels.
        reg: Regularization parameter.
        adapt_sel: The selection criteria for data used for adaptation calculations. None means all data is used.
        sel: Dictionary of coordinates to select.
        drop_sel: Dictionary of coordinates to drop.
    """

    super().__init__(
        strategy_name=strategy_name,
        sample_dim=sample_dim,
        output_dim_name=output_dim_name,
        group_coord=group_coord,
        n_jobs=n_jobs,
        adapt_sel=adapt_sel,
        sel=sel,
        drop_sel=drop_sel,
        **kwargs,
    )

    self.n_components = n_components if n_components is not None else np.inf
    self.target_coord = target_coord
    self.sampling_method = sampling_method
    self.random_state = random_state
    self.kernel = kernel
    self.gamma = gamma
    self.degree = degree
    self.coef0 = coef0
    self.reg = reg

    assert self.sampling_method in ["mean", "min_count"], "Sampling method must be either 'mean' or 'min_count'."

AdaptWrapperStrategy

AdaptWrapperStrategy(group_coord: str, target_group: str | int | float, n_jobs: int = 1, adapt_sel: dict[str, Any] | None = None)

Bases: AdaptiveStrategy

An adaptation strategy for adapt classes from the adapt package. This is used with the AdaptWrapperTransform. During adaptation, it fits the one source domain and one target domain. During transformation, it transforms the source and target domains according to the adapt class. Domains are determined by the group_coord.

Initialize the AdaptWrapperStrategy.

Parameters:

Name Type Description Default
group_coord str

The coordinate to group by. Determines different domains.

required
target_group str | int | float

The target group to adapt to.

required
n_jobs int

The number of jobs to use for parallel processing.

1
adapt_sel dict[str, Any] | None

The selection criteria for data used for adaptation calculations. None means all data is used.

None
Source code in xdflow/transforms/adapt_wrapper.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def __init__(
    self,
    group_coord: str,
    target_group: str | int | float,
    n_jobs: int = 1,
    adapt_sel: dict[str, Any] | None = None,
):
    """
    Initialize the AdaptWrapperStrategy.

    Args:
        group_coord: The coordinate to group by. Determines different domains.
        target_group: The target group to adapt to.
        n_jobs: The number of jobs to use for parallel processing.
        adapt_sel: The selection criteria for data used for adaptation calculations. None means all data is used.
    """
    super().__init__(group_coord=group_coord, n_jobs=n_jobs, adapt_sel=adapt_sel)
    self.target_group = target_group
    self.target_params = {}
    self.adapted_params = {}
    self.group_dim = None
    self.seen_target_groups_ = []
    self.seen_groups_ = []

transform

transform(aligner: AdaptWrapperTransform, container: DataContainer, **kwargs) -> DataContainer

Transforms data by applying the appropriate source or adapted model to each group.

Parameters:

Name Type Description Default
aligner AdaptWrapperTransform

The AdaptWrapperAligner instance using this strategy.

required
container DataContainer

The DataContainer to be transformed.

required
**kwargs

Additional arguments.

{}

Returns:

Type Description
DataContainer

The transformed DataContainer.

Source code in xdflow/transforms/adapt_wrapper.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def transform(self, aligner: "AdaptWrapperTransform", container: DataContainer, **kwargs) -> DataContainer:
    """
    Transforms data by applying the appropriate source or adapted model to each group.

    Args:
        aligner: The AdaptWrapperAligner instance using this strategy.
        container: The DataContainer to be transformed.
        **kwargs: Additional arguments.

    Returns:
        The transformed DataContainer.
    """
    current_groups = self._discover_groups(container)

    def transform_group(group_val):
        group_container = self._select_group(container, group_val)
        if group_val == self.source_group:
            return aligner._adapted_transform(group_container, domain="source", **kwargs)
        elif group_val == self.target_group:
            return aligner._adapted_transform(group_container, domain="target", **kwargs)
        else:  # Unseen target group
            raise TransformError(
                f"Group '{group_val}' was not seen during 'adapt'. Seen groups: {self.seen_groups_}"
            )

    group_outputs = []
    if self.n_jobs != 1:
        transformed_containers = Parallel(n_jobs=self.n_jobs)(
            delayed(transform_group)(group_val) for group_val in current_groups
        )
        group_outputs.extend([output.data for output in transformed_containers])
    else:
        for group_val in current_groups:
            transformed_container = transform_group(group_val)
            group_outputs.append(transformed_container.data)

    # Reassemble outputs #TODO: do we need to reassemble in the same order?
    reassembled = xr.concat(group_outputs, dim=self.group_dim)

    # Note: Reordering to match original is complex and may not be necessary.
    # If order is critical, it should be handled carefully.

    return DataContainer(reassembled)

AdaptWrapperTransform

AdaptWrapperTransform(adapt_estimator_cls, sample_dim: str, target_coord: str, group_coord: str, target_group: str | int | float, random_state: int = 0, adapt_sel: dict[str, Any] | None = None, sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None, **kwargs)

Bases: AdaptiveTransform

Source code in xdflow/transforms/adapt_wrapper.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def __init__(
    self,
    adapt_estimator_cls,
    sample_dim: str,
    target_coord: str,
    group_coord: str,
    target_group: str | int | float,
    random_state: int = 0,
    adapt_sel: dict[str, Any] | None = None,
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
    **kwargs,
):
    strategy_name = "adapt_wrapper"
    super().__init__(
        # base adaptive transform kwargs
        strategy_name=strategy_name,
        sample_dim=sample_dim,
        # strategy kwargs
        group_coord=group_coord,
        target_group=target_group,
        adapt_sel=adapt_sel,
        # transform kwargs
        sel=sel,
        drop_sel=drop_sel,
        **kwargs,
    )

    self.target_coord = target_coord
    self.adapt_estimator_cls = adapt_estimator_cls  # needed for clone
    self.random_state = random_state

    # Extract estimator-specific parameters (everything not used by Transform or Strategy)
    self._adapt_estimator_kwargs = {k: v for k, v in kwargs.items() if k not in NON_ADAPT_ESTIMATOR_KWARGS}
    self.adapt_estimator = adapt_estimator_cls(**self._adapt_estimator_kwargs)

    # check if adapt estimator has "domain" parameter in fit method
    if "domain" in inspect.signature(self.adapt_estimator.fit).parameters:
        self._specifies_transform_domain = True
    else:
        self._specifies_transform_domain = False

Optional Spectral Module

xdflow.transforms.spectral depends on spectral-connectivity. Install xdflow[spectral] or xdflow[all] before importing:

  • MultiTaperTransform
  • BandpassFilterTransform

The published docs avoid importing that module during the standard docs build so Read the Docs does not need optional spectral dependencies.