Beispiel #1
0
def test_linear():
    encoder_1 = architectures.LinearEncoder(latent_dims=1, feature_size=10)
    encoder_2 = architectures.LinearEncoder(latent_dims=1, feature_size=12)
    dcca = DCCA(latent_dims=1, encoders=[encoder_1, encoder_2])
    optimizer = optim.Adam(dcca.parameters(), lr=1e-1)
    dcca = CCALightning(dcca, optimizer=optimizer)
    trainer = pl.Trainer(max_epochs=50, enable_checkpointing=False)
    trainer.fit(dcca, loader)
    cca = CCA().fit((X, Y))
    # check linear encoder with SGD matches vanilla linear CCA
    assert (np.testing.assert_array_almost_equal(
        cca.score((X, Y)), trainer.model.score(loader), decimal=2) is None)
from multiviewdata.torchdatasets import SplitMNISTDataset
from cca_zoo.deepmodels import DCCA, CCALightning, get_dataloaders, architectures

n_train = 500
n_val = 100
train_dataset = SplitMNISTDataset(root="",
                                  mnist_type="MNIST",
                                  train=True,
                                  download=True)
val_dataset = Subset(train_dataset, np.arange(n_train, n_train + n_val))
train_dataset = Subset(train_dataset, np.arange(n_train))
train_loader, val_loader = get_dataloaders(train_dataset, val_dataset)

# The number of latent dimensions across models
latent_dims = 2
# number of epochs for deep models
epochs = 10

# TODO add in custom architecture and schedulers and stuff to show it off
encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=392)
encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=392)

# Deep CCA
dcca = DCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2])
optimizer = optim.Adam(dcca.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 1)
dcca = CCALightning(dcca, optimizer=optimizer, lr_scheduler=scheduler)
trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False)
trainer.fit(dcca, train_loader, val_loader)
Beispiel #3
0
def test_DCCA_methods():
    N = len(train_dataset)
    latent_dims = 2
    epochs = 100
    cca = CCA(latent_dims=latent_dims).fit((X, Y))
    # DCCA_NOI
    encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=10)
    encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=12)
    dcca_noi = DCCA_NOI(latent_dims, N, encoders=[encoder_1, encoder_2], rho=0)
    optimizer = optim.Adam(dcca_noi.parameters(), lr=1e-2)
    dcca_noi = CCALightning(dcca_noi, optimizer=optimizer)
    trainer = pl.Trainer(max_epochs=epochs,
                         log_every_n_steps=10,
                         enable_checkpointing=False)
    trainer.fit(dcca_noi, train_loader)
    assert (np.testing.assert_array_less(
        cca.score((X, Y)).sum(),
        trainer.model.score(train_loader).sum()) is None)
    # Soft Decorrelation (stochastic Decorrelation Loss)
    encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=10)
    encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=12)
    sdl = DCCA_SDL(latent_dims, N, encoders=[encoder_1, encoder_2], lam=1e-3)
    optimizer = optim.SGD(sdl.parameters(), lr=1e-1)
    sdl = CCALightning(sdl, optimizer=optimizer)
    trainer = pl.Trainer(max_epochs=epochs, log_every_n_steps=10)
    trainer.fit(sdl, train_loader)
    assert (np.testing.assert_array_less(
        cca.score((X, Y)).sum(),
        trainer.model.score(train_loader).sum()) is None)
    # DCCA
    encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=10)
    encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=12)
    dcca = DCCA(
        latent_dims=latent_dims,
        encoders=[encoder_1, encoder_2],
        objective=objectives.CCA,
    )
    optimizer = optim.SGD(dcca.parameters(), lr=1e-1)
    dcca = CCALightning(dcca, optimizer=optimizer)
    trainer = pl.Trainer(max_epochs=epochs,
                         log_every_n_steps=10,
                         enable_checkpointing=False)
    trainer.fit(dcca, train_loader, val_dataloaders=val_loader)
    assert (np.testing.assert_array_less(
        cca.score((X, Y)).sum(),
        trainer.model.score(train_loader).sum()) is None)
    # DGCCA
    encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=10)
    encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=12)
    dgcca = DCCA(
        latent_dims=latent_dims,
        encoders=[encoder_1, encoder_2],
        objective=objectives.GCCA,
    )
    optimizer = optim.SGD(dgcca.parameters(), lr=1e-2)
    dgcca = CCALightning(dgcca, optimizer=optimizer)
    trainer = pl.Trainer(max_epochs=epochs,
                         log_every_n_steps=10,
                         enable_checkpointing=False)
    trainer.fit(dgcca, train_loader)
    assert (np.testing.assert_array_less(
        cca.score((X, Y)).sum(),
        trainer.model.score(train_loader).sum()) is None)
    # DMCCA
    encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=10)
    encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=12)
    dmcca = DCCA(
        latent_dims=latent_dims,
        encoders=[encoder_1, encoder_2],
        objective=objectives.MCCA,
    )
    optimizer = optim.SGD(dmcca.parameters(), lr=1e-2)
    dmcca = CCALightning(dmcca, optimizer=optimizer)
    trainer = pl.Trainer(max_epochs=epochs,
                         log_every_n_steps=10,
                         enable_checkpointing=False)
    trainer.fit(dmcca, train_loader)
    assert (np.testing.assert_array_less(
        cca.score((X, Y)).sum(),
        trainer.model.score(train_loader).sum()) is None)
    # Barlow Twins
    encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=10)
    encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=12)
    barlowtwins = BarlowTwins(
        latent_dims=latent_dims,
        encoders=[encoder_1, encoder_2],
    )
    optimizer = optim.SGD(barlowtwins.parameters(), lr=1e-2)
    barlowtwins = CCALightning(barlowtwins, optimizer=optimizer)
    trainer = pl.Trainer(max_epochs=epochs,
                         log_every_n_steps=10,
                         enable_checkpointing=False)
    trainer.fit(barlowtwins, train_loader)
    assert (np.testing.assert_array_less(
        cca.score((X, Y)).sum(),
        trainer.model.score(train_loader).sum()) is None)