def set_transform():

    if args.is_REA:
        transform_train = T.Compose([
            T.Random2DTranslation(args.height, args.width),
            T.RandomHorizontalFlip(),
            T.RandomEraising(),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    else:
        transform_train = T.Compose([
            T.Random2DTranslation(args.height, args.width),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    return transform_train,transform_test
Ejemplo n.º 2
0
Archivo: demo.py Proyecto: lihao056/BOT
def chart_recognition(model_chart, img_dir):
    # img = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB))
    transform = T.Compose([
        T.Resize((256, 128)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    loader = DataLoader(
        ImageDataset_demo(img_dir, transform=transform),
        batch_size=1,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        drop_last=False,
    )
    model_chart.eval()
    with torch.no_grad():
        for batch_idx, img2 in enumerate(loader):
            if torch.cuda.is_available(): img2 = img2.cuda()
            score = model_chart(img2)
            print(score)
            chart = torch.argmax(score.data, 1)
            chart = chart[0].cpu().numpy()
            print(chart)
            kind = num2label[str(chart)]

    return kind
Ejemplo n.º 3
0
def main():
    transform_test = T.Compose([
        T.Resize((256, 128)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    use_gpu = torch.cuda.is_available()
    model = models.init_model(name=args.arch, num_classes=751, loss={'xent'})
    checkpoint = torch.load(os.path.join('./model', 'best_model.pth.tar'))
    model.load_state_dict(checkpoint['state_dict'])
    model.classifier = nn.Sequential()
    if use_gpu:
        model = nn.DataParallel(model).cuda()
    model.eval()
    for dataset in ['val', 'test']:

        for subset in ['query', 'gallery']:
            test_names, test_features = extractor(
                model,
                DataLoader(
                    Dataset(dataset + '/' + subset, transform=transform_test)))
            results = {'names': test_names, 'features': test_features.numpy()}
            scipy.io.savemat(
                os.path.join('log_dir',
                             'feature_%s_%s.mat' % (dataset, subset)), results)
Ejemplo n.º 4
0
def main():
    torch.manual_seed(args.seed)
    if not args.use_avai_gpus:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
    else:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("Initializing dataset {}".format(args.dataset))
    dataset = data_manager.init_imgreid_dataset(
        root=args.root,
        name=args.dataset,
        split_id=args.split_id,
        cuhk03_labeled=args.cuhk03_labeled,
        cuhk03_classic_split=args.cuhk03_classic_split,
    )

    transform_train = T.Compose([
        T.Random2DTranslation(args.height, args.width),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = True if use_gpu else False

    trainloader = DataLoader(
        ImageDataset(dataset.train, transform=transform_train),
        batch_size=args.train_batch,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=True,
    )

    queryloader = DataLoader(
        ImageDataset(dataset.query, transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    galleryloader = DataLoader(
        ImageDataset(dataset.gallery, transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    print("Initializing model: {}".format(args.arch))
    model = models.init_model(name=args.arch,
                              num_classes=dataset.num_train_pids,
                              loss={'xent'},
                              use_gpu=use_gpu)
    print("Model size: {:.3f} M".format(count_num_param(model)))

    if args.label_smooth:
        criterion = CrossEntropyLabelSmooth(num_classes=dataset.num_train_pids,
                                            use_gpu=use_gpu)
    else:
        criterion = nn.CrossEntropyLoss()
    optimizer = init_optim(args.optim, model.parameters(), args.lr,
                           args.weight_decay)
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=args.stepsize,
                                         gamma=args.gamma)

    if args.fixbase_epoch > 0:
        if hasattr(model, 'classifier') and isinstance(model.classifier,
                                                       nn.Module):
            optimizer_tmp = init_optim(args.optim,
                                       model.classifier.parameters(),
                                       args.fixbase_lr, args.weight_decay)
        else:
            print(
                "Warn: model has no attribute 'classifier' and fixbase_epoch is reset to 0"
            )
            args.fixbase_epoch = 0

    if args.load_weights and check_isfile(args.load_weights):
        # load pretrained weights but ignore layers that don't match in size
        checkpoint = torch.load(args.load_weights)
        pretrain_dict = checkpoint['state_dict']
        model_dict = model.state_dict()
        pretrain_dict = {
            k: v
            for k, v in pretrain_dict.items()
            if k in model_dict and model_dict[k].size() == v.size()
        }
        model_dict.update(pretrain_dict)
        model.load_state_dict(model_dict)
        print("Loaded pretrained weights from '{}'".format(args.load_weights))

    if args.resume and check_isfile(args.resume):
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        args.start_epoch = checkpoint['epoch'] + 1
        best_rank1 = checkpoint['rank1']
        print("Loaded checkpoint from '{}'".format(args.resume))
        print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch,
                                                      best_rank1))

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    if args.evaluate:
        print("Evaluate only")
        distmat = test(model,
                       queryloader,
                       galleryloader,
                       use_gpu,
                       return_distmat=True)
        if args.vis_ranked_res:
            visualize_ranked_results(
                distmat,
                dataset,
                save_dir=osp.join(args.save_dir, 'ranked_results'),
                topk=20,
            )
        return

    start_time = time.time()
    train_time = 0
    best_epoch = args.start_epoch
    print("==> Start training")

    if args.fixbase_epoch > 0:
        print(
            "Train classifier for {} epochs while keeping base network frozen".
            format(args.fixbase_epoch))

        for epoch in range(args.fixbase_epoch):
            start_train_time = time.time()
            train(epoch,
                  model,
                  criterion,
                  optimizer_tmp,
                  trainloader,
                  use_gpu,
                  freeze_bn=True)
            train_time += round(time.time() - start_train_time)

        del optimizer_tmp
        print("Now open all layers for training")

    for epoch in range(args.start_epoch, args.max_epoch):
        start_train_time = time.time()
        train(epoch, model, criterion, optimizer, trainloader, use_gpu)
        train_time += round(time.time() - start_train_time)

        scheduler.step()

        if (epoch + 1) > args.start_eval and args.eval_step > 0 and (
                epoch + 1) % args.eval_step == 0 or (epoch +
                                                     1) == args.max_epoch:
            print("==> Test")
            rank1 = test(model, queryloader, galleryloader, use_gpu)
            is_best = rank1 > best_rank1

            if is_best:
                best_rank1 = rank1
                best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                }, is_best,
                osp.join(args.save_dir,
                         'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

    print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(
        best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))
def main():
    torch.manual_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    logger_info = LoggerInfo()
    sys.stdout = Logger(logger_info)
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    # print("Initializing dataset {}".format(args.dataset))
    # dataset = data_manager.init_imgreid_dataset(
    #     root=args.root, name=args.dataset, split_id=args.split_id,
    #     cuhk03_labeled=args.cuhk03_labeled,
    #     cuhk03_classic_split=args.cuhk03_classic_split,
    # )

    transform_train = T.Compose([
        T.Random2DTranslation(args.height, args.width),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = True if use_gpu else False

    train_dataset = ImageFolder(os.path.join(args.data_dir, "train_all"),
                                transform=transform_train)
    train_query_dataset = ImageFolder(os.path.join(args.data_dir, "val"),
                                      transform=transform_train)
    train_gallery_dataset = ImageFolder(os.path.join(args.data_dir, "train"),
                                        transform=transform_train)
    test_dataset = ImageFolder(os.path.join(args.data_dir, "probe"),
                               transform=transform_test)
    query_dataset = ImageFolder(os.path.join(args.data_dir, "query"),
                                transform=transform_test)
    gallery_dataset = ImageFolder(os.path.join(args.data_dir, "gallery"),
                                  transform=transform_test)

    # train_batch_sampler = VehicleIdBalancedBatchSampler(train_dataset, n_classes=8, n_samples=6)
    # test_batch_sampler = VehicleIdBalancedBatchSampler(test_dataset, n_classes=8, n_samples=8)

    train_batch_sampler = VehicleIdCCLBatchSampler(train_dataset,
                                                   n_classes=n_cls,
                                                   n_samples=n_samples)

    trainloader = DataLoader(train_dataset,
                             batch_sampler=train_batch_sampler,
                             num_workers=args.workers,
                             pin_memory=pin_memory)

    testloader = DataLoader(test_dataset,
                            batch_size=args.test_batch,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=pin_memory,
                            drop_last=False)

    # trainloader = DataLoader(
    #     ImageDataset(dataset.train, transform=transform_train),
    #     batch_sampler=train_batch_sampler, batch_size=args.train_batch,
    #     shuffle=True, num_workers=args.workers, pin_memory=pin_memory, drop_last=True
    # )

    queryloader = DataLoader(
        query_dataset,
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    galleryloader = DataLoader(
        gallery_dataset,
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    train_query_loader = DataLoader(
        train_query_dataset,
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    train_gallery_loader = DataLoader(
        train_gallery_dataset,
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    print("Initializing model: {}".format(args.arch))

    if args.evaluate:
        model = models.init_model(name=args.arch,
                                  num_classes=len(query_dataset.classes),
                                  loss_type=args.loss_type)
    else:
        model = models.init_model(name=args.arch,
                                  num_classes=len(train_dataset.classes),
                                  loss_type=args.loss_type)
    print("Model size: {:.3f} M".format(count_num_param(model)))

    if args.label_smooth:
        criterion = CrossEntropyLabelSmooth(num_classes=len(
            train_dataset.classes),
                                            use_gpu=use_gpu)
    else:
        if args.loss_type == 'xent':
            criterion = nn.CrossEntropyLoss()
        elif args.loss_type == 'angle':
            criterion = AngleLoss()
        elif args.loss_type == 'triplet':
            # criterion = CoupledClustersLoss(margin=1., triplet_selector=RandomNegativeTripletSelector(margin=1.))
            # criterion = OnlineTripletLoss(margin=1., triplet_selector=RandomNegativeTripletSelector(margin=1.))
            # criterion = OnlineTripletLoss(margin=1., triplet_selector=HardestNegativeTripletSelector(margin=1.))
            criterion = CoupledClustersLoss(margin=1.,
                                            n_classes=n_cls,
                                            n_samples=n_samples)
            # criterion = OnlineTripletLoss(margin=1., triplet_selector=SemihardNegativeTripletSelector(margin=1.))
        elif args.loss_type == 'xent_htri':
            criterion = XentTripletLoss(
                margin=1.,
                triplet_selector=RandomNegativeTripletSelector(margin=1.))
        else:
            raise KeyError("Unsupported loss: {}".format(args.loss_type))
    # model_param_list = [{'params': model.base.parameters(), 'lr': args.lr},
    #                     {'params': model.classifier.parameters(), 'lr': args.lr * 10}]
    # optimizer = init_optim(args.optim, model_param_list, lr=1.0, weight_decay=args.weight_decay)
    optimizer = init_optim(args.optim, model.parameters(), args.lr,
                           args.weight_decay)
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=args.stepsize,
                                         gamma=args.gamma)

    if args.fixbase_epoch > 0:
        if hasattr(model, 'classifier') and isinstance(model.classifier,
                                                       nn.Module):
            optimizer_tmp = init_optim(args.optim,
                                       model.classifier.parameters(),
                                       args.fixbase_lr, args.weight_decay)
        else:
            print(
                "Warn: model has no attribute 'classifier' and fixbase_epoch is reset to 0"
            )
            args.fixbase_epoch = 0

    if args.load_weights:
        # load pretrained weights but ignore layers that don't match in size
        if check_isfile(args.load_weights):
            checkpoint = torch.load(args.load_weights)
            pretrain_dict = checkpoint['state_dict']
            model_dict = model.state_dict()
            pretrain_dict = {
                k: v
                for k, v in pretrain_dict.items()
                if k in model_dict and model_dict[k].size() == v.size()
            }
            model_dict.update(pretrain_dict)
            model.load_state_dict(model_dict)
            print("Loaded pretrained weights from '{}'".format(
                args.load_weights))

    if args.resume:
        from functools import partial
        import pickle
        pickle.load = partial(pickle.load, encoding="latin1")
        pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
        if check_isfile(args.resume):
            checkpoint = torch.load(args.resume)
            # checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage, pickle_module=pickle)
            model.load_state_dict(checkpoint['state_dict'])
            args.start_epoch = checkpoint['epoch'] + 1
            rank1 = checkpoint['rank1']
            print("Loaded checkpoint from '{}'".format(args.resume))
            print("- start_epoch: {}\n- rank1: {}".format(
                args.start_epoch, rank1))

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    # if args.evaluate:
    #     print("Evaluate only")
    #     distmat = test(model, queryloader, galleryloader, train_query_loader, train_gallery_loader,
    #                    use_gpu, return_distmat=True)
    #     if args.vis_ranked_res:
    #         visualize_ranked_results(
    #             distmat, dataset,
    #             save_dir=osp.join(args.save_dir, 'ranked_results'),
    #             topk=20,
    #         )
    #     return

    start_time = time.time()
    train_time = 0
    best_rank1 = -np.inf
    best_epoch = 0
    print("==> Start training")

    if args.fixbase_epoch > 0:
        print(
            "Train classifier for {} epochs while keeping base network frozen".
            format(args.fixbase_epoch))

        for epoch in range(args.fixbase_epoch):
            start_train_time = time.time()
            train(epoch,
                  model,
                  criterion,
                  optimizer_tmp,
                  trainloader,
                  use_gpu,
                  freeze_bn=True)
            train_time += round(time.time() - start_train_time)

        del optimizer_tmp
        print("Now open all layers for training")

    for epoch in range(args.start_epoch, args.max_epoch):
        start_train_time = time.time()
        train(epoch, model, criterion, optimizer, trainloader, use_gpu)
        train_time += round(time.time() - start_train_time)

        scheduler.step()

        if (epoch + 1) > args.start_eval and args.eval_step > 0 and (
                epoch + 1) % args.eval_step == 0 or (epoch +
                                                     1) == args.max_epoch:
            print("==> Test")
            rank1 = test(model, testloader, queryloader, galleryloader,
                         train_query_loader, train_gallery_loader, use_gpu)
            is_best = rank1 > best_rank1

            if is_best:
                best_rank1 = rank1
                best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                },
                is_best,
                use_gpu_suo=True,
                fpath=osp.join(
                    args.save_dir, 'checkpoint_ep' + str(epoch + 1) +
                    checkpoint_suffix + '.pth.tar'))

    print("==> Best Rank-1 {:.2%}, achieved at epoch {}".format(
        best_rank1, best_epoch))
    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))
Ejemplo n.º 6
0
def main():
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
    else:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))

    print("==========\nArgs:{}\n==========".format(args))

    dataset = mydataset.Market1501(root=args.root, split_id=0)

    transform_train = T.Compose([
        T.Random2DTranslation(args.height, args.width),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = True if use_gpu else False

    queryloader = DataLoader(
        ImageDatasettest(dataset.query, transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    galleryloader = DataLoader(
        ImageDatasettest(dataset.gallery, transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    cri = nn.MSELoss().cuda()
    criterion = nn.CrossEntropyLoss()
    model = torchreid.resnet_person.net2(num_classes=dataset.num_train_pids)

    if args.evaluate:

        print("Evaluate only")
        checkpoint = torch.load(args.testmodel)
        model.load_state_dict(checkpoint['state_dict'])
        args.start_epoch = checkpoint['epoch']
        rank1 = checkpoint['rank1']
        print("rank1: {}".format(rank1))
        if use_gpu:
            model = nn.DataParallel(model).cuda()
        distmat = test(model,
                       queryloader,
                       galleryloader,
                       use_gpu,
                       return_distmat=True)
        return

    trainloader = DataLoader(
        ImageDatasettrain(dataset.train, args.height, args.width),
        batch_size=args.train_batch,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=True,
    )
    if use_gpu:
        model = nn.DataParallel(model).cuda()
    optimizer = init_optim(args.optim, model.parameters(), args.lr, 5e-04)
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=args.stepsize,
                                         gamma=0.1)

    start_time = time.time()
    train_time = 0
    best_rank1 = -np.inf
    best_epoch = 0
    print("==> Start training")

    for epoch in range(args.max_epoch):
        start_train_time = time.time()

        train(epoch, model, criterion, cri, optimizer, trainloader, use_gpu)
        train_time += round(time.time() - start_train_time)

        scheduler.step()

        if args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (
                epoch + 1) == args.max_epoch:
            print("==> Test")
            rank1 = test(model, queryloader, galleryloader, use_gpu)
            is_best = rank1 > best_rank1

            if is_best:
                best_rank1 = rank1
                best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                }, is_best,
                osp.join(args.save_dir,
                         'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

    print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(
        best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))
def main():
    torch.manual_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    logger_info = LoggerInfo()
    sys.stdout = Logger(logger_info)
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("\nInitializing dataset {}".format(args.dataset_plt))
    dataset_plt = data_manager.init_imgreid_dataset(root=args.root,
                                                    name=args.dataset_plt)
    print("\nInitializing dataset {}".format(args.dataset_vecl))
    dataset_vecl = data_manager.init_imgreid_dataset(root=args.root,
                                                     name=args.dataset_vecl)

    transform_test_plt = T.Compose([
        T.Resize((args.height_plt, args.width_plt)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    # transform_flip_test_plt = T.Compose([
    #     T.Resize((args.height_plt, args.width_plt)),
    #     functional.hflip,
    #     T.ToTensor(),
    #     T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    # ])
    transform_test_vecl = T.Compose([
        T.Resize((args.height_vecl, args.width_vecl)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    # transform_flip_test_vecl = T.Compose([
    #     T.Resize((args.height_vecl, args.width_vecl)),
    #     functional.hflip,
    #     T.ToTensor(),
    #     T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    # ])

    pin_memory = True if use_gpu else False

    queryloader_plt = DataLoader(
        ImageDatasetV2(dataset_plt.query, transform=transform_test_plt),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )
    # queryloader_flip_plt = DataLoader(
    #     ImageDatasetV2(dataset_plt.query, transform=transform_flip_test_plt),
    #     batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
    #     pin_memory=pin_memory, drop_last=False,
    # )
    # queryloader_plt = [queryloader_plt, queryloader_flip_plt]
    queryloader_plt = [queryloader_plt]
    galleryloader_plt = DataLoader(
        ImageDatasetV2(dataset_plt.gallery, transform=transform_test_plt),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )
    # galleryloader_flip_plt = DataLoader(
    #     ImageDatasetV2(dataset_plt.gallery, transform=transform_flip_test_plt),
    #     batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
    #     pin_memory=pin_memory, drop_last=False,
    # )
    # galleryloader_plt = [galleryloader_plt, galleryloader_flip_plt]
    galleryloader_plt = [galleryloader_plt]

    queryloader_vecl = DataLoader(
        ImageDatasetWGL(dataset_vecl.query,
                        transform=transform_test_vecl,
                        with_image_name=True),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )
    # queryloader_flip_vecl = DataLoader(
    #     ImageDatasetV2(dataset_vecl.query, transform=transform_flip_test_vecl),
    #     batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
    #     pin_memory=pin_memory, drop_last=False,
    # )
    # queryloader_vecl = [queryloader_vecl, queryloader_flip_vecl]
    queryloader_vecl = [queryloader_vecl]
    galleryloader_vecl = DataLoader(
        ImageDatasetWGL(dataset_vecl.gallery,
                        transform=transform_test_vecl,
                        with_image_name=True),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )
    # galleryloader_flip_vecl = DataLoader(
    #     ImageDatasetV2(dataset_vecl.gallery, transform=transform_flip_test_vecl),
    #     batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
    #     pin_memory=pin_memory, drop_last=False,
    # )
    # galleryloader_vecl = [galleryloader_vecl, galleryloader_flip_vecl]
    galleryloader_vecl = [galleryloader_vecl]

    print("\nInitializing model: {}".format(args.arch))
    model_plt = models.init_model(name=args.arch_plt,
                                  num_classes=dataset_plt.num_train_pids,
                                  loss_type=args.loss_type)
    model_vecl = models.init_model(name=args.arch_vecl,
                                   num_classes=dataset_vecl.num_train_pids,
                                   loss_type=args.loss_type)
    print("Plate model size: {:.3f} M".format(count_num_param(model_plt)))
    print("Vehicle model size: {:.3f} M".format(count_num_param(model_vecl)))

    if args.loss_type == 'xent':
        criterion = nn.CrossEntropyLoss()
    else:
        raise KeyError("Unsupported loss: {}".format(args.loss_type))

    if args.resm_plt and args.resm_vecl:
        if check_isfile(args.resm_plt) and check_isfile(args.resm_vecl):
            ckpt_plt = torch.load(args.resm_plt)
            pre_dic_plt = ckpt_plt['state_dict']

            model_dic_plt = model_plt.state_dict()
            pre_dic_plt = {
                k: v
                for k, v in pre_dic_plt.items()
                if k in model_dic_plt and model_dic_plt[k].size() == v.size()
            }
            model_dic_plt.update(pre_dic_plt)
            model_plt.load_state_dict(model_dic_plt)
            args.start_epoch_plt = ckpt_plt['epoch']
            rank1_plt = ckpt_plt['rank1']

            ckpt_vecl = torch.load(args.resm_vecl)
            pre_dic_vecl = ckpt_vecl['state_dict']

            model_dic_vecl = model_vecl.state_dict()
            pre_dic_vecl = {
                k: v
                for k, v in pre_dic_vecl.items() if k in model_dic_vecl
                and model_dic_vecl[k].size() == v.size()
            }
            model_dic_vecl.update(pre_dic_vecl)
            model_vecl.load_state_dict(model_dic_vecl)
            args.start_epoch_vecl = ckpt_vecl['epoch']
            rank1_vecl = ckpt_vecl['rank1']

            print("\nLoaded checkpoint from '{}' \nand '{}".format(
                args.resm_plt, args.resm_vecl))
            print("Plate model: start_epoch: {}, rank1: {}".format(
                args.start_epoch_plt, rank1_plt))
            print("Vehicle model: start_epoch: {}, rank1: {}".format(
                args.start_epoch_vecl, rank1_vecl))

    if use_gpu:
        model_plt = nn.DataParallel(model_plt).cuda()
        model_vecl = nn.DataParallel(model_vecl).cuda()

    if args.evaluate:
        print("\nEvaluate only")
        test(model_plt, model_vecl, queryloader_plt, queryloader_vecl,
             galleryloader_plt, galleryloader_vecl, use_gpu)
        return
Ejemplo n.º 8
0
def main():
    torch.manual_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
    else:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("Initializing dataset {}".format(args.dataset))
    dataset = data_manager.init_imgreid_dataset(
        root=args.root,
        name=args.dataset,
        split_id=args.split_id,
        cuhk03_labeled=args.cuhk03_labeled,
        cuhk03_classic_split=args.cuhk03_classic_split,
    )  #cuhk03_labeled: detected,labeled

    transform_train = T.Compose([
        T.Random2DTranslation(args.height, args.width),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = True if use_gpu else False
    #pdb.set_trace()

    trainloader = DataLoader(
        ImageDataset(dataset.train, transform=transform_train),
        sampler=RandomIdentitySampler(dataset.train, args.train_batch,
                                      args.num_instances),
        batch_size=args.train_batch,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=True,
    )

    queryloader = DataLoader(
        ImageDataset(dataset.query, transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    galleryloader = DataLoader(
        ImageDataset(dataset.gallery, transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    print("Initializing model: {}".format(args.arch))
    model = models.init_model(name=args.arch,
                              num_classes=dataset.num_train_pids,
                              batchsize=args.test_batch,
                              loss={'xent', 'wcont', 'htri'})
    print("Model size: {:.3f} M".format(count_num_param(model)))

    criterion_xent = nn.CrossEntropyLoss()
    criterion_htri = TripletLoss(margin=args.margin)
    criterion_KA = KALoss(margin=args.margin,
                          same_margin=args.same_margin,
                          use_auto_samemargin=args.use_auto_samemargin)
    cirterion_lifted = LiftedLoss(margin=args.margin)
    cirterion_batri = BA_TripletLoss(margin=args.margin)

    if args.use_auto_samemargin == True:
        G_params = [{
            'params': model.parameters(),
            'lr': args.lr
        }, {
            'params': criterion_KA.auto_samemargin,
            'lr': args.lr
        }]
    else:
        G_params = [para for _, para in model.named_parameters()]

    optimizer = init_optim(args.optim, G_params, args.lr, args.weight_decay)

    if args.load_weights:
        # load pretrained weights but ignore layers that don't match in size
        if check_isfile(args.load_weights):
            checkpoint = torch.load(args.load_weights)
            pretrain_dict = checkpoint['state_dict']
            model_dict = model.state_dict()
            pretrain_dict = {
                k: v
                for k, v in pretrain_dict.items()
                if k in model_dict and model_dict[k].size() == v.size()
            }
            model_dict.update(pretrain_dict)
            model.load_state_dict(model_dict)
            print("Loaded pretrained weights from '{}'".format(
                args.load_weights))

    if args.resume:
        if check_isfile(args.resume):
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            args.start_epoch = checkpoint['epoch']
            rank1 = checkpoint['rank1']
            print("Loaded checkpoint from '{}'".format(args.resume))
            print("- start_epoch: {}\n- rank1: {}".format(
                args.start_epoch, rank1))

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    if args.evaluate:
        print("Evaluate only")
        distmat = test(model,
                       queryloader,
                       galleryloader,
                       use_gpu,
                       return_distmat=True)
        if args.vis_ranked_res:
            visualize_ranked_results(
                distmat,
                dataset,
                save_dir=osp.join(args.save_dir, 'ranked_results'),
                topk=20,
            )
        return

    start_time = time.time()
    train_time = 0
    best_rank1 = -np.inf
    best_epoch = 0
    print("==> Start training")

    for epoch in range(args.start_epoch, args.max_epoch):
        start_train_time = time.time()
        adjust_learning_rate(optimizer, epoch)
        train(epoch, model, cirterion_batri, cirterion_lifted, criterion_xent,
              criterion_htri, criterion_KA, optimizer, trainloader, use_gpu)
        train_time += round(time.time() - start_train_time)

        if (epoch + 1) > args.start_eval and args.eval_step > 0 and (
                epoch + 1) % args.eval_step == 0 or (epoch +
                                                     1) == args.max_epoch:
            rank1 = 0
            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            print("==> Test")
            sys.stdout.flush()
            rank1 = test(model, queryloader, galleryloader, use_gpu)
            is_best = rank1 > best_rank1

            if is_best:
                best_rank1 = rank1
                best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                }, is_best,
                osp.join(args.save_dir,
                         'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

            print("model saved")

    print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(
        best_rank1, best_epoch))
    sys.stdout.flush()

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))
    sys.stdout.flush()
Ejemplo n.º 9
0
def main(args):
    args = parser.parse_args(args)
    #global best_rank1
    best_rank1 = -np.inf
    torch.manual_seed(args.seed)
    # np.random.seed(args.seed)
    # random.seed(args.seed)
    if not args.use_avai_gpus:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
    else:
        test_dir = args.save_dir
        if args.save_dir == 'log':
            if args.resume:
                test_dir = os.path.dirname(args.resume)
            else:
                test_dir = os.path.dirname(args.load_weights)
        sys.stdout = Logger(osp.join(test_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
        # print("Currently using GPU {}".format(args.gpu_devices))
        # #cudnn.benchmark = False
        # cudnn.deterministic = True
        # torch.cuda.manual_seed_all(args.seed)
        # torch.set_default_tensor_type('torch.DoubleTensor')
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("Initializing dataset {}".format(args.dataset))
    dataset = data_manager.init_imgreid_dataset(
        root=args.root,
        name=args.dataset,
        split_id=args.split_id,
        cuhk03_labeled=args.cuhk03_labeled,
        cuhk03_classic_split=args.cuhk03_classic_split,
    )

    transform_train = T.Compose([
        T.Random2DTranslation(args.height, args.width),
        #T.Resize((args.height, args.width)),
        #T.RandomSizedEarser(),
        T.RandomHorizontalFlip(),
        #T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = True if use_gpu else False

    if 'stanford' in args.dataset:
        datasetLoader = ImageDataset_stanford
    else:
        datasetLoader = ImageDataset
    if args.crop_img:
        print("Using Cropped Images")
    else:
        print("NOT using cropped Images")
    trainloader = DataLoader(
        datasetLoader(dataset.train,
                      -1,
                      crop=args.crop_img,
                      transform=transform_train),
        batch_size=args.train_batch,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=True,
    )

    testloader = DataLoader(
        datasetLoader(dataset.test,
                      -1,
                      crop=args.crop_img,
                      transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    print("Initializing model: {}".format(args.arch))
    model = models.init_model(
        name=args.arch,
        num_classes=dataset.num_train_pids,
        loss={'xent', 'angular'} if args.use_angular else {'xent'},
        use_gpu=use_gpu)
    print("Model size: {:.3f} M".format(count_num_param(model)))

    if not (args.use_angular):
        if args.label_smooth:
            print("Using Label Smoothing")
            criterion = CrossEntropyLabelSmooth(
                num_classes=dataset.num_train_pids, use_gpu=use_gpu)
        else:
            criterion = nn.CrossEntropyLoss()
    else:
        if args.label_smooth:
            print("Using Label Smoothing")
            criterion = AngularLabelSmooth(num_classes=dataset.num_train_pids,
                                           use_gpu=use_gpu)
        else:
            criterion = AngleLoss()
    optimizer = init_optim(args.optim, model.parameters(), args.lr,
                           args.weight_decay)
    if args.scheduler != 0:
        scheduler = lr_scheduler.MultiStepLR(optimizer,
                                             milestones=args.stepsize,
                                             gamma=args.gamma)

    if args.fixbase_epoch > 0:
        if hasattr(model, 'classifier') and isinstance(model.classifier,
                                                       nn.Module):
            optimizer_tmp = init_optim(
                args.optim,
                list(model.classifier.parameters()) +
                list(model.encoder.parameters()), args.fixbase_lr,
                args.weight_decay)
        else:
            print(
                "Warn: model has no attribute 'classifier' and fixbase_epoch is reset to 0"
            )
            args.fixbase_epoch = 0

    if args.load_weights and check_isfile(args.load_weights):
        # load pretrained weights but ignore layers that don't match in size
        checkpoint = torch.load(args.load_weights)
        pretrain_dict = checkpoint['state_dict']
        model_dict = model.state_dict()
        pretrain_dict = {
            k: v
            for k, v in pretrain_dict.items()
            if k in model_dict and model_dict[k].size() == v.size()
        }
        model_dict.update(pretrain_dict)
        model.load_state_dict(model_dict)
        print("Loaded pretrained weights from '{}'".format(args.load_weights))

    if args.resume and check_isfile(args.resume):
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        args.start_epoch = checkpoint['epoch'] + 1
        best_rank1 = checkpoint['rank1']
        print("Loaded checkpoint from '{}'".format(args.resume))
        print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch,
                                                      best_rank1))

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    if args.evaluate:
        print("Evaluate only")
        test_dir = args.save_dir
        if args.save_dir == 'log':
            if args.resume:
                test_dir = os.path.dirname(args.resume)
            else:
                test_dir = os.path.dirname(args.load_weights)
        distmat = test(model,
                       testloader,
                       use_gpu,
                       args,
                       writer=None,
                       epoch=-1,
                       return_distmat=True,
                       draw_tsne=args.draw_tsne,
                       tsne_clusters=args.tsne_labels,
                       use_cosine=args.plot_deltaTheta)

        if args.visualize_ranks:
            visualize_ranked_results(
                distmat,
                dataset,
                save_dir=osp.join(test_dir, 'ranked_results'),
                topk=10,
            )
        if args.plot_deltaTheta:
            plot_deltaTheta(distmat,
                            dataset,
                            save_dir=osp.join(test_dir, 'deltaTheta_results'),
                            min_rank=1)
        return

    writer = SummaryWriter(log_dir=osp.join(args.save_dir, 'tensorboard'))
    start_time = time.time()
    train_time = 0
    best_epoch = args.start_epoch
    print("==> Start training")

    if args.test_rot:
        print("Training only classifier for rotation")
        model = models.init_model(name='rot_tester',
                                  base_model=model,
                                  inplanes=2048,
                                  num_rot_classes=8)
        criterion_rot = nn.CrossEntropyLoss()
        optimizer_rot = init_optim(args.optim, model.fc_rot.parameters(),
                                   args.fixbase_lr, args.weight_decay)
        if use_gpu:
            model = nn.DataParallel(model).cuda()
        try:
            best_epoch = 0
            for epoch in range(0, args.max_epoch):
                start_train_time = time.time()
                train_rotTester(epoch, model, criterion_rot, optimizer_rot,
                                trainloader, use_gpu, writer, args)
                train_time += round(time.time() - start_train_time)

                if args.scheduler != 0:
                    scheduler.step()

                if (epoch + 1) > args.start_eval and args.eval_step > 0 and (
                        epoch + 1) % args.eval_step == 0 or (
                            epoch + 1) == args.max_epoch:
                    if (epoch + 1) == args.max_epoch:
                        if use_gpu:
                            state_dict = model.module.state_dict()
                        else:
                            state_dict = model.state_dict()

                        save_checkpoint(
                            {
                                'state_dict': state_dict,
                                'rank1': -1,
                                'epoch': epoch,
                            }, False,
                            osp.join(
                                args.save_dir, 'beforeTesting_checkpoint_ep' +
                                str(epoch + 1) + '.pth.tar'))
                    print("==> Test")
                    rank1 = test_rotTester(model,
                                           criterion_rot,
                                           queryloader,
                                           galleryloader,
                                           trainloader,
                                           use_gpu,
                                           args,
                                           writer=writer,
                                           epoch=epoch)
                    is_best = rank1 > best_rank1

                    if is_best:
                        best_rank1 = rank1
                        best_epoch = epoch + 1

                    if use_gpu:
                        state_dict = model.module.state_dict()
                    else:
                        state_dict = model.state_dict()

                    save_checkpoint(
                        {
                            'state_dict': state_dict,
                            'rank1': rank1,
                            'epoch': epoch,
                        }, is_best,
                        osp.join(args.save_dir, 'checkpoint_ep' +
                                 str(epoch + 1) + '.pth.tar'))

            print("==> Best Cccuracy {:.1%}, achieved at epoch {}".format(
                best_rank1, best_epoch))

            elapsed = round(time.time() - start_time)
            elapsed = str(datetime.timedelta(seconds=elapsed))
            train_time = str(datetime.timedelta(seconds=train_time))
            print(
                "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}."
                .format(elapsed, train_time))
            return best_rank1, best_epoch
        except KeyboardInterrupt:
            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'rank1': -1,
                    'epoch': epoch,
                }, False,
                osp.join(
                    args.save_dir, 'keyboardInterrupt_checkpoint_ep' +
                    str(epoch + 1) + '.pth.tar'))

        return None, None

    if args.fixbase_epoch > 0:
        print(
            "Train classifier for {} epochs while keeping base network frozen".
            format(args.fixbase_epoch))

        for epoch in range(args.fixbase_epoch):
            start_train_time = time.time()
            train(epoch,
                  model,
                  criterion,
                  optimizer_tmp,
                  trainloader,
                  use_gpu,
                  writer,
                  args,
                  freeze_bn=True)
            train_time += round(time.time() - start_train_time)

        del optimizer_tmp
        print("Now open all layers for training")
    best_epoch = 0
    for epoch in range(args.start_epoch, args.max_epoch):
        start_train_time = time.time()
        train(epoch, model, criterion, optimizer, trainloader, use_gpu, writer,
              args)
        train_time += round(time.time() - start_train_time)

        if args.scheduler != 0:
            scheduler.step()

        if (epoch + 1) > args.start_eval and args.eval_step > 0 and (
                epoch + 1) % args.eval_step == 0 or (epoch +
                                                     1) == args.max_epoch:
            if (epoch + 1) == args.max_epoch:
                if use_gpu:
                    state_dict = model.module.state_dict()
                else:
                    state_dict = model.state_dict()

                save_checkpoint(
                    {
                        'state_dict': state_dict,
                        'rank1': -1,
                        'epoch': epoch,
                    }, False,
                    osp.join(
                        args.save_dir, 'beforeTesting_checkpoint_ep' +
                        str(epoch + 1) + '.pth.tar'))
            print("==> Test")

            rank1 = test(model,
                         testloader,
                         use_gpu,
                         args,
                         writer=writer,
                         epoch=epoch)

            is_best = rank1 > best_rank1

            if is_best:
                best_rank1 = rank1
                best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                }, is_best,
                osp.join(args.save_dir,
                         'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

    print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(
        best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))
    return best_rank1, best_epoch
def main(args):
    args = parser.parse_args(args)
    #global best_rank1
    best_rank1 = -np.inf
    torch.manual_seed(args.seed)
    # np.random.seed(args.seed)
    # random.seed(args.seed)
    if not args.use_avai_gpus:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
    else:
        test_dir = args.save_dir
        if args.save_dir == 'log':
            if args.resume:
                test_dir = os.path.dirname(args.resume)
            else:
                test_dir = os.path.dirname(args.load_weights)
        sys.stdout = Logger(osp.join(test_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("Initializing dataset {}".format(args.dataset))
    dataset = data_manager.init_imgreid_dataset(
        root=args.root,
        name=args.dataset,
        split_id=args.split_id,
        cuhk03_labeled=args.cuhk03_labeled,
        cuhk03_classic_split=args.cuhk03_classic_split,
        split_wild=args.split_wild)

    transform_train = T.Compose([
        T.Random2DTranslation(args.height, args.width),
        #T.Resize((args.height, args.width)),
        T.RandomSizedEarser(),
        T.RandomHorizontalFlip_custom(),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = True if use_gpu else False

    trainloader = DataLoader(
        ImageDataset(dataset.train, transform=transform_train),
        batch_size=args.train_batch,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=True,
    )

    queryloader = DataLoader(
        ImageDataset(dataset.query,
                     transform=transform_test,
                     return_path=args.draw_tsne),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    galleryloader = DataLoader(
        ImageDataset(dataset.gallery,
                     transform=transform_test,
                     return_path=args.draw_tsne),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    print("Initializing model: {}".format(args.arch))
    model = models.init_model(
        name=args.arch,
        num_classes=dataset.num_train_pids,
        loss={'xent', 'angular'} if args.use_angular else {'xent'},
        use_gpu=use_gpu)
    print("Model size: {:.3f} M".format(count_num_param(model)))

    use_autoTune = False
    if not (args.use_angular):
        if args.label_smooth:
            print("Using Label Smoothing with epsilon", args.label_epsilon)
            criterion = CrossEntropyLabelSmooth(
                num_classes=dataset.num_train_pids,
                epsilon=args.label_epsilon,
                use_gpu=use_gpu)
        elif args.focal_loss:
            print("Using Focal Loss with gamma=", args.focal_gamma)
            criterion = FocalLoss(gamma=args.focal_gamma)
        else:
            print("Using Normal Cross-Entropy")
            criterion = nn.CrossEntropyLoss()

        if args.jsd:
            print("Using JSD regularizer")
            criterion = (criterion, JSD_loss(dataset.num_train_pids))
            if args.auto_tune_mtl:
                print("Using AutoTune")
                use_autoTune = True
                criterion = MultiHeadLossAutoTune(
                    list(criterion),
                    [args.lambda_xent, args.confidence_beta]).cuda()
        else:
            if args.confidence_penalty:
                print("Using Confidence Penalty", args.confidence_beta)
            criterion = (criterion, ConfidencePenalty())
            if args.auto_tune_mtl and args.confidence_penalty:
                print("Using AutoTune")
                use_autoTune = True
                criterion = MultiHeadLossAutoTune(
                    list(criterion),
                    [args.lambda_xent, -args.confidence_beta]).cuda()
    else:
        if args.label_smooth:
            print("Using Angular Label Smoothing")
            criterion = AngularLabelSmooth(num_classes=dataset.num_train_pids,
                                           use_gpu=use_gpu)

        else:
            print("Using Angular Loss")
            criterion = AngleLoss()
    if use_autoTune:
        optimizer = init_optim(
            args.optim,
            list(model.parameters()) + list(criterion.parameters()), args.lr,
            args.weight_decay)
    else:
        optimizer = init_optim(args.optim, model.parameters(), args.lr,
                               args.weight_decay)
    if args.scheduler:
        scheduler = lr_scheduler.MultiStepLR(optimizer,
                                             milestones=args.stepsize,
                                             gamma=args.gamma)

    if args.fixbase_epoch > 0:
        if hasattr(model, 'classifier') and isinstance(model.classifier,
                                                       nn.Module):
            if use_autoTune:
                optimizer_tmp = init_optim(
                    args.optim,
                    list(model.classifier.parameters()) +
                    list(criterion.parameters()), args.fixbase_lr,
                    args.weight_decay)
            else:
                optimizer_tmp = init_optim(args.optim,
                                           model.classifier.parameters(),
                                           args.fixbase_lr, args.weight_decay)
        else:
            print(
                "Warn: model has no attribute 'classifier' and fixbase_epoch is reset to 0"
            )
            args.fixbase_epoch = 0

    if args.load_weights and check_isfile(args.load_weights):
        # load pretrained weights but ignore layers that don't match in size
        checkpoint = torch.load(args.load_weights)
        pretrain_dict = checkpoint['state_dict']
        model_dict = model.state_dict()
        pretrain_dict = {
            k: v
            for k, v in pretrain_dict.items()
            if k in model_dict and model_dict[k].size() == v.size()
        }
        model_dict.update(pretrain_dict)
        model.load_state_dict(model_dict)
        print("Loaded pretrained weights from '{}'".format(args.load_weights))

    if args.resume and check_isfile(args.resume):
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        args.start_epoch = checkpoint['epoch'] + 1
        best_rank1 = checkpoint['rank1']
        print("Loaded checkpoint from '{}'".format(args.resume))
        print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch,
                                                      best_rank1))

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    if args.single_folder != '':
        extract_features(model,
                         use_gpu,
                         args,
                         transform_test,
                         return_distmat=False)
        return
    if args.evaluate:
        print("Evaluate only")
        test_dir = args.save_dir
        if args.save_dir == 'log':
            if args.resume:
                test_dir = os.path.dirname(args.resume)
            else:
                test_dir = os.path.dirname(args.load_weights)
        distmat = test(model,
                       queryloader,
                       galleryloader,
                       use_gpu,
                       args,
                       writer=None,
                       epoch=-1,
                       return_distmat=True,
                       tsne_clusters=args.tsne_labels)

        if args.visualize_ranks:
            visualize_ranked_results(
                distmat,
                dataset,
                save_dir=osp.join(test_dir, 'ranked_results'),
                topk=10,
            )
        return

    writer = SummaryWriter(log_dir=osp.join(args.save_dir, 'tensorboard'))
    start_time = time.time()
    train_time = 0
    best_epoch = args.start_epoch
    print("==> Start training")

    if args.fixbase_epoch > 0:
        print(
            "Train classifier for {} epochs while keeping base network frozen".
            format(args.fixbase_epoch))

        for epoch in range(args.fixbase_epoch):
            start_train_time = time.time()
            train(epoch,
                  model,
                  criterion,
                  optimizer_tmp,
                  trainloader,
                  use_gpu,
                  writer,
                  args,
                  freeze_bn=True)
            train_time += round(time.time() - start_train_time)

        del optimizer_tmp
        print("Now open all layers for training")
    best_epoch = 0
    for epoch in range(args.start_epoch, args.max_epoch):
        start_train_time = time.time()
        train(epoch, model, criterion, optimizer, trainloader, use_gpu, writer,
              args)
        train_time += round(time.time() - start_train_time)

        if args.scheduler:
            scheduler.step()

        if (epoch + 1) > args.start_eval and (
            (args.save_epoch > 0 and (epoch + 1) % args.save_epoch == 0) or
            (args.eval_step > 0 and (epoch + 1) % args.eval_step == 0) or
            (epoch + 1) == args.max_epoch):
            if (epoch + 1) == args.max_epoch:
                if use_gpu:
                    state_dict = model.module.state_dict()
                else:
                    state_dict = model.state_dict()

                save_checkpoint(
                    {
                        'state_dict': state_dict,
                        'rank1': -1,
                        'epoch': epoch,
                    }, False,
                    osp.join(
                        args.save_dir, 'beforeTesting_checkpoint_ep' +
                        str(epoch + 1) + '.pth.tar'))
            is_best = False
            rank1 = -1
            if args.eval_step > 0:
                print("==> Test")

                rank1 = test(model,
                             queryloader,
                             galleryloader,
                             use_gpu,
                             args,
                             writer=writer,
                             epoch=epoch)

                is_best = rank1 > best_rank1

                if is_best:
                    best_rank1 = rank1
                    best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                }, is_best,
                osp.join(args.save_dir,
                         'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

    print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(
        best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))
    return best_rank1, best_epoch
def main():
    global args, best_rank1

    torch.manual_seed(args.seed)
    if not args.use_avai_gpus:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
    else:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("Initializing dataset {}".format(args.dataset))
    dataset = data_manager.init_vidreid_dataset(root=args.root,
                                                name=args.dataset)

    transform_train = T.Compose([
        T.Random2DTranslation(args.height, args.width),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = True if use_gpu else False

    # decompose tracklets into images for image-based training
    new_train = []
    for img_paths, pid, camid in dataset.train:
        for img_path in img_paths:
            new_train.append((img_path, pid, camid))

    trainloader = DataLoader(
        ImageDataset(new_train, transform=transform_train),
        sampler=RandomIdentitySampler(new_train, args.train_batch,
                                      args.num_instances),
        batch_size=args.train_batch,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=True,
    )

    queryloader = DataLoader(
        VideoDataset(dataset.query,
                     seq_len=args.seq_len,
                     sample='evenly',
                     transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    galleryloader = DataLoader(
        VideoDataset(dataset.gallery,
                     seq_len=args.seq_len,
                     sample='evenly',
                     transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    print("Initializing model: {}".format(args.arch))
    model = models.init_model(name=args.arch,
                              num_classes=dataset.num_train_pids,
                              loss={'xent', 'htri'})
    print("Model size: {:.3f} M".format(count_num_param(model)))

    if args.label_smooth:
        criterion_xent = CrossEntropyLabelSmooth(
            num_classes=dataset.num_train_pids, use_gpu=use_gpu)
    else:
        criterion_xent = nn.CrossEntropyLoss()
    criterion_htri = TripletLoss(margin=args.margin)

    optimizer = init_optim(args.optim, model.parameters(), args.lr,
                           args.weight_decay)
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=args.stepsize,
                                         gamma=args.gamma)

    if args.load_weights and check_isfile(args.load_weights):
        # load pretrained weights but ignore layers that don't match in size
        checkpoint = torch.load(args.load_weights)
        pretrain_dict = checkpoint['state_dict']
        model_dict = model.state_dict()
        pretrain_dict = {
            k: v
            for k, v in pretrain_dict.items()
            if k in model_dict and model_dict[k].size() == v.size()
        }
        model_dict.update(pretrain_dict)
        model.load_state_dict(model_dict)
        print("Loaded pretrained weights from '{}'".format(args.load_weights))

    if args.resume and check_isfile(args.resume):
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        args.start_epoch = checkpoint['epoch'] + 1
        best_rank1 = checkpoint['rank1']
        print("Loaded checkpoint from '{}'".format(args.resume))
        print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch,
                                                      best_rank1))

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    if args.evaluate:
        print("Evaluate only")
        distmat = test(model,
                       queryloader,
                       galleryloader,
                       args.pool,
                       use_gpu,
                       return_distmat=True)
        if args.visualize_ranks:
            visualize_ranked_results(
                distmat,
                dataset,
                save_dir=osp.join(args.save_dir, 'ranked_results'),
                topk=20,
            )
        return

    start_time = time.time()
    train_time = 0
    best_epoch = args.start_epoch
    print("==> Start training")

    for epoch in range(args.start_epoch, args.max_epoch):
        start_train_time = time.time()
        train(epoch, model, criterion_xent, criterion_htri, optimizer,
              trainloader, use_gpu)
        train_time += round(time.time() - start_train_time)

        scheduler.step()

        if (epoch + 1) > args.start_eval and args.eval_step > 0 and (
                epoch + 1) % args.eval_step == 0 or (epoch +
                                                     1) == args.max_epoch:
            print("==> Test")
            rank1 = test(model, queryloader, galleryloader, args.pool, use_gpu)
            is_best = rank1 > best_rank1

            if is_best:
                best_rank1 = rank1
                best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                }, is_best,
                osp.join(args.save_dir,
                         'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

    print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(
        best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))
Ejemplo n.º 12
0
def main():
    torch.manual_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
    else:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("Initializing dataset {}".format(args.dataset))
    dataset = data_manager.init_imgreid_dataset(root=args.root,
                                                name=args.dataset,
                                                split_id=args.split_id)

    transform_train = T.Compose([
        T.Random2DTranslation(args.height, args.width),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = True if use_gpu else False

    trainloader = DataLoader(
        ImageDataset(dataset.train, transform=transform_train),
        batch_size=args.train_batch,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=True,
    )

    testloader = DataLoader(
        ImageDataset(dataset.test, transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    print("Initializing model: {}".format(args.arch))
    model = models.init_model(name=args.arch, loss={'xent'}, use_gpu=use_gpu)
    print("Model size: {:.3f} M".format(count_num_param(model)))

    gender_criterion_xent = nn.CrossEntropyLoss()
    staff_criterion_xent = nn.CrossEntropyLoss()
    customer_criterion_xent = nn.CrossEntropyLoss()
    stand_criterion_xent = nn.CrossEntropyLoss()
    sit_criterion_xent = nn.CrossEntropyLoss()
    phone_criterion_xent = nn.CrossEntropyLoss()

    optimizer = init_optim(args.optim, model.parameters(), args.lr,
                           args.weight_decay)
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=args.stepsize,
                                         gamma=args.gamma)

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    start_time = time.time()
    train_time = 0
    best_score = 0
    best_epoch = args.start_epoch
    print("==> Start training")

    ################################### 修改到这里,把train 和 test改一下就好
    for epoch in range(args.start_epoch, args.max_epoch):
        start_train_time = time.time()
        train(epoch, model, gender_criterion_xent, staff_criterion_xent, customer_criterion_xent, \
              stand_criterion_xent, sit_criterion_xent, phone_criterion_xent, optimizer, trainloader, use_gpu)
        train_time += round(time.time() - start_train_time)

        scheduler.step()

        if (epoch + 1) > args.start_eval and args.eval_step > 0 and (
                epoch + 1) % args.eval_step == 0 or (epoch +
                                                     1) == args.max_epoch:
            print("==> Test")
            gender_accurary, staff_accurary, customer_accurary, stand_accurary, sit_accurary, phone_accurary = test(
                model, testloader, use_gpu)
            Score = (gender_accurary + staff_accurary + customer_accurary +
                     stand_accurary + sit_accurary + phone_accurary) * 100
            is_best = Score > best_score

            if is_best:
                best_score = Score
                best_gender_acc = gender_accurary
                best_staff_acc = staff_accurary
                best_customer_acc = customer_accurary
                best_stand_acc = stand_accurary
                best_sit_acc = sit_accurary
                best_phone_acc = phone_accurary
                best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'rank1': Score,
                    'epoch': epoch,
                }, is_best,
                osp.join(args.save_dir,
                         'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

    print(
        "==> Best best_score {} |Gender_acc {}\t Staff_acc {}\t Customer_acc {}\t Stand_acc {}\t Sit_acc {}\t Phone_acc {}|achieved at epoch {}"
        .format(best_score, best_gender_acc, best_staff_acc, best_customer_acc,
                best_stand_acc, best_sit_acc, best_phone_acc, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))
Ejemplo n.º 13
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--snap_shot',
        type=str,
        default='saved-models/densenet121_xent_market1501.pth.tar')
    parser.add_argument('--arch', type=str, default='densenet121')
    parser.add_argument('--dataset-path',
                        type=str,
                        default='data/valset/valSet')
    parser.add_argument('--height',
                        type=int,
                        default=256,
                        help="height of an image (default: 256)")
    parser.add_argument('--width',
                        type=int,
                        default=128,
                        help="width of an image (default: 128)")
    parser.add_argument('--test-batch',
                        default=100,
                        type=int,
                        help="test batch size")
    parser.add_argument('-j',
                        '--workers',
                        default=4,
                        type=int,
                        help="number of data loading workers (default: 4)")
    parser.add_argument('--log-dir', type=str, default='log/eval_625')
    parser.add_argument('--gpu', type=int, default=1)
    args = parser.parse_args()

    pin_memory = True if args.gpu else False

    print("Initializing model: {}".format(args.arch))
    model = models.init_model(name=args.arch,
                              num_classes=751,
                              loss={'xent'},
                              use_gpu=args.gpu).cuda()
    print("Model size: {:.3f} M".format(count_num_param(model)))

    checkpoint = torch.load(args.snap_shot)
    pretrain_dict = checkpoint['state_dict']
    model_dict = model.state_dict()
    pretrain_dict = {
        k: v
        for k, v in pretrain_dict.items()
        if k in model_dict and model_dict[k].size() == v.size()
    }
    model_dict.update(pretrain_dict)
    model.load_state_dict(model_dict)
    print("Loaded pretrained weights from '{}'".format(args.snap_shot))

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    queryloader = DataLoader(
        evalDataset(os.path.join(args.dataset_path, 'query'),
                    transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    galleryloader = DataLoader(
        evalDataset(os.path.join(args.dataset_path, 'gallery'),
                    transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    dataloaders = {'query': queryloader, 'gallery': galleryloader}

    for dataset in ['val']:
        for subset in ['query', 'gallery']:
            test_names, test_features = extractor(model, dataloaders[subset])
            results = {'names': test_names, 'features': test_features.numpy()}
            scipy.io.savemat(
                os.path.join(args.log_dir,
                             'feature_%s_%s.mat' % (dataset, subset)), results)