Esempio n. 1
0
    def __init__(self, args: Namespace, prediction_model: nn.Module,
                 encoder: nn.Module):
        super(GAN, self).__init__()
        self.args = args
        self.prediction_model = prediction_model
        self.encoder = encoder

        self.hidden_size = args.hidden_size
        self.disc_input_size = args.hidden_size + args.output_size
        self.act_func = self.encoder.encoder.act_func

        self.netD = nn.Sequential(
            nn.Linear(self.disc_input_size, self.hidden_size
                      ),  # doesn't support jtnn or additional features rn
            self.act_func,
            nn.Linear(self.hidden_size, self.hidden_size),
            self.act_func,
            nn.Linear(self.hidden_size, self.hidden_size),
            self.act_func,
            nn.Linear(self.hidden_size, 1))
        self.beta = args.wgan_beta

        # the optimizers don't really belong here, but we put it here so that we don't clutter code for other opts
        self.optimizerG = Adam(self.encoder.parameters(),
                               lr=args.init_lr[0] * args.gan_lr_mult,
                               betas=(0, 0.9))
        self.optimizerD = Adam(self.netD.parameters(),
                               lr=args.init_lr[0] * args.gan_lr_mult,
                               betas=(0, 0.9))

        self.use_scheduler = args.gan_use_scheduler
        if self.use_scheduler:
            self.schedulerG = NoamLR(
                self.optimizerG,
                warmup_epochs=args.warmup_epochs,
                total_epochs=args.epochs,
                steps_per_epoch=args.train_data_length // args.batch_size,
                init_lr=args.init_lr[0] * args.gan_lr_mult,
                max_lr=args.max_lr[0] * args.gan_lr_mult,
                final_lr=args.final_lr[0] * args.gan_lr_mult)
            self.schedulerD = NoamLR(
                self.optimizerD,
                warmup_epochs=args.warmup_epochs,
                total_epochs=args.epochs,
                steps_per_epoch=(args.train_data_length // args.batch_size) *
                args.gan_d_per_g,
                init_lr=args.init_lr[0] * args.gan_lr_mult,
                max_lr=args.max_lr[0] * args.gan_lr_mult,
                final_lr=args.final_lr[0] * args.gan_lr_mult)
Esempio n. 2
0
def build_lr_scheduler(optimizer: Optimizer,
                       warmup_epochs: int,
                       train_data_size: int,
                       batch_size: int,
                       init_lr: float,
                       max_lr: float,
                       final_lr: float,
                       epochs: int,
                       num_lrs: int,
                       total_epochs: List[int] = None) -> _LRScheduler:
    '''
    Builds a learning rate scheduler.

    :param optimizer: The Optimizer whose learning rate will be scheduled.
    :param total_epochs: The total number of epochs for which the model will be
    run.
    :return: An initialized learning rate scheduler.
    '''
    # Learning rate scheduler
    return NoamLR(
        optimizer=optimizer,
        warmup_epochs=[warmup_epochs],
        total_epochs=total_epochs or [epochs] * num_lrs,
        steps_per_epoch=train_data_size // batch_size,
        init_lr=[init_lr],
        max_lr=[max_lr],
        final_lr=[final_lr]
    )
Esempio n. 3
0
def build_lr_scheduler(optimizer: Optimizer,
                       args: Namespace,
                       total_epochs: List[int] = None) -> _LRScheduler:
    """
    Builds a learning rate scheduler.

    :param optimizer: The Optimizer whose learning rate will be scheduled.
    :param args: Arguments.
    :return: An initialized learning rate scheduler.
    """
    # Learning rate scheduler
    if args.scheduler == 'noam':
        return NoamLR(optimizer=optimizer,
                      warmup_epochs=args.warmup_epochs,
                      total_epochs=total_epochs
                      or [args.epochs] * args.num_lrs,
                      steps_per_epoch=args.train_data_size // args.batch_size,
                      init_lr=args.init_lr,
                      max_lr=args.max_lr,
                      final_lr=args.final_lr)

    if args.scheduler == 'none':
        return MockLR(optimizer=optimizer, lr=args.init_lr)

    if args.scheduler == 'decay':
        return ExponentialLR(optimizer, args.lr_decay_rate)

    raise ValueError('Learning rate scheduler "{}" not supported.'.format(
        args.scheduler))
Esempio n. 4
0
def build_lr_scheduler(optimizer: Optimizer,
                       args: Namespace,
                       total_epochs: List[int] = None,
                       scheduler_name: str = 'Noam') -> _LRScheduler:
    """
    Builds a learning rate scheduler.

    :param optimizer: The Optimizer whose learning rate will be scheduled.
    :param args: Arguments.
    :param total_epochs: The total number of epochs for which the model will be run.
    :return: An initialized learning rate scheduler.
    """
    # Learning rate scheduler
    if scheduler_name == 'Noam':
        return NoamLR(optimizer=optimizer,
                      warmup_epochs=[args.warmup_epochs],
                      total_epochs=total_epochs
                      or [args.epochs] * args.num_lrs,
                      steps_per_epoch=args.train_data_size // args.batch_size,
                      init_lr=[args.init_lr],
                      max_lr=[args.max_lr],
                      final_lr=[args.final_lr])
    elif scheduler_name == 'Sinexp':
        return SinexpLR(
            optimizer=optimizer,
            warmup_epochs=[args.warmup_epochs],
            total_epochs=total_epochs or [args.epochs] * args.num_lrs,
            steps_per_epoch=args.train_data_size // args.batch_size,
            init_lr=[args.max_lr],
            final_lr=[args.final_lr])
Esempio n. 5
0
def build_lr_scheduler(optimizer: Optimizer,
                       args: TrainArgs,
                       total_epochs: List[int] = None) -> _LRScheduler:
    """
    Builds a PyTorch learning rate scheduler.

    :param optimizer: The Optimizer whose learning rate will be scheduled.
    :param args: A :class:`~chemprop.args.TrainArgs` object containing learning rate arguments.
    :param total_epochs: The total number of epochs for which the model will be run.
    :return: An initialized learning rate scheduler.
    """
    # Learning rate scheduler
    return NoamLR(optimizer=optimizer,
                  warmup_epochs=[args.warmup_epochs],
                  total_epochs=total_epochs or [args.epochs] * args.num_lrs,
                  steps_per_epoch=args.train_data_size // args.batch_size,
                  init_lr=[args.init_lr],
                  max_lr=[args.max_lr],
                  final_lr=[args.final_lr])
Esempio n. 6
0
def build_lr_scheduler(optimizer: Optimizer,
                       args: TrainArgs,
                       total_epochs: List[int] = None) -> _LRScheduler:
    """
    Builds a learning rate scheduler.

    :param optimizer: The Optimizer whose learning rate will be scheduled.
    :param args: Arguments.
    :param total_epochs: The total number of epochs for which the model will be run.
    :return: An initialized learning rate scheduler.
    """
    num_param_groups = len(optimizer.param_groups)
    return NoamLR(optimizer=optimizer,
                  warmup_epochs=[args.warmup_epochs] * num_param_groups,
                  total_epochs=[args.noam_epochs] * num_param_groups,
                  steps_per_epoch=args.train_data_size // args.batch_size,
                  init_lr=[args.init_lr] * num_param_groups,
                  max_lr=[args.max_lr] * num_param_groups,
                  final_lr=[args.final_lr] * num_param_groups)
Esempio n. 7
0
def train_dun(
    model,
    train_data,
    val_data,
    num_workers,
    cache,
    loss_func,
    metric_func,
    scaler,
    features_scaler,
    args,
    save_dir
    ):

    # data loaders for dun
    train_data_loader = MoleculeDataLoader(
        dataset=train_data,
        batch_size=args.batch_size_dun,
        num_workers=num_workers,
        cache=cache,
        class_balance=args.class_balance,
        shuffle=True,
        seed=args.seed
    )
    val_data_loader = MoleculeDataLoader(
        dataset=val_data,
        batch_size=args.batch_size_dun,
        num_workers=num_workers,
        cache=cache
    )
    
    # instantiate DUN model with Bayesian linear layers (includes log noise)
    model_dun = MoleculeModelDUN(args)

    # copy over parameters from pretrained to DUN model
    # we take the transpose because the Bayes linear layers have transpose shapes
    for (_, param_dun), (_, param_pre) in zip(model_dun.named_parameters(), model.named_parameters()):
        param_dun.data = copy.deepcopy(param_pre.data.T)
        
    # instantiate rho for each weight
    for layer in model_dun.children():
        if isinstance(layer, BayesLinear):
            layer.init_rho(args.rho_min_dun, args.rho_max_dun)
    for layer in model_dun.encoder.encoder.children():
        if isinstance(layer, BayesLinear):
            layer.init_rho(args.rho_min_dun, args.rho_max_dun)

    # instantiate variational categorical distribution
    model_dun.create_log_cat(args)

    # move dun model to cuda
    if args.cuda:
        print('Moving dun model to cuda')
        model_dun = model_dun.to(args.device)
    
    # loss_func
    loss_func = neg_log_likeDUN
    
    # optimiser
    optimizer = torch.optim.Adam(model_dun.parameters(), lr=args.lr_dun_min)
    
    # scheduler
    scheduler = NoamLR(
        optimizer=optimizer,
        warmup_epochs=[2],
        total_epochs=[100],
        steps_per_epoch=args.train_data_size // args.batch_size_dun,
        init_lr=[args.lr_dun_min],
        max_lr=[args.lr_dun_max],
        final_lr=[args.lr_dun_min]
    )

    # non sampling mode for first 100 epochs
    bbp_switch = 3
    
    # freeze log_cat for first 100 epochs
    for name, parameter in model_dun.named_parameters():
        if name == 'log_cat':
            parameter.requires_grad = False
        else:
            parameter.requires_grad = True

    
    print("----------DUN training----------")
    
    # training loop
    best_score = float('inf') if args.minimize_score else -float('inf')
    best_epoch, n_iter = 0, 0
    for epoch in range(args.epochs_dun):
        print(f'DUN epoch {epoch}')

        # start second phase
        if epoch == 100:
            scheduler = scheduler_const([args.lr_dun_min])
            bbp_switch = 4
            for name, parameter in model_dun.named_parameters():
                parameter.requires_grad = True


        n_iter = train(
                model=model_dun,
                data_loader=train_data_loader,
                loss_func=loss_func,
                optimizer=optimizer,
                scheduler=scheduler,
                args=args,
                n_iter=n_iter,
                bbp_switch=bbp_switch
            )
        
        val_scores = evaluate(
                model=model_dun,
                data_loader=val_data_loader,
                args=args,
                num_tasks=args.num_tasks,
                metric_func=metric_func,
                dataset_type=args.dataset_type,
                scaler=scaler
            )
        
        # Average validation score
        avg_val_score = np.nanmean(val_scores)
        print(f'Validation {args.metric} = {avg_val_score:.6f}')
        wandb.log({"Validation MAE": avg_val_score})
        print('variational categorical:')
        print(torch.exp(model_dun.log_cat) / torch.sum(torch.exp(model_dun.log_cat)))

        # Save model checkpoint if improved validation score
        if (args.minimize_score and avg_val_score < best_score or \
                not args.minimize_score and avg_val_score > best_score) and (epoch >= args.presave_dun):
            best_score, best_epoch = avg_val_score, epoch
            save_checkpoint(os.path.join(save_dir, 'model_dun.pt'), model_dun, scaler, features_scaler, args)
    
    # load model with best validation score
    template = MoleculeModelDUN(args)
    for layer in template.children():
        if isinstance(layer, BayesLinear):
            layer.init_rho(args.rho_min_dun, args.rho_max_dun)
    for layer in template.encoder.encoder.children():
        if isinstance(layer, BayesLinear):
            layer.init_rho(args.rho_min_dun, args.rho_max_dun)
    template.create_log_cat(args)
    print(f'Best validation {args.metric} = {best_score:.6f} on epoch {best_epoch}')
    model_dun = load_checkpoint(os.path.join(save_dir, 'model_dun.pt'), device=args.device, logger=None, template = template)


    return model_dun
Esempio n. 8
0
def run_training(args: Namespace, logger: Logger = None):
    """
    Trains a model and returns test scores on the model checkpoint with the highest validation score.
    :param args: args info
    :param logger: logger info
    :return: Optimal average test score (for use in hyperparameter optimization via Hyperopt)
    """

    # Set up logger
    if logger is not None:
        debug, info = logger.debug, logger.info
    else:
        debug = info = print

    debug(pformat(vars(args)))

    # Load metadata
    metadata = json.load(open(args.data_path, 'r'))

    # Train/val/test split
    if args.k_fold_split:
        data_splits = []
        kf = KFold(n_splits=args.num_folds, shuffle=True, random_state=args.seed)
        for train_index, test_index in kf.split(metadata):
            splits = [train_index, test_index]
            data_splits.append(splits)
        data_splits = data_splits[args.fold_index]

        if args.use_inner_test:
            train_indices, remaining_indices = train_test_split(data_splits[0], test_size=args.val_test_size,
                                                                random_state=args.seed)
            validation_indices, test_indices = train_test_split(remaining_indices, test_size=0.5,
                                                                random_state=args.seed)

        else:
            train_indices = data_splits[0]
            validation_indices, test_indices = train_test_split(data_splits[1], test_size=0.5, random_state=args.seed)

        train_metadata = list(np.asarray(metadata)[list(train_indices)])
        validation_metadata = list(np.asarray(metadata)[list(validation_indices)])
        test_metadata = list(np.asarray(metadata)[list(test_indices)])

    else:
        train_metadata, remaining_metadata = train_test_split(metadata, test_size=args.val_test_size,
                                                              random_state=args.seed)
        validation_metadata, test_metadata = train_test_split(remaining_metadata, test_size=0.5, random_state=args.seed)

    # Load datasets
    debug('Loading data')
    transform = Compose([Augmentation(args.augmentation_length), NNGraph(args.num_neighbors), Distance(False)])
    train_data = GlassDataset(train_metadata, transform=transform)
    val_data = GlassDataset(validation_metadata, transform=transform)
    test_data = GlassDataset(test_metadata, transform=transform)
    args.atom_fdim = 3
    args.bond_fdim = args.atom_fdim + 1

    # Dataset lengths
    train_data_length, val_data_length, test_data_length = len(train_data), len(val_data), len(test_data)
    debug('train size = {:,} | val size = {:,} | test size = {:,}'.format(
        train_data_length,
        val_data_length,
        test_data_length)
    )

    # Convert to iterators
    train_data = DataLoader(train_data, args.batch_size)
    val_data = DataLoader(val_data, args.batch_size)
    test_data = DataLoader(test_data, args.batch_size)

    # Get loss and metric functions
    loss_func = get_loss_func(args)
    metric_func = get_metric_func(args.metric)

    # Train ensemble of models
    for model_idx in range(args.ensemble_size):
        # Tensorboard writer
        save_dir = os.path.join(args.save_dir, 'model_{}'.format(model_idx))
        os.makedirs(save_dir, exist_ok=True)
        writer = SummaryWriter(log_dir=save_dir)

        # Load/build model
        if args.checkpoint_paths is not None:
            debug('Loading model {} from {}'.format(model_idx, args.checkpoint_paths[model_idx]))
            model = load_checkpoint(args.checkpoint_paths[model_idx], args.save_dir, attention_viz=args.attention_viz)
        else:
            debug('Building model {}'.format(model_idx))
            model = build_model(args)

        debug(model)
        debug('Number of parameters = {:,}'.format(param_count(model)))

        if args.cuda:
            debug('Moving model to cuda')
            model = model.cuda()

        # Ensure that model is saved in correct location for evaluation if 0 epochs
        save_checkpoint(model, args, os.path.join(save_dir, 'model.pt'))

        # Optimizer and learning rate scheduler
        optimizer = Adam(model.parameters(), lr=args.init_lr[model_idx], weight_decay=args.weight_decay[model_idx])

        scheduler = NoamLR(
            optimizer,
            warmup_epochs=args.warmup_epochs,
            total_epochs=[args.epochs],
            steps_per_epoch=train_data_length // args.batch_size,
            init_lr=args.init_lr,
            max_lr=args.max_lr,
            final_lr=args.final_lr
        )

        # Run training
        best_score = float('inf') if args.minimize_score else -float('inf')
        best_epoch, n_iter = 0, 0
        for epoch in trange(args.epochs):
            debug('Epoch {}'.format(epoch))

            n_iter = train(
                model=model,
                data=train_data,
                loss_func=loss_func,
                optimizer=optimizer,
                scheduler=scheduler,
                args=args,
                n_iter=n_iter,
                logger=logger,
                writer=writer
            )

            val_scores = []
            for val_runs in range(args.num_val_runs):

                val_batch_scores = evaluate(
                    model=model,
                    data=val_data,
                    metric_func=metric_func,
                    args=args,
                )

                val_scores.append(np.mean(val_batch_scores))

            # Average validation score
            avg_val_score = np.mean(val_scores)
            debug('Validation {} = {:.3f}'.format(args.metric, avg_val_score))
            writer.add_scalar('validation_{}'.format(args.metric), avg_val_score, n_iter)

            # Save model checkpoint if improved validation score
            if args.minimize_score and avg_val_score < best_score or \
                    not args.minimize_score and avg_val_score > best_score:
                best_score, best_epoch = avg_val_score, epoch
                save_checkpoint(model, args, os.path.join(save_dir, 'model.pt'))

        # Evaluate on test set using model with best validation score
        info('Model {} best validation {} = {:.3f} on epoch {}'.format(model_idx, args.metric, best_score, best_epoch))
        model = load_checkpoint(os.path.join(save_dir, 'model.pt'), args.save_dir, cuda=args.cuda,
                                attention_viz=args.attention_viz)

        test_scores = []
        for test_runs in range(args.num_test_runs):

            test_batch_scores = evaluate(
                model=model,
                data=test_data,
                metric_func=metric_func,
                args=args
            )

            test_scores.append(np.mean(test_batch_scores))

        # Get accuracy (assuming args.metric is set to AUC)
        metric_func_accuracy = get_metric_func('accuracy')
        test_scores_accuracy = []
        for test_runs in range(args.num_test_runs):

            test_batch_scores = evaluate(
                model=model,
                data=test_data,
                metric_func=metric_func_accuracy,
                args=args
            )

            test_scores_accuracy.append(np.mean(test_batch_scores))

        # Average test score
        avg_test_score = np.mean(test_scores)
        avg_test_accuracy = np.mean(test_scores_accuracy)
        info('Model {} test {} = {:.3f}, test {} = {:.3f}'.format(model_idx, args.metric,
                                                                  avg_test_score, 'accuracy', avg_test_accuracy))
        writer.add_scalar('test_{}'.format(args.metric), avg_test_score, n_iter)

        return avg_test_score, avg_test_accuracy  # For hyperparameter optimization or cross validation use
Esempio n. 9
0
File: train.py Progetto: ks8/glassML
def train(model: nn.Module,
          data: DataLoader,
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: NoamLR,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.
    :param model: Model.
    :param data: A DataLoader.
    :param loss_func: Loss function.
    :param optimizer: Optimizer.
    :param scheduler: A NoamLR learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    loss_sum, iter_count = 0, 0
    for batch in tqdm(data, total=len(data)):
        if args.cuda:
            targets = batch.y.float().unsqueeze(1).cuda()
        else:
            targets = batch.y.float().unsqueeze(1)
        batch = GlassBatchMolGraph(
            batch)  # TODO: Apply a check for connectivity of graph

        # Run model
        model.zero_grad()
        preds = model(batch)
        loss = loss_func(preds, targets)
        loss = loss.sum() / loss.size(0)

        loss_sum += loss.item()
        iter_count += len(batch)

        loss.backward()
        if args.max_grad_norm is not None:
            clip_grad_norm_(model.parameters(), args.max_grad_norm)
        optimizer.step()
        scheduler.step()

        n_iter += len(batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lr = scheduler.get_lr()[0]
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            debug("Loss = {:.4e}, PNorm = {:.4f}, GNorm = {:.4f}, lr = {:.4e}".
                  format(loss_avg, pnorm, gnorm, lr))

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                writer.add_scalar('learning_rate', lr, n_iter)

    return n_iter
Esempio n. 10
0
class GAN(nn.Module):
    def __init__(self, args: Namespace, prediction_model: nn.Module,
                 encoder: nn.Module):
        super(GAN, self).__init__()
        self.args = args
        self.prediction_model = prediction_model
        self.encoder = encoder

        self.hidden_size = args.hidden_size
        self.disc_input_size = args.hidden_size + args.output_size
        self.act_func = self.encoder.encoder.act_func

        self.netD = nn.Sequential(
            nn.Linear(self.disc_input_size, self.hidden_size
                      ),  # doesn't support jtnn or additional features rn
            self.act_func,
            nn.Linear(self.hidden_size, self.hidden_size),
            self.act_func,
            nn.Linear(self.hidden_size, self.hidden_size),
            self.act_func,
            nn.Linear(self.hidden_size, 1))
        self.beta = args.wgan_beta

        # the optimizers don't really belong here, but we put it here so that we don't clutter code for other opts
        self.optimizerG = Adam(self.encoder.parameters(),
                               lr=args.init_lr[0] * args.gan_lr_mult,
                               betas=(0, 0.9))
        self.optimizerD = Adam(self.netD.parameters(),
                               lr=args.init_lr[0] * args.gan_lr_mult,
                               betas=(0, 0.9))

        self.use_scheduler = args.gan_use_scheduler
        if self.use_scheduler:
            self.schedulerG = NoamLR(
                self.optimizerG,
                warmup_epochs=args.warmup_epochs,
                total_epochs=args.epochs,
                steps_per_epoch=args.train_data_length // args.batch_size,
                init_lr=args.init_lr[0] * args.gan_lr_mult,
                max_lr=args.max_lr[0] * args.gan_lr_mult,
                final_lr=args.final_lr[0] * args.gan_lr_mult)
            self.schedulerD = NoamLR(
                self.optimizerD,
                warmup_epochs=args.warmup_epochs,
                total_epochs=args.epochs,
                steps_per_epoch=(args.train_data_length // args.batch_size) *
                args.gan_d_per_g,
                init_lr=args.init_lr[0] * args.gan_lr_mult,
                max_lr=args.max_lr[0] * args.gan_lr_mult,
                final_lr=args.final_lr[0] * args.gan_lr_mult)

    def forward(self, smiles_batch: List[str], features=None) -> torch.Tensor:
        return self.prediction_model(smiles_batch, features)

    # TODO maybe this isn't the best way to wrap the MOE class, but it works for now
    def mahalanobis_metric(self, p, mu, j):
        return self.prediction_model.mahalanobis_metric(p, mu, j)

    def compute_domain_encs(self, all_train_smiles):
        return self.prediction_model.compute_domain_encs(all_train_smiles)

    def compute_minibatch_domain_encs(self, train_smiles):
        return self.prediction_model.compute_minibatch_domain_encs(
            train_smiles)

    def compute_loss(self, train_smiles, train_targets, test_smiles):
        return self.prediction_model.compute_loss(train_smiles, train_targets,
                                                  test_smiles)

    def set_domain_encs(self, domain_encs):
        self.prediction_model.domain_encs = domain_encs

    def get_domain_encs(self):
        return self.prediction_model.domain_encs

    # the following methods are code borrowed from Wengong and modified
    def train_D(self, fake_smiles: List[str], real_smiles: List[str]):
        self.netD.zero_grad()

        real_output = self.prediction_model(real_smiles).detach()
        real_enc_output = self.encoder.saved_encoder_output.detach()
        real_vecs = torch.cat([real_enc_output, real_output], dim=1)
        fake_output = self.prediction_model(fake_smiles).detach()
        fake_enc_output = self.encoder.saved_encoder_output.detach()
        fake_vecs = torch.cat([fake_enc_output, fake_output], dim=1)

        # real_vecs = self.encoder(mol2graph(real_smiles, self.args)).detach()
        # fake_vecs = self.encoder(mol2graph(fake_smiles, self.args)).detach()
        real_score = self.netD(real_vecs)
        fake_score = self.netD(fake_vecs)

        score = fake_score.mean() - real_score.mean(
        )  #maximize -> minimize minus
        score.backward()

        #Gradient Penalty
        inter_gp, inter_norm = self.gradient_penalty(real_vecs, fake_vecs)
        inter_gp.backward()

        self.optimizerD.step()
        if self.use_scheduler:
            self.schedulerD.step()

        return -score.item(), inter_norm

    def train_G(self, fake_smiles: List[str], real_smiles: List[str]):
        self.encoder.zero_grad()

        real_output = self.prediction_model(real_smiles).detach()
        real_enc_output = self.encoder.saved_encoder_output
        real_vecs = torch.cat([real_enc_output, real_output], dim=1)
        fake_output = self.prediction_model(fake_smiles).detach()
        fake_enc_output = self.encoder.saved_encoder_output
        fake_vecs = torch.cat([fake_enc_output, fake_output], dim=1)

        # real_vecs = self.encoder(mol2graph(real_smiles, self.args))
        # fake_vecs = self.encoder(mol2graph(fake_smiles, self.args))
        real_score = self.netD(real_vecs)
        fake_score = self.netD(fake_vecs)

        score = real_score.mean() - fake_score.mean()
        score.backward()

        self.optimizerG.step()
        if self.use_scheduler:
            self.schedulerG.step()
        self.netD.zero_grad(
        )  #technically not necessary since it'll get zero'd in the next iteration anyway

        return score.item()

    def gradient_penalty(self, real_vecs, fake_vecs):
        assert real_vecs.size() == fake_vecs.size()
        eps = torch.rand(real_vecs.size(0), 1).cuda()
        inter_data = eps * real_vecs + (1 - eps) * fake_vecs
        inter_data = autograd.Variable(
            inter_data, requires_grad=True
        )  # TODO check if this is necessary (we detached earlier)
        inter_score = self.netD(inter_data)
        inter_score = inter_score.view(-1)  # bs*hidden

        inter_grad = autograd.grad(inter_score,
                                   inter_data,
                                   grad_outputs=torch.ones(
                                       inter_score.size()).cuda(),
                                   create_graph=True,
                                   retain_graph=True,
                                   only_inputs=True)[0]

        inter_norm = inter_grad.norm(2, dim=1)
        inter_gp = ((inter_norm - 1)**2).mean() * self.beta

        return inter_gp, inter_norm.mean().item()
def train_gp(
        model,
        train_data,
        val_data,
        num_workers,
        cache,
        metric_func,
        scaler,
        features_scaler,
        args,
        save_dir):
    
    
    # create data loaders for gp (allows different batch size)
    train_data_loader = MoleculeDataLoader(
        dataset=train_data,
        batch_size=args.batch_size_gp,
        num_workers=num_workers,
        cache=cache,
        class_balance=args.class_balance,
        shuffle=True,
        seed=args.seed
    )
    val_data_loader = MoleculeDataLoader(
        dataset=val_data,
        batch_size=args.batch_size_gp,
        num_workers=num_workers,
        cache=cache
    )
    
    # feature_extractor
    model.featurizer = True
    feature_extractor = model
    
    # inducing points
    inducing_points = initial_inducing_points(
        train_data_loader,
        feature_extractor,
        args
        )
    
    # GP layer
    gp_layer = GPLayer(inducing_points, args.num_tasks)
    
    # full DKL model
    model = copy.deepcopy(DKLMoleculeModel(feature_extractor, gp_layer))
    
    # likelihood
    # rank 0 restricts to diagonal matrix
    likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=12, rank=0)

    # model and likelihood to CUDA
    if args.cuda:
        model.cuda()
        likelihood.cuda()

    # loss object
    mll = gpytorch.mlls.VariationalELBO(likelihood, model.gp_layer, num_data=args.train_data_size)
    
    # optimizer
    params_list = [
        {'params': model.feature_extractor.parameters(), 'weight_decay': args.weight_decay_gp},
        {'params': model.gp_layer.hyperparameters()},
        {'params': model.gp_layer.variational_parameters()},
        {'params': likelihood.parameters()},
    ]    
    optimizer = torch.optim.Adam(params_list, lr = args.init_lr_gp)    
    
    # scheduler
    num_params = len(params_list)
    scheduler = NoamLR(
        optimizer=optimizer,
        warmup_epochs=[args.warmup_epochs_gp]*num_params,
        total_epochs=[args.noam_epochs_gp]*num_params,
        steps_per_epoch=args.train_data_size // args.batch_size_gp,
        init_lr=[args.init_lr_gp]*num_params,
        max_lr=[args.max_lr_gp]*num_params,
        final_lr=[args.final_lr_gp]*num_params)
        
    
    print("----------GP training----------")
    
    # training loop
    best_score = float('inf') if args.minimize_score else -float('inf')
    best_epoch, n_iter = 0, 0
    for epoch in range(args.epochs_gp):
        print(f'GP epoch {epoch}')
        
        if epoch == args.noam_epochs_gp:
            scheduler = scheduler_const([args.final_lr_gp])
    
        n_iter = train(
                model=model,
                data_loader=train_data_loader,
                loss_func=mll,
                optimizer=optimizer,
                scheduler=scheduler,
                args=args,
                n_iter=n_iter,
                gp_switch=True,
                likelihood = likelihood
            )
    
        val_scores = evaluate(
            model=model,
            data_loader=val_data_loader,
            args=args,
            num_tasks=args.num_tasks,
            metric_func=metric_func,
            dataset_type=args.dataset_type,
            scaler=scaler
        )

        # Average validation score
        avg_val_score = np.nanmean(val_scores)
        print(f'Validation {args.metric} = {avg_val_score:.6f}')
        wandb.log({"Validation MAE": avg_val_score})

        # Save model AND LIKELIHOOD checkpoint if improved validation score
        if args.minimize_score and avg_val_score < best_score or \
                not args.minimize_score and avg_val_score > best_score:
            best_score, best_epoch = avg_val_score, epoch
            save_checkpoint(os.path.join(save_dir, 'DKN_model.pt'), model, scaler, features_scaler, args)
            best_likelihood = copy.deepcopy(likelihood)
            
            
    # load model with best validation score
    # NOTE: TEMPLATE MUST BE NEWLY INSTANTIATED MODEL
    print(f'Loading model with best validation {args.metric} = {best_score:.6f} on epoch {best_epoch}')
    model = load_checkpoint(os.path.join(save_dir, 'DKN_model.pt'), device=args.device, logger=None,
                            template = DKLMoleculeModel(MoleculeModel(args, featurizer=True), gp_layer))

    
    return model, best_likelihood