def main(): # Parse flags config = forge.config() fet.EXPERIMENT_FOLDER = config.model_dir fet.FPRINT_FILE = 'fid_evaluation.txt' config.shuffle_test = True # Fix seeds. Always first thing to be done after parsing the config! torch.manual_seed(config.seed) np.random.seed(config.seed) random.seed(config.seed) # Make CUDA operations deterministic torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Using GPU? if torch.cuda.is_available() and config.gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') else: config.gpu = False torch.set_default_tensor_type('torch.FloatTensor') fet.print_flags() # Load data _, _, test_loader = fet.load(config.data_config, config) # Load model flag_path = osp.join(config.model_dir, 'flags.json') fprint(f"Restoring flags from {flag_path}") pretrained_flags = AttrDict(fet.json_load(flag_path)) model = fet.load(config.model_config, pretrained_flags) model_path = osp.join(config.model_dir, config.model_file) fprint(f"Restoring model from {model_path}") checkpoint = torch.load(model_path, map_location='cpu') model_state_dict = checkpoint['model_state_dict'] model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1', None) model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2', None) model.load_state_dict(model_state_dict) fprint(model) # Put model on GPU if config.gpu: model = model.cuda() # Compute FID fid_from_model(model, test_loader, config.batch_size, config.num_fid_images, config.feat_dim, config.img_dir)
def main(): # Parse flags config = forge.config() # Restore flags of pretrained model flag_path = osp.join(config.model_dir, 'flags.json') fprint(f"Restoring flags from {flag_path}") pretrained_flags = AttrDict(fet.json_load(flag_path)) pretrained_flags.batch_size = 1 pretrained_flags.gpu = False pretrained_flags.debug = True fet.print_flags() # Fix seeds. Always first thing to be done after parsing the config! torch.manual_seed(0) np.random.seed(0) random.seed(0) # Make CUDA operations deterministic torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Load model model = fet.load(config.model_config, pretrained_flags) model_path = osp.join(config.model_dir, config.model_file) fprint(f"Restoring model from {model_path}") checkpoint = torch.load(model_path, map_location='cpu') model_state_dict = checkpoint['model_state_dict'] model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1', None) model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2', None) model.load_state_dict(model_state_dict) fprint(model) # Visualise model.eval() for _ in range(100): y, stats = model.sample(1, pretrained_flags.K_steps) fig, axes = plt.subplots(nrows=4, ncols=1 + pretrained_flags.K_steps) # Generated plot(axes, 0, 0, y, title='Generated scene', fontsize=12) # Empty plots plot(axes, 1, 0, fontsize=12) plot(axes, 2, 0, fontsize=12) plot(axes, 3, 0, fontsize=12) # Put K generation steps in separate subfigures for step in range(pretrained_flags.K_steps): x_step = stats['x_k'][step] m_step = stats['log_m_k'][step].exp() mx_step = stats['mx_k'][step] if 'log_s_k' in stats: s_step = stats['log_s_k'][step].exp() pre = 'Mask x RGB ' if step == 0 else '' plot(axes, 0, 1 + step, mx_step, pre + f'k={step+1}', fontsize=12) pre = 'RGB ' if step == 0 else '' plot(axes, 1, 1 + step, x_step, pre + f'k={step+1}', fontsize=12) pre = 'Mask ' if step == 0 else '' plot(axes, 2, 1 + step, m_step, pre + f'k={step+1}', True, fontsize=12) if 'log_s_k' in stats: pre = 'Scope ' if step == 0 else '' plot(axes, 3, 1 + step, s_step, pre + f'k={step+1}', True, axis=step == 0, fontsize=12) # Beautify and show figure plt.subplots_adjust(wspace=0.05, hspace=0.05) manager = plt.get_current_fig_manager() manager.resize(*manager.window.maxsize()) plt.show()
def main(): # Parse flags config = forge.config() config.batch_size = 1 config.load_instances = True fet.print_flags() # Restore original model flags pretrained_flags = AttrDict( fet.json_load(os.path.join(config.model_dir, 'flags.json'))) # Get validation loader train_loader, val_loader, test_loader = fet.load(config.data_config, config) fprint(f"Split: {config.split}") if config.split == 'train': batch_loader = train_loader elif config.split == 'val': batch_loader = val_loader elif config.split == 'test': batch_loader = test_loader # Shuffle and prefetch to get same data for different models if 'gqn' not in config.data_config: batch_loader = torch.utils.data.DataLoader(batch_loader.dataset, batch_size=1, num_workers=0, shuffle=True) # Prefetch batches prefetched_batches = [] for i, x in enumerate(batch_loader): if i == config.num_images: break prefetched_batches.append(x) # Load model model = fet.load(config.model_config, pretrained_flags) fprint(model) model_path = os.path.join(config.model_dir, config.model_file) fprint(f"Restoring model from {model_path}") checkpoint = torch.load(model_path, map_location='cpu') model_state_dict = checkpoint['model_state_dict'] model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1', None) model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2', None) model.load_state_dict(model_state_dict) # Set experiment folder and fprint file for logging fet.EXPERIMENT_FOLDER = config.model_dir fet.FPRINT_FILE = 'segmentation_metrics.txt' # Compute metrics model.eval() ari_fg_list, sc_fg_list, msc_fg_list = [], [], [] with torch.no_grad(): for i, x in enumerate(tqdm(prefetched_batches)): _, _, stats, _, _ = model(x['input']) # ARI ari_fg, _ = average_ari(stats.log_m_k, x['instances'], foreground_only=True) # Segmentation covering - foreground only gt_instances = x['instances'].clone() gt_instances[gt_instances == 0] = -100 ins_preds = torch.argmax(torch.stack(stats.log_m_k, dim=1), dim=1) sc_fg = average_segcover(gt_instances, ins_preds) msc_fg = average_segcover(gt_instances, ins_preds, False) # Recording ari_fg_list.append(ari_fg) sc_fg_list.append(sc_fg) msc_fg_list.append(msc_fg) # Print average metrics fprint(f"Average FG ARI: {sum(ari_fg_list)/len(ari_fg_list)}") fprint(f"Average FG SegCover: {sum(sc_fg_list)/len(sc_fg_list)}") fprint(f"Average FG MeanSegCover: {sum(msc_fg_list)/len(msc_fg_list)}")
def main(): # Parse flags config = forge.config() fet.print_flags() # Restore flags of pretrained model flag_path = osp.join(config.model_dir, 'flags.json') fprint(f"Restoring flags from {flag_path}") pretrained_flags = AttrDict(fet.json_load(flag_path)) pretrained_flags.debug = True # Fix seeds. Always first thing to be done after parsing the config! torch.manual_seed(0) np.random.seed(0) random.seed(0) # Make CUDA operations deterministic torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Load data config.batch_size = 1 _, _, test_loader = fet.load(config.data_config, config) # Load model model = fet.load(config.model_config, pretrained_flags) model_path = osp.join(config.model_dir, config.model_file) fprint(f"Restoring model from {model_path}") checkpoint = torch.load(model_path, map_location='cpu') model_state_dict = checkpoint['model_state_dict'] model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1', None) model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2', None) model.load_state_dict(model_state_dict) fprint(model) # Visualise model.eval() for count, batch in enumerate(test_loader): if count >= config.num_images: break # Forward pass output, _, stats, _, _ = model(batch['input']) # Set up figure fig, axes = plt.subplots(nrows=4, ncols=1 + pretrained_flags.K_steps) # Input and reconstruction plot(axes, 0, 0, batch['input'], title='Input image', fontsize=12) plot(axes, 1, 0, output, title='Reconstruction', fontsize=12) # Empty plots plot(axes, 2, 0, fontsize=12) plot(axes, 3, 0, fontsize=12) # Put K reconstruction steps into separate subfigures x_k = stats['x_r_k'] log_m_k = stats['log_m_k'] mx_k = [x * m.exp() for x, m in zip(x_k, log_m_k)] log_s_k = stats['log_s_k'] if 'log_s_k' in stats else None for step in range(pretrained_flags.K_steps): mx_step = mx_k[step] x_step = x_k[step] m_step = log_m_k[step].exp() if log_s_k: s_step = log_s_k[step].exp() pre = 'Mask x RGB ' if step == 0 else '' plot(axes, 0, 1 + step, mx_step, pre + f'k={step+1}', fontsize=12) pre = 'RGB ' if step == 0 else '' plot(axes, 1, 1 + step, x_step, pre + f'k={step+1}', fontsize=12) pre = 'Mask ' if step == 0 else '' plot(axes, 2, 1 + step, m_step, pre + f'k={step+1}', True, fontsize=12) if log_s_k: pre = 'Scope ' if step == 0 else '' plot(axes, 3, 1 + step, s_step, pre + f'k={step+1}', True, axis=step == 0, fontsize=12) # Beautify and show figure plt.subplots_adjust(wspace=0.05, hspace=0.15) manager = plt.get_current_fig_manager() manager.resize(*manager.window.maxsize()) plt.show()