Source code for pyriemann_qiskit.visualization.manifold
"""
Visualization of the covariance matrices on the SPD manifold.
"""
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial import ConvexHull
from ..utils.math import to_xyz
[docs]
def plot_cvx_hull(X, ax):
"""Plot the convex hull of a set of points.
Parameters
----------
X : ndarray, shape (n_matrices, 3)
A set of 3d points in Cartesian coordinates.
ax : Matplotlib.Axes
A figure where to plot the points of the convex hull.
Notes
-----
.. versionadded:: 0.2.0
"""
hull = ConvexHull(X)
for simplex in hull.simplices:
ax.plot(X[simplex, 0], X[simplex, 1], X[simplex, 2], "k--", alpha=0.2)
[docs]
def plot_manifold(X, y, plot_hull=False):
"""Plot spd matrices in 3d (cartesian coordinate system).
Parameters
----------
X : ndarray, shape (n_matrices, 2, 2)
A set of SPD matrices of size 2 x 2.
y : ndarray, shape (n_matrices,)
Labels for each matrix.
plot_hull : boolean, default=False
If True, plot the convex hull of X.
Notes
-----
.. versionadded:: 0.2.0
"""
if X.ndim != 3:
raise ValueError("Input `covs` has not 3 dimensions")
if X.shape[1] != 2 and X.shape[2] != 2:
raise ValueError("SPD matrices must have size 2 x 2")
classes = np.unique(y)
points = to_xyz(X)
points0 = points[y == classes[0]]
points1 = points[y == classes[1]]
fig = plt.figure()
ax = fig.add_subplot(projection="3d")
ax.scatter(
points0[:, 0], points0[:, 1], points0[:, 2], color="red", label=classes[0]
)
ax.scatter(
points1[:, 0],
points1[:, 1],
points1[:, 2],
alpha=0.5,
color="blue",
label=classes[1],
)
ax.legend(title="Classes", loc="upper center")
if plot_hull:
plot_cvx_hull(points, ax)
return fig