def _test_rectangular_data_resnet_vae(self): #max_len, nchars = 24, 524 # (22, 22) input use latent_dim=11, dec_filters=22 max_len, nchars = 22, 22 rectangular_shape = (max_len, nchars) train_loader = DataLoader(TestVAE.DummyContactMap(rectangular_shape), batch_size=self.batch_size, shuffle=True) test_loader = DataLoader(TestVAE.DummyContactMap(rectangular_shape), batch_size=self.batch_size, shuffle=True) from molecules.ml.unsupervised.vae.resnet import ResnetVAEHyperparams hparams = ResnetVAEHyperparams(max_len=max_len, nchars=nchars, latent_dim=11, dec_filters=22) vae = VAE(rectangular_shape, hparams, self.optimizer_hparams) print(vae) summary(vae.model, rectangular_shape) vae.train(train_loader, test_loader, self.epochs)
def _test_resnet_vae_training(self): from molecules.ml.unsupervised.vae.resnet import ResnetVAEHyperparams path = './test/cvae_input.h5' train_loader = DataLoader(TestVAE.ContactMap(path, split='train', squeeze=True), batch_size=self.batch_size, shuffle=True) test_loader = DataLoader(TestVAE.ContactMap(path, split='valid', squeeze=True), batch_size=self.batch_size, shuffle=True) input_shape = (22, 22) hparams = ResnetVAEHyperparams(input_shape, latent_dim=11) vae = VAE(input_shape, hparams, self.optimizer_hparams) print(vae) summary(vae.model, input_shape) vae.train(train_loader, test_loader, self.epochs)
def test_load_checkpoint(self): hparams = SymmetricVAEHyperparams(**self.fs_peptide_hparams) checkpoint_callback = CheckpointCallback(directory=self.checkpoint_dir, interval=2) vae = VAE(self.input_shape, hparams, self.optimizer_hparams) print('loading checkpoint to resume training') # Train for 2 more additional epochs i.e. epochs=4 vae.train(self.train_loader, self.test_loader, epochs=4, checkpoint=self.checkpoint_file, callbacks=[checkpoint_callback])
def test_sparse_data_symmetric_vae(self): input_shape = (22, 22) train_loader = DataLoader(TestVAE.SparseSquareContactMap(input_shape), batch_size=self.batch_size, shuffle=True) test_loader = DataLoader(TestVAE.SparseSquareContactMap(input_shape), batch_size=self.batch_size, shuffle=True) vae = VAE(input_shape, self.hparams, self.optimizer_hparams) print(vae) summary(vae.model, input_shape) vae.train(train_loader, test_loader, self.epochs)
def _test_cvae_real_data(self): path = './test/cvae_input.h5' train_loader = DataLoader(TestVAE.ContactMap(path, split='train'), batch_size=self.batch_size, shuffle=True) test_loader = DataLoader(TestVAE.ContactMap(path, split='valid'), batch_size=self.batch_size, shuffle=True) vae = VAE(self.input_shape, self.hparams, self.optimizer_hparams) print(vae) summary(vae.model, self.input_shape) vae.train(train_loader, test_loader, self.epochs)
def _test_rectangular_data_symmetric_vae(self): rectangular_shape = (1, 22, 22) #rectangular_shape = (1, 24, 524) train_loader = DataLoader(TestVAE.DummyContactMap(rectangular_shape), batch_size=self.batch_size, shuffle=True) test_loader = DataLoader(TestVAE.DummyContactMap(rectangular_shape), batch_size=self.batch_size, shuffle=True) vae = VAE(rectangular_shape, self.hparams, self.optimizer_hparams) print(vae) summary(vae.model, rectangular_shape) vae.train(train_loader, test_loader, self.epochs)
def test_resnet_big_input(self): big_shape = (3768, 3768) # GB: 0.113582592 train_loader = DataLoader(TestVAE.DummyContactMap(big_shape, size=5), batch_size=2, shuffle=True) test_loader = DataLoader(TestVAE.DummyContactMap(big_shape, size=5), batch_size=5, shuffle=True) from molecules.ml.unsupervised.vae.resnet import ResnetVAEHyperparams hparams = ResnetVAEHyperparams(nchars=3768, max_len=3768, latent_dim=24) vae = VAE(big_shape, hparams, self.optimizer_hparams) print(vae) summary(vae.model, big_shape) vae.train(train_loader, test_loader, self.epochs)
def _test_resnet_vae(self): from molecules.ml.unsupervised.vae.resnet import ResnetVAEHyperparams input_shape = (1200, 1200) hparams = ResnetVAEHyperparams(input_shape, latent_dim=150) vae = VAE(input_shape, hparams, self.optimizer_hparams) print(vae) summary(vae.model, input_shape)
def test_pytorch_cvae_real_data(self): hparams = SymmetricVAEHyperparams(**self.fs_peptide_hparams) vae = VAE(self.input_shape, hparams, self.optimizer_hparams) print(vae) summary(vae.model, self.input_shape) loss_callback = LossCallback() checkpoint_callback = CheckpointCallback(directory=self.checkpoint_dir, interval=2) embedding_callback = EmbeddingCallback(TestCallback.DummyContactMap(self.input_shape)[:]) vae.train(self.train_loader, self.test_loader, self.epochs, callbacks=[loss_callback, checkpoint_callback, embedding_callback]) print(loss_callback.train_losses) print(loss_callback.valid_losses)
def setup_class(self): self.epochs = 2 self.batch_size = 100 self.input_shape = (1, 22, 22) # Use FSPeptide sized contact maps self.checkpoint_dir = os.path.join('.', 'test', 'test_checkpoint') # Optimal Fs-peptide params self.fs_peptide_hparams ={'filters': [100, 100, 100, 100], 'kernels': [5, 5, 5, 5], 'strides': [1, 2, 1, 1], 'affine_widths': [64], 'affine_dropouts': [0], 'latent_dim': 10} self.optimizer_hparams = OptimizerHyperparams(name='RMSprop', hparams={'lr':0.00001}) path = './test/cvae_input.h5' self.train_loader = DataLoader(TestCallback.ContactMap(path, split='train'), batch_size=self.batch_size, shuffle=True) self.test_loader = DataLoader(TestCallback.ContactMap(path, split='valid'), batch_size=self.batch_size, shuffle=True) # Save checkpoint to test loading checkpoint_callback = CheckpointCallback(directory=self.checkpoint_dir, interval=2) hparams = SymmetricVAEHyperparams(**self.fs_peptide_hparams) vae = VAE(self.input_shape, hparams, self.optimizer_hparams) vae.train(self.train_loader, self.test_loader, epochs=1, callbacks=[checkpoint_callback]) # Get checkpoint after 2nd epoch file = os.path.join(self.checkpoint_dir, '*') self.checkpoint_file = sorted(glob.glob(file))[-1]
def _test_save_load_weights(self): vae1 = VAE(self.input_shape, self.hparams, self.optimizer_hparams) vae1.train(self.train_loader, self.test_loader, self.epochs) vae1.save_weights(self.enc_path, self.dec_path) vae2 = VAE(self.input_shape, self.hparams, self.optimizer_hparams) vae2.load_weights(self.enc_path, self.dec_path) # Checks that all model weights are exactly equal for va1_params, vae2_params in zip(vae1.model.state_dict().values(), vae2.model.state_dict().values()): assert torch.equal(va1_params, va1_params) # Checks that weights can be loaded into encoder/decoder modules seperately from molecules.ml.unsupervised.vae.symmetric import ( SymmetricEncoderConv2d, SymmetricDecoderConv2d) encoder = SymmetricEncoderConv2d(self.input_shape, self.hparams) encoder.load_weights(self.enc_path) decoder = SymmetricDecoderConv2d(self.input_shape, self.hparams, encoder.encoder_dim) decoder.load_weights(self.dec_path)
def _test_encode_decode(self): vae = VAE(self.input_shape, self.hparams, self.optimizer_hparams) vae.train(self.train_loader, self.test_loader, self.epochs) test_data = TestVAE.DummyContactMap(self.input_shape, 100)[:] embeddings = vae.encode(test_data) recons = vae.decode(embeddings)
def main(input_path, out_path, model_id, gpu, epochs, batch_size, model_type, latent_dim): """Example for training Fs-peptide with either Symmetric or Resnet VAE.""" assert model_type in ['symmetric', 'resnet'] # Note: See SymmetricVAEHyperparams, ResnetVAEHyperparams class definitions # for hyperparameter options. if model_type == 'symmetric': # Optimal Fs-peptide params fs_peptide_hparams ={'filters': [100, 100, 100, 100], 'kernels': [5, 5, 5, 5], 'strides': [1, 2, 1, 1], 'affine_widths': [64], 'affine_dropouts': [0], 'latent_dim': latent_dim} input_shape = (1, 22, 22) squeeze = False hparams = SymmetricVAEHyperparams(**fs_peptide_hparams) elif model_type == 'resnet': input_shape = (22, 22) squeeze = True # Specify no ones in training data shape hparams = ResnetVAEHyperparams(input_shape, latent_dim=11) optimizer_hparams = OptimizerHyperparams(name='RMSprop', hparams={'lr':0.00001}) vae = VAE(input_shape, hparams, optimizer_hparams) # Diplay model print(vae) summary(vae.model, input_shape) # Load training and validation data train_loader = DataLoader(ContactMap(input_path, split='train', squeeze=squeeze), batch_size=batch_size, shuffle=True) valid_loader = DataLoader(ContactMap(input_path, split='valid', squeeze=squeeze), batch_size=batch_size, shuffle=True) # For ease of training multiple models model_path = join(out_path, f'model-{model_id}') # Optional callbacks loss_callback = LossCallback() checkpoint_callback = CheckpointCallback(directory=join(model_path, 'checkpoint')) embedding_callback = EmbeddingCallback(ContactMap(input_path, split='valid', squeeze=squeeze).sample) # Train model with callbacks vae.train(train_loader, valid_loader, epochs, callbacks=[loss_callback, checkpoint_callback, embedding_callback]) # Save loss history and embedddings history to disk. loss_callback.save(join(model_path, 'loss.pt')) embedding_callback.save(join(model_path, 'embed.pt')) # Save hparams to disk hparams.save(join(model_path, 'model-hparams.pkl')) optimizer_hparams.save(join(model_path, 'optimizer-hparams.pkl')) # Save final model weights to disk vae.save_weights(join(model_path, 'encoder-weights.pt'), join(model_path, 'decoder-weights.pt'))