Exemple #1
0
    def _test_rectangular_data_resnet_vae(self):

        #max_len, nchars = 24, 524

        # (22, 22) input use latent_dim=11, dec_filters=22
        max_len, nchars = 22, 22


        rectangular_shape = (max_len, nchars)

        train_loader = DataLoader(TestVAE.DummyContactMap(rectangular_shape),
                                  batch_size=self.batch_size, shuffle=True)
        test_loader = DataLoader(TestVAE.DummyContactMap(rectangular_shape),
                                 batch_size=self.batch_size, shuffle=True)

        from molecules.ml.unsupervised.vae.resnet import ResnetVAEHyperparams
        hparams = ResnetVAEHyperparams(max_len=max_len, nchars=nchars, latent_dim=11,
                                       dec_filters=22)

        vae = VAE(rectangular_shape, hparams, self.optimizer_hparams)

        print(vae)
        summary(vae.model, rectangular_shape)

        vae.train(train_loader, test_loader, self.epochs)
Exemple #2
0
    def _test_resnet_vae_training(self):
        from molecules.ml.unsupervised.vae.resnet import ResnetVAEHyperparams

        path = './test/cvae_input.h5'

        train_loader = DataLoader(TestVAE.ContactMap(path,
                                                     split='train',
                                                     squeeze=True),
                                  batch_size=self.batch_size,
                                  shuffle=True)
        test_loader = DataLoader(TestVAE.ContactMap(path,
                                                    split='valid',
                                                    squeeze=True),
                                 batch_size=self.batch_size,
                                 shuffle=True)

        input_shape = (22, 22)
        hparams = ResnetVAEHyperparams(input_shape, latent_dim=11)

        vae = VAE(input_shape, hparams, self.optimizer_hparams)

        print(vae)
        summary(vae.model, input_shape)

        vae.train(train_loader, test_loader, self.epochs)
    def test_load_checkpoint(self):

        hparams = SymmetricVAEHyperparams(**self.fs_peptide_hparams)
        checkpoint_callback = CheckpointCallback(directory=self.checkpoint_dir,
                                                 interval=2)

        vae = VAE(self.input_shape, hparams, self.optimizer_hparams)

        print('loading checkpoint to resume training')
        # Train for 2 more additional epochs i.e. epochs=4
        vae.train(self.train_loader, self.test_loader, epochs=4,
                  checkpoint=self.checkpoint_file, callbacks=[checkpoint_callback])
Exemple #4
0
    def test_sparse_data_symmetric_vae(self):

        input_shape = (22, 22)

        train_loader = DataLoader(TestVAE.SparseSquareContactMap(input_shape),
                                  batch_size=self.batch_size, shuffle=True)
        test_loader = DataLoader(TestVAE.SparseSquareContactMap(input_shape),
                                 batch_size=self.batch_size, shuffle=True)

        vae = VAE(input_shape, self.hparams, self.optimizer_hparams)

        print(vae)
        summary(vae.model, input_shape)

        vae.train(train_loader, test_loader, self.epochs)
Exemple #5
0
    def _test_cvae_real_data(self):

        path = './test/cvae_input.h5'

        train_loader = DataLoader(TestVAE.ContactMap(path, split='train'),
                                  batch_size=self.batch_size, shuffle=True)
        test_loader = DataLoader(TestVAE.ContactMap(path, split='valid'),
                                 batch_size=self.batch_size, shuffle=True)

        vae = VAE(self.input_shape, self.hparams, self.optimizer_hparams)

        print(vae)
        summary(vae.model, self.input_shape)

        vae.train(train_loader, test_loader, self.epochs)
Exemple #6
0
    def _test_rectangular_data_symmetric_vae(self):

        rectangular_shape = (1, 22, 22)
        #rectangular_shape = (1, 24, 524)

        train_loader = DataLoader(TestVAE.DummyContactMap(rectangular_shape),
                                  batch_size=self.batch_size, shuffle=True)
        test_loader = DataLoader(TestVAE.DummyContactMap(rectangular_shape),
                                 batch_size=self.batch_size, shuffle=True)

        vae = VAE(rectangular_shape, self.hparams, self.optimizer_hparams)

        print(vae)
        summary(vae.model, rectangular_shape)

        vae.train(train_loader, test_loader, self.epochs)
Exemple #7
0
    def test_resnet_big_input(self):
        big_shape = (3768, 3768) # GB: 0.113582592

        train_loader = DataLoader(TestVAE.DummyContactMap(big_shape, size=5),
                                  batch_size=2, shuffle=True)
        test_loader = DataLoader(TestVAE.DummyContactMap(big_shape, size=5),
                                 batch_size=5, shuffle=True)


        from molecules.ml.unsupervised.vae.resnet import ResnetVAEHyperparams
        hparams = ResnetVAEHyperparams(nchars=3768, max_len=3768, latent_dim=24)

        vae = VAE(big_shape, hparams, self.optimizer_hparams)

        print(vae)
        summary(vae.model, big_shape)

        vae.train(train_loader, test_loader, self.epochs)
Exemple #8
0
    def _test_resnet_vae(self):
        from molecules.ml.unsupervised.vae.resnet import ResnetVAEHyperparams

        input_shape = (1200, 1200)
        hparams = ResnetVAEHyperparams(input_shape, latent_dim=150)

        vae = VAE(input_shape, hparams, self.optimizer_hparams)

        print(vae)
        summary(vae.model, input_shape)
    def test_pytorch_cvae_real_data(self):

        hparams = SymmetricVAEHyperparams(**self.fs_peptide_hparams)

        vae = VAE(self.input_shape, hparams, self.optimizer_hparams)

        print(vae)
        summary(vae.model, self.input_shape)

        loss_callback = LossCallback()
        checkpoint_callback = CheckpointCallback(directory=self.checkpoint_dir,
                                                 interval=2)
        embedding_callback = EmbeddingCallback(TestCallback.DummyContactMap(self.input_shape)[:])

        vae.train(self.train_loader, self.test_loader, self.epochs,
                  callbacks=[loss_callback, checkpoint_callback, embedding_callback])

        print(loss_callback.train_losses)
        print(loss_callback.valid_losses)
Exemple #10
0
    def setup_class(self):
        self.epochs = 2
        self.batch_size = 100
        self.input_shape = (1, 22, 22) # Use FSPeptide sized contact maps
        self.checkpoint_dir = os.path.join('.', 'test', 'test_checkpoint')

        # Optimal Fs-peptide params
        self.fs_peptide_hparams ={'filters': [100, 100, 100, 100],
                                  'kernels': [5, 5, 5, 5],
                                  'strides': [1, 2, 1, 1],
                                  'affine_widths': [64],
                                  'affine_dropouts': [0],
                                  'latent_dim': 10}

        self.optimizer_hparams = OptimizerHyperparams(name='RMSprop', hparams={'lr':0.00001})

        path = './test/cvae_input.h5'

        self.train_loader = DataLoader(TestCallback.ContactMap(path, split='train'),
                                  batch_size=self.batch_size, shuffle=True)
        self.test_loader = DataLoader(TestCallback.ContactMap(path, split='valid'),
                                 batch_size=self.batch_size, shuffle=True)

        # Save checkpoint to test loading

        checkpoint_callback = CheckpointCallback(directory=self.checkpoint_dir,
                                                 interval=2)

        hparams = SymmetricVAEHyperparams(**self.fs_peptide_hparams)

        vae = VAE(self.input_shape, hparams, self.optimizer_hparams)

        vae.train(self.train_loader, self.test_loader, epochs=1,
                  callbacks=[checkpoint_callback])

        # Get checkpoint after 2nd epoch
        file = os.path.join(self.checkpoint_dir, '*')
        self.checkpoint_file = sorted(glob.glob(file))[-1]
Exemple #11
0
    def _test_save_load_weights(self):
        vae1 = VAE(self.input_shape, self.hparams, self.optimizer_hparams)
        vae1.train(self.train_loader, self.test_loader, self.epochs)
        vae1.save_weights(self.enc_path, self.dec_path)

        vae2 = VAE(self.input_shape, self.hparams, self.optimizer_hparams)
        vae2.load_weights(self.enc_path, self.dec_path)

        # Checks that all model weights are exactly equal
        for va1_params, vae2_params in zip(vae1.model.state_dict().values(),
                                           vae2.model.state_dict().values()):
            assert torch.equal(va1_params, va1_params)

        # Checks that weights can be loaded into encoder/decoder modules seperately
        from molecules.ml.unsupervised.vae.symmetric import (
            SymmetricEncoderConv2d, SymmetricDecoderConv2d)

        encoder = SymmetricEncoderConv2d(self.input_shape, self.hparams)
        encoder.load_weights(self.enc_path)

        decoder = SymmetricDecoderConv2d(self.input_shape, self.hparams,
                                         encoder.encoder_dim)
        decoder.load_weights(self.dec_path)
Exemple #12
0
    def _test_encode_decode(self):
        vae = VAE(self.input_shape, self.hparams, self.optimizer_hparams)
        vae.train(self.train_loader, self.test_loader, self.epochs)

        test_data = TestVAE.DummyContactMap(self.input_shape, 100)[:]

        embeddings = vae.encode(test_data)

        recons = vae.decode(embeddings)
Exemple #13
0
def main(input_path, out_path, model_id, gpu, epochs, batch_size, model_type, latent_dim):
    """Example for training Fs-peptide with either Symmetric or Resnet VAE."""

    assert model_type in ['symmetric', 'resnet']

    # Note: See SymmetricVAEHyperparams, ResnetVAEHyperparams class definitions
    #       for hyperparameter options. 

    if model_type == 'symmetric':
        # Optimal Fs-peptide params
        fs_peptide_hparams ={'filters': [100, 100, 100, 100],
                             'kernels': [5, 5, 5, 5],
                             'strides': [1, 2, 1, 1],
                             'affine_widths': [64],
                             'affine_dropouts': [0],
                             'latent_dim': latent_dim}

        input_shape = (1, 22, 22)
        squeeze = False
        hparams = SymmetricVAEHyperparams(**fs_peptide_hparams)

    elif model_type == 'resnet':
        input_shape = (22, 22)
        squeeze = True # Specify no ones in training data shape
        hparams = ResnetVAEHyperparams(input_shape, latent_dim=11)

    optimizer_hparams = OptimizerHyperparams(name='RMSprop', hparams={'lr':0.00001})

    vae = VAE(input_shape, hparams, optimizer_hparams)

    # Diplay model
    print(vae)
    summary(vae.model, input_shape)

    # Load training and validation data
    train_loader = DataLoader(ContactMap(input_path, split='train', squeeze=squeeze),
                              batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(ContactMap(input_path, split='valid', squeeze=squeeze),
                              batch_size=batch_size, shuffle=True)

    # For ease of training multiple models
    model_path = join(out_path, f'model-{model_id}')

    # Optional callbacks
    loss_callback = LossCallback()
    checkpoint_callback = CheckpointCallback(directory=join(model_path, 'checkpoint'))
    embedding_callback = EmbeddingCallback(ContactMap(input_path, split='valid', squeeze=squeeze).sample)

    # Train model with callbacks
    vae.train(train_loader, valid_loader, epochs,
              callbacks=[loss_callback, checkpoint_callback, embedding_callback])

    # Save loss history and embedddings history to disk.
    loss_callback.save(join(model_path, 'loss.pt'))
    embedding_callback.save(join(model_path, 'embed.pt'))

    # Save hparams to disk
    hparams.save(join(model_path, 'model-hparams.pkl'))
    optimizer_hparams.save(join(model_path, 'optimizer-hparams.pkl'))

    # Save final model weights to disk
    vae.save_weights(join(model_path, 'encoder-weights.pt'),
                     join(model_path, 'decoder-weights.pt'))