def pdts(args: TrainArgs, model_idx):
    """
    preliminary experiment with PDTS (approximate BO)
    we use a data set size of 50k and run until we have trained with 15k data points
    our batch size is 50
    we initialise with 1000 data points
    """

    ######## set up all logging ########
    logger = None

    # make save_dir
    save_dir = os.path.join(args.save_dir, f'model_{model_idx}')
    makedirs(save_dir)

    # make results_dir
    results_dir = args.results_dir
    makedirs(results_dir)

    # initialise wandb
    #os.environ['WANDB_MODE'] = 'dryrun'
    wandb.init(name=args.wandb_name + '_' + str(model_idx),
               project=args.wandb_proj,
               reinit=True)
    #print('WANDB directory is:')
    #print(wandb.run.dir)
    ####################################

    ########## get data
    args.task_names = args.target_columns or get_task_names(args.data_path)
    data = get_data(path=args.data_path, args=args, logger=logger)
    args.num_tasks = data.num_tasks()
    args.features_size = data.features_size()

    ########## SMILES of top 1%
    top1p = np.array(MoleculeDataset(data).targets())
    top1p_idx = np.argsort(-top1p[:, 0])[:int(args.max_data_size * 0.01)]
    SMILES = np.array(MoleculeDataset(data).smiles())[top1p_idx]

    ########## initial data splits
    args.seed = args.data_seeds[model_idx]
    data.shuffle(seed=args.seed)
    sizes = args.split_sizes
    train_size = int(sizes[0] * len(data))
    train_orig = data[:train_size]
    test_orig = data[train_size:]
    train_data, test_data = copy.deepcopy(
        MoleculeDataset(train_orig)), copy.deepcopy(MoleculeDataset(test_orig))
    args.train_data_size = len(train_data)

    ########## standardising
    # features (train and test)
    features_scaler = train_data.normalize_features(replace_nan_token=0)
    test_data.normalize_features(features_scaler)
    # targets (train)
    train_targets = train_data.targets()
    test_targets = test_data.targets()
    scaler = StandardScaler().fit(train_targets)
    scaled_targets = scaler.transform(train_targets).tolist()
    train_data.set_targets(scaled_targets)

    ########## loss, metric functions
    loss_func = neg_log_like
    metric_func = get_metric_func(metric=args.metric)

    ########## data loaders
    if len(data) <= args.cache_cutoff:
        cache = True
        num_workers = 0
    else:
        cache = False
        num_workers = args.num_workers
    train_data_loader = MoleculeDataLoader(dataset=train_data,
                                           batch_size=args.batch_size,
                                           num_workers=num_workers,
                                           cache=cache,
                                           class_balance=args.class_balance,
                                           shuffle=True,
                                           seed=args.seed)
    test_data_loader = MoleculeDataLoader(dataset=test_data,
                                          batch_size=args.batch_size,
                                          num_workers=num_workers,
                                          cache=cache)

    ########## instantiating model, optimiser, scheduler (MAP)
    # set pytorch seed for random initial weights
    torch.manual_seed(args.pytorch_seeds[model_idx])
    # build model
    print(f'Building model {model_idx}')
    model = MoleculeModel(args)
    print(model)
    print(f'Number of parameters = {param_count(model):,}')
    if args.cuda:
        print('Moving model to cuda')
    model = model.to(args.device)
    # optimizer
    optimizer = Adam([{
        'params': model.encoder.parameters()
    }, {
        'params': model.ffn.parameters()
    }, {
        'params': model.log_noise,
        'weight_decay': 0
    }],
                     lr=args.lr,
                     weight_decay=args.weight_decay)
    # learning rate scheduler
    scheduler = scheduler_const([args.lr])

    ####################################################################
    ####################################################################
    # FIRST THOMPSON ITERATION

    ### scores array
    ptds_scores = np.ones(args.pdts_batches + 1)
    batch_no = 0

    ### fill for batch 0
    SMILES_train = np.array(train_data.smiles())
    SMILES_stack = np.hstack((SMILES, SMILES_train))
    overlap = len(SMILES_stack) - len(np.unique(SMILES_stack))
    prop = overlap / len(SMILES)
    ptds_scores[batch_no] = prop
    wandb.log({
        "Proportion of top 1%": prop,
        "batch_no": batch_no
    },
              commit=False)

    ### train MAP posterior
    gp_switch = False
    likelihood = None
    bbp_switch = None
    n_iter = 0
    for epoch in range(args.epochs_init_map):
        n_iter = train(model=model,
                       data_loader=train_data_loader,
                       loss_func=loss_func,
                       optimizer=optimizer,
                       scheduler=scheduler,
                       args=args,
                       n_iter=n_iter,
                       bbp_switch=bbp_switch)
        # save to save_dir
        #if epoch == args.epochs_init_map - 1:
        #save_checkpoint(os.path.join(save_dir, f'model_{batch_no}.pt'), model, scaler, features_scaler, args)
    # if X load from checkpoint path
    if args.bbp or args.gp or args.swag or args.sgld:
        model = load_checkpoint(args.checkpoint_path +
                                f'/model_{model_idx}/model_{batch_no}.pt',
                                device=args.device,
                                logger=None)

    ########## BBP
    if args.bbp:
        model_bbp = MoleculeModelBBP(
            args)  # instantiate with bayesian linear layers
        for (_, param_bbp), (_, param_pre) in zip(model_bbp.named_parameters(),
                                                  model.named_parameters()):
            param_bbp.data = copy.deepcopy(
                param_pre.data.T)  # copy over parameters
        # instantiate rhos
        for layer in model_bbp.children():
            if isinstance(layer, BayesLinear):
                layer.init_rho(args.rho_min_bbp, args.rho_max_bbp)
        for layer in model_bbp.encoder.encoder.children():
            if isinstance(layer, BayesLinear):
                layer.init_rho(args.rho_min_bbp, args.rho_max_bbp)
        model = model_bbp  # name back
        # move to cuda
        if args.cuda:
            print('Moving bbp model to cuda')
            model = model.to(args.device)
        # optimiser and scheduler
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        scheduler = scheduler_const([args.lr])

        bbp_switch = 2
        n_iter = 0
        for epoch in range(args.epochs_init):
            n_iter = train(model=model,
                           data_loader=train_data_loader,
                           loss_func=loss_func,
                           optimizer=optimizer,
                           scheduler=scheduler,
                           args=args,
                           n_iter=n_iter,
                           bbp_switch=bbp_switch)

    ########## GP
    if args.gp:
        # feature_extractor
        model.featurizer = True
        feature_extractor = model
        # inducing points
        inducing_points = initial_inducing_points(train_data_loader,
                                                  feature_extractor, args)
        # GP layer
        gp_layer = GPLayer(inducing_points, args.num_tasks)
        # full DKL model
        model = copy.deepcopy(DKLMoleculeModel(feature_extractor, gp_layer))
        # likelihood (rank 0 restricts to diagonal matrix)
        likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(
            num_tasks=12, rank=0)
        # model and likelihood to CUDA
        if args.cuda:
            model.cuda()
            likelihood.cuda()
        # loss object
        loss_func = gpytorch.mlls.VariationalELBO(
            likelihood, model.gp_layer, num_data=args.train_data_size)
        # optimiser and scheduler
        params_list = [
            {
                'params': model.feature_extractor.parameters(),
                'weight_decay': args.weight_decay_gp
            },
            {
                'params': model.gp_layer.hyperparameters()
            },
            {
                'params': model.gp_layer.variational_parameters()
            },
            {
                'params': likelihood.parameters()
            },
        ]
        optimizer = torch.optim.Adam(params_list, lr=args.lr)
        scheduler = scheduler_const([args.lr])

        gp_switch = True
        n_iter = 0
        for epoch in range(args.epochs_init):
            n_iter = train(model=model,
                           data_loader=train_data_loader,
                           loss_func=loss_func,
                           optimizer=optimizer,
                           scheduler=scheduler,
                           args=args,
                           n_iter=n_iter,
                           gp_switch=gp_switch,
                           likelihood=likelihood)

    ########## SWAG
    if args.swag:
        model_core = copy.deepcopy(model)
        model = train_swag_pdts(model_core, train_data_loader, loss_func,
                                scaler, features_scaler, args, save_dir,
                                batch_no)

    ########## SGLD
    if args.sgld:
        model = train_sgld_pdts(model, train_data_loader, loss_func, scaler,
                                features_scaler, args, save_dir, batch_no)

    ### find top_idx
    top_idx = []  # need for thom
    sum_test_preds = np.zeros(
        (len(test_orig), args.num_tasks))  # need for greedy
    for sample in range(args.samples):

        # draw model from SWAG posterior
        if args.swag:
            model.sample(scale=1.0, cov=args.cov_mat, block=args.block)

        # retrieve sgld sample
        if args.sgld:
            model = load_checkpoint(
                args.save_dir +
                f'/model_{model_idx}/model_{batch_no}/model_{sample}.pt',
                device=args.device,
                logger=logger)

        test_preds = predict(model=model,
                             data_loader=test_data_loader,
                             args=args,
                             scaler=scaler,
                             test_data=True,
                             gp_sample=args.thompson,
                             bbp_sample=True)
        test_preds = np.array(test_preds)
        # thompson bit
        rank = 0

        # base length
        if args.sgld:
            base_length = 5 * sample + 4
        else:
            base_length = sample

        while args.thompson and (len(top_idx) <= base_length):
            top_unique_molecule = np.argsort(-test_preds[:, 0])[rank]
            rank += 1
            if top_unique_molecule not in top_idx:
                top_idx.append(top_unique_molecule)
        # add to sum_test_preds
        sum_test_preds += test_preds
        # print
        print('done sample ' + str(sample))
    # final top_idx
    if args.thompson:
        top_idx = np.array(top_idx)
    else:
        sum_test_preds /= args.samples
        top_idx = np.argsort(-sum_test_preds[:, 0])[:50]

    ### transfer from test to train
    top_idx = -np.sort(-top_idx)
    for idx in top_idx:
        train_orig.append(test_orig.pop(idx))
    train_data, test_data = copy.deepcopy(
        MoleculeDataset(train_orig)), copy.deepcopy(MoleculeDataset(test_orig))
    args.train_data_size = len(train_data)
    if args.gp:
        loss_func = gpytorch.mlls.VariationalELBO(
            likelihood, model.gp_layer, num_data=args.train_data_size)
    print(args.train_data_size)

    ### standardise features (train and test; using original features_scaler)
    train_data.normalize_features(features_scaler)
    test_data.normalize_features(features_scaler)

    ### standardise targets (train only; using original scaler)
    train_targets = train_data.targets()
    scaled_targets_tr = scaler.transform(train_targets).tolist()
    train_data.set_targets(scaled_targets_tr)

    ### create data loaders
    train_data_loader = MoleculeDataLoader(dataset=train_data,
                                           batch_size=args.batch_size,
                                           num_workers=num_workers,
                                           cache=cache,
                                           class_balance=args.class_balance,
                                           shuffle=True,
                                           seed=args.seed)
    test_data_loader = MoleculeDataLoader(dataset=test_data,
                                          batch_size=args.batch_size,
                                          num_workers=num_workers,
                                          cache=cache)

    ####################################################################
    ####################################################################

    ##################################
    ########## thompson sampling loop
    ##################################

    for batch_no in range(1, args.pdts_batches + 1):

        ### fill in ptds_scores
        SMILES_train = np.array(train_data.smiles())
        SMILES_stack = np.hstack((SMILES, SMILES_train))
        overlap = len(SMILES_stack) - len(np.unique(SMILES_stack))
        prop = overlap / len(SMILES)
        ptds_scores[batch_no] = prop
        wandb.log({
            "Proportion of top 1%": prop,
            "batch_no": batch_no
        },
                  commit=False)

        ### train posterior
        n_iter = 0
        for epoch in range(args.epochs):
            n_iter = train(model=model,
                           data_loader=train_data_loader,
                           loss_func=loss_func,
                           optimizer=optimizer,
                           scheduler=scheduler,
                           args=args,
                           n_iter=n_iter,
                           gp_switch=gp_switch,
                           likelihood=likelihood,
                           bbp_switch=bbp_switch)
            # save to save_dir
            #if epoch == args.epochs - 1:
            #save_checkpoint(os.path.join(save_dir, f'model_{batch_no}.pt'), model, scaler, features_scaler, args)
        # if swag, load checkpoint
        if args.swag:
            model_core = load_checkpoint(
                args.checkpoint_path +
                f'/model_{model_idx}/model_{batch_no}.pt',
                device=args.device,
                logger=None)

        ########## SWAG
        if args.swag:
            model = train_swag_pdts(model_core, train_data_loader, loss_func,
                                    scaler, features_scaler, args, save_dir,
                                    batch_no)

        ########## SGLD
        if args.sgld:
            model = train_sgld_pdts(model, train_data_loader, loss_func,
                                    scaler, features_scaler, args, save_dir,
                                    batch_no)

        ### find top_idx
        top_idx = []  # need for thom
        sum_test_preds = np.zeros(
            (len(test_orig), args.num_tasks))  # need for greedy
        for sample in range(args.samples):

            # draw model from SWAG posterior
            if args.swag:
                model.sample(scale=1.0, cov=args.cov_mat, block=args.block)

            # retrieve sgld sample
            if args.sgld:
                model = load_checkpoint(
                    args.save_dir +
                    f'/model_{model_idx}/model_{batch_no}/model_{sample}.pt',
                    device=args.device,
                    logger=logger)

            test_preds = predict(model=model,
                                 data_loader=test_data_loader,
                                 args=args,
                                 scaler=scaler,
                                 test_data=True,
                                 gp_sample=args.thompson,
                                 bbp_sample=True)
            test_preds = np.array(test_preds)
            # thompson bit
            rank = 0

            # base length
            if args.sgld:
                base_length = 5 * sample + 4
            else:
                base_length = sample

            while args.thompson and (len(top_idx) <= base_length):
                top_unique_molecule = np.argsort(-test_preds[:, 0])[rank]
                rank += 1
                if top_unique_molecule not in top_idx:
                    top_idx.append(top_unique_molecule)
            # add to sum_test_preds
            sum_test_preds += test_preds
            # print
            print('done sample ' + str(sample))
        # final top_idx
        if args.thompson:
            top_idx = np.array(top_idx)
        else:
            sum_test_preds /= args.samples
            top_idx = np.argsort(-sum_test_preds[:, 0])[:50]

        ### transfer from test to train
        top_idx = -np.sort(-top_idx)
        for idx in top_idx:
            train_orig.append(test_orig.pop(idx))
        train_data, test_data = copy.deepcopy(
            MoleculeDataset(train_orig)), copy.deepcopy(
                MoleculeDataset(test_orig))
        args.train_data_size = len(train_data)
        if args.gp:
            loss_func = gpytorch.mlls.VariationalELBO(
                likelihood, model.gp_layer, num_data=args.train_data_size)
        print(args.train_data_size)

        ### standardise features (train and test; using original features_scaler)
        train_data.normalize_features(features_scaler)
        test_data.normalize_features(features_scaler)

        ### standardise targets (train only; using original scaler)
        train_targets = train_data.targets()
        scaled_targets_tr = scaler.transform(train_targets).tolist()
        train_data.set_targets(scaled_targets_tr)

        ### create data loaders
        train_data_loader = MoleculeDataLoader(
            dataset=train_data,
            batch_size=args.batch_size,
            num_workers=num_workers,
            cache=cache,
            class_balance=args.class_balance,
            shuffle=True,
            seed=args.seed)
        test_data_loader = MoleculeDataLoader(dataset=test_data,
                                              batch_size=args.batch_size,
                                              num_workers=num_workers,
                                              cache=cache)

    # save scores
    np.savez(os.path.join(results_dir, f'ptds_{model_idx}'), ptds_scores)
コード例 #2
0
ファイル: train.py プロジェクト: jasonzdeng/chemprop
def train(model: MoleculeModel,
          data_loader: MoleculeDataLoader,
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: TrainArgs,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: A :class:`~chemprop.models.model.MoleculeModel`.
    :param data_loader: A :class:`~chemprop.data.data.MoleculeDataLoader`.
    :param loss_func: Loss function.
    :param optimizer: An optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for training the model.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for recording output.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()
    loss_sum, iter_count = 0, 0

    for batch in tqdm(data_loader, total=len(data_loader)):
        # Prepare batch
        batch: MoleculeDataset
        mol_batch, features_batch, target_batch = batch.batch_graph(
        ), batch.features(), batch.targets()
        mask = torch.Tensor([[x is not None for x in tb]
                             for tb in target_batch])
        targets = torch.Tensor([[0 if x is None else x for x in tb]
                                for tb in target_batch])

        # Run model
        model.zero_grad()
        preds = model(mol_batch, features_batch)

        # Move tensors to correct device
        mask = mask.to(preds.device)
        targets = targets.to(preds.device)
        class_weights = torch.ones(targets.shape, device=preds.device)

        if args.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([
                loss_func(preds[:, target_index, :],
                          targets[:, target_index]).unsqueeze(1)
                for target_index in range(preds.size(1))
            ],
                             dim=1) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask
        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += len(batch)

        loss.backward()
        if args.grad_clip:
            nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}'
                                for i, lr in enumerate(lrs))
            debug(
                f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}'
            )

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter
コード例 #3
0
ファイル: train.py プロジェクト: bp-kelley/chemprop
def train(model: MoleculeModel,
          data_loader: MoleculeDataLoader,
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: TrainArgs,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: A :class:`~chemprop.models.model.MoleculeModel`.
    :param data_loader: A :class:`~chemprop.data.data.MoleculeDataLoader`.
    :param loss_func: Loss function.
    :param optimizer: An optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for training the model.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for recording output.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()
    loss_sum = iter_count = 0

    for batch in tqdm(data_loader, total=len(data_loader), leave=False):
        # Prepare batch
        batch: MoleculeDataset
        mol_batch, features_batch, target_batch, mask_batch, atom_descriptors_batch, atom_features_batch, bond_features_batch, data_weights_batch = \
            batch.batch_graph(), batch.features(), batch.targets(), batch.mask(), batch.atom_descriptors(), \
            batch.atom_features(), batch.bond_features(), batch.data_weights()

        mask = torch.tensor(mask_batch, dtype=torch.bool) # shape(batch, tasks)
        targets = torch.tensor([[0 if x is None else x for x in tb] for tb in target_batch]) # shape(batch, tasks)

        if args.target_weights is not None:
            target_weights = torch.tensor(args.target_weights).unsqueeze(0) # shape(1,tasks)
        else:
            target_weights = torch.ones(targets.shape[1]).unsqueeze(0)
        data_weights = torch.tensor(data_weights_batch).unsqueeze(1) # shape(batch,1)

        if args.loss_function == 'bounded_mse':
            lt_target_batch = batch.lt_targets() # shape(batch, tasks)
            gt_target_batch = batch.gt_targets() # shape(batch, tasks)
            lt_target_batch = torch.tensor(lt_target_batch)
            gt_target_batch = torch.tensor(gt_target_batch)

        # Run model
        model.zero_grad()
        preds = model(mol_batch, features_batch, atom_descriptors_batch, atom_features_batch, bond_features_batch)

        # Move tensors to correct device
        torch_device = preds.device
        mask = mask.to(torch_device)
        targets = targets.to(torch_device)
        target_weights = target_weights.to(torch_device)
        data_weights = data_weights.to(torch_device)
        if args.loss_function == 'bounded_mse':
            lt_target_batch = lt_target_batch.to(torch_device)
            gt_target_batch = gt_target_batch.to(torch_device)

        # Calculate losses
        if args.loss_function == 'mcc' and args.dataset_type == 'classification':
            loss = loss_func(preds, targets, data_weights, mask) *target_weights.squeeze(0)
        elif args.loss_function == 'mcc': # multiclass dataset type
            targets = targets.long()
            target_losses = []
            for target_index in range(preds.size(1)):
                target_loss = loss_func(preds[:, target_index, :], targets[:, target_index], data_weights, mask[:, target_index]).unsqueeze(0)
                target_losses.append(target_loss)
            loss = torch.cat(target_losses).to(torch_device) * target_weights.squeeze(0)
        elif args.dataset_type == 'multiclass':
            targets = targets.long()
            if args.loss_function == 'dirichlet':
                loss = loss_func(preds, targets, args.evidential_regularization) * target_weights * data_weights * mask
            else:
                target_losses = []
                for target_index in range(preds.size(1)):
                    target_loss = loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1)
                    target_losses.append(target_loss)
                loss = torch.cat(target_losses, dim=1).to(torch_device) * target_weights * data_weights * mask
        elif args.dataset_type == 'spectra':
            loss = loss_func(preds, targets, mask) * target_weights * data_weights * mask
        elif args.loss_function == 'bounded_mse':
            loss = loss_func(preds, targets, lt_target_batch, gt_target_batch) * target_weights * data_weights * mask
        elif args.loss_function == 'evidential':
            loss = loss_func(preds, targets, args.evidential_regularization) * target_weights * data_weights * mask
        elif args.loss_function == 'dirichlet': # classification
            loss = loss_func(preds, targets, args.evidential_regularization) * target_weights * data_weights * mask
        else:
            loss = loss_func(preds, targets) * target_weights * data_weights * mask
        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += 1

        loss.backward()
        if args.grad_clip:
            nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum = iter_count = 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs))
            debug(f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}')

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter