Пример #1
0
def main():
    torch.manual_seed(args.seed)
    if not args.use_avai_gpus:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
    else:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("Initializing dataset {}".format(args.dataset))
    dataset = data_manager.init_imgreid_dataset(
        root=args.root,
        name=args.dataset,
        split_id=args.split_id,
        cuhk03_labeled=args.cuhk03_labeled,
        cuhk03_classic_split=args.cuhk03_classic_split,
    )

    transform_train = T.Compose([
        T.Random2DTranslation(args.height, args.width),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = True if use_gpu else False

    trainloader = DataLoader(
        ImageDataset(dataset.train, transform=transform_train),
        batch_size=args.train_batch,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=True,
    )

    queryloader = DataLoader(
        ImageDataset(dataset.query, transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    galleryloader = DataLoader(
        ImageDataset(dataset.gallery, transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    print("Initializing model: {}".format(args.arch))
    model = models.init_model(name=args.arch,
                              num_classes=dataset.num_train_pids,
                              loss={'xent'},
                              use_gpu=use_gpu)
    print("Model size: {:.3f} M".format(count_num_param(model)))

    if args.label_smooth:
        criterion = CrossEntropyLabelSmooth(num_classes=dataset.num_train_pids,
                                            use_gpu=use_gpu)
    else:
        criterion = nn.CrossEntropyLoss()
    optimizer = init_optim(args.optim, model.parameters(), args.lr,
                           args.weight_decay)
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=args.stepsize,
                                         gamma=args.gamma)

    if args.fixbase_epoch > 0:
        if hasattr(model, 'classifier') and isinstance(model.classifier,
                                                       nn.Module):
            optimizer_tmp = init_optim(args.optim,
                                       model.classifier.parameters(),
                                       args.fixbase_lr, args.weight_decay)
        else:
            print(
                "Warn: model has no attribute 'classifier' and fixbase_epoch is reset to 0"
            )
            args.fixbase_epoch = 0

    if args.load_weights and check_isfile(args.load_weights):
        # load pretrained weights but ignore layers that don't match in size
        checkpoint = torch.load(args.load_weights)
        pretrain_dict = checkpoint['state_dict']
        model_dict = model.state_dict()
        pretrain_dict = {
            k: v
            for k, v in pretrain_dict.items()
            if k in model_dict and model_dict[k].size() == v.size()
        }
        model_dict.update(pretrain_dict)
        model.load_state_dict(model_dict)
        print("Loaded pretrained weights from '{}'".format(args.load_weights))

    if args.resume and check_isfile(args.resume):
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        args.start_epoch = checkpoint['epoch'] + 1
        best_rank1 = checkpoint['rank1']
        print("Loaded checkpoint from '{}'".format(args.resume))
        print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch,
                                                      best_rank1))

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    if args.evaluate:
        print("Evaluate only")
        distmat = test(model,
                       queryloader,
                       galleryloader,
                       use_gpu,
                       return_distmat=True)
        if args.vis_ranked_res:
            visualize_ranked_results(
                distmat,
                dataset,
                save_dir=osp.join(args.save_dir, 'ranked_results'),
                topk=20,
            )
        return

    start_time = time.time()
    train_time = 0
    best_epoch = args.start_epoch
    print("==> Start training")

    if args.fixbase_epoch > 0:
        print(
            "Train classifier for {} epochs while keeping base network frozen".
            format(args.fixbase_epoch))

        for epoch in range(args.fixbase_epoch):
            start_train_time = time.time()
            train(epoch,
                  model,
                  criterion,
                  optimizer_tmp,
                  trainloader,
                  use_gpu,
                  freeze_bn=True)
            train_time += round(time.time() - start_train_time)

        del optimizer_tmp
        print("Now open all layers for training")

    for epoch in range(args.start_epoch, args.max_epoch):
        start_train_time = time.time()
        train(epoch, model, criterion, optimizer, trainloader, use_gpu)
        train_time += round(time.time() - start_train_time)

        scheduler.step()

        if (epoch + 1) > args.start_eval and args.eval_step > 0 and (
                epoch + 1) % args.eval_step == 0 or (epoch +
                                                     1) == args.max_epoch:
            print("==> Test")
            rank1 = test(model, queryloader, galleryloader, use_gpu)
            is_best = rank1 > best_rank1

            if is_best:
                best_rank1 = rank1
                best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                }, is_best,
                osp.join(args.save_dir,
                         'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

    print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(
        best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))
Пример #2
0
def main():
    torch.manual_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
    else:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("Initializing dataset {}".format(args.dataset))
    dataset = DatasetManager(dataset_dir=args.dataset, root=args.root)

    transform_train = T.Compose_Keypt([
        T.Random2DTranslation_Keypt((args.width, args.height)),
        T.RandomHorizontalFlip_Keypt(),
        T.ToTensor_Keypt(),
        T.Normalize_Keypt(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225]),
    ])

    transform_test = T.Compose_Keypt([
        T.Resize_Keypt((args.width, args.height)),
        T.ToTensor_Keypt(),
        T.Normalize_Keypt(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = True if use_gpu else False

    trainloader = DataLoader(
        ImageDataset(dataset.train,
                     keyptaware=args.keyptaware,
                     heatmapaware=args.heatmapaware,
                     segmentaware=args.segmentaware,
                     transform=transform_train,
                     imagesize=(args.width, args.height)),
        sampler=RandomIdentitySampler(dataset.train, args.train_batch,
                                      args.num_instances),
        batch_size=args.train_batch,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=True,
    )

    queryloader = DataLoader(
        ImageDataset(dataset.query,
                     keyptaware=args.keyptaware,
                     heatmapaware=args.heatmapaware,
                     segmentaware=args.segmentaware,
                     transform=transform_test,
                     imagesize=(args.width, args.height)),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    galleryloader = DataLoader(
        ImageDataset(dataset.gallery,
                     keyptaware=args.keyptaware,
                     heatmapaware=args.heatmapaware,
                     segmentaware=args.segmentaware,
                     transform=transform_test,
                     imagesize=(args.width, args.height)),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    print("Initializing model: {}".format(args.arch))
    model = models.init_model(name=args.arch,
                              num_vids=dataset.num_train_vids,
                              num_vcolors=dataset.num_train_vcolors,
                              num_vtypes=dataset.num_train_vtypes,
                              keyptaware=args.keyptaware,
                              heatmapaware=args.heatmapaware,
                              segmentaware=args.segmentaware,
                              multitask=args.multitask)
    print("Model size: {:.3f} M".format(count_num_param(model)))

    if args.label_smooth:
        criterion_xent_vid = CrossEntropyLabelSmooth(
            num_classes=dataset.num_train_vids, use_gpu=use_gpu)
        criterion_xent_vcolor = CrossEntropyLabelSmooth(
            num_classes=dataset.num_train_vcolors, use_gpu=use_gpu)
        criterion_xent_vtype = CrossEntropyLabelSmooth(
            num_classes=dataset.num_train_vtypes, use_gpu=use_gpu)
    else:
        criterion_xent_vid = nn.CrossEntropyLoss()
        criterion_xent_vcolor = nn.CrossEntropyLoss()
        criterion_xent_vtype = nn.CrossEntropyLoss()
    criterion_htri = TripletLoss(margin=args.margin)

    optimizer = init_optim(args.optim, model.parameters(), args.lr,
                           args.weight_decay)
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=args.stepsize,
                                         gamma=args.gamma)

    if args.load_weights:
        # load pretrained weights but ignore layers that don't match in size
        if check_isfile(args.load_weights):
            checkpoint = torch.load(args.load_weights)
            pretrain_dict = checkpoint['state_dict']
            model_dict = model.state_dict()
            pretrain_dict = {
                k: v
                for k, v in pretrain_dict.items()
                if k in model_dict and model_dict[k].size() == v.size()
            }
            model_dict.update(pretrain_dict)
            model.load_state_dict(model_dict)
            print("Loaded pretrained weights from '{}'".format(
                args.load_weights))

    if args.resume:
        if check_isfile(args.resume):
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            args.start_epoch = checkpoint['epoch']
            rank1 = checkpoint['rank1']
            print("Loaded checkpoint from '{}'".format(args.resume))
            print("- start_epoch: {}\n- rank1: {}".format(
                args.start_epoch, rank1))

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    if args.evaluate:
        print("Evaluate only")
        distmat = test(model,
                       args.keyptaware,
                       args.multitask,
                       queryloader,
                       galleryloader,
                       use_gpu,
                       dataset.vcolor2label,
                       dataset.vtype2label,
                       return_distmat=True)
        if args.vis_ranked_res:
            visualize_ranked_results(
                distmat,
                dataset,
                save_dir=osp.join(args.save_dir, 'ranked_results'),
                topk=100,
            )
        return

    start_time = time.time()
    train_time = 0
    best_rank1 = -np.inf
    best_epoch = 0
    print("==> Start training")

    for epoch in range(args.start_epoch, args.max_epoch):
        start_train_time = time.time()
        train(epoch, model, args.keyptaware, args.multitask,
              criterion_xent_vid, criterion_xent_vcolor, criterion_xent_vtype,
              criterion_htri, optimizer, trainloader, use_gpu)
        train_time += round(time.time() - start_train_time)

        scheduler.step()

        if ((epoch + 1) > args.start_eval and args.eval_step > 0 and
            (epoch + 1) % args.eval_step == 0) or (epoch +
                                                   1) == args.max_epoch:
            print("==> Test")
            rank1 = test(model, args.keyptaware, args.multitask, queryloader,
                         galleryloader, use_gpu, dataset.vcolor2label,
                         dataset.vtype2label)
            is_best = rank1 > best_rank1

            if is_best:
                best_rank1 = rank1
                best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            # rank1 = 1
            # is_best = True

            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                }, is_best,
                osp.join(args.save_dir,
                         'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

    print("==> Best Rank-1 {:.2%}, achieved at epoch {}".format(
        best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))
Пример #3
0
def main(args):
    args = parser.parse_args(args)
    #global best_rank1
    best_rank1 = -np.inf
    torch.manual_seed(args.seed)
    # np.random.seed(args.seed)
    # random.seed(args.seed)
    if not args.use_avai_gpus:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
    else:
        test_dir = args.save_dir
        if args.save_dir == 'log':
            if args.resume:
                test_dir = os.path.dirname(args.resume)
            else:
                test_dir = os.path.dirname(args.load_weights)
        sys.stdout = Logger(osp.join(test_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
        # print("Currently using GPU {}".format(args.gpu_devices))
        # #cudnn.benchmark = False
        # cudnn.deterministic = True
        # torch.cuda.manual_seed_all(args.seed)
        # torch.set_default_tensor_type('torch.DoubleTensor')
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("Initializing dataset {}".format(args.dataset))
    dataset = data_manager.init_imgreid_dataset(
        root=args.root,
        name=args.dataset,
        split_id=args.split_id,
        cuhk03_labeled=args.cuhk03_labeled,
        cuhk03_classic_split=args.cuhk03_classic_split,
    )

    transform_train = T.Compose([
        T.Random2DTranslation(args.height, args.width),
        #T.Resize((args.height, args.width)),
        #T.RandomSizedEarser(),
        T.RandomHorizontalFlip(),
        #T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = True if use_gpu else False

    if 'stanford' in args.dataset:
        datasetLoader = ImageDataset_stanford
    else:
        datasetLoader = ImageDataset
    if args.crop_img:
        print("Using Cropped Images")
    else:
        print("NOT using cropped Images")
    trainloader = DataLoader(
        datasetLoader(dataset.train,
                      -1,
                      crop=args.crop_img,
                      transform=transform_train),
        batch_size=args.train_batch,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=True,
    )

    testloader = DataLoader(
        datasetLoader(dataset.test,
                      -1,
                      crop=args.crop_img,
                      transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    print("Initializing model: {}".format(args.arch))
    model = models.init_model(
        name=args.arch,
        num_classes=dataset.num_train_pids,
        loss={'xent', 'angular'} if args.use_angular else {'xent'},
        use_gpu=use_gpu)
    print("Model size: {:.3f} M".format(count_num_param(model)))

    if not (args.use_angular):
        if args.label_smooth:
            print("Using Label Smoothing")
            criterion = CrossEntropyLabelSmooth(
                num_classes=dataset.num_train_pids, use_gpu=use_gpu)
        else:
            criterion = nn.CrossEntropyLoss()
    else:
        if args.label_smooth:
            print("Using Label Smoothing")
            criterion = AngularLabelSmooth(num_classes=dataset.num_train_pids,
                                           use_gpu=use_gpu)
        else:
            criterion = AngleLoss()
    optimizer = init_optim(args.optim, model.parameters(), args.lr,
                           args.weight_decay)
    if args.scheduler != 0:
        scheduler = lr_scheduler.MultiStepLR(optimizer,
                                             milestones=args.stepsize,
                                             gamma=args.gamma)

    if args.fixbase_epoch > 0:
        if hasattr(model, 'classifier') and isinstance(model.classifier,
                                                       nn.Module):
            optimizer_tmp = init_optim(
                args.optim,
                list(model.classifier.parameters()) +
                list(model.encoder.parameters()), args.fixbase_lr,
                args.weight_decay)
        else:
            print(
                "Warn: model has no attribute 'classifier' and fixbase_epoch is reset to 0"
            )
            args.fixbase_epoch = 0

    if args.load_weights and check_isfile(args.load_weights):
        # load pretrained weights but ignore layers that don't match in size
        checkpoint = torch.load(args.load_weights)
        pretrain_dict = checkpoint['state_dict']
        model_dict = model.state_dict()
        pretrain_dict = {
            k: v
            for k, v in pretrain_dict.items()
            if k in model_dict and model_dict[k].size() == v.size()
        }
        model_dict.update(pretrain_dict)
        model.load_state_dict(model_dict)
        print("Loaded pretrained weights from '{}'".format(args.load_weights))

    if args.resume and check_isfile(args.resume):
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        args.start_epoch = checkpoint['epoch'] + 1
        best_rank1 = checkpoint['rank1']
        print("Loaded checkpoint from '{}'".format(args.resume))
        print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch,
                                                      best_rank1))

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    if args.evaluate:
        print("Evaluate only")
        test_dir = args.save_dir
        if args.save_dir == 'log':
            if args.resume:
                test_dir = os.path.dirname(args.resume)
            else:
                test_dir = os.path.dirname(args.load_weights)
        distmat = test(model,
                       testloader,
                       use_gpu,
                       args,
                       writer=None,
                       epoch=-1,
                       return_distmat=True,
                       draw_tsne=args.draw_tsne,
                       tsne_clusters=args.tsne_labels,
                       use_cosine=args.plot_deltaTheta)

        if args.visualize_ranks:
            visualize_ranked_results(
                distmat,
                dataset,
                save_dir=osp.join(test_dir, 'ranked_results'),
                topk=10,
            )
        if args.plot_deltaTheta:
            plot_deltaTheta(distmat,
                            dataset,
                            save_dir=osp.join(test_dir, 'deltaTheta_results'),
                            min_rank=1)
        return

    writer = SummaryWriter(log_dir=osp.join(args.save_dir, 'tensorboard'))
    start_time = time.time()
    train_time = 0
    best_epoch = args.start_epoch
    print("==> Start training")

    if args.test_rot:
        print("Training only classifier for rotation")
        model = models.init_model(name='rot_tester',
                                  base_model=model,
                                  inplanes=2048,
                                  num_rot_classes=8)
        criterion_rot = nn.CrossEntropyLoss()
        optimizer_rot = init_optim(args.optim, model.fc_rot.parameters(),
                                   args.fixbase_lr, args.weight_decay)
        if use_gpu:
            model = nn.DataParallel(model).cuda()
        try:
            best_epoch = 0
            for epoch in range(0, args.max_epoch):
                start_train_time = time.time()
                train_rotTester(epoch, model, criterion_rot, optimizer_rot,
                                trainloader, use_gpu, writer, args)
                train_time += round(time.time() - start_train_time)

                if args.scheduler != 0:
                    scheduler.step()

                if (epoch + 1) > args.start_eval and args.eval_step > 0 and (
                        epoch + 1) % args.eval_step == 0 or (
                            epoch + 1) == args.max_epoch:
                    if (epoch + 1) == args.max_epoch:
                        if use_gpu:
                            state_dict = model.module.state_dict()
                        else:
                            state_dict = model.state_dict()

                        save_checkpoint(
                            {
                                'state_dict': state_dict,
                                'rank1': -1,
                                'epoch': epoch,
                            }, False,
                            osp.join(
                                args.save_dir, 'beforeTesting_checkpoint_ep' +
                                str(epoch + 1) + '.pth.tar'))
                    print("==> Test")
                    rank1 = test_rotTester(model,
                                           criterion_rot,
                                           queryloader,
                                           galleryloader,
                                           trainloader,
                                           use_gpu,
                                           args,
                                           writer=writer,
                                           epoch=epoch)
                    is_best = rank1 > best_rank1

                    if is_best:
                        best_rank1 = rank1
                        best_epoch = epoch + 1

                    if use_gpu:
                        state_dict = model.module.state_dict()
                    else:
                        state_dict = model.state_dict()

                    save_checkpoint(
                        {
                            'state_dict': state_dict,
                            'rank1': rank1,
                            'epoch': epoch,
                        }, is_best,
                        osp.join(args.save_dir, 'checkpoint_ep' +
                                 str(epoch + 1) + '.pth.tar'))

            print("==> Best Cccuracy {:.1%}, achieved at epoch {}".format(
                best_rank1, best_epoch))

            elapsed = round(time.time() - start_time)
            elapsed = str(datetime.timedelta(seconds=elapsed))
            train_time = str(datetime.timedelta(seconds=train_time))
            print(
                "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}."
                .format(elapsed, train_time))
            return best_rank1, best_epoch
        except KeyboardInterrupt:
            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'rank1': -1,
                    'epoch': epoch,
                }, False,
                osp.join(
                    args.save_dir, 'keyboardInterrupt_checkpoint_ep' +
                    str(epoch + 1) + '.pth.tar'))

        return None, None

    if args.fixbase_epoch > 0:
        print(
            "Train classifier for {} epochs while keeping base network frozen".
            format(args.fixbase_epoch))

        for epoch in range(args.fixbase_epoch):
            start_train_time = time.time()
            train(epoch,
                  model,
                  criterion,
                  optimizer_tmp,
                  trainloader,
                  use_gpu,
                  writer,
                  args,
                  freeze_bn=True)
            train_time += round(time.time() - start_train_time)

        del optimizer_tmp
        print("Now open all layers for training")
    best_epoch = 0
    for epoch in range(args.start_epoch, args.max_epoch):
        start_train_time = time.time()
        train(epoch, model, criterion, optimizer, trainloader, use_gpu, writer,
              args)
        train_time += round(time.time() - start_train_time)

        if args.scheduler != 0:
            scheduler.step()

        if (epoch + 1) > args.start_eval and args.eval_step > 0 and (
                epoch + 1) % args.eval_step == 0 or (epoch +
                                                     1) == args.max_epoch:
            if (epoch + 1) == args.max_epoch:
                if use_gpu:
                    state_dict = model.module.state_dict()
                else:
                    state_dict = model.state_dict()

                save_checkpoint(
                    {
                        'state_dict': state_dict,
                        'rank1': -1,
                        'epoch': epoch,
                    }, False,
                    osp.join(
                        args.save_dir, 'beforeTesting_checkpoint_ep' +
                        str(epoch + 1) + '.pth.tar'))
            print("==> Test")

            rank1 = test(model,
                         testloader,
                         use_gpu,
                         args,
                         writer=writer,
                         epoch=epoch)

            is_best = rank1 > best_rank1

            if is_best:
                best_rank1 = rank1
                best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                }, is_best,
                osp.join(args.save_dir,
                         'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

    print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(
        best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))
    return best_rank1, best_epoch
def main(args):
    args = parser.parse_args(args)
    #global best_rank1
    best_rank1 = -np.inf
    torch.manual_seed(args.seed)
    # np.random.seed(args.seed)
    # random.seed(args.seed)
    if not args.use_avai_gpus:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
    else:
        test_dir = args.save_dir
        if args.save_dir == 'log':
            if args.resume:
                test_dir = os.path.dirname(args.resume)
            else:
                test_dir = os.path.dirname(args.load_weights)
        sys.stdout = Logger(osp.join(test_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("Initializing dataset {}".format(args.dataset))
    dataset = data_manager.init_imgreid_dataset(
        root=args.root,
        name=args.dataset,
        split_id=args.split_id,
        cuhk03_labeled=args.cuhk03_labeled,
        cuhk03_classic_split=args.cuhk03_classic_split,
        split_wild=args.split_wild)

    transform_train = T.Compose([
        T.Random2DTranslation(args.height, args.width),
        #T.Resize((args.height, args.width)),
        T.RandomSizedEarser(),
        T.RandomHorizontalFlip_custom(),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = True if use_gpu else False

    trainloader = DataLoader(
        ImageDataset(dataset.train, transform=transform_train),
        batch_size=args.train_batch,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=True,
    )

    queryloader = DataLoader(
        ImageDataset(dataset.query,
                     transform=transform_test,
                     return_path=args.draw_tsne),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    galleryloader = DataLoader(
        ImageDataset(dataset.gallery,
                     transform=transform_test,
                     return_path=args.draw_tsne),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    print("Initializing model: {}".format(args.arch))
    model = models.init_model(
        name=args.arch,
        num_classes=dataset.num_train_pids,
        loss={'xent', 'angular'} if args.use_angular else {'xent'},
        use_gpu=use_gpu)
    print("Model size: {:.3f} M".format(count_num_param(model)))

    use_autoTune = False
    if not (args.use_angular):
        if args.label_smooth:
            print("Using Label Smoothing with epsilon", args.label_epsilon)
            criterion = CrossEntropyLabelSmooth(
                num_classes=dataset.num_train_pids,
                epsilon=args.label_epsilon,
                use_gpu=use_gpu)
        elif args.focal_loss:
            print("Using Focal Loss with gamma=", args.focal_gamma)
            criterion = FocalLoss(gamma=args.focal_gamma)
        else:
            print("Using Normal Cross-Entropy")
            criterion = nn.CrossEntropyLoss()

        if args.jsd:
            print("Using JSD regularizer")
            criterion = (criterion, JSD_loss(dataset.num_train_pids))
            if args.auto_tune_mtl:
                print("Using AutoTune")
                use_autoTune = True
                criterion = MultiHeadLossAutoTune(
                    list(criterion),
                    [args.lambda_xent, args.confidence_beta]).cuda()
        else:
            if args.confidence_penalty:
                print("Using Confidence Penalty", args.confidence_beta)
            criterion = (criterion, ConfidencePenalty())
            if args.auto_tune_mtl and args.confidence_penalty:
                print("Using AutoTune")
                use_autoTune = True
                criterion = MultiHeadLossAutoTune(
                    list(criterion),
                    [args.lambda_xent, -args.confidence_beta]).cuda()
    else:
        if args.label_smooth:
            print("Using Angular Label Smoothing")
            criterion = AngularLabelSmooth(num_classes=dataset.num_train_pids,
                                           use_gpu=use_gpu)

        else:
            print("Using Angular Loss")
            criterion = AngleLoss()
    if use_autoTune:
        optimizer = init_optim(
            args.optim,
            list(model.parameters()) + list(criterion.parameters()), args.lr,
            args.weight_decay)
    else:
        optimizer = init_optim(args.optim, model.parameters(), args.lr,
                               args.weight_decay)
    if args.scheduler:
        scheduler = lr_scheduler.MultiStepLR(optimizer,
                                             milestones=args.stepsize,
                                             gamma=args.gamma)

    if args.fixbase_epoch > 0:
        if hasattr(model, 'classifier') and isinstance(model.classifier,
                                                       nn.Module):
            if use_autoTune:
                optimizer_tmp = init_optim(
                    args.optim,
                    list(model.classifier.parameters()) +
                    list(criterion.parameters()), args.fixbase_lr,
                    args.weight_decay)
            else:
                optimizer_tmp = init_optim(args.optim,
                                           model.classifier.parameters(),
                                           args.fixbase_lr, args.weight_decay)
        else:
            print(
                "Warn: model has no attribute 'classifier' and fixbase_epoch is reset to 0"
            )
            args.fixbase_epoch = 0

    if args.load_weights and check_isfile(args.load_weights):
        # load pretrained weights but ignore layers that don't match in size
        checkpoint = torch.load(args.load_weights)
        pretrain_dict = checkpoint['state_dict']
        model_dict = model.state_dict()
        pretrain_dict = {
            k: v
            for k, v in pretrain_dict.items()
            if k in model_dict and model_dict[k].size() == v.size()
        }
        model_dict.update(pretrain_dict)
        model.load_state_dict(model_dict)
        print("Loaded pretrained weights from '{}'".format(args.load_weights))

    if args.resume and check_isfile(args.resume):
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        args.start_epoch = checkpoint['epoch'] + 1
        best_rank1 = checkpoint['rank1']
        print("Loaded checkpoint from '{}'".format(args.resume))
        print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch,
                                                      best_rank1))

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    if args.single_folder != '':
        extract_features(model,
                         use_gpu,
                         args,
                         transform_test,
                         return_distmat=False)
        return
    if args.evaluate:
        print("Evaluate only")
        test_dir = args.save_dir
        if args.save_dir == 'log':
            if args.resume:
                test_dir = os.path.dirname(args.resume)
            else:
                test_dir = os.path.dirname(args.load_weights)
        distmat = test(model,
                       queryloader,
                       galleryloader,
                       use_gpu,
                       args,
                       writer=None,
                       epoch=-1,
                       return_distmat=True,
                       tsne_clusters=args.tsne_labels)

        if args.visualize_ranks:
            visualize_ranked_results(
                distmat,
                dataset,
                save_dir=osp.join(test_dir, 'ranked_results'),
                topk=10,
            )
        return

    writer = SummaryWriter(log_dir=osp.join(args.save_dir, 'tensorboard'))
    start_time = time.time()
    train_time = 0
    best_epoch = args.start_epoch
    print("==> Start training")

    if args.fixbase_epoch > 0:
        print(
            "Train classifier for {} epochs while keeping base network frozen".
            format(args.fixbase_epoch))

        for epoch in range(args.fixbase_epoch):
            start_train_time = time.time()
            train(epoch,
                  model,
                  criterion,
                  optimizer_tmp,
                  trainloader,
                  use_gpu,
                  writer,
                  args,
                  freeze_bn=True)
            train_time += round(time.time() - start_train_time)

        del optimizer_tmp
        print("Now open all layers for training")
    best_epoch = 0
    for epoch in range(args.start_epoch, args.max_epoch):
        start_train_time = time.time()
        train(epoch, model, criterion, optimizer, trainloader, use_gpu, writer,
              args)
        train_time += round(time.time() - start_train_time)

        if args.scheduler:
            scheduler.step()

        if (epoch + 1) > args.start_eval and (
            (args.save_epoch > 0 and (epoch + 1) % args.save_epoch == 0) or
            (args.eval_step > 0 and (epoch + 1) % args.eval_step == 0) or
            (epoch + 1) == args.max_epoch):
            if (epoch + 1) == args.max_epoch:
                if use_gpu:
                    state_dict = model.module.state_dict()
                else:
                    state_dict = model.state_dict()

                save_checkpoint(
                    {
                        'state_dict': state_dict,
                        'rank1': -1,
                        'epoch': epoch,
                    }, False,
                    osp.join(
                        args.save_dir, 'beforeTesting_checkpoint_ep' +
                        str(epoch + 1) + '.pth.tar'))
            is_best = False
            rank1 = -1
            if args.eval_step > 0:
                print("==> Test")

                rank1 = test(model,
                             queryloader,
                             galleryloader,
                             use_gpu,
                             args,
                             writer=writer,
                             epoch=epoch)

                is_best = rank1 > best_rank1

                if is_best:
                    best_rank1 = rank1
                    best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                }, is_best,
                osp.join(args.save_dir,
                         'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

    print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(
        best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))
    return best_rank1, best_epoch
def main():
    global args, best_rank1

    torch.manual_seed(args.seed)
    if not args.use_avai_gpus:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
    else:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("Initializing dataset {}".format(args.dataset))
    dataset = data_manager.init_vidreid_dataset(root=args.root,
                                                name=args.dataset)

    transform_train = T.Compose([
        T.Random2DTranslation(args.height, args.width),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = True if use_gpu else False

    # decompose tracklets into images for image-based training
    new_train = []
    for img_paths, pid, camid in dataset.train:
        for img_path in img_paths:
            new_train.append((img_path, pid, camid))

    trainloader = DataLoader(
        ImageDataset(new_train, transform=transform_train),
        sampler=RandomIdentitySampler(new_train, args.train_batch,
                                      args.num_instances),
        batch_size=args.train_batch,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=True,
    )

    queryloader = DataLoader(
        VideoDataset(dataset.query,
                     seq_len=args.seq_len,
                     sample='evenly',
                     transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    galleryloader = DataLoader(
        VideoDataset(dataset.gallery,
                     seq_len=args.seq_len,
                     sample='evenly',
                     transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    print("Initializing model: {}".format(args.arch))
    model = models.init_model(name=args.arch,
                              num_classes=dataset.num_train_pids,
                              loss={'xent', 'htri'})
    print("Model size: {:.3f} M".format(count_num_param(model)))

    if args.label_smooth:
        criterion_xent = CrossEntropyLabelSmooth(
            num_classes=dataset.num_train_pids, use_gpu=use_gpu)
    else:
        criterion_xent = nn.CrossEntropyLoss()
    criterion_htri = TripletLoss(margin=args.margin)

    optimizer = init_optim(args.optim, model.parameters(), args.lr,
                           args.weight_decay)
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=args.stepsize,
                                         gamma=args.gamma)

    if args.load_weights and check_isfile(args.load_weights):
        # load pretrained weights but ignore layers that don't match in size
        checkpoint = torch.load(args.load_weights)
        pretrain_dict = checkpoint['state_dict']
        model_dict = model.state_dict()
        pretrain_dict = {
            k: v
            for k, v in pretrain_dict.items()
            if k in model_dict and model_dict[k].size() == v.size()
        }
        model_dict.update(pretrain_dict)
        model.load_state_dict(model_dict)
        print("Loaded pretrained weights from '{}'".format(args.load_weights))

    if args.resume and check_isfile(args.resume):
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        args.start_epoch = checkpoint['epoch'] + 1
        best_rank1 = checkpoint['rank1']
        print("Loaded checkpoint from '{}'".format(args.resume))
        print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch,
                                                      best_rank1))

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    if args.evaluate:
        print("Evaluate only")
        distmat = test(model,
                       queryloader,
                       galleryloader,
                       args.pool,
                       use_gpu,
                       return_distmat=True)
        if args.visualize_ranks:
            visualize_ranked_results(
                distmat,
                dataset,
                save_dir=osp.join(args.save_dir, 'ranked_results'),
                topk=20,
            )
        return

    start_time = time.time()
    train_time = 0
    best_epoch = args.start_epoch
    print("==> Start training")

    for epoch in range(args.start_epoch, args.max_epoch):
        start_train_time = time.time()
        train(epoch, model, criterion_xent, criterion_htri, optimizer,
              trainloader, use_gpu)
        train_time += round(time.time() - start_train_time)

        scheduler.step()

        if (epoch + 1) > args.start_eval and args.eval_step > 0 and (
                epoch + 1) % args.eval_step == 0 or (epoch +
                                                     1) == args.max_epoch:
            print("==> Test")
            rank1 = test(model, queryloader, galleryloader, args.pool, use_gpu)
            is_best = rank1 > best_rank1

            if is_best:
                best_rank1 = rank1
                best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                }, is_best,
                osp.join(args.save_dir,
                         'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

    print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(
        best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))
Пример #6
0
def main():
    global best_rank1, best_mAP
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if not args.use_avai_gpus:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu:
        use_gpu = False

    if not args.evaluate:
        sys.stdout = Logger(
            osp.join(
                args.save_dir,
                'log_train{}.txt'.format(time.strftime('-%Y-%m-%d-%H-%M-%S'))))
    else:
        sys.stdout = Logger(
            osp.join(
                args.save_dir,
                'log_test{}.txt'.format(time.strftime('-%Y-%m-%d-%H-%M-%S'))))
    writer = SummaryWriter(log_dir=args.save_dir, comment=args.arch)
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = False if 'resnet3dt' in args.arch else True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("Initializing dataset {}".format(args.dataset))
    dataset = data_manager.init_vidreid_dataset(root=args.root,
                                                name=args.dataset,
                                                split_id=args.split_id,
                                                use_pose=args.use_pose)

    transform_train = list()
    print('Transform:')
    if args.misalign_aug:
        print('+ Misalign Augmentation')
        transform_train.append(T.GroupMisAlignAugment())
    if args.rand_crop:
        print('+ Random Crop')
        transform_train.append(T.GroupRandomCrop(size=(240, 120)))
    print('+ Resize to ({} x {})'.format(args.height, args.width))
    transform_train.append(T.GroupResize((args.height, args.width)))
    if args.flip_aug:
        print('+ Random HorizontalFlip')
        transform_train.append(T.GroupRandomHorizontalFlip())
    print('+ ToTensor')
    transform_train.append(T.GroupToTensor())
    print(
        '+ Normalize with mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]'
    )
    transform_train.append(
        T.GroupNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224,
                                                          0.225]))
    if args.rand_erase:
        print('+ Random Erasing')
        transform_train.append(T.GroupRandomErasing())
    transform_train = T.Compose(transform_train)

    transform_test = T.Compose([
        T.GroupResize((args.height, args.width)),
        T.GroupToTensor(),
        T.GroupNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224,
                                                          0.225]),
    ])

    pin_memory = True if use_gpu else False

    trainloader = DataLoader(
        VideoDataset(dataset.train,
                     seq_len=args.seq_len,
                     sample=args.train_sample,
                     transform=transform_train,
                     training=True,
                     pose_info=dataset.process_poses,
                     num_split=args.num_split,
                     num_parts=args.num_parts,
                     num_scale=args.num_scale,
                     pyramid_part=args.pyramid_part,
                     enable_pose=args.use_pose),
        sampler=eval(args.train_sampler)(dataset.train,
                                         batch_size=args.train_batch,
                                         num_instances=args.num_instances),
        batch_size=args.train_batch,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=True,
    )

    queryloader = DataLoader(
        VideoDataset(dataset.query,
                     seq_len=args.seq_len,
                     sample=args.test_sample,
                     transform=transform_test,
                     pose_info=dataset.process_poses,
                     num_split=args.num_split,
                     num_parts=args.num_parts,
                     num_scale=args.num_scale,
                     pyramid_part=args.pyramid_part,
                     enable_pose=args.use_pose),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    galleryloader = DataLoader(
        VideoDataset(dataset.gallery,
                     seq_len=args.seq_len,
                     sample=args.test_sample,
                     transform=transform_test,
                     pose_info=dataset.process_poses,
                     num_split=args.num_split,
                     num_parts=args.num_parts,
                     num_scale=args.num_scale,
                     pyramid_part=args.pyramid_part,
                     enable_pose=args.use_pose),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    print("Initializing model: {}".format(args.arch))
    model = models.init_model(name=args.arch,
                              num_classes=dataset.num_train_pids,
                              loss={'xent', 'htri'},
                              last_stride=args.last_stride,
                              num_parts=args.num_parts,
                              num_scale=args.num_scale,
                              num_split=args.num_split,
                              pyramid_part=args.pyramid_part,
                              num_gb=args.num_gb,
                              use_pose=args.use_pose,
                              learn_graph=args.learn_graph,
                              consistent_loss=args.consistent_loss,
                              bnneck=args.bnneck,
                              save_dir=args.save_dir)

    input_size = sum(calc_splits(
        args.num_split)) if args.pyramid_part else args.num_split
    input_size *= args.num_scale * args.seq_len
    num_params, flops = compute_model_complexity(
        model,
        input=[
            torch.randn(1, args.seq_len, 3, args.height, args.width),
            torch.ones(1, input_size, input_size)
        ],
        verbose=True,
        only_conv_linear=False)
    print('Model complexity: params={:,} flops={:,}'.format(num_params, flops))

    if args.label_smooth:
        criterion_xent = CrossEntropyLabelSmooth(
            num_classes=dataset.num_train_pids, use_gpu=use_gpu)
    else:
        criterion_xent = nn.CrossEntropyLoss()
    criterion_htri = TripletLoss(margin=args.margin, soft=args.soft_margin)

    param_groups = model.parameters()
    optimizer = init_optim(args.optim, param_groups, args.lr,
                           args.weight_decay)

    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=args.stepsize,
                                         gamma=args.gamma)
    if args.warmup:
        scheduler = lr_scheduler.WarmupMultiStepLR(optimizer,
                                                   milestones=args.stepsize,
                                                   gamma=args.gamma,
                                                   warmup_iters=10,
                                                   warmup_factor=0.01)

    if args.load_weights and check_isfile(args.load_weights):
        # load pretrained weights but ignore layers that don't match in size
        checkpoint = torch.load(args.load_weights)
        pretrain_dict = checkpoint['state_dict']
        model_dict = model.state_dict()
        pretrain_dict = {
            k: v
            for k, v in pretrain_dict.items()
            if k in model_dict and model_dict[k].size() == v.size()
        }
        model_dict.update(pretrain_dict)
        model.load_state_dict(model_dict)
        print("Loaded pretrained weights from '{}'".format(args.load_weights))

    if args.resume and check_isfile(args.resume):
        print("Loaded checkpoint from '{}'".format(args.resume))
        from functools import partial
        import pickle
        pickle.load = partial(pickle.load, encoding="latin1")
        pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
        checkpoint = torch.load(args.resume,
                                map_location=lambda storage, loc: storage,
                                pickle_module=pickle)

        print('Loaded model weights')
        model.load_state_dict(checkpoint['state_dict'])
        if optimizer is not None and 'optimizer' in checkpoint:
            print('Loaded optimizer')
            optimizer.load_state_dict(checkpoint['optimizer'])
            if use_gpu:
                for state in optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.cuda()
        start_epoch = checkpoint['epoch'] + 1
        print('- start_epoch: {}'.format(start_epoch))
        best_rank1 = checkpoint['rank1']
        print("- rank1: {}".format(best_rank1))
        if 'mAP' in checkpoint:
            best_mAP = checkpoint['mAP']
            print("- mAP: {}".format(best_mAP))
    else:
        start_epoch = 0

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    if args.evaluate:
        print("Evaluate only")
        distmat = test(model,
                       queryloader,
                       galleryloader,
                       args.pool,
                       use_gpu,
                       return_distmat=True)
        if args.visualize_ranks:
            visualize_ranked_results(
                distmat,
                dataset,
                save_dir=osp.join(args.save_dir, 'ranked_results'),
                topk=20,
            )
        return

    start_time = time.time()
    train_time = 0
    best_epoch = start_epoch
    print("==> Start training")

    for epoch in range(start_epoch, args.max_epoch):
        start_train_time = time.time()
        train(epoch,
              model,
              criterion_xent,
              criterion_htri,
              optimizer,
              trainloader,
              use_gpu,
              writer=writer)
        train_time += round(time.time() - start_train_time)

        if epoch >= args.zero_wd > 0:
            set_wd(optimizer, 0)
            for group in optimizer.param_groups:
                assert group['weight_decay'] == 0, '{} is not zero'.format(
                    group['weight_decay'])

        scheduler.step(epoch)

        if (epoch + 1) > args.start_eval and args.eval_step > 0 and (
                epoch + 1) % args.eval_step == 0 or (epoch +
                                                     1) == args.max_epoch:
            print("==> Test")
            rank1, mAP = test(model, queryloader, galleryloader, args.pool,
                              use_gpu)
            is_best = rank1 > best_rank1

            if is_best:
                best_rank1 = rank1
                best_mAP = mAP
                best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'optimizer': optimizer.state_dict(),
                    'rank1': rank1,
                    'mAP': mAP,
                    'epoch': epoch,
                }, False,
                osp.join(args.save_dir,
                         'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

            writer.add_scalar(tag='acc/rank1',
                              scalar_value=rank1,
                              global_step=epoch + 1)
            writer.add_scalar(tag='acc/mAP',
                              scalar_value=mAP,
                              global_step=epoch + 1)

    print("==> Best Rank-1 {:.2%}, mAP: {:.2%}, achieved at epoch {}".format(
        best_rank1, best_mAP, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))