Exemple #1
0
def enumerate_scheduler(scheduler: _LRScheduler, steps: int) -> List[float]:
    """
    Reads the current learning rate via get_last_lr, run 1 scheduler step, and repeat. Returns the LR values.
    """
    lrs = []
    for _ in range(steps):
        lr = scheduler.get_last_lr()  # type: ignore
        assert isinstance(lr, list)
        assert len(lr) == 1
        lrs.append(lr[0])
        scheduler.step()
    return lrs
def fit(
    step         :FunctionType,
    epochs       :int,
    model        :ModuleType,
    optimizer    :OptimizerType,
    scheduler    :SchedulerType,
    data_loader  :DataLoader,
    model_path   :str,
    logger       :object,
    early_stop   :bool            =False,
    pgd_kwargs   :Optional[dict]  =None,
    verbose      :bool            =False
) -> Tuple[ModuleType, dict]:
    """Standard pytorch boilerplate for training a model.
    Allows for early stopping wrt robust accuracy rather than the usual clean accuracy.
    """
    device = next(model.parameters()).device
    prev_robust_acc = 0.
    start_train_time = time.time()
    logger.info('Epoch \t Seconds \t LR \t \t Train Loss \t Train Acc')

    for epoch in range(epochs):
        start_epoch_time = time.time()
        train_loss = 0
        train_acc = 0
        train_n = 0

        data_generator = enumerate(data_loader)
        if verbose:
            data_generator = tqdm(data_generator, total=len(data_loader), desc=f'Epoch {epoch + 1}')

        for i, (X, y) in data_generator:
            X, y = X.to(device), y.to(device)
            if i == 0:
                first_batch = (X, y)

            loss, logits = step(X, y, 
                model    =model, 
                optimizer=optimizer, 
                scheduler=scheduler)

            train_loss += loss.item() * y.size(0)
            train_acc += (logits.argmax(dim=1) == y).sum().item()
            train_n += y.size(0)

        if early_stop:
            assert pgd_kwargs is not None
            # Check current PGD robustness of model using random minibatch
            X, y = first_batch

            pgd_delta = attack_pgd(
                model=model, 
                X    =X, 
                y    =y, 
                opt  =optimizer,
                **pgd_kwargs)

            model.eval()
            with torch.no_grad():
                output = model(clamp(X + pgd_delta[:X.size(0)], pgd_kwargs['lower_limit'], pgd_kwargs['upper_limit']))
            robust_acc = (output.softmax(dim=1).argmax(dim=1) == y).sum().item() / y.size(0)
            if robust_acc - prev_robust_acc < -0.2:
                logger.info('EARLY STOPPING TRIGGERED')
                break
            prev_robust_acc = robust_acc
            best_state_dict = copy.deepcopy(model.state_dict())
            model.train()
        
        epoch_time = time.time()
        lr = scheduler.get_last_lr()[0]
        logger.info('%d \t %.1f \t \t %.4f \t %.4f \t %.4f',
            epoch, 
            epoch_time - start_epoch_time, 
            lr, 
            train_loss / train_n, 
            train_acc / train_n)

    train_time = time.time()
    if not early_stop:
        best_state_dict = model.state_dict()
    torch.save(best_state_dict, model_path)
    logger.info('Total train time: %.4f minutes', (train_time - start_train_time)/60)

    return model, best_state_dict
Exemple #3
0
def train(model: nn.Module,
          data_loader: MoleculeDataLoader,
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: TrainArgs,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None,
          gp_switch: bool = False,
          likelihood = None,
          bbp_switch = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data_loader: A MoleculeDataLoader.
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    
    

        
    debug = logger.debug if logger is not None else print
    
    model.train()
    if likelihood is not None:
        likelihood.train()
    
    loss_sum = 0
    if bbp_switch is not None:
        data_loss_sum = 0
        kl_loss_sum = 0
        kl_loss_depth_sum = 0

    #for batch in tqdm(data_loader, total=len(data_loader)):
    for batch in data_loader:
        # Prepare batch
        batch: MoleculeDataset

        # .batch_graph() returns BatchMolGraph
        # .features() returns None if no additional features
        # .targets() returns list of lists of floats containing the targets
        mol_batch, features_batch, target_batch = batch.batch_graph(), batch.features(), batch.targets()
        
        # mask is 1 where targets are not None
        mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch])
        # where targets are None, replace with 0
        targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch])

        # Move tensors to correct device
        mask = mask.to(args.device)
        targets = targets.to(args.device)
        class_weights = torch.ones(targets.shape, device=args.device)
        
        # zero gradients
        model.zero_grad()
        optimizer.zero_grad()
        
        
        ##### FORWARD PASS AND LOSS COMPUTATION #####
        
        
        if bbp_switch == None:
        
            # forward pass
            preds = model(mol_batch, features_batch)
    
            # compute loss
            if gp_switch:
                loss = -loss_func(preds, targets)
            else:
                loss = loss_func(preds, targets, torch.exp(model.log_noise))
        
        
        ### bbp non sample option
        if bbp_switch == 1:    
            preds, kl_loss = model(mol_batch, features_batch, sample = False)
            data_loss = loss_func(preds, targets, torch.exp(model.log_noise))
            kl_loss /= args.train_data_size
            loss = data_loss + kl_loss  
            
        ### bbp sample option
        if bbp_switch == 2:

            if args.samples_bbp == 1:
                preds, kl_loss = model(mol_batch, features_batch, sample=True)
                data_loss = loss_func(preds, targets, torch.exp(model.log_noise))
                kl_loss /= args.train_data_size
        
            elif args.samples_bbp > 1:
                data_loss_cum = 0
                kl_loss_cum = 0
        
                for i in range(args.samples_bbp):
                    preds, kl_loss_i = model(mol_batch, features_batch, sample=True)
                    data_loss_i = loss_func(preds, targets, torch.exp(model.log_noise))                    
                    kl_loss_i /= args.train_data_size                    
                    
                    data_loss_cum += data_loss_i
                    kl_loss_cum += kl_loss_i
        
                data_loss = data_loss_cum / args.samples_bbp
                kl_loss = kl_loss_cum / args.samples_bbp
            
            loss = data_loss + kl_loss

        ### DUN non sample option
        if bbp_switch == 3:
            cat = torch.exp(model.log_cat) / torch.sum(torch.exp(model.log_cat))    
            _, preds_list, kl_loss, kl_loss_depth = model(mol_batch, features_batch, sample=False)
            data_loss = loss_func(preds_list, targets, torch.exp(model.log_noise), cat)
            kl_loss /= args.train_data_size
            kl_loss_depth /= args.train_data_size
            loss = data_loss + kl_loss + kl_loss_depth
            #print('-----')
            #print(data_loss)
            #print(kl_loss)
            #print(cat)

        ### DUN sample option
        if bbp_switch == 4:

            cat = torch.exp(model.log_cat) / torch.sum(torch.exp(model.log_cat))

            if args.samples_dun == 1:
                _, preds_list, kl_loss, kl_loss_depth = model(mol_batch, features_batch, sample=True)
                data_loss = loss_func(preds_list, targets, torch.exp(model.log_noise), cat)
                kl_loss /= args.train_data_size
                kl_loss_depth /= args.train_data_size
        
            elif args.samples_dun > 1:
                data_loss_cum = 0
                kl_loss_cum = 0

                for i in range(args.samples_dun):
                    _, preds_list, kl_loss_i, kl_loss_depth = model(mol_batch, features_batch, sample=True)
                    data_loss_i = loss_func(preds_list, targets, torch.exp(model.log_noise), cat)
                    kl_loss_i /= args.train_data_size                    
                    kl_loss_depth /= args.train_data_size

                    data_loss_cum += data_loss_i
                    kl_loss_cum += kl_loss_i
        
                data_loss = data_loss_cum / args.samples_dun
                kl_loss = kl_loss_cum / args.samples_dun
            
            loss = data_loss + kl_loss + kl_loss_depth

            #print('-----')
            #print(data_loss)
            #print(kl_loss)
            #print(cat)
            
        #############################################
        
        
        # backward pass; update weights
        loss.backward()
        optimizer.step()
        
        
        #for name, parameter in model.named_parameters():
            #print(name)#, parameter.grad)
            #print(np.sum(np.array(parameter.grad)))

        # add to loss_sum and iter_count
        loss_sum += loss.item() * len(batch)
        if bbp_switch is not None:
            data_loss_sum += data_loss.item() * len(batch)
            kl_loss_sum += kl_loss.item() * len(batch)
            if bbp_switch > 2:
                kl_loss_depth_sum += kl_loss_depth * len(batch)

        # update learning rate by taking a step
        if isinstance(scheduler, NoamLR) or isinstance(scheduler, OneCycleLR):
            scheduler.step()

        # increment n_iter (total number of examples across epochs)
        n_iter += len(batch)

        

            
        ########### per epoch REPORTING
        if n_iter % args.train_data_size == 0:
            lrs = scheduler.get_last_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            
            loss_avg = loss_sum / args.train_data_size

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs))
            debug(f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}')
            
            if bbp_switch is not None:
                data_loss_avg = data_loss_sum / args.train_data_size
                kl_loss_avg = kl_loss_sum / args.train_data_size
                wandb.log({"Total loss": loss_avg}, commit=False)
                wandb.log({"Likelihood cost": data_loss_avg}, commit=False)
                wandb.log({"KL cost": kl_loss_avg}, commit=False)

                if bbp_switch > 2:
                    kl_loss_depth_avg = kl_loss_depth_sum / args.train_data_size
                    wandb.log({"KL cost DEPTH": kl_loss_depth_avg}, commit=False)

                    # log variational categorical distribution
                    wandb.log({"d_1": cat.detach().cpu().numpy()[0]}, commit=False)
                    wandb.log({"d_2": cat.detach().cpu().numpy()[1]}, commit=False)
                    wandb.log({"d_3": cat.detach().cpu().numpy()[2]}, commit=False)
                    wandb.log({"d_4": cat.detach().cpu().numpy()[3]}, commit=False)
                    wandb.log({"d_5": cat.detach().cpu().numpy()[4]}, commit=False)

            else:
                if gp_switch:
                    wandb.log({"Negative ELBO": loss_avg}, commit=False)
                else:
                    wandb.log({"Negative log likelihood (scaled)": loss_avg}, commit=False)
            
            if args.pdts:
                wandb.log({"Learning rate": lrs[0]}, commit=True)
            else:
                wandb.log({"Learning rate": lrs[0]}, commit=False)
            
    if args.pdts and args.swag:
        return loss_avg, n_iter
    else:
        return n_iter