コード例 #1
0
def main():
    global args
    args = parser.parse_args()
    init_distributed_mode(args)
    fix_random_seeds(args.seed)
    logger, training_stats = initialize_exp(args, "epoch", "loss")

    # build data
    train_dataset = MultiCropDataset(
        args.data_path,
        args.size_crops,
        args.nmb_crops,
        args.min_scale_crops,
        args.max_scale_crops,
        pil_blur=args.use_pil_blur,
    )
    sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               sampler=sampler,
                                               batch_size=args.batch_size,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)
    logger.info("Building data done with {} images loaded.".format(
        len(train_dataset)))

    # build model
    model = resnet_models.__dict__[args.arch](
        normalize=True,
        hidden_mlp=args.hidden_mlp,
        output_dim=args.feat_dim,
        nmb_prototypes=args.nmb_prototypes,
    )
    # synchronize batch norm layers
    if args.sync_bn == "pytorch":
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    elif args.sync_bn == "apex":
        process_group = None
        if args.world_size // 8 > 0:
            process_group = apex.parallel.create_syncbn_process_group(
                args.world_size // 8)
        model = apex.parallel.convert_syncbn_model(model,
                                                   process_group=process_group)
    # copy model to GPU
    model = model.cuda()
    if args.rank == 0:
        logger.info(model)
    logger.info("Building model done.")

    # build optimizer
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args.base_lr,
        momentum=0.9,
        weight_decay=args.wd,
    )
    optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False)
    warmup_lr_schedule = np.linspace(args.start_warmup, args.base_lr,
                                     len(train_loader) * args.warmup_epochs)
    iters = np.arange(len(train_loader) * (args.epochs - args.warmup_epochs))
    cosine_lr_schedule = np.array([args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (1 + \
                         math.cos(math.pi * t / (len(train_loader) * (args.epochs - args.warmup_epochs)))) for t in iters])
    lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
    logger.info("Building optimizer done.")

    # init mixed precision
    if args.use_fp16:
        model, optimizer = apex.amp.initialize(model,
                                               optimizer,
                                               opt_level="O1")
        logger.info("Initializing mixed precision done.")

    # wrap model
    model = nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.gpu_to_work_on],
        find_unused_parameters=True,
    )

    # optionally resume from a checkpoint
    to_restore = {"epoch": 0}
    restart_from_checkpoint(
        os.path.join(args.dump_path, "checkpoint.pth.tar"),
        run_variables=to_restore,
        state_dict=model,
        optimizer=optimizer,
        amp=apex.amp,
    )
    start_epoch = to_restore["epoch"]

    # build the queue
    queue = None
    queue_path = os.path.join(args.dump_path,
                              "queue" + str(args.rank) + ".pth")
    if os.path.isfile(queue_path):
        queue = torch.load(queue_path)["queue"]
    # the queue needs to be divisible by the batch size
    # args.queue_length -= args.queue_length % (args.batch_size * args.world_size)

    cudnn.benchmark = True

    ## initialize queue
    print('start initialize queue')
    queue = init_queue(train_loader, model, args)
    print('queue initialize finish')

    for epoch in range(start_epoch, args.epochs):

        # train the network for one epoch
        logger.info("============ Starting epoch %i ... ============" % epoch)

        # set sampler
        train_loader.sampler.set_epoch(epoch)

        # optionally starts a queue
        # queue shape : (Ncrops, Lqueue, feat) --> (NClass, NCrops, Lqueue, feat)
        # if queue is None:
        #     queue = torch.randn(1000, args.feat_dim).cuda()
        #     queue = nn.functional.normalize(queue, dim=1, p=2)
        # train the network
        scores, queue = train(train_loader, model, optimizer, epoch,
                              lr_schedule, queue, args)
        training_stats.update(scores)

        # save checkpoints
        if args.rank == 0:
            save_dict = {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            if args.use_fp16:
                save_dict["amp"] = apex.amp.state_dict()
            torch.save(
                save_dict,
                os.path.join(args.dump_path, "checkpoint.pth.tar"),
            )
            if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1:
                shutil.copyfile(
                    os.path.join(args.dump_path, "checkpoint.pth.tar"),
                    os.path.join(args.dump_checkpoints,
                                 "ckp-" + str(epoch) + ".pth"),
                )
        if queue is not None:
            torch.save({"queue": queue}, queue_path)
コード例 #2
0
def main(params):

    # initialize the multi-GPU / multi-node training
    init_distributed_mode(params)

    # initialize the experiment / load data
    logger = initialize_exp(params)

    # Seed
    torch.manual_seed(params.seed)
    torch.cuda.manual_seed_all(params.seed)

    # initialize SLURM signal handler for time limit / pre-emption
    if params.is_slurm_job:
        init_signal_handler()

    # data loaders / samplers
    populate_dataset(params)
    train_data_loader, train_sampler, _ = get_data_loader(
        img_size=params.img_size,
        crop_size=params.crop_size,
        shuffle=True,
        batch_size=params.batch_size,
        num_classes=params.num_classes,
        nb_workers=params.nb_workers,
        distributed_sampler=params.multi_gpu,
        dataset=params.dataset,
        data_path=params.train_path,
        transform=params.train_transform,
        split='valid' if params.debug_train else 'train',
        seed=params.seed)

    valid_data_loader, _, _ = get_data_loader(img_size=params.img_size,
                                              crop_size=params.crop_size,
                                              shuffle=False,
                                              batch_size=params.batch_size,
                                              num_classes=params.num_classes,
                                              nb_workers=params.nb_workers,
                                              distributed_sampler=False,
                                              dataset=params.dataset,
                                              transform='center',
                                              split='valid',
                                              seed=params.seed)

    # build model / cuda
    logger.info("Building %s model ..." % params.architecture)
    ftmodel = build_model(params)
    ftmodel.fc = nn.Sequential()
    ftmodel.eval().cuda()

    linearmodel = nn.Linear(EMBEDDING_SIZE[params.architecture],
                            params.num_classes).cuda()

    if params.from_ckpt != "":
        ckpt = torch.load(params.from_ckpt)
        state_dict = {
            k.replace("module.", ""): v
            for k, v in ckpt['model'].items()
        }

        del state_dict["fc.weight"]
        if "fc.bias" in state_dict:
            del state_dict["fc.bias"]
        missing_keys, unexcepted_keys = ftmodel.load_state_dict(state_dict,
                                                                strict=False)
        print("Missing keys: ", missing_keys)
        print("Unexcepted keys: ", unexcepted_keys)

    # distributed  # TODO: check this https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main.py#L142
    if params.multi_gpu:
        logger.info("Using nn.parallel.DistributedDataParallel ...")
        linearmodel = nn.parallel.DistributedDataParallel(
            linearmodel,
            device_ids=[params.local_rank],
            output_device=params.local_rank,
            broadcast_buffers=True)

    # build trainer / reload potential checkpoints / build evaluator
    trainer = Trainer(model=linearmodel, params=params, ftmodel=ftmodel)
    trainer.reload_checkpoint()
    evaluator = Evaluator(trainer, params)

    # evaluation
    if params.eval_only:
        scores = evaluator.run_all_evals(trainer,
                                         evals=['classif'],
                                         data_loader=valid_data_loader)

        for k, v in scores.items():
            logger.info('%s -> %.6f' % (k, v))
        logger.info("__log__:%s" % json.dumps(scores))
        exit()

    # training
    for epoch in range(trainer.epoch, params.epochs):

        # update epoch / sampler / learning rate
        trainer.epoch = epoch
        logger.info("============ Starting epoch %i ... ============" %
                    trainer.epoch)
        if params.multi_gpu:
            train_sampler.set_epoch(epoch)

        # update learning rate
        trainer.update_learning_rate()

        # train
        for i, (images, targets) in enumerate(train_data_loader):
            trainer.classif_step(images, targets)
            trainer.iter()

        logger.info("============ End of epoch %i ============" %
                    trainer.epoch)

        # evaluate classification accuracy
        scores = evaluator.run_all_evals(trainer,
                                         evals=['classif'],
                                         data_loader=valid_data_loader)

        for name, val in trainer.get_scores().items():
            scores[name] = val

        # print / JSON log
        for k, v in scores.items():
            logger.info('%s -> %.6f' % (k, v))
        if params.is_master:
            logger.info("__log__:%s" % json.dumps(scores))

        # end of epoch
        trainer.save_best_model(scores)
        trainer.save_periodic()
        trainer.end_epoch(scores)
コード例 #3
0
def main():
    global args
    args = parser.parse_args()
    init_distributed_mode(args)
    fix_random_seeds(args.seed)
    logger, training_stats = initialize_exp(args, "epoch", "loss")

    # build data
    train_dataset = MultiCropDataset(
        args.data_path,
        args.size_crops,
        args.nmb_crops,
        args.min_scale_crops,
        args.max_scale_crops,
        return_index=True,
    )
    sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               sampler=sampler,
                                               batch_size=args.batch_size,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)
    logger.info("Building data done with {} images loaded.".format(
        len(train_dataset)))

    # build model
    model = resnet_models.__dict__[args.arch](
        normalize=True,
        hidden_mlp=args.hidden_mlp,
        output_dim=args.feat_dim,
        nmb_prototypes=args.nmb_prototypes,
    )
    # synchronize batch norm layers
    if args.sync_bn == "pytorch":
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    elif args.sync_bn == "apex":
        process_group = None
        if args.world_size // 8 > 0:
            process_group = apex.parallel.create_syncbn_process_group(
                args.world_size // 8)
        model = apex.parallel.convert_syncbn_model(model,
                                                   process_group=process_group)
    # copy model to GPU
    model = model.cuda()
    if args.rank == 0:
        logger.info(model)
    logger.info("Building model done.")

    # build optimizer
    # base_lr=4.8 wd=1e-6
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args.base_lr,
        momentum=0.9,
        weight_decay=args.wd,
    )
    # Using Dist LARC Optimizer
    optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False)

    # LR Scheduling
    warmup_lr_schedule = np.linspace(args.start_warmup, args.base_lr,
                                     len(train_loader) * args.warmup_epochs)
    iters = np.arange(len(train_loader) * (args.epochs - args.warmup_epochs))
    cosine_lr_schedule = np.array([args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (1 + \
                         math.cos(math.pi * t / (len(train_loader) * (args.epochs - args.warmup_epochs)))) for t in iters])
    lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))

    logger.info("Building optimizer done.")

    # wrap model
    model = nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.gpu_to_work_on],
        find_unused_parameters=True,
    )

    # optionally resume from a checkpoint
    to_restore = {"epoch": 0}
    restart_from_checkpoint(
        os.path.join(args.dump_path, "checkpoint.pth.tar"),
        run_variables=to_restore,
        state_dict=model,
        optimizer=optimizer,
    )
    start_epoch = to_restore["epoch"]

    # build the memory bank
    mb_path = os.path.join(args.dump_path, "mb" + str(args.rank) + ".pth")
    if os.path.isfile(mb_path):
        mb_ckp = torch.load(mb_path)
        local_memory_index = mb_ckp["local_memory_index"]
        local_memory_embeddings = mb_ckp["local_memory_embeddings"]
    else:
        local_memory_index, local_memory_embeddings = init_memory(
            train_loader, model)

    cudnn.benchmark = True
    for epoch in range(start_epoch, args.epochs):

        # train the network for one epoch
        logger.info("============ Starting epoch %i ... ============" % epoch)

        # set sampler
        train_loader.sampler.set_epoch(epoch)

        # train the network
        scores, local_memory_index, local_memory_embeddings = train(
            train_loader,
            model,
            optimizer,
            epoch,
            lr_schedule,
            local_memory_index,
            local_memory_embeddings,
        )
        training_stats.update(scores)

        # save checkpoints
        if args.rank == 0:
            save_dict = {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            torch.save(
                save_dict,
                os.path.join(args.dump_path, "checkpoint.pth.tar"),
            )
            if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1:
                shutil.copyfile(
                    os.path.join(args.dump_path, "checkpoint.pth.tar"),
                    os.path.join(args.dump_checkpoints,
                                 "ckp-" + str(epoch) + ".pth"),
                )
        torch.save(
            {
                "local_memory_embeddings": local_memory_embeddings,
                "local_memory_index": local_memory_index
            }, mb_path)
コード例 #4
0
def main(args):

    # initialize the multi-GPU / multi-node training
    init_distributed_mode(args, make_communication_groups=False)

    # initialize the experiment
    logger, training_stats = initialize_exp(args, 'epoch', 'iter', 'prec',
                                            'loss', 'prec_val', 'loss_val')

    # initialize SLURM signal handler for time limit / pre-emption
    init_signal_handler()

    if not 'pascal' in args.data_path:
        main_data_path = args.data_path
        args.data_path = os.path.join(main_data_path, 'train')
        train_dataset = load_data(args)
    else:
        train_dataset = VOC2007_dataset(args.data_path, split=args.split)

    args.test = 'val' if args.split == 'train' else 'test'
    if not 'pascal' in args.data_path:
        if args.cross_valid is None:
            args.data_path = os.path.join(main_data_path, 'val')
        val_dataset = load_data(args)
    else:
        val_dataset = VOC2007_dataset(args.data_path, split=args.test)

    if args.cross_valid is not None:
        kfold = KFold(per_target(train_dataset.imgs), args.cross_valid, args.kfold)
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, sampler=kfold.train,
            num_workers=args.workers, pin_memory=True)
        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=args.batch_size, sampler=kfold.val,
            num_workers=args.workers)

    else:
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=True,
            num_workers=args.workers, pin_memory=True)
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers)

    # prepare the different data transformations
    tr_val, tr_train = get_data_transformations()
    train_dataset.transform = tr_train
    val_dataset.transform = tr_val

    # build model skeleton
    fix_random_seeds(args.seed)
    model = model_factory(args.arch, args.sobel)

    load_pretrained(model, args)

    # keep only conv layers
    model.body.classifier = None
    model.conv = args.conv

    if 'places' in args.data_path:
        nmb_classes = 205
    elif 'pascal' in args.data_path:
        nmb_classes = 20
    else:
        nmb_classes = 1000

    reglog = RegLog(args.arch, nmb_classes, args.conv)

    # distributed training wrapper
    model = to_cuda(model, [args.gpu_to_work_on], apex=False)
    reglog = to_cuda(reglog, [args.gpu_to_work_on], apex=False)
    logger.info('model to cuda')


    # set optimizer
    optimizer = sgd_optimizer(reglog, args.lr, args.wd)

    ## variables to reload to fetch in checkpoint
    to_restore = {'epoch': 0, 'start_iter': 0}

    # re start from checkpoint
    restart_from_checkpoint(
        args,
        run_variables=to_restore,
        state_dict=reglog,
        optimizer=optimizer,
    )
    args.epoch = to_restore['epoch']
    args.start_iter = to_restore['start_iter']

    model.eval()
    reglog.train()

    # Linear training
    for _ in range(args.epoch, args.nepochs):

        logger.info("============ Starting epoch %i ... ============" % args.epoch)

        # train the network for one epoch
        scores = train_network(args, model, reglog, optimizer, train_loader)

        if not 'pascal' in args.data_path:
            scores_val = validate_network(val_loader, [model, reglog], args)
        else:
            scores_val = evaluate_pascal(val_dataset, [model, reglog])

        scores = scores + scores_val

        # save training statistics
        logger.info(scores)
        training_stats.update(scores)
コード例 #5
0
ファイル: eval_semisup.py プロジェクト: killsking/swav
def main():
    global args, best_acc
    args = parser.parse_args()
    init_distributed_mode(args)
    fix_random_seeds(args.seed)
    logger, training_stats = initialize_exp(args, "epoch", "loss", "prec1",
                                            "prec5", "loss_val", "prec1_val",
                                            "prec5_val")

    # build data
    train_data_path = os.path.join(args.data_path, "train")
    train_dataset = datasets.ImageFolder(train_data_path)
    # take either 1% or 10% of images
    subset_file = urllib.request.urlopen(
        "https://raw.githubusercontent.com/google-research/simclr/master/imagenet_subsets/"
        + str(args.labels_perc) + "percent.txt")
    list_imgs = [li.decode("utf-8").split('\n')[0] for li in subset_file]
    train_dataset.samples = [(os.path.join(train_data_path,
                                           li.split('_')[0], li),
                              train_dataset.class_to_idx[li.split('_')[0]])
                             for li in list_imgs]
    val_dataset = datasets.ImageFolder(os.path.join(args.data_path, "val"))
    tr_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.228, 0.224, 0.225])
    train_dataset.transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        tr_normalize,
    ])
    val_dataset.transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        tr_normalize,
    ])
    sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        sampler=sampler,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
    )
    logger.info("Building data done with {} images loaded.".format(
        len(train_dataset)))

    # build model
    model = resnet_models.__dict__[args.arch](output_dim=1000)

    # convert batch norm layers
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

    # load weights
    if os.path.isfile(args.pretrained):
        state_dict = torch.load(args.pretrained,
                                map_location="cuda:" +
                                str(args.gpu_to_work_on))
        if "state_dict" in state_dict:
            state_dict = state_dict["state_dict"]
        # remove prefixe "module."
        state_dict = {
            k.replace("module.", ""): v
            for k, v in state_dict.items()
        }
        for k, v in model.state_dict().items():
            if k not in list(state_dict):
                logger.info(
                    'key "{}" could not be found in provided state dict'.
                    format(k))
            elif state_dict[k].shape != v.shape:
                logger.info(
                    'key "{}" is of different shape in model and provided state dict'
                    .format(k))
                state_dict[k] = v
        msg = model.load_state_dict(state_dict, strict=False)
        logger.info("Load pretrained model with msg: {}".format(msg))
    else:
        logger.info(
            "No pretrained weights found => training from random weights")

    # model to gpu
    model = model.cuda()
    model = nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.gpu_to_work_on],
        find_unused_parameters=True,
    )

    # set optimizer
    trunk_parameters = []
    head_parameters = []
    for name, param in model.named_parameters():
        if 'head' in name:
            head_parameters.append(param)
        else:
            trunk_parameters.append(param)
    optimizer = torch.optim.SGD(
        [{
            'params': trunk_parameters
        }, {
            'params': head_parameters,
            'lr': args.lr_last_layer
        }],
        lr=args.lr,
        momentum=0.9,
        weight_decay=0,
    )
    # set scheduler
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     args.decay_epochs,
                                                     gamma=args.gamma)

    # Optionally resume from a checkpoint
    to_restore = {"epoch": 0, "best_acc": (0., 0.)}
    restart_from_checkpoint(
        os.path.join(args.dump_path, "checkpoint.pth.tar"),
        run_variables=to_restore,
        state_dict=model,
        optimizer=optimizer,
        scheduler=scheduler,
    )
    start_epoch = to_restore["epoch"]
    best_acc = to_restore["best_acc"]
    cudnn.benchmark = True

    for epoch in range(start_epoch, args.epochs):

        # train the network for one epoch
        logger.info("============ Starting epoch %i ... ============" % epoch)

        # set samplers
        train_loader.sampler.set_epoch(epoch)

        scores = train(model, optimizer, train_loader, epoch)
        scores_val = validate_network(val_loader, model)
        training_stats.update(scores + scores_val)

        scheduler.step()

        # save checkpoint
        if args.rank == 0:
            save_dict = {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "best_acc": best_acc,
            }
            torch.save(save_dict,
                       os.path.join(args.dump_path, "checkpoint.pth.tar"))
    logger.info("Fine-tuning with {}% of labels completed.\n"
                "Test accuracies: top-1 {acc1:.1f}, top-5 {acc5:.1f}".format(
                    args.labels_perc, acc1=best_acc[0], acc5=best_acc[1]))
コード例 #6
0
ファイル: eval_linear.py プロジェクト: gustvision/swav
def main():
    global args, best_acc
    args = parser.parse_args()
    init_distributed_mode(args)
    fix_random_seeds(args.seed)
    logger, training_stats = initialize_exp(
        args,
        "epoch",
        "loss",
        "prec1",
        "prec5",
        "loss_val",
        "prec1_val",
        "prec5_val",
    )

    # build data
    train_dataset = datasets.ImageFolder(os.path.join(args.data_path, "train"))
    val_dataset = datasets.ImageFolder(os.path.join(args.data_path, "val"))
    tr_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.228, 0.224, 0.225])
    train_dataset.transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        tr_normalize,
    ])
    val_dataset.transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        tr_normalize,
    ])
    sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        sampler=sampler,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
    )
    logger.info("Building data done")

    # build model
    model = resnet_models.__dict__[args.arch](output_dim=0, eval_mode=True)
    linear_classifier = RegLog(1000, args.arch, args.global_pooling,
                               args.use_bn)

    # convert batch norm layers (if any)
    linear_classifier = nn.SyncBatchNorm.convert_sync_batchnorm(
        linear_classifier)

    # model to gpu
    model = model.cuda()
    linear_classifier = linear_classifier.cuda()
    linear_classifier = nn.parallel.DistributedDataParallel(
        linear_classifier,
        device_ids=[args.gpu_to_work_on],
        find_unused_parameters=True,
    )
    model.eval()

    # load weights
    if os.path.isfile(args.pretrained):
        state_dict = torch.load(args.pretrained,
                                map_location="cuda:" +
                                str(args.gpu_to_work_on))
        if "state_dict" in state_dict:
            state_dict = state_dict["state_dict"]
        # remove prefixe "module."
        state_dict = {
            k.replace("module.", ""): v
            for k, v in state_dict.items()
        }
        for k, v in model.state_dict().items():
            if k not in list(state_dict):
                logger.info(
                    'key "{}" could not be found in provided state dict'.
                    format(k))
            elif state_dict[k].shape != v.shape:
                logger.info(
                    'key "{}" is of different shape in model and provided state dict'
                    .format(k))
                state_dict[k] = v
        msg = model.load_state_dict(state_dict, strict=False)
        logger.info("Load pretrained model with msg: {}".format(msg))
    else:
        logger.info(
            "No pretrained weights found => training with random weights")

    # set optimizer
    optimizer = torch.optim.SGD(
        linear_classifier.parameters(),
        lr=args.lr,
        nesterov=args.nesterov,
        momentum=0.9,
        weight_decay=args.wd,
    )

    # set scheduler
    if args.scheduler_type == "step":
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                         args.decay_epochs,
                                                         gamma=args.gamma)
    elif args.scheduler_type == "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.epochs, eta_min=args.final_lr)

    # Optionally resume from a checkpoint
    to_restore = {"epoch": 0, "best_acc": 0.0}
    restart_from_checkpoint(
        os.path.join(args.dump_path, "checkpoint.pth.tar"),
        run_variables=to_restore,
        state_dict=linear_classifier,
        optimizer=optimizer,
        scheduler=scheduler,
    )
    start_epoch = to_restore["epoch"]
    best_acc = to_restore["best_acc"]
    cudnn.benchmark = True

    for epoch in range(start_epoch, args.epochs):

        # train the network for one epoch
        logger.info("============ Starting epoch %i ... ============" % epoch)

        # set samplers
        train_loader.sampler.set_epoch(epoch)

        scores = train(model, linear_classifier, optimizer, train_loader,
                       epoch)
        scores_val = validate_network(val_loader, model, linear_classifier)
        training_stats.update(scores + scores_val)

        scheduler.step()

        # save checkpoint
        if args.rank == 0:
            save_dict = {
                "epoch": epoch + 1,
                "state_dict": linear_classifier.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "best_acc": best_acc,
            }
            torch.save(save_dict,
                       os.path.join(args.dump_path, "checkpoint.pth.tar"))
    logger.info(
        "Training of the supervised linear classifier on frozen features completed.\n"
        "Top-1 test accuracy: {acc:.1f}".format(acc=best_acc))
コード例 #7
0
ファイル: main.py プロジェクト: GG-yuki/bugs
def main(args):
    """
    This code implements the paper: https://arxiv.org/abs/1905.01278
    The method consists in alternating between a hierachical clustering of the
    features and learning the parameters of a convnet by predicting both the
    angle of the rotation applied to the input data and the cluster assignments
    in a single hierachical loss.
           """

    # initialize communication groups
    training_groups, clustering_groups = init_distributed_mode(args)

    # check parameters
    check_parameters(args)

    # initialize the experiment
    logger, training_stats = initialize_exp(args, 'epoch', 'iter', 'prec',
                                            'loss', 'prec_super_class',
                                            'loss_super_class',
                                            'prec_sub_class', 'loss_sub_class')

    # initialize SLURM signal handler for time limit / pre-emption
    init_signal_handler()

    # load data
    dataset = YFCC100M_dataset(r'./dataset', size=args.size_dataset)

    # prepare the different data transformations
    tr_cluster, tr_train = get_data_transformations(args.rotation * 90)

    # build model skeleton
    fix_random_seeds()
    model = model_factory(args.sobel)
    logger.info('model created')

    # load pretrained weights
    load_pretrained(model, args)

    # convert batch-norm layers to nvidia wrapper to enable batch stats reduction
    model = apex.parallel.convert_syncbn_model(model)

    # distributed training wrapper
    model = to_cuda(model, args.gpu_to_work_on, apex=False)
    logger.info('model to cuda')

    # set optimizer
    optimizer = sgd_optimizer(model, args.lr, args.wd)

    # load cluster assignments
    cluster_assignments = load_cluster_assignments(args, dataset)

    # build prediction layer on the super_class
    pred_layer, optimizer_pred_layer = build_prediction_layer(
        model.module.body.dim_output_space,
        args,
    )

    nmb_sub_classes = args.k // args.nmb_super_clusters
    sub_class_pred_layer, optimizer_sub_class_pred_layer = build_prediction_layer(
        model.module.body.dim_output_space,
        args,
        num_classes=nmb_sub_classes,
        group=training_groups[args.training_local_world_id],
    )

    # variables to fetch in checkpoint
    to_restore = {'epoch': 0, 'start_iter': 0}

    # re start from checkpoint
    restart_from_checkpoint(
        args,
        run_variables=to_restore,
        state_dict=model,
        optimizer=optimizer,
        pred_layer_state_dict=pred_layer,
        optimizer_pred_layer=optimizer_pred_layer,
    )
    pred_layer_name = str(args.training_local_world_id) + '-pred_layer.pth.tar'
    restart_from_checkpoint(
        args,
        ckp_path=os.path.join(args.dump_path, pred_layer_name),
        state_dict=sub_class_pred_layer,
        optimizer=optimizer_sub_class_pred_layer,
    )
    args.epoch = to_restore['epoch']
    args.start_iter = to_restore['start_iter']

    for _ in range(args.epoch, args.nepochs):

        logger.info("============ Starting epoch %i ... ============" %
                    args.epoch)
        fix_random_seeds(args.epoch)

        # step 1: Get the final activations for the whole dataset / Cluster them

        if cluster_assignments is None and not args.epoch % args.reassignment:

            logger.info("=> Start clustering step")
            dataset.transform = tr_cluster

            cluster_assignments = get_cluster_assignments(
                args, model, dataset, clustering_groups)

            # reset prediction layers
            if args.nmb_super_clusters > 1:
                pred_layer, optimizer_pred_layer = build_prediction_layer(
                    model.module.body.dim_output_space,
                    args,
                )
            sub_class_pred_layer, optimizer_sub_class_pred_layer = build_prediction_layer(
                model.module.body.dim_output_space,
                args,
                num_classes=nmb_sub_classes,
                group=training_groups[args.training_local_world_id],
            )

        # step 2: Train the network with the cluster assignments as labels

        # prepare dataset
        dataset.transform = tr_train
        dataset.sub_classes = cluster_assignments

        # concatenate models and their corresponding optimizers
        models = [model, pred_layer, sub_class_pred_layer]
        optimizers = [
            optimizer, optimizer_pred_layer, optimizer_sub_class_pred_layer
        ]

        # train the network for one epoch
        scores = train_network(args, models, optimizers, dataset)

        ## save training statistics
        logger.info(scores)
        training_stats.update(scores)

        # reassign clusters at the next epoch
        if not args.epoch % args.reassignment:
            cluster_assignments = None
            dataset.subset_indexes = None
            end_of_epoch(args)

        dist.barrier()
コード例 #8
0
def main(args):

    # initialize the multi-GPU / multi-node training
    init_distributed_mode(args, make_communication_groups=False)

    # initialize the experiment
    logger, training_stats = initialize_exp(args, 'epoch', 'iter', 'prec',
                                            'loss', 'prec_val', 'loss_val')

    # initialize SLURM signal handler for time limit / pre-emption
    init_signal_handler()

    main_data_path = args.data_path
    if args.debug:
        args.data_path = os.path.join(main_data_path, 'val')
    else:
        args.data_path = os.path.join(main_data_path, 'train')
    train_dataset = load_data(args)

    args.data_path = os.path.join(main_data_path, 'val')
    val_dataset = load_data(args)

    # prepare the different data transformations
    tr_val, tr_train = get_data_transformations()
    train_dataset.transform = tr_train
    val_dataset.transform = tr_val
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
    )

    # build model skeleton
    fix_random_seeds(args.seed)
    nmb_classes = 205 if 'places' in args.data_path else 1000
    model = model_factory(args, relu=True, num_classes=nmb_classes)

    # load pretrained weights
    load_pretrained(model, args)

    # merge sobel layers with first convolution layer
    if args.sobel2RGB:
        sobel2RGB(model)

    # re initialize classifier
    if hasattr(model.body, 'classifier'):
        for m in model.body.classifier.modules():
            if isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.fill_(0.1)

    # distributed training wrapper
    model = to_cuda(model, [args.gpu_to_work_on], apex=True)
    logger.info('model to cuda')

    # set optimizer
    optimizer = sgd_optimizer(model, args.lr, args.wd)

    ## variables to reload to fetch in checkpoint
    to_restore = {'epoch': 0, 'start_iter': 0}

    # re start from checkpoint
    restart_from_checkpoint(
        args,
        run_variables=to_restore,
        state_dict=model,
        optimizer=optimizer,
    )
    args.epoch = to_restore['epoch']
    args.start_iter = to_restore['start_iter']

    if args.evaluate:
        validate_network(val_loader, [model], args)
        return

    # Supervised training
    for _ in range(args.epoch, args.nepochs):

        logger.info("============ Starting epoch %i ... ============" %
                    args.epoch)

        fix_random_seeds(args.seed + args.epoch)

        # train the network for one epoch
        adjust_learning_rate(optimizer, args)
        scores = train_network(args, model, optimizer, train_dataset)

        scores_val = validate_network(val_loader, [model], args)

        # save training statistics
        logger.info(scores + scores_val)
        training_stats.update(scores + scores_val)
コード例 #9
0
ファイル: main_swav.py プロジェクト: wang3702/swav
def main():
    global args
    args = parser.parse_args()
    init_distributed_mode(args)
    fix_random_seeds(args.seed)
    logger, training_stats = initialize_exp(args, "epoch", "loss")

    # build data
    if args.type==0:
        traindir = os.path.join(args.data_path, 'train')
        train_dataset = MultiCropDataset(
        traindir,
        args.size_crops,
        args.nmb_crops,
        args.min_scale_crops,
        args.max_scale_crops,
        )
    else:
        from src.inside_crop import inside_crop, TwoCropsTransform
        from src.multicropdataset import CropDataset
        traindir = os.path.join(args.data_path, 'train')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        cur_transform = inside_crop(args.size_crops,
                                    args.nmb_crops,
                                    args.min_scale_crops,
                                    args.max_scale_crops, normalize)
        train_dataset = CropDataset(traindir, cur_transform)
    sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        sampler=sampler,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True
    )
    #configure dataset for knn checking
    traindir = os.path.join(args.data_path, 'train')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    testdir = os.path.join(args.data_path, 'val')
    transform_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
    # val_dataset = datasets.ImageFolder(traindir,transform_test)
    val_dataset = imagenet(traindir, 0.2, transform_test)
    test_dataset = datasets.ImageFolder(testdir, transform_test)
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
    test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)

    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.knn_batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=False)
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=args.knn_batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True, sampler=test_sampler, drop_last=False)


    logger.info("Building data done with {} images loaded.".format(len(train_dataset)))

    # build model
    model = resnet_models.__dict__[args.arch](
        normalize=True,
        hidden_mlp=args.hidden_mlp,
        output_dim=args.feat_dim,
        nmb_prototypes=args.nmb_prototypes,
    )
    # synchronize batch norm layers
    if args.sync_bn == "pytorch":
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    # elif args.sync_bn == "apex":
    #     process_group = None
    #     if args.world_size // 8 > 0:
    #         process_group = apex.parallel.create_syncbn_process_group(args.world_size // 8)
    #     model = apex.parallel.convert_syncbn_model(model, process_group=process_group)
    # copy model to GPU
    model = model.cuda()
    if args.rank == 0:
        logger.info(model)
    logger.info("Building model done.")

    # build optimizer
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args.base_lr,
        momentum=0.9,
        weight_decay=args.wd,
    )
    #optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False)
    optimizer = SGD_LARC(optimizer, trust_coefficient=0.001, clip=False, eps=1e-8)
    warmup_lr_schedule = np.linspace(args.start_warmup, args.base_lr, len(train_loader) * args.warmup_epochs)
    iters = np.arange(len(train_loader) * (args.epochs - args.warmup_epochs))
    cosine_lr_schedule = np.array([args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (1 + \
                         math.cos(math.pi * t / (len(train_loader) * (args.epochs - args.warmup_epochs)))) for t in iters])
    lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
    logger.info("Building optimizer done.")

    # init mixed precision
    # if args.use_fp16:
    #     model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1")
    #     logger.info("Initializing mixed precision done.")

    # wrap model
    model = nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.gpu_to_work_on],
        find_unused_parameters=True,
    )

    # optionally resume from a checkpoint
    to_restore = {"epoch": 0}
    restart_from_checkpoint(
        os.path.join(args.dump_path, "checkpoint.pth.tar"),
        run_variables=to_restore,
        state_dict=model,
        optimizer=optimizer,
        #amp=apex.amp,
    )
    start_epoch = to_restore["epoch"]

    # build the queue
    queue = None
    queue_path = os.path.join(args.dump_path, "queue" + str(args.rank) + ".pth")
    if os.path.isfile(queue_path):
        queue = torch.load(queue_path)["queue"]
    # the queue needs to be divisible by the batch size
    args.queue_length -= args.queue_length % (args.batch_size * args.world_size)

    cudnn.benchmark = True

    for epoch in range(start_epoch, args.epochs):

        # train the network for one epoch
        logger.info("============ Starting epoch %i ... ============" % epoch)

        # set sampler
        train_loader.sampler.set_epoch(epoch)

        # optionally starts a queue
        if args.queue_length > 0 and epoch >= args.epoch_queue_starts and queue is None:
            queue = torch.zeros(
                len(args.crops_for_assign),
                args.queue_length // args.world_size,
                args.feat_dim,
            ).cuda()

        # train the network
        scores, queue = train(train_loader, model, optimizer, epoch, lr_schedule, queue)
        training_stats.update(scores)
        if epoch%args.knn_freq==0:
            print("gpu consuming before cleaning:", torch.cuda.memory_allocated()/1024/1024)
            torch.cuda.empty_cache()
            print("gpu consuming after cleaning:", torch.cuda.memory_allocated()/1024/1024)
            #try:
            #should also work using a much smaller knn batch size with sampler
            knn_test_acc=knn_monitor(model, val_loader, test_loader,
                        global_k=min(args.knn_neighbor,len(val_loader.dataset)))
            #except:
            #    torch.cuda.empty_cache()
            #knn_test_acc = knn_monitor_fast(model.module.encoder_q, val_loader, test_loader,
            #                               global_k=min(args.knn_neighbor, len(val_loader.dataset)))
            print({'*KNN monitor Accuracy': knn_test_acc})
            torch.cuda.empty_cache()
        # save checkpoints
        if args.rank == 0:
            save_dict = {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            #if args.use_fp16:
            #    save_dict["amp"] = apex.amp.state_dict()
            torch.save(
                save_dict,
                os.path.join(args.dump_path, "checkpoint.pth.tar"),
            )
            if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1:
                shutil.copyfile(
                    os.path.join(args.dump_path, "checkpoint.pth.tar"),
                    os.path.join(args.dump_checkpoints, "ckp-" + str(epoch) + ".pth"),
                )
        if queue is not None:
            torch.save({"queue": queue}, queue_path)
コード例 #10
0
def main(params):
    init_distributed_mode(params)

    logger = initialize_exp(params)

    torch.cuda.manual_seed_all(params.seed)
    transform = getTransform(0)

    root_data = '/private/home/asablayrolles/data/cifar-dejalight2'
    trainset = CIFAR10(root=root_data, name=params.name, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=params.batch_size,
                                              shuffle=True,
                                              num_workers=2)

    valid_set = CIFAR10(root=root_data, name='public_0', transform=transform)
    valid_data_loader = torch.utils.data.DataLoader(
        valid_set, batch_size=params.batch_size, shuffle=False, num_workers=2)

    model = build_model(params)
    if params.gpu:
        model = model.cuda()

    # criterion = nn.CrossEntropyLoss()
    # optimizer = optim.SGD(model.parameters(), lr=params.lr, momentum=params.momentum)

    trainer = Trainer(model=model, params=params)
    evaluator = Evaluator(trainer, params)

    for epoch in range(params.epochs):
        trainer.update_learning_rate()
        for images, targets in trainloader:
            trainer.classif_step(images, targets)

        # evaluate classification accuracy
        scores = evaluator.run_all_evals(trainer,
                                         evals=['classif'],
                                         data_loader=valid_data_loader)

        for name, val in trainer.get_scores().items():
            scores[name] = val

        accuracy, precision_train, recall_train = mast_topline(
            model, trainloader, valid_data_loader)
        print(f"Guessing accuracy: {accuracy}")

        scores["mast_accuracy"] = accuracy
        scores["mast_precision_train"] = precision_train
        scores["mast_recall_train"] = recall_train

        # print / JSON log
        for k, v in scores.items():
            logger.info('%s -> %.6f' % (k, v))

        if params.is_master:
            logger.info("__log__:%s" % json.dumps(scores))

        # end of epoch
        trainer.save_best_model(scores)
        trainer.save_periodic()
        trainer.end_epoch(scores)

    print('Finished Training')
コード例 #11
0
def main(params):

    # initialize the multi-GPU / multi-node training
    init_distributed_mode(params)

    # initialize the experiment / load data
    logger = initialize_exp(params)

    # initialize SLURM signal handler for time limit / pre-emption
    if params.is_slurm_job:
        init_signal_handler()

    if params.dataset == "imagenet":
        params.num_classes = 1000
        params.img_size = 256
        params.crop_size = 224
    else:
        if params.dataset == "cifar10":
            params.num_classes = 10
        elif params.dataset == "cifar100":
            params.num_classes = 100
        else:
            assert False, "Dataset unbeknownst to me"

        params.img_size = 40
        params.crop_size = 32

    # data loaders / samplers
    train_data_loader, train_sampler = get_data_loader(
        img_size=params.img_size,
        crop_size=params.crop_size,
        shuffle=True,
        batch_size=params.batch_size,
        nb_workers=params.nb_workers,
        distributed_sampler=params.multi_gpu,
        dataset=params.dataset,
        transform=params.transform,
        split='valid' if params.debug_train else params.split_train,
    )

    valid_data_loader, _ = get_data_loader(
        img_size=params.img_size,
        crop_size=params.crop_size,
        shuffle=False,
        batch_size=params.batch_size,
        nb_workers=params.nb_workers,
        distributed_sampler=False,
        dataset=params.dataset,
        transform='center',
        split='valid',
    )

    # build model / cuda
    logger.info("Building %s model ..." % params.architecture)
    model = build_model(params)
    model.cuda()

    # distributed  # TODO: check this https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main.py#L142
    if params.multi_gpu:
        logger.info("Using nn.parallel.DistributedDataParallel ...")
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[params.local_rank],
            output_device=params.local_rank,
            broadcast_buffers=True)

    # build trainer / reload potential checkpoints / build evaluator
    trainer = Trainer(model=model, params=params)
    trainer.reload_checkpoint()
    evaluator = Evaluator(trainer, params)

    # evaluation
    if params.eval_only:
        scores = evaluator.run_all_evals(trainer,
                                         evals=['classif', 'recognition'],
                                         data_loader=valid_data_loader)

        for k, v in scores.items():
            logger.info('%s -> %.6f' % (k, v))
        logger.info("__log__:%s" % json.dumps(scores))
        exit()

    # training
    for epoch in range(trainer.epoch, params.epochs):

        # update epoch / sampler / learning rate
        trainer.epoch = epoch
        logger.info("============ Starting epoch %i ... ============" %
                    trainer.epoch)
        if params.multi_gpu:
            train_sampler.set_epoch(epoch)

        # update learning rate
        trainer.update_learning_rate()

        # train
        for i, (images, targets) in enumerate(train_data_loader):
            trainer.classif_step(images, targets)
            trainer.iter()

        logger.info("============ End of epoch %i ============" %
                    trainer.epoch)

        # evaluate classification accuracy
        scores = evaluator.run_all_evals(trainer,
                                         evals=['classif'],
                                         data_loader=valid_data_loader)

        for name, val in trainer.get_scores().items():
            scores[name] = val

        # print / JSON log
        for k, v in scores.items():
            logger.info('%s -> %.6f' % (k, v))
        if params.is_master:
            logger.info("__log__:%s" % json.dumps(scores))

        # end of epoch
        trainer.save_best_model(scores)
        trainer.save_periodic()
        trainer.end_epoch(scores)
コード例 #12
0
def main():
    global args
    args = parser.parse_args()
    init_distributed_mode(args)
    fix_random_seeds(args.seed)
    logger, training_stats = initialize_exp(args, "epoch", "loss")
    print(torch.cuda.memory_allocated(), flush=True)
    
    train_paths, train_labs, dev_paths, dev_labs, test_paths, test_labs = get_patches_labels('./sc/arion/work/millej37/ML-project/patches',
                                                                                             './sc/arion/work/millej37/ML-project/swav')
    color_transform = [get_color_distortion(), 
                           transforms.GaussianBlur(kernel_size=int(.1*224)+1,sigma=(0.1, 2.0))]
    mean = [0.485, 0.456, 0.406]
    std = [0.228, 0.224, 0.225]
    swav_transform = transforms.Compose([
                                         transforms.ToTensor(),
                                         transforms.RandomResizedCrop(),
                                         transforms.RandomHorizontalFlip(p=0.5),
                                         transforms.Compose(color_transform),
                                         transforms.Normalize(mean=mean, std=std)
                                         ])

    # build data
    train_dataset = PatchDataset(train_paths, transform=swav_transform)
    sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        sampler=sampler,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True
    )
    logger.info("Building data done with {} images loaded.".format(len(train_dataset)))

    # build model
    model = resnet_models.__dict__[args.arch](
        normalize=True,
        hidden_mlp=args.hidden_mlp,
        output_dim=args.feat_dim,
        nmb_prototypes=args.nmb_prototypes,
    )
    print(torch.cuda.memory_allocated(), flush=True)
    # synchronize batch norm layers
    if args.sync_bn == "pytorch":
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    elif args.sync_bn == "apex":
        # with apex syncbn we sync bn per group because it speeds up computation
        # compared to global syncbn
        process_group = apex.parallel.create_syncbn_process_group(args.syncbn_process_group_size)
        model = apex.parallel.convert_syncbn_model(model, process_group=process_group)
    # copy model to GPU
    model = model.cuda()
    if args.rank == 0:
        logger.info(model)
    logger.info("Building model done.")

    # build optimizer
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args.base_lr,
        momentum=0.9,
        weight_decay=args.wd,
    )
    optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False)
    warmup_lr_schedule = np.linspace(args.start_warmup, args.base_lr, len(train_loader) * args.warmup_epochs)
    iters = np.arange(len(train_loader) * (args.epochs - args.warmup_epochs))
    cosine_lr_schedule = np.array([args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (1 + \
                         math.cos(math.pi * t / (len(train_loader) * (args.epochs - args.warmup_epochs)))) for t in iters])
    lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
    logger.info("Building optimizer done.")

    # init mixed precision
    if args.use_fp16:
        model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1")
        logger.info("Initializing mixed precision done.")

    # wrap model
    model = nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.gpu_to_work_on],
        find_unused_parameters=True,
    )

    # optionally resume from a checkpoint
    to_restore = {"epoch": 0}
    restart_from_checkpoint(
        os.path.join(args.dump_path, "checkpoint.pth.tar"),
        run_variables=to_restore,
        state_dict=model,
        optimizer=optimizer,
        amp=apex.amp,
    )
    start_epoch = to_restore["epoch"]

    # build the queue
    queue = None
    queue_path = os.path.join(args.dump_path, "queue" + str(args.rank) + ".pth")
    if os.path.isfile(queue_path):
        queue = torch.load(queue_path)["queue"]
    # the queue needs to be divisible by the batch size
    args.queue_length -= args.queue_length % (args.batch_size * args.world_size)

    cudnn.benchmark = True

    for epoch in range(start_epoch, args.epochs):

        # train the network for one epoch
        logger.info("============ Starting epoch %i ... ============" % epoch)

        # set sampler
        train_loader.sampler.set_epoch(epoch)

        # optionally starts a queue
        if args.queue_length > 0 and epoch >= args.epoch_queue_starts and queue is None:
            queue = torch.zeros(
                len(args.crops_for_assign),
                args.queue_length // args.world_size,
                args.feat_dim,
            ).cuda()

        # train the network
        scores, queue = train(train_loader, model, optimizer, epoch, lr_schedule, queue)
        training_stats.update(scores)

        # save checkpoints
        if args.rank == 0:
            save_dict = {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            if args.use_fp16:
                save_dict["amp"] = apex.amp.state_dict()
            torch.save(
                save_dict,
                os.path.join(args.dump_path, "checkpoint.pth.tar"),
            )
            if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1:
                shutil.copyfile(
                    os.path.join(args.dump_path, "checkpoint.pth.tar"),
                    os.path.join(args.dump_checkpoints, "ckp-" + str(epoch) + ".pth"),
                )
        if queue is not None:
            torch.save({"queue": queue}, queue_path)