コード例 #1
0
ファイル: misc.py プロジェクト: sjjdd/hybrid-ds
    def _make_test_transform(self, crop_type, crop_size_img, crop_size_label,
                             pad_size):
        test_transform_ops = self.basic_transform_ops.copy()
        if pad_size is not None:
            test_transform_ops.append(transforms.Pad(pad_size, 0))
        if crop_type == 'center':
            test_transform_ops.append(
                transforms.CenterCrop(crop_size_img, crop_size_label))
        elif crop_type is None:
            pass
        else:
            raise RuntimeError('Unknown test crop type.')

        return transforms.Compose(test_transform_ops)
コード例 #2
0
ファイル: misc.py プロジェクト: sjjdd/hybrid-ds
    def _make_train_transform(self, crop_type, crop_size_img, crop_size_label,
                              rand_flip, mod_drop_rate, balance_rate, pad_size,
                              rand_rot90, random_black_patch_size,
                              mini_positive):
        train_transform_ops = self.basic_transform_ops.copy()

        train_transform_ops += [
            transforms.RandomBlack(random_black_patch_size),
            transforms.RandomDropout(mod_drop_rate),
            transforms.RandomFlip(rand_flip)
        ]
        if pad_size is not None:
            train_transform_ops.append(transforms.Pad(pad_size, 0))

        if rand_rot90:
            train_transform_ops.append(transforms.RandomRotate2d())

        if crop_type == 'random':
            if mini_positive:
                train_transform_ops.append(
                    transforms.RandomCropMinSize(crop_size_img, mini_positive))
            else:
                train_transform_ops.append(
                    transforms.RandomCrop(crop_size_img))
        elif crop_type == 'balance':
            train_transform_ops.append(
                transforms.BalanceCrop(balance_rate, crop_size_img,
                                       crop_size_label))
        elif crop_type == 'center':
            train_transform_ops.append(
                transforms.CenterCrop(crop_size_img, crop_size_label))
        elif crop_type is None:
            pass
        else:
            raise RuntimeError('Unknown train crop type.')

        return transforms.Compose(train_transform_ops)
コード例 #3
0
def main():
    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
    else:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    use_gpu = torch.cuda.is_available()
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    cudnn.benchmark = True

    print("Initializing train dataset {}".format(args.train_dataset))
    train_dataset = data_manager.init_dataset(name=args.train_dataset)
    print("Initializing test dataset {}".format(args.test_dataset))
    test_dataset = data_manager.init_dataset(name=args.test_dataset)

    # print("Initializing train dataset {}".format(args.train_dataset, split_id=6))
    # train_dataset = data_manager.init_dataset(name=args.train_dataset)
    # print("Initializing test dataset {}".format(args.test_dataset, split_id=6))
    # test_dataset = data_manager.init_dataset(name=args.test_dataset)

    transform_train = T.Compose([
        T.Resize([args.height, args.width]),
        T.RandomHorizontalFlip(),
        T.Pad(10),
        T.RandomCrop([args.height, args.width]),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        T.RandomErasing(probability=0.5, mean=[0.485, 0.456, 0.406])
    ])

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = True if use_gpu else False

    # random_snip  first_snip constrain_random evenly
    trainloader = DataLoader(
        VideoDataset(train_dataset.train,
                     seq_len=args.seq_len,
                     sample='constrain_random',
                     transform=transform_train),
        sampler=RandomIdentitySampler(train_dataset.train,
                                      num_instances=args.num_instances),
        batch_size=args.train_batch,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=True,
    )

    queryloader = DataLoader(
        VideoDataset(test_dataset.query,
                     seq_len=args.seq_len,
                     sample='evenly',
                     transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    galleryloader = DataLoader(
        VideoDataset(test_dataset.gallery,
                     seq_len=args.seq_len,
                     sample='evenly',
                     transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    print("Initializing model: {}".format(args.arch))
    model = models.init_model(name=args.arch,
                              num_classes=train_dataset.num_train_pids,
                              loss={'xent', 'htri'})
    print("Model size: {:.5f}M".format(
        sum(p.numel() for p in model.parameters()) / 1000000.0))

    print("load model {0} from {1}".format(args.arch, args.load_model))
    if args.load_model != '':
        pretrained_model = torch.load(args.load_model)
        model_dict = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_model['state_dict'].items()
            if k in model_dict
        }
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        start_epoch = pretrained_model['epoch'] + 1
        best_rank1 = pretrained_model['rank1']
    else:
        start_epoch = args.start_epoch
        best_rank1 = -np.inf

    criterion = dict()
    criterion['triplet'] = WeightedRegularizedTriplet()
    criterion['xent'] = CrossEntropyLabelSmooth(
        num_classes=train_dataset.num_train_pids)
    criterion['center'] = CenterLoss(num_classes=train_dataset.num_train_pids,
                                     feat_dim=512,
                                     use_gpu=True)
    print(criterion)

    optimizer = dict()
    optimizer['model'] = model.get_optimizer(args)
    optimizer['center'] = torch.optim.SGD(criterion['center'].parameters(),
                                          lr=0.5)

    scheduler = lr_scheduler.MultiStepLR(optimizer['model'],
                                         milestones=args.stepsize,
                                         gamma=args.gamma)

    print(model)
    model = nn.DataParallel(model).cuda()

    if args.evaluate:
        print("Evaluate only")
        distmat = test(model,
                       queryloader,
                       galleryloader,
                       args.pool,
                       use_gpu,
                       return_distmat=True)
        return

    start_time = time.time()
    train_time = 0
    best_epoch = args.start_epoch
    print("==> Start training")
    for epoch in range(start_epoch, args.max_epoch):

        scheduler.step()
        print('Epoch', epoch, 'lr', scheduler.get_lr()[0])

        start_train_time = time.time()
        train(epoch, model, criterion, optimizer, trainloader, use_gpu)
        train_time += round(time.time() - start_train_time)

        if (epoch + 1) > args.start_eval and args.eval_step > 0 and (
                epoch + 1) % args.eval_step == 0 or (epoch +
                                                     1) == args.max_epoch:
            print("==> Test")
            rank1 = test(model, queryloader, galleryloader, args.pool, use_gpu)
            is_best = rank1 > best_rank1

            if is_best:
                best_rank1 = rank1
                best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                }, is_best,
                osp.join(args.save_dir,
                         'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

    print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(
        best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))
コード例 #4
0
ファイル: REA_debug.py プロジェクト: tomFoxxxx/LR_reid_osnet
    print("image_dtype: ", image.dtype)
    print("image_type: ", type(image))
    plt.imshow(image)
    plt.show()

'''
torchvision.transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False)

value – erasing value. Default is 0. If a single int, it is used to erase all pixels. 
If a tuple of length 3, it is used to erase R, G, B channels respectively. 
If a str of ‘random’, erasing each pixel with random values.
'''
transform_img =  T.Compose([
        T.Resize((256,128)),
        T.RandomHorizontalFlip(p=0.5),
        T.Pad(10),
        T.RandomCrop([256,128]),
        T.ToTensor(),
        #T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        #torchvision.transforms.RandomErasing(p=1, scale=(0.02, 0.4), ratio=(0.3, 3.33))
        #torchvision.transforms.RandomErasing(p=1, scale=(0.02, 0.4), ratio=(0.3, 3.33), value=(0.4914, 0.4822, 0.4465))
        #torchvision.transforms.RandomErasing(p=1, scale=(0.02, 0.4), ratio=(0.3, 3.33), value=1)
        #torchvision.transforms.RandomErasing(p=1, scale=(0.02, 0.4), ratio=(0.3, 3.33), value=2)
        #torchvision.transforms.RandomErasing(p=1, scale=(0.02, 0.4), ratio=(0.3, 3.33), value=12)
        #torchvision.transforms.RandomErasing(p=1, scale=(0.02, 0.4), ratio=(0.3, 3.33), value='random')
        T.RandomErasing(probability=0.5, sh=0.4, mean=(0.4914, 0.4822, 0.4465)),
    ])

if __name__ == '__main__':
    pth = r'D:\pycharm\LR_reid\osnet\deep-person-reid-master\data\market1501\bounding_box_train\0002_c1s1_000451_03.jpg'
    img, img_path = read_image(pth)
コード例 #5
0
ファイル: settings.py プロジェクト: prismformore/expAT
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)

############################# Hyper-parameters ################################
alpha = 1.0
beta = 1.0
at_margin = 1

pixel_mean = [0.485, 0.456, 0.406]
pixel_std = [0.229, 0.224, 0.225]
inp_size = [384, 128]

# transforms

transforms_list = transforms.Compose([
    transforms.RectScale(*inp_size),
    transforms.RandomHorizontalFlip(),
    transforms.Pad(10),
    transforms.RandomCrop(inp_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=pixel_mean, std=pixel_std),
    transforms.RandomErasing(probability=0.5, mean=pixel_mean)
])

test_transforms_list = transforms.Compose([
    transforms.RectScale(*inp_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=pixel_mean, std=pixel_std)
])
コード例 #6
0
def main():
    torch.manual_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
    else:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("Initializing dataset {}".format(args.dataset))
    dataset = data_manager.init_img_dataset(
        root=args.root,
        name=args.dataset,
        split_id=args.split_id,
        cuhk03_labeled=args.cuhk03_labeled,
        cuhk03_classic_split=args.cuhk03_classic_split,
    )

    transform_train = T.Compose([
        T.Resize((args.height, args.width)),
        T.RandomHorizontalFlip(p=0.5),
        T.Pad(10),
        T.RandomCrop([args.height, args.width]),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        torchvision.transforms.RandomErasing(p=0.5,
                                             scale=(0.02, 0.4),
                                             ratio=(0.3, 3.33),
                                             value=(0.4914, 0.4822, 0.4465))
        # T.RandomErasing(probability=0.5, sh=0.4, mean=(0.4914, 0.4822, 0.4465)),
    ])

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = True if use_gpu else False

    trainloader = DataLoader(
        ImageDataset(dataset.train, transform=transform_train),
        batch_size=args.train_batch,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=True,
    )

    queryloader = DataLoader(
        ImageDataset(dataset.query, transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    galleryloader = DataLoader(
        ImageDataset(dataset.gallery, transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    print("Initializing model: {}".format(args.arch))
    model = models.init_model(name=args.arch,
                              num_classes=dataset.num_train_pids,
                              loss={'cent'})
    print("Model size: {:.5f}M".format(
        sum(p.numel() for p in model.parameters()) / 1000000.0))

    criterion_xent = CrossEntropyLabelSmooth(
        num_classes=dataset.num_train_pids, use_gpu=use_gpu)
    criterion_cent = CenterLoss(num_classes=dataset.num_train_pids,
                                feat_dim=model.feat_dim,
                                use_gpu=use_gpu)

    optimizer_model = torch.optim.Adam(model.parameters(),
                                       lr=args.lr,
                                       weight_decay=args.weight_decay)
    optimizer_cent = torch.optim.SGD(criterion_cent.parameters(),
                                     lr=args.lr_cent)

    #only the optimizer_model use learning rate schedule
    # if args.stepsize > 0:
    #     scheduler = lr_scheduler.StepLR(optimizer_model, step_size=args.stepsize, gamma=args.gamma)
    '''------Modify lr_schedule here------'''
    current_schedule = init_lr_schedule(schedule=args.schedule,
                                        warm_up_epoch=args.warm_up_epoch,
                                        half_cos_period=args.half_cos_period,
                                        lr_milestone=args.lr_milestone,
                                        gamma=args.gamma,
                                        stepsize=args.stepsize)

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer_model,
                                                  lr_lambda=current_schedule)
    '''------Please refer to the args.xxx for details of hyperparams------'''
    # embed()
    start_epoch = args.start_epoch

    if args.resume:
        print("Loading checkpoint from '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        start_epoch = checkpoint['epoch']

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    if args.evaluate:
        print("Evaluate only")
        test(model, queryloader, galleryloader, use_gpu)
        return

    start_time = time.time()
    train_time = 0
    best_rank1 = -np.inf
    best_epoch = 0
    print("==> Start training")

    for epoch in range(start_epoch, args.max_epoch):
        start_train_time = time.time()
        train(epoch, model, criterion_xent, criterion_cent, optimizer_model,
              optimizer_cent, trainloader, use_gpu)
        train_time += round(time.time() - start_train_time)

        if args.schedule: scheduler.step()

        if (epoch + 1) > args.start_eval and args.eval_step > 0 and (
                epoch + 1) % args.eval_step == 0 or (epoch +
                                                     1) == args.max_epoch:
            print("==> Test")
            rank1 = test(model, queryloader, galleryloader, use_gpu)
            is_best = rank1 > best_rank1
            if is_best:
                best_rank1 = rank1
                best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()
            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                }, is_best,
                osp.join(args.save_dir,
                         'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

    print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(
        best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))
コード例 #7
0
def main():
    runId = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    cfg.OUTPUT_DIR = os.path.join(cfg.OUTPUT_DIR, runId)
    if not os.path.exists(cfg.OUTPUT_DIR):
        os.mkdir(cfg.OUTPUT_DIR)
    print(cfg.OUTPUT_DIR)
    torch.manual_seed(cfg.RANDOM_SEED)
    random.seed(cfg.RANDOM_SEED)
    np.random.seed(cfg.RANDOM_SEED)
    os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID

    use_gpu = torch.cuda.is_available() and cfg.MODEL.DEVICE == "cuda"
    if not cfg.EVALUATE_ONLY:
        sys.stdout = Logger(osp.join(cfg.OUTPUT_DIR, 'log_train.txt'))
    else:
        sys.stdout = Logger(osp.join(cfg.OUTPUT_DIR, 'log_test.txt'))

    print("==========\nConfigs:{}\n==========".format(cfg))

    if use_gpu:
        print("Currently using GPU {}".format(cfg.MODEL.DEVICE_ID))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(cfg.RANDOM_SEED)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("Initializing dataset {}".format(cfg.DATASETS.NAME))

    dataset = data_manager.init_dataset(root=cfg.DATASETS.ROOT_DIR,
                                        name=cfg.DATASETS.NAME)
    print("Initializing model: {}".format(cfg.MODEL.NAME))

    if cfg.MODEL.ARCH == 'video_baseline':
        torch.backends.cudnn.benchmark = False
        model = models.init_model(name=cfg.MODEL.ARCH,
                                  num_classes=625,
                                  pretrain_choice=cfg.MODEL.PRETRAIN_CHOICE,
                                  last_stride=cfg.MODEL.LAST_STRIDE,
                                  neck=cfg.MODEL.NECK,
                                  model_name=cfg.MODEL.NAME,
                                  neck_feat=cfg.TEST.NECK_FEAT,
                                  model_path=cfg.MODEL.PRETRAIN_PATH)

    print("Model size: {:.5f}M".format(
        sum(p.numel() for p in model.parameters()) / 1000000.0))

    transform_train = T.Compose([
        T.Resize(cfg.INPUT.SIZE_TRAIN),
        T.RandomHorizontalFlip(p=cfg.INPUT.PROB),
        T.Pad(cfg.INPUT.PADDING),
        T.RandomCrop(cfg.INPUT.SIZE_TRAIN),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        T.RandomErasing(probability=cfg.INPUT.RE_PROB,
                        mean=cfg.INPUT.PIXEL_MEAN)
    ])
    transform_test = T.Compose([
        T.Resize(cfg.INPUT.SIZE_TEST),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    pin_memory = True if use_gpu else False

    cfg.DATALOADER.NUM_WORKERS = 0

    trainloader = DataLoader(VideoDataset(
        dataset.train,
        seq_len=cfg.DATASETS.SEQ_LEN,
        sample=cfg.DATASETS.TRAIN_SAMPLE_METHOD,
        transform=transform_train,
        dataset_name=cfg.DATASETS.NAME),
                             sampler=RandomIdentitySampler(
                                 dataset.train,
                                 num_instances=cfg.DATALOADER.NUM_INSTANCE),
                             batch_size=cfg.SOLVER.SEQS_PER_BATCH,
                             num_workers=cfg.DATALOADER.NUM_WORKERS,
                             pin_memory=pin_memory,
                             drop_last=True)

    queryloader = DataLoader(VideoDataset(
        dataset.query,
        seq_len=cfg.DATASETS.SEQ_LEN,
        sample=cfg.DATASETS.TEST_SAMPLE_METHOD,
        transform=transform_test,
        max_seq_len=cfg.DATASETS.TEST_MAX_SEQ_NUM,
        dataset_name=cfg.DATASETS.NAME),
                             batch_size=cfg.TEST.SEQS_PER_BATCH,
                             shuffle=False,
                             num_workers=cfg.DATALOADER.NUM_WORKERS,
                             pin_memory=pin_memory,
                             drop_last=False)

    galleryloader = DataLoader(
        VideoDataset(dataset.gallery,
                     seq_len=cfg.DATASETS.SEQ_LEN,
                     sample=cfg.DATASETS.TEST_SAMPLE_METHOD,
                     transform=transform_test,
                     max_seq_len=cfg.DATASETS.TEST_MAX_SEQ_NUM,
                     dataset_name=cfg.DATASETS.NAME),
        batch_size=cfg.TEST.SEQS_PER_BATCH,
        shuffle=False,
        num_workers=cfg.DATALOADER.NUM_WORKERS,
        pin_memory=pin_memory,
        drop_last=False,
    )

    if cfg.MODEL.SYN_BN:
        if use_gpu:
            model = nn.DataParallel(model)
        if cfg.SOLVER.FP_16:
            model = apex.parallel.convert_syncbn_model(model)
        model.cuda()

    start_time = time.time()
    xent = CrossEntropyLabelSmooth(num_classes=dataset.num_train_pids)
    tent = TripletLoss(cfg.SOLVER.MARGIN)

    optimizer = make_optimizer(cfg, model)

    scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS,
                                  cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
                                  cfg.SOLVER.WARMUP_ITERS,
                                  cfg.SOLVER.WARMUP_METHOD)
    # metrics = test(model, queryloader, galleryloader, cfg.TEST.TEMPORAL_POOL_METHOD, use_gpu)
    no_rise = 0
    best_rank1 = 0
    start_epoch = 0
    for epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCHS):
        # if no_rise == 10:
        #     break
        scheduler.step()
        print("noriase:", no_rise)
        print("==> Epoch {}/{}".format(epoch + 1, cfg.SOLVER.MAX_EPOCHS))
        print("current lr:", scheduler.get_lr()[0])

        train(model, trainloader, xent, tent, optimizer, use_gpu)
        if cfg.SOLVER.EVAL_PERIOD > 0 and (
            (epoch + 1) % cfg.SOLVER.EVAL_PERIOD == 0 or
            (epoch + 1) == cfg.SOLVER.MAX_EPOCHS):
            print("==> Test")

            metrics = test(model, queryloader, galleryloader,
                           cfg.TEST.TEMPORAL_POOL_METHOD, use_gpu)
            rank1 = metrics[0]
            if rank1 > best_rank1:
                best_rank1 = rank1
                no_rise = 0
            else:
                no_rise += 1
                continue

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()
            torch.save(
                state_dict,
                osp.join(
                    cfg.OUTPUT_DIR, "rank1_" + str(rank1) + '_checkpoint_ep' +
                    str(epoch + 1) + '.pth'))
            # best_p = osp.join(cfg.OUTPUT_DIR, "rank1_" + str(rank1) + '_checkpoint_ep' + str(epoch + 1) + '.pth')

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))