示例#1
0
def breast_cancer(x_train, y_train, x_val, y_val, params):
    print("Iteration parameters: ", params)

    def weights_init_uniform_rule(m):
        classname = m.__class__.__name__
        if classname.find('Linear') != -1:
            n = m.in_features
            y = 1.0 / np.sqrt(n)
            m.weight.data.uniform_(-y, y)
            m.bias.data.fill_(0)

    manager = DataManager.from_numpy(train_inputs=x_train,
                                     train_labels=y_train,
                                     batch_size=params["batch_size"],
                                     validation_inputs=x_val,
                                     validation_labels=y_val)
    net = BreastCancerNet(n_feature=x_train.shape[1],
                          first_neuron=params["first_neuron"],
                          second_neuron=params["second_neuron"],
                          dropout=params["dropout"])
    net.apply(weights_init_uniform_rule)
    net.init_history()
    model = DeepLearningInterface(model=net,
                                  optimizer_name=params["optimizer_name"],
                                  learning_rate=params["learning_rate"],
                                  loss_name=params["loss_name"],
                                  metrics=["accuracy"])
    model.add_observer("after_epoch", update_talos_history)
    model.training(manager=manager,
                   nb_epochs=params["epochs"],
                   checkpointdir=None,
                   fold_index=0,
                   with_validation=True)
    return net, net.parameters()
示例#2
0
    def setUp(self):
        """ Setup test.
        """
        data = fetch_cifar(datasetdir="/tmp/cifar")
        self.manager = DataManager(input_path=data.input_path,
                                   labels=["label"],
                                   metadata_path=data.metadata_path,
                                   number_of_folds=10,
                                   batch_size=10,
                                   stratify_label="category",
                                   test_size=0.1,
                                   sample_size=0.01)

        class Net(nn.Module):
            def __init__(self):
                super(Net, self).__init__()
                self.conv1 = nn.Conv2d(3, 6, 5)
                self.pool = nn.MaxPool2d(2, 2)
                self.conv2 = nn.Conv2d(6, 16, 5)
                self.fc1 = nn.Linear(16 * 5 * 5, 120)
                self.fc2 = nn.Linear(120, 84)
                self.fc3 = nn.Linear(84, 10)

            def forward(self, x):
                x = self.pool(func.relu(self.conv1(x)))
                x = self.pool(func.relu(self.conv2(x)))
                x = x.view(-1, 16 * 5 * 5)
                x = func.relu(self.fc1(x))
                x = func.relu(self.fc2(x))
                x = self.fc3(x)
                return x

        self.cl = DeepLearningInterface(model=Net(),
                                        optimizer_name="SGD",
                                        momentum=0.9,
                                        learning_rate=0.001,
                                        loss_name="CrossEntropyLoss",
                                        metrics=["accuracy"])
示例#3
0
# You may need to change the 'datasetdir' parameter.

import os
import numpy as np
from pynet.datasets import DataManager, fetch_echocardiography
from pynet.plotting import plot_data
from pynet.utils import setup_logging

setup_logging(level="info")

data = fetch_echocardiography(datasetdir="/tmp/echocardiography")
manager = DataManager(input_path=data.input_path,
                      metadata_path=data.metadata_path,
                      output_path=data.output_path,
                      number_of_folds=2,
                      stratify_label="label",
                      sampler="random",
                      batch_size=10,
                      test_size=0.1,
                      sample_size=0.2)
dataset = manager["test"]
print(dataset.inputs.shape, dataset.outputs.shape)
data = np.concatenate((dataset.inputs, dataset.outputs), axis=1)
plot_data(data, nb_samples=5)

#############################################################################
# Optimisation
# ------------
#
# From the available models load the UNet, and start the training.
# You may need to change the 'outdir' parameter.
示例#4
0
文件: dataset.py 项目: rlouiset/pynet
#############################################################################
# Import a pynet dataset
# ----------------------
#
# Use a fetcher to retrieve some data and use generic interface to import and
# split this dataset: train, test and validation.
# You may need to change the 'datasetdir' parameter.

from pynet.datasets import DataManager, fetch_cifar

data = fetch_cifar(datasetdir="/tmp/cifar")
manager = DataManager(input_path=data.input_path,
                      labels=["label"],
                      metadata_path=data.metadata_path,
                      number_of_folds=10,
                      batch_size=50,
                      stratify_label="category",
                      test_size=0.1)

#############################################################################
# We have now a test, and multiple folds with train-validation datasets that
# can be used to train our network using cross-validation.

import numpy as np
from pynet.plotting import plot_data

print("Nb folds: ", manager.number_of_folds)
dataloader = manager.get_dataloader(train=True,
                                    validation=False,
                                    test=False,
示例#5
0
    [np.round(np.mean(labels[tri])) for tri in ico_triangles])
plot_trisurf(fig, ax, ico_vertices, ico_triangles, tri_texture)
data = np.zeros((N_SAMPLES, N_CLASSES, len(labels)), dtype=float)
for klass in (0, 1):
    k_indices = np.argwhere(labels == 0).squeeze()
    for loc, scale in SAMPLES[klass]:
        data[:, klass, k_indices] = np.random.normal(loc=loc,
                                                     scale=scale,
                                                     size=len(k_indices))
labels = np.ones((N_SAMPLES, 1)) * labels
print("dataset: x {0} - y {1}".format(data.shape, labels.shape))

# Create data manager
manager = DataManager.from_numpy(train_inputs=data,
                                 train_labels=labels,
                                 test_inputs=data,
                                 test_labels=labels,
                                 batch_size=BATCH_SIZE)

# Create model
net_params = pynet.NetParameters(in_order=ICO_ORDER,
                                 in_channels=2,
                                 out_channels=N_CLASSES,
                                 depth=3,
                                 start_filts=32,
                                 conv_mode="1ring",
                                 up_mode="transpose",
                                 cachedir=os.path.join(OUTDIR, "cache"))
model = SphericalUNetEncoder(net_params,
                             optimizer_name="SGD",
                             learning_rate=0.1,
示例#6
0
# neural network is evaluated on the validation set, but not trained on it.
# If the validation loss starts to grow, it means that the network is
# overfitting the training set, and that it is time to stop the training.
#
# The following cell create stratified test, train, and validation loaders.

from pynet.datasets import fetch_orientation
from pynet.datasets import DataManager

data = fetch_orientation(
    datasetdir="/tmp/orientation",
    flatten=True)
manager = DataManager(
    input_path=data.input_path,
    labels=["label"],
    metadata_path=data.metadata_path,
    number_of_folds=10,
    batch_size=1000,
    stratify_label="label",
    test_size=0.1)


#############################################################################
# Displaying some images of the test dataset.

from pynet.plotting import plot_data

dataset = manager["test"]
sample = dataset.inputs.reshape(-1, data.height, data.width)
sample = np.expand_dims(sample, axis=1)
plot_data(sample, nb_samples=5)
示例#7
0
# Show example noisy training data that have the signatures applied.
# It's not obvious to the human eye the subtle differences, but the cross
# row and column above perturbed the below matrices with the y weights.
# Show in the title how much each signature is weighted by.
plt.figure(figsize=(16, 4))
for idx in range(3):
    plt.subplot(1, 3, idx + 1)
    plt.imshow(np.squeeze(x_train[idx]), interpolation="None")
    plt.colorbar()
    plt.title(y_train[idx])

manager = DataManager.from_numpy(train_inputs=x_train,
                                 train_labels=y_train,
                                 validation_inputs=x_valid,
                                 validation_labels=y_valid,
                                 test_inputs=x_test,
                                 test_labels=y_test,
                                 batch_size=128,
                                 continuous_labels=True)
interfaces = pynet.get_interfaces()["graph"]
net_params = pynet.NetParameters(input_shape=(90, 90),
                                 in_channels=1,
                                 num_classes=2,
                                 nb_e2e=32,
                                 nb_e2n=64,
                                 nb_n2g=30,
                                 dropout=0.5,
                                 leaky_alpha=0.1,
                                 twice_e2e=False,
                                 dense_sml=True)
my_loss = pynet.get_tools()["losses"]["MSELoss"]()
示例#8
0
the activation map of the last convolutional layer in our model.

Load the data
-------------

Load some images and apply the ImageNet transformation.
You may need to change the 'datasetdir' parameter.
"""

from pynet.datasets import DataManager, fetch_gradcam
from pynet.plotting import plot_data

data = fetch_gradcam(datasetdir="/tmp/gradcam")
manager = DataManager(input_path=data.input_path,
                      metadata_path=data.metadata_path,
                      number_of_folds=2,
                      batch_size=5,
                      test_size=1)
dataset = manager["test"]
print(dataset.inputs.shape)
plot_data(dataset.inputs, nb_samples=5, random=False, rgb=True)

#############################################################################
# Explore different architectures
# -------------------------------
#
# Let's automate this procedure for different networks.
# We need to reload the data for the inception network.
# You may need to change the 'datasetdir' parameter.

import os
示例#9
0
masker = MultiNiftiMasker(mask_img=mask_img, standardize=True)
masker.fit()
if not os.path.isfile(DATAFILE):
    y = np.concatenate(masker.transform(func_filenames), axis=0)
    print(y.shape)
    np.save(DATAFILE, y)
else:
    y = np.load(DATAFILE)
iterator = masker.inverse_transform(y).get_fdata()
iterator = iterator.transpose((3, 0, 1, 2))
iterator = np.expand_dims(iterator, axis=1)
print(iterator.shape)

# Data iterator
manager = DataManager.from_numpy(train_inputs=iterator,
                                 batch_size=BATCH_SIZE,
                                 add_input=True)

# Create model
name = "ResAENet"
model_weights = os.path.join(WORKDIR, "checkpoint_" + name,
                             "model_0_epoch_{0}.pth".format(EPOCH))
if os.path.isfile(model_weights):
    pretrained = model_weights
else:
    pretrained = None
params = NetParameters(input_shape=(61, 73, 61),
                       cardinality=1,
                       layers=[3, 4, 6, 3],
                       n_channels_in=1,
                       decode=True)
    # Show in the title how much each signature is weighted by.
    plt.figure(figsize=(16, 4))
    for idx in range(3):
        plt.subplot(1, 3, idx + 1)
        plt.imshow(np.squeeze(x_train[idx]), interpolation="None")
        plt.colorbar()
        plt.title(y_train[idx])

data = np.concatenate(data, axis=0)
labels = np.asarray(labels)
print("dataset: x {0} - y {1}".format(data.shape, labels.shape))


# Create data manager
manager = DataManager.from_numpy(
    train_inputs=data, train_labels=np.zeros(labels.shape),
    batch_size=BATCH_SIZE)


class FKmeans(object):
    def __init__(self, n_clusters):
        self.n_clusters = n_clusters

    def fit(self, data):
        n_data, d = data.shape
        self.clus = faiss.Kmeans(d, self.n_clusters)
        self.clus.seed = np.random.randint(1234)
        self.clus.niter = 20
        self.clus.max_points_per_centroid = 10000000
        self.clus.train(data)
示例#11
0
"""

#############################################################################
# Import the dataset
# ------------------
#
# You may need to change the 'datasetdir' parameter.

import numpy as np
from pynet.datasets import DataManager, fetch_echocardiography
from pynet.plotting import plot_data

data = fetch_echocardiography(datasetdir="/tmp/echocardiography")
manager = DataManager(input_path=data.input_path,
                      metadata_path=data.metadata_path,
                      output_path=data.output_path,
                      number_of_folds=10,
                      batch_size=10,
                      test_size=0.1)
dataset = manager["test"]
data = np.concatenate((dataset.inputs, dataset.outputs), axis=1)
plot_data(data, nb_samples=5)

#############################################################################
# Optimisation
# ------------
#
# From the available models load the UNet, and start the training.
# You may need to change the 'outdir' parameter.

import os
import torch
示例#12
0
from pynet.losses import get_vae_loss
from pynet.models.vae.utils import (reconstruct_traverse, make_mosaic_img,
                                    add_labels)

# Global parameters
WDIR = "/tmp/beta_vae_disentangling"
BATCH_SIZE = 64
N_EPOCHS = 30
ADAM_LR = 5e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DISPLAY = False

# Load the data
dataset = DSprites(WDIR)
manager = DataManager.from_dataset(train_dataset=dataset,
                                   batch_size=BATCH_SIZE,
                                   sampler="random")

# Test different losses

loss_params = {
    "betah": {
        "beta": 4,
        "steps_anneal": 0,
        "use_mse": True
    },
    "betab": {
        "C_init": 0.5,
        "C_fin": 25,
        "gamma": 100,
        "steps_anneal": 100000,
示例#13
0
board = Board(port=8097, host="http://localhost", env="data-augmentation")
compose_transforms = Transformer()
compose_transforms.register(flip,
                            probability=0.5,
                            axis=0,
                            apply_to=["input", "output"])
compose_transforms.register(add_blur,
                            probability=1,
                            sigma=4,
                            apply_to=["input"])
manager = DataManager(input_path=data.input_path,
                      metadata_path=data.metadata_path,
                      output_path=data.output_path,
                      number_of_folds=2,
                      batch_size=2,
                      test_size=0.1,
                      sample_size=0.1,
                      sampler=None,
                      add_input=True,
                      data_augmentation_transforms=[compose_transforms])
loaders = manager.get_dataloader(train=True, validation=False, fold_index=0)
for dataitem in loaders.train:
    print("-" * 50)
    print(dataitem.inputs.shape, dataitem.outputs.shape, dataitem.labels)
    images = [
        dataitem.inputs[0, 0].numpy(), dataitem.inputs[0, 1].numpy(),
        dataitem.outputs[0, 0].numpy(), dataitem.outputs[0, 1].numpy(),
        dataitem.outputs[0, 4].numpy(), dataitem.outputs[0, 5].numpy()
    ]
    images = np.asarray(images)
    images = np.expand_dims(images, axis=1)
示例#14
0
import os
import sys
from pynet.datasets import DataManager, fetch_height_biobank
from pynet.utils import setup_logging

# This example cannot run in CI : it accesses NS intra filesystems
if "CI_MODE" in os.environ:
    sys.exit(0)

setup_logging(level="info")

data = fetch_height_biobank(datasetdir="/neurospin/tmp/height_bb")
manager = DataManager(input_path=data.input_path,
                      labels=["Height"],
                      metadata_path=data.metadata_path,
                      number_of_folds=2,
                      batch_size=5,
                      test_size=0.2,
                      continuous_labels=True)

#############################################################################
# Basic inspection

import numpy as np
import matplotlib.pyplot as plt

train_dataset = manager["train"][0]
X_train = train_dataset.inputs[train_dataset.indices]
y_train = train_dataset.labels[train_dataset.indices]
test_dataset = manager["test"]
X_test = test_dataset.inputs[test_dataset.indices]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
losses = pynet.get_tools(tool_name="losses")
setup_logging(level="info")

#############################################################################
# Kang dataset
# ------------
#
# Fetch & load the Kang dataset:

data, trainset, testset, membership_mask = fetch_kang(datasetdir=datasetdir,
                                                      random_state=0)
gtpath = os.path.join(datasetdir, "kang_recons.h5ad")
manager = DataManager.from_numpy(train_inputs=trainset,
                                 validation_inputs=testset,
                                 test_inputs=data.X,
                                 batch_size=batch_size,
                                 sampler="random",
                                 add_input=True)

#############################################################################
# Training
# --------
#
# Create/train the model:

if checkpointdir is not None:
    weights_filename = os.path.join(
        checkpointdir, "model_0_epoch_{0}.pth".format(nb_epochs - 1))
params = NetParameters(membership_mask=membership_mask,
                       latent_dim=latent_dim,
                       hidden_layers=[12],
示例#16
0
#
# Use the fetcher of the pynet package.

from pynet.datasets import DataManager, fetch_brats
from pynet.plotting import plot_data
from pynet.transforms import RandomFlipDimensions, Offset

data = fetch_brats(datasetdir="/neurospin/nsap/datasets/brats")
manager = DataManager(
    input_path=data.input_path,
    metadata_path=data.metadata_path,
    output_path=data.output_path,
    projection_labels=None,
    number_of_folds=10,
    batch_size=1,
    stratify_label="grade",
    #input_transforms=[
    #    RandomFlipDimensions(ndims=3, proba=0.5, with_channels=True),
    #    Offset(nb_channels=4, factor=0.1)],
    sampler="random",
    add_input=True,
    test_size=0.1,
    pin_memory=True)
dataset = manager["test"][:1]
print(dataset.inputs.shape, dataset.outputs.shape)
plot_data(dataset.inputs, channel=1, nb_samples=5)
plot_data(dataset.outputs, channel=1, nb_samples=5)

#############################################################################
# Training
# --------
示例#17
0
import matplotlib.pyplot as plt

setup_logging(level="info")
logger = logging.getLogger("pynet")

use_toy = False
dtype = "all"

data = fetch_impac(datasetdir="/neurospin/nsap/datasets/impac",
                   mode="train",
                   dtype=dtype)
nb_features = data.nb_features
manager = DataManager(input_path=data.input_path,
                      labels=["participants_asd"],
                      metadata_path=data.metadata_path,
                      number_of_folds=3,
                      batch_size=128,
                      sampler="random",
                      test_size=2,
                      sample_size=1)

if use_toy:
    toy_data = {}
    nb_features = 50
    for name, nb_samples in (("train", 1000), ("test", 2)):
        x1 = torch.randn(nb_samples, 50)
        x2 = torch.randn(nb_samples, 50) + 1.5
        x = torch.cat([x1, x2], dim=0)
        y1 = torch.zeros(nb_samples, 1)
        y2 = torch.ones(nb_samples, 1)
        y = torch.cat([y1, y2], dim=0)
        toy_data[name] = (x, y)
示例#18
0
        if not isinstance(imgtype, list):
            imgtype = [imgtype]
        imgtype = [typemap[key] for key in imgtype]
    transformed_data = []
    for channel_id in range(len(data)):
        if channel_id not in imgtype:
            continue
        arr = data[channel_id]
        transformed_data.append(downsample(arr, scale=3))
    return np.asarray(transformed_data)


manager = DataManager(input_path=data.input_path,
                      metadata_path=data.metadata_path,
                      stratify_label="grade",
                      number_of_folds=10,
                      batch_size=batch_size,
                      test_size=0,
                      input_transforms=[transformer],
                      sample_size=0.2)

########################
# Loss
# ----


def calc_gradient_penalty(model, x, x_gen, w=10):
    """ WGAN-GP gradient penalty.
    """
    assert (x.size() == x_gen.size()), "Real and sampled sizes do not match."
    alpha_size = tuple((len(x), *(1, ) * (x.dim() - 1)))
    alpha_t = torch.cuda.FloatTensor if x.is_cuda else torch.Tensor
#
# A validation step is a useful way to avoid overfitting. At each epoch, the
# neural network is evaluated on the validation set, but not trained on it.
# If the validation loss starts to grow, it means that the network is
# overfitting the training set, and that it is time to stop the training.
#
# The following cell create stratified test, train, and validation loaders.

from pynet.datasets import fetch_orientation
from pynet.datasets import DataManager

data = fetch_orientation(datasetdir="/tmp/orientation", flatten=True)
manager = DataManager(input_path=data.input_path,
                      labels=["label"],
                      metadata_path=data.metadata_path,
                      number_of_folds=10,
                      batch_size=1000,
                      stratify_label="label",
                      test_size=0.1,
                      sample_size=(1 if "CI_MODE" not in os.environ else 0.1))

#############################################################################
# Displaying some images of the test dataset.

from pynet.plotting import plot_data

dataset = manager["test"]
sample = dataset.inputs.reshape(-1, data.height, data.width)
sample = np.expand_dims(sample, axis=1)
plot_data(sample, nb_samples=5)

#############################################################################
示例#20
0
from pynet.plotting import Board, update_board


#############################################################################
# The model will be trained on MNIST - handwritten digits dataset. The input
# is an image in R(28×28).

def flatten(arr):
    return arr.flatten()

data = fetch_minst(datasetdir="/neurospin/nsap/datasets/minst")
manager = DataManager(
    input_path=data.input_path,
    metadata_path=data.metadata_path,
    stratify_label="label",
    number_of_folds=10,
    batch_size=64,
    test_size=0,
    input_transforms=[flatten],
    add_input=True,
    sample_size=0.05)


#############################################################################
# The Model
# ---------
#
# The model is composed of two sub-networks:
#
# 1. Given x (image), encode it into a distribution over the latent space -
#    referred to as Q(z|x).
# 2. Given z in latent space (code representation of an image), decode it into
示例#21
0
import os
import numpy as np
from pynet.datasets import DataManager, fetch_echocardiography
from pynet.plotting import plot_data
from pynet.utils import setup_logging

setup_logging(level="info")

data = fetch_echocardiography(
    datasetdir="/tmp/echocardiography")
manager = DataManager(
    input_path=data.input_path,
    metadata_path=data.metadata_path,
    output_path=data.output_path,
    number_of_folds=2,
    stratify_label="label",
    sampler="weighted_random",
    batch_size=10,
    test_size=0.1,
    sample_size=(1 if "CI_MODE" not in os.environ else 0.05))
dataset = manager["test"]
print(dataset.inputs.shape, dataset.outputs.shape)
data = np.concatenate((dataset.inputs, dataset.outputs), axis=1)
plot_data(data, nb_samples=5)


#############################################################################
# Optimisation
# ------------
#
# From the available models load the UNet, and start the training.
示例#22
0
Load some data.
You may need to change the 'datasetdir' parameter.
"""

import os
from pynet.datasets import DataManager, fetch_genomic_pred
from pynet.utils import setup_logging

setup_logging(level="info")

data = fetch_genomic_pred(datasetdir="/tmp/genomic_pred")
manager = DataManager(input_path=data.input_path,
                      labels=["env0"],
                      metadata_path=data.metadata_path,
                      number_of_folds=2,
                      batch_size=5,
                      test_size=0.2,
                      continuous_labels=True)

#############################################################################
# Basic inspection

import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

train_dataset = manager["train"][0]
X_train = train_dataset.inputs[train_dataset.indices]
y_train = train_dataset.labels[train_dataset.indices]
test_dataset = manager["test"]
示例#23
0
from pynet.plotting import plot_history
from pynet.history import History
from pynet.losses import MSELoss, NCCLoss, RCNetLoss
import matplotlib.pyplot as plt

setup_logging(level="debug")
logger = logging.getLogger("pynet")

outdir = "/neurospin/nsap/tmp/registration"
data = fetch_registration(
    datasetdir=outdir)
manager = DataManager(
    input_path=data.input_path,
    metadata_path=data.metadata_path,
    number_of_folds=10,
    batch_size=1,
    sampler="random",
    #stratify_label="centers",
    test_size=0.1,
    add_input=True,
    sample_size=1)

#############################################################################
# Training
# --------
#
# From the available models load the VoxelMorphRegister, VTNetRegister or
# ADDNet  and start the training.
# Note that the two first estimate a non linear deformation and require
# the input data to be afinely registered. The ADDNet estimate an affine
# transform. We will see in the next section how to combine them in an
# efficient way.
示例#24
0
from pynet.history import History
from pynet.losses import MSELoss, NCCLoss, RCNetLoss, PCCLoss
from pynet.plotting import Board, update_board
import matplotlib.pyplot as plt

setup_logging(level="debug")
logger = logging.getLogger("pynet")
losses = pynet.get_tools(tool_name="losses")

outdir = "/neurospin/nsap/tmp/registration"
data = fetch_registration(datasetdir=outdir)
manager = DataManager(input_path=data.input_path,
                      metadata_path=data.metadata_path,
                      number_of_folds=2,
                      batch_size=8,
                      sampler="random",
                      stratify_label="studies",
                      projection_labels={"studies": ["abide"]},
                      test_size=0.1,
                      add_input=True,
                      sample_size=0.1)

#############################################################################
# Training
# --------
#
# From the available models load the VoxelMorphRegister, VTNetRegister or
# ADDNet  and start the training.
# Note that the two first estimate a non linear deformation and require
# the input data to be afinely registered. The ADDNet estimate an affine
# transform. We will see in the next section how to combine them in an
# efficient way.
                            n_feats=n_feats,
                            n_classes=n_classes,
                            train=True,
                            snr=snr)
ds_val = SyntheticDataset(n_samples=n_samples,
                          lat_dim=true_lat_dims,
                          n_feats=n_feats,
                          n_classes=n_classes,
                          train=False,
                          snr=snr)
image_datasets = {"train": ds_train, "val": ds_val}
manager = DataManager.from_numpy(train_inputs=ds_train.data,
                                 train_outputs=None,
                                 train_labels=ds_train.labels,
                                 validation_inputs=ds_val.data,
                                 validation_outputs=None,
                                 validation_labels=ds_val.labels,
                                 batch_size=batch_size,
                                 sampler="random",
                                 add_input=True)
print("- datasets:", image_datasets)
print("- shapes:", ds_train.data.shape, ds_val.data.shape)

# Display generated data
method = manifold.TSNE(n_components=2, init="pca", random_state=0)
y_train = method.fit_transform(ds_train.data)
y_val = method.fit_transform(ds_val.data)
fig, axs = plt.subplots(nrows=3, ncols=2)
for cnt, (name, y, labels) in enumerate(
    (("train", y_train, ds_train.labels), ("val", y_val, ds_val.labels))):
    colors = labels.astype(float)