예제 #1
0
def run(args):
    assert args.voc12_root is not None
    assert args.train_list is not None
    assert args.cam_weights_name is not None
    assert args.cam_network is not None
    assert args.cam_out_dir is not None
    assert args.cam_network_module is not None

    num_classes = 0
    if args.category_name == 'make_id':
        num_classes = 75
    if args.category_name == 'model_id':
        num_classes = 431
    if args.category_name == 'year':
        num_classes = 16

    model = getattr(importlib.import_module(args.cam_network_module),
                    args.cam_network + 'CAM')(num_classes=num_classes)
    if use_gpu:
        model.load_state_dict(torch.load(args.cam_weights_name + '.pth'),
                              strict=True)
    else:
        model.load_state_dict(torch.load(args.cam_weights_name + '.pth',
                                         map_location=torch.device('cpu')),
                              strict=True)
    model.eval()
    dataset = dataloader.VOC12ClassificationDatasetMSF(
        args.train_list,
        voc12_root=args.voc12_root,
        scales=args.cam_scales,
        category_name=args.category_name)
    print('[ ', end='')
    if use_gpu:
        n_gpus = torch.cuda.device_count()

        dataset = torchutils.split_dataset(dataset, n_gpus)
        multiprocessing.spawn(_work_gpu,
                              nprocs=n_gpus,
                              args=(model, dataset, args),
                              join=True)
    else:
        dataset = torchutils.split_dataset(dataset, args.num_workers)
        multiprocessing.spawn(_work_cpu,
                              nprocs=args.num_workers,
                              args=(model, dataset, args),
                              join=True)
    print(']')

    torch.cuda.empty_cache()
def run(args):
    assert args.voc12_root is not None
    assert args.class_label_dict_path is not None
    assert args.infer_list is not None
    assert args.sem_seg_out_dir is not None
    assert args.irn_weights_name is not None
    assert args.cam_out_dir is not None
    assert args.irn_network is not None
    assert args.irn_network_module is not None

    model = getattr(importlib.import_module(args.irn_network_module),
                    args.irn_network + 'EdgeDisplacement')()

    model.load_state_dict(torch.load(args.irn_weights_name), strict=False)
    model.eval()

    dataset = dataloader.VOC12ClassificationDatasetMSF(
        args.infer_list,
        voc12_root=args.voc12_root,
        scales=(1.0,),
        class_label_dict_path=args.class_label_dict_path)
    print("[", end='')
    if use_gpu:
        n_gpus = torch.cuda.device_count()

        dataset = torchutils.split_dataset(dataset, n_gpus)

        multiprocessing.spawn(_work_gpu,
                              nprocs=n_gpus,
                              args=(model, dataset, args),
                              join=True)
    else:
        dataset = torchutils.split_dataset(dataset, args.num_workers)
        multiprocessing.spawn(_work_cpu,
                              nprocs=args.num_workers,
                              args=(model, dataset, args),
                              join=True)
    print("]")

    torch.cuda.empty_cache()
예제 #3
0
def run(args):
    assert args.voc12_root is not None
    assert args.train_list is not None
    assert args.ir_label_out_dir is not None
    assert args.cam_out_dir is not None

    dataset = dataloader.VOC12ImageDataset(args.train_list,
                                           voc12_root=args.voc12_root,
                                           img_normal=None,
                                           to_torch=False)
    dataset = torchutils.split_dataset(dataset, args.num_workers)

    print('[ ', end='')
    multiprocessing.spawn(_work,
                          nprocs=args.num_workers,
                          args=(dataset, args),
                          join=True)
    print(']')