def enumerate_scheduler(scheduler: _LRScheduler, steps: int) -> List[float]: """ Reads the current learning rate via get_last_lr, run 1 scheduler step, and repeat. Returns the LR values. """ lrs = [] for _ in range(steps): lr = scheduler.get_last_lr() # type: ignore assert isinstance(lr, list) assert len(lr) == 1 lrs.append(lr[0]) scheduler.step() return lrs
def fit( step :FunctionType, epochs :int, model :ModuleType, optimizer :OptimizerType, scheduler :SchedulerType, data_loader :DataLoader, model_path :str, logger :object, early_stop :bool =False, pgd_kwargs :Optional[dict] =None, verbose :bool =False ) -> Tuple[ModuleType, dict]: """Standard pytorch boilerplate for training a model. Allows for early stopping wrt robust accuracy rather than the usual clean accuracy. """ device = next(model.parameters()).device prev_robust_acc = 0. start_train_time = time.time() logger.info('Epoch \t Seconds \t LR \t \t Train Loss \t Train Acc') for epoch in range(epochs): start_epoch_time = time.time() train_loss = 0 train_acc = 0 train_n = 0 data_generator = enumerate(data_loader) if verbose: data_generator = tqdm(data_generator, total=len(data_loader), desc=f'Epoch {epoch + 1}') for i, (X, y) in data_generator: X, y = X.to(device), y.to(device) if i == 0: first_batch = (X, y) loss, logits = step(X, y, model =model, optimizer=optimizer, scheduler=scheduler) train_loss += loss.item() * y.size(0) train_acc += (logits.argmax(dim=1) == y).sum().item() train_n += y.size(0) if early_stop: assert pgd_kwargs is not None # Check current PGD robustness of model using random minibatch X, y = first_batch pgd_delta = attack_pgd( model=model, X =X, y =y, opt =optimizer, **pgd_kwargs) model.eval() with torch.no_grad(): output = model(clamp(X + pgd_delta[:X.size(0)], pgd_kwargs['lower_limit'], pgd_kwargs['upper_limit'])) robust_acc = (output.softmax(dim=1).argmax(dim=1) == y).sum().item() / y.size(0) if robust_acc - prev_robust_acc < -0.2: logger.info('EARLY STOPPING TRIGGERED') break prev_robust_acc = robust_acc best_state_dict = copy.deepcopy(model.state_dict()) model.train() epoch_time = time.time() lr = scheduler.get_last_lr()[0] logger.info('%d \t %.1f \t \t %.4f \t %.4f \t %.4f', epoch, epoch_time - start_epoch_time, lr, train_loss / train_n, train_acc / train_n) train_time = time.time() if not early_stop: best_state_dict = model.state_dict() torch.save(best_state_dict, model_path) logger.info('Total train time: %.4f minutes', (train_time - start_train_time)/60) return model, best_state_dict
def train(model: nn.Module, data_loader: MoleculeDataLoader, loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: TrainArgs, n_iter: int = 0, logger: logging.Logger = None, writer: SummaryWriter = None, gp_switch: bool = False, likelihood = None, bbp_switch = None) -> int: """ Trains a model for an epoch. :param model: Model. :param data_loader: A MoleculeDataLoader. :param loss_func: Loss function. :param optimizer: An Optimizer. :param scheduler: A learning rate scheduler. :param args: Arguments. :param n_iter: The number of iterations (training examples) trained on so far. :param logger: A logger for printing intermediate results. :param writer: A tensorboardX SummaryWriter. :return: The total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() if likelihood is not None: likelihood.train() loss_sum = 0 if bbp_switch is not None: data_loss_sum = 0 kl_loss_sum = 0 kl_loss_depth_sum = 0 #for batch in tqdm(data_loader, total=len(data_loader)): for batch in data_loader: # Prepare batch batch: MoleculeDataset # .batch_graph() returns BatchMolGraph # .features() returns None if no additional features # .targets() returns list of lists of floats containing the targets mol_batch, features_batch, target_batch = batch.batch_graph(), batch.features(), batch.targets() # mask is 1 where targets are not None mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch]) # where targets are None, replace with 0 targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch]) # Move tensors to correct device mask = mask.to(args.device) targets = targets.to(args.device) class_weights = torch.ones(targets.shape, device=args.device) # zero gradients model.zero_grad() optimizer.zero_grad() ##### FORWARD PASS AND LOSS COMPUTATION ##### if bbp_switch == None: # forward pass preds = model(mol_batch, features_batch) # compute loss if gp_switch: loss = -loss_func(preds, targets) else: loss = loss_func(preds, targets, torch.exp(model.log_noise)) ### bbp non sample option if bbp_switch == 1: preds, kl_loss = model(mol_batch, features_batch, sample = False) data_loss = loss_func(preds, targets, torch.exp(model.log_noise)) kl_loss /= args.train_data_size loss = data_loss + kl_loss ### bbp sample option if bbp_switch == 2: if args.samples_bbp == 1: preds, kl_loss = model(mol_batch, features_batch, sample=True) data_loss = loss_func(preds, targets, torch.exp(model.log_noise)) kl_loss /= args.train_data_size elif args.samples_bbp > 1: data_loss_cum = 0 kl_loss_cum = 0 for i in range(args.samples_bbp): preds, kl_loss_i = model(mol_batch, features_batch, sample=True) data_loss_i = loss_func(preds, targets, torch.exp(model.log_noise)) kl_loss_i /= args.train_data_size data_loss_cum += data_loss_i kl_loss_cum += kl_loss_i data_loss = data_loss_cum / args.samples_bbp kl_loss = kl_loss_cum / args.samples_bbp loss = data_loss + kl_loss ### DUN non sample option if bbp_switch == 3: cat = torch.exp(model.log_cat) / torch.sum(torch.exp(model.log_cat)) _, preds_list, kl_loss, kl_loss_depth = model(mol_batch, features_batch, sample=False) data_loss = loss_func(preds_list, targets, torch.exp(model.log_noise), cat) kl_loss /= args.train_data_size kl_loss_depth /= args.train_data_size loss = data_loss + kl_loss + kl_loss_depth #print('-----') #print(data_loss) #print(kl_loss) #print(cat) ### DUN sample option if bbp_switch == 4: cat = torch.exp(model.log_cat) / torch.sum(torch.exp(model.log_cat)) if args.samples_dun == 1: _, preds_list, kl_loss, kl_loss_depth = model(mol_batch, features_batch, sample=True) data_loss = loss_func(preds_list, targets, torch.exp(model.log_noise), cat) kl_loss /= args.train_data_size kl_loss_depth /= args.train_data_size elif args.samples_dun > 1: data_loss_cum = 0 kl_loss_cum = 0 for i in range(args.samples_dun): _, preds_list, kl_loss_i, kl_loss_depth = model(mol_batch, features_batch, sample=True) data_loss_i = loss_func(preds_list, targets, torch.exp(model.log_noise), cat) kl_loss_i /= args.train_data_size kl_loss_depth /= args.train_data_size data_loss_cum += data_loss_i kl_loss_cum += kl_loss_i data_loss = data_loss_cum / args.samples_dun kl_loss = kl_loss_cum / args.samples_dun loss = data_loss + kl_loss + kl_loss_depth #print('-----') #print(data_loss) #print(kl_loss) #print(cat) ############################################# # backward pass; update weights loss.backward() optimizer.step() #for name, parameter in model.named_parameters(): #print(name)#, parameter.grad) #print(np.sum(np.array(parameter.grad))) # add to loss_sum and iter_count loss_sum += loss.item() * len(batch) if bbp_switch is not None: data_loss_sum += data_loss.item() * len(batch) kl_loss_sum += kl_loss.item() * len(batch) if bbp_switch > 2: kl_loss_depth_sum += kl_loss_depth * len(batch) # update learning rate by taking a step if isinstance(scheduler, NoamLR) or isinstance(scheduler, OneCycleLR): scheduler.step() # increment n_iter (total number of examples across epochs) n_iter += len(batch) ########### per epoch REPORTING if n_iter % args.train_data_size == 0: lrs = scheduler.get_last_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = loss_sum / args.train_data_size lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) debug(f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}') if bbp_switch is not None: data_loss_avg = data_loss_sum / args.train_data_size kl_loss_avg = kl_loss_sum / args.train_data_size wandb.log({"Total loss": loss_avg}, commit=False) wandb.log({"Likelihood cost": data_loss_avg}, commit=False) wandb.log({"KL cost": kl_loss_avg}, commit=False) if bbp_switch > 2: kl_loss_depth_avg = kl_loss_depth_sum / args.train_data_size wandb.log({"KL cost DEPTH": kl_loss_depth_avg}, commit=False) # log variational categorical distribution wandb.log({"d_1": cat.detach().cpu().numpy()[0]}, commit=False) wandb.log({"d_2": cat.detach().cpu().numpy()[1]}, commit=False) wandb.log({"d_3": cat.detach().cpu().numpy()[2]}, commit=False) wandb.log({"d_4": cat.detach().cpu().numpy()[3]}, commit=False) wandb.log({"d_5": cat.detach().cpu().numpy()[4]}, commit=False) else: if gp_switch: wandb.log({"Negative ELBO": loss_avg}, commit=False) else: wandb.log({"Negative log likelihood (scaled)": loss_avg}, commit=False) if args.pdts: wandb.log({"Learning rate": lrs[0]}, commit=True) else: wandb.log({"Learning rate": lrs[0]}, commit=False) if args.pdts and args.swag: return loss_avg, n_iter else: return n_iter