Пример #1
0
def puzzle_model_inference(config):
    traindb = composites_coco(config, 'train', '2017')
    valdb = composites_coco(config, 'test', '2017')
    auxdb = composites_coco(config, 'aux', '2017')
    trainer = PuzzleTrainer(traindb)
    t0 = time()

    patch_dir_name = 'patch_feature_' + 'with_bg' if config.use_patch_background else 'without_bg'
    if not osp.exists(osp.join(traindb.root_dir, patch_dir_name)):
        trainer.dump_shape_vectors(traindb)

    all_tables = AllCategoriesTables(traindb)
    all_tables.build_nntables_for_all_categories(True)
    print("NN completes (time %.2fs)" % (time() - t0))
    t0 = time()
    if config.for_visualization:
        trainer.sample_for_vis(0,
                               testdb,
                               len(valdb.scenedb),
                               nn_table=all_tables)
        trainer.sample_for_vis(0,
                               auxdb,
                               len(auxdb.scenedb),
                               nn_table=all_tables)
    else:
        trainer.sample_for_eval(testdb, nn_table=all_tables)
        trainer.sample_for_eval(auxdb, nn_table=all_tables)
    print("Sampling completes (time %.2fs)" % (time() - t0))
Пример #2
0
def test_coco_dataset(config):
    db = composites_coco(config, 'train')
    # with open('category_ind_to_class_ind.json', 'w') as fp:
    #     json.dump(db.category_ind_to_class_ind, fp, indent=4, sort_keys=True)
    valdb = composites_coco(config, 'val')
    testdb = composites_coco(config, 'test')
    auxdb = composites_coco(config, 'aux')
    print(len(db.scenedb), len(valdb.scenedb), len(testdb.scenedb),
          len(auxdb.scenedb))
Пример #3
0
def puzzle_model_inference_preparation(config):
    traindb = composites_coco(config, 'train', '2017')
    testdb = composites_coco(config, 'test', '2017')
    trainer = PuzzleTrainer(traindb)
    t0 = time()
    trainer.dump_shape_vectors(traindb)
    trainer.dump_shape_vectors(testdb)
    print("Dump shape vectors completes (time %.2fs)" % (time() - t0))
    t0 = time()
    all_tables = AllCategoriesTables(traindb)
    all_tables.build_nntables_for_all_categories(False)
    print("NN completes (time %.2fs)" % (time() - t0))
Пример #4
0
def test_txt_encoder_coco(config):
    db = composites_coco(config, 'train', '2017')
    all_tables = AllCategoriesTables(db)
    all_tables.build_nntables_for_all_categories(True)
    sequence_db = sequence_loader(db, all_tables)
    net = TextEncoder(db)

    loader = DataLoader(sequence_db,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.num_workers)

    for cnt, batched in enumerate(loader):
        input_inds = batched['word_inds'].long()
        input_lens = batched['word_lens'].long()

        print('Checking the output shapes')
        out = net(input_inds, input_lens)
        out_embs, out_rfts, out_msks, out_hids = out
        print(out_rfts.size(), out_embs.size(), out_msks.size())
        if isinstance(out_hids, tuple):
            print(out_hids[0].size())
        else:
            print(out_hids.size())
        print('m: ', out_msks[-1])

        print('Checking the embedding')
        embeded = net.embedding(input_inds)
        v1 = embeded[0, 0]
        idx = input_inds[0, 0].data.item()
        v2 = db.lang_vocab.vectors[idx]
        diff = v2 - v1
        print('Diff: (should be zero)', torch.sum(diff.abs_()))

        break
Пример #5
0
def test_coco_decoder(config):
    db = composites_coco(config, 'train', '2017')
    all_tables = AllCategoriesTables(db)
    all_tables.build_nntables_for_all_categories(True)
    sequence_db = sequence_loader(db, all_tables)

    text_encoder = TextEncoder(db)
    img_encoder = VolumeEncoder(config)
    what_decoder = WhatDecoder(config)
    where_decoder = WhereDecoder(config)

    print('txt_encoder', get_n_params(text_encoder))
    print('img_encoder', get_n_params(img_encoder))
    print('what_decoder', get_n_params(what_decoder))
    print('where_decoder', get_n_params(where_decoder))

    loader = DataLoader(sequence_db,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.num_workers)

    for cnt, batched in enumerate(loader):
        word_inds = batched['word_inds'].long()
        word_lens = batched['word_lens'].long()
        bg_imgs = batched['background'].float()

        encoder_states = text_encoder(word_inds, word_lens)
        bg_feats = img_encoder(bg_imgs)
        prev_bgfs = bg_feats[:, 0].unsqueeze(1)
        what_outs = what_decoder((prev_bgfs, None, None), encoder_states)
        obj_logits, rnn_feats_2d, nxt_hids_2d, prev_bgfs, att_ctx, att_wei = what_outs
        print('------------------------------------------')
        print('obj_logits', obj_logits.size())
        print('rnn_feats_2d', rnn_feats_2d.size())
        print('nxt_hids_2d', nxt_hids_2d.size())
        print('prev_bgfs', prev_bgfs.size())
        print('att_ctx', att_ctx.size())
        print('att_wei', att_wei.size())
        print('------------------------------------------')

        _, obj_inds = torch.max(obj_logits + 1.0, dim=-1)
        curr_fgfs = indices2onehots(obj_inds.cpu().data,
                                    config.output_vocab_size)
        # curr_fgfs = curr_fgfs.unsqueeze(1)
        if config.cuda:
            curr_fgfs = curr_fgfs.cuda()

        where_outs = where_decoder(
            (rnn_feats_2d, curr_fgfs, prev_bgfs, att_ctx), encoder_states)
        coord_logits, attri_logits, patch_vectors, where_ctx, where_wei = where_outs
        print('coord_logits ', coord_logits.size())
        print('attri_logits ', attri_logits.size())
        print('patch_vectors', patch_vectors.size())
        # print('att_ctx', where_ctx.size())
        # print('att_wei', where_wei.size())
        break
Пример #6
0
def overfit_synthesis_model(config):
    config.log_per_steps = 1
    traindb = composites_coco(config, 'train', '2017')
    traindb.scenedb = traindb.scenedb[:config.batch_size*3]
    # print('build pca table')
    # pca_table = AllCategoriesTables(traindb)
    # pca_table.run_PCAs_and_build_nntables_in_feature_space()
    # print('create trainer')
    trainer = SynthesisTrainer(config)
    # print('start training')
    trainer.train(traindb, traindb, traindb)
Пример #7
0
def composites_demo(config):
    traindb = composites_coco(config, 'train', '2017')
    trainer = PuzzleTrainer(traindb)
    t0 = time()

    patch_dir_name = 'patch_feature_' + 'with_bg' if config.use_patch_background else 'without_bg'
    if not osp.exists(osp.join(traindb.root_dir, patch_dir_name)):
        trainer.dump_shape_vectors(traindb)

    all_tables = AllCategoriesTables(traindb)
    all_tables.build_nntables_for_all_categories(True)
    print("NN completes (time %.2fs)" % (time() - t0))
    t0 = time()
    input_sentences = json_load('examples/composites_samples.json')
    trainer.sample_demo(input_sentences, all_tables)
    print("Sampling completes (time %.2fs)" % (time() - t0))
Пример #8
0
def test_syn_encoder(config):
    img_encoder = SynthesisEncoder(config)
    print(get_n_params(img_encoder))

    db = composites_coco(config, 'train', '2017')
    syn_loader = synthesis_loader(db)
    loader = DataLoader(syn_loader, batch_size=1, 
        shuffle=False, num_workers=config.num_workers)
    
    start = time()
    for cnt, batched in enumerate(loader):
        x = batched['input_vol'].float()
        y = img_encoder(x)
        for z in y:
            print(z.size())
        break
Пример #9
0
def test_perceptual_loss_network(config):
    img_encoder = VGG19LossNetwork(config).eval()
    print(get_n_params(img_encoder))

    db = composites_coco(config, 'train', '2017')
    syn_loader = synthesis_loader(db)
    loader = DataLoader(syn_loader, batch_size=1, 
        shuffle=False, num_workers=config.num_workers)

    start = time()
    for cnt, batched in enumerate(loader):
        x = batched['gt_image'].float()
        y = img_encoder(x.permute(0,3,1,2))
        for z in y:
            print(z.size())
        break
Пример #10
0
def test_syn_decoder(config):
    img_encoder = SynthesisEncoder(config)
    img_decoder = SynthesisDecoder(config)
    print(get_n_params(img_encoder))
    print(get_n_params(img_decoder))

    db = composites_coco(config, 'train', '2017')
    syn_loader = synthesis_loader(db)
    loader = DataLoader(syn_loader, batch_size=1, 
        shuffle=False, num_workers=config.num_workers)

    start = time()
    for cnt, batched in enumerate(loader):
        x = batched['input_vol'].float()
        x0, x1, x2, x3, x4, x5, x6 = img_encoder(x)
        inputs = (x0, x1, x2, x3, x4, x5, x6)
        image, label = img_decoder(inputs)
        print(image.size(), label.size())
        break
Пример #11
0
def test_vol_encoder(config):
    db = composites_coco(config, 'train', '2017')

    all_tables = AllCategoriesTables(db)
    all_tables.build_nntables_for_all_categories(True)
    sequence_db = sequence_loader(db, all_tables)

    img_encoder = VolumeEncoder(config)
    print(get_n_params(img_encoder))
    # print(img_encoder)

    loader = DataLoader(sequence_db,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.num_workers)

    for cnt, batched in enumerate(loader):
        x = batched['background'].float()
        y = img_encoder(x)
        print('y.size()', y.size())
        break
Пример #12
0
def test_syn_model(config):
    synthesizer = SynthesisModel(config)
    print(get_n_params(synthesizer))

    db = composites_coco(config, 'train', '2017')
    syn_loader = synthesis_loader(db)
    loader = DataLoader(syn_loader, batch_size=1, 
        shuffle=False, num_workers=config.num_workers)

    start = time()
    for cnt, batched in enumerate(loader):
        x = batched['input_vol'].float()
        y = batched['gt_image'].float()
        z = batched['gt_label'].long()
        y = y.permute(0,3,1,2)
        image, label, syn_feats, gt_feats = synthesizer(x, True, y)
        print(image.size(), label.size())
        for v in syn_feats:
            print(v.size())
        print('------------')
        for v in gt_feats:
            print(v.size())
        break
Пример #13
0
def test_puzzle_model(config):
    output_dir = osp.join(config.model_dir, 'test_puzzle_model')
    maybe_create(output_dir)
    plt.switch_backend('agg')

    db = composites_coco(config, 'train', '2017')
    all_tables = AllCategoriesTables(db)
    all_tables.build_nntables_for_all_categories(True)
    sequence_db = sequence_loader(db, all_tables)
    loader = DataLoader(sequence_db,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.num_workers)

    net = PuzzleModel(db)

    net.eval()
    for cnt, batched in enumerate(loader):
        word_inds = batched['word_inds'].long()
        word_lens = batched['word_lens'].long()
        bg_images = batched['background'].float()
        fg_images = batched['foreground'].float()
        neg_images = batched['negative'].float()

        fg_resnets = batched['foreground_resnets'].float()
        neg_resnets = batched['negative_resnets'].float()

        fg_inds = batched['fg_inds'].long()
        gt_inds = batched['out_inds'].long()
        gt_msks = batched['out_msks'].float()

        fg_onehots = indices2onehots(fg_inds, config.output_vocab_size)

        inf_outs, _, positive_feats, negative_feats = net(
            (word_inds, word_lens, bg_images, fg_onehots, fg_images,
             neg_images, fg_resnets, neg_resnets))
        obj_logits, coord_logits, attri_logits, patch_vectors, enc_msks, what_wei, where_wei = inf_outs
        print('teacher forcing')
        print('obj_logits ', obj_logits.size())
        print('coord_logits ', coord_logits.size())
        print('attri_logits ', attri_logits.size())
        print('patch_vectors ', patch_vectors.size())
        print('patch_vectors max:', torch.max(patch_vectors))
        print('patch_vectors min:', torch.min(patch_vectors))
        print('patch_vectors norm:',
              torch.norm(patch_vectors, dim=-2)[0, 0, 0])
        print('positive_feats ', positive_feats.size())
        print('negative_feats ', negative_feats.size())
        if config.what_attn:
            print('what_att_logits ', what_wei.size())
        if config.where_attn > 0:
            print('where_att_logits ', where_wei.size())
        print('----------------------')

        _, pred_vecs = net.collect_logits_and_vectors(inf_outs, gt_inds)
        print('pred_vecs', pred_vecs.size())
        print('*******************')

        # # inf_outs, env = net.inference(word_inds, word_lens, -1, 0.0, 0, gt_inds, gt_vecs)
        # inf_outs, env = net.inference(word_inds, word_lens, -1, 2.0, 0, None, None, all_tables)
        # obj_logits, coord_logits, attri_logits, patch_vectors, enc_msks, what_wei, where_wei = inf_outs
        # print('scheduled sampling')
        # print('obj_logits ', obj_logits.size())
        # print('coord_logits ', coord_logits.size())
        # print('attri_logits ', attri_logits.size())
        # print('patch_vectors ', patch_vectors.size())
        # if config.what_attn:
        #     print('what_att_logits ', what_wei.size())
        # if config.where_attn > 0:
        #     print('where_att_logits ', where_wei.size())
        # print('----------------------')

        # sequences = env.batch_redraw(True)
        # for i in range(len(sequences)):
        #     sequence = sequences[i]
        #     image_idx = batched['image_index'][i]
        #     name = '%03d_'%i + str(image_idx).zfill(12)
        #     out_path = osp.join(output_dir, name+'.png')
        #     color = cv2.imread(batched['color_path'][i], cv2.IMREAD_COLOR)
        #     color, _, _ = create_squared_image(color)

        #     fig = plt.figure(figsize=(32, 16))
        #     plt.suptitle(batched['sentence'][i], fontsize=30)

        #     for j in range(min(len(sequence), 14)):
        #         plt.subplot(3, 5, j+1)
        #         partially_completed_img = clamp_array(sequence[j][:,:,-3:], 0, 255).astype(np.uint8)
        #         partially_completed_img = partially_completed_img[:,:,::-1]
        #         plt.imshow(partially_completed_img)
        #         plt.axis('off')

        #     plt.subplot(3, 5, 15)
        #     plt.imshow(color[:,:,::-1])
        #     plt.axis('off')

        #     fig.savefig(out_path, bbox_inches='tight')
        #     plt.close(fig)

        break
Пример #14
0
def train_synthesis_model(config):
    traindb = composites_coco(config, 'train', '2017')
    valdb   = composites_coco(config, 'val',   '2017')
    testdb  = composites_coco(config, 'test',  '2017')
    trainer = SynthesisTrainer(config)
    trainer.train(traindb, valdb, testdb)
Пример #15
0
def train_puzzle_model(config):
    traindb = composites_coco(config, 'train', '2017')
    valdb = composites_coco(config, 'val', '2017')
    testdb = composites_coco(config, 'test', '2017')
    trainer = PuzzleTrainer(traindb)
    trainer.train(traindb, valdb, testdb)
Пример #16
0
def test_coco_dataloader(config):
    db = composites_coco(config, 'train')

    all_tables = AllCategoriesTables(db)
    all_tables.build_nntables_for_all_categories(True)

    sequence_db = sequence_loader(db, all_tables)
    output_dir = osp.join(config.model_dir, 'test_coco_dataloader')
    maybe_create(output_dir)

    loader = DataLoader(sequence_db,
                        batch_size=config.batch_size,
                        shuffle=True,
                        num_workers=config.num_workers)

    start = time()
    for cnt, batched in enumerate(loader):
        x = batched['background'].float()
        y = batched['foreground'].float()
        z = batched['negative'].float()

        x = sequence_color_volumn_preprocess(x, len(db.classes))
        y = sequence_onehot_volumn_preprocess(y, len(db.classes))
        z = sequence_onehot_volumn_preprocess(z, len(db.classes))

        # cv2.imwrite('mask0.png', y[0,2,:,:,-4].cpu().data.numpy())
        # cv2.imwrite('mask1.png', y[1,2,:,:,-4].cpu().data.numpy())
        # cv2.imwrite('mask2.png', y[2,2,:,:,-4].cpu().data.numpy())
        # cv2.imwrite('mask3.png', y[3,2,:,:,-4].cpu().data.numpy())
        # cv2.imwrite('label0.png', y[0,2,:,:,3].cpu().data.numpy())
        # cv2.imwrite('label1.png', y[1,2,:,:,3].cpu().data.numpy())
        # cv2.imwrite('label2.png', y[2,2,:,:,3].cpu().data.numpy())
        # cv2.imwrite('label3.png', y[3,2,:,:,3].cpu().data.numpy())
        # cv2.imwrite('color0.png', y[0,2,:,:,-3:].cpu().data.numpy())
        # cv2.imwrite('color1.png', y[1,2,:,:,-3:].cpu().data.numpy())
        # cv2.imwrite('color2.png', y[2,2,:,:,-3:].cpu().data.numpy())
        # cv2.imwrite('color3.png', y[3,2,:,:,-3:].cpu().data.numpy())
        # cv2.imwrite('bg0.png', x[0,3,:,:,9:12].cpu().data.numpy())
        # cv2.imwrite('bg1.png', x[1,3,:,:,9:12].cpu().data.numpy())
        # cv2.imwrite('bg2.png', x[2,3,:,:,9:12].cpu().data.numpy())
        # cv2.imwrite('bg3.png', x[3,3,:,:,9:12].cpu().data.numpy())

        x = (x - 128.0).permute(0, 1, 4, 2, 3)
        y = (y - 128.0).permute(0, 1, 4, 2, 3)
        z = (z - 128.0).permute(0, 1, 4, 2, 3)

        print('background', x.size())
        print('foreground', y.size())
        print('negative', z.size())
        print('word_inds', batched['word_inds'].size())
        print('word_lens', batched['word_lens'].size())
        print('fg_inds', batched['fg_inds'].size())
        print('patch_inds', batched['patch_inds'].size())
        print('out_inds', batched['out_inds'].size())
        print('out_msks', batched['out_msks'].size())
        print('foreground_resnets', batched['foreground_resnets'].size())
        print('negative_resnets', batched['negative_resnets'].size())

        print('foreground_resnets', batched['foreground_resnets'][0, 0])
        print('negative_resnets', batched['negative_resnets'][0, 0])
        print('out_msks', batched['out_msks'][0])
        print('patch_inds', batched['patch_inds'][0])

        plt.switch_backend('agg')
        bg_images = x
        fg_images = y
        neg_images = z

        bsize, ssize, n, h, w = bg_images.size()
        bg_images = bg_images.view(bsize * ssize, n, h, w)
        bg_images = tensors_to_vols(bg_images)
        bg_images = bg_images.reshape(bsize, ssize, h, w, n)

        bsize, ssize, n, h, w = fg_images.size()
        fg_images = fg_images.view(bsize * ssize, n, h, w)
        fg_images = tensors_to_vols(fg_images)
        fg_images = fg_images.reshape(bsize, ssize, h, w, n)

        bsize, ssize, n, h, w = neg_images.size()
        neg_images = neg_images.view(bsize * ssize, n, h, w)
        neg_images = tensors_to_vols(neg_images)
        neg_images = neg_images.reshape(bsize, ssize, h, w, n)

        for i in range(bsize):
            bg_seq = bg_images[i]
            fg_seq = fg_images[i]
            neg_seq = neg_images[i]
            image_idx = batched['image_index'][i]
            fg_inds = batched['fg_inds'][i]
            name = '%03d_' % i + str(image_idx).zfill(12)
            out_path = osp.join(output_dir, name + '.png')
            color = cv2.imread(batched['image_path'][i], cv2.IMREAD_COLOR)
            color, _, _ = create_squared_image(color)

            fig = plt.figure(figsize=(48, 32))
            plt.suptitle(batched['sentence'][i], fontsize=30)

            for j in range(min(len(bg_seq), 15)):
                bg, _ = heuristic_collage(bg_seq[j], 83)
                bg_mask = 255 * np.ones((bg.shape[1], bg.shape[0]))
                row, col = np.where(np.sum(np.absolute(bg), -1) == 0)
                bg_mask[row, col] = 0
                # bg = bg_seq[j][:,:,-3:]
                # bg_mask = bg_seq[j][:,:,-4]
                bg_mask = np.repeat(bg_mask[..., None], 3, -1)
                fg_color = fg_seq[j][:, :, -3:]
                # fg_mask = fg_seq[j][:,:,fg_inds[j+1]]
                fg_mask = fg_seq[j][:, :, -4]
                neg_color = neg_seq[j][:, :, -3:]
                # neg_mask = neg_seq[j][:,:,fg_inds[j+1]]
                neg_mask = neg_seq[j][:, :, -4]

                color_pair = np.concatenate((fg_color, neg_color), 1)
                mask_pair = np.concatenate((fg_mask, neg_mask), 1)
                mask_pair = np.repeat(mask_pair[..., None], 3, -1)
                patch = np.concatenate((color_pair, mask_pair), 0)
                patch = cv2.resize(patch, (bg.shape[1], bg.shape[0]))

                partially_completed_img = np.concatenate((bg, bg_mask, patch),
                                                         1)
                partially_completed_img = clamp_array(partially_completed_img,
                                                      0, 255).astype(np.uint8)
                partially_completed_img = partially_completed_img[:, :, ::-1]
                plt.subplot(4, 4, j + 1)
                plt.imshow(partially_completed_img)
                plt.axis('off')

            plt.subplot(4, 4, 16)
            plt.imshow(color[:, :, ::-1])
            plt.axis('off')

            fig.savefig(out_path, bbox_inches='tight')
            plt.close(fig)

        if cnt == 3:
            break
    print("Time", time() - start)
Пример #17
0
def test_syn_dataloader(config):
    db = composites_coco(config, 'train', '2017')

    syn_loader = synthesis_loader(db)
    output_dir = osp.join(config.model_dir, 'test_syn_dataloader')
    maybe_create(output_dir)

    loader = DataLoader(syn_loader,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.num_workers)

    start = time()
    for cnt, batched in enumerate(loader):
        x = batched['input_vol'].float()
        y = batched['gt_image'].float()
        z = batched['gt_label'].float()

        if config.use_color_volume:
            x = batch_color_volumn_preprocess(x, len(db.classes))
        else:
            x = batch_onehot_volumn_preprocess(x, len(db.classes))
        print('input_vol', x.size())
        print('gt_image', y.size())
        print('gt_label', z.size())

        # cv2.imwrite('mask0.png', x[0,:,:,-4].cpu().data.numpy())
        # cv2.imwrite('mask1.png', x[1,:,:,-4].cpu().data.numpy())
        # cv2.imwrite('mask2.png', x[2,:,:,-4].cpu().data.numpy())
        # cv2.imwrite('mask3.png', x[3,:,:,-4].cpu().data.numpy())
        # cv2.imwrite('label0.png', x[0,:,:,3].cpu().data.numpy())
        # cv2.imwrite('label1.png', x[1,:,:,3].cpu().data.numpy())
        # cv2.imwrite('label2.png', x[2,:,:,3].cpu().data.numpy())
        # cv2.imwrite('label3.png', x[3,:,:,3].cpu().data.numpy())
        # cv2.imwrite('color0.png', x[0,:,:,-3:].cpu().data.numpy())
        # cv2.imwrite('color1.png', x[1,:,:,-3:].cpu().data.numpy())
        # cv2.imwrite('color2.png', x[2,:,:,-3:].cpu().data.numpy())
        # cv2.imwrite('color3.png', x[3,:,:,-3:].cpu().data.numpy())

        x = (x - 128.0).permute(0, 3, 1, 2)

        plt.switch_backend('agg')
        x = tensors_to_vols(x)
        for i in range(x.shape[0]):
            image_idx = batched['image_index'][i]
            name = '%03d_' % i + str(image_idx).zfill(12)
            out_path = osp.join(output_dir, name + '.png')

            if config.use_color_volume:
                proposal = x[i, :, :, 12:15]
                mask = x[i, :, :, :3]
                person = x[i, :, :, 9:12]
                other = x[i, :, :, 15:18]
                gt_color = y[i]
                gt_label = z[i]
                gt_label = np.repeat(gt_label[..., None], 3, -1)
            else:
                proposal = x[i, :, :, -3:]
                mask = x[i, :, :, -4]
                mask = np.repeat(mask[..., None], 3, -1)
                person = x[i, :, :, 3]
                person = np.repeat(person[..., None], 3, -1)
                other = x[i, :, :, 5]
                other = np.repeat(other[..., None], 3, -1)
                gt_color = y[i]
                gt_label = z[i]
                gt_label = np.repeat(gt_label[..., None], 3, -1)

            r1 = np.concatenate((proposal, mask, person), 1)
            r2 = np.concatenate((gt_color, gt_label, other), 1)
            out = np.concatenate((r1, r2), 0).astype(np.uint8)

            fig = plt.figure(figsize=(32, 32))
            plt.imshow(out[:, :, :])
            plt.axis('off')

            fig.savefig(out_path, bbox_inches='tight')
            plt.close(fig)

        if cnt == 1:
            break
    print("Time", time() - start)
Пример #18
0
def test_synthesis_model(config):
    testdb  = composites_coco(config, 'test',  '2017')
    trainer = SynthesisTrainer(config)
    # we use the official validation set as test set
    trainer.sample_for_eval(testdb)