Beispiel #1
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    train_transform = utils.get_train_transform(args.height,
                                                args.width,
                                                args.train_resizing,
                                                random_horizontal_flip=True,
                                                random_color_jitter=False,
                                                random_gray_scale=False,
                                                random_erasing=True)
    val_transform = utils.get_val_transform(args.height, args.width)
    print("train_transform: ", train_transform)
    print("val_transform: ", val_transform)

    working_dir = osp.dirname(osp.abspath(__file__))
    source_root = osp.join(working_dir, args.source_root)
    target_root = osp.join(working_dir, args.target_root)

    # source dataset
    source_dataset = datasets.__dict__[args.source](
        root=osp.join(source_root, args.source.lower()))
    sampler = RandomMultipleGallerySampler(source_dataset.train,
                                           args.num_instances)
    train_source_loader = DataLoader(convert_to_pytorch_dataset(
        source_dataset.train,
        root=source_dataset.images_dir,
        transform=train_transform),
                                     batch_size=args.batch_size,
                                     num_workers=args.workers,
                                     sampler=sampler,
                                     pin_memory=True,
                                     drop_last=True)
    train_source_iter = ForeverDataIterator(train_source_loader)
    cluster_source_loader = DataLoader(convert_to_pytorch_dataset(
        source_dataset.train,
        root=source_dataset.images_dir,
        transform=val_transform),
                                       batch_size=args.batch_size,
                                       num_workers=args.workers,
                                       shuffle=False,
                                       pin_memory=True)
    val_loader = DataLoader(convert_to_pytorch_dataset(
        list(set(source_dataset.query) | set(source_dataset.gallery)),
        root=source_dataset.images_dir,
        transform=val_transform),
                            batch_size=args.batch_size,
                            num_workers=args.workers,
                            shuffle=False,
                            pin_memory=True)

    # target dataset
    target_dataset = datasets.__dict__[args.target](
        root=osp.join(target_root, args.target.lower()))
    cluster_target_loader = DataLoader(convert_to_pytorch_dataset(
        target_dataset.train,
        root=target_dataset.images_dir,
        transform=val_transform),
                                       batch_size=args.batch_size,
                                       num_workers=args.workers,
                                       shuffle=False,
                                       pin_memory=True)
    test_loader = DataLoader(convert_to_pytorch_dataset(
        list(set(target_dataset.query) | set(target_dataset.gallery)),
        root=target_dataset.images_dir,
        transform=val_transform),
                             batch_size=args.batch_size,
                             num_workers=args.workers,
                             shuffle=False,
                             pin_memory=True)

    n_s_classes = source_dataset.num_train_pids
    args.n_classes = n_s_classes + len(target_dataset.train)
    args.n_s_classes = n_s_classes
    args.n_t_classes = len(target_dataset.train)

    # create model
    backbone = models.__dict__[args.arch](pretrained=True)
    pool_layer = nn.Identity() if args.no_pool else None
    model = ReIdentifier(backbone,
                         args.n_classes,
                         finetune=args.finetune,
                         pool_layer=pool_layer)
    features_dim = model.features_dim

    idm_bn_names = filter_layers(args.stage)
    convert_dsbn_idm(model, idm_bn_names, idm=False)

    model = model.to(device)
    model = DataParallel(model)

    # resume from the best checkpoint
    if args.phase != 'train':
        checkpoint = torch.load(logger.get_checkpoint_path('best'),
                                map_location='cpu')
        utils.copy_state_dict(model, checkpoint['model'])

    # analysis the model
    if args.phase == 'analysis':
        # plot t-SNE
        utils.visualize_tsne(source_loader=val_loader,
                             target_loader=test_loader,
                             model=model,
                             filename=osp.join(logger.visualize_directory,
                                               'analysis', 'TSNE.pdf'),
                             device=device)
        # visualize ranked results
        visualize_ranked_results(test_loader,
                                 model,
                                 target_dataset.query,
                                 target_dataset.gallery,
                                 device,
                                 visualize_dir=logger.visualize_directory,
                                 width=args.width,
                                 height=args.height,
                                 rerank=args.rerank)
        return

    if args.phase == 'test':
        print("Test on target domain:")
        validate(test_loader,
                 model,
                 target_dataset.query,
                 target_dataset.gallery,
                 device,
                 cmc_flag=True,
                 rerank=args.rerank)
        return

    # create XBM
    dataset_size = len(source_dataset.train) + len(target_dataset.train)
    memory_size = int(args.ratio * dataset_size)
    xbm = XBM(memory_size, features_dim)

    # initialize source-domain class centroids
    source_feature_dict = extract_reid_feature(cluster_source_loader,
                                               model,
                                               device,
                                               normalize=True)
    source_features_per_id = {}
    for f, pid, _ in source_dataset.train:
        if pid not in source_features_per_id:
            source_features_per_id[pid] = []
        source_features_per_id[pid].append(source_feature_dict[f].unsqueeze(0))
    source_centers = [
        torch.cat(source_features_per_id[pid], 0).mean(0)
        for pid in sorted(source_features_per_id.keys())
    ]
    source_centers = torch.stack(source_centers, 0)
    source_centers = F.normalize(source_centers, dim=1)
    model.module.head.weight.data[0:n_s_classes].copy_(
        source_centers.to(device))

    # save memory
    del source_centers, cluster_source_loader, source_features_per_id

    # define optimizer and lr scheduler
    optimizer = Adam(model.module.get_parameters(base_lr=args.lr,
                                                 rate=args.rate),
                     args.lr,
                     weight_decay=args.weight_decay)
    lr_scheduler = StepLR(optimizer, step_size=args.step_size, gamma=0.1)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        utils.copy_state_dict(model, checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    # start training
    best_test_mAP = 0.
    for epoch in range(args.start_epoch, args.epochs):
        # run clustering algorithm and generate pseudo labels
        train_target_iter = run_dbscan(cluster_target_loader, model,
                                       target_dataset, train_transform, args)

        # train for one epoch
        print(lr_scheduler.get_lr())
        train(train_source_iter, train_target_iter, model, optimizer, xbm,
              epoch, args)

        if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1):
            # remember best mAP and save checkpoint
            torch.save(
                {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch
                }, logger.get_checkpoint_path(epoch))
            print("Test on target domain...")
            _, test_mAP = validate(test_loader,
                                   model,
                                   target_dataset.query,
                                   target_dataset.gallery,
                                   device,
                                   cmc_flag=True,
                                   rerank=args.rerank)
            if test_mAP > best_test_mAP:
                shutil.copy(logger.get_checkpoint_path(epoch),
                            logger.get_checkpoint_path('best'))
            best_test_mAP = max(test_mAP, best_test_mAP)

        # update lr
        lr_scheduler.step()

    print("best mAP on target = {}".format(best_test_mAP))
    logger.close()
Beispiel #2
0
 def load_pretrain_params(self):
     success_layer = copy_state_dict(self.state_dict(),
                                     torch.load(config.hrnet_pretrain),
                                     prefix='',
                                     fix_loaded=True)
from utils import masks_to_segmentation, overaly_segmentation, load_segmentation_definition, copy_state_dict

# Command line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", default='./out/weights.pt', type=str)
parser.add_argument("--data_directory", default='./data/original/', type=str)
parser.add_argument("--result_directory", default='./result/', type=str)

args = parser.parse_args()
if not os.path.isdir(args.result_directory):
    os.mkdir(args.result_directory)

model = createDeepLabv3(outputchannels=12, pretrain=True)
# model = torch.nn.DataParallel(model)
train_model = torch.load(args.model_path)
trained_state_dict = copy_state_dict(train_model.state_dict())
model.load_state_dict(trained_state_dict)
model.eval()
model.cuda()

definitions = load_segmentation_definition(
    'data/json/annotation_definitions.json')
s0 = time.time()
with torch.no_grad():
    for step, path in enumerate(glob.glob(args.data_directory + '*.png')):
        im = cv2.imread(path)
        im = cv2.resize(im, (480, 320), cv2.INTER_NEAREST)
        input_tensor = torch.from_numpy(im)
        input_tensor = input_tensor.type(torch.FloatTensor) / 255
        input_tensor = input_tensor.cuda()
        input_tensor = input_tensor.permute(2, 0, 1)
Beispiel #4
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    train_transform = utils.get_train_transform(args.height,
                                                args.width,
                                                args.train_resizing,
                                                random_horizontal_flip=True,
                                                random_color_jitter=False,
                                                random_gray_scale=False,
                                                random_erasing=True)
    val_transform = utils.get_val_transform(args.height, args.width)
    print("train_transform: ", train_transform)
    print("val_transform: ", val_transform)

    working_dir = osp.dirname(osp.abspath(__file__))
    source_root = osp.join(working_dir, args.source_root)
    target_root = osp.join(working_dir, args.target_root)

    # source dataset
    source_dataset = datasets.__dict__[args.source](
        root=osp.join(source_root, args.source.lower()))
    val_loader = DataLoader(convert_to_pytorch_dataset(
        list(set(source_dataset.query) | set(source_dataset.gallery)),
        root=source_dataset.images_dir,
        transform=val_transform),
                            batch_size=args.batch_size,
                            num_workers=args.workers,
                            shuffle=False,
                            pin_memory=True)

    # target dataset
    target_dataset = datasets.__dict__[args.target](
        root=osp.join(target_root, args.target.lower()))
    cluster_loader = DataLoader(convert_to_pytorch_dataset(
        target_dataset.train,
        root=target_dataset.images_dir,
        transform=val_transform),
                                batch_size=args.batch_size,
                                num_workers=args.workers,
                                shuffle=False,
                                pin_memory=True)
    test_loader = DataLoader(convert_to_pytorch_dataset(
        list(set(target_dataset.query) | set(target_dataset.gallery)),
        root=target_dataset.images_dir,
        transform=val_transform),
                             batch_size=args.batch_size,
                             num_workers=args.workers,
                             shuffle=False,
                             pin_memory=True)

    # create model
    num_classes = args.num_clusters
    backbone = utils.get_model(args.arch)
    pool_layer = nn.Identity() if args.no_pool else None
    model = ReIdentifier(backbone,
                         num_classes,
                         finetune=args.finetune,
                         pool_layer=pool_layer).to(device)
    model = DataParallel(model)

    # load pretrained weights
    pretrained_model = torch.load(args.pretrained_model_path)
    utils.copy_state_dict(model, pretrained_model)

    # resume from the best checkpoint
    if args.phase != 'train':
        checkpoint = torch.load(logger.get_checkpoint_path('best'),
                                map_location='cpu')
        utils.copy_state_dict(model, checkpoint['model'])

    # analysis the model
    if args.phase == 'analysis':
        # plot t-SNE
        utils.visualize_tsne(source_loader=val_loader,
                             target_loader=test_loader,
                             model=model,
                             filename=osp.join(logger.visualize_directory,
                                               'analysis', 'TSNE.pdf'),
                             device=device)
        # visualize ranked results
        visualize_ranked_results(test_loader,
                                 model,
                                 target_dataset.query,
                                 target_dataset.gallery,
                                 device,
                                 visualize_dir=logger.visualize_directory,
                                 width=args.width,
                                 height=args.height,
                                 rerank=args.rerank)
        return

    if args.phase == 'test':
        print("Test on Source domain:")
        validate(val_loader,
                 model,
                 source_dataset.query,
                 source_dataset.gallery,
                 device,
                 cmc_flag=True,
                 rerank=args.rerank)
        print("Test on target domain:")
        validate(test_loader,
                 model,
                 target_dataset.query,
                 target_dataset.gallery,
                 device,
                 cmc_flag=True,
                 rerank=args.rerank)
        return

    # define loss function
    criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device)
    criterion_triplet = SoftTripletLoss(margin=args.margin).to(device)

    # optionally resume from a checkpoint
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        utils.copy_state_dict(model, checkpoint['model'])
        args.start_epoch = checkpoint['epoch'] + 1

    # start training
    best_test_mAP = 0.
    for epoch in range(args.start_epoch, args.epochs):
        # run clustering algorithm and generate pseudo labels
        if args.clustering_algorithm == 'kmeans':
            train_target_iter = run_kmeans(cluster_loader, model,
                                           target_dataset, train_transform,
                                           args)
        elif args.clustering_algorithm == 'dbscan':
            train_target_iter, num_classes = run_dbscan(
                cluster_loader, model, target_dataset, train_transform, args)

        # define cross entropy loss with current number of classes
        criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device)

        # define optimizer
        optimizer = Adam(model.module.get_parameters(base_lr=args.lr,
                                                     rate=args.rate),
                         args.lr,
                         weight_decay=args.weight_decay)

        # train for one epoch
        train(train_target_iter, model, optimizer, criterion_ce,
              criterion_triplet, epoch, args)

        if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1):
            # remember best mAP and save checkpoint
            torch.save({
                'model': model.state_dict(),
                'epoch': epoch
            }, logger.get_checkpoint_path(epoch))
            print("Test on target domain...")
            _, test_mAP = validate(test_loader,
                                   model,
                                   target_dataset.query,
                                   target_dataset.gallery,
                                   device,
                                   cmc_flag=True,
                                   rerank=args.rerank)
            if test_mAP > best_test_mAP:
                shutil.copy(logger.get_checkpoint_path(epoch),
                            logger.get_checkpoint_path('best'))
            best_test_mAP = max(test_mAP, best_test_mAP)

    print("best mAP on target = {}".format(best_test_mAP))
    logger.close()
Beispiel #5
0
    alphas = ign_alphas(net, args)

    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    if args.resume is not None:
        checkpoint_path = args.resume
        if not os.path.isfile(checkpoint_path):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epochs = checkpoint['epoch'] + 1
        if torch.cuda.device_count() > 1:
            copy_state_dict(net, checkpoint, parallel=True)
        else:
            copy_state_dict(net, checkpoint)
        # copy_optimizer_state_dict(optimizer, checkpoint['optimizer'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        # copy_optimizer_state_dict(alphas.optimizer, checkpoint['alpha_optimizer'])
        alphas.optimizer.load_state_dict(checkpoint['alpha_optimizer'])

    # params = net.parameters()
    # for param in params:
    #     # print(param)
    #     print(param.shape)
    if args.scheduler == 'step':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[75, 125, 175], gamma=0.1)
    elif args.scheduler == 'cos':