def train(model: nn.Module, data: DataLoader, loss_func: Callable, optimizer: Optimizer, scheduler: NoamLR, args: Namespace, n_iter: int = 0, logger: logging.Logger = None, writer: SummaryWriter = None) -> int: """ Trains a model for an epoch. :param model: Model. :param data: A DataLoader. :param loss_func: Loss function. :param optimizer: Optimizer. :param scheduler: A NoamLR learning rate scheduler. :param args: Arguments. :param n_iter: The number of iterations (training examples) trained on so far. :param logger: A logger for printing intermediate results. :param writer: A tensorboardX SummaryWriter. :return: The total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() loss_sum, iter_count = 0, 0 for batch in tqdm(data, total=len(data)): if args.cuda: targets = batch.y.float().unsqueeze(1).cuda() else: targets = batch.y.float().unsqueeze(1) batch = GlassBatchMolGraph( batch) # TODO: Apply a check for connectivity of graph # Run model model.zero_grad() preds = model(batch) loss = loss_func(preds, targets) loss = loss.sum() / loss.size(0) loss_sum += loss.item() iter_count += len(batch) loss.backward() if args.max_grad_norm is not None: clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() n_iter += len(batch) # Log and/or add to tensorboard if (n_iter // args.batch_size) % args.log_frequency == 0: lr = scheduler.get_lr()[0] pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = loss_sum / iter_count loss_sum, iter_count = 0, 0 debug("Loss = {:.4e}, PNorm = {:.4f}, GNorm = {:.4f}, lr = {:.4e}". format(loss_avg, pnorm, gnorm, lr)) if writer is not None: writer.add_scalar('train_loss', loss_avg, n_iter) writer.add_scalar('param_norm', pnorm, n_iter) writer.add_scalar('gradient_norm', gnorm, n_iter) writer.add_scalar('learning_rate', lr, n_iter) return n_iter
class GAN(nn.Module): def __init__(self, args: Namespace, prediction_model: nn.Module, encoder: nn.Module): super(GAN, self).__init__() self.args = args self.prediction_model = prediction_model self.encoder = encoder self.hidden_size = args.hidden_size self.disc_input_size = args.hidden_size + args.output_size self.act_func = self.encoder.encoder.act_func self.netD = nn.Sequential( nn.Linear(self.disc_input_size, self.hidden_size ), # doesn't support jtnn or additional features rn self.act_func, nn.Linear(self.hidden_size, self.hidden_size), self.act_func, nn.Linear(self.hidden_size, self.hidden_size), self.act_func, nn.Linear(self.hidden_size, 1)) self.beta = args.wgan_beta # the optimizers don't really belong here, but we put it here so that we don't clutter code for other opts self.optimizerG = Adam(self.encoder.parameters(), lr=args.init_lr[0] * args.gan_lr_mult, betas=(0, 0.9)) self.optimizerD = Adam(self.netD.parameters(), lr=args.init_lr[0] * args.gan_lr_mult, betas=(0, 0.9)) self.use_scheduler = args.gan_use_scheduler if self.use_scheduler: self.schedulerG = NoamLR( self.optimizerG, warmup_epochs=args.warmup_epochs, total_epochs=args.epochs, steps_per_epoch=args.train_data_length // args.batch_size, init_lr=args.init_lr[0] * args.gan_lr_mult, max_lr=args.max_lr[0] * args.gan_lr_mult, final_lr=args.final_lr[0] * args.gan_lr_mult) self.schedulerD = NoamLR( self.optimizerD, warmup_epochs=args.warmup_epochs, total_epochs=args.epochs, steps_per_epoch=(args.train_data_length // args.batch_size) * args.gan_d_per_g, init_lr=args.init_lr[0] * args.gan_lr_mult, max_lr=args.max_lr[0] * args.gan_lr_mult, final_lr=args.final_lr[0] * args.gan_lr_mult) def forward(self, smiles_batch: List[str], features=None) -> torch.Tensor: return self.prediction_model(smiles_batch, features) # TODO maybe this isn't the best way to wrap the MOE class, but it works for now def mahalanobis_metric(self, p, mu, j): return self.prediction_model.mahalanobis_metric(p, mu, j) def compute_domain_encs(self, all_train_smiles): return self.prediction_model.compute_domain_encs(all_train_smiles) def compute_minibatch_domain_encs(self, train_smiles): return self.prediction_model.compute_minibatch_domain_encs( train_smiles) def compute_loss(self, train_smiles, train_targets, test_smiles): return self.prediction_model.compute_loss(train_smiles, train_targets, test_smiles) def set_domain_encs(self, domain_encs): self.prediction_model.domain_encs = domain_encs def get_domain_encs(self): return self.prediction_model.domain_encs # the following methods are code borrowed from Wengong and modified def train_D(self, fake_smiles: List[str], real_smiles: List[str]): self.netD.zero_grad() real_output = self.prediction_model(real_smiles).detach() real_enc_output = self.encoder.saved_encoder_output.detach() real_vecs = torch.cat([real_enc_output, real_output], dim=1) fake_output = self.prediction_model(fake_smiles).detach() fake_enc_output = self.encoder.saved_encoder_output.detach() fake_vecs = torch.cat([fake_enc_output, fake_output], dim=1) # real_vecs = self.encoder(mol2graph(real_smiles, self.args)).detach() # fake_vecs = self.encoder(mol2graph(fake_smiles, self.args)).detach() real_score = self.netD(real_vecs) fake_score = self.netD(fake_vecs) score = fake_score.mean() - real_score.mean( ) #maximize -> minimize minus score.backward() #Gradient Penalty inter_gp, inter_norm = self.gradient_penalty(real_vecs, fake_vecs) inter_gp.backward() self.optimizerD.step() if self.use_scheduler: self.schedulerD.step() return -score.item(), inter_norm def train_G(self, fake_smiles: List[str], real_smiles: List[str]): self.encoder.zero_grad() real_output = self.prediction_model(real_smiles).detach() real_enc_output = self.encoder.saved_encoder_output real_vecs = torch.cat([real_enc_output, real_output], dim=1) fake_output = self.prediction_model(fake_smiles).detach() fake_enc_output = self.encoder.saved_encoder_output fake_vecs = torch.cat([fake_enc_output, fake_output], dim=1) # real_vecs = self.encoder(mol2graph(real_smiles, self.args)) # fake_vecs = self.encoder(mol2graph(fake_smiles, self.args)) real_score = self.netD(real_vecs) fake_score = self.netD(fake_vecs) score = real_score.mean() - fake_score.mean() score.backward() self.optimizerG.step() if self.use_scheduler: self.schedulerG.step() self.netD.zero_grad( ) #technically not necessary since it'll get zero'd in the next iteration anyway return score.item() def gradient_penalty(self, real_vecs, fake_vecs): assert real_vecs.size() == fake_vecs.size() eps = torch.rand(real_vecs.size(0), 1).cuda() inter_data = eps * real_vecs + (1 - eps) * fake_vecs inter_data = autograd.Variable( inter_data, requires_grad=True ) # TODO check if this is necessary (we detached earlier) inter_score = self.netD(inter_data) inter_score = inter_score.view(-1) # bs*hidden inter_grad = autograd.grad(inter_score, inter_data, grad_outputs=torch.ones( inter_score.size()).cuda(), create_graph=True, retain_graph=True, only_inputs=True)[0] inter_norm = inter_grad.norm(2, dim=1) inter_gp = ((inter_norm - 1)**2).mean() * self.beta return inter_gp, inter_norm.mean().item()