Skip to content

Composition API

Composition APIs define how transforms are combined while keeping the pipeline visible to validators and tuners. Use them for sequential pipelines, branching, per-group fitting, optional steps, and ensembles when those choices should participate in split, refit, and cache planning.

Base Composition Types

TransformStep dataclass

TransformStep(name: str, transform: Transform)

Represents a specifically named step in a processing pipeline.

Attributes:

Name Type Description
name str

The name of the step.

transform Transform

The transform object to be executed in this step.

CompositeTransform

CompositeTransform(sel: dict[str, Any] | None = None, drop_sel: dict[str, Any] | None = None, transform_sel: dict | None = None, transform_drop_sel: dict | None = None, **kwargs)

Bases: Transform, ABC

Abstract base class for transforms that are compositions of other transforms.

This class provides common functionality for orchestrators like Pipeline and PipelineUnion, such as dynamically determining statefulness based on its constituent children.

Cloning semantics

  • CompositeTransform.clone() performs a constructor-filtered recursive clone: it reconstructs a new instance using only parameters present in the subclass init signature and, for any child Transform(s), calls child.clone().
  • "Recursive" means we clone through the transform hierarchy, but do not copy fitted state. Each child must keep fitted state out of init so the cloned composite is unfitted.
  • Subclasses should ensure that child collections (e.g., self.steps) are set before super().init so is_stateful can be computed from children.

Initializes the CompositeTransform.

The is_stateful attribute is automatically determined by inspecting the children defined in the concrete subclass. This requires that child collections (e.g., self.steps) are initialized in the subclass's __init__ method before calling super().__init__().

Parameters:

Name Type Description Default
sel dict[str, Any] | None

Dictionary of coordinates to select.

None
drop_sel dict[str, Any] | None

Dictionary of coordinates to drop.

None
transform_sel dict | None

Dictionary of coordinates to select for fitting/transforming.

None
transform_drop_sel dict | None

Dictionary of coordinates to drop for fitting/transforming.

None
**kwargs

Additional keyword arguments.

{}
Source code in xdflow/composite/base.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def __init__(
    self,
    sel: dict[str, Any] | None = None,
    drop_sel: dict[str, Any] | None = None,
    transform_sel: dict | None = None,
    transform_drop_sel: dict | None = None,
    **kwargs,
):
    """
    Initializes the CompositeTransform.

    The `is_stateful` attribute is automatically determined by inspecting
    the children defined in the concrete subclass. This requires that child
    collections (e.g., `self.steps`) are initialized in the subclass's
    `__init__` method *before* calling `super().__init__()`.

    Args:
        sel: Dictionary of coordinates to select.
        drop_sel: Dictionary of coordinates to drop.
        transform_sel: Dictionary of coordinates to select for fitting/transforming.
        transform_drop_sel: Dictionary of coordinates to drop for fitting/transforming.
        **kwargs: Additional keyword arguments.
    """
    # Note: This super().__init__() call must come *after* the subclass
    # has defined the collection that self.children will point to.
    super().__init__(
        sel=sel, drop_sel=drop_sel, transform_sel=transform_sel, transform_drop_sel=transform_drop_sel, **kwargs
    )
    self.is_stateful = builtins.any(child.is_stateful for child in self.children)

children abstractmethod property

children: Iterable[Transform]

An abstract property that must be implemented by subclasses.

Returns:

Type Description
Iterable[Transform]

An iterable collection of the child Transform objects contained

Iterable[Transform]

within this composite.

is_predictor abstractmethod property

is_predictor: bool

Returns True if the transform performs prediction.

predictive_transform property

predictive_transform: Transform | None

Returns the predictive transform if it exists, otherwise None. Must be implemented by subclasses if the subclass is a predictor.

predict

predict(container: DataContainer, **kwargs) -> DataContainer

Predicts on data. Must be implemented by subclasses if the subclass is a predictor but does not inherit from Predictor.

Source code in xdflow/composite/base.py
150
151
152
153
154
155
156
157
158
159
160
def predict(self, container: DataContainer, **kwargs) -> DataContainer:
    """
    Predicts on data.
    Must be implemented by subclasses if the subclass is a predictor but does not inherit from Predictor.
    """
    if isinstance(self, Predictor):
        return super().predict(container, **kwargs)
    if self.is_predictor:
        raise NotImplementedError("Child transform must implement predict.")
    else:
        raise ValueError("CompositeTransform is not a predictor.")

predict_proba

predict_proba(container: DataContainer, **kwargs) -> DataContainer

Predicts the probabilities on data. Must be implemented by subclasses if the subclass is a predictor but does not inherit from Predictor.

Source code in xdflow/composite/base.py
162
163
164
165
166
167
168
169
170
171
172
def predict_proba(self, container: DataContainer, **kwargs) -> DataContainer:
    """
    Predicts the probabilities on data.
    Must be implemented by subclasses if the subclass is a predictor but does not inherit from Predictor.
    """
    if isinstance(self, Predictor):
        return super().predict_proba(container, **kwargs)
    if self.is_predictor:
        raise NotImplementedError("Child transform must implement predict_proba.")
    else:
        raise ValueError("CompositeTransform is not a predictor.")

set_params

set_params(**params: Any) -> CompositeTransform

Set the parameters of this transform and its children.

This method supports nested parameter setting using the __ separator, similar to scikit-learn's Pipeline.

Source code in xdflow/composite/base.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def set_params(self, **params: Any) -> "CompositeTransform":
    """
    Set the parameters of this transform and its children.

    This method supports nested parameter setting using the `__` separator,
    similar to scikit-learn's Pipeline.
    """
    # Separate parameters for this transform and its children
    self_params = {}
    nested_params = {}
    for key, value in params.items():
        if "__" in key:
            # Key is for a nested transform
            step_name, param_name = key.split("__", 1)
            if step_name not in nested_params:
                nested_params[step_name] = {}
            nested_params[step_name][param_name] = value
        else:
            # Key is for this transform itself
            self_params[key] = value

    # Set parameters on this transform
    super().set_params(**self_params)

    # Set parameters on the children
    for step_name, params_to_set in nested_params.items():
        child = self.get_transform_from_name(step_name)
        child.set_params(**params_to_set)

    return self

clone

clone() -> Self

Return a fresh unfitted instance by recursively cloning constructor-filtered params.

This default implementation mirrors Transform.clone but recursively clones any values that are Transforms (or collections of them), ensuring children are cloned without copying fitted state. Only public constructor parameters are passed to the new instance.

Returns:

Name Type Description
Self Self

A new, unfitted instance with cloned child transforms.

Source code in xdflow/composite/base.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def clone(self) -> Self:
    """Return a fresh unfitted instance by recursively cloning constructor-filtered params.

    This default implementation mirrors `Transform.clone` but recursively clones any values
    that are `Transform`s (or collections of them), ensuring children are cloned
    without copying fitted state. Only public constructor parameters are passed
    to the new instance.

    Returns:
        Self: A new, unfitted instance with cloned child transforms.
    """

    def _recursive_clone_constructor_value(value):
        """Recursively clone values that are Transforms or collections thereof.

        - Transform: use `.clone()`
        - TransformStep-like: duck-type objects with `.name` and `.transform`
          where `.transform` is a Transform; reconstruct with the same type
          and a cloned child
        - list/tuple/dict: recurse elementwise, preserving container type
        - everything else: return as-is

        Args:
            value: Any parameter value.

        Returns:
            Any: Cloned value preserving structure where applicable.
        """
        if isinstance(value, Transform):
            return value.clone()

        # Duck-typed TransformStep-like object
        has_transform = hasattr(value, "transform") and hasattr(value, "name")
        if has_transform:
            child = value.transform
            name = value.name
            if isinstance(child, Transform):
                return type(value)(name, child.clone())

        if isinstance(value, tuple):
            # Special-case (name, Transform) tuples while preserving arbitrary metadata in name
            if len(value) == 2 and isinstance(value[1], Transform):
                return (value[0], value[1].clone())
            return tuple(_recursive_clone_constructor_value(v) for v in value)

        if isinstance(value, list):
            return [_recursive_clone_constructor_value(v) for v in value]

        if isinstance(value, dict):
            return {k: _recursive_clone_constructor_value(v) for k, v in value.items()}

        return value

    ctor = signature(type(self).__init__)
    ctor_param_names = {
        name
        for name, p in ctor.parameters.items()
        if name != "self" and p.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY)
    }
    raw_params = self.get_params(deep=False) or {}

    # Ensure all constructor parameters are present in raw params
    for ctor_param_name in ctor_param_names:
        assert ctor_param_name in raw_params, (
            f"Constructor parameter {ctor_param_name} not found as parameter in {self.__class__.__name__}."
        )

    filtered_params = {
        k: _recursive_clone_constructor_value(v) for k, v in raw_params.items() if k in ctor_param_names
    }

    return type(self)(**filtered_params)

get_transform_from_name

get_transform_from_name(name: str) -> Transform

Returns the step with the given name.

Source code in xdflow/composite/base.py
286
287
288
def get_transform_from_name(self, name: str) -> Transform:
    """Returns the step with the given name."""
    return self.transform_from_name[name]

Pipelines

Pipeline

Pipeline(name: str, steps: list[tuple[str, Transform]] | list[TransformStep], expected_input_dims: dict[str, tuple[str, ...]] = None, use_cache: bool = False)

Bases: CompositeTransform

Run named transforms in sequence.

A pipeline is itself a transform, so it can be nested inside other composites or passed directly to CrossValidator. Each step receives the DataContainer produced by the previous step. fit_transform fits stateful steps as the data flows forward; transform assumes stateful steps have already been fitted.

If the final step is a Predictor, the pipeline also exposes predict, predict_proba, and get_labels. In that case all steps before the final predictor are applied first, then prediction is delegated to the predictor.

Step names must be unique. Optional expected_input_dims can be used to validate the dimensions seen by each step at runtime.

Create a named pipeline from transform steps.

Parameters:

Name Type Description Default
name str

Human-readable pipeline name.

required
steps list[tuple[str, Transform]] | list[TransformStep]

Ordered (step_name, transform) pairs or TransformStep objects.

required
expected_input_dims dict[str, tuple[str, ...]]

Optional mapping from each step name to the dimensions expected immediately before that step runs.

None
use_cache bool

Whether to cache fit_transform output for reuse.

False
Source code in xdflow/composite/pipeline.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def __init__(
    self,
    name: str,
    steps: list[tuple[str, Transform]] | list[TransformStep],
    expected_input_dims: dict[str, tuple[str, ...]] = None,
    use_cache: bool = False,
):
    """Create a named pipeline from transform steps.

    Args:
        name: Human-readable pipeline name.
        steps: Ordered `(step_name, transform)` pairs or `TransformStep`
            objects.
        expected_input_dims: Optional mapping from each step name to the
            dimensions expected immediately before that step runs.
        use_cache: Whether to cache `fit_transform` output for reuse.
    """
    self.name = name

    if steps and not isinstance(steps[0], TransformStep):
        steps = [TransformStep(name, transform) for name, transform in steps]
    self.steps: list[TransformStep] = steps
    self.transform_from_name = {step.name: step.transform for step in steps}
    self.expected_input_dims = expected_input_dims or {}
    self.use_cache = use_cache

    # The super().__init__() call will now handle is_stateful
    super().__init__()

    # Compute input/output dimensions
    self.input_dims = self.steps[0].transform.input_dims if self.steps else ()
    self.output_dims = ()  # Dynamic based on step composition

    # Validate pipeline structure at initialization
    self._validate_composition()

children property

children: list[Transform]

Returns the transform objects from the steps.

is_predictor property

is_predictor: bool

Returns True if the last step is a Predictor.

predictive_transform property

predictive_transform: Transform | None

Returns the predictive transform if it exists, otherwise None.

final_target_coord property

final_target_coord: str | None

Convenience: expose the final predictor's target coordinate, if any.

Returns:

Type Description
str | None

The target_coord of the last step when it is a Predictor, else None.

fit

fit(container: DataContainer, **kwargs) -> Pipeline

Fits all stateful transforms in the pipeline using recursive delegation.

This method fits all the transformers in sequence. The data is transformed by each step and passed to the next. The final transformed data is discarded. The primary purpose is to prepare the pipeline for future transform() calls.

Parameters:

Name Type Description Default
container DataContainer

DataContainer to fit on

required
**kwargs

Additional context/parameters passed through the pipeline

{}

Returns:

Type Description
Pipeline

Self (fitted pipeline)

Source code in xdflow/composite/pipeline.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def fit(self, container: DataContainer, **kwargs) -> "Pipeline":
    """
    Fits all stateful transforms in the pipeline using recursive delegation.

    This method fits all the transformers in sequence. The data is transformed
    by each step and passed to the next. The final transformed data is discarded.
    The primary purpose is to prepare the pipeline for future transform() calls.

    Args:
        container: DataContainer to fit on
        **kwargs: Additional context/parameters passed through the pipeline

    Returns:
        Self (fitted pipeline)
    """
    self.fit_transform(container, **kwargs)
    # Discard the transformed data and return self for method chaining
    return self

fit_transform

fit_transform(container: DataContainer, **kwargs) -> DataContainer

Fits and transforms the data in a single, efficient pass.

This is the preferred method when you need to both fit the pipeline and get the transformed training data back. It performs the exact same fitting logic as fit() but returns the final transformed result instead of discarding it.

Parameters:

Name Type Description Default
container DataContainer

DataContainer to fit and transform.

required
**kwargs

Additional context/parameters passed through the pipeline.

{}

Returns:

Type Description
DataContainer

The transformed DataContainer.

Source code in xdflow/composite/pipeline.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
@cache_result(prefix="fit_transform", key_gen_func=get_pipeline_cache_key_dict)
def fit_transform(self, container: DataContainer, **kwargs) -> DataContainer:
    """
    Fits and transforms the data in a single, efficient pass.

    This is the preferred method when you need to both fit the pipeline
    and get the transformed training data back. It performs the exact same
    fitting logic as fit() but returns the final transformed result instead
    of discarding it.

    Args:
        container: DataContainer to fit and transform.
        **kwargs: Additional context/parameters passed through the pipeline.

    Returns:
        The transformed DataContainer.
    """

    temp_container = container
    for step in self.steps:
        try:
            # Runtime validation using the internal expected_input_dims map
            if self.expected_input_dims:
                expected = self.expected_input_dims[step.name]
                actual = temp_container.dims
                if actual != expected:
                    raise RuntimeError(
                        f"Dimension mismatch before step '{step.name}': Expected {expected}, got {actual}"
                    )
            # Fit and transform each step to provide correct input for next step
            temp_container = step.transform.fit_transform(temp_container, **kwargs)

        except Exception as e:
            raise TransformError(f"Error in step '{step.name}' ({step.__class__.__name__}): {e}") from e

    return temp_container

get_expected_output_dims

get_expected_output_dims(input_dims: tuple[str, ...], print_steps: bool = False) -> tuple[str, ...]

Returns the expected output dimensions for the pipeline.

Source code in xdflow/composite/pipeline.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
def get_expected_output_dims(self, input_dims: tuple[str, ...], print_steps: bool = False) -> tuple[str, ...]:
    """Returns the expected output dimensions for the pipeline."""

    # run through steps and get expected output dims
    curr_input_dims = input_dims
    for step in self.steps:
        if print_steps:
            print(f"Step {step.name} \ninput dims: {curr_input_dims}")
        curr_input_dims = step.transform.get_expected_output_dims(curr_input_dims)
        if print_steps:
            print(f"output dims: {curr_input_dims}")

    expected_output_dims = curr_input_dims
    return expected_output_dims

predict

predict(container: DataContainer, **kwargs) -> DataContainer

Generates predictions using the final predictor in the pipeline.

Parameters:

Name Type Description Default
container DataContainer

DataContainer to make predictions on

required
**kwargs

Additional context/parameters passed through the pipeline

{}

Returns:

Type Description
DataContainer

DataContainer with predictions as the primary data

Source code in xdflow/composite/pipeline.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
def predict(self, container: DataContainer, **kwargs) -> DataContainer:
    """
    Generates predictions using the final predictor in the pipeline.

    Args:
        container: DataContainer to make predictions on
        **kwargs: Additional context/parameters passed through the pipeline

    Returns:
        DataContainer with predictions as the primary data
    """
    if not self.steps:
        raise ValueError("Cannot predict with an empty pipeline.")

    final_step = self.steps[-1]
    if not self.is_predictor:
        raise TypeError("The last step of the pipeline must be a Predictor for prediction.")

    transformed_container = self._transform_to_final_step(container, **kwargs)
    return final_step.transform.predict(transformed_container)

predict_proba

predict_proba(container: DataContainer, **kwargs) -> DataContainer

Generates prediction probabilities using the final Predictor in the pipeline.

Parameters:

Name Type Description Default
container DataContainer

DataContainer to make predictions on

required
**kwargs

Additional context/parameters passed through the pipeline

{}

Returns:

Type Description
DataContainer

DataContainer with prediction probabilities

Source code in xdflow/composite/pipeline.py
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
def predict_proba(self, container: DataContainer, **kwargs) -> DataContainer:
    """
    Generates prediction probabilities using the final Predictor in the pipeline.

    Args:
        container: DataContainer to make predictions on
        **kwargs: Additional context/parameters passed through the pipeline

    Returns:
        DataContainer with prediction probabilities
    """
    if not self.steps:
        raise ValueError("Cannot predict with an empty pipeline.")

    final_step = self.steps[-1]
    if not self.is_predictor:
        raise TypeError("The last step of the pipeline must be a Predictor for prediction.")

    transformed_container = self._transform_to_final_step(container, **kwargs)
    return final_step.transform.predict_proba(transformed_container)

prepare_for_inference

prepare_for_inference(*, set_n_jobs_single: bool = True) -> None

Disable training-time optimizations that are undesirable at inference.

Parameters:

Name Type Description Default
set_n_jobs_single bool

When True, force transforms that expose an n_jobs attribute to run single-threaded for request/response latency.

True
Source code in xdflow/composite/pipeline.py
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
def prepare_for_inference(self, *, set_n_jobs_single: bool = True) -> None:
    """
    Disable training-time optimizations that are undesirable at inference.

    Args:
        set_n_jobs_single: When True, force transforms that expose an `n_jobs`
            attribute to run single-threaded for request/response latency.
    """
    self.use_cache = False

    if set_n_jobs_single and hasattr(self, "n_jobs"):
        try:
            self.n_jobs = 1
        except AttributeError:
            pass

    visited: set[int] = {id(self)}
    for step in self.steps:
        _configure_transform_for_inference(
            step.transform,
            set_n_jobs_single=set_n_jobs_single,
            visited=visited,
        )

get_labels

get_labels() -> list[Any]

Return the label ordering from the final predictor.

Relies on the predictor implementing get_labels; raises when the pipeline cannot provide labels unambiguously.

Source code in xdflow/composite/pipeline.py
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def get_labels(self) -> list[Any]:
    """
    Return the label ordering from the final predictor.

    Relies on the predictor implementing `get_labels`; raises when the pipeline
    cannot provide labels unambiguously.
    """
    if not self.steps:
        raise ValueError("Cannot get labels from an empty pipeline.")

    if not self.is_predictor:
        raise TypeError("Pipeline does not terminate in a Predictor, so labels are undefined.")

    predictive_transform = self.predictive_transform
    if predictive_transform is None:
        raise RuntimeError("Pipeline.predictive_transform resolved to None; cannot determine labels.")

    labels = predictive_transform.get_labels()
    if labels is None:
        raise ValueError(
            f"Predictive transform {predictive_transform.__class__.__name__} returned no labels. "
            "Ensure it is fitted and exposes label metadata."
        )

    return labels

Grouped Application

GroupApplyTransform

GroupApplyTransform(group_coord: str | list[str], transform_template: Transform, unseen_policy: Literal['error', 'average', 'weighted_average'] = 'error', unequal_output_dims_strategy: Literal['error', 'cut_to_min'] = 'error', n_jobs: int = 1)

Bases: CompositeTransform

Applies a transform individually to each group defined by a metadata coordinate.

This transform discovers groups from the data at fit time, creates independent transform instances per group by cloning the template, and applies transformations per group. The outputs are reassembled along the original grouped axis.

Use cases: - Apply per-animal preprocessing where each animal needs independent fitting - Train separate models per session or experimental condition - Any scenario where groups should be processed independently

Parameters:

Name Type Description Default
group_coord str | list[str]

Coordinate name to use for grouping (e.g., "animal", "session")

required
transform

Template transform to clone per group (unfitted)

required
unseen_policy Literal['error', 'average', 'weighted_average']

How to handle groups not seen during fit: - "error": raise TransformError (default) - "average": uniform average across all fitted group transforms - "weighted_average": weighted average by training counts per group

'error'
unequal_output_dims_strategy Literal['error', 'cut_to_min']

How to handle unequal (non-group) output dimensions across groups: (unequal output dims lead to NaNs during concatenation) - "error": raise TransformError (default) - "cut_to_min": use the min size per dimension across groups

'error'
n_jobs int

Number of parallel jobs for per-group processing

1

Initialize GroupApplyTransform with grouping parameters.

Source code in xdflow/composite/group_apply.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def __init__(
    self,
    group_coord: str | list[str],
    transform_template: Transform,  # needs same name as param for cloning
    unseen_policy: Literal["error", "average", "weighted_average"] = "error",
    unequal_output_dims_strategy: Literal["error", "cut_to_min"] = "error",
    n_jobs: int = 1,
):
    """Initialize GroupApplyTransform with grouping parameters."""

    self.group_coord = group_coord if isinstance(group_coord, list) else [group_coord]
    self.transform_template = transform_template
    self.unseen_policy = unseen_policy
    self.unequal_output_dims_strategy = unequal_output_dims_strategy
    self.n_jobs = n_jobs

    # State set during fitting
    self.seen_groups: list[Hashable] = []
    self.per_group_fitted: dict[Hashable, Transform] = {}
    self.train_counts: dict[Hashable, int] = {}
    self.group_dim: str | None = None

    # Compute input/output dimensions from template
    self.input_dims = self.transform_template.input_dims
    self.output_dims = self.transform_template.output_dims

    # Compute max size per dimension to avoid nans during concatenation
    self.max_size_per_dim = {}  # used for equalizing output dimensions

    # Compute combined group coord name
    self.combined_group_coord = "_".join(self.group_coord)

    # Call parent after setting up children collections, but override is_stateful manually
    # since children won't exist until after fitting
    super().__init__()

    # Override is_stateful based on template or overrides since children don't exist yet
    self.is_stateful = self.transform_template.is_stateful

    # only keep the template transform
    self.transform_from_name = {"transform_template": self.transform_template}
    self._validate_composition()

children property

children: list[Transform]

Returns fitted per-group transforms after fitting.

is_predictor property

is_predictor: bool

Returns True if the template transform performs prediction.

predictive_transform property

predictive_transform: Transform | None

Returns the predictive transform if it exists, otherwise None.

fit

fit(container: DataContainer, **kwargs) -> GroupApplyTransform

Fits per-group transforms after discovering groups from the data.

Parameters:

Name Type Description Default
container DataContainer

DataContainer to fit on

required
**kwargs

Additional context/parameters passed through

{}

Returns:

Type Description
GroupApplyTransform

Self (fitted GroupApplyTransform)

Source code in xdflow/composite/group_apply.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def fit(self, container: DataContainer, **kwargs) -> "GroupApplyTransform":
    """
    Fits per-group transforms after discovering groups from the data.

    Args:
        container: DataContainer to fit on
        **kwargs: Additional context/parameters passed through

    Returns:
        Self (fitted GroupApplyTransform)
    """

    # Discover grouping structure
    self.group_dim = self._get_group_dim(container)
    container = self._set_combined_group_coord_values(container)
    self.seen_groups = self._discover_groups(container)

    # Reset state
    self.per_group_fitted = {}
    self.train_counts = {}

    # Fit each group
    if self.n_jobs != 1:
        # Parallel fitting
        def fit_group(group_val):
            group_container = self._select_group(container, group_val)
            transform = self.transform_template.clone()
            fitted_transform = transform.fit(group_container, **kwargs)
            train_count = group_container.data.sizes[self.group_dim]
            return group_val, fitted_transform, train_count

        results = Parallel(n_jobs=self.n_jobs)(delayed(fit_group)(group_val) for group_val in self.seen_groups)

        for group_val, fitted_transform, train_count in results:
            self.per_group_fitted[group_val] = fitted_transform
            self.train_counts[group_val] = train_count
    else:
        # Sequential fitting
        for group_val in self.seen_groups:
            group_container = self._select_group(container, group_val)
            transform = self.transform_template.clone()
            self.per_group_fitted[group_val] = transform.fit(group_container, **kwargs)
            self.train_counts[group_val] = group_container.data.sizes[self.group_dim]

    return self

fit_transform

fit_transform(container: DataContainer, **kwargs) -> DataContainer

Fits and transforms in a single pass for efficiency.

Parameters:

Name Type Description Default
container DataContainer

DataContainer to fit and transform

required
**kwargs

Additional context/parameters passed through

{}

Returns:

Type Description
DataContainer

Transformed DataContainer with results reassembled

Source code in xdflow/composite/group_apply.py
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
def fit_transform(self, container: DataContainer, **kwargs) -> DataContainer:
    """
    Fits and transforms in a single pass for efficiency.

    Args:
        container: DataContainer to fit and transform
        **kwargs: Additional context/parameters passed through

    Returns:
        Transformed DataContainer with results reassembled
    """

    # Discover grouping structure
    self.group_dim = self._get_group_dim(container)
    container = self._set_combined_group_coord_values(container)
    self.seen_groups = self._discover_groups(container)

    # Reset state
    self.per_group_fitted = {}
    self.train_counts = {}

    # Fit and transform each group
    if self.n_jobs != 1:
        # Parallel fit_transform
        def fit_transform_group(group_val):
            group_container = self._select_group(container, group_val)
            transform = self.transform_template.clone()
            transformed_container = transform.fit_transform(group_container, **kwargs)
            train_count = group_container.data.sizes[self.group_dim]
            return group_val, transform, transformed_container, train_count

        results = Parallel(n_jobs=self.n_jobs)(
            delayed(fit_transform_group)(group_val) for group_val in self.seen_groups
        )

        group_outputs = []
        for group_val, fitted_transform, transformed_container, train_count in results:
            self.per_group_fitted[group_val] = fitted_transform
            self.train_counts[group_val] = train_count

            # Validate output preserves grouped axis
            group_input = self._select_group(container, group_val)
            self._validate_group_output_preserves_axis(group_val, group_input, transformed_container)

            group_outputs.append(transformed_container.data)
    else:
        # Sequential fit_transform
        group_outputs = []
        for group_val in self.seen_groups:
            group_container = self._select_group(container, group_val)
            transform = self.transform_template.clone()
            transformed_container = transform.fit_transform(group_container, **kwargs)

            # Store fitted transform and count
            self.per_group_fitted[group_val] = transform
            self.train_counts[group_val] = group_container.data.sizes[self.group_dim]

            # Validate output preserves grouped axis
            self._validate_group_output_preserves_axis(group_val, group_container, transformed_container)

            group_outputs.append(transformed_container.data)

    # Reassemble outputs along the grouped dimension
    reassembled = xr.concat(group_outputs, dim=self.group_dim, join="outer")
    reassembled = self._handle_unequal_output_dims(reassembled, group_outputs, fitted=False)

    # remove the combined group coord
    if len(self.group_coord) > 1:
        reassembled = reassembled.drop_vars(self.combined_group_coord)

    return DataContainer(reassembled)

predict

predict(container: DataContainer, **kwargs) -> DataContainer

Generates predictions using per-group fitted predictors.

Parameters:

Name Type Description Default
container DataContainer

DataContainer to make predictions on

required
**kwargs

Additional context/parameters

{}

Returns:

Type Description
DataContainer

DataContainer with predictions

Source code in xdflow/composite/group_apply.py
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
def predict(self, container: DataContainer, **kwargs) -> DataContainer:
    """
    Generates predictions using per-group fitted predictors.

    Args:
        container: DataContainer to make predictions on
        **kwargs: Additional context/parameters

    Returns:
        DataContainer with predictions
    """
    if not self.per_group_fitted:
        raise TransformError("GroupApplyTransform must be fitted before predict")

    # Check that all fitted transforms are predictors
    for group_val, fitted_transform in self.per_group_fitted.items():
        if not isinstance(fitted_transform, Predictor):
            raise TypeError(
                f"Transform for group '{group_val}' is not a Predictor. "
                f"predict() requires all fitted transforms to be Predictors."
            )
    fitted_predictors = {
        group_val: cast(Predictor, fitted_transform)
        for group_val, fitted_transform in self.per_group_fitted.items()
    }

    # set the combined group coord values and discover groups
    container = self._set_combined_group_coord_values(container)
    current_groups = self._discover_groups(container)

    group_outputs = []

    # Process each group
    for group_val in current_groups:
        group_container = self._select_group(container, group_val)

        if group_val in self.per_group_fitted:
            # Use fitted predictor for seen group
            fitted_predictor = fitted_predictors[group_val]
            prediction = fitted_predictor.predict(group_container, **kwargs)
            group_outputs.append(prediction.data)
        else:
            # Apply unseen group policy
            if self.unseen_policy == "error":
                raise TransformError(
                    f"Group '{group_val}' was not seen during fit. Seen groups: {self.seen_groups}"
                )

            # Apply all fitted predictors and average
            if self.n_jobs != 1:
                # Parallel prediction
                def predict_with_fitted(fitted_predictor, group_data):
                    return fitted_predictor.predict(group_data, **kwargs)

                predictions = Parallel(n_jobs=self.n_jobs)(
                    delayed(predict_with_fitted)(fitted_predictor, group_container)
                    for fitted_predictor in fitted_predictors.values()
                )
            else:
                # Sequential prediction
                predictions = []
                for fitted_predictor in fitted_predictors.values():
                    predictions.append(fitted_predictor.predict(group_container, **kwargs))

            # Average predictions
            if self.unseen_policy == "average":
                averaged_data = sum(pred.data for pred in predictions) / len(predictions)
            elif self.unseen_policy == "weighted_average":
                total_count = sum(self.train_counts.values())
                if total_count == 0:
                    averaged_data = sum(pred.data for pred in predictions) / len(predictions)
                else:
                    weights = [count / total_count for count in self.train_counts.values()]
                    averaged_data = sum(w * pred.data for w, pred in zip(weights, predictions, strict=True))

            group_outputs.append(averaged_data)

    # Reassemble predictions
    reassembled = xr.concat(group_outputs, dim=self.group_dim)

    # remove the combined group coord
    if len(self.group_coord) > 1:
        reassembled = reassembled.drop_vars(self.combined_group_coord)

    return DataContainer(reassembled)

predict_proba

predict_proba(container: DataContainer, **kwargs) -> DataContainer

Generates prediction probabilities using per-group fitted predictors.

Parameters:

Name Type Description Default
container DataContainer

DataContainer to make predictions on

required
**kwargs

Additional context/parameters

{}

Returns:

Type Description
DataContainer

DataContainer with prediction probabilities

Source code in xdflow/composite/group_apply.py
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
def predict_proba(self, container: DataContainer, **kwargs) -> DataContainer:
    """
    Generates prediction probabilities using per-group fitted predictors.

    Args:
        container: DataContainer to make predictions on
        **kwargs: Additional context/parameters

    Returns:
        DataContainer with prediction probabilities
    """
    if not self.per_group_fitted:
        raise TransformError("GroupApplyTransform must be fitted before predict_proba")

    # Check that all fitted transforms are classifiers
    for group_val, fitted_transform in self.per_group_fitted.items():
        if not isinstance(fitted_transform, Predictor):
            raise TypeError(
                f"Transform for group '{group_val}' is not a Predictor. "
                f"predict_proba() requires all fitted transforms to be Predictors."
            )
        if not fitted_transform.is_classifier:
            raise TypeError(
                f"Transform for group '{group_val}' is not a classifier. "
                f"predict_proba() requires all fitted transforms to be classifiers."
            )
    fitted_predictors = {
        group_val: cast(Predictor, fitted_transform)
        for group_val, fitted_transform in self.per_group_fitted.items()
    }

    # set the combined group coord values and discover groups
    container = self._set_combined_group_coord_values(container)
    current_groups = self._discover_groups(container)

    group_outputs = []

    # Process each group
    for group_val in current_groups:
        group_container = self._select_group(container, group_val)

        if group_val in self.per_group_fitted:
            # Use fitted predictor for seen group
            fitted_predictor = fitted_predictors[group_val]
            probabilities = fitted_predictor.predict_proba(group_container, **kwargs)
            group_outputs.append(probabilities.data)
        else:
            # Apply unseen group policy
            if self.unseen_policy == "error":
                raise TransformError(
                    f"Group '{group_val}' was not seen during fit. Seen groups: {self.seen_groups}"
                )

            # Apply all fitted predictors and average probabilities
            if self.n_jobs != 1:
                # Parallel prediction
                def predict_proba_with_fitted(fitted_predictor, group_data):
                    return fitted_predictor.predict_proba(group_data, **kwargs)

                probabilities = Parallel(n_jobs=self.n_jobs)(
                    delayed(predict_proba_with_fitted)(fitted_predictor, group_container)
                    for fitted_predictor in fitted_predictors.values()
                )
            else:
                # Sequential prediction
                probabilities = []
                for fitted_predictor in fitted_predictors.values():
                    probabilities.append(fitted_predictor.predict_proba(group_container, **kwargs))

            # Average probabilities
            if self.unseen_policy == "average":
                averaged_data = sum(prob.data for prob in probabilities) / len(probabilities)
            elif self.unseen_policy == "weighted_average":
                total_count = sum(self.train_counts.values())
                if total_count == 0:
                    averaged_data = sum(prob.data for prob in probabilities) / len(probabilities)
                else:
                    weights = [count / total_count for count in self.train_counts.values()]
                    averaged_data = sum(w * prob.data for w, prob in zip(weights, probabilities, strict=True))

            group_outputs.append(averaged_data)

    # Reassemble probabilities
    reassembled = xr.concat(group_outputs, dim=self.group_dim)

    # remove the combined group coord
    if len(self.group_coord) > 1:
        reassembled = reassembled.drop_vars(self.combined_group_coord)

    return DataContainer(reassembled)

get_expected_output_dims

get_expected_output_dims(input_dims: tuple[str, ...]) -> tuple[str, ...]

Returns the expected output dimensions for the GroupApplyTransform.

Parameters:

Name Type Description Default
input_dims tuple[str, ...]

Expected input dimensions

required

Returns:

Type Description
tuple[str, ...]

Expected output dimensions from the reference transform

Source code in xdflow/composite/group_apply.py
651
652
653
654
655
656
657
658
659
660
661
662
def get_expected_output_dims(self, input_dims: tuple[str, ...]) -> tuple[str, ...]:
    """
    Returns the expected output dimensions for the GroupApplyTransform.

    Args:
        input_dims: Expected input dimensions

    Returns:
        Expected output dimensions from the reference transform
    """
    # Use template as reference
    return self.transform_template.get_expected_output_dims(input_dims)

Parallel Branches

TransformUnion

TransformUnion(transforms_list: list[tuple[str, Transform] | Pipeline | TransformStep], from_dims: list[str] | None = None, to_dim: str | None = 'feature', n_jobs: int = 1)

Bases: CompositeTransform

Applies a set of transforms in parallel and concatenates their outputs.

This is a special Transform that applies multiple transforms in parallel to the same input data and concatenates their results. This is useful for combining different types of features (e.g., spectral and temporal) into a single feature set.

Note: This class computes is_stateful dynamically based on constituent transforms, so it overrides the class attribute with an instance attribute.

Uses: TransformUnion(transforms_list=[ Pipeline(name="time_average", steps=[("average_time", AverageTransform(dims="time"))]), ("average_channel", AverageTransform(dims="channel")), ])

Initialize TransformUnion with multiple transforms.

Parameters:

Name Type Description Default
transforms_list list[tuple[str, Transform] | Pipeline | TransformStep]

List of (step_name, transform) tuples, Pipeline objects, or TransformStep objects

required
n_jobs int

Number of parallel jobs to run. - n_jobs=1 (default): Sequential execution, maintains current behavior - n_jobs=-1: Use all available CPU cores - n_jobs>1: Use the specified number of worker processes

1
Source code in xdflow/composite/transform_union.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def __init__(
    self,
    transforms_list: list[tuple[str, Transform] | Pipeline | TransformStep],
    from_dims: list[str] | None = None,
    to_dim: str | None = "feature",
    n_jobs: int = 1,
):
    """
    Initialize TransformUnion with multiple transforms.

    Args:
        transforms_list: List of (step_name, transform) tuples, Pipeline objects, or TransformStep objects
        n_jobs: Number of parallel jobs to run.
               - n_jobs=1 (default): Sequential execution, maintains current behavior
               - n_jobs=-1: Use all available CPU cores
               - n_jobs>1: Use the specified number of worker processes
    """

    # store transforms as TransformSteps for easier handling
    self.transforms_list = []
    for transform in transforms_list:
        if isinstance(transform, Pipeline):
            self.transforms_list.append(TransformStep(transform.name, transform))
        elif isinstance(transform, TransformStep):
            self.transforms_list.append(transform)
        elif isinstance(transform, tuple):
            if isinstance(transform[1], Pipeline) and (transform[0] != transform[1].name):
                raise ValueError(f"Pipeline name ({transform[1].name}) must match the step name ({transform[0]})")
            self.transforms_list.append(TransformStep(transform[0], transform[1]))
        else:
            raise ValueError(f"Invalid transform type: {type(transform)}")
    self.transform_from_name = {step.name: step.transform for step in self.transforms_list}

    # store dims to join
    if from_dims is None:
        raise ValueError("from_dims must be provided")
    self.from_dims = from_dims

    # store parallel processing parameter
    self.n_jobs = n_jobs
    # Name of the final joined dimension
    self.to_dim = to_dim

    # The super().__init__() call will now handle is_stateful
    super().__init__(sel=None, drop_sel=None)  # sel/drop_sel should be given to each transform

    # Compute input/output dimensions
    self.input_dims = ()  # Will be computed in validation if needed
    self.output_dims = ()  # Dynamic based on constituent transforms

    self._validate_composition()

children property

children: list[Transform]

Returns the pipeline objects from the dictionary.

is_predictor property

is_predictor: bool

Returns False for TransformUnion because it concatenates outputs and does not perform prediction.

fit_transform

fit_transform(container: DataContainer, **kwargs) -> DataContainer

Fits and transforms the data in a single, efficient parallel pass.

This method avoids double computation by running fit_transform on each child transform and collecting both the fitted transformer and the transformed data from each worker process.

Parameters:

Name Type Description Default
container DataContainer

DataContainer to fit and transform

required
**kwargs

Additional context/parameters passed through the pipeline

{}

Returns:

Type Description
DataContainer

DataContainer with concatenated results along 'feature' dimension

Source code in xdflow/composite/transform_union.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def fit_transform(self, container: DataContainer, **kwargs) -> DataContainer:
    """
    Fits and transforms the data in a single, efficient parallel pass.

    This method avoids double computation by running `fit_transform` on each
    child transform and collecting both the fitted transformer and the
    transformed data from each worker process.

    Args:
        container: DataContainer to fit and transform
        **kwargs: Additional context/parameters passed through the pipeline

    Returns:
        DataContainer with concatenated results along 'feature' dimension
    """
    if self.is_stateful and self.n_jobs != 1:
        # Parallel execution
        # The helper returns a list of (fitted_transform, transformed_container) tuples
        results = Parallel(n_jobs=self.n_jobs)(
            delayed(_fit_transform_one)(step.transform, container.copy(deep=False), **kwargs)
            for step in self.transforms_list
        )

        # Unzip the results
        fitted_transforms, output_containers = zip(*results)

        # Update the original transform objects with the fitted versions from the workers
        for i, fitted_transform in enumerate(fitted_transforms):
            self.transforms_list[i].transform = fitted_transform

        # Update the name-to-transform mapping as well
        self.transform_from_name = {step.name: step.transform for step in self.transforms_list}

        outputs = [c.data for c in output_containers]

    elif self.is_stateful:
        # Sequential execution for n_jobs=1
        outputs = []
        for step in self.transforms_list:
            # In the sequential case, the transform is modified in-place, and we just collect the data
            transformed_container = step.transform.fit_transform(container.copy(deep=False), **kwargs)
            outputs.append(transformed_container.data)
    else:
        # If not stateful, just transform (no fitting needed)
        outputs = [
            step.transform.transform(container.copy(deep=False), **kwargs).data for step in self.transforms_list
        ]

    # Concatenate the results along a new 'feature' dimension
    concat_outputs = self._concatenate_outputs(outputs)
    return DataContainer(concat_outputs)

fit

fit(container: DataContainer, **kwargs) -> TransformUnion

Fits all stateful steps in parallel or sequentially.

For parallel execution, fitted transforms are returned from worker processes and used to update the original transform objects.

Parameters:

Name Type Description Default
container DataContainer

DataContainer to fit on

required
**kwargs

Additional context/parameters passed through the pipeline

{}

Returns:

Type Description
TransformUnion

Self (fitted TransformUnion)

Source code in xdflow/composite/transform_union.py
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def fit(self, container: DataContainer, **kwargs) -> "TransformUnion":
    """
    Fits all stateful steps in parallel or sequentially.

    For parallel execution, fitted transforms are returned from worker processes
    and used to update the original transform objects.

    Args:
        container: DataContainer to fit on
        **kwargs: Additional context/parameters passed through the pipeline

    Returns:
        Self (fitted TransformUnion)
    """
    if self.is_stateful and self.n_jobs != 1:
        # Parallel execution for stateful transforms
        fitted_transforms = Parallel(n_jobs=self.n_jobs)(
            delayed(step.transform.fit)(container.copy(deep=False), **kwargs) for step in self.transforms_list
        )
        # Update the original transform objects with fitted versions
        for i, fitted_transform in enumerate(fitted_transforms):
            self.transforms_list[i].transform = fitted_transform
        # Update the name-to-transform mapping
        self.transform_from_name = {step.name: step.transform for step in self.transforms_list}
    elif self.is_stateful:
        # Sequential execution (original behavior)
        for step in self.transforms_list:
            step.transform.fit(container.copy(deep=False), **kwargs)
    # If not stateful, no fitting is needed
    return self

get_expected_output_dims

get_expected_output_dims(input_dims: tuple[str, ...]) -> tuple[str, ...]

Returns the expected output dimensions for the PipelineUnion.

Source code in xdflow/composite/transform_union.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
def get_expected_output_dims(self, input_dims: tuple[str, ...]) -> tuple[str, ...]:
    """
    Returns the expected output dimensions for the PipelineUnion.
    """

    expected_output_dims = []
    for step in self.transforms_list:
        step_output_dims = step.transform.get_expected_output_dims(input_dims)
        expected_output_dims.append(step_output_dims)

    # Validate and recover the ordered shared dims (value unused; validates only)
    self._validate_and_get_shared_output_dims(expected_output_dims)

    # Determine final join dimension expected name
    join_dim_name = self.to_dim if self.to_dim is not None else self.from_dims[0]

    # Preserve the full order of the first child's output dims, but map its join dim
    # to the resolved name (the first join dim).
    first_output = expected_output_dims[0]
    first_join_dim = self.from_dims[0]
    ref_order = tuple(join_dim_name if d == first_join_dim else d for d in first_output)
    return ref_order

UnionWithInput

UnionWithInput(transform_template: Transform, join_dim: str, to_dim: str | None = None, n_jobs: int = 1, name: str | None = None)

Bases: TransformUnion

Concatenates a transform's output with the original input along a join dimension.

This is a convenience wrapper around TransformUnion that forms a two-branch union consisting of the provided transform and an identity branch. It is equivalent to:

TransformUnion(
    transforms_list=[("transform", transform), ("identity", IdentityTransform())],
    from_dims=[join_dim, join_dim],
    to_dim=to_dim or join_dim,
    n_jobs=n_jobs,
)

Typical usage is to augment feature channels by concatenating the transform's output with the original input along channel.

Parameters:

Name Type Description Default
transform

The transform or pipeline to apply in the non-identity branch.

required
join_dim str

The dimension name along which to concatenate both branches.

required
to_dim str | None

Optional name for the resulting join dimension. Defaults to join_dim.

None
n_jobs int

Parallelism parameter passed through to TransformUnion.

1
name str | None

Optional explicit name to assign to the transform branch.

None
Source code in xdflow/composite/transform_union.py
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
def __init__(
    self,
    transform_template: Transform,
    join_dim: str,
    to_dim: str | None = None,
    n_jobs: int = 1,
    name: str | None = None,
):
    # Store constructor args for get_params/clone
    self.transform_template = transform_template
    self.join_dim = join_dim
    self.name = name

    step_name = name or transform_template.__class__.__name__.lower()
    transforms_list = [
        (step_name, transform_template),
        ("identity", IdentityTransform()),
    ]
    super().__init__(
        transforms_list=transforms_list,
        from_dims=[join_dim, join_dim],
        to_dim=(to_dim if to_dim is not None else join_dim),
        n_jobs=n_jobs,
    )

Conditional Branches

SwitchTransform

SwitchTransform(choices: list[tuple[str, Transform] | TransformStep | Pipeline] | dict[str, Transform], choose: str | None = None, from_dim: str | None = None, to_dim: str | None = None)

Bases: CompositeTransform

A conditional transform that selects one of several child transforms to execute.

This acts as a placeholder in a pipeline for a step that has multiple possible implementations. The choice of which transform to run is determined at runtime by the choose keyword argument passed to fit or transform.

Parameters:

Name Type Description Default
choices list[tuple[str, Transform] | TransformStep | Pipeline] | dict[str, Transform]

Preferred style is a list of (name, transform) tuples, TransformSteps, or Pipelines, mirroring how Pipeline is declared. For backward compatibility, a dict[str, Transform] is also accepted.

required
choose str | None

Optional explicit selection for the switch. If provided, it must match one of the choice names. If not provided, the user must supply choose at fit/transform time.

None

Initialize SwitchTransform with multiple choice transforms.

Source code in xdflow/composite/switch_transform.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def __init__(
    self,
    choices: list[tuple[str, Transform] | TransformStep | Pipeline] | dict[str, Transform],
    choose: str | None = None,
    from_dim: str | None = None,
    to_dim: str | None = None,
):
    """Initialize SwitchTransform with multiple choice transforms."""

    # Handle dict input for backward compatibility
    if isinstance(choices, dict):
        choices = list(choices.items())

    # Normalize inputs to a list of TransformStep objects
    normalized_steps: list[TransformStep] = []
    for item in choices:
        if isinstance(item, Pipeline):
            normalized_steps.append(TransformStep(item.name, item))
        elif isinstance(item, TransformStep):
            normalized_steps.append(item)
        elif isinstance(item, tuple):
            name, transform = item
            if isinstance(transform, Pipeline) and (name != transform.name):
                raise ValueError(f"Pipeline name ({transform.name}) must match the step name ({name})")
            normalized_steps.append(TransformStep(name, transform))
        else:
            raise ValueError(f"Invalid choice type: {type(item)}")

    if not normalized_steps:
        raise ValueError("At least one choice must be provided")

    # Ensure unique choice names
    if len(normalized_steps) != len({step.name for step in normalized_steps}):
        raise ValueError("Choice names must be unique")

    self.choices: list[TransformStep] = normalized_steps
    self.transform_from_name = {step.name: step.transform for step in self.choices}
    self.from_dim = from_dim
    self.to_dim = to_dim

    # Validate provided 'choose' if given
    if choose is not None:
        if choose not in self.transform_from_name:
            raise ValueError(
                f"Invalid choose='{choose}'. Available choices: {list(self.transform_from_name.keys())}"
            )
    self.choose = choose

    # Call parent after children are established
    super().__init__()

    # Compute input/output dims from the first choice; assert consistency later
    first_transform = self.choices[0].transform
    self.input_dims = first_transform.input_dims
    self.output_dims = first_transform.output_dims

    self._validate_composition()

children property

children: list[Transform]

Returns the transform objects from the choices list.

is_predictor property

is_predictor: bool

Returns True if the selected transform performs prediction.

predictive_transform property

predictive_transform: Transform | None

Returns the predictive transform if it exists, otherwise None.

predict

predict(container: DataContainer, **kwargs) -> DataContainer

Predicts the data using the selected child transform.

Source code in xdflow/composite/switch_transform.py
131
132
133
134
135
136
137
138
139
140
def predict(self, container: DataContainer, **kwargs) -> DataContainer:
    """Predicts the data using the selected child transform."""
    if not self.is_predictor:
        raise ValueError("SwitchTransform is not a predictor.")
    selected_transform = self._get_selected_transform(**kwargs)
    if isinstance(selected_transform, CompositeTransform):
        return self._rename_output_dim(selected_transform.predict(container, **kwargs))
    if isinstance(selected_transform, Predictor):
        return self._rename_output_dim(selected_transform.predict(container, **kwargs))
    raise TypeError(f"Selected transform '{type(selected_transform).__name__}' is not a predictor.")

predict_proba

predict_proba(container: DataContainer, **kwargs) -> DataContainer

Predicts the probabilities using the selected child transform.

Source code in xdflow/composite/switch_transform.py
142
143
144
145
146
147
148
149
150
151
def predict_proba(self, container: DataContainer, **kwargs) -> DataContainer:
    """Predicts the probabilities using the selected child transform."""
    if not self.is_predictor:
        raise ValueError("SwitchTransform is not a predictor.")
    selected_transform = self._get_selected_transform(**kwargs)
    if isinstance(selected_transform, CompositeTransform):
        return self._rename_output_dim(selected_transform.predict_proba(container, **kwargs))
    if isinstance(selected_transform, Predictor):
        return self._rename_output_dim(selected_transform.predict_proba(container, **kwargs))
    raise TypeError(f"Selected transform '{type(selected_transform).__name__}' is not a predictor.")

fit

fit(container: DataContainer, **kwargs) -> SwitchTransform

Fits the selected child transform.

Source code in xdflow/composite/switch_transform.py
153
154
155
156
157
def fit(self, container: DataContainer, **kwargs) -> "SwitchTransform":
    """Fits the selected child transform."""
    selected_transform = self._get_selected_transform(**kwargs)
    selected_transform.fit(container, **kwargs)
    return self

fit_transform

fit_transform(container: DataContainer, **kwargs) -> DataContainer

Fit/transform by delegating to the selected child.

If the selected child is stateful, call its fit_transform; otherwise, call its transform. This allows mixing stateful and stateless choices without requiring the switch wrapper itself to implement _fit.

Source code in xdflow/composite/switch_transform.py
159
160
161
162
163
164
165
166
167
168
169
170
171
def fit_transform(self, container: DataContainer, **kwargs) -> DataContainer:
    """Fit/transform by delegating to the selected child.

    If the selected child is stateful, call its fit_transform; otherwise,
    call its transform. This allows mixing stateful and stateless choices
    without requiring the switch wrapper itself to implement _fit.
    """
    selected_transform = self._get_selected_transform(**kwargs)
    if getattr(selected_transform, "is_stateful", False):
        result = selected_transform.fit_transform(container, **kwargs)
    else:
        result = selected_transform.transform(container, **kwargs)
    return self._rename_output_dim(result)

get_expected_output_dims

get_expected_output_dims(input_dims: tuple[str, ...]) -> tuple[str, ...]

Determines the expected output dimensions.

For consistency, this implementation requires that all possible choices produce the same output dimensions for a given input. It validates this by checking the first choice and then asserting all others match.

Source code in xdflow/composite/switch_transform.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def get_expected_output_dims(self, input_dims: tuple[str, ...]) -> tuple[str, ...]:
    """
    Determines the expected output dimensions.

    For consistency, this implementation requires that all possible choices
    produce the same output dimensions for a given input. It validates
    this by checking the first choice and then asserting all others match.
    """
    if not self.choices:
        return input_dims

    # Get the expected output from the first choice as a reference
    ref_name = self.choices[0].name
    ref_transform = self.choices[0].transform
    expected_dims = self._rename_dims_tuple(ref_transform.get_expected_output_dims(input_dims))

    # Verify that all other choices produce the same output dimensions
    for step in self.choices[1:]:
        name = step.name
        current_dims = self._rename_dims_tuple(step.transform.get_expected_output_dims(input_dims))
        if current_dims != expected_dims:
            raise TransformError(
                f"Inconsistent output dimensions in SwitchTransform. "
                f"Choice '{ref_name}' produces {expected_dims}, but "
                f"choice '{name}' produces {current_dims}. All choices must have "
                "the same output dimension signature."
            )

    return expected_dims

OptionalTransform

OptionalTransform(transform_template: Transform, choose: str | None = None, use: bool | None = None, name: str | None = None, skip_name: str = 'identity', identity_rename: dict[str, str] | None = None)

Bases: SwitchTransform

Optionally apply a transform or skip it entirely (identity behavior).

This is a convenience wrapper over SwitchTransform that defines two choices: - "use": apply the provided transform - "skip": apply IdentityTransform (no-op)

You can control selection by either: - use=True|False boolean, or - choose set to either "use" or "skip".

Note

For this to be valid within a statically validated pipeline, the wrapped transform should preserve the dimension signature. Otherwise, the two choices would yield different output dims and violate the validation requirement that choices share the same output dims.

Parameters:

Name Type Description Default
transform_template Transform

The transform or pipeline to optionally apply.

required
choose str | None

Optional explicit selection ("use" or "skip").

None
use bool | None

Optional boolean shorthand for choose.

None

Initialize an OptionalTransform.

Parameters:

Name Type Description Default
transform_template Transform

The wrapped transform/pipeline to optionally apply.

required
choose str | None

Explicit choice label to select the branch. If provided, must be either the transform branch name or skip_name.

None
use bool | None

Boolean shorthand; if provided, maps to choose with transform branch name when True and skip_name when False.

None
name str | None

Optional label for the transform branch. Defaults to the lowercased class name of transform.

None
skip_name str

Label for the identity branch. Defaults to "identity".

'identity'
identity_rename dict[str, str] | None

Optional mapping of coordinate names to rename ONLY when the identity branch is selected, e.g., {"old_coord": "new_coord"}. This renames coordinates without altering dimension names.

None
Source code in xdflow/composite/switch_transform.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
def __init__(
    self,
    transform_template: Transform,
    choose: str | None = None,
    use: bool | None = None,
    name: str | None = None,
    skip_name: str = "identity",
    identity_rename: dict[str, str] | None = None,
):
    """
    Initialize an OptionalTransform.

    Args:
        transform_template: The wrapped transform/pipeline to optionally apply.
        choose: Explicit choice label to select the branch. If provided, must
            be either the transform branch name or `skip_name`.
        use: Boolean shorthand; if provided, maps to `choose` with
            transform branch name when True and `skip_name` when False.
        name: Optional label for the transform branch. Defaults to the
            lowercased class name of `transform`.
        skip_name: Label for the identity branch. Defaults to "identity".
        identity_rename: Optional mapping of coordinate names to rename ONLY when
            the identity branch is selected, e.g., {"old_coord": "new_coord"}.
            This renames coordinates without altering dimension names.
    """
    # Store original constructor args for easy cloning
    self.transform_template = transform_template
    self.choose = choose
    self.use = use
    self.name = name
    self.skip_name = skip_name
    self.identity_rename = identity_rename

    # Prefer an explicit name; otherwise prefer the wrapped transform's own name
    # (e.g., Pipeline(name=...)); finally fall back to the class name.
    explicit_name = name
    auto_name_from_transform = getattr(transform_template, "name", None)
    transform_choice_name = (
        explicit_name or auto_name_from_transform or transform_template.__class__.__name__.lower()
    )

    if choose is None and use is not None:
        resolved_choose = transform_choice_name if use else self.skip_name
    elif choose is None and use is None:
        # Default to applying the transform when not specified explicitly
        resolved_choose = transform_choice_name
    else:
        resolved_choose = choose

    # Create identity branch - either plain IdentityTransform or with renaming
    if identity_rename:
        identity_transform = RenameDimsTransform(rename_map=identity_rename)
    else:
        identity_transform = IdentityTransform()

    super().__init__(
        choices=[(transform_choice_name, transform_template), (self.skip_name, identity_transform)],
        choose=resolved_choose,
    )

Ensembles

EnsembleMember dataclass

EnsembleMember(name: str, transform: Transform, weight: float = 1.0)

Represents a member of an ensemble with its name, predictor, and weight.

Attributes:

Name Type Description
name str

The name of the ensemble member.

transform Transform

The transform object.

weight float

The weight for this member in the ensemble.

predictor property

predictor: Predictor

Return the predictive component for this member's transform.

EnsemblePredictor

EnsemblePredictor(members: list[tuple[str, Predictor] | EnsembleMember | TransformStep | Transform], sample_dim: str, target_coord: str, encoder: LabelEncoder | None = None, weights: list[float] | None = None, weighting_strategy: Literal['uniform', 'score_based', 'custom'] = 'uniform', scoring_func: Callable = accuracy_score, scoring_transform_func: Callable[[float], float] | None = None, normalize_weights: bool = True, normalize_outputs: bool = True, n_jobs: int = 1, calibration_container: DataContainer | None = None, proba: bool = False, sel: dict | None = None, drop_sel: dict | None = None, **kwargs)

Bases: CompositeTransform, Predictor

An ensemble predictor that combines multiple predictors using weighted averaging.

This predictor applies multiple child predictors to the same input and combines their outputs using weighted averaging. It supports various weighting strategies for combining predictor outputs.

Features: - Multiple weighting strategies (uniform, score-based, custom) - Parallel execution support - Score-based weighting with customizable scoring functions - Proper validation and error handling - Both prediction and probability prediction ensemble

Parameters:

Name Type Description Default
members list[tuple[str, Predictor] | EnsembleMember | TransformStep | Transform]

List of (name, predictor) tuples, EnsembleMember objects, or TransformStep objects

required
sample_dim str

Name of the sample dimension

required
target_coord str

Name of the target coordinate

required
encoder LabelEncoder | None

Optional label encoder for the predictor

None
weights list[float] | None

Optional explicit weights for the members (overrides weighting_strategy)

None
weighting_strategy Literal['uniform', 'score_based', 'custom']

Strategy for determining weights ('uniform', 'score_based', 'custom')

'uniform'
scoring_func Callable

Function to use for score-based weighting (default: accuracy_score)

accuracy_score
scoring_transform_func Callable[[float], float] | None

Function to transform scores before using as weights

None
normalize_weights bool

Whether to normalize weights to sum to 1

True
normalize_outputs bool

Whether to normalize final ensemble outputs

True
n_jobs int

Number of parallel jobs for execution

1
calibration_container DataContainer | None

Optional container for score-based weighting calibration

None

Initialize EnsemblePredictor with ensemble members and configuration.

Parameters:

Name Type Description Default
members list[tuple[str, Predictor] | EnsembleMember | TransformStep | Transform]

List of ensemble members in various formats

required
sample_dim str

Name of the sample dimension

required
target_coord str

Name of the target coordinate

required
encoder LabelEncoder | None

Optional label encoder for the predictor

None
weights list[float] | None

Optional explicit weights (overrides weighting_strategy)

None
weighting_strategy Literal['uniform', 'score_based', 'custom']

How to determine member weights

'uniform'
scoring_func Callable

Function for score-based weighting evaluation

accuracy_score
scoring_transform_func Callable[[float], float] | None

Transform function applied to scores (defaults to identity function)

None
normalize_weights bool

Whether to normalize weights to sum to 1

True
normalize_outputs bool

Whether to normalize final outputs

True
n_jobs int

Number of parallel jobs to use

1
calibration_container DataContainer | None

Data for score-based weight calibration

None
proba bool

Whether to return probabilities by default

False
sel dict | None

Optional selection to apply before predicting

None
drop_sel dict | None

Optional drop selection to apply before predicting

None
Source code in xdflow/composite/ensemble.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
def __init__(
    self,
    members: list[tuple[str, Predictor] | EnsembleMember | TransformStep | Transform],
    sample_dim: str,
    target_coord: str,
    encoder: LabelEncoder | None = None,
    weights: list[float] | None = None,
    weighting_strategy: Literal["uniform", "score_based", "custom"] = "uniform",
    scoring_func: Callable = accuracy_score,
    scoring_transform_func: Callable[[float], float] | None = None,
    normalize_weights: bool = True,
    normalize_outputs: bool = True,
    n_jobs: int = 1,
    calibration_container: DataContainer | None = None,
    proba: bool = False,
    sel: dict | None = None,
    drop_sel: dict | None = None,
    **kwargs,
):
    """
    Initialize EnsemblePredictor with ensemble members and configuration.

    Args:
        members: List of ensemble members in various formats
        sample_dim: Name of the sample dimension
        target_coord: Name of the target coordinate
        encoder: Optional label encoder for the predictor
        weights: Optional explicit weights (overrides weighting_strategy)
        weighting_strategy: How to determine member weights
        scoring_func: Function for score-based weighting evaluation
        scoring_transform_func: Transform function applied to scores (defaults to identity function)
        normalize_weights: Whether to normalize weights to sum to 1
        normalize_outputs: Whether to normalize final outputs
        n_jobs: Number of parallel jobs to use
        calibration_container: Data for score-based weight calibration
        proba: Whether to return probabilities by default
        sel: Optional selection to apply before predicting
        drop_sel: Optional drop selection to apply before predicting
    """

    # Set default for scoring_transform_func
    if scoring_transform_func is None:
        scoring_transform_func = _identity_transform

    # Normalize inputs to EnsembleMember objects
    self.members: list[EnsembleMember] = []
    initial_weights = []
    for i, member in enumerate(members):
        if isinstance(member, EnsembleMember):
            self.members.append(member)
            initial_weights.append(member.weight)
        elif isinstance(member, TransformStep) or isinstance(member, tuple) or isinstance(member, Transform):
            if isinstance(member, tuple):
                name, transform = member
            elif isinstance(member, TransformStep):
                name = member.name
                transform = member.transform
            else:
                name = f"member_{i}"
                transform = member

            if isinstance(transform, CompositeTransform):
                if not transform.is_predictor:
                    raise ValueError(f"Member '{name}' is/contains a CompositeTransform that is not a predictor")
                self.members.append(EnsembleMember(name=name, transform=transform))
            else:
                if not isinstance(transform, Predictor):
                    raise ValueError(f"Member '{name}' must be/contain a Predictor, got {type(transform)}")
                self.members.append(EnsembleMember(name=name, transform=transform))  # predictor same as transform
            initial_weights.append(1.0)
        else:
            raise ValueError(f"Invalid member type: {type(member)}")

    if not self.members:
        raise ValueError("At least one ensemble member must be provided")

    # Store configuration
    self.weighting_strategy = weighting_strategy
    self.scoring_func = scoring_func
    self.scoring_transform_func = scoring_transform_func
    self.normalize_weights = normalize_weights
    self.normalize_outputs = normalize_outputs
    self.n_jobs = n_jobs
    self.calibration_container = calibration_container

    # Set weights
    if weights is not None:
        if len(weights) != len(self.members):
            raise ValueError(
                f"Number of weights ({len(weights)}) must match number of members ({len(self.members)})"
            )
        self._set_weights(weights)
    else:
        # Use initial weights from members or uniform weights
        self._set_weights(initial_weights)

    # Determine if this is a classifier based on first member
    first_predictor = self.members[0].predictor
    is_classifier = first_predictor.is_classifier

    # Initialize parent Predictor
    super().__init__(
        sample_dim=sample_dim,
        target_coord=target_coord,
        is_classifier=is_classifier,
        encoder=encoder,
        proba=proba,
        sel=sel,
        drop_sel=drop_sel,
        **kwargs,
    )

    self._validate_composition()
    self._ensure_shared_encoders()  # Handle pre-fitted models at initialization

    # set is_fitted
    self._is_fitted = self.check_is_fitted()

children property

children: list[Transform]

Returns the transform objects from the ensemble members.

check_is_fitted

check_is_fitted() -> bool

Checks if the ensemble is fitted if all members are fitted.

Source code in xdflow/composite/ensemble.py
313
314
315
316
317
318
def check_is_fitted(self) -> bool:
    """Checks if the ensemble is fitted if all members are fitted."""
    for member in self.members:
        if not getattr(member.predictor, "_is_fitted", False):
            return False
    return True

prepare_for_inference

prepare_for_inference(*, set_n_jobs_single: bool = True) -> None

Disable training-time options that slow down per-request inference.

Parameters:

Name Type Description Default
set_n_jobs_single bool

When True, force single-threaded execution for members.

True
Source code in xdflow/composite/ensemble.py
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
def prepare_for_inference(self, *, set_n_jobs_single: bool = True) -> None:
    """
    Disable training-time options that slow down per-request inference.

    Args:
        set_n_jobs_single: When True, force single-threaded execution for members.
    """
    if set_n_jobs_single and hasattr(self, "n_jobs"):
        try:
            self.n_jobs = 1
        except AttributeError:
            pass

    visited: set[int] = {id(self)}
    for member in self.members:
        _configure_transform_for_inference(
            member.transform,
            set_n_jobs_single=set_n_jobs_single,
            visited=visited,
        )

fit_transform

fit_transform(container: DataContainer, **kwargs) -> DataContainer

Fits and transforms all ensemble members.

Parameters:

Name Type Description Default
container DataContainer

DataContainer to fit on

required
**kwargs

Additional context/parameters passed through

{}

Returns:

Type Description
DataContainer

Self (fitted ensemble)

Source code in xdflow/composite/ensemble.py
538
539
540
541
542
543
544
545
546
547
548
549
550
def fit_transform(self, container: DataContainer, **kwargs) -> DataContainer:
    """
    Fits and transforms all ensemble members.

    Args:
        container: DataContainer to fit on
        **kwargs: Additional context/parameters passed through

    Returns:
        Self (fitted ensemble)
    """

    raise NotImplementedError("fit_transform is not implemented for EnsemblePredictor, should not be needed")

transform

transform(container: DataContainer, **kwargs) -> DataContainer

Transforms the data using all ensemble members.

Source code in xdflow/composite/ensemble.py
552
553
554
555
556
557
558
559
560
561
def transform(self, container: DataContainer, **kwargs) -> DataContainer:
    """
    Transforms the data using all ensemble members.
    """
    warnings.warn(
        "using predict instead of transform. transform should not be needed/used for EnsemblePredictor",
        UserWarning,
        stacklevel=2,
    )
    return self.predict(container, **kwargs)

predict

predict(container: DataContainer, **kwargs) -> DataContainer

Predict labels using ensemble, leveraging shared encoding optimization.

Parameters:

Name Type Description Default
container DataContainer

DataContainer to predict on

required
**kwargs

Additional context/parameters passed through

{}

Returns:

Type Description
DataContainer

DataContainer with predictions

Source code in xdflow/composite/ensemble.py
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
def predict(self, container: DataContainer, **kwargs) -> DataContainer:
    """
    Predict labels using ensemble, leveraging shared encoding optimization.

    Args:
        container: DataContainer to predict on
        **kwargs: Additional context/parameters passed through

    Returns:
        DataContainer with predictions
    """
    # Apply selection using base helper
    container = self._apply_selection(container)

    # Encode target coordinate if classifier
    if self.is_classifier:
        if self.encoder is None or not hasattr(self.encoder, "classes_"):
            raise ValueError(f"{self.__class__.__name__} requires a fitted encoder before calling predict.")
        encoded_da = self._encode_target_coord(container.data)
        encoded_container = DataContainer(encoded_da)
    else:
        encoded_container = container

    # Delegate to encoded entry point for ensemble logic
    return self._predict_from_encoded(encoded_container, original_container=container, **kwargs)

predict_proba

predict_proba(container: DataContainer, **kwargs) -> DataContainer

Predict class probabilities using ensemble, leveraging shared encoding optimization.

Parameters:

Name Type Description Default
container DataContainer

DataContainer to predict on

required
**kwargs

Additional context/parameters passed through

{}

Returns:

Type Description
DataContainer

DataContainer with class probabilities

Source code in xdflow/composite/ensemble.py
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
def predict_proba(self, container: DataContainer, **kwargs) -> DataContainer:
    """
    Predict class probabilities using ensemble, leveraging shared encoding optimization.

    Args:
        container: DataContainer to predict on
        **kwargs: Additional context/parameters passed through

    Returns:
        DataContainer with class probabilities
    """
    if not self.is_classifier:
        raise AttributeError(
            f"'{self.__class__.__name__}' has not been instantiated as a classifier "
            "(is_classifier=False) so should not call the 'predict_proba' method."
        )

    # Apply selection using base helper
    container = self._apply_selection(container)

    # Encode target coordinate
    if self.encoder is None or not hasattr(self.encoder, "classes_"):
        raise ValueError(f"{self.__class__.__name__} requires a fitted encoder before calling predict_proba.")
    encoded_da = self._encode_target_coord(container.data)
    encoded_container = DataContainer(encoded_da)

    # Delegate to encoded entry point for ensemble logic
    return self._predict_proba_from_encoded(encoded_container, original_container=container, **kwargs)

predict_proba_with_uncertainty_components

predict_proba_with_uncertainty_components(container: DataContainer, **kwargs) -> tuple[DataContainer, DataContainer, DataContainer]

Predict class probabilities and entropy-based aleatoric/epistemic uncertainty components.

The two components are:

A = E_w[H(p_i)]                     (aleatoric)
B = H(E_w[p_i]) - E_w[H(p_i)]       (epistemic)

Parameters:

Name Type Description Default
container DataContainer

DataContainer to predict on.

required
**kwargs

Additional context/parameters passed through to member predictors.

{}

Returns:

Type Description
tuple[DataContainer, DataContainer, DataContainer]

Tuple of: - DataContainer with class probabilities (same as predict_proba) - DataContainer with aleatoric uncertainty (A), one score per sample - DataContainer with epistemic uncertainty (B), one score per sample

Source code in xdflow/composite/ensemble.py
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
def predict_proba_with_uncertainty_components(
    self, container: DataContainer, **kwargs
) -> tuple[DataContainer, DataContainer, DataContainer]:
    """
    Predict class probabilities and entropy-based aleatoric/epistemic uncertainty components.

    The two components are:

        A = E_w[H(p_i)]                     (aleatoric)
        B = H(E_w[p_i]) - E_w[H(p_i)]       (epistemic)

    Args:
        container: DataContainer to predict on.
        **kwargs: Additional context/parameters passed through to member predictors.

    Returns:
        Tuple of:
            - DataContainer with class probabilities (same as predict_proba)
            - DataContainer with aleatoric uncertainty (A), one score per sample
            - DataContainer with epistemic uncertainty (B), one score per sample
    """
    if not self.is_classifier:
        raise AttributeError(
            f"'{self.__class__.__name__}' has not been instantiated as a classifier "
            "(is_classifier=False) so should not call 'predict_proba_with_uncertainty_components'."
        )

    # Apply selection using base helper
    container = self._apply_selection(container)

    # Encode target coordinate
    if self.encoder is None or not hasattr(self.encoder, "classes_"):
        raise ValueError(
            f"{self.__class__.__name__} requires a fitted encoder before calling "
            "predict_proba_with_uncertainty_components."
        )
    encoded_da = self._encode_target_coord(container.data)
    encoded_container = DataContainer(encoded_da)

    # Collect per-member probability predictions (aligned via shared encoder)
    prob_predictions = self._collect_member_predictions(
        encoded_container=encoded_container,
        original_container=container,
        proba=True,
        **kwargs,
    )
    prob_data_arrays = [pred.data for pred in prob_predictions]
    if not prob_data_arrays:
        raise ValueError("No probability outputs to compute uncertainty from")

    # Aggregate probability outputs (same as predict_proba)
    ensemble_proba = self._ensemble_proba(prob_data_arrays)

    # Extract class labels and encode them for alignment
    class_coord = ensemble_proba.coords.get("class")
    if class_coord is not None:
        class_labels = np.asarray(class_coord.values)
        if self.encoder is None:
            raise ValueError(f"{self.__class__.__name__} requires an encoder for classifier probabilities.")
        encoded_classes = self.encoder.transform(class_labels)
    else:
        encoded_classes = np.arange(ensemble_proba.shape[-1])

    aligned_proba = self._align_proba_to_global(ensemble_proba.values, encoded_classes)

    # Build probability DataContainer
    output_coords = self._get_output_coords(container.data)
    if self.encoder is None:
        raise ValueError(f"{self.__class__.__name__} requires an encoder for classifier probabilities.")
    output_coords["class"] = self.encoder.classes_

    proba_da = xr.DataArray(
        aligned_proba,
        dims=(self.sample_dim, "class"),
        coords=output_coords,
        attrs=container.data.attrs,
    )
    proba_container = DataContainer(proba_da)

    # Compute entropy-based aleatoric and epistemic uncertainty components
    sample_dim = self.sample_dim
    class_dim = "class"
    probs = np.stack(
        [da.transpose(sample_dim, class_dim).values for da in prob_data_arrays],
        axis=0,
    )

    weights = np.asarray(self.weights, dtype=float)
    if weights.shape[0] != probs.shape[0]:
        raise ValueError(
            f"Number of weights ({weights.shape[0]}) must match number of members ({probs.shape[0]}) "
            "when computing uncertainty."
        )
    weights = weights / weights.sum()

    aleatoric_scores, epistemic_scores = _entropy_uncertainty_components(probs, weights)

    # Build uncertainty DataContainers with one score per sample
    uncertainty_coords = self._get_output_coords(container.data)
    aleatoric_da = xr.DataArray(
        aleatoric_scores,
        dims=[self.sample_dim],
        coords=uncertainty_coords,
        attrs=container.data.attrs,
        name="aleatoric_uncertainty",
    )
    epistemic_da = xr.DataArray(
        epistemic_scores,
        dims=[self.sample_dim],
        coords=uncertainty_coords,
        attrs=container.data.attrs,
        name="epistemic_uncertainty",
    )

    return proba_container, DataContainer(aleatoric_da), DataContainer(epistemic_da)

predict_proba_with_std

predict_proba_with_std(container: DataContainer, *, return_stderr: bool = False, **kwargs) -> tuple[DataContainer, DataContainer]

Predict class probabilities along with the standard deviation or standard error across ensemble members.

Parameters:

Name Type Description Default
container DataContainer

DataContainer to predict on.

required
return_stderr bool

When True, return standard error instead of standard deviation.

False
**kwargs

Additional context/parameters passed through to member predictors.

{}

Returns:

Type Description
tuple[DataContainer, DataContainer]

Tuple of: - DataContainer with class probabilities (same as predict_proba) - DataContainer with standard deviation (or standard error) per sample/class

Source code in xdflow/composite/ensemble.py
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
def predict_proba_with_std(
    self, container: DataContainer, *, return_stderr: bool = False, **kwargs
) -> tuple[DataContainer, DataContainer]:
    """
    Predict class probabilities along with the standard deviation or standard error across ensemble members.

    Args:
        container: DataContainer to predict on.
        return_stderr: When True, return standard error instead of standard deviation.
        **kwargs: Additional context/parameters passed through to member predictors.

    Returns:
        Tuple of:
            - DataContainer with class probabilities (same as predict_proba)
            - DataContainer with standard deviation (or standard error) per sample/class
    """
    if not self.is_classifier:
        raise AttributeError(
            f"'{self.__class__.__name__}' has not been instantiated as a classifier "
            "(is_classifier=False) so should not call 'predict_proba_with_std'."
        )

    # Apply selection using base helper
    container = self._apply_selection(container)

    # Encode target coordinate
    if self.encoder is None or not hasattr(self.encoder, "classes_"):
        raise ValueError(f"{self.__class__.__name__} requires a fitted encoder before calling predict_proba.")
    encoded_da = self._encode_target_coord(container.data)
    encoded_container = DataContainer(encoded_da)

    # Collect per-member probability predictions (aligned via shared encoder)
    prob_predictions = self._collect_member_predictions(
        encoded_container=encoded_container,
        original_container=container,
        proba=True,
        **kwargs,
    )
    prob_data_arrays = [pred.data for pred in prob_predictions]
    if not prob_data_arrays:
        raise ValueError("No probability outputs to compute uncertainty from")

    # Aggregate probability outputs (same as predict_proba)
    ensemble_proba = self._ensemble_proba(prob_data_arrays)

    # Extract class labels and encode them for alignment
    class_coord = ensemble_proba.coords.get("class")
    if class_coord is not None:
        class_labels = np.asarray(class_coord.values)
        if self.encoder is None:
            raise ValueError(f"{self.__class__.__name__} requires an encoder for classifier probabilities.")
        encoded_classes = self.encoder.transform(class_labels)
    else:
        encoded_classes = np.arange(ensemble_proba.shape[-1])

    aligned_proba = self._align_proba_to_global(ensemble_proba.values, encoded_classes)

    # Build probability DataContainer
    output_coords = self._get_output_coords(container.data)
    if self.encoder is None:
        raise ValueError(f"{self.__class__.__name__} requires an encoder for classifier probabilities.")
    output_coords["class"] = self.encoder.classes_

    proba_da = xr.DataArray(
        aligned_proba,
        dims=(self.sample_dim, "class"),
        coords=output_coords,
        attrs=container.data.attrs,
    )
    proba_container = DataContainer(proba_da)

    # Compute standard deviation (or standard error) across members
    probs = np.stack([da.values for da in prob_data_arrays], axis=0)
    std = probs.std(axis=0)
    if return_stderr:
        std = std / np.sqrt(probs.shape[0])

    std_da = xr.DataArray(
        std,
        dims=(self.sample_dim, "class"),
        coords=output_coords,
        attrs=container.data.attrs,
        name="proba_std_error" if return_stderr else "proba_std",
    )
    return proba_container, DataContainer(std_da)