예제 #1
0
 def prepare(self, param_groups):
     self.optimizer = torch.optim.SGD(
         param_groups,
         lr=self.parameters.lr,
         nesterov=self.parameters.nesterov,
         momentum=self.parameters.momentum,
         weight_decay=self.parameters.weight_decay,
     )
     if self.parameters.use_larc:
         try:
             from apex.parallel.LARC import LARC
         except ImportError:
             raise RuntimeError("Apex needed for LARC")
         self.optimizer = LARC(optimizer=self.optimizer, **self.larc_config)
예제 #2
0
 def init_pytorch_optimizer(self, model, **kwargs):
     super().init_pytorch_optimizer(model, **kwargs)
     self.optimizer = torch.optim.SGD(
         self.param_groups_override,
         lr=self.parameters.lr,
         nesterov=self.parameters.nesterov,
         momentum=self.parameters.momentum,
         weight_decay=self.parameters.weight_decay,
     )
     if self.parameters.use_larc:
         try:
             from apex.parallel.LARC import LARC
         except ImportError:
             raise RuntimeError("Apex needed for LARC")
         self.optimizer = LARC(optimizer=self.optimizer, **self.larc_config)
예제 #3
0
def getModel(loader, config, args):
    model = getDNN(loader, args).cuda(args.local_rank)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=config['learning_rate'],
                                 weight_decay=config['l2_regularization'])
    optimizer = LARC(optimizer)
    if config['mixed_precision']:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O2')

    if args.world_size > 1:
        model = DDP(model)
    sys.path.insert(1, os.getcwd())
    try:
        from dnn import loss_function
        loss_fn = loss_function.cuda()
        print('Imported custom loss function')
    except:
        print('Using MSELoss from pytorch')
        loss_fn = torch.nn.MSELoss().cuda()
    return model, optimizer, loss_fn
예제 #4
0
    def test_larc_mixed_precision(self):
        for opt_level in ["O0", "O1", "O2", "O3"]:
            model = MyModel(1)

            optimizer = LARC(
                torch.optim.SGD(
                    [{"params": model.parameters(), "lr": 0.25}], momentum=0.125
                )
            )

            model, optimizer = amp.initialize(
                model, optimizer, opt_level=opt_level, verbosity=0
            )

            optimizer.zero_grad()
            loss = model(self.x)
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()
예제 #5
0
def main():
    global args
    args = parser.parse_args()
    init_distributed_mode(args)
    fix_random_seeds(args.seed)
    logger, training_stats = initialize_exp(args, "epoch", "loss")

    # build data
    train_dataset = MultiCropDataset(
        args.data_path,
        args.size_crops,
        args.nmb_crops,
        args.min_scale_crops,
        args.max_scale_crops,
        pil_blur=args.use_pil_blur,
    )
    sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               sampler=sampler,
                                               batch_size=args.batch_size,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)
    logger.info("Building data done with {} images loaded.".format(
        len(train_dataset)))

    # build model
    model = resnet_models.__dict__[args.arch](
        normalize=True,
        hidden_mlp=args.hidden_mlp,
        output_dim=args.feat_dim,
        nmb_prototypes=args.nmb_prototypes,
    )
    # synchronize batch norm layers
    if args.sync_bn == "pytorch":
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    elif args.sync_bn == "apex":
        process_group = None
        if args.world_size // 8 > 0:
            process_group = apex.parallel.create_syncbn_process_group(
                args.world_size // 8)
        model = apex.parallel.convert_syncbn_model(model,
                                                   process_group=process_group)
    # copy model to GPU
    model = model.cuda()
    if args.rank == 0:
        logger.info(model)
    logger.info("Building model done.")

    # build optimizer
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args.base_lr,
        momentum=0.9,
        weight_decay=args.wd,
    )
    optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False)
    warmup_lr_schedule = np.linspace(args.start_warmup, args.base_lr,
                                     len(train_loader) * args.warmup_epochs)
    iters = np.arange(len(train_loader) * (args.epochs - args.warmup_epochs))
    cosine_lr_schedule = np.array([args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (1 + \
                         math.cos(math.pi * t / (len(train_loader) * (args.epochs - args.warmup_epochs)))) for t in iters])
    lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
    logger.info("Building optimizer done.")

    # init mixed precision
    if args.use_fp16:
        model, optimizer = apex.amp.initialize(model,
                                               optimizer,
                                               opt_level="O1")
        logger.info("Initializing mixed precision done.")

    # wrap model
    model = nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.gpu_to_work_on],
        find_unused_parameters=True,
    )

    # optionally resume from a checkpoint
    to_restore = {"epoch": 0}
    restart_from_checkpoint(
        os.path.join(args.dump_path, "checkpoint.pth.tar"),
        run_variables=to_restore,
        state_dict=model,
        optimizer=optimizer,
        amp=apex.amp,
    )
    start_epoch = to_restore["epoch"]

    # build the queue
    queue = None
    queue_path = os.path.join(args.dump_path,
                              "queue" + str(args.rank) + ".pth")
    if os.path.isfile(queue_path):
        queue = torch.load(queue_path)["queue"]
    # the queue needs to be divisible by the batch size
    # args.queue_length -= args.queue_length % (args.batch_size * args.world_size)

    cudnn.benchmark = True

    ## initialize queue
    print('start initialize queue')
    queue = init_queue(train_loader, model, args)
    print('queue initialize finish')

    for epoch in range(start_epoch, args.epochs):

        # train the network for one epoch
        logger.info("============ Starting epoch %i ... ============" % epoch)

        # set sampler
        train_loader.sampler.set_epoch(epoch)

        # optionally starts a queue
        # queue shape : (Ncrops, Lqueue, feat) --> (NClass, NCrops, Lqueue, feat)
        # if queue is None:
        #     queue = torch.randn(1000, args.feat_dim).cuda()
        #     queue = nn.functional.normalize(queue, dim=1, p=2)
        # train the network
        scores, queue = train(train_loader, model, optimizer, epoch,
                              lr_schedule, queue, args)
        training_stats.update(scores)

        # save checkpoints
        if args.rank == 0:
            save_dict = {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            if args.use_fp16:
                save_dict["amp"] = apex.amp.state_dict()
            torch.save(
                save_dict,
                os.path.join(args.dump_path, "checkpoint.pth.tar"),
            )
            if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1:
                shutil.copyfile(
                    os.path.join(args.dump_path, "checkpoint.pth.tar"),
                    os.path.join(args.dump_checkpoints,
                                 "ckp-" + str(epoch) + ".pth"),
                )
        if queue is not None:
            torch.save({"queue": queue}, queue_path)
예제 #6
0
def main():
    global args
    args = parser.parse_args()
    init_distributed_mode(args)
    fix_random_seeds(args.seed)
    logger, training_stats = initialize_exp(args, "epoch", "loss")

    # build data
    train_dataset = MultiCropDataset(
        args.data_path,
        args.size_crops,
        args.nmb_crops,
        args.min_scale_crops,
        args.max_scale_crops,
        return_index=True,
    )
    sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               sampler=sampler,
                                               batch_size=args.batch_size,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)
    logger.info("Building data done with {} images loaded.".format(
        len(train_dataset)))

    # build model
    model = resnet_models.__dict__[args.arch](
        normalize=True,
        hidden_mlp=args.hidden_mlp,
        output_dim=args.feat_dim,
        nmb_prototypes=args.nmb_prototypes,
    )
    # synchronize batch norm layers
    if args.sync_bn == "pytorch":
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    elif args.sync_bn == "apex":
        process_group = None
        if args.world_size // 8 > 0:
            process_group = apex.parallel.create_syncbn_process_group(
                args.world_size // 8)
        model = apex.parallel.convert_syncbn_model(model,
                                                   process_group=process_group)
    # copy model to GPU
    model = model.cuda()
    if args.rank == 0:
        logger.info(model)
    logger.info("Building model done.")

    # build optimizer
    # base_lr=4.8 wd=1e-6
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args.base_lr,
        momentum=0.9,
        weight_decay=args.wd,
    )
    # Using Dist LARC Optimizer
    optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False)

    # LR Scheduling
    warmup_lr_schedule = np.linspace(args.start_warmup, args.base_lr,
                                     len(train_loader) * args.warmup_epochs)
    iters = np.arange(len(train_loader) * (args.epochs - args.warmup_epochs))
    cosine_lr_schedule = np.array([args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (1 + \
                         math.cos(math.pi * t / (len(train_loader) * (args.epochs - args.warmup_epochs)))) for t in iters])
    lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))

    logger.info("Building optimizer done.")

    # wrap model
    model = nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.gpu_to_work_on],
        find_unused_parameters=True,
    )

    # optionally resume from a checkpoint
    to_restore = {"epoch": 0}
    restart_from_checkpoint(
        os.path.join(args.dump_path, "checkpoint.pth.tar"),
        run_variables=to_restore,
        state_dict=model,
        optimizer=optimizer,
    )
    start_epoch = to_restore["epoch"]

    # build the memory bank
    mb_path = os.path.join(args.dump_path, "mb" + str(args.rank) + ".pth")
    if os.path.isfile(mb_path):
        mb_ckp = torch.load(mb_path)
        local_memory_index = mb_ckp["local_memory_index"]
        local_memory_embeddings = mb_ckp["local_memory_embeddings"]
    else:
        local_memory_index, local_memory_embeddings = init_memory(
            train_loader, model)

    cudnn.benchmark = True
    for epoch in range(start_epoch, args.epochs):

        # train the network for one epoch
        logger.info("============ Starting epoch %i ... ============" % epoch)

        # set sampler
        train_loader.sampler.set_epoch(epoch)

        # train the network
        scores, local_memory_index, local_memory_embeddings = train(
            train_loader,
            model,
            optimizer,
            epoch,
            lr_schedule,
            local_memory_index,
            local_memory_embeddings,
        )
        training_stats.update(scores)

        # save checkpoints
        if args.rank == 0:
            save_dict = {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            torch.save(
                save_dict,
                os.path.join(args.dump_path, "checkpoint.pth.tar"),
            )
            if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1:
                shutil.copyfile(
                    os.path.join(args.dump_path, "checkpoint.pth.tar"),
                    os.path.join(args.dump_checkpoints,
                                 "ckp-" + str(epoch) + ".pth"),
                )
        torch.save(
            {
                "local_memory_embeddings": local_memory_embeddings,
                "local_memory_index": local_memory_index
            }, mb_path)
예제 #7
0
def learning(
    cfg: OmegaConf,
    training_data_loader: torch.utils.data.DataLoader,
    validation_data_loader: torch.utils.data.DataLoader,
    model: SupervisedModel,
) -> None:
    """
    Learning function including evaluation

    :param cfg: Hydra's config instance
    :param training_data_loader: Training data loader
    :param validation_data_loader: Validation data loader
    :param model: Model
    :return: None
    """

    local_rank = cfg["distributed"]["local_rank"]
    num_gpus = cfg["distributed"]["world_size"]
    epochs = cfg["parameter"]["epochs"]
    num_training_samples = len(training_data_loader.dataset.data)
    steps_per_epoch = int(
        num_training_samples /
        (cfg["experiment"]["batches"] * num_gpus))  # because the drop=True
    total_steps = cfg["parameter"]["epochs"] * steps_per_epoch
    warmup_steps = cfg["parameter"]["warmup_epochs"] * steps_per_epoch
    current_step = 0

    best_metric = np.finfo(np.float64).max

    optimizer = torch.optim.SGD(params=model.parameters(),
                                lr=calculate_initial_lr(cfg),
                                momentum=cfg["parameter"]["momentum"],
                                nesterov=False,
                                weight_decay=cfg["experiment"]["decay"])

    # https://github.com/google-research/simclr/blob/master/lars_optimizer.py#L26
    optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False)

    cos_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer.optim,
        T_max=total_steps - warmup_steps,
    )

    for epoch in range(1, epochs + 1):
        # training
        model.train()
        training_data_loader.sampler.set_epoch(epoch)

        for data, targets in training_data_loader:
            # adjust learning rate by applying linear warming
            if current_step <= warmup_steps:
                lr = calculate_lr(cfg, warmup_steps, current_step)
                for param_group in optimizer.param_groups:
                    param_group["lr"] = lr

            optimizer.zero_grad()
            data, targets = data.to(local_rank), targets.to(local_rank)
            unnormalized_features = model(data)
            loss = torch.nn.functional.cross_entropy(unnormalized_features,
                                                     targets)
            loss.backward()
            optimizer.step()

            # adjust learning rate by applying cosine annealing
            if current_step > warmup_steps:
                cos_lr_scheduler.step()

            current_step += 1

        if local_rank == 0:
            logger_line = "Epoch:{}/{} progress:{:.3f} loss:{:.3f}, lr:{:.7f}".format(
                epoch, epochs, epoch / epochs, loss.item(),
                optimizer.param_groups[0]["lr"])

        # During warmup phase, we skip validation
        sum_val_loss, num_val_corrects = validation(validation_data_loader,
                                                    model, local_rank)

        torch.distributed.barrier()
        torch.distributed.reduce(sum_val_loss, dst=0)
        torch.distributed.reduce(num_val_corrects, dst=0)

        num_val_samples = len(validation_data_loader.dataset)

        # logging and save checkpoint
        if local_rank == 0:

            validation_loss = sum_val_loss.item() / num_val_samples
            validation_acc = num_val_corrects.item() / num_val_samples

            logging.info(logger_line +
                         " val loss:{:.3f}, val acc:{:.2f}%".format(
                             validation_loss, validation_acc * 100.))

            if cfg["parameter"]["metric"] == "loss":
                metric = validation_loss
            else:
                metric = 1. - validation_acc

            if metric <= best_metric:
                if "save_fname" in locals():
                    if os.path.exists(save_fname):
                        os.remove(save_fname)

                save_fname = "epoch={}-{}".format(
                    epoch, cfg["experiment"]["output_model_name"])
                torch.save(model.state_dict(), save_fname)
예제 #8
0
def main():
    global args
    args = parser.parse_args()
    init_distributed_mode(args)
    fix_random_seeds(args.seed)
    logger, training_stats = initialize_exp(args, "epoch", "loss")
    writer = SummaryWriter()

    # build data
    if args.dataset == 'imagenet':
        train_dataset = MultiCropDataset(
            args.data_path,
            args.size_crops,
            args.nmb_crops,
            args.min_scale_crops,
            args.max_scale_crops,
        )
        sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   sampler=sampler,
                                                   batch_size=args.batch_size,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   drop_last=True)
    elif args.dataset == 'stl10':
        swav_train_transform = SwAVTrainDataTransform(
            normalize=stl10_normalization(),
            size_crops=args.size_crops,
            nmb_crops=args.nmb_crops,
            min_scale_crops=args.min_scale_crops,
            max_scale_crops=args.max_scale_crops,
            gaussian_blur=args.gaussian_blur,
            jitter_strength=args.jitter_strength)

        datamodule = STL10DataModule(data_dir=args.data_path,
                                     train_dist_sampler=True,
                                     num_workers=args.workers,
                                     batch_size=args.batch_size)

        datamodule.prepare_data()
        datamodule.setup()

        datamodule.train_dataloader = datamodule.train_dataloader_mixed
        datamodule.train_transforms = swav_train_transform
        train_loader = datamodule.train_dataloader_mixed()

    if args.dataset == 'imagenet':
        logger.info("Building data done with {} images loaded.".format(
            len(train_dataset)))
    elif args.dataset == 'stl10':
        logger.info("Building data done with {} images loaded.".format(
            datamodule.num_unlabeled_samples + datamodule.num_labeled_samples))

    # build model
    model = resnet_models.__dict__[args.arch](
        normalize=True,
        hidden_mlp=args.hidden_mlp,
        output_dim=args.feat_dim,
        nmb_prototypes=args.nmb_prototypes,
    )

    if args.dataset == 'stl10':
        model.maxpool = nn.MaxPool2d(kernel_size=1, stride=1)

    # synchronize batch norm layers
    if args.sync_bn == "pytorch":
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    elif args.sync_bn == "apex":
        process_group = None
        if args.world_size // 8 > 0:
            process_group = apex.parallel.create_syncbn_process_group(
                args.world_size // 8)
        model = apex.parallel.convert_syncbn_model(model,
                                                   process_group=process_group)
    # copy model to GPU
    model = model.cuda()
    if args.rank == 0:
        logger.info(model)
    logger.info("Building model done.")

    params = None
    if args.exclude_bn_bias:
        params = exclude_from_wt_decay(model.named_parameters(),
                                       weight_decay=args.wd)

    # build optimizer
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(
            params if args.exclude_bn_bias else model.parameters(),
            lr=args.base_lr,
            momentum=0.9,
            weight_decay=args.wd,
        )
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(
            params if args.exclude_bn_bias else model.parameters(),
            lr=args.base_lr,
            weight_decay=args.wd)

    optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False)
    warmup_lr_schedule = np.linspace(args.start_warmup, args.base_lr,
                                     len(train_loader) * args.warmup_epochs)
    iters = np.arange(len(train_loader) * (args.epochs - args.warmup_epochs))
    cosine_lr_schedule = np.array([args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (1 + \
                         math.cos(math.pi * t / (len(train_loader) * (args.epochs - args.warmup_epochs)))) for t in iters])
    lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
    logger.info("Building optimizer done.")

    # init mixed precision
    if args.use_fp16:
        model, optimizer = apex.amp.initialize(model,
                                               optimizer,
                                               opt_level="O1")
        logger.info("Initializing mixed precision done.")

    # wrap model
    model = nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.gpu_to_work_on],
        find_unused_parameters=True,
    )

    # optionally resume from a checkpoint
    to_restore = {"epoch": 0}
    restart_from_checkpoint(
        os.path.join(args.dump_path, "checkpoint.pth.tar"),
        run_variables=to_restore,
        state_dict=model,
        optimizer=optimizer,
        amp=apex.amp,
    )
    start_epoch = to_restore["epoch"]

    # build the queue
    queue = None
    queue_path = os.path.join(args.dump_path,
                              "queue" + str(args.rank) + ".pth")
    if os.path.isfile(queue_path):
        queue = torch.load(queue_path)["queue"]
    # the queue needs to be divisible by the batch size
    args.queue_length -= args.queue_length % (args.batch_size *
                                              args.world_size)

    cudnn.benchmark = True

    for epoch in range(start_epoch, args.epochs):

        # train the network for one epoch
        logger.info("============ Starting epoch %i ... ============" % epoch)

        # set sampler
        train_loader.sampler.set_epoch(epoch)

        # optionally starts a queue
        if args.queue_length > 0 and epoch >= args.epoch_queue_starts and queue is None:
            queue = torch.zeros(
                len(args.crops_for_assign),
                args.queue_length // args.world_size,
                args.feat_dim,
            ).cuda()

        # train the network
        scores, queue = train(train_loader, model, optimizer, epoch,
                              lr_schedule, queue)
        training_stats.update(scores)
        writer.add_scalar("Loss/train", scores[1], scores[0])

        # save checkpoints
        if args.rank == 0:
            save_dict = {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            if args.use_fp16:
                save_dict["amp"] = apex.amp.state_dict()
            torch.save(
                save_dict,
                os.path.join(args.dump_path, "checkpoint.pth.tar"),
            )
            if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1:
                shutil.copyfile(
                    os.path.join(args.dump_path, "checkpoint.pth.tar"),
                    os.path.join(args.dump_checkpoints,
                                 "ckp-" + str(epoch) + ".pth"),
                )
        if queue is not None:
            torch.save({"queue": queue}, queue_path)

    writer.flush()
예제 #9
0
                                          drop_last=True)

# Create Model
net = get_model()
net = nn.SyncBatchNorm.convert_sync_batchnorm(net)
net = net.cuda()

# Optimizer
#optimizer = torch.optim.Adam(net.parameters(), lr=0.1)
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=4.8,
    momentum=0.9,
    weight_decay=1e-6,
)
optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False)
warmup_lr_schedule = np.linspace(0, 4.8, len(train_loader) * 10)
iters = np.arange(len(train_loader) * (EPOCHS - 10))
cosine_lr_schedule = np.array([0 + 0.5 * (4.8 - 0) * (1 + \
                        math.cos(math.pi * t / (len(train_loader) * (EPOCHS - 10)))) for t in iters])
lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1")

# Data Parallel Model
net = torch.nn.DataParallel(net)

# Criterion
#criterion = nn.CrossEntropyLoss()

# misc
cudnn.benchmark = True
예제 #10
0
def main():
    global args
    args = parser.parse_args()
    init_distributed_mode(args)
    fix_random_seeds(args.seed)
    logger, training_stats = initialize_exp(args, "epoch", "loss")
    print(torch.cuda.memory_allocated(), flush=True)
    
    train_paths, train_labs, dev_paths, dev_labs, test_paths, test_labs = get_patches_labels('./sc/arion/work/millej37/ML-project/patches',
                                                                                             './sc/arion/work/millej37/ML-project/swav')
    color_transform = [get_color_distortion(), 
                           transforms.GaussianBlur(kernel_size=int(.1*224)+1,sigma=(0.1, 2.0))]
    mean = [0.485, 0.456, 0.406]
    std = [0.228, 0.224, 0.225]
    swav_transform = transforms.Compose([
                                         transforms.ToTensor(),
                                         transforms.RandomResizedCrop(),
                                         transforms.RandomHorizontalFlip(p=0.5),
                                         transforms.Compose(color_transform),
                                         transforms.Normalize(mean=mean, std=std)
                                         ])

    # build data
    train_dataset = PatchDataset(train_paths, transform=swav_transform)
    sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        sampler=sampler,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True
    )
    logger.info("Building data done with {} images loaded.".format(len(train_dataset)))

    # build model
    model = resnet_models.__dict__[args.arch](
        normalize=True,
        hidden_mlp=args.hidden_mlp,
        output_dim=args.feat_dim,
        nmb_prototypes=args.nmb_prototypes,
    )
    print(torch.cuda.memory_allocated(), flush=True)
    # synchronize batch norm layers
    if args.sync_bn == "pytorch":
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    elif args.sync_bn == "apex":
        # with apex syncbn we sync bn per group because it speeds up computation
        # compared to global syncbn
        process_group = apex.parallel.create_syncbn_process_group(args.syncbn_process_group_size)
        model = apex.parallel.convert_syncbn_model(model, process_group=process_group)
    # copy model to GPU
    model = model.cuda()
    if args.rank == 0:
        logger.info(model)
    logger.info("Building model done.")

    # build optimizer
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args.base_lr,
        momentum=0.9,
        weight_decay=args.wd,
    )
    optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False)
    warmup_lr_schedule = np.linspace(args.start_warmup, args.base_lr, len(train_loader) * args.warmup_epochs)
    iters = np.arange(len(train_loader) * (args.epochs - args.warmup_epochs))
    cosine_lr_schedule = np.array([args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (1 + \
                         math.cos(math.pi * t / (len(train_loader) * (args.epochs - args.warmup_epochs)))) for t in iters])
    lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
    logger.info("Building optimizer done.")

    # init mixed precision
    if args.use_fp16:
        model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1")
        logger.info("Initializing mixed precision done.")

    # wrap model
    model = nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.gpu_to_work_on],
        find_unused_parameters=True,
    )

    # optionally resume from a checkpoint
    to_restore = {"epoch": 0}
    restart_from_checkpoint(
        os.path.join(args.dump_path, "checkpoint.pth.tar"),
        run_variables=to_restore,
        state_dict=model,
        optimizer=optimizer,
        amp=apex.amp,
    )
    start_epoch = to_restore["epoch"]

    # build the queue
    queue = None
    queue_path = os.path.join(args.dump_path, "queue" + str(args.rank) + ".pth")
    if os.path.isfile(queue_path):
        queue = torch.load(queue_path)["queue"]
    # the queue needs to be divisible by the batch size
    args.queue_length -= args.queue_length % (args.batch_size * args.world_size)

    cudnn.benchmark = True

    for epoch in range(start_epoch, args.epochs):

        # train the network for one epoch
        logger.info("============ Starting epoch %i ... ============" % epoch)

        # set sampler
        train_loader.sampler.set_epoch(epoch)

        # optionally starts a queue
        if args.queue_length > 0 and epoch >= args.epoch_queue_starts and queue is None:
            queue = torch.zeros(
                len(args.crops_for_assign),
                args.queue_length // args.world_size,
                args.feat_dim,
            ).cuda()

        # train the network
        scores, queue = train(train_loader, model, optimizer, epoch, lr_schedule, queue)
        training_stats.update(scores)

        # save checkpoints
        if args.rank == 0:
            save_dict = {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            if args.use_fp16:
                save_dict["amp"] = apex.amp.state_dict()
            torch.save(
                save_dict,
                os.path.join(args.dump_path, "checkpoint.pth.tar"),
            )
            if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1:
                shutil.copyfile(
                    os.path.join(args.dump_path, "checkpoint.pth.tar"),
                    os.path.join(args.dump_checkpoints, "ckp-" + str(epoch) + ".pth"),
                )
        if queue is not None:
            torch.save({"queue": queue}, queue_path)
예제 #11
0
def train(
    cfg: OmegaConf,
    training_data_loader: torch.utils.data.DataLoader,
    model: ContrastiveModel,
) -> None:
    """
    Training function
    :param cfg: Hydra's config instance
    :param training_data_loader: Training data loader for contrastive learning
    :param model: Contrastive model based on resnet
    :return: None
    """
    local_rank = cfg["distributed"]["local_rank"]
    num_gpus = cfg["distributed"]["world_size"]
    epochs = cfg["parameter"]["epochs"]
    num_training_samples = len(training_data_loader.dataset.data)
    steps_per_epoch = int(
        num_training_samples /
        (cfg["experiment"]["batches"] * num_gpus))  # because the drop=True
    total_steps = cfg["parameter"]["epochs"] * steps_per_epoch
    warmup_steps = cfg["parameter"]["warmup_epochs"] * steps_per_epoch
    current_step = 0

    model.train()
    nt_cross_entropy_loss = NT_Xent(
        temperature=cfg["parameter"]["temperature"], device=local_rank)

    optimizer = torch.optim.SGD(params=exclude_from_wt_decay(
        model.named_parameters(), weight_decay=cfg["experiment"]["decay"]),
                                lr=calculate_initial_lr(cfg),
                                momentum=cfg["parameter"]["momentum"],
                                nesterov=False,
                                weight_decay=0.)

    # https://github.com/google-research/simclr/blob/master/lars_optimizer.py#L26
    optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False)

    cos_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer.optim,
        T_max=total_steps - warmup_steps,
    )

    for epoch in range(1, epochs + 1):
        training_data_loader.sampler.set_epoch(epoch)

        for (view0, view1), _ in training_data_loader:
            # adjust learning rate by applying linear warming
            if current_step <= warmup_steps:
                lr = calculate_lr(cfg, warmup_steps, current_step)
                for param_group in optimizer.param_groups:
                    param_group["lr"] = lr

            optimizer.zero_grad()
            z0 = model(view0.to(local_rank))
            z1 = model(view1.to(local_rank))
            loss = nt_cross_entropy_loss(z0, z1)
            loss.backward()
            optimizer.step()

            # adjust learning rate by applying cosine annealing
            if current_step > warmup_steps:
                cos_lr_scheduler.step()

            current_step += 1

        if local_rank == 0:
            logging.info(
                "Epoch:{}/{} progress:{:.3f} loss:{:.3f}, lr:{:.7f}".format(
                    epoch, epochs, epoch / epochs, loss.item(),
                    optimizer.param_groups[0]["lr"]))

            if epoch % cfg["experiment"]["save_model_epoch"] == 0:
                save_fname = "epoch={}-{}".format(
                    epoch, cfg["experiment"]["output_model_name"])
                torch.save(model.state_dict(), save_fname)
예제 #12
0
def train300_mlperf_coco(args):
    from coco import COCO

    # Check that GPUs are actually available
    use_cuda = not args.no_cuda

    # Setup multi-GPU if necessary
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
    local_seed = set_seeds(args)
    # start timing here
    ssd_print(key=mlperf_log.RUN_START)
    if args.distributed:
        N_gpu = torch.distributed.get_world_size()
    else:
        N_gpu = 1

    # Setup data, defaults
    dboxes = dboxes300_coco()
    encoder = Encoder(dboxes)
    input_size = 300
    val_trans = SSDTransformer(dboxes, (input_size, input_size), val=True)
    ssd_print(key=mlperf_log.INPUT_SIZE, value=input_size)

    val_annotate = os.path.join(args.data,
                                "annotations/instances_val2017.json")
    val_coco_root = os.path.join(args.data, "val2017")
    train_annotate = os.path.join(args.data,
                                  "annotations/instances_train2017.json")
    train_coco_root = os.path.join(args.data, "train2017")

    cocoGt = COCO(annotation_file=val_annotate)
    val_coco = COCODetection(val_coco_root, val_annotate, val_trans)

    if args.distributed:
        val_sampler = GeneralDistributedSampler(val_coco, pad=False)
    else:
        val_sampler = None

    train_pipe = COCOPipeline(args.batch_size,
                              args.local_rank,
                              train_coco_root,
                              train_annotate,
                              N_gpu,
                              num_threads=args.num_workers,
                              output_fp16=args.use_fp16,
                              output_nhwc=args.nhwc,
                              pad_output=args.pad_input,
                              seed=local_seed - 2**31)
    print_message(args.local_rank,
                  "time_check a: {secs:.9f}".format(secs=time.time()))
    train_pipe.build()
    print_message(args.local_rank,
                  "time_check b: {secs:.9f}".format(secs=time.time()))
    test_run = train_pipe.run()
    train_loader = DALICOCOIterator(train_pipe, 118287 / N_gpu)

    val_dataloader = DataLoader(
        val_coco,
        batch_size=args.eval_batch_size,
        shuffle=False,  # Note: distributed sampler is shuffled :(
        sampler=val_sampler,
        num_workers=args.num_workers)
    ssd_print(key=mlperf_log.INPUT_ORDER)
    ssd_print(key=mlperf_log.INPUT_BATCH_SIZE, value=args.batch_size)
    # Build the model
    ssd300 = SSD300(val_coco.labelnum,
                    backbone=args.backbone,
                    use_nhwc=args.nhwc,
                    pad_input=args.pad_input)
    if args.checkpoint is not None:
        load_checkpoint(ssd300, args.checkpoint)

    ssd300.train()
    ssd300.cuda()
    loss_func = Loss(dboxes)
    loss_func.cuda()

    if args.distributed:
        N_gpu = torch.distributed.get_world_size()
    else:
        N_gpu = 1

    if args.use_fp16:
        ssd300 = network_to_half(ssd300)

    # Parallelize.  Need to do this after network_to_half.
    if args.distributed:
        if args.delay_allreduce:
            print_message(args.local_rank,
                          "Delaying allreduces to the end of backward()")
        ssd300 = DDP(ssd300,
                     delay_allreduce=args.delay_allreduce,
                     retain_allreduce_buffers=args.use_fp16)

    # Create optimizer.  This must also be done after network_to_half.
    global_batch_size = (N_gpu * args.batch_size)

    # mlperf only allows base_lr scaled by an integer
    base_lr = 1e-3
    requested_lr_multiplier = args.lr / base_lr
    adjusted_multiplier = max(
        1, round(requested_lr_multiplier * global_batch_size / 32))

    current_lr = base_lr * adjusted_multiplier
    current_momentum = 0.9
    current_weight_decay = 5e-4
    static_loss_scale = 128.
    if args.use_fp16:
        if args.distributed and not args.delay_allreduce:
            # We can't create the flat master params yet, because we need to
            # imitate the flattened bucket structure that DDP produces.
            optimizer_created = False
        else:
            model_buckets = [
                [
                    p for p in ssd300.parameters()
                    if p.requires_grad and p.type() == "torch.cuda.HalfTensor"
                ],
                [
                    p for p in ssd300.parameters()
                    if p.requires_grad and p.type() == "torch.cuda.FloatTensor"
                ]
            ]
            flat_master_buckets = create_flat_master(model_buckets)
            optim = torch.optim.SGD(flat_master_buckets,
                                    lr=current_lr,
                                    momentum=current_momentum,
                                    weight_decay=current_weight_decay)
            optimizer_created = True
    else:
        optim = torch.optim.SGD(ssd300.parameters(),
                                lr=current_lr,
                                momentum=current_momentum,
                                weight_decay=current_weight_decay)
        optimizer_created = True

    # Add LARC if desired
    if args.use_larc:
        optim = LARC(optim)

    ssd_print(key=mlperf_log.OPT_NAME, value="SGD")
    ssd_print(key=mlperf_log.OPT_LR, value=current_lr)
    ssd_print(key=mlperf_log.OPT_MOMENTUM, value=current_momentum)
    ssd_print(key=mlperf_log.OPT_WEIGHT_DECAY, value=current_weight_decay)
    if args.warmup is not None:
        ssd_print(key=mlperf_log.OPT_LR_WARMUP_STEPS, value=args.warmup)

    # Model is completely finished -- need to create separate copies, preserve parameters across
    # them, and jit
    ssd300_eval = SSD300(val_coco.labelnum,
                         backbone=args.backbone,
                         use_nhwc=args.nhwc,
                         pad_input=args.pad_input).cuda()
    if args.use_fp16:
        ssd300_eval = network_to_half(ssd300_eval)

    # Get the existant state from the train model
    # * if we use distributed, then we want .module
    train_model = ssd300.module if args.distributed else ssd300

    ssd300_eval.load_state_dict(train_model.state_dict())

    ssd300_eval.eval()

    if args.jit:
        input_c = 4 if args.pad_input else 3
        example_shape = [
            args.batch_size, 300, 300, input_c
        ] if args.nhwc else [args.batch_size, input_c, 300, 300]
        example_input = torch.randn(*example_shape).cuda()
        if args.use_fp16:
            example_input = example_input.half()
        # DDP has some Python-side control flow.  If we JIT the entire DDP-wrapped module,
        # the resulting ScriptModule will elide this control flow, resulting in allreduce
        # hooks not being called.  If we're running distributed, we need to extract and JIT
        # the wrapped .module.
        # Replacing a DDP-ed ssd300 with a script_module might also cause the AccumulateGrad hooks
        # to go out of scope, and therefore silently disappear.
        module_to_jit = ssd300.module if args.distributed else ssd300
        if args.distributed:
            ssd300.module = torch.jit.trace(module_to_jit, example_input)
        else:
            ssd300 = torch.jit.trace(module_to_jit, example_input)

    print_message(args.local_rank, "epoch", "nbatch", "loss")
    eval_points = np.array(args.evaluation) * 32 / global_batch_size
    eval_points = list(map(int, list(eval_points)))

    iter_num = args.iteration
    avg_loss = 0.0
    inv_map = {v: k for k, v in val_coco.label_map.items()}

    start_elapsed_time = time.time()
    last_printed_iter = args.iteration
    num_elapsed_samples = 0

    # Generate normalization tensors
    mean, std = generate_mean_std(args)

    def step_maybe_fp16_maybe_distributed(optim):
        if args.use_fp16:
            if args.distributed:
                for flat_master, allreduce_buffer in zip(
                        flat_master_buckets, ssd300.allreduce_buffers):
                    if allreduce_buffer is None:
                        raise RuntimeError("allreduce_buffer is None")
                    flat_master.grad = allreduce_buffer.float()
                    flat_master.grad.data.mul_(1. / static_loss_scale)
            else:
                for flat_master, model_bucket in zip(flat_master_buckets,
                                                     model_buckets):
                    flat_grad = apex_C.flatten(
                        [m.grad.data for m in model_bucket])
                    flat_master.grad = flat_grad.float()
                    flat_master.grad.data.mul_(1. / static_loss_scale)
        optim.step()
        if args.use_fp16:
            for model_bucket, flat_master in zip(model_buckets,
                                                 flat_master_buckets):
                for model, master in zip(
                        model_bucket,
                        apex_C.unflatten(flat_master.data, model_bucket)):
                    model.data.copy_(master.data)

    ssd_print(key=mlperf_log.TRAIN_LOOP)
    for epoch in range(args.epochs):
        ssd_print(key=mlperf_log.TRAIN_EPOCH, value=epoch)
        for p in ssd300.parameters():
            p.grad = None
        for i, data in enumerate(train_loader):
            img = data[0][0][0]
            bbox = data[0][1][0]
            label = data[0][2][0]
            label = label.type(torch.cuda.LongTensor)
            bbox_offsets = data[0][3][0]

            # handle random flipping outside of DALI for now
            bbox_offsets = bbox_offsets.cuda()
            img, bbox = C.random_horiz_flip(img, bbox, bbox_offsets, 0.5,
                                            args.nhwc)
            img.sub_(mean).div_(std)

            if args.profile is not None and iter_num == args.profile:
                return
            if args.warmup is not None and optimizer_created:
                lr_warmup(optim, args.warmup, iter_num, epoch, current_lr,
                          args)
            if iter_num == ((args.decay1 * 1000 * 32) // global_batch_size):
                print_message(args.local_rank, "lr decay step #1")
                current_lr *= 0.1
                for param_group in optim.param_groups:
                    param_group['lr'] = current_lr
                ssd_print(key=mlperf_log.OPT_LR, value=current_lr)

            if iter_num == ((args.decay2 * 1000 * 32) // global_batch_size):
                print_message(args.local_rank, "lr decay step #2")
                current_lr *= 0.1
                for param_group in optim.param_groups:
                    param_group['lr'] = current_lr
                ssd_print(key=mlperf_log.OPT_LR, value=current_lr)

            if use_cuda:
                img = img.cuda()
                # NHWC direct from DALI now if necessary
                bbox = bbox.cuda()
                label = label.cuda()
                bbox_offsets = bbox_offsets.cuda()

            # Now run the batched box encoder
            N = img.shape[0]
            if bbox_offsets[-1].item() == 0:
                print("No labels in batch")
                continue
            bbox, label = C.box_encoder(N, bbox, bbox_offsets, label,
                                        encoder.dboxes.cuda(), 0.5)

            # output is ([N*8732, 4], [N*8732], need [N, 8732, 4], [N, 8732] respectively
            M = bbox.shape[0] // N
            bbox = bbox.view(N, M, 4)
            label = label.view(N, M)
            # print(img.shape, bbox.shape, label.shape)

            ploc, plabel = ssd300(img)
            ploc, plabel = ploc.float(), plabel.float()

            trans_bbox = bbox.transpose(1, 2).contiguous().cuda()
            label = label.cuda()
            gloc, glabel = Variable(trans_bbox, requires_grad=False), \
                           Variable(label, requires_grad=False)
            loss = loss_func(ploc, plabel, gloc, glabel)

            if not np.isinf(loss.item()):
                avg_loss = 0.999 * avg_loss + 0.001 * loss.item()

            num_elapsed_samples += N
            if args.local_rank == 0 and iter_num % args.print_interval == 0:
                end_elapsed_time = time.time()
                elapsed_time = end_elapsed_time - start_elapsed_time

                avg_samples_per_sec = num_elapsed_samples * N_gpu / elapsed_time

                print("Iteration: {:6d}, Loss function: {:5.3f}, Average Loss: {:.3f}, avg. samples / sec: {:.2f}"\
                            .format(iter_num, loss.item(), avg_loss, avg_samples_per_sec), end="\n")

                last_printed_iter = iter_num
                start_elapsed_time = time.time()
                num_elapsed_samples = 0

            # loss scaling
            if args.use_fp16:
                loss = loss * static_loss_scale
            loss.backward()

            if not optimizer_created:
                # Imitate the model bucket structure created by DDP.
                # These will already be split by type (float or half).
                model_buckets = []
                for bucket in ssd300.active_i_buckets:
                    model_buckets.append([])
                    for active_i in bucket:
                        model_buckets[-1].append(
                            ssd300.active_params[active_i])
                flat_master_buckets = create_flat_master(model_buckets)
                optim = torch.optim.SGD(flat_master_buckets,
                                        lr=current_lr,
                                        momentum=current_momentum,
                                        weight_decay=current_weight_decay)
                optimizer_created = True
                # Skip this first iteration because flattened allreduce buffers are not yet created.
                # step_maybe_fp16_maybe_distributed(optim)
            else:
                step_maybe_fp16_maybe_distributed(optim)

            # Likely a decent skew here, let's take this opportunity to set the gradients to None.
            # After DALI integration, playing with the placement of this is worth trying.
            for p in ssd300.parameters():
                p.grad = None

            if iter_num in eval_points:
                if args.local_rank == 0:
                    if not args.no_save:
                        print("saving model...")
                        torch.save(
                            {
                                "model": ssd300.state_dict(),
                                "label_map": val_coco.label_info
                            }, "./models/iter_{}.pt".format(iter_num))

# Get the existant state from the train model
# * if we use distributed, then we want .module
                train_model = ssd300.module if args.distributed else ssd300

                ssd300_eval.load_state_dict(train_model.state_dict())
                if coco_eval(
                        ssd300_eval,
                        val_dataloader,
                        cocoGt,
                        encoder,
                        inv_map,
                        args.threshold,
                        epoch,
                        iter_num,
                        args.eval_batch_size,
                        use_fp16=args.use_fp16,
                        local_rank=args.local_rank if args.distributed else -1,
                        N_gpu=N_gpu,
                        use_nhwc=args.nhwc,
                        pad_input=args.pad_input):
                    return True

            iter_num += 1

        train_loader.reset()
    return False
def main():
    global args
    args = parser.parse_args()
    if args.distributed:
        args.rank, args.world_size, args.gpu_to_work_on = init_distributed_mode(
        )
    fix_random_seeds(args.seed)
    logger, training_stats = initialize_exp(args, "epoch", "loss", "acc",
                                            "acc_val")
    writer = SummaryWriter(args.dump_path)

    dataloaders = {}
    num_classes = 10 if args.dataset_type == 'STL10' else 100
    for split in ['train', 'test']:
        dataset = get_custom_dataset(args.dataset_type,
                                     root=args.data_path,
                                     split=split,
                                     download=args.download_dataset,
                                     return_target_word=True)
        sampler = torch.utils.data.distributed.DistributedSampler(
            dataset) if args.distributed else None
        dataloaders[split] = DataLoader(dataset,
                                        sampler=sampler,
                                        batch_size=args.batch_size,
                                        num_workers=args.num_workers,
                                        pin_memory=True,
                                        drop_last=True)

    word_embeddings = None
    if args.sim_loss:
        word_embeddings = ViCoWordEmbeddings(root=args.embed_path,
                                             num_classes=num_classes,
                                             vico_mode=args.vico_mode,
                                             one_hot=args.one_hot,
                                             linear_dim=args.linear_dim,
                                             no_hypernym=args.no_hypernym,
                                             no_glove=args.no_glove,
                                             pool_size=None)
        if args.distributed:
            word_embeddings = nn.SyncBatchNorm.convert_sync_batchnorm(
                word_embeddings)
        word_embeddings = word_embeddings.cuda()

    model = resnet_models.__dict__['resnet{}'.format(args.num_layers)](
        small_image=True,
        hidden_mlp=0,
        output_dim=num_classes,
        returned_featmaps=[3, 4, 5],
        multi_cropped_input=False)
    # synchronize batch norm layers
    if args.distributed:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = model.cuda()

    if args.rank == 0:
        logger.info(model)
    logger.info("Building model done.")

    lr = args.lr
    params = model.parameters()
    if args.optimizer == 'SGD':
        optimizer = optim.SGD(params,
                              lr=lr,
                              momentum=args.momentum,
                              weight_decay=1e-4)
    elif args.optimizer == 'Adam':
        opt = optim.Adam(params, lr=lr, weight_decay=1e-4)
    else:
        assert (False), 'optimizer not implemented'

    # objective
    criterion = nn.CrossEntropyLoss(ignore_index=-1)

    # optimizer and schedulers
    optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False)
    # warm up
    warmup_lr_schedule = np.linspace(
        args.start_warmup, args.lr,
        len(dataloaders['train']) * args.warmup_epochs)
    # cosine/step
    iters = np.arange(
        len(dataloaders['train']) * (args.num_epochs - args.warmup_epochs))
    if args.cosine:
        final_lr = args.lr * (args.lr_decay_rate)**3
        cosine_lr_schedule = np.array([final_lr + 0.5 * (args.lr - final_lr) * (1 + \
                            math.cos(math.pi * t / (len(dataloaders['train']) * (args.num_epochs - args.warmup_epochs)))) for t in iters])
        lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
    else:
        steps = np.array([
            int(item.strip()) * len(dataloaders['train'])
            for item in args.lr_decay_epochs.split(',')
        ])
        step_lr_schedule = np.array(
            [args.lr * args.lr_decay_rate**(t >= steps).sum() for t in iters])
        lr_schedule = np.concatenate((warmup_lr_schedule, step_lr_schedule))

    logger.info("Building optimizer done.")

    # wrap models
    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.gpu_to_work_on],
            find_unused_parameters=True,
        )
        if args.sim_loss:
            word_embeddings = nn.parallel.DistributedDataParallel(
                word_embeddings,
                device_ids=[args.gpu_to_work_on],
                find_unused_parameters=True,
            )

    # optionally resume from a checkpoint
    to_restore = {"epoch": 0, "val_acc": 0, "best_val_acc": 0}
    restart_from_checkpoint(os.path.join(args.dump_path, "checkpoint.pth.tar"),
                            run_variables=to_restore,
                            state_dict=model,
                            optimizer=optimizer,
                            distributed=args.distributed)

    eval_score = to_restore["val_acc"]
    start_epoch = to_restore["epoch"]
    best_val_acc = to_restore["best_val_acc"]
    for epoch in range(start_epoch, args.num_epochs):

        logger.info("============ Starting epoch %i ... ============" % epoch)

        # set sampler
        if args.distributed:
            dataloaders['train'].sampler.set_epoch(epoch)

        # train for one epoch
        scores = train_model(model, word_embeddings, dataloaders['train'],
                             optimizer, criterion, epoch, lr_schedule, writer)

        # evaluate if needed
        if epoch % args.val_freq == 0 and args.rank == 0:
            if args.distributed:
                dataloaders['test'].sampler.set_epoch(epoch)
            eval_score = eval_model(model, word_embeddings,
                                    dataloaders['test'], epoch, writer)
            if eval_score > best_val_acc:
                best_val_acc = eval_score

        training_stats.update(scores + (eval_score, ))

        if args.rank == 0:
            # after epoch: save checkpoints
            save_dict = {
                "epoch": epoch + 1,
                "val_acc": eval_score,
                "best_val_acc": best_val_acc,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            torch.save(
                save_dict,
                os.path.join(args.dump_path, "checkpoint.pth.tar"),
            )
            if epoch % args.checkpoint_freq == 0 or epoch == args.num_epochs - 1:
                shutil.copyfile(
                    os.path.join(args.dump_path, "checkpoint.pth.tar"),
                    os.path.join(args.dump_checkpoints,
                                 "ckp-" + str(epoch) + ".pth"),
                )

    writer.close()