def run(args): assert args.voc12_root is not None assert args.train_list is not None assert args.cam_weights_name is not None assert args.cam_network is not None assert args.cam_out_dir is not None assert args.cam_network_module is not None num_classes = 0 if args.category_name == 'make_id': num_classes = 75 if args.category_name == 'model_id': num_classes = 431 if args.category_name == 'year': num_classes = 16 model = getattr(importlib.import_module(args.cam_network_module), args.cam_network + 'CAM')(num_classes=num_classes) if use_gpu: model.load_state_dict(torch.load(args.cam_weights_name + '.pth'), strict=True) else: model.load_state_dict(torch.load(args.cam_weights_name + '.pth', map_location=torch.device('cpu')), strict=True) model.eval() dataset = dataloader.VOC12ClassificationDatasetMSF( args.train_list, voc12_root=args.voc12_root, scales=args.cam_scales, category_name=args.category_name) print('[ ', end='') if use_gpu: n_gpus = torch.cuda.device_count() dataset = torchutils.split_dataset(dataset, n_gpus) multiprocessing.spawn(_work_gpu, nprocs=n_gpus, args=(model, dataset, args), join=True) else: dataset = torchutils.split_dataset(dataset, args.num_workers) multiprocessing.spawn(_work_cpu, nprocs=args.num_workers, args=(model, dataset, args), join=True) print(']') torch.cuda.empty_cache()
def run(args): assert args.voc12_root is not None assert args.class_label_dict_path is not None assert args.infer_list is not None assert args.sem_seg_out_dir is not None assert args.irn_weights_name is not None assert args.cam_out_dir is not None assert args.irn_network is not None assert args.irn_network_module is not None model = getattr(importlib.import_module(args.irn_network_module), args.irn_network + 'EdgeDisplacement')() model.load_state_dict(torch.load(args.irn_weights_name), strict=False) model.eval() dataset = dataloader.VOC12ClassificationDatasetMSF( args.infer_list, voc12_root=args.voc12_root, scales=(1.0,), class_label_dict_path=args.class_label_dict_path) print("[", end='') if use_gpu: n_gpus = torch.cuda.device_count() dataset = torchutils.split_dataset(dataset, n_gpus) multiprocessing.spawn(_work_gpu, nprocs=n_gpus, args=(model, dataset, args), join=True) else: dataset = torchutils.split_dataset(dataset, args.num_workers) multiprocessing.spawn(_work_cpu, nprocs=args.num_workers, args=(model, dataset, args), join=True) print("]") torch.cuda.empty_cache()
def run(args): assert args.voc12_root is not None assert args.train_list is not None assert args.ir_label_out_dir is not None assert args.cam_out_dir is not None dataset = dataloader.VOC12ImageDataset(args.train_list, voc12_root=args.voc12_root, img_normal=None, to_torch=False) dataset = torchutils.split_dataset(dataset, args.num_workers) print('[ ', end='') multiprocessing.spawn(_work, nprocs=args.num_workers, args=(dataset, args), join=True) print(']')