コード例 #1
0
def train(
    cfg: OmegaConf,
    training_data_loader: torch.utils.data.DataLoader,
    model: ContrastiveModel,
) -> None:
    """
    Training function
    :param cfg: Hydra's config instance
    :param training_data_loader: Training data loader for contrastive learning
    :param model: Contrastive model based on resnet
    :return: None
    """
    local_rank = cfg["distributed"]["local_rank"]
    num_gpus = cfg["distributed"]["world_size"]
    epochs = cfg["parameter"]["epochs"]
    num_training_samples = len(training_data_loader.dataset.data)
    steps_per_epoch = int(
        num_training_samples /
        (cfg["experiment"]["batches"] * num_gpus))  # because the drop=True
    total_steps = cfg["parameter"]["epochs"] * steps_per_epoch
    warmup_steps = cfg["parameter"]["warmup_epochs"] * steps_per_epoch
    current_step = 0

    model.train()
    nt_cross_entropy_loss = NT_Xent(
        temperature=cfg["parameter"]["temperature"], device=local_rank)

    optimizer = torch.optim.SGD(params=exclude_from_wt_decay(
        model.named_parameters(), weight_decay=cfg["experiment"]["decay"]),
                                lr=calculate_initial_lr(cfg),
                                momentum=cfg["parameter"]["momentum"],
                                nesterov=False,
                                weight_decay=0.)

    # https://github.com/google-research/simclr/blob/master/lars_optimizer.py#L26
    optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False)

    cos_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer.optim,
        T_max=total_steps - warmup_steps,
    )

    for epoch in range(1, epochs + 1):
        training_data_loader.sampler.set_epoch(epoch)

        for (view0, view1), _ in training_data_loader:
            # adjust learning rate by applying linear warming
            if current_step <= warmup_steps:
                lr = calculate_lr(cfg, warmup_steps, current_step)
                for param_group in optimizer.param_groups:
                    param_group["lr"] = lr

            optimizer.zero_grad()
            z0 = model(view0.to(local_rank))
            z1 = model(view1.to(local_rank))
            loss = nt_cross_entropy_loss(z0, z1)
            loss.backward()
            optimizer.step()

            # adjust learning rate by applying cosine annealing
            if current_step > warmup_steps:
                cos_lr_scheduler.step()

            current_step += 1

        if local_rank == 0:
            logging.info(
                "Epoch:{}/{} progress:{:.3f} loss:{:.3f}, lr:{:.7f}".format(
                    epoch, epochs, epoch / epochs, loss.item(),
                    optimizer.param_groups[0]["lr"]))

            if epoch % cfg["experiment"]["save_model_epoch"] == 0:
                save_fname = "epoch={}-{}".format(
                    epoch, cfg["experiment"]["output_model_name"])
                torch.save(model.state_dict(), save_fname)
コード例 #2
0
        loss.backward()
        optimizer.step()
        scheduler.step()
        data_dict.update({'lr': scheduler.get_last_lr()})
        local_progress.set_postfix(data_dict)
        logger.update_scalers(data_dict)

    current_loss = data_dict['loss']

    if epoch % knn_interval == 0:
        accuracy = accuracy_monitor(model.backbone,
                                    memory_loader,
                                    test_loader,
                                    'cpu',
                                    hide_progress=True)
        data_dict['accuracy'] = accuracy

    global_progress.set_postfix(data_dict)
    logger.update_scalers(data_dict)

    model_path = os.path.join(
        ckpt_dir, f"{uid}_{datetime.now().strftime('%m%d%H%M%S')}.pth")

    if min_loss > current_loss:
        min_loss = current_loss

        torch.save({
            'epoch': epoch + 1,
            'state_dict': model.state_dict()
        }, model_path)
        # print(f'Model saved at: {model_path}')