Example #1
0
def test_txt_encoder_coco(config):
    db = composites_coco(config, 'train', '2017')
    all_tables = AllCategoriesTables(db)
    all_tables.build_nntables_for_all_categories(True)
    sequence_db = sequence_loader(db, all_tables)
    net = TextEncoder(db)

    loader = DataLoader(sequence_db,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.num_workers)

    for cnt, batched in enumerate(loader):
        input_inds = batched['word_inds'].long()
        input_lens = batched['word_lens'].long()

        print('Checking the output shapes')
        out = net(input_inds, input_lens)
        out_embs, out_rfts, out_msks, out_hids = out
        print(out_rfts.size(), out_embs.size(), out_msks.size())
        if isinstance(out_hids, tuple):
            print(out_hids[0].size())
        else:
            print(out_hids.size())
        print('m: ', out_msks[-1])

        print('Checking the embedding')
        embeded = net.embedding(input_inds)
        v1 = embeded[0, 0]
        idx = input_inds[0, 0].data.item()
        v2 = db.lang_vocab.vectors[idx]
        diff = v2 - v1
        print('Diff: (should be zero)', torch.sum(diff.abs_()))

        break
Example #2
0
def test_txt_encoder_coco(config):
    transformer = image_normalize('background')
    db = coco(config, 'train', transform=transformer)
    pca_table = AllCategoriesTables(db)
    pca_table.run_PCAs_and_build_nntables_in_feature_space()
    net = TextEncoder(db)

    loader = DataLoader(db,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.num_workers)

    for cnt, batched in enumerate(loader):
        input_inds = batched['word_inds'].long()
        input_lens = batched['word_lens'].long()

        print('Checking the output shapes')
        out = net(input_inds, input_lens)
        out_embs, out_rfts, out_msks, out_hids = out
        print(out_rfts.size(), out_embs.size(), out_msks.size())
        if isinstance(out_hids, tuple):
            print(out_hids[0].size())
        else:
            print(out_hids.size())
        print('m: ', out_msks[-1])

        print('Checking the embedding')
        embeded = net.embedding(input_inds)
        v1 = embeded[0, 0]
        idx = input_inds[0, 0].data.item()
        v2 = db.lang_vocab.vectors[idx]
        diff = v2 - v1
        print('Diff: (should be zero)', torch.sum(diff.abs_()))

        break
Example #3
0
def puzzle_model_inference(config):
    traindb = composites_coco(config, 'train', '2017')
    valdb = composites_coco(config, 'test', '2017')
    auxdb = composites_coco(config, 'aux', '2017')
    trainer = PuzzleTrainer(traindb)
    t0 = time()

    patch_dir_name = 'patch_feature_' + 'with_bg' if config.use_patch_background else 'without_bg'
    if not osp.exists(osp.join(traindb.root_dir, patch_dir_name)):
        trainer.dump_shape_vectors(traindb)

    all_tables = AllCategoriesTables(traindb)
    all_tables.build_nntables_for_all_categories(True)
    print("NN completes (time %.2fs)" % (time() - t0))
    t0 = time()
    if config.for_visualization:
        trainer.sample_for_vis(0,
                               testdb,
                               len(valdb.scenedb),
                               nn_table=all_tables)
        trainer.sample_for_vis(0,
                               auxdb,
                               len(auxdb.scenedb),
                               nn_table=all_tables)
    else:
        trainer.sample_for_eval(testdb, nn_table=all_tables)
        trainer.sample_for_eval(auxdb, nn_table=all_tables)
    print("Sampling completes (time %.2fs)" % (time() - t0))
Example #4
0
 def __init__(self, db, batch_size=None, nn_table=None):
     self.db = db
     self.cfg = db.cfg
     self.batch_size = batch_size if batch_size is not None else self.cfg.batch_size
     if nn_table is None:
         self.nn_table = AllCategoriesTables(db)
         self.nn_table.build_nntables_for_all_categories()
     else:
         self.nn_table = nn_table
Example #5
0
def test_nntable(config):
    db = coco(config, 'test')
    output_dir = osp.join(config.model_dir, 'test_nntable')
    maybe_create(output_dir)

    t0 = time()
    all_tables = AllCategoriesTables(db)
    all_tables.build_nntables_for_all_categories(False)
    print("NN completes (time %.2fs)" % (time() - t0))
Example #6
0
def test_coco_decoder(config):
    db = composites_coco(config, 'train', '2017')
    all_tables = AllCategoriesTables(db)
    all_tables.build_nntables_for_all_categories(True)
    sequence_db = sequence_loader(db, all_tables)

    text_encoder = TextEncoder(db)
    img_encoder = VolumeEncoder(config)
    what_decoder = WhatDecoder(config)
    where_decoder = WhereDecoder(config)

    print('txt_encoder', get_n_params(text_encoder))
    print('img_encoder', get_n_params(img_encoder))
    print('what_decoder', get_n_params(what_decoder))
    print('where_decoder', get_n_params(where_decoder))

    loader = DataLoader(sequence_db,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.num_workers)

    for cnt, batched in enumerate(loader):
        word_inds = batched['word_inds'].long()
        word_lens = batched['word_lens'].long()
        bg_imgs = batched['background'].float()

        encoder_states = text_encoder(word_inds, word_lens)
        bg_feats = img_encoder(bg_imgs)
        prev_bgfs = bg_feats[:, 0].unsqueeze(1)
        what_outs = what_decoder((prev_bgfs, None, None), encoder_states)
        obj_logits, rnn_feats_2d, nxt_hids_2d, prev_bgfs, att_ctx, att_wei = what_outs
        print('------------------------------------------')
        print('obj_logits', obj_logits.size())
        print('rnn_feats_2d', rnn_feats_2d.size())
        print('nxt_hids_2d', nxt_hids_2d.size())
        print('prev_bgfs', prev_bgfs.size())
        print('att_ctx', att_ctx.size())
        print('att_wei', att_wei.size())
        print('------------------------------------------')

        _, obj_inds = torch.max(obj_logits + 1.0, dim=-1)
        curr_fgfs = indices2onehots(obj_inds.cpu().data,
                                    config.output_vocab_size)
        # curr_fgfs = curr_fgfs.unsqueeze(1)
        if config.cuda:
            curr_fgfs = curr_fgfs.cuda()

        where_outs = where_decoder(
            (rnn_feats_2d, curr_fgfs, prev_bgfs, att_ctx), encoder_states)
        coord_logits, attri_logits, patch_vectors, where_ctx, where_wei = where_outs
        print('coord_logits ', coord_logits.size())
        print('attri_logits ', attri_logits.size())
        print('patch_vectors', patch_vectors.size())
        # print('att_ctx', where_ctx.size())
        # print('att_wei', where_wei.size())
        break
Example #7
0
def embedding_model_inference_preparation(config):
    traindb = coco(config, 'train')
    valdb   = coco(config, 'val')
    trainer = EmbeddingTrainer(traindb)
    t0 = time()
    trainer.dump_shape_vectors(traindb)
    print("Dump shape vectors completes (time %.2fs)" % (time() - t0))
    t0 = time()
    all_tables = AllCategoriesTables(traindb)
    all_tables.build_nntables_for_all_categories(False)
    print("NN completes (time %.2fs)" % (time() - t0))
Example #8
0
def embedding_model_inference(config):
    traindb = coco(config, 'train')
    valdb   = coco(config, 'val')
    trainer = EmbeddingTrainer(traindb)
    t0 = time()
    all_tables = AllCategoriesTables(traindb)
    all_tables.build_nntables_for_all_categories(True)
    print("NN completes (time %.2fs)" % (time() - t0))
    t0 = time()
    trainer.sample(0, valdb, 50, random_or_not=False)
    print("Sampling completes (time %.2fs)" % (time() - t0))
def train_scene_model(config):
    transformer = image_normalize('background')
    traindb = coco(config, 'train', transform=transformer)
    valdb = coco(config, 'val', transform=transformer)
    testdb = coco(config, 'test', transform=transformer)
    pca_table = AllCategoriesTables(traindb)
    pca_table.run_PCAs_and_build_nntables_in_feature_space()

    trainer = SupervisedTrainer(traindb)
    # we use the official validation set as test set
    trainer.train(traindb, testdb, valdb)
Example #10
0
def puzzle_model_inference_preparation(config):
    traindb = coco(config, 'train', '2017')
    testdb = coco(config, 'test', '2017')
    trainer = PuzzleTrainer(traindb)
    t0 = time()
    trainer.dump_shape_vectors(traindb)
    trainer.dump_shape_vectors(testdb)
    print("Dump shape vectors completes (time %.2fs)" % (time() - t0))
    t0 = time()
    all_tables = AllCategoriesTables(traindb)
    all_tables.build_nntables_for_all_categories(False)
    print("NN completes (time %.2fs)" % (time() - t0))
Example #11
0
def composites_demo(config):
    traindb = composites_coco(config, 'train', '2017')
    trainer = PuzzleTrainer(traindb)
    t0 = time()

    patch_dir_name = 'patch_feature_' + 'with_bg' if config.use_patch_background else 'without_bg'
    if not osp.exists(osp.join(traindb.root_dir, patch_dir_name)):
        trainer.dump_shape_vectors(traindb)

    all_tables = AllCategoriesTables(traindb)
    all_tables.build_nntables_for_all_categories(True)
    print("NN completes (time %.2fs)" % (time() - t0))
    t0 = time()
    input_sentences = json_load('examples/composites_samples.json')
    trainer.sample_demo(input_sentences, all_tables)
    print("Sampling completes (time %.2fs)" % (time() - t0))
def overfit_scene_model(config):
    config.log_per_steps = 1
    transformer = image_normalize('background')
    traindb = coco(config, 'train', transform=transformer)
    valdb = coco(config, 'val', transform=transformer)
    testdb = coco(config, 'test', transform=transformer)
    traindb.scenedb = traindb.scenedb[:config.batch_size]
    valdb.scenedb = valdb.scenedb[:config.batch_size]
    testdb.scenedb = testdb.scenedb[:config.batch_size]
    print('build pca table')
    pca_table = AllCategoriesTables(traindb)
    pca_table.run_PCAs_and_build_nntables_in_feature_space()
    print('create trainer')
    trainer = SupervisedTrainer(traindb)
    print('start training')
    # trainer.train(traindb, traindb, traindb)
    trainer.train(traindb, valdb, testdb)
Example #13
0
def test_vol_encoder(config):
    db = composites_coco(config, 'train', '2017')

    all_tables = AllCategoriesTables(db)
    all_tables.build_nntables_for_all_categories(True)
    sequence_db = sequence_loader(db, all_tables)

    img_encoder = VolumeEncoder(config)
    print(get_n_params(img_encoder))
    # print(img_encoder)

    loader = DataLoader(sequence_db,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.num_workers)

    for cnt, batched in enumerate(loader):
        x = batched['background'].float()
        y = img_encoder(x)
        print('y.size()', y.size())
        break
Example #14
0
def puzzle_model_inference(config):
    traindb = coco(config, 'train', '2017')
    valdb = coco(config, 'val', '2017')
    auxdb = coco(config, 'aux', '2017')
    trainer = PuzzleTrainer(traindb)
    t0 = time()
    all_tables = AllCategoriesTables(traindb)
    all_tables.build_nntables_for_all_categories(True)
    print("NN completes (time %.2fs)" % (time() - t0))
    t0 = time()
    if config.for_visualization:
        trainer.sample_for_vis(0,
                               valdb,
                               len(valdb.scenedb),
                               nn_table=all_tables)
        trainer.sample_for_vis(0,
                               auxdb,
                               len(auxdb.scenedb),
                               nn_table=all_tables)
    else:
        trainer.sample_for_eval(valdb, nn_table=all_tables)
        # trainer.sample_for_eval(auxdb, nn_table=all_tables)
    print("Sampling completes (time %.2fs)" % (time() - t0))
Example #15
0
def test_shape_encoder(config):
    db = coco(config, 'train', '2017')
    all_tables = AllCategoriesTables(db)
    all_tables.build_nntables_for_all_categories(True)
    sequence_db = sequence_loader(db, all_tables)

    img_encoder = ShapeEncoder(config)
    print(get_n_params(img_encoder))

    loader = DataLoader(sequence_db,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.num_workers)

    for cnt, batched in enumerate(loader):
        x = batched['foreground'].float()
        y = batched['foreground_resnets'].float()
        y = img_encoder(x, y)
        print('y.size()', y.size())
        print('y max', torch.max(y))
        print('y min', torch.min(y))
        print('y norm', torch.norm(y, dim=-1)[0, 0])
        break
Example #16
0
def test_step_by_step(config):
    db = coco(config, 'train', '2017')
    output_dir = osp.join(config.model_dir, 'test_step_by_step')
    maybe_create(output_dir)

    all_tables = AllCategoriesTables(db)
    all_tables.build_nntables_for_all_categories(True)

    seq_db = sequence_loader(db, all_tables)
    env = simulator(db, config.batch_size, all_tables)
    env.reset()

    loader = DataLoader(seq_db,
                        batch_size=config.batch_size,
                        shuffle=True,
                        num_workers=config.num_workers)

    for cnt, batched in enumerate(loader):
        out_inds = batched['out_inds'].long().numpy()
        out_vecs = batched['out_vecs'].float().numpy()

        sequences = []
        for i in range(out_inds.shape[1]):
            frames = env.batch_render_to_pytorch(out_inds[:, i], out_vecs[:,
                                                                          i])
            sequences.append(frames)
        sequences = torch.stack(sequences, dim=1)
        # sequences = [tensors_to_vols(x) for x in sequences]

        for i in range(len(sequences)):
            sequence = sequences[i]
            image_idx = batched['image_index'][i]
            name = '%03d_' % i + str(image_idx).zfill(12)
            out_path = osp.join(output_dir, name + '.png')
            color = cv2.imread(batched['image_path'][i], cv2.IMREAD_COLOR)
            color, _, _ = create_squared_image(color)

            fig = plt.figure(figsize=(32, 32))
            plt.suptitle(batched['sentence'][i], fontsize=30)

            for j in range(min(len(sequence), 14)):
                plt.subplot(4, 4, j + 1)
                seq_np = sequence[j].cpu().data.numpy()
                if config.use_color_volume:
                    partially_completed_img, _ = heuristic_collage(seq_np, 83)
                else:
                    partially_completed_img = seq_np[:, :, -3:]
                partially_completed_img = clamp_array(partially_completed_img,
                                                      0, 255).astype(np.uint8)
                partially_completed_img = partially_completed_img[:, :, ::-1]
                plt.imshow(partially_completed_img)
                plt.axis('off')

            plt.subplot(4, 4, 16)
            plt.imshow(color[:, :, ::-1])
            plt.axis('off')

            fig.savefig(out_path, bbox_inches='tight')
            plt.close(fig)

        break
def test_scene_model(config):
    output_dir = osp.join(config.model_dir, 'test_scene_model')
    maybe_create(output_dir)
    plt.switch_backend('agg')

    transformer = image_normalize('background')
    db = coco(config, 'train', transform=transformer)
    pca_table = AllCategoriesTables(db)
    pca_table.run_PCAs_and_build_nntables_in_feature_space()

    loader = DataLoader(db,
        batch_size=config.batch_size, shuffle=False,
        num_workers=config.num_workers)

    net = SceneModel(db)

    net.eval()
    for cnt, batched in enumerate(loader):
        word_inds = batched['word_inds'].long()
        word_lens = batched['word_lens'].long()
        bg_images = batched['background'].float()

        fg_inds = batched['fg_inds'].long()
        gt_inds = batched['out_inds'].long()
        gt_vecs = batched['out_vecs'].float()
        gt_msks = batched['out_msks'].float()

        fg_onehots = indices2onehots(fg_inds, config.output_vocab_size)

        # inf_outs, _ = net((word_inds, word_lens, bg_images, fg_onehots))
        # obj_logits, coord_logits, attri_logits, pca_vectors, enc_msks, what_wei, where_wei = inf_outs
        # print('teacher forcing')
        # print('obj_logits ', obj_logits.size())
        # print('coord_logits ', coord_logits.size())
        # print('attri_logits ', attri_logits.size())
        # print('pca_vectors ', pca_vectors.size())
        # if config.what_attn:
        #     print('what_att_logits ', what_wei.size())
        # if config.where_attn > 0:
        #     print('where_att_logits ', where_wei.size())
        # print('----------------------')

        inf_outs, env = net.inference(word_inds, word_lens, -1, 0, 0, gt_inds, gt_vecs)
        # inf_outs, env = net.inference(word_inds, word_lens, -1, 2.0, 0, None, None)
        obj_logits, coord_logits, attri_logits, pca_vectors, enc_msks, what_wei, where_wei = inf_outs
        print('scheduled sampling')
        print('obj_logits ', obj_logits.size())
        print('coord_logits ', coord_logits.size())
        print('attri_logits ', attri_logits.size())
        print('pca_vectors ', pca_vectors.size())
        if config.what_attn:
            print('what_att_logits ', what_wei.size())
        if config.where_attn > 0:
            print('where_att_logits ', where_wei.size())
        print('----------------------')


        sequences = env.batch_redraw(True)
        for i in range(len(sequences)):
            sequence = sequences[i]
            image_idx = batched['image_index'][i]
            name = '%03d_'%i + str(image_idx).zfill(12)
            out_path = osp.join(output_dir, name+'.png')
            color = cv2.imread(batched['color_path'][i], cv2.IMREAD_COLOR)
            color, _, _ = create_squared_image(color)

            fig = plt.figure(figsize=(32, 16))
            plt.suptitle(batched['sentence'][i], fontsize=30)

            for j in range(min(len(sequence), 14)):
                plt.subplot(3, 5, j+1)
                partially_completed_img = clamp_array(sequence[j], 0, 255).astype(np.uint8)
                partially_completed_img = partially_completed_img[:,:,::-1]
                plt.imshow(partially_completed_img)
                plt.axis('off')

            plt.subplot(3, 5, 15)
            plt.imshow(color[:,:,::-1])
            plt.axis('off')

            fig.savefig(out_path, bbox_inches='tight')
            plt.close(fig)

        break
Example #18
0
class PuzzleTrainer(object):
    def __init__(self, db):
        self.db = db
        self.cfg = db.cfg
        self.net = PuzzleModel(db)
        if self.cfg.cuda:
            if self.cfg.parallel and torch.cuda.device_count() > 1:
                print("Let's use", torch.cuda.device_count(), "GPUs!")
                self.net = nn.DataParallel(self.net)
            self.net = self.net.cuda()

        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net
        image_encoder_trainable_paras = \
            filter(lambda p: p.requires_grad, net.image_encoder.parameters())
        raw_optimizer = optim.Adam([
            {
                'params': image_encoder_trainable_paras
            },
            {
                'params': net.text_encoder.embedding.parameters(),
                'lr': self.cfg.finetune_lr
            },
            {
                'params': net.text_encoder.rnn.parameters()
            },
            {
                'params': net.what_decoder.parameters()
            },
            {
                'params': net.where_decoder.parameters()
            },
            {
                'params': net.shape_encoder.parameters()
            },
        ],
                                   lr=self.cfg.lr)
        optimizer = Optimizer(raw_optimizer,
                              max_grad_norm=self.cfg.grad_norm_clipping)
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer.optimizer, factor=0.8, patience=3)
        # scheduler = optim.lr_scheduler.StepLR(optimizer.optimizer, step_size=3, gamma=0.8)
        # optimizer.set_scheduler(scheduler)
        self.optimizer = optimizer
        self.epoch = 0

        if self.cfg.pretrained is not None:
            self.load_pretrained_net(self.cfg.pretrained)

    def load_pretrained_net(self, pretrained_name):
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net

        cache_dir = osp.join(self.cfg.data_dir, 'caches')
        pretrained_path = osp.join(cache_dir, 'puzzle_ckpts',
                                   pretrained_name + '.pkl')
        assert osp.exists(pretrained_path)
        if self.cfg.cuda:
            states = torch.load(pretrained_path)
        else:
            states = torch.load(pretrained_path,
                                map_location=lambda storage, loc: storage)
        # states = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
        net.load_state_dict(states['state_dict'])
        self.optimizer.optimizer.load_state_dict(states['optimizer'])
        self.epoch = states['epoch']

    def batch_data(self, entry):
        ################################################
        # Inputs
        ################################################
        input_inds = entry['word_inds'].long()
        input_lens = entry['word_lens'].long()
        fg_inds = entry['fg_inds'].long()
        bg_imgs = entry['background'].float()
        fg_imgs = entry['foreground'].float()
        neg_imgs = entry['negative'].float()
        fg_resnets = entry['foreground_resnets'].float()
        neg_resnets = entry['negative_resnets'].float()
        fg_onehots = indices2onehots(fg_inds, self.cfg.output_vocab_size)

        ################################################
        # Outputs
        ################################################
        gt_inds = entry['out_inds'].long()
        # gt_vecs = entry['out_vecs'].float()
        gt_msks = entry['out_msks'].float()
        patch_inds = entry['patch_inds'].long().numpy()

        if self.cfg.cuda:
            input_inds = input_inds.cuda(non_blocking=True)
            input_lens = input_lens.cuda(non_blocking=True)
            fg_onehots = fg_onehots.cuda(non_blocking=True)
            bg_imgs = bg_imgs.cuda(non_blocking=True)
            fg_imgs = fg_imgs.cuda(non_blocking=True)
            neg_imgs = neg_imgs.cuda(non_blocking=True)
            fg_resnets = fg_resnets.cuda(non_blocking=True)
            neg_resnets = neg_resnets.cuda(non_blocking=True)
            gt_inds = gt_inds.cuda(non_blocking=True)
            gt_msks = gt_msks.cuda(non_blocking=True)

        return input_inds, input_lens, fg_onehots, bg_imgs, fg_imgs, neg_imgs, fg_resnets, neg_resnets, gt_inds, gt_msks, patch_inds

    def dump_shape_vectors(self, train_db):
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net

        db = patch_vol_loader(train_db)
        loader = DataLoader(db,
                            batch_size=512,
                            shuffle=False,
                            num_workers=4,
                            pin_memory=True)

        for cnt, batched in enumerate(loader):
            start = time()
            patch_inds = batched['patch_ind'].long()
            patch_vols = batched['patch_vol'].float()
            patch_resnet = batched['patch_resnet'].float()
            if self.cfg.cuda:
                patch_vols = patch_vols.cuda(non_blocking=True)
                patch_resnet = patch_resnet.cuda(non_blocking=True)
            patch_features = net.shape_encoder.batch_forward(
                patch_vols, patch_resnet)
            patch_features = patch_features.cpu().data.numpy()
            for i in range(patch_vols.size(0)):
                patch = train_db.patchdb[patch_inds[i]]
                image_index = patch['image_index']
                instance_ind = patch['instance_ind']
                patch_feature_path = train_db.patch_path_from_indices(
                    image_index, instance_ind, 'patch_feature', 'pkl',
                    self.cfg.use_patch_background)
                patch_feature_dir = osp.dirname(patch_feature_path)
                maybe_create(patch_feature_dir)
                features = patch_features[i].flatten()
                with open(patch_feature_path, 'wb') as fid:
                    pickle.dump(features, fid, pickle.HIGHEST_PROTOCOL)
            print('%s, current_ind: %d, time consumed: %f' %
                  (train_db.split, cnt, time() - start))

    def evaluate(self,
                 inf_outs,
                 pos_vecs,
                 neg_vecs,
                 ref_inds,
                 ref_msks,
                 db=None,
                 patch_inds=None):
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net

        _, _, _, _, enc_msks, what_wei, where_wei = inf_outs
        logits, pred_vecs = net.collect_logits_and_vectors(inf_outs, ref_inds)

        ####################################################################
        # Prediction loss
        ####################################################################
        bsize, slen, _ = logits.size()
        loss_wei = [
            self.cfg.obj_loss_weight, \
            self.cfg.coord_loss_weight, \
            self.cfg.scale_loss_weight, \
            self.cfg.ratio_loss_weight
        ]
        loss_wei = torch.from_numpy(np.array(loss_wei)).float()
        if self.cfg.cuda:
            loss_wei = loss_wei.cuda()
        loss_wei = loss_wei.view(1, 1, 4)
        loss_wei = loss_wei.expand(bsize, slen, 4)

        pred_loss = -torch.log(
            logits.clamp(min=self.cfg.eps)) * loss_wei * ref_msks
        pred_loss = torch.sum(pred_loss) / (torch.sum(ref_msks) + self.cfg.eps)

        ####################################################################
        # Embedding loss
        ####################################################################
        bsize, slen, vsize = pred_vecs.size()
        embed_metric = nn.TripletMarginLoss(margin=self.cfg.margin,
                                            p=2,
                                            eps=1e-06,
                                            swap=False,
                                            size_average=None,
                                            reduction='none')
        embed_loss = embed_metric(pred_vecs.view(bsize * slen, vsize),
                                  pos_vecs.view(bsize * slen, vsize),
                                  neg_vecs.view(bsize * slen,
                                                vsize)).view(bsize, slen)
        embed_mask = ref_msks[:, :, 1]
        embed_loss = torch.sum(
            embed_loss * embed_mask) / (torch.sum(embed_mask) + self.cfg.eps)
        embed_loss = embed_loss * self.cfg.embed_loss_weight
        # print('embed_loss', embed_loss)

        ####################################################################
        # doubly stochastic attn loss
        ####################################################################
        attn_loss = logits.new_zeros(size=(1, ))
        encoder_msks = enc_msks
        if self.cfg.what_attn:
            obj_msks = ref_msks[:, :, 0].unsqueeze(-1)
            what_att_logits = what_wei
            raw_obj_att_loss = torch.mul(what_att_logits, obj_msks)
            raw_obj_att_loss = torch.sum(raw_obj_att_loss, dim=1)
            obj_att_loss = raw_obj_att_loss - encoder_msks
            obj_att_loss = torch.sum(obj_att_loss**2, dim=-1)
            obj_att_loss = torch.mean(obj_att_loss)
            attn_loss = attn_loss + obj_att_loss
        attn_loss = self.cfg.attn_loss_weight * attn_loss

        ####################################################################
        # Accuracies
        ####################################################################
        pred_accu = net.collect_accuracies(inf_outs, ref_inds)
        pred_accu = pred_accu * ref_msks
        comp_accu = torch.sum(torch.sum(pred_accu, 0), 0)
        comp_msks = torch.sum(torch.sum(ref_msks, 0), 0)
        pred_accu = comp_accu / (comp_msks + self.cfg.eps)

        ####################################################################
        # Dump predicted vectors
        ####################################################################
        if (db is not None) and (patch_inds is not None):
            tmp_vecs = pred_vecs.clone()
            tmp_vecs = tmp_vecs.detach().cpu().data.numpy()
            bsize, slen, fsize = tmp_vecs.shape
            for i in range(bsize):
                for j in range(slen):
                    patch_index = patch_inds[i, j]
                    if patch_index < 0:
                        continue
                    patch = db.patchdb[patch_index]
                    image_index = patch['image_index']
                    instance_ind = patch['instance_ind']
                    patch_feature_path = db.patch_path_from_indices(
                        image_index, instance_ind, 'predicted_feature', 'pkl',
                        None)
                    patch_feature_dir = osp.dirname(patch_feature_path)
                    maybe_create(patch_feature_dir)
                    features = tmp_vecs[i, j].flatten()
                    with open(patch_feature_path, 'wb') as fid:
                        pickle.dump(features, fid, pickle.HIGHEST_PROTOCOL)

        return pred_loss, embed_loss, attn_loss, pred_accu

    def train(self, train_db, val_db, test_db):
        ##################################################################
        ## LOG
        ##################################################################
        logz.configure_output_dir(self.cfg.model_dir)
        logz.save_config(self.cfg)

        ##################################################################
        ## NN table
        ##################################################################
        if self.cfg.use_hard_mining:
            self.train_tables = AllCategoriesTables(train_db)
            self.val_tables = AllCategoriesTables(val_db)
            self.train_tables.build_nntables_for_all_categories(True)
            self.val_tables.build_nntables_for_all_categories(True)

        ##################################################################
        ## Main loop
        ##################################################################
        start = time()
        min_val_loss = 100000000
        for epoch in range(self.epoch, self.cfg.n_epochs):
            ##################################################################
            ## Training
            ##################################################################
            torch.cuda.empty_cache()
            train_loss, train_accu = self.train_epoch(train_db, epoch)

            ##################################################################
            ## Validation
            ##################################################################
            torch.cuda.empty_cache()
            val_loss, val_accu = self.validate_epoch(val_db, epoch)

            ##################################################################
            ## Logging
            ##################################################################

            # update optim scheduler
            current_val_loss = np.mean(val_loss[:, 0])
            # self.optimizer.update(current_val_loss, epoch)
            logz.log_tabular("Time", time() - start)
            logz.log_tabular("Iteration", epoch)
            logz.log_tabular("AverageLoss", np.mean(train_loss[:, 0]))
            logz.log_tabular("AveragePredLoss", np.mean(train_loss[:, 1]))
            logz.log_tabular("AverageEmbedLoss", np.mean(train_loss[:, 2]))
            logz.log_tabular("AverageAttnLoss", np.mean(train_loss[:, 3]))
            logz.log_tabular("AverageObjAccu", np.mean(train_accu[:, 0]))
            logz.log_tabular("AverageCoordAccu", np.mean(train_accu[:, 1]))
            logz.log_tabular("AverageScaleAccu", np.mean(train_accu[:, 2]))
            logz.log_tabular("AverageRatioAccu", np.mean(train_accu[:, 3]))

            logz.log_tabular("ValAverageLoss", np.mean(val_loss[:, 0]))
            logz.log_tabular("ValAveragePredLoss", np.mean(val_loss[:, 1]))
            logz.log_tabular("ValAverageEmbedLoss", np.mean(val_loss[:, 2]))
            logz.log_tabular("ValAverageAttnLoss", np.mean(val_loss[:, 3]))
            logz.log_tabular("ValAverageObjAccu", np.mean(val_accu[:, 0]))
            logz.log_tabular("ValAverageCoordAccu", np.mean(val_accu[:, 1]))
            logz.log_tabular("ValAverageScaleAccu", np.mean(val_accu[:, 2]))
            logz.log_tabular("ValAverageRatioAccu", np.mean(val_accu[:, 3]))
            logz.dump_tabular()

            ##################################################################
            ## Checkpoint
            ##################################################################
            if self.cfg.use_hard_mining:
                if (epoch + 1) % 3 == 0:
                    torch.cuda.empty_cache()
                    t0 = time()
                    self.dump_shape_vectors(train_db)
                    torch.cuda.empty_cache()
                    self.dump_shape_vectors(val_db)
                    print("Dump shape vectors completes (time %.2fs)" %
                          (time() - t0))
                    torch.cuda.empty_cache()
                    t0 = time()
                    self.train_tables.build_nntables_for_all_categories(False)
                    self.val_tables.build_nntables_for_all_categories(False)
                    print("NN completes (time %.2fs)" % (time() - t0))
            self.save_checkpoint(epoch)
            # else:
            #     if min_val_loss > current_val_loss:
            #         min_val_loss = current_val_loss
            #         self.save_checkpoint(epoch)

    def train_epoch(self, train_db, epoch):
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net

        train_db.cfg.sent_group = -1
        if self.cfg.use_hard_mining:
            seq_db = sequence_loader(train_db, self.train_tables)
        else:
            seq_db = sequence_loader(train_db)

        # if epoch == 0:
        #     seq_db = sequence_loader(train_db)
        # else:
        #     seq_db = sequence_loader(train_db, self.train_tables)
        # seq_db = sequence_loader(train_db, self.train_tables)
        train_loader = DataLoader(seq_db,
                                  batch_size=self.cfg.batch_size,
                                  shuffle=True,
                                  num_workers=self.cfg.num_workers,
                                  pin_memory=True)

        all_losses, all_accuracies = [], []

        for cnt, batched in enumerate(train_loader):
            ##################################################################
            ## Batched data
            ##################################################################
            input_inds, input_lens, fg_onehots, bg_imgs, \
            fg_imgs, neg_imgs, fg_resnets, neg_resnets,\
            gt_inds, gt_msks, patch_inds = \
                self.batch_data(batched)

            ##################################################################
            ## Train one step
            ##################################################################
            self.net.train()
            self.net.zero_grad()

            inputs = (input_inds, input_lens, bg_imgs, fg_onehots, fg_imgs,
                      neg_imgs, fg_resnets, neg_resnets)
            inf_outs, _, pos_vecs, neg_vecs = self.net(inputs)
            if self.cfg.use_hard_mining:
                pred_loss, embed_loss, attn_loss, pred_accu = self.evaluate(
                    inf_outs, pos_vecs, neg_vecs, gt_inds, gt_msks, train_db,
                    patch_inds)
            else:
                pred_loss, embed_loss, attn_loss, pred_accu = self.evaluate(
                    inf_outs, pos_vecs, neg_vecs, gt_inds, gt_msks)

            loss = pred_loss + embed_loss + attn_loss
            loss.backward()
            self.optimizer.step()

            ##################################################################
            ## Collect info
            ##################################################################
            all_losses.append(
                np.array([
                    loss.cpu().data.item(),
                    pred_loss.cpu().data.item(),
                    embed_loss.cpu().data.item(),
                    attn_loss.cpu().data.item()
                ]))
            all_accuracies.append(pred_accu.cpu().data.numpy())

            ##################################################################
            ## Print info
            ##################################################################
            if cnt % self.cfg.log_per_steps == 0:
                print('Epoch %03d, iter %07d:' % (epoch, cnt))
                tmp_losses = np.stack(all_losses, 0)
                tmp_accuracies = np.stack(all_accuracies, 0)
                print('losses: ', np.mean(tmp_losses[:, 0]),
                      np.mean(tmp_losses[:, 1]), np.mean(tmp_losses[:, 2]),
                      np.mean(tmp_losses[:, 3]))
                print('accuracies: ', np.mean(tmp_accuracies[:, 0]),
                      np.mean(tmp_accuracies[:, 1]),
                      np.mean(tmp_accuracies[:, 2]),
                      np.mean(tmp_accuracies[:, 3]))
                print('-------------------------')

        all_losses = np.stack(all_losses, 0)
        all_accuracies = np.stack(all_accuracies, 0)

        return all_losses, all_accuracies

    def validate_epoch(self, val_db, epoch):
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net

        all_losses, all_accuracies = [], []
        # initial experiment, just use one group of sentence
        for G in range(5):
            val_db.cfg.sent_group = G
            # if epoch == 0:
            #     seq_db = sequence_loader(val_db)
            # else:
            #     seq_db = sequence_loader(val_db, self.val_tables)
            # seq_db = sequence_loader(val_db, self.val_tables)
            if self.cfg.use_hard_mining:
                seq_db = sequence_loader(val_db, self.val_tables)
            else:
                seq_db = sequence_loader(val_db)
            val_loader = DataLoader(seq_db,
                                    batch_size=self.cfg.batch_size,
                                    shuffle=False,
                                    num_workers=self.cfg.num_workers,
                                    pin_memory=True)

            for cnt, batched in enumerate(val_loader):
                ##################################################################
                ## Batched data
                ##################################################################
                input_inds, input_lens, fg_onehots, bg_imgs, \
                fg_imgs, neg_imgs, fg_resnets, neg_resnets,\
                gt_inds, gt_msks, patch_inds = \
                    self.batch_data(batched)

                ##################################################################
                ## Validate one step
                ##################################################################
                self.net.eval()
                with torch.no_grad():
                    inputs = (input_inds, input_lens, bg_imgs, fg_onehots,
                              fg_imgs, neg_imgs, fg_resnets, neg_resnets)
                    inf_outs, _, pos_vecs, neg_vecs = self.net(inputs)
                    if self.cfg.use_hard_mining:
                        pred_loss, embed_loss, attn_loss, pred_accu = self.evaluate(
                            inf_outs, pos_vecs, neg_vecs, gt_inds, gt_msks,
                            val_db, patch_inds)
                    else:
                        pred_loss, embed_loss, attn_loss, pred_accu = self.evaluate(
                            inf_outs, pos_vecs, neg_vecs, gt_inds, gt_msks)

                loss = pred_loss + embed_loss + attn_loss
                all_losses.append(
                    np.array([
                        loss.cpu().data.item(),
                        pred_loss.cpu().data.item(),
                        embed_loss.cpu().data.item(),
                        attn_loss.cpu().data.item()
                    ]))
                all_accuracies.append(pred_accu.cpu().data.numpy())
                print(epoch, G, cnt)

        all_losses = np.stack(all_losses, 0)
        all_accuracies = np.stack(all_accuracies, 0)

        return all_losses, all_accuracies

    def save_checkpoint(self, epoch):
        print(" [*] Saving checkpoints...")
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net
        checkpoint_dir = osp.join(self.cfg.model_dir, 'puzzle_ckpts')
        if not osp.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        states = {
            'epoch': epoch,
            'state_dict': net.state_dict(),
            'optimizer': self.optimizer.optimizer.state_dict()
        }
        torch.save(states, osp.join(checkpoint_dir, "ckpt-%03d.pkl" % epoch))

    def decode_attention(self, word_inds, word_lens, att_logits):
        _, att_inds = torch.topk(att_logits, 3, -1)
        att_inds = att_inds.cpu().data.numpy()

        if len(word_inds.shape) > 1:
            lin_inds = []
            for i in range(word_inds.shape[0]):
                lin_inds.extend(word_inds[i, :word_lens[i]].tolist())
            vlen = len(lin_inds)
            npad = self.cfg.max_input_length * 3 - vlen
            lin_inds = lin_inds + [0] * npad
            # print(lin_inds)
            lin_inds = np.array(lin_inds).astype(np.int32)
        else:
            lin_inds = word_inds.copy()

        slen, _ = att_inds.shape
        attn_words = []
        for i in range(slen):
            w_inds = [lin_inds[x] for x in att_inds[i]]
            w_strs = [self.db.lang_vocab.index2word[x] for x in w_inds]
            attn_words = attn_words + [w_strs]

        return attn_words

    def sample_for_vis(self,
                       epoch,
                       test_db,
                       N,
                       random_or_not=False,
                       nn_table=None):
        ##############################################################
        # Output prefix
        ##############################################################
        output_dir = osp.join(self.cfg.model_dir, '%03d' % epoch, 'vis')
        maybe_create(output_dir)

        seq_db = sequence_loader(test_db)

        ##############################################################
        # Main loop
        ##############################################################
        plt.switch_backend('agg')
        if random_or_not:
            indices = np.random.permutation(range(len(test_db.scenedb)))
        else:
            indices = range(len(test_db.scenedb))
        indices = indices[:min(N, len(test_db.scenedb))]

        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net

        for i in indices:
            entry = seq_db[i]
            gt_scene = test_db.scenedb[i]
            image_index = gt_scene['image_index']
            image_path = test_db.color_path_from_index(image_index)

            gt_img = cv2.imread(image_path, cv2.IMREAD_COLOR)
            gt_img, _, _ = create_squared_image(gt_img)
            gt_img = cv2.resize(
                gt_img,
                (self.cfg.input_image_size[0], self.cfg.input_image_size[1]))

            ##############################################################
            # Inputs
            ##############################################################
            input_inds_np = np.array(entry['word_inds'])
            input_lens_np = np.array(entry['word_lens'])
            input_inds = torch.from_numpy(input_inds_np).long().unsqueeze(0)
            input_lens = torch.from_numpy(input_lens_np).long().unsqueeze(0)

            if self.cfg.cuda:
                input_inds = input_inds.cuda()
                input_lens = input_lens.cuda()
            ##############################################################
            # Inference
            ##############################################################
            self.net.eval()
            with torch.no_grad():
                inf_outs, env = net.inference(input_inds, input_lens, -1, 1.0,
                                              0, None, None, nn_table)
            frames, _, _, _, _ = env.batch_redraw(return_sequence=True)
            frames = frames[0]
            _, _, _, _, _, what_wei, where_wei = inf_outs

            if self.cfg.what_attn:
                what_attn_words = self.decode_attention(
                    input_inds_np, input_lens_np, what_wei.squeeze(0))
            if self.cfg.where_attn > 0:
                where_attn_words = self.decode_attention(
                    input_inds_np, input_lens_np, where_wei.squeeze(0))

            ##############################################################
            # Draw
            ##############################################################
            fig = plt.figure(figsize=(32, 32))
            plt.suptitle(entry['sentence'], fontsize=40)
            for j in range(len(frames)):
                subtitle = ''
                if self.cfg.what_attn:
                    subtitle = subtitle + ' '.join(what_attn_words[j])
                if self.cfg.where_attn > 0:
                    subtitle = subtitle + '\n' + ' '.join(where_attn_words[j])

                plt.subplot(4, 4, j + 1)
                plt.title(subtitle, fontsize=30)
                if self.cfg.use_color_volume:
                    vis_img, _ = heuristic_collage(frames[j], 83)
                else:
                    vis_img = frames[j][:, :, -3:]
                vis_img = clamp_array(vis_img[:, :, ::-1], 0,
                                      255).astype(np.uint8)
                plt.imshow(vis_img)
                plt.axis('off')
            plt.subplot(4, 4, 16)
            plt.imshow(gt_img[:, :, ::-1])
            plt.axis('off')

            name = osp.splitext(osp.basename(image_path))[0]
            out_path = osp.join(output_dir, name + '.png')
            fig.savefig(out_path, bbox_inches='tight')
            plt.close(fig)
            print('sampling: %d, %d' % (epoch, i))

    def sample_for_eval(self, test_db, nn_table=None):
        ##############################################################
        # Output prefix
        ##############################################################
        # gt_dir    = osp.join(self.cfg.model_dir, 'gt')
        # frame_dir = osp.join(self.cfg.model_dir, 'proposal_images')
        # noice_dir = osp.join(self.cfg.model_dir, 'proposal_noices')
        # label_dir = osp.join(self.cfg.model_dir, 'proposal_labels')
        # mask_dir  = osp.join(self.cfg.model_dir, 'proposal_masks')
        # info_dir  = osp.join(self.cfg.model_dir, 'proposal_info')

        main_dir = 'puzzle_results'
        maybe_create(main_dir)

        gt_dir = osp.join(main_dir, 'gt')
        frame_dir = osp.join(main_dir, 'proposal_images')
        noice_dir = osp.join(main_dir, 'proposal_noices')
        label_dir = osp.join(main_dir, 'proposal_labels')
        mask_dir = osp.join(main_dir, 'proposal_masks')
        info_dir = osp.join(main_dir, 'proposal_info')

        maybe_create(gt_dir)
        maybe_create(frame_dir)
        maybe_create(noice_dir)
        maybe_create(label_dir)
        maybe_create(mask_dir)
        maybe_create(info_dir)

        seq_db = sequence_loader(test_db)
        ##############################################################
        # Main loop
        ##############################################################
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net
        # start_ind = 0
        # end_ind = len(seq_db)
        start_ind = self.cfg.seed * 1250
        end_ind = (self.cfg.seed + 1) * 1250
        # start_ind = 35490
        # end_ind = len(seq_db)
        for i in range(start_ind, end_ind):
            entry = seq_db[i]
            gt_scene = test_db.scenedb[i]
            image_index = gt_scene['image_index']
            image_path = test_db.color_path_from_index(image_index)
            name = osp.splitext(osp.basename(image_path))[0]
            gt_path = osp.join(gt_dir, osp.basename(image_path))
            # save gt
            shutil.copy2(image_path, gt_path)

            ##############################################################
            # Inputs
            ##############################################################
            input_inds_np = np.array(entry['word_inds'])
            input_lens_np = np.array(entry['word_lens'])
            input_inds = torch.from_numpy(input_inds_np).long().unsqueeze(0)
            input_lens = torch.from_numpy(input_lens_np).long().unsqueeze(0)
            if self.cfg.cuda:
                input_inds = input_inds.cuda()
                input_lens = input_lens.cuda()
            ##############################################################
            # Inference
            ##############################################################
            self.net.eval()
            with torch.no_grad():
                inf_outs, env = net.inference(input_inds, input_lens, -1, 1.0,
                                              0, None, None, nn_table)
            frame, noice, mask, label, env_info = env.batch_redraw(
                return_sequence=False)
            frame = frame[0][0]
            noice = noice[0][0]
            mask = mask[0][0]
            label = label[0][0]
            env_info = env_info[0]
            frame_path = osp.join(frame_dir, name + '.jpg')
            noice_path = osp.join(noice_dir, name + '.jpg')
            mask_path = osp.join(mask_dir, name + '.png')
            label_path = osp.join(label_dir, name + '.png')
            info_path = osp.join(info_dir, name + '.json')

            if self.cfg.use_color_volume:
                frame, _ = heuristic_collage(frame, 83)
                noice, _ = heuristic_collage(noice, 83)
            else:
                frame = frame[:, :, -3:]
                noice = noice[:, :, -3:]
            cv2.imwrite(frame_path,
                        clamp_array(frame, 0, 255).astype(np.uint8))
            cv2.imwrite(noice_path,
                        clamp_array(noice, 0, 255).astype(np.uint8))
            cv2.imwrite(mask_path, clamp_array(255 * mask, 0, 255))
            cv2.imwrite(label_path, label)

            # info
            pred_info = {}
            pred_info['width'] = env_info['width']
            pred_info['height'] = env_info['height']
            pred_info['clses'] = env_info['clses'].tolist()
            pred_info['boxes'] = [x.tolist() for x in env_info['boxes']]
            current_patches = env_info['patches']
            current_image_indices = []
            current_instance_inds = []
            for j in range(len(pred_info['clses'])):
                current_image_indices.append(current_patches[j]['image_index'])
                current_instance_inds.append(
                    current_patches[j]['instance_ind'])
            pred_info['image_indices'] = current_image_indices
            pred_info['instance_inds'] = current_instance_inds
            with open(info_path, 'w') as fp:
                json.dump(pred_info, fp, indent=4, sort_keys=True)
            print('sampling: %d, %s' % (i, name))
Example #19
0
def test_puzzle_model(config):
    output_dir = osp.join(config.model_dir, 'test_puzzle_model')
    maybe_create(output_dir)
    plt.switch_backend('agg')

    db = composites_coco(config, 'train', '2017')
    all_tables = AllCategoriesTables(db)
    all_tables.build_nntables_for_all_categories(True)
    sequence_db = sequence_loader(db, all_tables)
    loader = DataLoader(sequence_db,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.num_workers)

    net = PuzzleModel(db)

    net.eval()
    for cnt, batched in enumerate(loader):
        word_inds = batched['word_inds'].long()
        word_lens = batched['word_lens'].long()
        bg_images = batched['background'].float()
        fg_images = batched['foreground'].float()
        neg_images = batched['negative'].float()

        fg_resnets = batched['foreground_resnets'].float()
        neg_resnets = batched['negative_resnets'].float()

        fg_inds = batched['fg_inds'].long()
        gt_inds = batched['out_inds'].long()
        gt_msks = batched['out_msks'].float()

        fg_onehots = indices2onehots(fg_inds, config.output_vocab_size)

        inf_outs, _, positive_feats, negative_feats = net(
            (word_inds, word_lens, bg_images, fg_onehots, fg_images,
             neg_images, fg_resnets, neg_resnets))
        obj_logits, coord_logits, attri_logits, patch_vectors, enc_msks, what_wei, where_wei = inf_outs
        print('teacher forcing')
        print('obj_logits ', obj_logits.size())
        print('coord_logits ', coord_logits.size())
        print('attri_logits ', attri_logits.size())
        print('patch_vectors ', patch_vectors.size())
        print('patch_vectors max:', torch.max(patch_vectors))
        print('patch_vectors min:', torch.min(patch_vectors))
        print('patch_vectors norm:',
              torch.norm(patch_vectors, dim=-2)[0, 0, 0])
        print('positive_feats ', positive_feats.size())
        print('negative_feats ', negative_feats.size())
        if config.what_attn:
            print('what_att_logits ', what_wei.size())
        if config.where_attn > 0:
            print('where_att_logits ', where_wei.size())
        print('----------------------')

        _, pred_vecs = net.collect_logits_and_vectors(inf_outs, gt_inds)
        print('pred_vecs', pred_vecs.size())
        print('*******************')

        # # inf_outs, env = net.inference(word_inds, word_lens, -1, 0.0, 0, gt_inds, gt_vecs)
        # inf_outs, env = net.inference(word_inds, word_lens, -1, 2.0, 0, None, None, all_tables)
        # obj_logits, coord_logits, attri_logits, patch_vectors, enc_msks, what_wei, where_wei = inf_outs
        # print('scheduled sampling')
        # print('obj_logits ', obj_logits.size())
        # print('coord_logits ', coord_logits.size())
        # print('attri_logits ', attri_logits.size())
        # print('patch_vectors ', patch_vectors.size())
        # if config.what_attn:
        #     print('what_att_logits ', what_wei.size())
        # if config.where_attn > 0:
        #     print('where_att_logits ', where_wei.size())
        # print('----------------------')

        # sequences = env.batch_redraw(True)
        # for i in range(len(sequences)):
        #     sequence = sequences[i]
        #     image_idx = batched['image_index'][i]
        #     name = '%03d_'%i + str(image_idx).zfill(12)
        #     out_path = osp.join(output_dir, name+'.png')
        #     color = cv2.imread(batched['color_path'][i], cv2.IMREAD_COLOR)
        #     color, _, _ = create_squared_image(color)

        #     fig = plt.figure(figsize=(32, 16))
        #     plt.suptitle(batched['sentence'][i], fontsize=30)

        #     for j in range(min(len(sequence), 14)):
        #         plt.subplot(3, 5, j+1)
        #         partially_completed_img = clamp_array(sequence[j][:,:,-3:], 0, 255).astype(np.uint8)
        #         partially_completed_img = partially_completed_img[:,:,::-1]
        #         plt.imshow(partially_completed_img)
        #         plt.axis('off')

        #     plt.subplot(3, 5, 15)
        #     plt.imshow(color[:,:,::-1])
        #     plt.axis('off')

        #     fig.savefig(out_path, bbox_inches='tight')
        #     plt.close(fig)

        break
Example #20
0
    def train(self, train_db, val_db, test_db):
        ##################################################################
        ## LOG
        ##################################################################
        logz.configure_output_dir(self.cfg.model_dir)
        logz.save_config(self.cfg)

        ##################################################################
        ## NN table
        ##################################################################
        if self.cfg.use_hard_mining:
            self.train_tables = AllCategoriesTables(train_db)
            self.val_tables = AllCategoriesTables(val_db)
            self.train_tables.build_nntables_for_all_categories(True)
            self.val_tables.build_nntables_for_all_categories(True)

        ##################################################################
        ## Main loop
        ##################################################################
        start = time()
        min_val_loss = 100000000
        for epoch in range(self.epoch, self.cfg.n_epochs):
            ##################################################################
            ## Training
            ##################################################################
            torch.cuda.empty_cache()
            train_loss, train_accu = self.train_epoch(train_db, epoch)

            ##################################################################
            ## Validation
            ##################################################################
            torch.cuda.empty_cache()
            val_loss, val_accu = self.validate_epoch(val_db, epoch)

            ##################################################################
            ## Logging
            ##################################################################

            # update optim scheduler
            current_val_loss = np.mean(val_loss[:, 0])
            # self.optimizer.update(current_val_loss, epoch)
            logz.log_tabular("Time", time() - start)
            logz.log_tabular("Iteration", epoch)
            logz.log_tabular("AverageLoss", np.mean(train_loss[:, 0]))
            logz.log_tabular("AveragePredLoss", np.mean(train_loss[:, 1]))
            logz.log_tabular("AverageEmbedLoss", np.mean(train_loss[:, 2]))
            logz.log_tabular("AverageAttnLoss", np.mean(train_loss[:, 3]))
            logz.log_tabular("AverageObjAccu", np.mean(train_accu[:, 0]))
            logz.log_tabular("AverageCoordAccu", np.mean(train_accu[:, 1]))
            logz.log_tabular("AverageScaleAccu", np.mean(train_accu[:, 2]))
            logz.log_tabular("AverageRatioAccu", np.mean(train_accu[:, 3]))

            logz.log_tabular("ValAverageLoss", np.mean(val_loss[:, 0]))
            logz.log_tabular("ValAveragePredLoss", np.mean(val_loss[:, 1]))
            logz.log_tabular("ValAverageEmbedLoss", np.mean(val_loss[:, 2]))
            logz.log_tabular("ValAverageAttnLoss", np.mean(val_loss[:, 3]))
            logz.log_tabular("ValAverageObjAccu", np.mean(val_accu[:, 0]))
            logz.log_tabular("ValAverageCoordAccu", np.mean(val_accu[:, 1]))
            logz.log_tabular("ValAverageScaleAccu", np.mean(val_accu[:, 2]))
            logz.log_tabular("ValAverageRatioAccu", np.mean(val_accu[:, 3]))
            logz.dump_tabular()

            ##################################################################
            ## Checkpoint
            ##################################################################
            if self.cfg.use_hard_mining:
                if (epoch + 1) % 3 == 0:
                    torch.cuda.empty_cache()
                    t0 = time()
                    self.dump_shape_vectors(train_db)
                    torch.cuda.empty_cache()
                    self.dump_shape_vectors(val_db)
                    print("Dump shape vectors completes (time %.2fs)" %
                          (time() - t0))
                    torch.cuda.empty_cache()
                    t0 = time()
                    self.train_tables.build_nntables_for_all_categories(False)
                    self.val_tables.build_nntables_for_all_categories(False)
                    print("NN completes (time %.2fs)" % (time() - t0))
            self.save_checkpoint(epoch)
def generate_simulated_scenes(config, split, year):
    db = coco(config, split, year)
    data_dir = osp.join(config.data_dir, 'coco')
    if (split == 'test') or (split == 'aux'):
        images_dir = osp.join(data_dir, 'crn_images', 'train' + year)
        noices_dir = osp.join(data_dir, 'crn_noices', 'train' + year)
        labels_dir = osp.join(data_dir, 'crn_labels', 'train' + year)
        masks_dir = osp.join(data_dir, 'crn_masks', 'train' + year)
    else:
        images_dir = osp.join(data_dir, 'crn_images', split + year)
        noices_dir = osp.join(data_dir, 'crn_noices', split + year)
        labels_dir = osp.join(data_dir, 'crn_labels', split + year)
        masks_dir = osp.join(data_dir, 'crn_masks', split + year)
    maybe_create(images_dir)
    maybe_create(noices_dir)
    maybe_create(labels_dir)
    maybe_create(masks_dir)

    traindb = coco(config, 'train', '2017')
    nn_tables = AllCategoriesTables(traindb)
    nn_tables.build_nntables_for_all_categories(True)

    # start_ind = 0
    # end_ind = len(db.scenedb)
    start_ind = 25000 + 14000 * config.seed
    end_ind = 25000 + 14000 * (config.seed + 1)
    patches_per_class = traindb.patches_per_class
    color_transfer_threshold = 0.8

    for i in range(start_ind, end_ind):
        entry = db.scenedb[i]
        width = entry['width']
        height = entry['height']
        xywhs = entry['boxes']
        masks = entry['masks']
        clses = entry['clses']
        image_index = entry['image_index']
        instance_inds = entry['instance_inds']

        full_mask = np.zeros((height, width), dtype=np.float32)
        full_label = np.zeros((height, width), dtype=np.float32)
        full_image = np.zeros((height, width, 3), dtype=np.float32)
        full_noice = np.zeros((height, width, 3), dtype=np.float32)

        original_image = cv2.imread(db.color_path_from_index(image_index),
                                    cv2.IMREAD_COLOR)

        for j in range(len(masks)):
            src_img = original_image.astype(np.float32).copy()
            xywh = xywhs[j]
            mask = masks[j]
            cls_idx = clses[j]
            instance_ind = instance_inds[j]
            embed_path = db.patch_path_from_indices(
                image_index, instance_ind, 'patch_feature', 'pkl',
                config.use_patch_background)
            with open(embed_path, 'rb') as fid:
                query_vector = pickle.load(fid)
            n_samples = min(
                100, len(patches_per_class[cls_idx])
            )  #min(config.n_nntable_trees, len(patches_per_class[cls_idx]))
            candidate_patches = nn_tables.retrieve(cls_idx, query_vector,
                                                   n_samples)
            candidate_patches = [
                x for x in candidate_patches
                if x['instance_ind'] != instance_ind
            ]
            assert (len(candidate_patches) > 1)

            # candidate_instance_ind = instance_ind
            # candidate_patch = None
            # while (candidate_instance_ind == instance_ind):
            # 	cid = np.random.randint(0, len(candidate_patches))
            # 	candidate_patch = candidate_patches[cid]
            # 	candidate_instance_ind = candidate_patch['instance_ind']
            candidate_patch = find_closest_patch(db, traindb, image_index,
                                                 instance_ind,
                                                 candidate_patches)

            # stenciling
            src_mask = COCOmask.decode(mask)
            dst_mask = COCOmask.decode(candidate_patch['mask'])
            src_xyxy = xywh_to_xyxy(xywh, width, height)
            dst_xyxy = xywh_to_xyxy(candidate_patch['box'],
                                    candidate_patch['width'],
                                    candidate_patch['height'])
            dst_mask = dst_mask[dst_xyxy[1]:(dst_xyxy[3] + 1),
                                dst_xyxy[0]:(dst_xyxy[2] + 1)]
            dst_mask = cv2.resize(
                dst_mask,
                (src_xyxy[2] - src_xyxy[0] + 1, src_xyxy[3] - src_xyxy[1] + 1),
                interpolation=cv2.INTER_NEAREST)
            src_mask[src_xyxy[1]:(src_xyxy[3]+1), src_xyxy[0]:(src_xyxy[2]+1)] = \
             np.minimum(dst_mask, src_mask[src_xyxy[1]:(src_xyxy[3]+1), src_xyxy[0]:(src_xyxy[2]+1)])
            # color transfer
            if random.random() > color_transfer_threshold:
                candidate_index = candidate_patch['image_index']
                candidate_image = cv2.imread(
                    traindb.color_path_from_index(candidate_index),
                    cv2.IMREAD_COLOR).astype(np.float32)
                candidate_cropped = candidate_image[dst_xyxy[1]:(dst_xyxy[3] +
                                                                 1),
                                                    dst_xyxy[0]:(dst_xyxy[2] +
                                                                 1)]
                candidate_cropped = cv2.resize(candidate_cropped,
                                               (src_xyxy[2] - src_xyxy[0] + 1,
                                                src_xyxy[3] - src_xyxy[1] + 1),
                                               interpolation=cv2.INTER_CUBIC)
                original_cropped = src_img[src_xyxy[1]:(src_xyxy[3] + 1),
                                           src_xyxy[0]:(src_xyxy[2] + 1)]
                transfer_cropped = Monge_Kantorovitch_color_transfer(
                    original_cropped, candidate_cropped)
                src_img[src_xyxy[1]:(src_xyxy[3] + 1),
                        src_xyxy[0]:(src_xyxy[2] + 1)] = transfer_cropped

            # im1 = cv2.resize(full_image, (128, 128))
            # im2 = cv2.resize(src_img[src_xyxy[1]:(src_xyxy[3]+1), src_xyxy[0]:(src_xyxy[2]+1), :], (128, 128))
            # # im2 = cv2.resize(np.repeat(255*src_mask[...,None], 3, -1), (128, 128))
            # im3 = cv2.resize(candidate_image, (128, 128))
            # im4 = cv2.resize(candidate_cropped, (128, 128))
            # im = np.concatenate((im1, im2, im3, im4), 1)
            # cv2.imwrite("%03d_%03d.png"%(i, j), im)

            full_image = compose(full_image, src_img, src_mask)

            # boundary elision
            radius = int(0.05 * min(width, height))
            if np.amin(src_mask) > 0:
                src_mask[0, :] = 0
                src_mask[-1, :] = 0
                src_mask[:, 0] = 0
                src_mask[:, -1] = 0
            sobelx = cv2.Sobel(src_mask, cv2.CV_64F, 1, 0, ksize=3)
            sobely = cv2.Sobel(src_mask, cv2.CV_64F, 0, 1, ksize=3)
            sobel = np.abs(sobelx) + np.abs(sobely)
            edge = np.zeros_like(sobel)
            edge[sobel > 0.9] = 1.0
            morp_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
                                                    (radius, radius))
            edge = cv2.dilate(edge, morp_kernel, iterations=1)
            row, col = np.where(edge > 0)
            n_edge_pixels = len(row)
            pixel_indices = np.random.permutation(range(n_edge_pixels))
            pixel_indices = pixel_indices[:(n_edge_pixels // 2)]
            row = row[pixel_indices]
            col = col[pixel_indices]
            src_img[row, col, :] = 255

            full_mask = np.maximum(full_mask, src_mask)
            full_label[src_mask > 0] = cls_idx
            full_noice = compose(full_noice, src_img, src_mask)

            # im1 = cv2.resize(full_image, (128, 128))
            # im2 = cv2.resize(src_img[src_xyxy[1]:(src_xyxy[3]+1), src_xyxy[0]:(src_xyxy[2]+1), :], (128, 128))
            # im3 = cv2.resize(candidate_image, (128, 128))
            # im4 = cv2.resize(candidate_cropped, (128, 128))
            # im = np.concatenate((im1, im2, im3, im4), 1)
            # cv2.imwrite("%03d_%03d.png"%(i, j), im)

        output_name = str(image_index).zfill(12)
        output_path = osp.join(images_dir, output_name + '.jpg')
        cv2.imwrite(output_path,
                    clamp_array(full_image, 0, 255).astype(np.uint8))
        output_path = osp.join(noices_dir, output_name + '.jpg')
        cv2.imwrite(output_path,
                    clamp_array(full_noice, 0, 255).astype(np.uint8))
        output_path = osp.join(masks_dir, output_name + '.png')
        cv2.imwrite(output_path,
                    clamp_array(255 * full_mask, 0, 255).astype(np.uint8))
        output_path = osp.join(labels_dir, output_name + '.png')
        cv2.imwrite(output_path, full_label.astype(np.uint8))
        print(i, image_index)
Example #22
0
class simulator(object):
    def __init__(self, db, batch_size=None, nn_table=None):
        self.db = db
        self.cfg = db.cfg
        self.batch_size = batch_size if batch_size is not None else self.cfg.batch_size
        if nn_table is None:
            self.nn_table = AllCategoriesTables(db)
            self.nn_table.build_nntables_for_all_categories()
        else:
            self.nn_table = nn_table

    def reset(self):
        self.scenes = []
        frames = []
        if self.cfg.use_color_volume:
            channel_dim = 3 * self.cfg.output_vocab_size
        else:
            channel_dim = 4 + self.cfg.output_vocab_size
        for i in range(self.batch_size):
            scene = {}
            scene['out_inds'] = []
            scene['out_vecs'] = []
            scene['out_patches'] = []
            frame = np.zeros(
                (   self.cfg.input_image_size[1],
                    self.cfg.input_image_size[0],
                    channel_dim
                )
            )
            scene['last_frame'] = frame
            scene['last_label'] = np.zeros(
                (   self.cfg.input_image_size[1],
                    self.cfg.input_image_size[0]
                ), dtype=np.int32
            )
            scene['last_mask'] = np.zeros(
                (   self.cfg.input_image_size[1],
                    self.cfg.input_image_size[0]
                ), dtype=np.float32
            )
            self.scenes.append(scene)
            frames.append(frame)
        frames = np.stack(frames, axis=0)
        return torch.from_numpy(frames)

    def batch_render_to_pytorch(self, out_inds, out_vecs):
        assert(len(out_inds) == self.batch_size)
        outputs = []
        for i in range(self.batch_size):
            frame = self.update_scene(self.scenes[i],
                {'out_inds': out_inds[i], 'out_vec': out_vecs[i]})
            outputs.append(frame)
        outputs = np.stack(outputs, 0)
        return torch.from_numpy(outputs)

    def batch_redraw(self, return_sequence=False):
        out_frames, out_noises, out_masks, out_labels, out_scenes = [], [], [], [], []
        for i in range(len(self.scenes)):
            predicted_scene = self.db.prediction_outputs_to_scene(self.scenes[i], self.nn_table)
            predicted_scene['patches'] = self.scenes[i]['out_patches']
            frames, noises, masks, labels = self.render_predictions_as_output(predicted_scene, return_sequence)
            if not return_sequence:
                frames = frames[None, ...]
                noises = noises[None, ...]
                masks  = masks[None, ...]
                labels = labels[None, ...]
            out_frames.append(frames)
            out_noises.append(noises)
            out_masks.append(masks)
            out_labels.append(labels)
            out_scenes.append(predicted_scene)
        return out_frames, out_noises, out_masks, out_labels, out_scenes

    def render_predictions_as_output(self, scene, return_sequence):
        width  = scene['width']
        height = scene['height']
        clses  = scene['clses']
        boxes  = scene['boxes']
        patches = scene['patches']

        if self.cfg.use_color_volume:
            channel_dim = 3 * self.cfg.output_vocab_size
        else:
            channel_dim = 4 + self.cfg.output_vocab_size

        frame = np.zeros((height, width, channel_dim))
        noise = np.zeros((height, width, channel_dim))
        label = np.zeros((height, width), dtype=np.int32)
        mask = np.zeros((height, width), dtype=np.float32)

        out_frames, out_noises, out_labels, out_masks = [], [], [], []
        for i in range(len(clses)):
            cls_ind = clses[i]
            xywh = boxes[i]
            patch = patches[i]
            xyxy = xywh_to_xyxy(xywh, width, height)
            if self.cfg.use_color_volume:
                frame[:,:,3*cls_ind:3*(cls_ind+1)], mask, _, label, noise[:,:,3*cls_ind:3*(cls_ind+1)] = \
                    patch_compose_and_erose(frame[:,:,3*cls_ind:3*(cls_ind+1)], mask, label, \
                        xyxy, patch, self.db, noise[:,:,3*cls_ind:3*(cls_ind+1)])
            else:
                frame[:,:,-3:], mask, _, label, noise[:,:,-3:] = \
                    patch_compose_and_erose(frame[:,:,-3:], mask, label, xyxy, patch, self.db, noise[:,:,-3:])
                frame[:,:,-4] = np.maximum(mask*255, frame[:,:,-4])
                frame[:,:,cls_ind] = np.maximum(mask*255, frame[:,:,cls_ind])
            out_frames.append(frame.copy())
            out_noises.append(noise.copy())
            out_labels.append(label.copy())
            out_masks.append(mask.copy())

        if len(clses) == 0:
            out_frames.append(frame.copy())
            out_noises.append(noise.copy())
            out_labels.append(label.copy())
            out_masks.append(mask.copy())

        if return_sequence:
            return np.stack(out_frames, 0), np.stack(out_noises, 0), np.stack(out_masks, 0), np.stack(out_labels, 0)
        else:
            return out_frames[-1], out_noises[-1], out_masks[-1], out_labels[-1]

    def update_scene(self, scene, step_prediction):
        ##############################################################
        # Update the scene and the last instance of the scene
        ##############################################################
        out_inds = step_prediction['out_inds'].flatten()
        out_vec  = step_prediction['out_vec'].flatten()
        scene['out_inds'].append(out_inds)
        scene['out_vecs'].append(out_vec)
        scene['last_frame'], scene['last_mask'], scene['last_label'], current_patch = \
            self.update_frame(scene['last_frame'], scene['last_mask'], scene['last_label'], out_inds, out_vec)
        scene['out_patches'].append(current_patch)
        return scene['last_frame']

    def update_frame(self, input_frame, input_mask, input_label, input_inds, input_vec):
        if input_inds[0] <= self.cfg.EOS_idx:
            return input_frame, input_mask, input_label, None
        w = input_frame.shape[-2]
        h = input_frame.shape[-3]
        cls_ind = input_inds[0]
        xywh = self.db.index2box(input_inds[1:])
        xywh = xywh * np.array([w, h, w, h])
        xyxy = xywh_to_xyxy(xywh, w, h)
        patch = self.nn_table.retrieve(cls_ind, input_vec)[0]
        # print(patch)
        # print(patch['name'])

        # update the frame
        if self.cfg.use_color_volume:
            input_frame[:,:,3*cls_ind:3*(cls_ind+1)], input_mask, _, input_label, _ = \
                patch_compose_and_erose(input_frame[:,:,3*cls_ind:3*(cls_ind+1)], input_mask, input_label, xyxy, patch, self.db)
        else:
            input_frame[:,:,-3:], input_mask, _, input_label, _ = \
                patch_compose_and_erose(input_frame[:,:,-3:], input_mask, input_label, xyxy, patch, self.db)
            input_frame[:,:,-4] = np.maximum(255*input_mask, input_frame[:,:,-4])
            input_frame[:,:,cls_ind] = np.maximum(255*input_mask, input_frame[:,:,cls_ind])
        return input_frame, input_mask, input_label, patch
Example #23
0
def test_coco_dataloader(config):
    db = coco(config, 'train', '2017')

    all_tables = AllCategoriesTables(db)
    all_tables.build_nntables_for_all_categories(True)

    sequence_db = sequence_loader(db, all_tables)
    output_dir = osp.join(config.model_dir, 'test_coco_dataloader')
    maybe_create(output_dir)

    loader = DataLoader(sequence_db,
                        batch_size=config.batch_size,
                        shuffle=True,
                        num_workers=config.num_workers)

    start = time()
    for cnt, batched in enumerate(loader):
        x = batched['background'].float()
        y = batched['foreground'].float()
        z = batched['negative'].float()

        # x = sequence_onehot_volumn_preprocess(x, len(db.classes))
        x = sequence_color_volumn_preprocess(x, len(db.classes))
        y = sequence_onehot_volumn_preprocess(y, len(db.classes))
        z = sequence_onehot_volumn_preprocess(z, len(db.classes))

        # cv2.imwrite('mask0.png', y[0,2,:,:,-4].cpu().data.numpy())
        # cv2.imwrite('mask1.png', y[1,2,:,:,-4].cpu().data.numpy())
        # cv2.imwrite('mask2.png', y[2,2,:,:,-4].cpu().data.numpy())
        # cv2.imwrite('mask3.png', y[3,2,:,:,-4].cpu().data.numpy())
        # cv2.imwrite('label0.png', y[0,2,:,:,3].cpu().data.numpy())
        # cv2.imwrite('label1.png', y[1,2,:,:,3].cpu().data.numpy())
        # cv2.imwrite('label2.png', y[2,2,:,:,3].cpu().data.numpy())
        # cv2.imwrite('label3.png', y[3,2,:,:,3].cpu().data.numpy())
        # cv2.imwrite('color0.png', y[0,2,:,:,-3:].cpu().data.numpy())
        # cv2.imwrite('color1.png', y[1,2,:,:,-3:].cpu().data.numpy())
        # cv2.imwrite('color2.png', y[2,2,:,:,-3:].cpu().data.numpy())
        # cv2.imwrite('color3.png', y[3,2,:,:,-3:].cpu().data.numpy())
        # cv2.imwrite('bg0.png', x[0,3,:,:,9:12].cpu().data.numpy())
        # cv2.imwrite('bg1.png', x[1,3,:,:,9:12].cpu().data.numpy())
        # cv2.imwrite('bg2.png', x[2,3,:,:,9:12].cpu().data.numpy())
        # cv2.imwrite('bg3.png', x[3,3,:,:,9:12].cpu().data.numpy())

        x = (x - 128.0).permute(0, 1, 4, 2, 3)
        y = (y - 128.0).permute(0, 1, 4, 2, 3)
        z = (z - 128.0).permute(0, 1, 4, 2, 3)

        print('background', x.size())
        print('foreground', y.size())
        print('negative', z.size())
        print('word_inds', batched['word_inds'].size())
        print('word_lens', batched['word_lens'].size())
        print('fg_inds', batched['fg_inds'].size())
        print('patch_inds', batched['patch_inds'].size())
        print('out_inds', batched['out_inds'].size())
        print('out_msks', batched['out_msks'].size())
        print('foreground_resnets', batched['foreground_resnets'].size())
        print('negative_resnets', batched['negative_resnets'].size())

        print('foreground_resnets', batched['foreground_resnets'][0, 0])
        print('negative_resnets', batched['negative_resnets'][0, 0])
        print('out_msks', batched['out_msks'][0])
        print('patch_inds', batched['patch_inds'][0])

        plt.switch_backend('agg')
        bg_images = x
        fg_images = y
        neg_images = z

        bsize, ssize, n, h, w = bg_images.size()
        bg_images = bg_images.view(bsize * ssize, n, h, w)
        bg_images = tensors_to_vols(bg_images)
        bg_images = bg_images.reshape(bsize, ssize, h, w, n)

        bsize, ssize, n, h, w = fg_images.size()
        fg_images = fg_images.view(bsize * ssize, n, h, w)
        fg_images = tensors_to_vols(fg_images)
        fg_images = fg_images.reshape(bsize, ssize, h, w, n)

        bsize, ssize, n, h, w = neg_images.size()
        neg_images = neg_images.view(bsize * ssize, n, h, w)
        neg_images = tensors_to_vols(neg_images)
        neg_images = neg_images.reshape(bsize, ssize, h, w, n)

        for i in range(bsize):
            bg_seq = bg_images[i]
            fg_seq = fg_images[i]
            neg_seq = neg_images[i]
            image_idx = batched['image_index'][i]
            fg_inds = batched['fg_inds'][i]
            name = '%03d_' % i + str(image_idx).zfill(12)
            out_path = osp.join(output_dir, name + '.png')
            color = cv2.imread(batched['image_path'][i], cv2.IMREAD_COLOR)
            color, _, _ = create_squared_image(color)

            fig = plt.figure(figsize=(48, 32))
            plt.suptitle(batched['sentence'][i], fontsize=30)

            for j in range(min(len(bg_seq), 15)):
                bg, _ = heuristic_collage(bg_seq[j], 83)
                bg_mask = 255 * np.ones((bg.shape[1], bg.shape[0]))
                row, col = np.where(np.sum(np.absolute(bg), -1) == 0)
                bg_mask[row, col] = 0
                # bg = bg_seq[j][:,:,-3:]
                # bg_mask = bg_seq[j][:,:,-4]
                bg_mask = np.repeat(bg_mask[..., None], 3, -1)
                fg_color = fg_seq[j][:, :, -3:]
                # fg_mask = fg_seq[j][:,:,fg_inds[j+1]]
                fg_mask = fg_seq[j][:, :, -4]
                neg_color = neg_seq[j][:, :, -3:]
                # neg_mask = neg_seq[j][:,:,fg_inds[j+1]]
                neg_mask = neg_seq[j][:, :, -4]

                color_pair = np.concatenate((fg_color, neg_color), 1)
                mask_pair = np.concatenate((fg_mask, neg_mask), 1)
                mask_pair = np.repeat(mask_pair[..., None], 3, -1)
                patch = np.concatenate((color_pair, mask_pair), 0)
                patch = cv2.resize(patch, (bg.shape[1], bg.shape[0]))

                partially_completed_img = np.concatenate((bg, bg_mask, patch),
                                                         1)
                partially_completed_img = clamp_array(partially_completed_img,
                                                      0, 255).astype(np.uint8)
                partially_completed_img = partially_completed_img[:, :, ::-1]
                plt.subplot(4, 4, j + 1)
                plt.imshow(partially_completed_img)
                plt.axis('off')

            plt.subplot(4, 4, 16)
            plt.imshow(color[:, :, ::-1])
            plt.axis('off')

            fig.savefig(out_path, bbox_inches='tight')
            plt.close(fig)

        if cnt == 3:
            break
    print("Time", time() - start)