def main():
    global args
    args = parser.parse_args()
    init_distributed_mode(args)
    fix_random_seeds(args.seed)
    logger, training_stats = initialize_exp(args, "epoch", "loss")

    # build data
    train_dataset = MultiCropDataset(
        args.data_path,
        args.size_crops,
        args.nmb_crops,
        args.min_scale_crops,
        args.max_scale_crops,
        pil_blur=args.use_pil_blur,
    )
    sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               sampler=sampler,
                                               batch_size=args.batch_size,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)
    logger.info("Building data done with {} images loaded.".format(
        len(train_dataset)))

    # build model
    model = resnet_models.__dict__[args.arch](
        normalize=True,
        hidden_mlp=args.hidden_mlp,
        output_dim=args.feat_dim,
        nmb_prototypes=args.nmb_prototypes,
    )
    # synchronize batch norm layers
    if args.sync_bn == "pytorch":
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    elif args.sync_bn == "apex":
        process_group = None
        if args.world_size // 8 > 0:
            process_group = apex.parallel.create_syncbn_process_group(
                args.world_size // 8)
        model = apex.parallel.convert_syncbn_model(model,
                                                   process_group=process_group)
    # copy model to GPU
    model = model.cuda()
    if args.rank == 0:
        logger.info(model)
    logger.info("Building model done.")

    # build optimizer
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args.base_lr,
        momentum=0.9,
        weight_decay=args.wd,
    )
    optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False)
    warmup_lr_schedule = np.linspace(args.start_warmup, args.base_lr,
                                     len(train_loader) * args.warmup_epochs)
    iters = np.arange(len(train_loader) * (args.epochs - args.warmup_epochs))
    cosine_lr_schedule = np.array([args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (1 + \
                         math.cos(math.pi * t / (len(train_loader) * (args.epochs - args.warmup_epochs)))) for t in iters])
    lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
    logger.info("Building optimizer done.")

    # init mixed precision
    if args.use_fp16:
        model, optimizer = apex.amp.initialize(model,
                                               optimizer,
                                               opt_level="O1")
        logger.info("Initializing mixed precision done.")

    # wrap model
    model = nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.gpu_to_work_on],
        find_unused_parameters=True,
    )

    # optionally resume from a checkpoint
    to_restore = {"epoch": 0}
    restart_from_checkpoint(
        os.path.join(args.dump_path, "checkpoint.pth.tar"),
        run_variables=to_restore,
        state_dict=model,
        optimizer=optimizer,
        amp=apex.amp,
    )
    start_epoch = to_restore["epoch"]

    # build the queue
    queue = None
    queue_path = os.path.join(args.dump_path,
                              "queue" + str(args.rank) + ".pth")
    if os.path.isfile(queue_path):
        queue = torch.load(queue_path)["queue"]
    # the queue needs to be divisible by the batch size
    # args.queue_length -= args.queue_length % (args.batch_size * args.world_size)

    cudnn.benchmark = True

    ## initialize queue
    print('start initialize queue')
    queue = init_queue(train_loader, model, args)
    print('queue initialize finish')

    for epoch in range(start_epoch, args.epochs):

        # train the network for one epoch
        logger.info("============ Starting epoch %i ... ============" % epoch)

        # set sampler
        train_loader.sampler.set_epoch(epoch)

        # optionally starts a queue
        # queue shape : (Ncrops, Lqueue, feat) --> (NClass, NCrops, Lqueue, feat)
        # if queue is None:
        #     queue = torch.randn(1000, args.feat_dim).cuda()
        #     queue = nn.functional.normalize(queue, dim=1, p=2)
        # train the network
        scores, queue = train(train_loader, model, optimizer, epoch,
                              lr_schedule, queue, args)
        training_stats.update(scores)

        # save checkpoints
        if args.rank == 0:
            save_dict = {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            if args.use_fp16:
                save_dict["amp"] = apex.amp.state_dict()
            torch.save(
                save_dict,
                os.path.join(args.dump_path, "checkpoint.pth.tar"),
            )
            if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1:
                shutil.copyfile(
                    os.path.join(args.dump_path, "checkpoint.pth.tar"),
                    os.path.join(args.dump_checkpoints,
                                 "ckp-" + str(epoch) + ".pth"),
                )
        if queue is not None:
            torch.save({"queue": queue}, queue_path)
Exemple #2
0
def main():
    global args
    args = parser.parse_args()
    init_distributed_mode(args)
    fix_random_seeds(args.seed)
    logger, training_stats = initialize_exp(args, "epoch", "loss")

    # build data
    train_dataset = MultiCropDataset(
        args.data_path,
        args.size_crops,
        args.nmb_crops,
        args.min_scale_crops,
        args.max_scale_crops,
        return_index=True,
    )
    sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               sampler=sampler,
                                               batch_size=args.batch_size,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)
    logger.info("Building data done with {} images loaded.".format(
        len(train_dataset)))

    # build model
    model = resnet_models.__dict__[args.arch](
        normalize=True,
        hidden_mlp=args.hidden_mlp,
        output_dim=args.feat_dim,
        nmb_prototypes=args.nmb_prototypes,
    )
    # synchronize batch norm layers
    if args.sync_bn == "pytorch":
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    elif args.sync_bn == "apex":
        process_group = None
        if args.world_size // 8 > 0:
            process_group = apex.parallel.create_syncbn_process_group(
                args.world_size // 8)
        model = apex.parallel.convert_syncbn_model(model,
                                                   process_group=process_group)
    # copy model to GPU
    model = model.cuda()
    if args.rank == 0:
        logger.info(model)
    logger.info("Building model done.")

    # build optimizer
    # base_lr=4.8 wd=1e-6
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args.base_lr,
        momentum=0.9,
        weight_decay=args.wd,
    )
    # Using Dist LARC Optimizer
    optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False)

    # LR Scheduling
    warmup_lr_schedule = np.linspace(args.start_warmup, args.base_lr,
                                     len(train_loader) * args.warmup_epochs)
    iters = np.arange(len(train_loader) * (args.epochs - args.warmup_epochs))
    cosine_lr_schedule = np.array([args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (1 + \
                         math.cos(math.pi * t / (len(train_loader) * (args.epochs - args.warmup_epochs)))) for t in iters])
    lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))

    logger.info("Building optimizer done.")

    # wrap model
    model = nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.gpu_to_work_on],
        find_unused_parameters=True,
    )

    # optionally resume from a checkpoint
    to_restore = {"epoch": 0}
    restart_from_checkpoint(
        os.path.join(args.dump_path, "checkpoint.pth.tar"),
        run_variables=to_restore,
        state_dict=model,
        optimizer=optimizer,
    )
    start_epoch = to_restore["epoch"]

    # build the memory bank
    mb_path = os.path.join(args.dump_path, "mb" + str(args.rank) + ".pth")
    if os.path.isfile(mb_path):
        mb_ckp = torch.load(mb_path)
        local_memory_index = mb_ckp["local_memory_index"]
        local_memory_embeddings = mb_ckp["local_memory_embeddings"]
    else:
        local_memory_index, local_memory_embeddings = init_memory(
            train_loader, model)

    cudnn.benchmark = True
    for epoch in range(start_epoch, args.epochs):

        # train the network for one epoch
        logger.info("============ Starting epoch %i ... ============" % epoch)

        # set sampler
        train_loader.sampler.set_epoch(epoch)

        # train the network
        scores, local_memory_index, local_memory_embeddings = train(
            train_loader,
            model,
            optimizer,
            epoch,
            lr_schedule,
            local_memory_index,
            local_memory_embeddings,
        )
        training_stats.update(scores)

        # save checkpoints
        if args.rank == 0:
            save_dict = {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            torch.save(
                save_dict,
                os.path.join(args.dump_path, "checkpoint.pth.tar"),
            )
            if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1:
                shutil.copyfile(
                    os.path.join(args.dump_path, "checkpoint.pth.tar"),
                    os.path.join(args.dump_checkpoints,
                                 "ckp-" + str(epoch) + ".pth"),
                )
        torch.save(
            {
                "local_memory_embeddings": local_memory_embeddings,
                "local_memory_index": local_memory_index
            }, mb_path)
Exemple #3
0
def main():
    global args
    args = parser.parse_args()
    init_distributed_mode(args)
    fix_random_seeds(args.seed)
    logger, training_stats = initialize_exp(args, "epoch", "loss")
    print(torch.cuda.memory_allocated(), flush=True)
    
    train_paths, train_labs, dev_paths, dev_labs, test_paths, test_labs = get_patches_labels('./sc/arion/work/millej37/ML-project/patches',
                                                                                             './sc/arion/work/millej37/ML-project/swav')
    color_transform = [get_color_distortion(), 
                           transforms.GaussianBlur(kernel_size=int(.1*224)+1,sigma=(0.1, 2.0))]
    mean = [0.485, 0.456, 0.406]
    std = [0.228, 0.224, 0.225]
    swav_transform = transforms.Compose([
                                         transforms.ToTensor(),
                                         transforms.RandomResizedCrop(),
                                         transforms.RandomHorizontalFlip(p=0.5),
                                         transforms.Compose(color_transform),
                                         transforms.Normalize(mean=mean, std=std)
                                         ])

    # build data
    train_dataset = PatchDataset(train_paths, transform=swav_transform)
    sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        sampler=sampler,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True
    )
    logger.info("Building data done with {} images loaded.".format(len(train_dataset)))

    # build model
    model = resnet_models.__dict__[args.arch](
        normalize=True,
        hidden_mlp=args.hidden_mlp,
        output_dim=args.feat_dim,
        nmb_prototypes=args.nmb_prototypes,
    )
    print(torch.cuda.memory_allocated(), flush=True)
    # synchronize batch norm layers
    if args.sync_bn == "pytorch":
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    elif args.sync_bn == "apex":
        # with apex syncbn we sync bn per group because it speeds up computation
        # compared to global syncbn
        process_group = apex.parallel.create_syncbn_process_group(args.syncbn_process_group_size)
        model = apex.parallel.convert_syncbn_model(model, process_group=process_group)
    # copy model to GPU
    model = model.cuda()
    if args.rank == 0:
        logger.info(model)
    logger.info("Building model done.")

    # build optimizer
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args.base_lr,
        momentum=0.9,
        weight_decay=args.wd,
    )
    optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False)
    warmup_lr_schedule = np.linspace(args.start_warmup, args.base_lr, len(train_loader) * args.warmup_epochs)
    iters = np.arange(len(train_loader) * (args.epochs - args.warmup_epochs))
    cosine_lr_schedule = np.array([args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (1 + \
                         math.cos(math.pi * t / (len(train_loader) * (args.epochs - args.warmup_epochs)))) for t in iters])
    lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
    logger.info("Building optimizer done.")

    # init mixed precision
    if args.use_fp16:
        model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1")
        logger.info("Initializing mixed precision done.")

    # wrap model
    model = nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.gpu_to_work_on],
        find_unused_parameters=True,
    )

    # optionally resume from a checkpoint
    to_restore = {"epoch": 0}
    restart_from_checkpoint(
        os.path.join(args.dump_path, "checkpoint.pth.tar"),
        run_variables=to_restore,
        state_dict=model,
        optimizer=optimizer,
        amp=apex.amp,
    )
    start_epoch = to_restore["epoch"]

    # build the queue
    queue = None
    queue_path = os.path.join(args.dump_path, "queue" + str(args.rank) + ".pth")
    if os.path.isfile(queue_path):
        queue = torch.load(queue_path)["queue"]
    # the queue needs to be divisible by the batch size
    args.queue_length -= args.queue_length % (args.batch_size * args.world_size)

    cudnn.benchmark = True

    for epoch in range(start_epoch, args.epochs):

        # train the network for one epoch
        logger.info("============ Starting epoch %i ... ============" % epoch)

        # set sampler
        train_loader.sampler.set_epoch(epoch)

        # optionally starts a queue
        if args.queue_length > 0 and epoch >= args.epoch_queue_starts and queue is None:
            queue = torch.zeros(
                len(args.crops_for_assign),
                args.queue_length // args.world_size,
                args.feat_dim,
            ).cuda()

        # train the network
        scores, queue = train(train_loader, model, optimizer, epoch, lr_schedule, queue)
        training_stats.update(scores)

        # save checkpoints
        if args.rank == 0:
            save_dict = {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            if args.use_fp16:
                save_dict["amp"] = apex.amp.state_dict()
            torch.save(
                save_dict,
                os.path.join(args.dump_path, "checkpoint.pth.tar"),
            )
            if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1:
                shutil.copyfile(
                    os.path.join(args.dump_path, "checkpoint.pth.tar"),
                    os.path.join(args.dump_checkpoints, "ckp-" + str(epoch) + ".pth"),
                )
        if queue is not None:
            torch.save({"queue": queue}, queue_path)
Exemple #4
0
def main():
    global args
    args = parser.parse_args()
    init_distributed_mode(args)
    fix_random_seeds(args.seed)
    logger, training_stats = initialize_exp(args, "epoch", "loss")
    writer = SummaryWriter()

    # build data
    if args.dataset == 'imagenet':
        train_dataset = MultiCropDataset(
            args.data_path,
            args.size_crops,
            args.nmb_crops,
            args.min_scale_crops,
            args.max_scale_crops,
        )
        sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   sampler=sampler,
                                                   batch_size=args.batch_size,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   drop_last=True)
    elif args.dataset == 'stl10':
        swav_train_transform = SwAVTrainDataTransform(
            normalize=stl10_normalization(),
            size_crops=args.size_crops,
            nmb_crops=args.nmb_crops,
            min_scale_crops=args.min_scale_crops,
            max_scale_crops=args.max_scale_crops,
            gaussian_blur=args.gaussian_blur,
            jitter_strength=args.jitter_strength)

        datamodule = STL10DataModule(data_dir=args.data_path,
                                     train_dist_sampler=True,
                                     num_workers=args.workers,
                                     batch_size=args.batch_size)

        datamodule.prepare_data()
        datamodule.setup()

        datamodule.train_dataloader = datamodule.train_dataloader_mixed
        datamodule.train_transforms = swav_train_transform
        train_loader = datamodule.train_dataloader_mixed()

    if args.dataset == 'imagenet':
        logger.info("Building data done with {} images loaded.".format(
            len(train_dataset)))
    elif args.dataset == 'stl10':
        logger.info("Building data done with {} images loaded.".format(
            datamodule.num_unlabeled_samples + datamodule.num_labeled_samples))

    # build model
    model = resnet_models.__dict__[args.arch](
        normalize=True,
        hidden_mlp=args.hidden_mlp,
        output_dim=args.feat_dim,
        nmb_prototypes=args.nmb_prototypes,
    )

    if args.dataset == 'stl10':
        model.maxpool = nn.MaxPool2d(kernel_size=1, stride=1)

    # synchronize batch norm layers
    if args.sync_bn == "pytorch":
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    elif args.sync_bn == "apex":
        process_group = None
        if args.world_size // 8 > 0:
            process_group = apex.parallel.create_syncbn_process_group(
                args.world_size // 8)
        model = apex.parallel.convert_syncbn_model(model,
                                                   process_group=process_group)
    # copy model to GPU
    model = model.cuda()
    if args.rank == 0:
        logger.info(model)
    logger.info("Building model done.")

    params = None
    if args.exclude_bn_bias:
        params = exclude_from_wt_decay(model.named_parameters(),
                                       weight_decay=args.wd)

    # build optimizer
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(
            params if args.exclude_bn_bias else model.parameters(),
            lr=args.base_lr,
            momentum=0.9,
            weight_decay=args.wd,
        )
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(
            params if args.exclude_bn_bias else model.parameters(),
            lr=args.base_lr,
            weight_decay=args.wd)

    optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False)
    warmup_lr_schedule = np.linspace(args.start_warmup, args.base_lr,
                                     len(train_loader) * args.warmup_epochs)
    iters = np.arange(len(train_loader) * (args.epochs - args.warmup_epochs))
    cosine_lr_schedule = np.array([args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (1 + \
                         math.cos(math.pi * t / (len(train_loader) * (args.epochs - args.warmup_epochs)))) for t in iters])
    lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
    logger.info("Building optimizer done.")

    # init mixed precision
    if args.use_fp16:
        model, optimizer = apex.amp.initialize(model,
                                               optimizer,
                                               opt_level="O1")
        logger.info("Initializing mixed precision done.")

    # wrap model
    model = nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.gpu_to_work_on],
        find_unused_parameters=True,
    )

    # optionally resume from a checkpoint
    to_restore = {"epoch": 0}
    restart_from_checkpoint(
        os.path.join(args.dump_path, "checkpoint.pth.tar"),
        run_variables=to_restore,
        state_dict=model,
        optimizer=optimizer,
        amp=apex.amp,
    )
    start_epoch = to_restore["epoch"]

    # build the queue
    queue = None
    queue_path = os.path.join(args.dump_path,
                              "queue" + str(args.rank) + ".pth")
    if os.path.isfile(queue_path):
        queue = torch.load(queue_path)["queue"]
    # the queue needs to be divisible by the batch size
    args.queue_length -= args.queue_length % (args.batch_size *
                                              args.world_size)

    cudnn.benchmark = True

    for epoch in range(start_epoch, args.epochs):

        # train the network for one epoch
        logger.info("============ Starting epoch %i ... ============" % epoch)

        # set sampler
        train_loader.sampler.set_epoch(epoch)

        # optionally starts a queue
        if args.queue_length > 0 and epoch >= args.epoch_queue_starts and queue is None:
            queue = torch.zeros(
                len(args.crops_for_assign),
                args.queue_length // args.world_size,
                args.feat_dim,
            ).cuda()

        # train the network
        scores, queue = train(train_loader, model, optimizer, epoch,
                              lr_schedule, queue)
        training_stats.update(scores)
        writer.add_scalar("Loss/train", scores[1], scores[0])

        # save checkpoints
        if args.rank == 0:
            save_dict = {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            if args.use_fp16:
                save_dict["amp"] = apex.amp.state_dict()
            torch.save(
                save_dict,
                os.path.join(args.dump_path, "checkpoint.pth.tar"),
            )
            if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1:
                shutil.copyfile(
                    os.path.join(args.dump_path, "checkpoint.pth.tar"),
                    os.path.join(args.dump_checkpoints,
                                 "ckp-" + str(epoch) + ".pth"),
                )
        if queue is not None:
            torch.save({"queue": queue}, queue_path)

    writer.flush()
def main():
    global args
    args = parser.parse_args()
    if args.distributed:
        args.rank, args.world_size, args.gpu_to_work_on = init_distributed_mode(
        )
    fix_random_seeds(args.seed)
    logger, training_stats = initialize_exp(args, "epoch", "loss", "acc",
                                            "acc_val")
    writer = SummaryWriter(args.dump_path)

    dataloaders = {}
    num_classes = 10 if args.dataset_type == 'STL10' else 100
    for split in ['train', 'test']:
        dataset = get_custom_dataset(args.dataset_type,
                                     root=args.data_path,
                                     split=split,
                                     download=args.download_dataset,
                                     return_target_word=True)
        sampler = torch.utils.data.distributed.DistributedSampler(
            dataset) if args.distributed else None
        dataloaders[split] = DataLoader(dataset,
                                        sampler=sampler,
                                        batch_size=args.batch_size,
                                        num_workers=args.num_workers,
                                        pin_memory=True,
                                        drop_last=True)

    word_embeddings = None
    if args.sim_loss:
        word_embeddings = ViCoWordEmbeddings(root=args.embed_path,
                                             num_classes=num_classes,
                                             vico_mode=args.vico_mode,
                                             one_hot=args.one_hot,
                                             linear_dim=args.linear_dim,
                                             no_hypernym=args.no_hypernym,
                                             no_glove=args.no_glove,
                                             pool_size=None)
        if args.distributed:
            word_embeddings = nn.SyncBatchNorm.convert_sync_batchnorm(
                word_embeddings)
        word_embeddings = word_embeddings.cuda()

    model = resnet_models.__dict__['resnet{}'.format(args.num_layers)](
        small_image=True,
        hidden_mlp=0,
        output_dim=num_classes,
        returned_featmaps=[3, 4, 5],
        multi_cropped_input=False)
    # synchronize batch norm layers
    if args.distributed:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = model.cuda()

    if args.rank == 0:
        logger.info(model)
    logger.info("Building model done.")

    lr = args.lr
    params = model.parameters()
    if args.optimizer == 'SGD':
        optimizer = optim.SGD(params,
                              lr=lr,
                              momentum=args.momentum,
                              weight_decay=1e-4)
    elif args.optimizer == 'Adam':
        opt = optim.Adam(params, lr=lr, weight_decay=1e-4)
    else:
        assert (False), 'optimizer not implemented'

    # objective
    criterion = nn.CrossEntropyLoss(ignore_index=-1)

    # optimizer and schedulers
    optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False)
    # warm up
    warmup_lr_schedule = np.linspace(
        args.start_warmup, args.lr,
        len(dataloaders['train']) * args.warmup_epochs)
    # cosine/step
    iters = np.arange(
        len(dataloaders['train']) * (args.num_epochs - args.warmup_epochs))
    if args.cosine:
        final_lr = args.lr * (args.lr_decay_rate)**3
        cosine_lr_schedule = np.array([final_lr + 0.5 * (args.lr - final_lr) * (1 + \
                            math.cos(math.pi * t / (len(dataloaders['train']) * (args.num_epochs - args.warmup_epochs)))) for t in iters])
        lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
    else:
        steps = np.array([
            int(item.strip()) * len(dataloaders['train'])
            for item in args.lr_decay_epochs.split(',')
        ])
        step_lr_schedule = np.array(
            [args.lr * args.lr_decay_rate**(t >= steps).sum() for t in iters])
        lr_schedule = np.concatenate((warmup_lr_schedule, step_lr_schedule))

    logger.info("Building optimizer done.")

    # wrap models
    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.gpu_to_work_on],
            find_unused_parameters=True,
        )
        if args.sim_loss:
            word_embeddings = nn.parallel.DistributedDataParallel(
                word_embeddings,
                device_ids=[args.gpu_to_work_on],
                find_unused_parameters=True,
            )

    # optionally resume from a checkpoint
    to_restore = {"epoch": 0, "val_acc": 0, "best_val_acc": 0}
    restart_from_checkpoint(os.path.join(args.dump_path, "checkpoint.pth.tar"),
                            run_variables=to_restore,
                            state_dict=model,
                            optimizer=optimizer,
                            distributed=args.distributed)

    eval_score = to_restore["val_acc"]
    start_epoch = to_restore["epoch"]
    best_val_acc = to_restore["best_val_acc"]
    for epoch in range(start_epoch, args.num_epochs):

        logger.info("============ Starting epoch %i ... ============" % epoch)

        # set sampler
        if args.distributed:
            dataloaders['train'].sampler.set_epoch(epoch)

        # train for one epoch
        scores = train_model(model, word_embeddings, dataloaders['train'],
                             optimizer, criterion, epoch, lr_schedule, writer)

        # evaluate if needed
        if epoch % args.val_freq == 0 and args.rank == 0:
            if args.distributed:
                dataloaders['test'].sampler.set_epoch(epoch)
            eval_score = eval_model(model, word_embeddings,
                                    dataloaders['test'], epoch, writer)
            if eval_score > best_val_acc:
                best_val_acc = eval_score

        training_stats.update(scores + (eval_score, ))

        if args.rank == 0:
            # after epoch: save checkpoints
            save_dict = {
                "epoch": epoch + 1,
                "val_acc": eval_score,
                "best_val_acc": best_val_acc,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            torch.save(
                save_dict,
                os.path.join(args.dump_path, "checkpoint.pth.tar"),
            )
            if epoch % args.checkpoint_freq == 0 or epoch == args.num_epochs - 1:
                shutil.copyfile(
                    os.path.join(args.dump_path, "checkpoint.pth.tar"),
                    os.path.join(args.dump_checkpoints,
                                 "ckp-" + str(epoch) + ".pth"),
                )

    writer.close()