Source code for mvlearn.model_selection.split

# Authors: Alexander Chang, Gavin Mischler
#
# License: MIT

import numpy as np
from sklearn import model_selection as ms


[docs]def train_test_split(*inputs, **options): r''' Splits multi-view data into random train and test subsets. This utility wraps the train_test_split function from sklearn.model_selection for ease of use. Parameters ---------- inputs : sequence of indexables Allowed inputs are lists of numpy arrays, numpy arrays, lists, scipy-sparse matrices or pandas dataframes. test_size : float or int, default=None If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test samples. If None, the value is set to the complement of the train size. If train_size is also None, it will be set to 0.25. train_size: float or int, default=None If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the train split. If int, represents the absolute number of train samples. If None, the value is automatically set to the complement of the test size. random_state: int or RandomState instance, default=None Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function calls. shuffle: bool, default=True Whether or not to shuffle the data before splitting. If shuffle=False then stratify must be None. stratify: array-like, default=None If not None, data is split in a stratified fashion, using this as the class labels. Returns ------- splitting : list, length=2*len(arrays) List containing the train-test splits of each of the inputs. If a list of arrays or 3D array is one of the inputs, train_test_split operates on each subarray and puts them together into a list of arrays or 3D array for training and one for testing. Examples -------- >>> import numpy as np >>> from mvlearn.model_selection import train_test_split >>> Xs = np.arange(18).reshape((3, 3, 2)) >>> y = np.arange(3) >>> # Print the data >>> for i in range(len(data)): ... print('Xs[%d]' % i, Xs[i], sep='\n') >>> print('y', y, sep='\n') Xs[0] [[0 1] [2 3] [4 5]] Xs[1] [[ 6 7] [ 8 9] [10 11]] Xs[2] [[12 13] [14 15] [16 17]] y [0 1 2] >>> Xs_train, Xs_test, y_train, y_test = train_test_split(Xs, y, ... test_size=0.33, ... random_state=10) >>> # Print train set >>> for i in range(len(Xs_train)): ... print('Xs_train[%d]' % i, Xs_train[i], sep='\n') Xs_train[0] [[4 5] [2 3]] Xs_train[1] [[10 11] [ 8 9]] Xs_train[2] [[16 17] [14 15]] # Print test set >>> for i in range(len(Xs_test)): ... print('Xs_test[%d]' % i, Xs_test[i], sep='\n') Xs_test[0] [[0 1]] Xs_test[1] [[6 7]] Xs_test[2] [[12 1]] >>> print(y_train) [2 1] >>> print(y_test) [0] ''' splitting = [] for a in inputs: splits = None if isinstance(a, list) or (isinstance(a, np.ndarray) and a.ndim == 3): splits = ms.train_test_split(*a, **options) splits = (splits[::2], splits[1::2]) else: splits = ms.train_test_split(a, **options) for split in splits: splitting.append(split) return splitting