def train(model: MoleculeModel, data_loader: MoleculeDataLoader, loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: TrainArgs, n_iter: int = 0, logger: logging.Logger = None, writer: SummaryWriter = None) -> int: """ Trains a model for an epoch. :param model: A :class:`~chemprop.models.model.MoleculeModel`. :param data_loader: A :class:`~chemprop.data.data.MoleculeDataLoader`. :param loss_func: Loss function. :param optimizer: An optimizer. :param scheduler: A learning rate scheduler. :param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for training the model. :param n_iter: The number of iterations (training examples) trained on so far. :param logger: A logger for recording output. :param writer: A tensorboardX SummaryWriter. :return: The total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() loss_sum, iter_count = 0, 0 for batch in tqdm(data_loader, total=len(data_loader)): # Prepare batch batch: MoleculeDataset mol_batch, features_batch, target_batch = batch.batch_graph( ), batch.features(), batch.targets() mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch]) targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch]) # Run model model.zero_grad() preds = model(mol_batch, features_batch) # Move tensors to correct device mask = mask.to(preds.device) targets = targets.to(preds.device) class_weights = torch.ones(targets.shape, device=preds.device) if args.dataset_type == 'multiclass': targets = targets.long() loss = torch.cat([ loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) for target_index in range(preds.size(1)) ], dim=1) * class_weights * mask else: loss = loss_func(preds, targets) * class_weights * mask loss = loss.sum() / mask.sum() loss_sum += loss.item() iter_count += len(batch) loss.backward() if args.grad_clip: nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() if isinstance(scheduler, NoamLR): scheduler.step() n_iter += len(batch) # Log and/or add to tensorboard if (n_iter // args.batch_size) % args.log_frequency == 0: lrs = scheduler.get_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = loss_sum / iter_count loss_sum, iter_count = 0, 0 lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) debug( f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}' ) if writer is not None: writer.add_scalar('train_loss', loss_avg, n_iter) writer.add_scalar('param_norm', pnorm, n_iter) writer.add_scalar('gradient_norm', gnorm, n_iter) for i, lr in enumerate(lrs): writer.add_scalar(f'learning_rate_{i}', lr, n_iter) return n_iter
def train(model: nn.Module, data: Union[MoleculeDataset, List[MoleculeDataset]], loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: Namespace, n_iter: int = 0, logger: logging.Logger = None, writer: SummaryWriter = None) -> int: """ Trains a model for an epoch. :param model: Model. :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe). :param loss_func: Loss function. :param optimizer: An Optimizer. :param scheduler: A learning rate scheduler. :param args: Arguments. :param n_iter: The number of iterations (training examples) trained on so far. :param logger: A logger for printing intermediate results. :param writer: A tensorboardX SummaryWriter. :return: The total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() data.shuffle() loss_sum, iter_count = 0, 0 iter_size = args.batch_size if args.class_balance: # Reconstruct data so that each batch has equal number of positives and negatives # (will leave out a different random sample of negatives each epoch) assert len( data[0].targets) == 1 # only works for single class classification pos = [d for d in data if d.targets[0] == 1] neg = [d for d in data if d.targets[0] == 0] new_data = [] pos_size = iter_size // 2 pos_index = neg_index = 0 while True: new_pos = pos[pos_index:pos_index + pos_size] new_neg = neg[neg_index:neg_index + iter_size - len(new_pos)] if len(new_pos) == 0 or len(new_neg) == 0: break if len(new_pos) + len(new_neg) < iter_size: new_pos = pos[pos_index:pos_index + iter_size - len(new_neg)] new_data += new_pos + new_neg pos_index += len(new_pos) neg_index += len(new_neg) data = new_data num_iters = len( data ) // args.batch_size * args.batch_size # don't use the last batch if it's small, for stability for i in trange(0, num_iters, iter_size): # Prepare batch if i + args.batch_size > len(data): break mol_batch = MoleculeDataset(data[i:i + args.batch_size]) smiles_batch, features_batch, target_batch, weight_batch = mol_batch.smiles( ), mol_batch.features(), mol_batch.targets(), mol_batch.weights() batch = smiles_batch mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch]) targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch]) weights = torch.Tensor([[0 if x is None else x for x in tb] for tb in weight_batch]) # print (weight_batch) # print (weights) if next(model.parameters()).is_cuda: mask, targets = mask.cuda(), targets.cuda() if args.enable_weight: class_weights = weights else: class_weights = torch.ones(targets.shape) # print(class_weights) if args.cuda: class_weights = class_weights.cuda() # Run model model.zero_grad() preds = model(batch, features_batch) if args.dataset_type == 'multiclass': targets = targets.long() loss = torch.cat([ loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) for target_index in range(preds.size(1)) ], dim=1) * class_weights * mask else: loss = loss_func(preds, targets) * class_weights * mask # print ("loss") # print (loss) # print (class_weights) loss = loss.sum() / mask.sum() loss_sum += loss.item() iter_count += len(mol_batch) loss.backward() optimizer.step() if isinstance(scheduler, NoamLR): scheduler.step() n_iter += len(mol_batch) # Log and/or add to tensorboard if (n_iter // args.batch_size) % args.log_frequency == 0: lrs = scheduler.get_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = loss_sum / iter_count loss_sum, iter_count = 0, 0 lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) debug( f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}' ) if writer is not None: writer.add_scalar('train_loss', loss_avg, n_iter) writer.add_scalar('param_norm', pnorm, n_iter) writer.add_scalar('gradient_norm', gnorm, n_iter) for i, lr in enumerate(lrs): writer.add_scalar(f'learning_rate_{i}', lr, n_iter) return n_iter
def train(model: nn.Module, data: Union[MoleculeDataset, List[MoleculeDataset]], loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: Namespace, n_iter: int = 0, logger: logging.Logger = None, writer: SummaryWriter = None) -> int: """ Trains a model for an epoch. :param model: Model. :param data: A MoleculeDataset (or list of MoleculeDatasets if using moe). :param loss_func: Loss function. :param optimizer: An Optimizer. :param scheduler: A learning rate scheduler. :param args: Arguments. :param n_iter: Number of iterations (training examples) trained on so far. :param logger: A logger for printing intermediate results. :param writer: A tensorboardX SummaryWriter. :return: Total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() data.shuffle() loss_sum, iter_count = 0, 0 # don't use the last batch if it's small, for stability num_iters = len(data) // args.batch_size * args.batch_size iter_size = args.batch_size for i in trange(0, num_iters, iter_size): # Prepare batch if i + args.batch_size > len(data): break mol_batch = MoleculeDataset(data[i:i + args.batch_size]) smiles_batch, features_batch, target_batch = \ mol_batch.smiles(), mol_batch.features(), mol_batch.targets() mask = torch.Tensor([[not np.isnan(x) for x in tb] for tb in target_batch]) targets = torch.Tensor([[0 if np.isnan(x) else x for x in tb] for tb in target_batch]) if next(model.parameters()).is_cuda: mask, targets = mask.cuda(), targets.cuda() class_weights = torch.ones(targets.shape) if args.cuda: class_weights = class_weights.cuda() # Run model model.zero_grad() preds = model(smiles_batch, features_batch) # todo: change the loss function for property prediction tasks if args.dataset_type == 'multiclass': targets = targets.long() loss = torch.cat([loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) for target_index in range(preds.size(1))], dim=1) * class_weights * mask else: loss = loss_func(preds, targets) * class_weights * mask loss = loss.sum() / mask.sum() loss_sum += loss.item() iter_count += len(mol_batch) loss.backward() optimizer.step() if isinstance(scheduler, NoamLR): scheduler.step() n_iter += len(mol_batch) # Log and/or add to tensorboard if (n_iter // args.batch_size) % args.log_frequency == 0: lrs = scheduler.get_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = loss_sum / iter_count loss_sum, iter_count = 0, 0 lrs_str = ', '.join( f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) debug( f'\nLoss = {loss_avg:.4e}, PNorm = {pnorm:.4f},' f' GNorm = {gnorm:.4f}, {lrs_str}') if writer is not None: writer.add_scalar('train_loss', loss_avg, n_iter) writer.add_scalar('param_norm', pnorm, n_iter) writer.add_scalar('gradient_norm', gnorm, n_iter) for idx, learn_rate in enumerate(lrs): writer.add_scalar( f'learning_rate_{idx}', learn_rate, n_iter) return n_iter
def train(model: nn.Module, data: Union[MoleculeDataset, List[MoleculeDataset]], loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: Namespace, n_iter: int = 0, logger: logging.Logger = None, writer: SummaryWriter = None) -> int: """ Trains a model for an epoch. :param model: Model. :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe). :param loss_func: Loss function. :param optimizer: An Optimizer. :param scheduler: A learning rate scheduler. :param args: Arguments. :param n_iter: The number of iterations (training examples) trained on so far. :param logger: A logger for printing intermediate results. :param writer: A tensorboardX SummaryWriter. :return: The total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() data.shuffle() loss_sum, iter_count = 0, 0 num_iters = len( data ) // args.batch_size * args.batch_size # don't use the last batch if it's small, for stability iter_size = args.batch_size for i in trange(0, num_iters, iter_size): # Prepare batch if i + args.batch_size > len(data): break mol_batch = MoleculeDataset(data[i:i + args.batch_size]) smiles_batch, features_batch, target_batch = mol_batch.smiles( ), mol_batch.features(), mol_batch.targets() batch = smiles_batch mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch]) targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch]) if next(model.parameters()).is_cuda: mask, targets = mask.cuda(), targets.cuda() class_weights = torch.ones(targets.shape) #print('class_weight',class_weights.size(),class_weights) #print('mask',mask.size(),mask) if args.cuda: class_weights = class_weights.cuda() # Run model model.zero_grad() preds = model(batch, features_batch) if args.dataset_type == 'multiclass': targets = targets.long() loss = torch.cat([ loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) for target_index in range(preds.size(1)) ], dim=1) * class_weights * mask else: loss = loss_func(preds, targets) * class_weights * mask loss = loss.sum() / mask.sum() ############ add L1 regularization ############ ffn_d0_L1_reg_loss = 0 ffn_d1_L1_reg_loss = 0 ffn_d2_L1_reg_loss = 0 ffn_final_L1_reg_loss = 0 ffn_mol_L1_reg_loss = 0 lamda_ffn_d0 = 0 lamda_ffn_d1 = 0 lamda_ffn_d2 = 0 lamda_ffn_final = 0 lamda_ffn_mol = 0 for param in model.ffn_d0.parameters(): ffn_d0_L1_reg_loss += torch.sum(torch.abs(param)) for param in model.ffn_d1.parameters(): ffn_d1_L1_reg_loss += torch.sum(torch.abs(param)) for param in model.ffn_d2.parameters(): ffn_d2_L1_reg_loss += torch.sum(torch.abs(param)) for param in model.ffn_final.parameters(): ffn_final_L1_reg_loss += torch.sum(torch.abs(param)) for param in model.ffn_mol.parameters(): ffn_mol_L1_reg_loss += torch.sum(torch.abs(param)) loss += lamda_ffn_d0 * ffn_d0_L1_reg_loss + lamda_ffn_d1 * ffn_d1_L1_reg_loss + lamda_ffn_d2 * ffn_d2_L1_reg_loss + lamda_ffn_final * ffn_final_L1_reg_loss + lamda_ffn_mol * ffn_mol_L1_reg_loss ############ add L1 regularization ############ ############ add L2 regularization ############ ''' ffn_d0_L2_reg_loss = 0 ffn_d1_L2_reg_loss = 0 ffn_d2_L2_reg_loss = 0 ffn_final_L2_reg_loss = 0 ffn_mol_L2_reg_loss = 0 lamda_ffn_d0 = 1e-6 lamda_ffn_d1 = 1e-6 lamda_ffn_d2 = 1e-5 lamda_ffn_final = 1e-4 lamda_ffn_mol = 1e-3 for param in model.ffn_d0.parameters(): ffn_d0_L2_reg_loss += torch.sum(torch.square(param)) for param in model.ffn_d1.parameters(): ffn_d1_L2_reg_loss += torch.sum(torch.square(param)) for param in model.ffn_d2.parameters(): ffn_d2_L2_reg_loss += torch.sum(torch.square(param)) for param in model.ffn_final.parameters(): ffn_final_L2_reg_loss += torch.sum(torch.square(param)) for param in model.ffn_mol.parameters(): ffn_mol_L2_reg_loss += torch.sum(torch.square(param)) loss += lamda_ffn_d0 * ffn_d0_L2_reg_loss + lamda_ffn_d1 * ffn_d1_L2_reg_loss + lamda_ffn_d2 * ffn_d2_L2_reg_loss + lamda_ffn_final * ffn_final_L2_reg_loss + lamda_ffn_mol * ffn_mol_L2_reg_loss ''' ############ add L2 regularization ############ loss_sum += loss.item() iter_count += len(mol_batch) #loss.backward(retain_graph=True) # wei, retain_graph=True loss.backward() optimizer.step() if isinstance(scheduler, NoamLR): scheduler.step() n_iter += len(mol_batch) # Log and/or add to tensorboard if (n_iter // args.batch_size) % args.log_frequency == 0: lrs = scheduler.get_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = loss_sum / iter_count loss_sum, iter_count = 0, 0 lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) debug( f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}' ) if writer is not None: writer.add_scalar('train_loss', loss_avg, n_iter) writer.add_scalar('param_norm', pnorm, n_iter) writer.add_scalar('gradient_norm', gnorm, n_iter) for i, lr in enumerate(lrs): writer.add_scalar(f'learning_rate_{i}', lr, n_iter) #print(model) return n_iter
def train(model: nn.Module, data: Union[MoleculeDataset, List[MoleculeDataset]], loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: Namespace, n_iter: int = 0, logger: logging.Logger = None, writer: SummaryWriter = None, chunk_names: bool = False, val_smiles: List[str] = None, test_smiles: List[str] = None) -> int: """ Trains a model for an epoch. :param model: Model. :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe). :param loss_func: Loss function. :param optimizer: An Optimizer. :param scheduler: A learning rate scheduler. :param args: Arguments. :param n_iter: The number of iterations (training examples) trained on so far. :param logger: A logger for printing intermediate results. :param writer: A tensorboardX SummaryWriter. :param chunk_names: Whether to train on the data in chunks. In this case, data must be a list of paths to the data chunks. :param val_smiles: Validation smiles strings without targets. :param test_smiles: Test smiles strings without targets, used for adversarial setting. :return: The total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() if args.dataset_type == 'bert_pretraining': features_loss = nn.MSELoss() if chunk_names: for path, memo_path in tqdm(data, total=len(data)): featurization.SMILES_TO_FEATURES = dict() if os.path.isfile(memo_path): found_memo = True with open(memo_path, 'rb') as f: featurization.SMILES_TO_FEATURES = pickle.load(f) else: found_memo = False with open(path, 'rb') as f: chunk = pickle.load(f) if args.moe: for source in chunk: source.shuffle() else: chunk.shuffle() n_iter = train(model=model, data=chunk, loss_func=loss_func, optimizer=optimizer, scheduler=scheduler, args=args, n_iter=n_iter, logger=logger, writer=writer, chunk_names=False, val_smiles=val_smiles, test_smiles=test_smiles) if not found_memo: with open(memo_path, 'wb') as f: pickle.dump(featurization.SMILES_TO_GRAPH, f, protocol=pickle.HIGHEST_PROTOCOL) return n_iter if not args.moe: data.shuffle() loss_sum, iter_count = 0, 0 if args.adversarial: if args.moe: train_smiles = [] for d in data: train_smiles += d.smiles() else: train_smiles = data.smiles() train_val_smiles = train_smiles + val_smiles d_loss_sum, g_loss_sum, gp_norm_sum = 0, 0, 0 if args.moe: test_smiles = list(test_smiles) random.shuffle(test_smiles) train_smiles = [] for d in data: d.shuffle() train_smiles.append(d.smiles()) num_iters = min(len(test_smiles), min([len(d) for d in data])) elif args.maml: num_iters = args.maml_batches_per_epoch * args.maml_batch_size model.zero_grad() maml_sum_loss = 0 else: num_iters = len(data) if args.last_batch else len( data) // args.batch_size * args.batch_size if args.parallel_featurization: batch_queue = Queue(args.batch_queue_max_size) exit_queue = Queue(1) batch_process = Process(target=async_mol2graph, args=(batch_queue, data, args, num_iters, args.batch_size, exit_queue, args.last_batch)) batch_process.start() currently_loaded_batches = [] iter_size = 1 if args.maml else args.batch_size for i in trange(0, num_iters, iter_size): if args.moe: if not args.batch_domain_encs: model.compute_domain_encs( train_smiles) # want to recompute every batch mol_batch = [ MoleculeDataset(d[i:i + args.batch_size]) for d in data ] train_batch, train_targets = [], [] for b in mol_batch: tb, tt = b.smiles(), b.targets() train_batch.append(tb) train_targets.append(tt) test_batch = test_smiles[i:i + args.batch_size] loss = model.compute_loss(train_batch, train_targets, test_batch) model.zero_grad() loss_sum += loss.item() iter_count += len(mol_batch) elif args.maml: task_train_data, task_test_data, task_idx = data.sample_maml_task( args) mol_batch = task_test_data smiles_batch, features_batch, target_batch = task_train_data.smiles( ), task_train_data.features(), task_train_data.targets(task_idx) # no mask since we only picked data points that have the desired target targets = torch.Tensor(target_batch).unsqueeze(1) if next(model.parameters()).is_cuda: targets = targets.cuda() preds = model(smiles_batch, features_batch) loss = loss_func(preds, targets) loss = loss.sum() / len(smiles_batch) grad = torch.autograd.grad( loss, [p for p in model.parameters() if p.requires_grad]) theta = [ p for p in model.named_parameters() if p[1].requires_grad ] # comes in same order as grad theta_prime = { p[0]: p[1] - args.maml_lr * grad[i] for i, p in enumerate(theta) } for name, nongrad_param in [ p for p in model.named_parameters() if not p[1].requires_grad ]: theta_prime[name] = nongrad_param + torch.zeros( nongrad_param.size()).to(nongrad_param) else: # Prepare batch if args.parallel_featurization: if len(currently_loaded_batches) == 0: currently_loaded_batches = batch_queue.get() mol_batch, featurized_mol_batch = currently_loaded_batches.pop( ) else: if not args.last_batch and i + args.batch_size > len(data): break mol_batch = MoleculeDataset(data[i:i + args.batch_size]) smiles_batch, features_batch, target_batch = mol_batch.smiles( ), mol_batch.features(), mol_batch.targets() if args.dataset_type == 'bert_pretraining': batch = mol2graph(smiles_batch, args) mask = mol_batch.mask() batch.bert_mask(mask) mask = 1 - torch.FloatTensor(mask) # num_atoms features_targets = torch.FloatTensor( target_batch['features'] ) if target_batch[ 'features'] is not None else None # num_molecules x features_size targets = torch.FloatTensor(target_batch['vocab']) # num_atoms if args.bert_vocab_func == 'feature_vector': mask = mask.reshape(-1, 1) else: targets = targets.long() else: batch = smiles_batch mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch]) targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch]) if next(model.parameters()).is_cuda: mask, targets = mask.cuda(), targets.cuda() if args.dataset_type == 'bert_pretraining' and features_targets is not None: features_targets = features_targets.cuda() if args.class_balance: class_weights = [] for task_num in range(data.num_tasks()): class_weights.append( args.class_weights[task_num][targets[:, task_num].long()]) class_weights = torch.stack( class_weights).t() # num_molecules x num_tasks else: class_weights = torch.ones(targets.shape) if args.cuda: class_weights = class_weights.cuda() # Run model model.zero_grad() if args.parallel_featurization: previous_graph_input_mode = model.encoder.graph_input model.encoder.graph_input = True # force model to accept already processed input preds = model(featurized_mol_batch, features_batch) model.encoder.graph_input = previous_graph_input_mode else: preds = model(batch, features_batch) if args.dataset_type == 'regression_with_binning': preds = preds.view(targets.size(0), targets.size(1), -1) targets = targets.long() loss = 0 for task in range(targets.size(1)): loss += loss_func( preds[:, task, :], targets[:, task] ) * class_weights[:, task] * mask[:, task] # for some reason cross entropy doesn't support multi target loss = loss.sum() / mask.sum() else: if args.dataset_type == 'unsupervised': targets = targets.long().reshape(-1) if args.dataset_type == 'bert_pretraining': features_preds, preds = preds['features'], preds['vocab'] if args.dataset_type == 'kernel': preds = preds.view(int(preds.size(0) / 2), 2, preds.size(1)) preds = model.kernel_output_layer(preds) loss = loss_func(preds, targets) * class_weights * mask if args.predict_features_and_task: loss = (loss.sum() + loss[:, :-args.features_size].sum() * (args.task_weight-1)) \ / (mask.sum() + mask[:, :-args.features_size].sum() * (args.task_weight-1)) else: loss = loss.sum() / mask.sum() if args.dataset_type == 'bert_pretraining' and features_targets is not None: loss += features_loss(features_preds, features_targets) loss_sum += loss.item() iter_count += len(mol_batch) if args.maml: model_prime = build_model(args=args, params=theta_prime) smiles_batch, features_batch, target_batch = task_test_data.smiles( ), task_test_data.features(), [ t[task_idx] for t in task_test_data.targets() ] # no mask since we only picked data points that have the desired target targets = torch.Tensor([[t] for t in target_batch]) if next(model_prime.parameters()).is_cuda: targets = targets.cuda() model_prime.zero_grad() preds = model_prime(smiles_batch, features_batch) loss = loss_func(preds, targets) loss = loss.sum() / len(smiles_batch) loss_sum += loss.item() iter_count += len( smiles_batch ) # TODO check that this makes sense, but it's just for display maml_sum_loss += loss if i % args.maml_batch_size == args.maml_batch_size - 1: maml_sum_loss.backward() optimizer.step() model.zero_grad() maml_sum_loss = 0 else: loss.backward() if args.max_grad_norm is not None: clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() if args.adjust_weight_decay: current_pnorm = compute_pnorm(model) if current_pnorm < args.pnorm_target: for i in range(len(optimizer.param_groups)): optimizer.param_groups[i]['weight_decay'] = max( 0, optimizer.param_groups[i]['weight_decay'] - args.adjust_weight_decay_step) else: for i in range(len(optimizer.param_groups)): optimizer.param_groups[i][ 'weight_decay'] += args.adjust_weight_decay_step if isinstance(scheduler, NoamLR): scheduler.step() if args.adversarial: for _ in range(args.gan_d_per_g): train_val_smiles_batch = random.sample(train_val_smiles, args.batch_size) test_smiles_batch = random.sample(test_smiles, args.batch_size) d_loss, gp_norm = model.train_D(train_val_smiles_batch, test_smiles_batch) train_val_smiles_batch = random.sample(train_val_smiles, args.batch_size) test_smiles_batch = random.sample(test_smiles, args.batch_size) g_loss = model.train_G(train_val_smiles_batch, test_smiles_batch) # we probably only care about the g_loss honestly d_loss_sum += d_loss * args.batch_size gp_norm_sum += gp_norm * args.batch_size g_loss_sum += g_loss * args.batch_size n_iter += len(mol_batch) # Log and/or add to tensorboard if (n_iter // args.batch_size) % args.log_frequency == 0: lrs = scheduler.get_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = loss_sum / iter_count if args.adversarial: d_loss_avg, g_loss_avg, gp_norm_avg = d_loss_sum / iter_count, g_loss_sum / iter_count, gp_norm_sum / iter_count d_loss_sum, g_loss_sum, gp_norm_sum = 0, 0, 0 loss_sum, iter_count = 0, 0 lrs_str = ', '.join('lr_{} = {:.4e}'.format(i, lr) for i, lr in enumerate(lrs)) debug("Loss = {:.4e}, PNorm = {:.4f}, GNorm = {:.4f}, {}".format( loss_avg, pnorm, gnorm, lrs_str)) if args.adversarial: debug( "D Loss = {:.4e}, G Loss = {:.4e}, GP Norm = {:.4}".format( d_loss_avg, g_loss_avg, gp_norm_avg)) if writer is not None: writer.add_scalar('train_loss', loss_avg, n_iter) writer.add_scalar('param_norm', pnorm, n_iter) writer.add_scalar('gradient_norm', gnorm, n_iter) for i, lr in enumerate(lrs): writer.add_scalar('learning_rate_{}'.format(i), lr, n_iter) if args.parallel_featurization: exit_queue.put( 0) # dummy var to get the subprocess to know that we're done batch_process.join() return n_iter
def train(model: nn.Module, data: Union[MoleculeDataset, List[MoleculeDataset]], loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: Namespace, n_iter: int = 0, logger: logging.Logger = None, writer: SummaryWriter = None) -> int: """ Trains a model for an epoch. :param model: Model. :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe). :param loss_func: Loss function. :param optimizer: An Optimizer. :param scheduler: A learning rate scheduler. :param args: Arguments. :param n_iter: The number of iterations (training examples) trained on so far. :param logger: A logger for printing intermediate results. :param writer: A tensorboardX SummaryWriter. :return: The total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() data = deepcopy(data) data.shuffle() if args.uncertainty == 'bootstrap': data.sample(int(4 * len(data) / args.ensemble_size)) loss_sum, iter_count = 0, 0 num_iters = len( data ) // args.batch_size * args.batch_size # don't use the last batch if it's small, for stability iter_size = args.batch_size for i in trange(0, num_iters, iter_size): # Prepare batch if i + args.batch_size > len(data): break mol_batch = MoleculeDataset(data[i:i + args.batch_size]) smiles_batch, features_batch, target_batch = mol_batch.smiles( ), mol_batch.features(), mol_batch.targets() batch = smiles_batch mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch]) targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch]) if next(model.parameters()).is_cuda: mask, targets = mask.cuda(), targets.cuda() class_weights = torch.ones(targets.shape) if args.cuda: class_weights = class_weights.cuda() # Run model model.zero_grad() preds = model(batch, features_batch) if model.uncertainty: pred_targets = preds[:, [ j for j in range(len(preds[0])) if j % 2 == 0 ]] pred_var = preds[:, [j for j in range(len(preds[0])) if j % 2 == 1]] loss = loss_func(pred_targets, pred_var, targets) # sigma = ((pred_targets - targets) ** 2).detach() # loss = loss_func(pred_targets, targets) * class_weights * mask # loss += nn.MSELoss(reduction='none')(pred_sigma, sigma) * class_weights * mask else: loss = loss_func(preds, targets) * class_weights * mask loss = loss.sum() / mask.sum() loss_sum += loss.item() iter_count += len(mol_batch) loss.backward() optimizer.step() if isinstance(scheduler, NoamLR): scheduler.step() n_iter += len(mol_batch) # Log and/or add to tensorboard if (n_iter // args.batch_size) % args.log_frequency == 0: lrs = scheduler.get_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = loss_sum / iter_count loss_sum, iter_count = 0, 0 lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) debug( f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}' ) if writer is not None: writer.add_scalar('train_loss', loss_avg, n_iter) writer.add_scalar('param_norm', pnorm, n_iter) writer.add_scalar('gradient_norm', gnorm, n_iter) for i, lr in enumerate(lrs): writer.add_scalar(f'learning_rate_{i}', lr, n_iter) return n_iter
def train(model: nn.Module, data: Union[MoleculeDataset, List[MoleculeDataset]], loss_func: Callable, metric_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: Namespace, n_iter: int = 0, logger: logging.Logger = None, writer: SummaryWriter = None) -> int: """ Trains a model for an epoch. :param model: Model. :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe). :param loss_func: Loss function. :param optimizer: An Optimizer. :param scheduler: A learning rate scheduler. :param args: Arguments. :param n_iter: The number of iterations (training examples) trained on so far. :param logger: A logger for printing intermediate results. :param writer: A tensorboardX SummaryWriter. :return: The total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() data.shuffle() loss_sum, metric_sum, iter_count = [0]*(len(args.atom_targets) + len(args.bond_targets)), \ [0]*(len(args.atom_targets) + len(args.bond_targets)), 0 loss_weights = args.loss_weights num_iters = len( data ) // args.batch_size * args.batch_size # don't use the last batch if it's small, for stability iter_size = args.batch_size for i in trange(0, num_iters, iter_size): # Prepare batch if i + args.batch_size > len(data): break mol_batch = MoleculeDataset(data[i:i + args.batch_size]) smiles_batch, features_batch, target_batch = mol_batch.smiles( ), mol_batch.features(), mol_batch.targets() batch = smiles_batch #mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch]) # FIXME assign 0 to None in target # targets = [[0 if x is None else x for x in tb] for tb in target_batch] targets = [torch.Tensor(np.concatenate(x)) for x in zip(*target_batch)] if next(model.parameters()).is_cuda: # mask, targets = mask.cuda(), targets.cuda() targets = [x.cuda() for x in targets] # FIXME #class_weights = torch.ones(targets.shape) #if args.cuda: # class_weights = class_weights.cuda() # Run model model.zero_grad() preds = model(batch, features_batch) targets = [x.reshape([-1, 1]) for x in targets] #FIXME mutlticlass ''' if args.dataset_type == 'multiclass': targets = targets.long() loss = torch.cat([loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) for target_index in range(preds.size(1))], dim=1) * class_weights * mask else: loss = loss_func(preds, targets) * class_weights * mask ''' loss_multi_task = [] metric_multi_task = [] for target, pred, lw in zip(targets, preds, loss_weights): loss = loss_func(pred, target) loss = loss.sum() / target.shape[0] loss_multi_task.append(loss * lw) if args.cuda: metric = metric_func(pred.data.cpu().numpy(), target.data.cpu().numpy()) else: metric = metric_func(pred.data.numpy(), target.data.numpy()) metric_multi_task.append(metric) loss_sum = [x + y for x, y in zip(loss_sum, loss_multi_task)] iter_count += 1 sum(loss_multi_task).backward() optimizer.step() metric_sum = [x + y for x, y in zip(metric_sum, metric_multi_task)] if isinstance(scheduler, NoamLR) or isinstance(scheduler, SinexpLR): scheduler.step() n_iter += len(mol_batch) # Log and/or add to tensorboard if (n_iter // args.batch_size) % args.log_frequency == 0: lrs = scheduler.get_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = [x / iter_count for x in loss_sum] metric_avg = [x / iter_count for x in metric_sum] loss_sum, iter_count, metric_sum = [0]*(len(args.atom_targets) + len(args.bond_targets)), \ 0, \ [0]*(len(args.atom_targets) + len(args.bond_targets)) loss_str = ', '.join(f'lss_{i} = {lss:.4e}' for i, lss in enumerate(loss_avg)) metric_str = ', '.join(f'mc_{i} = {mc:.4e}' for i, mc in enumerate(metric_avg)) lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) debug( f'{loss_str}, {metric_str}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}' ) if writer is not None: for i, lss in enumerate(loss_avg): writer.add_scalar(f'train_loss_{i}', lss, n_iter) writer.add_scalar('param_norm', pnorm, n_iter) writer.add_scalar('gradient_norm', gnorm, n_iter) for i, lr in enumerate(lrs): writer.add_scalar(f'learning_rate_{i}', lr, n_iter) return n_iter
def train_batch(args, fold_i, model: nn.Module, data: DataLoader, loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, logger: logging.Logger = None, writer: SummaryWriter = None): debug = logger.debug if logger is not None else print loss_sum, iter_count, epoch_loss = 0, 0, 0 for it, result_batch in enumerate(tqdm(data)): model.zero_grad() batch = result_batch['sm'] label_batch = result_batch['labels'] mask = torch.Tensor([[x is not None for x in batch_t] for batch_t in result_batch['labels']]) targets = torch.Tensor([[0 if x is None else x for x in batch_t] for batch_t in result_batch['labels']]) args.num_tasks = len(result_batch['labels'][0]) if args.dataset_type == 'classification': if args.class_balance: class_weights = [] for task_num in range(args.n_task): class_weights.append(args.class_weights[fold_i][task_num][ targets[:, task_num].long()]) class_weights = torch.stack(class_weights).t() else: class_weights = torch.ones(targets.shape) if next(model.parameters()).is_cuda and args.gpuUSE: mask, targets = mask.cuda(), targets.cuda() preds = model(batch) if args.dataset_type == 'classification': loss = loss_func(preds, targets) * class_weights * mask else: loss = loss_func(preds, targets) * mask loss = loss.sum() / mask.sum() loss_sum += loss.item() epoch_loss += loss.item() iter_count += targets.size(0) loss.backward() optimizer.step() if isinstance(scheduler, NoamLR): scheduler.step() if it % args.log_frequency == 0: lrs = scheduler.get_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = loss_sum / (iter_count * targets.size(0)) loss_sum, iter_count = 0, 0 lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) if writer is not None: writer.add_scalar('train_loss_batch', loss_avg, it) writer.add_scalar('param_norm_batch', pnorm, it) writer.add_scalar('gradient_norm_batch', gnorm, it) for i, lr in enumerate(lrs): writer.add_scalar(f'learning_rate_{i}_batch', lr, it) return it, it * targets.size(0), loss_avg, epoch_loss
def train(model: MoleculeModel, data_loader: MoleculeDataLoader, loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: TrainArgs, n_iter: int = 0, logger: logging.Logger = None, writer: SummaryWriter = None) -> int: """ Trains a model for an epoch. :param model: A :class:`~chemprop.models.model.MoleculeModel`. :param data_loader: A :class:`~chemprop.data.data.MoleculeDataLoader`. :param loss_func: Loss function. :param optimizer: An optimizer. :param scheduler: A learning rate scheduler. :param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for training the model. :param n_iter: The number of iterations (training examples) trained on so far. :param logger: A logger for recording output. :param writer: A tensorboardX SummaryWriter. :return: The total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() loss_sum = iter_count = 0 for batch in tqdm(data_loader, total=len(data_loader), leave=False): # Prepare batch batch: MoleculeDataset mol_batch, features_batch, target_batch, mask_batch, atom_descriptors_batch, atom_features_batch, bond_features_batch, data_weights_batch = \ batch.batch_graph(), batch.features(), batch.targets(), batch.mask(), batch.atom_descriptors(), \ batch.atom_features(), batch.bond_features(), batch.data_weights() mask = torch.tensor(mask_batch, dtype=torch.bool) # shape(batch, tasks) targets = torch.tensor([[0 if x is None else x for x in tb] for tb in target_batch]) # shape(batch, tasks) if args.target_weights is not None: target_weights = torch.tensor(args.target_weights).unsqueeze(0) # shape(1,tasks) else: target_weights = torch.ones(targets.shape[1]).unsqueeze(0) data_weights = torch.tensor(data_weights_batch).unsqueeze(1) # shape(batch,1) if args.loss_function == 'bounded_mse': lt_target_batch = batch.lt_targets() # shape(batch, tasks) gt_target_batch = batch.gt_targets() # shape(batch, tasks) lt_target_batch = torch.tensor(lt_target_batch) gt_target_batch = torch.tensor(gt_target_batch) # Run model model.zero_grad() preds = model(mol_batch, features_batch, atom_descriptors_batch, atom_features_batch, bond_features_batch) # Move tensors to correct device torch_device = preds.device mask = mask.to(torch_device) targets = targets.to(torch_device) target_weights = target_weights.to(torch_device) data_weights = data_weights.to(torch_device) if args.loss_function == 'bounded_mse': lt_target_batch = lt_target_batch.to(torch_device) gt_target_batch = gt_target_batch.to(torch_device) # Calculate losses if args.loss_function == 'mcc' and args.dataset_type == 'classification': loss = loss_func(preds, targets, data_weights, mask) *target_weights.squeeze(0) elif args.loss_function == 'mcc': # multiclass dataset type targets = targets.long() target_losses = [] for target_index in range(preds.size(1)): target_loss = loss_func(preds[:, target_index, :], targets[:, target_index], data_weights, mask[:, target_index]).unsqueeze(0) target_losses.append(target_loss) loss = torch.cat(target_losses).to(torch_device) * target_weights.squeeze(0) elif args.dataset_type == 'multiclass': targets = targets.long() if args.loss_function == 'dirichlet': loss = loss_func(preds, targets, args.evidential_regularization) * target_weights * data_weights * mask else: target_losses = [] for target_index in range(preds.size(1)): target_loss = loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) target_losses.append(target_loss) loss = torch.cat(target_losses, dim=1).to(torch_device) * target_weights * data_weights * mask elif args.dataset_type == 'spectra': loss = loss_func(preds, targets, mask) * target_weights * data_weights * mask elif args.loss_function == 'bounded_mse': loss = loss_func(preds, targets, lt_target_batch, gt_target_batch) * target_weights * data_weights * mask elif args.loss_function == 'evidential': loss = loss_func(preds, targets, args.evidential_regularization) * target_weights * data_weights * mask elif args.loss_function == 'dirichlet': # classification loss = loss_func(preds, targets, args.evidential_regularization) * target_weights * data_weights * mask else: loss = loss_func(preds, targets) * target_weights * data_weights * mask loss = loss.sum() / mask.sum() loss_sum += loss.item() iter_count += 1 loss.backward() if args.grad_clip: nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() if isinstance(scheduler, NoamLR): scheduler.step() n_iter += len(batch) # Log and/or add to tensorboard if (n_iter // args.batch_size) % args.log_frequency == 0: lrs = scheduler.get_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = loss_sum / iter_count loss_sum = iter_count = 0 lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) debug(f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}') if writer is not None: writer.add_scalar('train_loss', loss_avg, n_iter) writer.add_scalar('param_norm', pnorm, n_iter) writer.add_scalar('gradient_norm', gnorm, n_iter) for i, lr in enumerate(lrs): writer.add_scalar(f'learning_rate_{i}', lr, n_iter) return n_iter
def train(model: nn.Module, data: MolPairDataset, loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: Namespace, n_iter: int = 0, logger: logging.Logger = None, writer: SummaryWriter = None) -> int: """ Trains a model for an epoch. :param model: Model. :param data: A MolPairDataset (or a list of MolPairDatasets if using moe). :param loss_func: Loss function. :param optimizer: An Optimizer. :param scheduler: A learning rate scheduler. :param args: Arguments. :param n_iter: The number of iterations (training examples) trained on so far. :param logger: A logger for printing intermediate results. :param writer: A tensorboardX SummaryWriter. :return: The total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() data.shuffle( ) # Very important this is done before conversion to maintain randomness in contrastive dataset. loss_sum, iter_count = 0, 0 if args.loss_func == 'contrastive': data = convert2contrast(data) num_iters = len( data ) // args.batch_size * args.batch_size # don't use the last batch if it's small, for stability iter_size = args.batch_size for i in trange(0, num_iters, iter_size): # Prepare batch if i + args.batch_size > len(data): break mol_batch = MolPairDataset(data[i:i + args.batch_size]) smiles_batch, features_batch, target_batch = mol_batch.smiles( ), mol_batch.features(), mol_batch.targets() batch = smiles_batch targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch]) if args.loss_func == 'contrastive': mask = targets else: mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch]) if next(model.parameters()).is_cuda: mask, targets = mask.cuda(), targets.cuda() if args.dataset_type == 'regression': class_weights = torch.ones(targets.shape) else: class_weights = targets * (args.class_weights - 1) + 1 if args.cuda: class_weights = class_weights.cuda() # Run model model.zero_grad() preds = model(batch, features_batch) if args.dataset_type == 'multiclass': targets = targets.long() loss = torch.cat([ loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) for target_index in range(preds.size(1)) ], dim=1) * class_weights * mask else: loss = loss_func(preds, targets) * class_weights * mask loss = loss.sum() / mask.sum() loss_sum += loss.item() iter_count += 1 loss.backward() if args.grad_clip: nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() if isinstance(scheduler, NoamLR): scheduler.step() n_iter += args.batch_size # Log and/or add to tensorboard if (n_iter // args.batch_size) % args.log_frequency == 0: lrs = scheduler.get_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = loss_sum / iter_count loss_sum, iter_count = 0, 0 lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) debug( f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}' ) if writer is not None: writer.add_scalar('train_loss', loss_avg, n_iter) writer.add_scalar('param_norm', pnorm, n_iter) writer.add_scalar('gradient_norm', gnorm, n_iter) for i, lr in enumerate(lrs): writer.add_scalar(f'learning_rate_{i}', lr, n_iter) return n_iter