Skip to content

Utilities API

Utilities cover caching, sampling, target-coordinate handling, spectral helpers, sample weights, and plotting helpers.

Cache Utilities

cache_result

cache_result(prefix: str, max_size: int = DEFAULT_MAX_CACHE_SIZE, max_age_days: float = DEFAULT_MAX_CACHE_AGE, key_gen_func: Callable | None = None) -> Callable

Decorator that caches function results based on all function and class instance parameters.

Parameters:

Name Type Description Default
prefix str

Prefix for the cache directory (e.g., 'preprocess', 'featurize')

required
max_size int

Maximum cache size in bytes for this prefix

DEFAULT_MAX_CACHE_SIZE
max_age_days float

Maximum age of cache files in days

DEFAULT_MAX_CACHE_AGE
key_gen_func Callable | None

Optional function to generate the cache key dictionary. If None, a default key generation logic is used.

None

Returns:

Name Type Description
Callable Callable

Decorated function

Source code in xdflow/utils/cache_utils.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
def cache_result(
    prefix: str,
    max_size: int = DEFAULT_MAX_CACHE_SIZE,
    max_age_days: float = DEFAULT_MAX_CACHE_AGE,
    key_gen_func: Callable | None = None,
) -> Callable:
    """
    Decorator that caches function results based on all function and class instance parameters.

    Args:
        prefix: Prefix for the cache directory (e.g., 'preprocess', 'featurize')
        max_size: Maximum cache size in bytes for this prefix
        max_age_days: Maximum age of cache files in days
        key_gen_func: Optional function to generate the cache key dictionary.
                      If None, a default key generation logic is used.

    Returns:
        Callable: Decorated function
    """

    def decorator(func: Callable) -> Callable:
        @wraps(func)
        def wrapper(*args, **kwargs):
            # Get the instance if this is an instance method
            instance = args[0] if args and hasattr(args[0], "__dict__") else None

            if instance and not getattr(instance, "use_cache", False):
                return func(*args, **kwargs)

            # Clean up cache files
            _cleanup_old_cache_files(max_age_days, prefix)
            _enforce_cache_size_limit(max_size, prefix)

            print("Checking cache")

            if key_gen_func and instance:
                key_dict = key_gen_func(func, *args, **kwargs)
            else:
                # Get arguments of function
                bound_args = inspect.signature(func).bind(*args, **kwargs)
                bound_args.apply_defaults()
                key_dict = {
                    k: _get_object_metadata(v) for k, v in bound_args.arguments.items() if k != "self" or not instance
                }
                if instance:
                    # Make sure the class instance is the same
                    key_dict["config"] = _get_object_params(instance)
                    key_dict["module_hash"] = _get_module_hash_from_obj(instance)

            cache_key = hash_dict(key_dict)
            cache_path = _get_cache_dir(prefix) / f"{cache_key}.pkl"

            print(f"Cache path: {cache_path}", "Cache exists:", cache_path.exists())

            # Return cached result if it exists
            if cache_path.exists():
                try:
                    with open(cache_path, "rb") as f:
                        # print(f"Loading cached result from {cache_path}")
                        return pickle.load(f)
                except (pickle.UnpicklingError, EOFError, FileNotFoundError):
                    # If cache is corrupted or gets deleted mid-read, recompute
                    pass

            # Compute result and cache it
            result = func(*args, **kwargs)
            try:
                with open(cache_path, "wb") as f:
                    pickle.dump(result, f)
            except PermissionError:
                # Skip caching if filesystem forbids writing (e.g., immutable files/attrs)
                pass

            return result

        return wrapper

    return decorator

get_pipeline_cache_key_dict

get_pipeline_cache_key_dict(func: Callable, instance: Any, *args, **kwargs) -> dict[str, Any]

Generate a cache key dictionary for a pipeline.

This function creates a detailed dictionary that includes: - The function's arguments. - The configuration of the pipeline instance and all its nested transforms. - The code hashes of the modules of the pipeline and all its transforms.

Parameters:

Name Type Description Default
func Callable

The function being called (e.g., fit_transform).

required
instance Any

The pipeline instance.

required
*args

Positional arguments to the function.

()
**kwargs

Keyword arguments to the function.

{}

Returns:

Type Description
dict[str, Any]

A dictionary to be hashed for the cache key.

Source code in xdflow/utils/cache_utils.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def get_pipeline_cache_key_dict(func: Callable, instance: Any, *args, **kwargs) -> dict[str, Any]:
    """
    Generate a cache key dictionary for a pipeline.

    This function creates a detailed dictionary that includes:
    - The function's arguments.
    - The configuration of the pipeline instance and all its nested transforms.
    - The code hashes of the modules of the pipeline and all its transforms.

    Args:
        func: The function being called (e.g., fit_transform).
        instance: The pipeline instance.
        *args: Positional arguments to the function.
        **kwargs: Keyword arguments to the function.

    Returns:
        A dictionary to be hashed for the cache key.
    """
    # 1. Get function arguments
    bound_args = inspect.signature(func).bind(instance, *args, **kwargs)
    bound_args.apply_defaults()
    key_dict = {k: _get_object_metadata(v) for k, v in bound_args.arguments.items() if k != "self"}

    # 2. Get instance configuration (recursively)
    key_dict["config"] = _get_object_params(instance)

    # 3. Get module hashes for the instance and all its children
    all_objects = [instance] + _get_all_children(instance)
    module_hashes = {}
    for obj in all_objects:
        try:
            module = inspect.getmodule(obj.__class__)
            if module is None:
                continue
            module_path = inspect.getsourcefile(module)
            module_hash = _get_module_hash_from_obj(obj)
            if module_path and module_hash is not None and module_path not in module_hashes:
                module_hashes[module_path] = module_hash
        except TypeError:
            pass  # Happens for some built-in types

    key_dict["module_hashes"] = module_hashes

    return key_dict

clear_cache

clear_cache(prefix: str | None = None) -> None

Clear the cache for a given prefix or all caches if no prefix is specified.

Parameters:

Name Type Description Default
prefix str | None

Optional prefix to clear specific cache directory

None
Source code in xdflow/utils/cache_utils.py
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
def clear_cache(prefix: str | None = None) -> None:
    """
    Clear the cache for a given prefix or all caches if no prefix is specified.

    Args:
        prefix: Optional prefix to clear specific cache directory
    """
    cache_root = _get_cache_root()
    if prefix:
        cache_dir = cache_root / prefix
        if cache_dir.exists():
            shutil.rmtree(cache_dir)
    else:
        if cache_root.exists():
            shutil.rmtree(cache_root)

Sampling Utilities

get_container_by_conditions

get_container_by_conditions(container: DataContainer, conditions: dict) -> DataContainer

Get a container by conditions.

Source code in xdflow/utils/sampling.py
12
13
14
15
16
def get_container_by_conditions(container: DataContainer, conditions: dict) -> DataContainer:
    """
    Get a container by conditions.
    """
    return DataContainer(get_da_by_conditions(container.data, conditions))

get_da_by_conditions

get_da_by_conditions(da: DataArray, conditions: dict[str, Any]) -> xr.DataArray

Select a DataArray subset based on flexible coordinate conditions.

Each condition can be
  • single value → equality
  • list → membership
  • tuple of 2 → range (inclusive)
  • dict with comparison operator → inequalities e.g. {'>': 5}, {'<=': 10}

Example: conditions = { "latitude": {">": 15}, # latitude > 15 "time": {"<=": 2}, # time <= 2 "depth": (10, 30), # between 10 and 30 (inclusive) "channel": [1, 3, 5], # in [1, 3, 5] "animal": 35 # equals 35 }

Source code in xdflow/utils/sampling.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def get_da_by_conditions(da: xr.DataArray, conditions: dict[str, Any]) -> xr.DataArray:
    """
    Select a DataArray subset based on flexible coordinate conditions.

    Each condition can be:
      - single value → equality
      - list → membership
      - tuple of 2 → range (inclusive)
      - dict with comparison operator → inequalities
        e.g. {'>': 5}, {'<=': 10}

    Example:
    conditions = {
        "latitude": {">": 15},          # latitude > 15
        "time": {"<=": 2},              # time <= 2
        "depth": (10, 30),              # between 10 and 30 (inclusive)
        "channel": [1, 3, 5],           # in [1, 3, 5]
        "animal": 35                    # equals 35
    }
    """
    mask_dict = {}
    for key, value in conditions.items():
        coord = da[key]

        if len(coord.dims) != 1:
            raise ValueError(
                f"Coordinate {key} has {len(coord.dims)} dimensions, conditions must be applied to a single dimension"
            )
        coord_dim = coord.dims[0]

        mask = mask_dict.get(coord_dim, True)

        # --- Operator-based conditions ---
        if isinstance(value, dict):
            for op, val in value.items():
                if op == ">":
                    mask &= coord > val
                elif op == ">=":
                    mask &= coord >= val
                elif op == "<":
                    mask &= coord < val
                elif op == "<=":
                    mask &= coord <= val
                elif op == "!=":
                    if isinstance(val, (list)):
                        mask &= ~coord.isin(val)
                    else:
                        mask &= coord != val
                else:
                    raise ValueError(f"Unsupported operator: {op}")

        # --- Range ---
        elif isinstance(value, tuple) and len(value) == 2 and not isinstance(value[0], bool):
            mask &= (coord >= value[0]) & (coord <= value[1])

        # --- Membership ---
        elif isinstance(value, list):
            mask &= coord.isin(value)

        # --- Equality ---
        else:
            mask &= coord == value

        mask_dict[coord_dim] = mask

    for key in mask_dict:
        da = da.where(mask_dict[key], drop=True)

    return da

train_test_split_container

train_test_split_container(container: DataContainer, target_coord: str, test_size: float = 0.2, random_state: int | None = None) -> tuple[DataContainer, DataContainer]

Split a container into train and test sets.

Source code in xdflow/utils/sampling.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def train_test_split_container(
    container: DataContainer, target_coord: str, test_size: float = 0.2, random_state: int | None = None
) -> tuple[DataContainer, DataContainer]:
    """
    Split a container into train and test sets.
    """
    all_trials = container.data.trial.values

    if test_size is None or test_size == 0:
        # No holdout set - use all data for cross-validation
        train_indices, test_indices = all_trials, np.array([])
    else:
        # Get labels for stratification
        if target_coord not in container.data.coords:
            raise ValueError(f"Target coordinate '{target_coord}' not found in container coords.")
        labels = container.data.coords[target_coord].values

        # Perform stratified split
        train_indices, test_indices = train_test_split(
            all_trials, test_size=test_size, stratify=labels, random_state=random_state
        )

    train_container = DataContainer(container.data.sel(trial=train_indices))
    test_container = DataContainer(container.data.sel(trial=test_indices))

    return train_container, test_container

stratified_sample

stratified_sample(da, coord_name, max_samples_per_class=10, random_state=None) -> xr.DataArray

Perform stratified sampling on categorical coordinates.

TODO: add support for non-categorical coordinates and balanced classes.

Parameters:

da : xr.DataArray Input data array coord_name : str Name of categorical coordinate to stratify on max_samples_per_class : int Maximum number of samples per category/class random_state : int, optional Random seed for reproducibility

Returns:

xr.DataArray Stratified sample with max_samples_per_class from each category

Source code in xdflow/utils/sampling.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def stratified_sample(da, coord_name, max_samples_per_class=10, random_state=None) -> xr.DataArray:
    """
    Perform stratified sampling on categorical coordinates.
    #TODO: add support for non-categorical coordinates and balanced classes.

    Parameters:
    -----------
    da : xr.DataArray
        Input data array
    coord_name : str
        Name of categorical coordinate to stratify on
    max_samples_per_class : int
        Maximum number of samples per category/class
    random_state : int, optional
        Random seed for reproducibility

    Returns:
    --------
    xr.DataArray
        Stratified sample with max_samples_per_class from each category
    """
    if random_state is not None:
        np.random.seed(random_state)

    # Get unique categories
    coord_values = da.coords[coord_name].values

    # Sample from each category
    sampled_indices = sample_by_max_count(np.arange(len(coord_values)), coord_values, max_samples_per_class)

    if coord_name not in da.dims:
        dim_names = da.coords[coord_name].dims
        if len(dim_names) > 1:
            raise ValueError(f"Coordinate {coord_name} has multiple dimensions: {dim_names}")
        dim_name = dim_names[0]
    else:
        dim_name = coord_name

    # Return stratified sample using xarray's isel
    return da.isel({dim_name: sampled_indices})

sample_by_max_count

sample_by_max_count(indices: ndarray, labels: ndarray, max_samples: int) -> np.ndarray

Sample up to max_samples from each class.

Source code in xdflow/utils/sampling.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def sample_by_max_count(indices: np.ndarray, labels: np.ndarray, max_samples: int) -> np.ndarray:
    """Sample up to max_samples from each class."""
    sampled_indices = []

    for class_label in np.unique(labels):
        class_mask = labels == class_label
        class_indices = indices[class_mask]
        class_size = len(class_indices)

        n_samples = min(class_size, max_samples)

        if n_samples < max_samples:
            warnings.warn(f"Class {class_label} has only {class_size} samples, requested {max_samples}")

        sampled = np.random.choice(class_indices, size=n_samples, replace=False)
        sampled_indices.append(sampled)

    return np.concatenate(sampled_indices)

sample_by_fraction

sample_by_fraction(indices: ndarray, labels: ndarray, all_labels: ndarray, sample_fraction: float) -> np.ndarray

Sample a fraction of each class based on the whole dataset.

Source code in xdflow/utils/sampling.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def sample_by_fraction(
    indices: np.ndarray, labels: np.ndarray, all_labels: np.ndarray, sample_fraction: float
) -> np.ndarray:
    """Sample a fraction of each class based on the whole dataset."""
    sampled_indices = []

    for class_label in np.unique(labels):
        class_mask = labels == class_label
        class_indices = indices[class_mask]
        class_size = len(class_indices)

        # Base fraction on the whole dataset class size
        total_class_size = np.sum(all_labels == class_label)
        n_samples = max(1, int(total_class_size * sample_fraction))
        # Can't sample more than available
        n_samples = min(n_samples, class_size)

        if n_samples > class_size:
            warnings.warn(f"Class {class_label} has only {class_size} samples in this fold, requested {n_samples}")

        sampled = np.random.choice(class_indices, size=n_samples, replace=False)
        sampled_indices.append(sampled)

    return np.concatenate(sampled_indices)

get_group_dim

get_group_dim(container: DataContainer, group_coord: str) -> str

Resolves the dimension that the group_coord indexes.

Source code in xdflow/utils/sampling.py
206
207
208
209
210
211
212
213
214
215
216
217
def get_group_dim(container: DataContainer, group_coord: str) -> str:
    """Resolves the dimension that the group_coord indexes."""
    if group_coord not in container.data.coords:
        raise ValueError(f"Group coordinate '{group_coord}' not found in data coordinates")

    coord_dims = container.data.coords[group_coord].dims
    if len(coord_dims) != 1:
        raise ValueError(
            f"Group coordinate '{group_coord}' must index exactly one dimension, "
            f"but it indexes {len(coord_dims)}: {coord_dims}"
        )
    return coord_dims[0]

discover_groups

discover_groups(container: DataContainer, group_coord: str) -> list[Hashable]

Discovers unique group values from the data.

Source code in xdflow/utils/sampling.py
220
221
222
223
def discover_groups(container: DataContainer, group_coord: str) -> list[Hashable]:
    """Discovers unique group values from the data."""
    group_values = container.data.coords[group_coord].values
    return sorted(np.unique(group_values).tolist())

select_group

select_group(container: DataContainer, group_coord: str, group_val: Hashable) -> DataContainer

Selects data for a specific group using boolean indexing.

Source code in xdflow/utils/sampling.py
226
227
228
229
230
def select_group(container: DataContainer, group_coord: str, group_val: Hashable) -> DataContainer:
    """Selects data for a specific group using boolean indexing."""
    group_mask = container.data.coords[group_coord] == group_val
    group_data = container.data.where(group_mask, drop=True)
    return DataContainer(group_data)

Target Utilities

resolve_target_coords

resolve_target_coords(target_coord: str | Sequence[str], data: DataArray) -> list[str]

Accept a single coord name, an explicit list/tuple of coord names, or a glob pattern (e.g., "*_target"). Returns a validated list of coord names present in data.

Pattern matching is only activated when the string contains a * wildcard character. Explicit coordinate names are matched exactly.

Source code in xdflow/utils/target_utils.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def resolve_target_coords(target_coord: str | Sequence[str], data: xr.DataArray) -> list[str]:
    """
    Accept a single coord name, an explicit list/tuple of coord names, or a glob pattern (e.g., ``"*_target"``).
    Returns a validated list of coord names present in ``data``.

    Pattern matching is only activated when the string contains a ``*`` wildcard character.
    Explicit coordinate names are matched exactly.
    """
    if isinstance(target_coord, (list, tuple)):
        missing = [coord for coord in target_coord if coord not in data.coords]
        if missing:
            raise ValueError(
                f"Target coordinates not found in DataArray: {missing}. Available: {list(data.coords.keys())}"
            )
        return list(target_coord)

    if isinstance(target_coord, str):
        # Only treat as pattern if it contains a wildcard
        if "*" in target_coord:
            matches = [coord for coord in data.coords.keys() if fnmatch.fnmatch(coord, target_coord)]
            if not matches:
                raise ValueError(
                    f"No coordinates found matching pattern '{target_coord}'. Available: {list(data.coords.keys())}"
                )
            return sorted(matches)
        # Exact string match required (no implicit pattern expansion)
        if target_coord not in data.coords:
            raise ValueError(f"Required target coordinate '{target_coord}' not found in DataArray.")
        return [target_coord]

    raise ValueError(f"target_coord must be string or list, got {type(target_coord)}")

extract_target_array

extract_target_array(target_coord: str | Sequence[str], data: DataArray, validate: bool = True) -> np.ndarray

Resolve target coordinates and return a stacked numpy array.

Parameters:

Name Type Description Default
target_coord str | Sequence[str]

String pattern/name or iterable of coord names. If validate is False and target_coord is a list/tuple, it is assumed to be pre-resolved.

required
data DataArray

DataArray containing target coords.

required
validate bool

Whether to validate/resolve the target coordinates. Defaults to True.

True
Source code in xdflow/utils/target_utils.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def extract_target_array(target_coord: str | Sequence[str], data: xr.DataArray, validate: bool = True) -> np.ndarray:
    """
    Resolve target coordinates and return a stacked numpy array.

    Args:
        target_coord: String pattern/name or iterable of coord names. If ``validate`` is False and
            ``target_coord`` is a list/tuple, it is assumed to be pre-resolved.
        data: DataArray containing target coords.
        validate: Whether to validate/resolve the target coordinates. Defaults to True.
    """
    if not validate and isinstance(target_coord, (list, tuple)):
        targets = list(target_coord)
    else:
        targets = resolve_target_coords(target_coord, data)

    return (
        data.coords[targets[0]].values
        if len(targets) == 1
        else np.column_stack([data.coords[t].values for t in targets])
    )

Sample-Weight Utilities

extract_sample_weights

extract_sample_weights(data: DataArray, sample_dim: str, coord_name: str | None, sample_index: Index | DataArray | None = None) -> np.ndarray | None

Extract 1D sample weights aligned to a sample dimension.

Source code in xdflow/utils/sample_weights.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def extract_sample_weights(
    data: xr.DataArray,
    sample_dim: str,
    coord_name: str | None,
    sample_index: pd.Index | xr.DataArray | None = None,
) -> np.ndarray | None:
    """Extract 1D sample weights aligned to a sample dimension."""
    if not coord_name:
        return None
    if coord_name not in data.coords:
        return None

    if sample_dim not in data.dims:
        raise ValueError(f"Sample dimension '{sample_dim}' not found in data dims {data.dims}.")

    weights = data.coords[coord_name]
    if sample_dim not in weights.dims:
        raise ValueError(f"Sample weight coordinate '{coord_name}' must include the sample dimension '{sample_dim}'.")

    if sample_index is None:
        sample_index = data.coords[sample_dim]

    try:
        aligned = weights.reindex({sample_dim: sample_index})
    except ValueError:
        aligned = weights.sel({sample_dim: sample_index})

    aligned = aligned.transpose(sample_dim, ...)
    if aligned.ndim != 1:
        raise ValueError(
            f"Sample weight coordinate '{coord_name}' must be 1D over '{sample_dim}', "
            f"but has dimensions {aligned.dims}."
        )

    weight_values = aligned.astype(float).values
    if np.isnan(weight_values).any():
        raise ValueError(f"Sample weight coordinate '{coord_name}' contains NaN values after alignment.")

    expected_len = len(sample_index)
    if weight_values.shape[0] != expected_len:
        raise ValueError(
            f"Sample weight coordinate '{coord_name}' length ({weight_values.shape[0]}) "
            f"does not match the number of samples ({expected_len})."
        )

    return weight_values

Spectral Utilities

bandpass_filter

bandpass_filter(data, lowcut, highcut, order=4, fs=500, causal=False, axis=-1)

Parameters:

Name Type Description Default
data
required
lowcut
required
highcut
required
order

(Default value = 4)

4
fs

Default value = 500)

500
causal

Default value = False)

False
axis

Default value = -1)

-1

Returns:

Source code in xdflow/utils/spectral.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def bandpass_filter(data, lowcut, highcut, order=4, fs=500, causal=False, axis=-1):
    """

    Args:
      data:
      lowcut:
      highcut:
      order: (Default value = 4)
      fs: Default value = 500)
      causal: Default value = False)
      axis: Default value = -1)

    Returns:

    """
    b, a = butter(order, [lowcut, highcut], btype="bandpass", fs=fs, output="ba")
    if causal:
        y = lfilter(b, a, data, axis=axis)
    else:
        y = filtfilt(b, a, data, axis=axis)
    return y

get_remove_freq_ranges

get_remove_freq_ranges(num_bands_remove, freqs, remove_high=True)

Removes a specified number of frequency bands from the frequency ranges dictionary, starting with high or low frequency.

Parameters:

Name Type Description Default
num_bands_remove

Number of frequency bands to remove

required
freqs

Dictionary of frequency ranges (e.g., {'theta': (4, 8), 'beta': (13, 30)})

required
remove_high

Boolean indicating whether to remove high (Default value = True)

True

Returns:

Type Description

Modified frequency ranges dictionary with the specified number of frequency bands removed.

Source code in xdflow/utils/spectral.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def get_remove_freq_ranges(num_bands_remove, freqs, remove_high=True):
    """Removes a specified number of frequency bands from the frequency ranges dictionary, starting with high or low frequency.

    Args:
      num_bands_remove: Number of frequency bands to remove
      freqs: Dictionary of frequency ranges (e.g., {'theta': (4, 8), 'beta': (13, 30)})
      remove_high: Boolean indicating whether to remove high (Default value = True)

    Returns:
      Modified frequency ranges dictionary with the specified number of frequency bands removed.

    """
    if freqs is None:
        raise ValueError("freqs parameter is required - provide a dictionary of frequency ranges")
    freq_bands_can_remove = list(freqs.keys())

    # Reverse the list if removing high-frequency bands
    if remove_high:
        freq_bands_can_remove = freq_bands_can_remove[::-1]

    for i in range(num_bands_remove):
        if i < len(freq_bands_can_remove):
            freqs.pop(freq_bands_can_remove[i], None)

    return freqs

get_freq_band_indices

get_freq_band_indices(frequencies, low, high)

Returns the indices of the beginning and end of a frequency band.

Parameters:

Name Type Description Default
frequencies

Sorted array of frequencies

required
low

Lower bound of the frequency band

required
high

Upper bound of the frequency band

required

Returns:

Type Description

List with start and end indices of the frequency band.

Source code in xdflow/utils/spectral.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def get_freq_band_indices(frequencies, low, high):
    """Returns the indices of the beginning and end of a frequency band.

    Args:
      frequencies: Sorted array of frequencies
      low: Lower bound of the frequency band
      high: Upper bound of the frequency band

    Returns:
      List with start and end indices of the frequency band.

    """
    low_index = np.searchsorted(frequencies, low, side="left")
    high_index = np.searchsorted(frequencies, high, side="right")
    return [low_index, high_index]

Visualization Utilities

plot_confusion_matrix

plot_confusion_matrix(confusion_matrix: ndarray, labels: Iterable[Any], want_plot: bool = False, want_confus: bool = False, save_as: str | None = None, title: str = 'Confusion Matrix', test_trues: Iterable[Any] | None = None, ylabels: Iterable[Any] | None = None, xlabels: Iterable[Any] | None = None, ax=None, show_plot: bool = True, show_annotations: bool = True, cmap: str = 'Blues')

Plot a confusion matrix heatmap with optional annotations.

Returns:

Type Description

The matplotlib module if want_plot is True.

Source code in xdflow/utils/visualizations.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def plot_confusion_matrix(
    confusion_matrix: np.ndarray,
    labels: Iterable[Any],
    want_plot: bool = False,
    want_confus: bool = False,
    save_as: str | None = None,
    title: str = "Confusion Matrix",
    test_trues: Iterable[Any] | None = None,
    ylabels: Iterable[Any] | None = None,
    xlabels: Iterable[Any] | None = None,
    ax=None,
    show_plot: bool = True,
    show_annotations: bool = True,
    cmap: str = "Blues",
):
    """
    Plot a confusion matrix heatmap with optional annotations.

    Returns:
        The matplotlib module if want_plot is True.
    """
    plt = _require_matplotlib()
    cm = np.array(confusion_matrix)
    labels = list(labels)

    if ylabels is None:
        ylabels = labels
    else:
        ylabels = list(ylabels)
    if xlabels is None:
        xlabels = labels
    else:
        xlabels = list(xlabels)

    if ax is None:
        fig, ax = plt.subplots(figsize=(9, 7))
    else:
        fig = ax.figure

    im = ax.imshow(cm, interpolation="nearest", cmap=cmap, vmin=0, vmax=1)
    ax.set_title(title)
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")
    ax.set_xticks(np.arange(len(xlabels)))
    ax.set_yticks(np.arange(len(ylabels)))
    ax.set_xticklabels(xlabels, rotation=45, ha="right")
    ax.set_yticklabels(ylabels)

    if show_annotations and cm.size:
        for i in range(cm.shape[0]):
            for j in range(cm.shape[1]):
                ax.text(j, i, f"{cm[i, j] * 100:.2f}%", ha="center", va="center", color="black")

    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

    if save_as:
        fig.savefig(save_as, bbox_inches="tight")

    if not show_plot:
        plt.close(fig)

    if want_confus:
        return cm
    if want_plot:
        return plt
    return None

plot_combined_confusion_matrix

plot_combined_confusion_matrix(confusion_matrices: Iterable[ndarray], labels: Iterable[Any], f1_scores: Iterable[float] | None = None, sample_sizes: ndarray | None = None, test_trues: Iterable[Iterable[Any]] | None = None, want_plot: bool = False, want_confus: bool = False, title: str | None = None, save_as: str | None = None, xlabels: Iterable[Any] | None = None, ylabels: Iterable[Any] | None = None, cmap: str = 'Blues')

Plot mean confusion matrix with standard error annotations across folds.

Source code in xdflow/utils/visualizations.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def plot_combined_confusion_matrix(
    confusion_matrices: Iterable[np.ndarray],
    labels: Iterable[Any],
    f1_scores: Iterable[float] | None = None,
    sample_sizes: np.ndarray | None = None,
    test_trues: Iterable[Iterable[Any]] | None = None,
    want_plot: bool = False,
    want_confus: bool = False,
    title: str | None = None,
    save_as: str | None = None,
    xlabels: Iterable[Any] | None = None,
    ylabels: Iterable[Any] | None = None,
    cmap: str = "Blues",
):
    """
    Plot mean confusion matrix with standard error annotations across folds.
    """
    plt = _require_matplotlib()
    matrices = [np.array(cm) for cm in confusion_matrices]
    stacked = np.stack(matrices, axis=2)
    mean_matrix = np.mean(stacked, axis=2)
    std_matrix = np.std(stacked, axis=2)
    sem_matrix = std_matrix / np.sqrt(stacked.shape[2])

    if title is None:
        title = "Confusion Matrix"

    fig, ax = plt.subplots(figsize=(9, 7))
    im = ax.imshow(mean_matrix, interpolation="nearest", cmap=cmap, vmin=0, vmax=1)
    ax.set_title(title)
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")

    labels = list(labels)
    if ylabels is None:
        ylabels = labels
    else:
        ylabels = list(ylabels)
    if xlabels is None:
        xlabels = labels
    else:
        xlabels = list(xlabels)
    ax.set_xticks(np.arange(len(xlabels)))
    ax.set_yticks(np.arange(len(ylabels)))
    ax.set_xticklabels(xlabels, rotation=45, ha="right")
    ax.set_yticklabels(ylabels)

    for i in range(mean_matrix.shape[0]):
        for j in range(mean_matrix.shape[1]):
            ax.text(
                j,
                i,
                f"{mean_matrix[i, j] * 100:.2f}%\n±{sem_matrix[i, j] * 100:.2f}%",
                ha="center",
                va="center",
                color="black",
            )

    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

    if save_as:
        fig.savefig(save_as, bbox_inches="tight")

    if not want_plot:
        plt.close(fig)

    if want_confus:
        return mean_matrix, sem_matrix
    if want_plot:
        return plt
    return None

plot_tune_importances

plot_tune_importances(study, *, want_plot: bool = True)

Plot Optuna parameter importances for a study.

Parameters:

Name Type Description Default
study

Optuna study object.

required
want_plot bool

Whether to return the matplotlib module.

True
Source code in xdflow/utils/visualizations.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def plot_tune_importances(study, *, want_plot: bool = True):
    """
    Plot Optuna parameter importances for a study.

    Args:
        study: Optuna study object.
        want_plot: Whether to return the matplotlib module.
    """
    plt = _require_matplotlib()
    try:
        from optuna.importance import get_param_importances
    except Exception as exc:  # pragma: no cover - dependency guard
        raise ImportError("plot_tune_importances requires Optuna. Install with: pip install xdflow[tuning]") from exc

    importances = get_param_importances(study)
    if not importances:
        raise ValueError("No parameter importances found for the provided study.")

    labels = list(importances.keys())
    values = list(importances.values())

    fig, ax = plt.subplots(figsize=(8, 4))
    ax.barh(labels, values)
    ax.set_xlabel("Importance")
    ax.set_title("Optuna Parameter Importances")
    fig.tight_layout()

    if want_plot:
        return plt
    plt.close(fig)
    return None