Source code for mvlearn.plotting.plot

# License: MIT

import matplotlib.pyplot as plt
import seaborn as sns
from ..utils.utils import check_Xs
from ..embed import MVMDS
import numpy as np


[docs]def crossviews_plot( Xs, labels=None, dimensions=None, figsize=(10, 10), title=None, cmap=None, show=True, context="notebook", equal_axes=False, ax_ticks=True, ax_labels=True, scatter_kwargs={}, fig_kwargs={}, ): r""" Plots each dimension fron one view against each dimension from a second view. If both views are the same, this reduces to a pairplot. Parameters ---------- Xs : list of array-likes or numpy.ndarray - Xs length: n_views - Xs[i] shape: (n_samples, n_features_i) The two views to plot against one another. If one view has fewer dimensions than the other, only that many will be plotted. labels : boolean, default=None Sets the labels of the samples. dimensions : array-like of ints, default=None The dimensions of the views to plot. If `None`, all dimensions up to the minimum between the views will be plotted. figsize : tuple, default=(10,10) Sets the grid figure size. title : string, default=None Sets the title of the grid. cmap : String, default=None Colormap argument for matplotlib.pyplot.scatter. show : boolean, default=False Shows the plots if true. Returns the objects otherwise. context : one of {'paper', 'notebook', 'talk', 'poster, None}, default='notebook' Sets the seaborn plotting context. equal_axes : boolean, default=False Equalizes the axes of the plots on the diagonals if true. ax_ticks : boolean, default=True Whether to have tick marks on the axes. ax_labels : boolean, default=True Whether to label the axes with the view and dimension numbers. scatter_kwargs : dict, default={} Additional matplotlib.pyplot.scatter arguments. fig_kwargs : dict, default={} Additional matplotlib.pyplot.subplots arguments. Returns ------- (fig, axes) : tuple of the figure and its axes. Only returned if `show=False`. Notes ----- Below is an example figure generated from 2 views with 2 features each. .. figure:: /figures/crossviews_plot_example.png :width: 250px :alt: Quick Visualization of Multi-view Data :align: center """ Xs = check_Xs(Xs) if dimensions is None: n = min(Xs[0].shape[1], Xs[1].shape[1]) dimensions = list(range(n)) else: if not isinstance(dimensions, (np.ndarray, list)): msg = "`dimensions` must be of type list or np.ndarray" raise ValueError(msg) elif min(dimensions) < 0 or max(dimensions) >= max( Xs[0].shape[1], Xs[1].shape[1] ): msg = "max or min of `dimensions` is too extreme." raise ValueError(msg) n = len(dimensions) fig, axes = plt.subplots(n, n, figsize=figsize, **fig_kwargs) sns.set_context(context) if n == 1: axes = np.asarray([axes]) for i, ax in enumerate(axes.flatten()): dim2 = dimensions[int(i / n)] dim1 = dimensions[i % n] if labels is None: ax.scatter( Xs[0][:, dim1], Xs[1][:, dim2], cmap=cmap, **scatter_kwargs ) else: ax.scatter( Xs[0][:, dim1], Xs[1][:, dim2], cmap=cmap, c=labels, **scatter_kwargs, ) if dim2 == n - 1 and ax_labels: ax.set_xlabel(f"View 1 Dim {dim1+1}") if dim1 == 0 and ax_labels: ax.set_ylabel(f"View 2 Dim {dim2+1}") if dim1 == dim2 and equal_axes: ax.axis("equal") if not ax_ticks: ax.set_xticks([]) ax.set_yticks([]) if title is not None: plt.tight_layout(rect=[0, 0.03, 1, 0.95]) plt.suptitle(title) else: plt.tight_layout() if show: plt.show() else: return (fig, axes)
[docs]def quick_visualize( Xs, labels=None, figsize=(5, 5), title=None, cmap=None, show=True, context="notebook", ax_ticks=True, ax_labels=True, scatter_kwargs={}, fig_kwargs={}, ): r""" Computes common principal components using MVMDS for dimensionality reduction and plots the multi-view data on a single 2D plot for easy visualization. This can be thought of as the multi-view analog of using PCA to decompose data and plot on principal components. See Also -------- mvlearn.embed.MVMDS Parameters ---------- Xs : list of array-likes or numpy.ndarray - Xs length: n_views - Xs[i] shape: (n_samples, n_features_i) The multi-view data to reduce to a single plot. labels : boolean, default=None Sets the labels of the samples. figsize : tuple, default=(5,5) Sets the figure size. title : string, default=None Sets the title of the figure. cmap : String, default=None Colormap argument for matplotlib.pyplot.scatter. show : boolean, default=True Shows the plots if true. Returns the objects otherwise. context : one of {'paper', 'notebook', 'talk', 'poster, None}, default='notebook' Sets the seaborn plotting context. ax_ticks : boolean, default=True Whether to have tick marks on the axes. ax_labels : boolean, default=True Whether to label the axes with the view and dimension numbers. scatter_kwargs : dict, default={} Additional matplotlib.pyplot.scatter arguments. fig_kwargs : dict, default={} Additional matplotlib.pyplot.figure arguments. Returns ------- fig : figure object Only returned if `show=False`. Notes ----- This function simply uses ``MVMDS`` with ``n_components=2`` to reduce arbitrarily many views of input data to 2-dimensions, then makes a scatter plot. .. figure:: /figures/quick_visualize.png :width: 250px :alt: Quick Visualization of Multi-view Data :align: center """ Xs = check_Xs(Xs) mvmds = MVMDS(n_components=2) Xs_reduced = mvmds.fit_transform(Xs) fig = plt.figure(figsize=figsize, **fig_kwargs) sns.set_context(context) if labels is None: plt.scatter( Xs_reduced[:, 0], Xs_reduced[:, 1], cmap=cmap, **scatter_kwargs ) else: plt.scatter( Xs_reduced[:, 0], Xs_reduced[:, 1], cmap=cmap, c=labels, **scatter_kwargs, ) if ax_labels: plt.xlabel("Component 1") plt.ylabel("Component 2") if not ax_ticks: plt.xticks([]) plt.yticks([]) if title is not None: plt.tight_layout(rect=[0, 0.03, 1, 0.95]) plt.title(title) else: plt.tight_layout() if show: plt.show() else: return fig