Source code for pyriemann_qiskit.classification.algorithms.cp_mdm
from warnings import warn
import numpy as np
from joblib import Parallel, delayed
from pyriemann.classification import MDM
from pyriemann.utils.utils import check_metric
from ...optimization.distance import distance_functions
from ...optimization.docplex import ClassicalOptimizer
from ...optimization.mean import mean_functions
from ...utils.utils import is_qfunction
[docs]
class CpMDM(MDM):
"""Quantum-enhanced MDM classifier
This class is a constraint programming (CP) implementation of the
Minimum Distance to Mean (MDM) [1]_, which can run with quantum optimization.
Only log-Euclidean distance between trial and class prototypes is supported
at the moment, but any type of metric can be used for centroid estimation.
Parameters
----------
optimizer : pyQiskitOptimizer, default=ClassicalOptimizer()
An instance of :class:`pyriemann_qiskit.optimization.docplex.pyQiskitOptimizer`.
Notes
-----
.. versionadded:: 0.4.2
.. versionchanged:: 0.6.0
Moved to algorithms sub-package
See Also
--------
pyriemann.classification.MDM
References
----------
.. [1] `Multiclass Brain-Computer Interface Classification by Riemannian
Geometry
<https://hal.archives-ouvertes.fr/hal-00681328>`_
A. Barachant, S. Bonnet, M. Congedo, and C. Jutten. IEEE Transactions
on Biomedical Engineering, vol. 59, no. 4, p. 920-928, 2012.
"""
[docs]
def __init__(self, optimizer=ClassicalOptimizer(), **params):
self.optimizer = optimizer
super().__init__(**params)
def fit(self, X, y, sample_weight=None):
"""Fit (estimates) the centroids.
Parameters
----------
X : ndarray, shape (n_trials, n_channels, n_channels)
ndarray of SPD matrices.
y : ndarray, shape (n_trials,)
labels corresponding to each trial.
sample_weight : None | ndarray shape (n_trials,), default=None
weights for each trial, not used.
Returns
-------
self : CpMDM instance
The CpMDM instance.
"""
self._metric_mean, self._metric_dist = check_metric(self.metric)
if is_qfunction(self._metric_mean):
self.classes_ = np.unique(y)
if sample_weight is None:
sample_weight = np.ones(X.shape[0])
mean_func = mean_functions[self._metric_mean]
self.covmeans_ = Parallel(n_jobs=self.n_jobs)(
delayed(mean_func)(
X[y == c],
sample_weight=sample_weight[y == c],
optimizer=self.optimizer,
)
for c in self.classes_
)
self.covmeans_ = np.stack(self.covmeans_, axis=0)
return self
else:
return super().fit(X, y, sample_weight)
def _predict_distances(self, X):
if is_qfunction(self._metric_dist):
distance = distance_functions[self._metric_dist]
if "hull" in self._metric_dist:
warn("qdistances to hull should not be use inside MDM")
weights = [distance(self.covmeans_, x, self.optimizer) for x in X]
else:
warn(
"q-distances for MDM are toy functions.\
Use pyRiemann distances instead."
)
weights = [distance(self.covmeans_, x, self.optimizer) for x in X]
return 1 - np.array(weights)
else:
return super()._predict_distances(X)