Source code for pyriemann_qiskit.utils.transfer

"""Transfer learning utilities compatible with MOABB's evaluation framework."""

import inspect
from copy import deepcopy
from time import time

import numpy as np

try:
    from mne.epochs import BaseEpochs
    from moabb.evaluations import CrossSubjectEvaluation
    from moabb.evaluations.splitters import CrossSubjectSplitter
    from moabb.evaluations.utils import (
        _create_save_path,
        _ensure_fitted,
        _save_model_cv,
    )

    HAS_MOABB = True
except ImportError:
    HAS_MOABB = False
from pyriemann.transfer import encode_domains
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.metrics import get_scorer
from sklearn.model_selection import GroupKFold, LeaveOneGroupOut, StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm


def _pipeline_accepts_groups(clf):
    """Return True if clf.fit() declares a groups parameter."""
    return "groups" in inspect.signature(clf.fit).parameters


def _pipeline_accepts_target_domain(clf):
    """Return True if clf.fit() declares a target_domain parameter."""
    return "target_domain" in inspect.signature(clf.fit).parameters


def _set_target_domain(estimator, target_domain):
    """Set target_domain on any step (or the estimator itself) that declares it."""
    if hasattr(estimator, "steps"):
        for _, step in estimator.steps:
            if hasattr(step, "target_domain"):
                step.set_params(target_domain=target_domain)
    elif hasattr(estimator, "target_domain"):
        estimator.set_params(target_domain=target_domain)


[docs] class Adapter(ClassifierMixin, BaseEstimator): """Meta-estimator bridging MOABB's evaluation interface with pyriemann TL pipelines. Handles preprocessing (raw → covariances), domain encoding, and target-domain propagation, then delegates to a plain pyriemann TL estimator. Parameters ---------- preprocessing : sklearn transformer Transforms raw input into covariance matrices. estimator : sklearn estimator Any pyriemann TL-compatible estimator (e.g. a pipeline of TLCenter/TLScale/TLRotate/TLClassifier, or MDWM) constructed with ``target_domain=None``. The actual value is injected via ``set_params`` at fit time from ``TLCrossSubjectEvaluation``. """
[docs] def __init__(self, preprocessing, estimator): self.preprocessing = preprocessing self.estimator = estimator
def fit(self, X, y, groups=None, target_domain=None): self.preprocessing_ = deepcopy(self.preprocessing) X_cov = self.preprocessing_.fit_transform(X, y) _, y_enc = encode_domains(X_cov, y, groups) self.estimator_ = deepcopy(self.estimator) _set_target_domain(self.estimator_, target_domain) self.estimator_.fit(X_cov, y_enc) self.classes_ = np.unique(y) return self def _apply_preprocessing(self, X): # Iterate steps directly to avoid Pipeline.transform's check_is_fitted # guard, which fails for stateless transformers (e.g. pyriemann's # Covariances) that set no fitted attributes in fit(). result = X for _, step in self.preprocessing_.steps: result = step.transform(result) return result def predict(self, X): return self.estimator_.predict(self._apply_preprocessing(X)) def predict_proba(self, X): return self.estimator_.predict_proba(self._apply_preprocessing(X))
if HAS_MOABB:
[docs] class TLCrossSubjectEvaluation(CrossSubjectEvaluation): """CrossSubjectEvaluation with group and target-domain forwarding. Extends CrossSubjectEvaluation with one behavioural change: Pipelines whose fit() declares ``groups`` receive the training subject IDs. Pipelines whose fit() also declares ``target_domain`` receive the test subject's ID as target_domain. All other pipelines are unaffected. **kwargs Forwarded to CrossSubjectEvaluation.__init__(). """ def _create_splitter(self): # Force the legacy path so our evaluate() override is called. # The new parallel path calls a module-level _evaluate_fold() that # cannot pass groups/target_domain to Adapter.fit(). return None # flake8: noqa: C901 def evaluate( self, dataset, pipelines, param_grid, process_pipeline, postprocess_pipeline=None, ): if not self.is_valid(dataset): raise AssertionError("Dataset is not appropriate for evaluation") run_pipes = {} for subject in dataset.subject_list: run_pipes.update( self.results.not_yet_computed( pipelines, dataset, subject, process_pipeline ) ) if len(run_pipes) == 0: return X, y, metadata = self.paradigm.get_data( dataset=dataset, return_epochs=self.return_epochs, return_raws=self.return_raws, cache_config=self.cache_config, postprocess_pipeline=postprocess_pipeline, process_pipelines=[process_pipeline], ) le = LabelEncoder() y = y if self.mne_labels else le.fit_transform(y) groups = metadata.subject.values sessions = metadata.session.values n_subjects = len(dataset.subject_list) scorer = get_scorer(self.paradigm.scoring) if self.n_splits is None: cv_class = LeaveOneGroupOut cv_kwargs = {} else: cv_class = GroupKFold cv_kwargs = {"n_splits": self.n_splits} n_subjects = self.n_splits self.cv = CrossSubjectSplitter( cv_class=cv_class, random_state=self.random_state, **cv_kwargs ) inner_cv = StratifiedKFold(3, shuffle=True, random_state=self.random_state) for cv_ind, (train, test) in enumerate( tqdm( self.cv.split(y, metadata), total=n_subjects, desc=f"{dataset.code}-CrossSubject", ) ): subject = groups[test[0]] run_pipes = self.results.not_yet_computed( pipelines, dataset, subject, process_pipeline ) for name, clf in run_pipes.items(): t_start = time() clf = self._grid_search( param_grid=param_grid, name=name, grid_clf=clf, inner_cv=inner_cv, ) try: if _pipeline_accepts_target_domain(clf): model = deepcopy(clf).fit( X[train], y[train], groups=groups[train], target_domain=str(groups[train[0]]), ) elif _pipeline_accepts_groups(clf): model = deepcopy(clf).fit( X[train], y[train], groups=groups[train] ) else: model = deepcopy(clf).fit(X[train], y[train]) _ensure_fitted(model) except Exception as e: import traceback print( f"\n[TLCrossSubjectEvaluation] ERROR fitting '{name}': {e}" ) traceback.print_exc() continue duration = time() - t_start if self.hdf5_path is not None and self.save_model: model_save_path = _create_save_path( hdf5_path=self.hdf5_path, code=dataset.code, subject=subject, session="", name=name, grid=self.search, eval_type="CrossSubject", ) _save_model_cv( model=model, save_path=model_save_path, cv_index=str(cv_ind), ) for session in np.unique(sessions[test]): ix = sessions[test] == session try: score = scorer(model, X[test[ix]], y[test[ix]]) except Exception as e: import traceback print( f"\n[TLCrossSubjectEvaluation] ERROR scoring '{name}': {e}" ) traceback.print_exc() continue nchan = ( X.info["nchan"] if isinstance(X, BaseEpochs) else X.shape[1] ) res = { "time": duration, "dataset": dataset, "subject": subject, "session": session, "score": score, "n_samples": len(train), "n_channels": nchan, "pipeline": name, } yield res