# License: MIT
import itertools
import numpy as np
# XXX I would use a nested import for matplotlib to make it a soft dep
import matplotlib.pyplot as plt
import tqdm
try:
import torch
except ModuleNotFoundError as error:
msg = (f"Error: {error}. torch dependencies required for this function. " +
"Please consult the mvlearn installation instructions at " +
"https://github.com/mvlearn/mvlearn to correctly install " +
"torch dependency.")
raise ModuleNotFoundError(msg)
from .base import BaseEmbed
from ..utils.utils import check_Xs
class _FullyConnectedNet(torch.nn.Module):
r"""
General torch module for a fully connected neural network.
- input_size: number of nodes in the first layer
- num_hidden_layers: number of hidden layers
- hidden_size: number of nodes in each hidden layer
- embedding_size: number of nodes in the output layer.
All are ints. Each hidden layer has the same number of nodes.
"""
def __init__(
self, input_size, hidden_size, num_hidden_layers, embedding_size
):
super().__init__()
assert num_hidden_layers >= 0, "can't have negative hidden layer count"
assert hidden_size >= 1, "hidden size must involve >= 1 node"
assert embedding_size >= 1, "embedding size must involve >= 1 node"
self.layers = torch.nn.ModuleList()
if num_hidden_layers == 0:
self.layers.append(torch.nn.Linear(input_size, embedding_size))
else:
self.layers.append(torch.nn.Linear(input_size, hidden_size))
for i in range(num_hidden_layers - 1):
self.layers.append(torch.nn.Linear(hidden_size, hidden_size))
self.layers.append(torch.nn.Linear(hidden_size, embedding_size))
def forward(self, x):
# Forward pass for the network. Pytorch automatically calculates
# backwards pass
for layer in self.layers[:-1]:
x = torch.nn.Sigmoid()(layer(x))
x = self.layers[-1](x) # no activation on last layer
return x
def param_count(self):
return np.sum([np.prod(s.shape) for s in self.parameters()])
[docs]class SplitAE(BaseEmbed):
r"""
Implements an autoencoder that creates an embedding of a view View1 and
from that embedding reconstructs View1 and another view View2, as
described in [#1Split]_.
Parameters
----------
hidden_size : int (default=64)
number of nodes in the hidden layers
num_hidden_layers : int (default=2)
number of hidden layers in each encoder or decoder net
embed_size : int (default=20)
size of the bottleneck vector in the autoencoder
training_epochs : int (default=10)
how many times the network trains on the full dataset
batch_size : int (default=16):
batch size while training the network
learning_rate : float (default=0.001)
learning rate of the Adam optimizer
print_info : bool (default=False)
whether or not to print errors as the network trains.
print_graph : bool (default=True)
whether or not to graph training loss
Attributes
----------
view1_encoder_ : torch.nn.Module
the View1 embedding network as a PyTorch module
view1_decoder_ : torch.nn.Module
the View1 decoding network as a PyTorch module
view2_decoder_ : torch.nn.Module
the View2 decoding network as a PyTorch module
Raises
------
ModuleNotFoundError
In order to run SplitAE, pytorch and other certain optional
dependencies must be installed. See the installation page for
details.
Notes
-----
.. figure:: /figures/splitAE.png
:width: 250px
:alt: SplitAE diagram
:align: center
In this figure :math:`\textbf{x}` is View1 and :math:`\textbf{y}`
is View2
Each encoder / decoder network is a fully connected neural net with
paramater count equal to:
.. math::
\left(\text{input_size} + \text{embed_size}\right) \cdot
\text{hidden_size} +
\sum_{1}^{\text{num_hidden_layers}-1}\text{hidden_size}^2
Where :math:`\text{input_size}` is the number of features in View1
or View2.
The loss that is reduced via gradient descent is:
.. math::
J = \left(p(f(\textbf{x})) - \textbf{x}\right)^2 +
\left(q(f(\textbf{x})) - \textbf{y}\right)^2
Where :math:`f` is the encoder, :math:`p` and :math:`q` are
the decoders, :math:`\textbf{x}` is View1,
and :math:`\textbf{y}` is View2.
References
----------
.. [#1Split] Wang, Weiran, et al. "On Deep Multi-View Representation
Learning." In Proceedings of the 32nd International Conference on
Machine Learning, 37:1083-1092, 2015.
For more extensive examples, see the ``tutorials`` for SplitAE in this
documentation.
"""
def __init__(
self,
hidden_size=64,
num_hidden_layers=2,
embed_size=20,
training_epochs=10,
batch_size=16,
learning_rate=0.001,
print_info=False,
print_graph=True,
):
self.hidden_size = hidden_size
self.embed_size = embed_size
self.num_hidden_layers = num_hidden_layers
self.training_epochs = training_epochs
self.batch_size = batch_size
self.learning_rate = learning_rate
self.print_info = print_info
self.print_graph = print_graph
def fit(self, Xs, validation_Xs=None, y=None):
r"""
Given two views, create and train the autoencoder.
Parameters
----------
Xs : list of array-likes or numpy.ndarray.
- Xs[0] is View1 and Xs[1] is View2
- Xs length: n_views, only 2 is currently supported for splitAE.
- Xs[i] shape: (n_samples, n_features_i)
validation_Xs : list of array-likes or numpy.ndarray
optional validation data in the same shape of Xs. If
:code:`print_info=True`, then validation error, calculated with
this data, will be printed as the network trains.
y : ignored
Included for API compliance.
"""
Xs = check_Xs(Xs, multiview=True, enforce_views=2)
assert (
Xs[0].shape[0] >= self.batch_size
), """batch size must be <= to
number of samples"""
assert self.batch_size > 0, """can't have negative batch size"""
assert (
self.training_epochs >= 0
), """can't train for negative amount of
times"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
view1 = torch.FloatTensor(Xs[0])
view2 = torch.FloatTensor(Xs[1])
self.view1_encoder_ = _FullyConnectedNet(
view1.shape[1], self.hidden_size,
self.num_hidden_layers, self.embed_size
).to(device)
self.view1_decoder_ = _FullyConnectedNet(
self.embed_size, self.hidden_size,
self.num_hidden_layers, view1.shape[1]
).to(device)
self.view2_decoder_ = _FullyConnectedNet(
self.embed_size, self.hidden_size,
self.num_hidden_layers, view2.shape[1]
).to(device)
self.view1_encoder_ = self.view1_encoder_
self.view1_decoder_ = self.view1_decoder_
self.view2_decoder_ = self.view2_decoder_
if self.print_graph:
print(
"Parameter counts: \nview1_encoder: {:,}\nview1_decoder: {:,}"
"\nview2_decoder: {:,}".format(
self.view1_encoder_.param_count(),
self.view1_decoder_.param_count(),
self.view2_decoder_.param_count(),
)
)
parameters = [
self.view1_encoder_.parameters(),
self.view1_decoder_.parameters(),
self.view2_decoder_.parameters(),
]
optim = torch.optim.Adam(
itertools.chain(*parameters), lr=self.learning_rate
)
n_samples = view1.shape[0]
epoch_train_errors = []
epoch_test_errors = []
for epoch in tqdm.tqdm(
range(self.training_epochs), disable=(not self.print_info)
):
batch_errors = []
for batch_num in range(n_samples // self.batch_size):
optim.zero_grad()
view1_batch = view1[
batch_num * self.batch_size:
(batch_num + 1) * self.batch_size
]
view2_batch = view2[
batch_num * self.batch_size:
(batch_num + 1) * self.batch_size
]
embedding = self.view1_encoder_(view1_batch.to(device))
view1_reconstruction = self.view1_decoder_(embedding)
view2_reconstruction = self.view2_decoder_(embedding)
view1_error = torch.nn.MSELoss()(
view1_reconstruction, view1_batch.to(device)
)
view2_error = torch.nn.MSELoss()(
view2_reconstruction, view2_batch.to(device)
)
total_error = view1_error + view2_error
total_error.backward()
optim.step()
batch_errors.append(total_error.item())
if self.print_info:
print(
"Average train error during epoch {} was {}".format(
epoch, np.mean(batch_errors)
)
)
epoch_train_errors.append(np.mean(batch_errors))
if validation_Xs is not None:
test_error = self._test_error(validation_Xs)
if self.print_info:
print(
"Average test error during epoch {} was {}\n".format(
epoch, test_error
)
)
epoch_test_errors.append(test_error)
if self.print_graph:
plt.plot(epoch_train_errors, label="train error")
if validation_Xs is not None:
plt.plot(epoch_test_errors, label="test error")
plt.title("Errors during training")
plt.xlabel("Epoch")
plt.ylabel("Error")
plt.legend()
plt.show()
return self
def _test_error(self, Xs):
# Calculates the error of the network on a set of data Xs
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
n_samples = Xs[0].shape[0]
validation_batch_size = self.batch_size
test_indices = np.random.choice(
n_samples, validation_batch_size, replace=False
)
view1_batch = torch.FloatTensor(Xs[0][test_indices])
view2_batch = torch.FloatTensor(Xs[1][test_indices])
with torch.no_grad():
embedding = self.view1_encoder_(view1_batch.to(device))
view1_reconstruction = self.view1_decoder_(embedding)
view2_reconstruction = self.view2_decoder_(embedding)
view1_error = torch.nn.MSELoss()(
view1_reconstruction, view1_batch.to(device)
)
view2_error = torch.nn.MSELoss()(
view2_reconstruction, view2_batch.to(device)
)
total_error = view1_error + view2_error
return total_error.item()
def transform(self, Xs):
r"""
Transform the given view with the trained autoencoder. Provide
a single view within a list.
Parameters
----------
Xs : a list of exactly one array-like, or an np.ndarray
Represents the View1 of some data. The array must have the same
number of columns (features) as the View1 presented
in the :code:`fit(...)` step.
- Xs length: 1
- Xs[0] shape: (n_samples, n_features_0)
Returns
----------
embedding : np.ndarray of shape (n_samples, embedding_size)
the embedding of the View1 data
view1_reconstructions : np.ndarray of shape (n_samples, n_features_0)
the reconstructed View1
view2_prediction : np.ndarray of shape (n_samples, n_features_1)
the predicted View2
"""
Xs = check_Xs(Xs, enforce_views=1)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
view1 = torch.FloatTensor(Xs[0])
with torch.no_grad():
embedding = self.view1_encoder_(view1.to(device))
view1_reconstruction = self.view1_decoder_(embedding)
view2_prediction = self.view2_decoder_(embedding)
return (
embedding.cpu().numpy(),
view1_reconstruction.cpu().numpy(),
view2_prediction.cpu().numpy(),
)
def fit_transform(self, Xs, y=None):
r"""
:code:`fit(Xs)` and then :code:`transform(Xs[:1])`.
Note that this method will be
embedding data that the autoencoder was trained on.
Parameters
----------
Xs : see :code:`fit(...)` Xs parameters
y : ignored
Included for API compliance.
Returns
----------
See :code:`transform(...)` return values.
"""
self.fit(Xs)
return self.transform(Xs[:1])