示例#1
0
def main():
    # Parse flags
    config = forge.config()
    fet.EXPERIMENT_FOLDER = config.model_dir
    fet.FPRINT_FILE = 'fid_evaluation.txt'
    config.shuffle_test = True

    # Fix seeds. Always first thing to be done after parsing the config!
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)
    random.seed(config.seed)
    # Make CUDA operations deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Using GPU?
    if torch.cuda.is_available() and config.gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    else:
        config.gpu = False
        torch.set_default_tensor_type('torch.FloatTensor')
    fet.print_flags()

    # Load data
    _, _, test_loader = fet.load(config.data_config, config)

    #  Load model
    flag_path = osp.join(config.model_dir, 'flags.json')
    fprint(f"Restoring flags from {flag_path}")
    pretrained_flags = AttrDict(fet.json_load(flag_path))
    model = fet.load(config.model_config, pretrained_flags)
    model_path = osp.join(config.model_dir, config.model_file)
    fprint(f"Restoring model from {model_path}")
    checkpoint = torch.load(model_path, map_location='cpu')
    model_state_dict = checkpoint['model_state_dict']
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1',
                         None)
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2',
                         None)
    model.load_state_dict(model_state_dict)
    fprint(model)
    # Put model on GPU
    if config.gpu:
        model = model.cuda()

    # Compute FID
    fid_from_model(model, test_loader, config.batch_size,
                   config.num_fid_images, config.feat_dim, config.img_dir)
def main():
    # Parse flags
    config = forge.config()
    # Restore flags of pretrained model
    flag_path = osp.join(config.model_dir, 'flags.json')
    fprint(f"Restoring flags from {flag_path}")
    pretrained_flags = AttrDict(fet.json_load(flag_path))
    pretrained_flags.batch_size = 1
    pretrained_flags.gpu = False
    pretrained_flags.debug = True
    fet.print_flags()

    # Fix seeds. Always first thing to be done after parsing the config!
    torch.manual_seed(0)
    np.random.seed(0)
    random.seed(0)
    # Make CUDA operations deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    #  Load model
    model = fet.load(config.model_config, pretrained_flags)
    model_path = osp.join(config.model_dir, config.model_file)
    fprint(f"Restoring model from {model_path}")
    checkpoint = torch.load(model_path, map_location='cpu')
    model_state_dict = checkpoint['model_state_dict']
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1',
                         None)
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2',
                         None)
    model.load_state_dict(model_state_dict)
    fprint(model)

    # Visualise
    model.eval()
    for _ in range(100):
        y, stats = model.sample(1, pretrained_flags.K_steps)
        fig, axes = plt.subplots(nrows=4, ncols=1 + pretrained_flags.K_steps)

        # Generated
        plot(axes, 0, 0, y, title='Generated scene', fontsize=12)
        # Empty plots
        plot(axes, 1, 0, fontsize=12)
        plot(axes, 2, 0, fontsize=12)
        plot(axes, 3, 0, fontsize=12)

        # Put K generation steps in separate subfigures
        for step in range(pretrained_flags.K_steps):
            x_step = stats['x_k'][step]
            m_step = stats['log_m_k'][step].exp()
            mx_step = stats['mx_k'][step]
            if 'log_s_k' in stats:
                s_step = stats['log_s_k'][step].exp()
            pre = 'Mask x RGB ' if step == 0 else ''
            plot(axes, 0, 1 + step, mx_step, pre + f'k={step+1}', fontsize=12)
            pre = 'RGB ' if step == 0 else ''
            plot(axes, 1, 1 + step, x_step, pre + f'k={step+1}', fontsize=12)
            pre = 'Mask ' if step == 0 else ''
            plot(axes,
                 2,
                 1 + step,
                 m_step,
                 pre + f'k={step+1}',
                 True,
                 fontsize=12)
            if 'log_s_k' in stats:
                pre = 'Scope ' if step == 0 else ''
                plot(axes,
                     3,
                     1 + step,
                     s_step,
                     pre + f'k={step+1}',
                     True,
                     axis=step == 0,
                     fontsize=12)

        # Beautify and show figure
        plt.subplots_adjust(wspace=0.05, hspace=0.05)
        manager = plt.get_current_fig_manager()
        manager.resize(*manager.window.maxsize())
        plt.show()
示例#3
0
def main():
    # Parse flags
    config = forge.config()
    config.batch_size = 1
    config.load_instances = True
    fet.print_flags()

    # Restore original model flags
    pretrained_flags = AttrDict(
        fet.json_load(os.path.join(config.model_dir, 'flags.json')))

    # Get validation loader
    train_loader, val_loader, test_loader = fet.load(config.data_config,
                                                     config)
    fprint(f"Split: {config.split}")
    if config.split == 'train':
        batch_loader = train_loader
    elif config.split == 'val':
        batch_loader = val_loader
    elif config.split == 'test':
        batch_loader = test_loader
    # Shuffle and prefetch to get same data for different models
    if 'gqn' not in config.data_config:
        batch_loader = torch.utils.data.DataLoader(batch_loader.dataset,
                                                   batch_size=1,
                                                   num_workers=0,
                                                   shuffle=True)
    # Prefetch batches
    prefetched_batches = []
    for i, x in enumerate(batch_loader):
        if i == config.num_images:
            break
        prefetched_batches.append(x)

    # Load model
    model = fet.load(config.model_config, pretrained_flags)
    fprint(model)
    model_path = os.path.join(config.model_dir, config.model_file)
    fprint(f"Restoring model from {model_path}")
    checkpoint = torch.load(model_path, map_location='cpu')
    model_state_dict = checkpoint['model_state_dict']
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1',
                         None)
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2',
                         None)
    model.load_state_dict(model_state_dict)

    # Set experiment folder and fprint file for logging
    fet.EXPERIMENT_FOLDER = config.model_dir
    fet.FPRINT_FILE = 'segmentation_metrics.txt'

    # Compute metrics
    model.eval()
    ari_fg_list, sc_fg_list, msc_fg_list = [], [], []
    with torch.no_grad():
        for i, x in enumerate(tqdm(prefetched_batches)):
            _, _, stats, _, _ = model(x['input'])
            # ARI
            ari_fg, _ = average_ari(stats.log_m_k,
                                    x['instances'],
                                    foreground_only=True)
            # Segmentation covering - foreground only
            gt_instances = x['instances'].clone()
            gt_instances[gt_instances == 0] = -100
            ins_preds = torch.argmax(torch.stack(stats.log_m_k, dim=1), dim=1)
            sc_fg = average_segcover(gt_instances, ins_preds)
            msc_fg = average_segcover(gt_instances, ins_preds, False)
            # Recording
            ari_fg_list.append(ari_fg)
            sc_fg_list.append(sc_fg)
            msc_fg_list.append(msc_fg)

    # Print average metrics
    fprint(f"Average FG ARI: {sum(ari_fg_list)/len(ari_fg_list)}")
    fprint(f"Average FG SegCover: {sum(sc_fg_list)/len(sc_fg_list)}")
    fprint(f"Average FG MeanSegCover: {sum(msc_fg_list)/len(msc_fg_list)}")
示例#4
0
def main():
    # Parse flags
    config = forge.config()
    fet.print_flags()
    # Restore flags of pretrained model
    flag_path = osp.join(config.model_dir, 'flags.json')
    fprint(f"Restoring flags from {flag_path}")
    pretrained_flags = AttrDict(fet.json_load(flag_path))
    pretrained_flags.debug = True

    # Fix seeds. Always first thing to be done after parsing the config!
    torch.manual_seed(0)
    np.random.seed(0)
    random.seed(0)
    # Make CUDA operations deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Load data
    config.batch_size = 1
    _, _, test_loader = fet.load(config.data_config, config)

    # Load model
    model = fet.load(config.model_config, pretrained_flags)
    model_path = osp.join(config.model_dir, config.model_file)
    fprint(f"Restoring model from {model_path}")
    checkpoint = torch.load(model_path, map_location='cpu')
    model_state_dict = checkpoint['model_state_dict']
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1',
                         None)
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2',
                         None)
    model.load_state_dict(model_state_dict)
    fprint(model)

    # Visualise
    model.eval()
    for count, batch in enumerate(test_loader):
        if count >= config.num_images:
            break

        # Forward pass
        output, _, stats, _, _ = model(batch['input'])
        # Set up figure
        fig, axes = plt.subplots(nrows=4, ncols=1 + pretrained_flags.K_steps)

        # Input and reconstruction
        plot(axes, 0, 0, batch['input'], title='Input image', fontsize=12)
        plot(axes, 1, 0, output, title='Reconstruction', fontsize=12)
        # Empty plots
        plot(axes, 2, 0, fontsize=12)
        plot(axes, 3, 0, fontsize=12)

        # Put K reconstruction steps into separate subfigures
        x_k = stats['x_r_k']
        log_m_k = stats['log_m_k']
        mx_k = [x * m.exp() for x, m in zip(x_k, log_m_k)]
        log_s_k = stats['log_s_k'] if 'log_s_k' in stats else None
        for step in range(pretrained_flags.K_steps):
            mx_step = mx_k[step]
            x_step = x_k[step]
            m_step = log_m_k[step].exp()
            if log_s_k:
                s_step = log_s_k[step].exp()

            pre = 'Mask x RGB ' if step == 0 else ''
            plot(axes, 0, 1 + step, mx_step, pre + f'k={step+1}', fontsize=12)
            pre = 'RGB ' if step == 0 else ''
            plot(axes, 1, 1 + step, x_step, pre + f'k={step+1}', fontsize=12)
            pre = 'Mask ' if step == 0 else ''
            plot(axes,
                 2,
                 1 + step,
                 m_step,
                 pre + f'k={step+1}',
                 True,
                 fontsize=12)
            if log_s_k:
                pre = 'Scope ' if step == 0 else ''
                plot(axes,
                     3,
                     1 + step,
                     s_step,
                     pre + f'k={step+1}',
                     True,
                     axis=step == 0,
                     fontsize=12)

        # Beautify and show figure
        plt.subplots_adjust(wspace=0.05, hspace=0.15)
        manager = plt.get_current_fig_manager()
        manager.resize(*manager.window.maxsize())
        plt.show()