示例#1
0
    def test_optimizer_hyperparams(self):
        from molecules.ml.hyperparams import OptimizerHyperparams, get_optimizer
        from torch import nn

        class Model(nn.Module):
            def __init__(self):
                super(Model, self).__init__()
                self.layer = nn.Linear(5, 5)

            def forward(self, x):
                return self.layer(x)

        model = Model()
        name = 'RMSprop'
        hparams = {'lr': 0.9}

        optimizer_hparams = OptimizerHyperparams(name, hparams)

        optimizer_hparams.save(self.optimizer_fname)

        del optimizer_hparams
        import gc
        gc.collect()

        loaded_hparams = OptimizerHyperparams().load(self.optimizer_fname)

        optimizer = get_optimizer(model, loaded_hparams)
示例#2
0
    def __init__(self,
                 input_shape,
                 hparams=SymmetricVAEHyperparams(),
                 optimizer_hparams=OptimizerHyperparams(),
                 loss=None,
                 gpu=None,
                 verbose=True):
        """
        Parameters
        ----------
        input_shape : tuple
            shape of incomming data.
            Note: For use with SymmetricVAE use (1, num_residues, num_residues)
                  For use with ResnetVAE use (num_residues, num_residues)

        hparams : molecules.ml.hyperparams.Hyperparams
            Defines the model architecture hyperparameters. Currently implemented
            are SymmetricVAEHyperparams and ResnetVAEHyperparams.

        optimizer_hparams : molecules.ml.hyperparams.OptimizerHyperparams
            Defines the optimizer type and corresponding hyperparameters.

        loss: : function, optional
            Defines an optional loss function with inputs (recon_x, x, mu, logvar)
            and ouput torch loss.

        gpu : int, tuple, or None
            Encoder and decoder will train on ...
            If None, cuda GPU device if it is available, otherwise CPU.
            If int, the specified GPU.
            If tuple, the first and second GPUs respectively.

        verbose : bool
            True prints training and validation loss to stdout.
        """

        hparams.validate()
        optimizer_hparams.validate()

        self.verbose = verbose

        # Tuple of encoder, decoder device
        self.device = Device(*self._configure_device(gpu))

        self.model = VAEModel(input_shape, hparams, self.device)

        # TODO: consider making optimizer_hparams a member variable
        # RMSprop with lr=0.001, alpha=0.9, epsilon=1e-08, decay=0.0
        self.optimizer = get_optimizer(self.model, optimizer_hparams)

        self.loss_fnc = vae_loss if loss is None else loss
示例#3
0
    def setup_class(self):
        self.epochs = 1
        self.batch_size = 100
        self.input_shape = (1, 22, 22) # Use FSPeptide sized contact maps
        self.train_loader = DataLoader(TestVAE.DummyContactMap(self.input_shape),
                                       batch_size=self.batch_size, shuffle=True)
        self.test_loader = DataLoader(TestVAE.DummyContactMap(self.input_shape),
                                      batch_size=self.batch_size, shuffle=True)


        # Optimal Fs-peptide params
        fs_peptide_hparams = {'filters': [100, 100, 100, 100],
                              'kernels': [5, 5, 5, 5],
                              'strides': [1, 2, 1, 1],
                              'affine_widths': [64],
                              'affine_dropouts': [0],
                              'latent_dim': 10}

        diff_filters_hparams = {'filters': [100, 64, 64, 100],
                                'kernels': [5, 5, 3, 7],
                                'strides': [1, 2, 1, 2],
                                'affine_widths': [64],
                                'affine_dropouts': [0],
                                'latent_dim': 10}

        strided_hparams = {'filters': [100, 100, 100, 100],
                           'kernels': [5, 5, 5, 5],
                           'strides': [1, 2, 2, 1],
                           'affine_widths': [64],
                           'affine_dropouts': [0],
                           'latent_dim': 10}

        hparams = {'filters': [64, 64, 64, 64],
                   'kernels': [3, 3, 3, 3],
                   'strides': [1, 2, 1, 1],
                   'affine_widths': [128],
                   'affine_dropouts': [0],
                   'latent_dim': 3}

        self.hparams = SymmetricVAEHyperparams(**fs_peptide_hparams)
        self.optimizer_hparams = OptimizerHyperparams(name='RMSprop', hparams={'lr':0.00001})

        # For testing saving and loading weights
        self.enc_path = os.path.join('.', 'test', 'data', 'encoder-weights.pt')
        self.dec_path = os.path.join('.', 'test', 'data', 'decoder-weights.pt')
示例#4
0
    def setup_class(self):
        self.epochs = 2
        self.batch_size = 100
        self.input_shape = (1, 22, 22) # Use FSPeptide sized contact maps
        self.checkpoint_dir = os.path.join('.', 'test', 'test_checkpoint')

        # Optimal Fs-peptide params
        self.fs_peptide_hparams ={'filters': [100, 100, 100, 100],
                                  'kernels': [5, 5, 5, 5],
                                  'strides': [1, 2, 1, 1],
                                  'affine_widths': [64],
                                  'affine_dropouts': [0],
                                  'latent_dim': 10}

        self.optimizer_hparams = OptimizerHyperparams(name='RMSprop', hparams={'lr':0.00001})

        path = './test/cvae_input.h5'

        self.train_loader = DataLoader(TestCallback.ContactMap(path, split='train'),
                                  batch_size=self.batch_size, shuffle=True)
        self.test_loader = DataLoader(TestCallback.ContactMap(path, split='valid'),
                                 batch_size=self.batch_size, shuffle=True)

        # Save checkpoint to test loading

        checkpoint_callback = CheckpointCallback(directory=self.checkpoint_dir,
                                                 interval=2)

        hparams = SymmetricVAEHyperparams(**self.fs_peptide_hparams)

        vae = VAE(self.input_shape, hparams, self.optimizer_hparams)

        vae.train(self.train_loader, self.test_loader, epochs=1,
                  callbacks=[checkpoint_callback])

        # Get checkpoint after 2nd epoch
        file = os.path.join(self.checkpoint_dir, '*')
        self.checkpoint_file = sorted(glob.glob(file))[-1]
示例#5
0
def main(input_path, out_path, model_id, gpu, epochs, batch_size, model_type, latent_dim):
    """Example for training Fs-peptide with either Symmetric or Resnet VAE."""

    assert model_type in ['symmetric', 'resnet']

    # Note: See SymmetricVAEHyperparams, ResnetVAEHyperparams class definitions
    #       for hyperparameter options. 

    if model_type == 'symmetric':
        # Optimal Fs-peptide params
        fs_peptide_hparams ={'filters': [100, 100, 100, 100],
                             'kernels': [5, 5, 5, 5],
                             'strides': [1, 2, 1, 1],
                             'affine_widths': [64],
                             'affine_dropouts': [0],
                             'latent_dim': latent_dim}

        input_shape = (1, 22, 22)
        squeeze = False
        hparams = SymmetricVAEHyperparams(**fs_peptide_hparams)

    elif model_type == 'resnet':
        input_shape = (22, 22)
        squeeze = True # Specify no ones in training data shape
        hparams = ResnetVAEHyperparams(input_shape, latent_dim=11)

    optimizer_hparams = OptimizerHyperparams(name='RMSprop', hparams={'lr':0.00001})

    vae = VAE(input_shape, hparams, optimizer_hparams)

    # Diplay model
    print(vae)
    summary(vae.model, input_shape)

    # Load training and validation data
    train_loader = DataLoader(ContactMap(input_path, split='train', squeeze=squeeze),
                              batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(ContactMap(input_path, split='valid', squeeze=squeeze),
                              batch_size=batch_size, shuffle=True)

    # For ease of training multiple models
    model_path = join(out_path, f'model-{model_id}')

    # Optional callbacks
    loss_callback = LossCallback()
    checkpoint_callback = CheckpointCallback(directory=join(model_path, 'checkpoint'))
    embedding_callback = EmbeddingCallback(ContactMap(input_path, split='valid', squeeze=squeeze).sample)

    # Train model with callbacks
    vae.train(train_loader, valid_loader, epochs,
              callbacks=[loss_callback, checkpoint_callback, embedding_callback])

    # Save loss history and embedddings history to disk.
    loss_callback.save(join(model_path, 'loss.pt'))
    embedding_callback.save(join(model_path, 'embed.pt'))

    # Save hparams to disk
    hparams.save(join(model_path, 'model-hparams.pkl'))
    optimizer_hparams.save(join(model_path, 'optimizer-hparams.pkl'))

    # Save final model weights to disk
    vae.save_weights(join(model_path, 'encoder-weights.pt'),
                     join(model_path, 'decoder-weights.pt'))
示例#6
0
def main(
    cfg: AAEModelConfig,
    encoder_gpu: int,
    generator_gpu: int,
    discriminator_gpu: int,
    distributed: bool,
):

    # Do some scaffolding for DDP
    comm_rank = 0
    comm_size = 1
    comm = None
    if distributed and dist.is_available():

        import mpi4py

        mpi4py.rc.initialize = False
        from mpi4py import MPI  # noqa: E402

        MPI.Init_thread()

        # get communicator: duplicate from comm world
        comm = MPI.COMM_WORLD.Dup()

        # now match ranks between the mpi comm and the nccl comm
        os.environ["WORLD_SIZE"] = str(comm.Get_size())
        os.environ["RANK"] = str(comm.Get_rank())

        # init pytorch
        dist.init_process_group(backend="nccl", init_method="env://")
        comm_rank = dist.get_rank()
        comm_size = dist.get_world_size()

    model_hparams = AAE3dHyperparams(
        num_features=cfg.num_features,
        encoder_filters=cfg.encoder_filters,
        encoder_kernel_sizes=cfg.encoder_kernel_sizes,
        generator_filters=cfg.generator_filters,
        discriminator_filters=cfg.discriminator_filters,
        latent_dim=cfg.latent_dim,
        encoder_relu_slope=cfg.encoder_relu_slope,
        generator_relu_slope=cfg.generator_relu_slope,
        discriminator_relu_slope=cfg.discriminator_relu_slope,
        use_encoder_bias=cfg.use_encoder_bias,
        use_generator_bias=cfg.use_generator_bias,
        use_discriminator_bias=cfg.use_discriminator_bias,
        noise_mu=cfg.noise_mu,
        noise_std=cfg.noise_std,
        lambda_rec=cfg.lambda_rec,
        lambda_gp=cfg.lambda_gp,
    )

    # optimizers
    optimizer_hparams = OptimizerHyperparams(name=cfg.optimizer_name,
                                             hparams={"lr": cfg.optimizer_lr})

    # Save hparams to disk and load initial weights and create virtual h5 file
    if comm_rank == 0:
        cfg.output_path.mkdir(exist_ok=True)
        model_hparams.save(cfg.output_path.joinpath("model-hparams.json"))
        optimizer_hparams.save(
            cfg.output_path.joinpath("optimizer-hparams.json"))
        init_weights = get_init_weights(cfg)
        h5_file, h5_files = get_h5_training_file(cfg)
        with open(cfg.output_path.joinpath("virtual-h5-metadata.json"),
                  "w") as f:
            json.dump(h5_files, f)

    else:
        init_weights, h5_file = None, None

    if comm_size > 1:
        init_weights = comm.bcast(init_weights, 0)
        h5_file = comm.bcast(h5_file, 0)

    # construct model
    aae = AAE3d(
        cfg.num_points,
        cfg.num_features,
        cfg.batch_size,
        model_hparams,
        optimizer_hparams,
        gpu=(encoder_gpu, generator_gpu, discriminator_gpu),
        init_weights=init_weights,
    )

    enc_device = torch.device(f"cuda:{encoder_gpu}")
    if comm_size > 1:
        if (encoder_gpu == generator_gpu) and (encoder_gpu
                                               == discriminator_gpu):
            aae.model = DDP(aae.model,
                            device_ids=[enc_device],
                            output_device=enc_device)
        else:
            aae.model = DDP(aae.model, device_ids=None, output_device=None)

    # set global default device
    torch.cuda.set_device(enc_device.index)

    if comm_rank == 0:
        # Diplay model
        print(aae)

    assert isinstance(h5_file, Path)
    # set up dataloaders
    train_dataset = get_dataset(
        cfg.dataset_location,
        h5_file,
        cfg.dataset_name,
        cfg.rmsd_name,
        cfg.fnc_name,
        cfg.num_points,
        cfg.num_features,
        split="train",
        shard_id=comm_rank,
        num_shards=comm_size,
        normalize="box",
        cms_transform=False,
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=cfg.batch_size,
        shuffle=True,
        drop_last=True,
        pin_memory=True,
        num_workers=cfg.num_data_workers,
    )

    valid_dataset = get_dataset(
        cfg.dataset_location,
        h5_file,
        cfg.dataset_name,
        cfg.rmsd_name,
        cfg.fnc_name,
        cfg.num_points,
        cfg.num_features,
        split="valid",
        shard_id=comm_rank,
        num_shards=comm_size,
        normalize="box",
        cms_transform=False,
    )

    valid_loader = DataLoader(
        valid_dataset,
        batch_size=cfg.batch_size,
        shuffle=True,
        drop_last=True,
        pin_memory=True,
        num_workers=cfg.num_data_workers,
    )

    print(
        f"Having {len(train_dataset)} training and {len(valid_dataset)} validation samples."
    )

    wandb_config = setup_wandb(cfg, aae.model, comm_rank)

    # Optional callbacks
    loss_callback = LossCallback(cfg.output_path.joinpath("loss.json"),
                                 wandb_config=wandb_config,
                                 mpi_comm=comm)

    checkpoint_callback = CheckpointCallback(
        out_dir=cfg.output_path.joinpath("checkpoint"), mpi_comm=comm)

    save_callback = SaveEmbeddingsCallback(
        out_dir=cfg.output_path.joinpath("embeddings"),
        interval=cfg.embed_interval,
        sample_interval=cfg.sample_interval,
        mpi_comm=comm,
    )

    # TSNEPlotCallback requires SaveEmbeddingsCallback to run first
    tsne_callback = TSNEPlotCallback(
        out_dir=cfg.output_path.joinpath("embeddings"),
        projection_type="3d",
        target_perplexity=100,
        interval=cfg.tsne_interval,
        tsne_is_blocking=True,
        wandb_config=wandb_config,
        mpi_comm=comm,
    )

    # Train model with callbacks
    callbacks = [
        loss_callback,
        checkpoint_callback,
        save_callback,
        tsne_callback,
    ]

    # Optionaly train for a different number of
    # epochs on the first DDMD iterations
    if cfg.stage_idx == 0:
        epochs = cfg.initial_epochs
    else:
        epochs = cfg.epochs

    aae.train(train_loader, valid_loader, epochs, callbacks=callbacks)

    # Save loss history to disk.
    if comm_rank == 0:
        loss_callback.save(cfg.output_path.joinpath("loss.json"))

        # Save final model weights to disk
        aae.save_weights(
            cfg.output_path.joinpath("encoder-weights.pt"),
            cfg.output_path.joinpath("generator-weights.pt"),
            cfg.output_path.joinpath("discriminator-weights.pt"),
        )