Exemple #1
0
def test_dataset(config):
    traindb = layout_coco(config, 'train')
    valdb = layout_coco(config, 'val')
    db = layout_coco(config, 'test')
    plt.switch_backend('agg')
    output_dir = osp.join(config.model_dir, 'test_dataset')
    maybe_create(output_dir)

    indices = np.random.permutation(range(len(db)))
    indices = indices[:config.n_samples]

    for i in indices:
        entry = db[i]
        layouts = db.render_indices_as_output(entry)
        image_idx = entry['image_idx']
        name = '%03d_' % i + str(image_idx).zfill(12)
        out_path = osp.join(output_dir, name + '.png')
        color = cv2.imread(entry['color_path'], cv2.IMREAD_COLOR)
        color, _, _ = create_squared_image(color)

        fig = plt.figure(figsize=(32, 16))
        plt.suptitle(entry['sentence'], fontsize=30)

        for j in range(len(layouts)):
            plt.subplot(3, 5, j + 1)
            plt.imshow(layouts[j])
            plt.axis('off')

        plt.subplot(3, 5, 15)
        plt.imshow(color[:, :, ::-1])
        plt.axis('off')

        fig.savefig(out_path, bbox_inches='tight')
        plt.close(fig)
Exemple #2
0
def train_model(config):
    transformer = volume_normalize('background')
    train_db = layout_coco(config, split='train', transform=transformer)
    val_db = layout_coco(config, split='val', transform=transformer)
    test_db = layout_coco(config, split='test', transform=transformer)

    trainer = SupervisedTrainer(train_db)
    trainer.train(train_db, val_db, test_db)
Exemple #3
0
def test_loader(config):
    from torch.utils.data import DataLoader
    transformer = volume_normalize('background')
    db = layout_coco(config, 'test', transform=transformer)

    loader = DataLoader(db,
                        batch_size=config.batch_size,
                        shuffle=True,
                        num_workers=config.num_workers)
    for cnt, batched in enumerate(loader):
        # print(batched['background'].size())
        print(batched['word_inds'].size())
        print(batched['word_lens'].size())
        # print(batched['word_inds'][1])
        # print(batched['word_lens'][1])
        print(batched['out_inds'].size())
        print(batched['out_msks'].size())
        print(batched['out_inds'][0])
        print(batched['out_msks'][0])
        # print(batched['trans_inds'].size())
        # print(batched['cls_mask'].size())
        # print(batched['pos_mask'].size())
        # cls_inds = batched['cls_inds']
        # fg_onehots = batched['foreground_onehots']
        # foo = np.argmax(fg_onehots, axis=-1)
        # assert((cls_inds == foo).all())
        # print(cls_inds, foo)
        # print(batched['word_vecs'].shape)
        # A = batched['output_clip_indices']
        # B = batched['output_clip_onehots']
        # C = np.argmax(B, axis=-1)
        # assert((A==C).all())
        # print(A[0], C[0])
        break
Exemple #4
0
def layout_demo(config, input_app):
    transformer = volume_normalize('background')
    train_db = layout_coco(config, split='train', transform=transformer)
    trainer = SupervisedTrainer(train_db)
    #input_sentences = json_load('examples/layout_samples.json')
    #print(type(input_app))
    trainer.sample_demo([input_app])
Exemple #5
0
def test_txt_encoder_coco(config):
    transformer = volume_normalize('background')
    db = layout_coco(config, 'train', transform=transformer)
    net = TextEncoder(db)

    loader = DataLoader(db,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.num_workers)

    for cnt, batched in enumerate(loader):
        input_inds = batched['word_inds'].long()
        input_lens = batched['word_lens'].long()

        print('Checking the output shapes')
        out = net(input_inds, input_lens)
        out_rfts, out_embs, out_msks, out_hids = out
        print(out_rfts.size(), out_embs.size(), out_msks.size())
        if isinstance(out_hids, tuple):
            print(out_hids[0].size())
        else:
            print(out_hids.size())
        print('m: ', out_msks[-1])

        print('Checking the embedding')
        embeded = net.embedding(input_inds)
        v1 = embeded[0, 0]
        idx = input_inds[0, 0].data.item()
        v2 = db.lang_vocab.vectors[idx]
        diff = v2 - v1
        print('Diff: (should be zero)', torch.sum(diff.abs_()))

        break
Exemple #6
0
def test_evaluator(config):
    transformer = volume_normalize('background')
    db = layout_coco(config, 'val', transform=transformer)
    output_dir = osp.join(db.cfg.model_dir, 'test_evaluator_coco')
    maybe_create(output_dir)

    ev = evaluator(db)
    for i in range(0, len(db), 2):
        # print('--------------------------------------')
        entry_1 = db[i]
        entry_2 = db[i+1]
        scene_1 = db.scenedb[i]
        scene_2 = db.scenedb[i+1]
        name_1 = osp.splitext(osp.basename(entry_1['color_path']))[0]
        name_2 = osp.splitext(osp.basename(entry_2['color_path']))[0]

        graph_1 = scene_graph(db, scene_1, entry_1['out_inds'], True)
        graph_2 = scene_graph(db, scene_2, entry_2['out_inds'], False)

        color_1 = cv2.imread(entry_1['color_path'], cv2.IMREAD_COLOR)
        color_2 = cv2.imread(entry_2['color_path'], cv2.IMREAD_COLOR)
        color_1, _, _ = create_squared_image(color_1)
        color_2, _, _ = create_squared_image(color_2)
        color_1 = cv2.resize(color_1, (config.draw_size[0], config.draw_size[1]))
        color_2 = cv2.resize(color_2, (config.draw_size[0], config.draw_size[1]))

        color_1 = visualize_unigram(config, color_1, graph_1.unigrams, (225, 0, 0))
        color_2 = visualize_unigram(config, color_2, graph_2.unigrams, (225, 0, 0))
        color_1 = visualize_bigram(config, color_1, graph_1.bigrams, (0, 255, 255))
        color_2 = visualize_bigram(config, color_2, graph_2.bigrams, (0, 255, 255))

        scores = ev.evaluate_graph(graph_1, graph_2)

        color_1 = visualize_unigram(config, color_1, ev.common_pred_unigrams, (0, 225, 0))
        color_2 = visualize_unigram(config, color_2, ev.common_gt_unigrams,   (0, 225, 0))
        color_1 = visualize_bigram(config, color_1, ev.common_pred_bigrams, (255, 255, 0))
        color_2 = visualize_bigram(config, color_2, ev.common_gt_bigrams, (255, 255, 0))

        info = eval_info(config, scores[None, ...])

        plt.switch_backend('agg')
        fig = plt.figure(figsize=(16, 10))
        title = entry_1['sentence'] + '\n' + entry_2['sentence'] + '\n'
        title += 'unigram f3: %f, bigram f3: %f, bigram sim: %f\n'%(info.unigram_F3()[0], info.bigram_F3()[0], info.bigram_coord()[0])
        title += 'scale: %f, ratio: %f, coord: %f \n'%(info.scale()[0], info.ratio()[0], info.unigram_coord()[0])

        
        plt.suptitle(title)
        plt.subplot(1, 2, 1); plt.imshow(color_1[:,:,::-1]); plt.axis('off')
        plt.subplot(1, 2, 2); plt.imshow(color_2[:,:,::-1]); plt.axis('off')

        out_path = osp.join(output_dir, name_1+'_'+name_2+'.png')
        fig.savefig(out_path, bbox_inches='tight')
        plt.close(fig)

        if i > 40:
            break
Exemple #7
0
def test_simulator(config):
    plt.switch_backend('agg')
    output_dir = osp.join(config.model_dir, 'simulator')
    maybe_create(output_dir)

    transformer = volume_normalize('background')
    db = layout_coco(config, 'val', transform=transformer)

    loader = DataLoader(db,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.num_workers)

    env = simulator(db, config.batch_size)
    env.reset()

    for cnt, batched in enumerate(loader):
        out_inds = batched['out_inds'].numpy()
        gt_paths = batched['color_path']
        img_inds = batched['image_idx']
        sents = batched['sentence']

        sequences = []
        for i in range(out_inds.shape[1]):
            frames = env.batch_render_to_pytorch(out_inds[:, i, :])
            sequences.append(frames)
        seqs1 = torch.stack(sequences, dim=1)
        print('seqs1', seqs1.size())
        seqs2 = env.batch_redraw(return_sequence=True)

        seqs = seqs2
        for i in range(len(seqs)):
            imgs = seqs[i]
            image_idx = img_inds[i]
            name = '%03d_' % i + str(image_idx.item()).zfill(9)
            out_path = osp.join(output_dir, name + '.png')
            color = cv2.imread(gt_paths[i], cv2.IMREAD_COLOR)
            color, _, _ = create_squared_image(color)

            fig = plt.figure(figsize=(32, 16))
            plt.suptitle(sents[i])

            for j in range(len(imgs)):
                plt.subplot(3, 5, j + 1)
                plt.imshow(imgs[j])
                plt.axis('off')

            plt.subplot(3, 5, 15)
            plt.imshow(color[:, :, ::-1])
            plt.axis('off')

            fig.savefig(out_path, bbox_inches='tight')
            plt.close(fig)

        break
Exemple #8
0
def test_coco_decoder(config):
    transformer = volume_normalize('background')
    db = layout_coco(config, 'val', transform=transformer)

    text_encoder = TextEncoder(db)
    img_encoder = VolumeEncoder(config)
    what_decoder = WhatDecoder(config)
    where_decoder = WhereDecoder(config)

    print('txt_encoder', get_n_params(text_encoder))
    print('vol_encoder', get_n_params(img_encoder))
    print('what_decoder', get_n_params(what_decoder))
    print('where_decoder', get_n_params(where_decoder))

    loader = DataLoader(db,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.num_workers)

    for cnt, batched in enumerate(loader):
        word_inds = batched['word_inds'].long()
        word_lens = batched['word_lens'].long()
        bg_imgs = batched['background'].float()

        encoder_states = text_encoder(word_inds, word_lens)
        bg_feats = img_encoder(bg_imgs)

        prev_bgfs = bg_feats[:, 0].unsqueeze(1)
        what_outs = what_decoder((prev_bgfs, None, None), encoder_states)
        obj_logits, rnn_outs, nxt_hids, prev_bgfs, att_ctx, att_wei = what_outs
        print('------------------------------------------')
        print('obj_logits', obj_logits.size())
        print('rnn_outs', rnn_outs.size())
        print('nxt_hids', nxt_hids.size())
        print('prev_bgfs', prev_bgfs.size())
        # print('att_ctx', att_ctx.size())
        # print('att_wei', att_wei.size())

        _, obj_inds = torch.max(obj_logits + 1.0, dim=-1)
        curr_fgfs = indices2onehots(obj_inds.cpu().data,
                                    config.output_cls_size).float()
        # curr_fgfs = curr_fgfs.unsqueeze(1)
        if config.cuda:
            curr_fgfs = curr_fgfs.cuda()

        where_outs = where_decoder((rnn_outs, curr_fgfs, prev_bgfs, att_ctx),
                                   encoder_states)
        coord_logits, attri_logits, where_ctx, where_wei = where_outs
        print('coord_logits ', coord_logits.size())
        print('attri_logits ', attri_logits.size())
        # print('att_ctx', where_ctx.size())
        # print('att_wei', where_wei.size())
        break
Exemple #9
0
def test_vol_encoder(config):
    transformer = volume_normalize('background')
    db = layout_coco(config, 'test', transform=transformer)

    vol_encoder = VolumeEncoder(config)
    print(get_n_params(vol_encoder))

    loader = DataLoader(db,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.num_workers)

    for cnt, batched in enumerate(loader):
        x = batched['background'].float()
        y = vol_encoder(x)
        print(y.size())
        break
Exemple #10
0
def test_model(config):
    transformer = volume_normalize('background')
    db = layout_coco(config, 'val', transform=transformer)
    net = DrawModel(db)

    plt.switch_backend('agg')
    output_dir = osp.join(config.model_dir, 'test_model_coco')
    maybe_create(output_dir)

    pretrained_path = osp.join(
        '../data/caches/layout_ckpts/supervised_coco_top1.pkl')
    assert osp.exists(pretrained_path)
    if config.cuda:
        states = torch.load(pretrained_path)
    else:
        states = torch.load(pretrained_path,
                            map_location=lambda storage, loc: storage)
    net.load_state_dict(states)

    loader = DataLoader(db,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.num_workers)

    net.eval()
    for cnt, batched in enumerate(loader):
        word_inds = batched['word_inds'].long()
        word_lens = batched['word_lens'].long()
        bg_images = batched['background'].float()

        fg_inds = batched['fg_inds'].long()
        gt_inds = batched['out_inds'].long()
        gt_msks = batched['out_msks'].float()

        fg_onehots = indices2onehots(fg_inds, config.output_cls_size)

        # inf_outs, _ = net((word_inds, word_lens, bg_images, fg_onehots))
        # obj_logits, coord_logits, attri_logits, enc_msks, what_wei, where_wei = inf_outs
        # print('teacher forcing')
        # print('obj_logits ', obj_logits.size())
        # print('coord_logits ', coord_logits.size())
        # print('attri_logits ', attri_logits.size())
        # if config.what_attn:
        #     print('what_att_logits ', what_wei.size())
        # if config.where_attn > 0:
        #     print('where_att_logits ', where_wei.size())
        # print('----------------------')
        # inf_outs, env = net.inference(word_inds, word_lens, -1, 0, 0, gt_inds)
        inf_outs, env = net.inference(word_inds, word_lens, -1, 2.0, 0, None)
        obj_logits, coord_logits, attri_logits, enc_msks, what_wei, where_wei = inf_outs
        print('scheduled sampling')
        print('obj_logits ', obj_logits.size())
        print('coord_logits ', coord_logits.size())
        print('attri_logits ', attri_logits.size())
        if config.what_attn:
            print('what_att_logits ', what_wei.size())
        if config.where_attn > 0:
            print('where_att_logits ', where_wei.size())
        print('----------------------')

        pred_out_inds, pred_out_msks = env.get_batch_inds_and_masks()
        print('pred_out_inds', pred_out_inds[0, 0], pred_out_inds.shape)
        print('gt_inds', gt_inds[0, 0], gt_inds.size())
        print('pred_out_msks', pred_out_msks[0, 0], pred_out_msks.shape)
        print('gt_msks', gt_msks[0, 0], gt_msks.size())

        batch_frames = env.batch_redraw(True)
        scene_inds = batched['scene_idx']
        for i in range(len(scene_inds)):
            sid = scene_inds[i]
            entry = db[sid]
            name = osp.splitext(osp.basename(entry['color_path']))[0]
            imgs = batch_frames[i]
            out_path = osp.join(output_dir, name + '.png')

            fig = plt.figure(figsize=(60, 30))
            plt.suptitle(entry['sentence'], fontsize=50)
            for j in range(len(imgs)):
                plt.subplot(4, 3, j + 1)
                plt.imshow(imgs[j, :, :, ::-1].astype(np.uint8))
                plt.axis('off')

            target = cv2.imread(entry['color_path'], cv2.IMREAD_COLOR)
            plt.subplot(4, 3, 12)
            plt.imshow(target[:, :, ::-1])
            plt.axis('off')

            fig.savefig(out_path, bbox_inches='tight')
            plt.close(fig)
        break
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--codah_dir",
        type=str,
        required=True,
        help=
        "The input data dir. Should contain train.tsv and dev.tsv files for the task."
    )
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--train_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--train_size",
                        default=0.8,
                        type=float,
                        help="Percentage of the data use for training.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=6,
                        type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument(
        "--categories",
        type=str,
        default="all",
        help=
        'String with the categories to be included or excluded separated by "-" ej: "o-i-q-n".'
    )
    parser.add_argument(
        "--exclude_categories",
        default=False,
        action="store_true",
        help=
        "the categories listed in `--categories` will be excluded,  must not be use if all categories are listed."
    )
    parser.add_argument(
        "--local_model",
        default=False,
        action="store_true",
        help=
        "This is to load the bert model from a local ckpt tensorflow index instead of downloading it."
    )
    parser.add_argument(
        "--use_pooled",
        default=False,
        action="store_true",
        help="Use the pooler output instead of the normal cls.")
    parser.add_argument("--bert_dir",
                        type=str,
                        help="The directori to load ckpt index of bert model.")
    parser.add_argument("--use_bert_adam",
                        default=False,
                        action="store_true",
                        help="Use build in BertAdam class instead of Adam.")

    args = parser.parse_args()

    print(args)

    if args.categories == "all":
        categories = CodahProcessor.get_all_categories()
    else:
        categories = set(args.categories.split('-'))

    cfg, _ = get_config()
    cfg.cuda = True
    transformer = volume_normalize('background')
    db = layout_coco(cfg, split='train', transform=transformer)

    processor = CodahProcessor(path=args.codah_dir,
                               categories=categories,
                               exclude=args.exclude_categories)
    print(" Initializing tokenizer ")
    tokenizer = BertTokenizer.from_pretrained(args.bert_model)

    print(" Creating train and dev datasets ")
    train_examples, eval_examples = processor.get_train_dev_examples(
        args.train_size)
    num_train_examples = len(train_examples)
    train_data = convert_examples_to_features(train_examples,
                                              processor.get_labels(),
                                              args.max_seq_length, tokenizer,
                                              db)
    train_sampler = RandomSampler(train_data)
    train_loader = DataLoader(train_data,
                              sampler=train_sampler,
                              batch_size=args.train_batch_size)

    eval_data = convert_examples_to_features(eval_examples,
                                             processor.get_labels(),
                                             args.max_seq_length, tokenizer,
                                             db)
    eval_sampler = RandomSampler(eval_data)
    eval_loader = DataLoader(eval_data, sampler=eval_sampler, batch_size=1)

    print(" Initializing bert model ")

    # model = CodahClasifier(model_type=args.bert_model,
    #                          from_tf=args.local_model,
    #                          tf_dir=args.model_dir,
    #                          use_pooled_output=args.use_pooled,
    #                          freeze_bert=False).cuda()

    # model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=1)
    model = CodahClassifier(args.bert_model, db)
    model.cuda()
    train_and_validate(model,
                       train_loader,
                       eval_loader,
                       tokenizer,
                       processor,
                       args.num_train_epochs,
                       args.learning_rate,
                       args.train_batch_size,
                       num_train_examples,
                       args.warmup,
                       print_every=10,
                       use_bert_adam=args.use_bert_adam)