def load_checkpoints(config, checkpoint, blend_scale=0.125, first_order_motion_model=False, cpu=False): with open(config) as f: config = yaml.load(f) reconstruction_module = PartSwapGenerator(blend_scale=blend_scale, first_order_motion_model=first_order_motion_model, **config['model_params']['reconstruction_module_params'], **config['model_params']['common_params']) if not cpu: reconstruction_module.cuda() segmentation_module = SegmentationModule(**config['model_params']['segmentation_module_params'], **config['model_params']['common_params']) if not cpu: segmentation_module.cuda() if cpu: checkpoint = torch.load(checkpoint, map_location=torch.device('cpu')) else: checkpoint = torch.load(checkpoint) load_reconstruction_module(reconstruction_module, checkpoint) load_segmentation_module(segmentation_module, checkpoint) if not cpu: reconstruction_module = DataParallelWithCallback(reconstruction_module) segmentation_module = DataParallelWithCallback(segmentation_module) reconstruction_module.eval() segmentation_module.eval() return reconstruction_module, segmentation_module
required=True, help="path to root folder of the train and test images") parser.add_argument("--checkpoint_path", default=None, help="path to checkpoint to restore") parser.add_argument("--device_ids", default="0", type=lambda x: list(map(int, x.split(','))), help="Names of the devices comma separated.") opt = parser.parse_args() with open(opt.config) as f: config = yaml.load(f) segmentation_module = SegmentationModule( **config['model_params']['segmentation_module_params'], **config['model_params']['common_params']) if torch.cuda.is_available(): segmentation_module.to(opt.device_ids[0]) if opt.checkpoint_path is not None: checkpoint = torch.load(opt.checkpoint_path) load_segmentation_module(segmentation_module, checkpoint) dataset = {} dataset['train'] = FramesDataset(root_dir=opt.root_dir, is_train=True) dataset['test'] = FramesDataset(root_dir=opt.root_dir, is_train=False) evaluate(config, segmentation_module, dataset, opt)