def test_dataset(config): traindb = layout_coco(config, 'train') valdb = layout_coco(config, 'val') db = layout_coco(config, 'test') plt.switch_backend('agg') output_dir = osp.join(config.model_dir, 'test_dataset') maybe_create(output_dir) indices = np.random.permutation(range(len(db))) indices = indices[:config.n_samples] for i in indices: entry = db[i] layouts = db.render_indices_as_output(entry) image_idx = entry['image_idx'] name = '%03d_' % i + str(image_idx).zfill(12) out_path = osp.join(output_dir, name + '.png') color = cv2.imread(entry['color_path'], cv2.IMREAD_COLOR) color, _, _ = create_squared_image(color) fig = plt.figure(figsize=(32, 16)) plt.suptitle(entry['sentence'], fontsize=30) for j in range(len(layouts)): plt.subplot(3, 5, j + 1) plt.imshow(layouts[j]) plt.axis('off') plt.subplot(3, 5, 15) plt.imshow(color[:, :, ::-1]) plt.axis('off') fig.savefig(out_path, bbox_inches='tight') plt.close(fig)
def train_model(config): transformer = volume_normalize('background') train_db = layout_coco(config, split='train', transform=transformer) val_db = layout_coco(config, split='val', transform=transformer) test_db = layout_coco(config, split='test', transform=transformer) trainer = SupervisedTrainer(train_db) trainer.train(train_db, val_db, test_db)
def test_loader(config): from torch.utils.data import DataLoader transformer = volume_normalize('background') db = layout_coco(config, 'test', transform=transformer) loader = DataLoader(db, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers) for cnt, batched in enumerate(loader): # print(batched['background'].size()) print(batched['word_inds'].size()) print(batched['word_lens'].size()) # print(batched['word_inds'][1]) # print(batched['word_lens'][1]) print(batched['out_inds'].size()) print(batched['out_msks'].size()) print(batched['out_inds'][0]) print(batched['out_msks'][0]) # print(batched['trans_inds'].size()) # print(batched['cls_mask'].size()) # print(batched['pos_mask'].size()) # cls_inds = batched['cls_inds'] # fg_onehots = batched['foreground_onehots'] # foo = np.argmax(fg_onehots, axis=-1) # assert((cls_inds == foo).all()) # print(cls_inds, foo) # print(batched['word_vecs'].shape) # A = batched['output_clip_indices'] # B = batched['output_clip_onehots'] # C = np.argmax(B, axis=-1) # assert((A==C).all()) # print(A[0], C[0]) break
def layout_demo(config, input_app): transformer = volume_normalize('background') train_db = layout_coco(config, split='train', transform=transformer) trainer = SupervisedTrainer(train_db) #input_sentences = json_load('examples/layout_samples.json') #print(type(input_app)) trainer.sample_demo([input_app])
def test_txt_encoder_coco(config): transformer = volume_normalize('background') db = layout_coco(config, 'train', transform=transformer) net = TextEncoder(db) loader = DataLoader(db, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers) for cnt, batched in enumerate(loader): input_inds = batched['word_inds'].long() input_lens = batched['word_lens'].long() print('Checking the output shapes') out = net(input_inds, input_lens) out_rfts, out_embs, out_msks, out_hids = out print(out_rfts.size(), out_embs.size(), out_msks.size()) if isinstance(out_hids, tuple): print(out_hids[0].size()) else: print(out_hids.size()) print('m: ', out_msks[-1]) print('Checking the embedding') embeded = net.embedding(input_inds) v1 = embeded[0, 0] idx = input_inds[0, 0].data.item() v2 = db.lang_vocab.vectors[idx] diff = v2 - v1 print('Diff: (should be zero)', torch.sum(diff.abs_())) break
def test_evaluator(config): transformer = volume_normalize('background') db = layout_coco(config, 'val', transform=transformer) output_dir = osp.join(db.cfg.model_dir, 'test_evaluator_coco') maybe_create(output_dir) ev = evaluator(db) for i in range(0, len(db), 2): # print('--------------------------------------') entry_1 = db[i] entry_2 = db[i+1] scene_1 = db.scenedb[i] scene_2 = db.scenedb[i+1] name_1 = osp.splitext(osp.basename(entry_1['color_path']))[0] name_2 = osp.splitext(osp.basename(entry_2['color_path']))[0] graph_1 = scene_graph(db, scene_1, entry_1['out_inds'], True) graph_2 = scene_graph(db, scene_2, entry_2['out_inds'], False) color_1 = cv2.imread(entry_1['color_path'], cv2.IMREAD_COLOR) color_2 = cv2.imread(entry_2['color_path'], cv2.IMREAD_COLOR) color_1, _, _ = create_squared_image(color_1) color_2, _, _ = create_squared_image(color_2) color_1 = cv2.resize(color_1, (config.draw_size[0], config.draw_size[1])) color_2 = cv2.resize(color_2, (config.draw_size[0], config.draw_size[1])) color_1 = visualize_unigram(config, color_1, graph_1.unigrams, (225, 0, 0)) color_2 = visualize_unigram(config, color_2, graph_2.unigrams, (225, 0, 0)) color_1 = visualize_bigram(config, color_1, graph_1.bigrams, (0, 255, 255)) color_2 = visualize_bigram(config, color_2, graph_2.bigrams, (0, 255, 255)) scores = ev.evaluate_graph(graph_1, graph_2) color_1 = visualize_unigram(config, color_1, ev.common_pred_unigrams, (0, 225, 0)) color_2 = visualize_unigram(config, color_2, ev.common_gt_unigrams, (0, 225, 0)) color_1 = visualize_bigram(config, color_1, ev.common_pred_bigrams, (255, 255, 0)) color_2 = visualize_bigram(config, color_2, ev.common_gt_bigrams, (255, 255, 0)) info = eval_info(config, scores[None, ...]) plt.switch_backend('agg') fig = plt.figure(figsize=(16, 10)) title = entry_1['sentence'] + '\n' + entry_2['sentence'] + '\n' title += 'unigram f3: %f, bigram f3: %f, bigram sim: %f\n'%(info.unigram_F3()[0], info.bigram_F3()[0], info.bigram_coord()[0]) title += 'scale: %f, ratio: %f, coord: %f \n'%(info.scale()[0], info.ratio()[0], info.unigram_coord()[0]) plt.suptitle(title) plt.subplot(1, 2, 1); plt.imshow(color_1[:,:,::-1]); plt.axis('off') plt.subplot(1, 2, 2); plt.imshow(color_2[:,:,::-1]); plt.axis('off') out_path = osp.join(output_dir, name_1+'_'+name_2+'.png') fig.savefig(out_path, bbox_inches='tight') plt.close(fig) if i > 40: break
def test_simulator(config): plt.switch_backend('agg') output_dir = osp.join(config.model_dir, 'simulator') maybe_create(output_dir) transformer = volume_normalize('background') db = layout_coco(config, 'val', transform=transformer) loader = DataLoader(db, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers) env = simulator(db, config.batch_size) env.reset() for cnt, batched in enumerate(loader): out_inds = batched['out_inds'].numpy() gt_paths = batched['color_path'] img_inds = batched['image_idx'] sents = batched['sentence'] sequences = [] for i in range(out_inds.shape[1]): frames = env.batch_render_to_pytorch(out_inds[:, i, :]) sequences.append(frames) seqs1 = torch.stack(sequences, dim=1) print('seqs1', seqs1.size()) seqs2 = env.batch_redraw(return_sequence=True) seqs = seqs2 for i in range(len(seqs)): imgs = seqs[i] image_idx = img_inds[i] name = '%03d_' % i + str(image_idx.item()).zfill(9) out_path = osp.join(output_dir, name + '.png') color = cv2.imread(gt_paths[i], cv2.IMREAD_COLOR) color, _, _ = create_squared_image(color) fig = plt.figure(figsize=(32, 16)) plt.suptitle(sents[i]) for j in range(len(imgs)): plt.subplot(3, 5, j + 1) plt.imshow(imgs[j]) plt.axis('off') plt.subplot(3, 5, 15) plt.imshow(color[:, :, ::-1]) plt.axis('off') fig.savefig(out_path, bbox_inches='tight') plt.close(fig) break
def test_coco_decoder(config): transformer = volume_normalize('background') db = layout_coco(config, 'val', transform=transformer) text_encoder = TextEncoder(db) img_encoder = VolumeEncoder(config) what_decoder = WhatDecoder(config) where_decoder = WhereDecoder(config) print('txt_encoder', get_n_params(text_encoder)) print('vol_encoder', get_n_params(img_encoder)) print('what_decoder', get_n_params(what_decoder)) print('where_decoder', get_n_params(where_decoder)) loader = DataLoader(db, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers) for cnt, batched in enumerate(loader): word_inds = batched['word_inds'].long() word_lens = batched['word_lens'].long() bg_imgs = batched['background'].float() encoder_states = text_encoder(word_inds, word_lens) bg_feats = img_encoder(bg_imgs) prev_bgfs = bg_feats[:, 0].unsqueeze(1) what_outs = what_decoder((prev_bgfs, None, None), encoder_states) obj_logits, rnn_outs, nxt_hids, prev_bgfs, att_ctx, att_wei = what_outs print('------------------------------------------') print('obj_logits', obj_logits.size()) print('rnn_outs', rnn_outs.size()) print('nxt_hids', nxt_hids.size()) print('prev_bgfs', prev_bgfs.size()) # print('att_ctx', att_ctx.size()) # print('att_wei', att_wei.size()) _, obj_inds = torch.max(obj_logits + 1.0, dim=-1) curr_fgfs = indices2onehots(obj_inds.cpu().data, config.output_cls_size).float() # curr_fgfs = curr_fgfs.unsqueeze(1) if config.cuda: curr_fgfs = curr_fgfs.cuda() where_outs = where_decoder((rnn_outs, curr_fgfs, prev_bgfs, att_ctx), encoder_states) coord_logits, attri_logits, where_ctx, where_wei = where_outs print('coord_logits ', coord_logits.size()) print('attri_logits ', attri_logits.size()) # print('att_ctx', where_ctx.size()) # print('att_wei', where_wei.size()) break
def test_vol_encoder(config): transformer = volume_normalize('background') db = layout_coco(config, 'test', transform=transformer) vol_encoder = VolumeEncoder(config) print(get_n_params(vol_encoder)) loader = DataLoader(db, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers) for cnt, batched in enumerate(loader): x = batched['background'].float() y = vol_encoder(x) print(y.size()) break
def test_model(config): transformer = volume_normalize('background') db = layout_coco(config, 'val', transform=transformer) net = DrawModel(db) plt.switch_backend('agg') output_dir = osp.join(config.model_dir, 'test_model_coco') maybe_create(output_dir) pretrained_path = osp.join( '../data/caches/layout_ckpts/supervised_coco_top1.pkl') assert osp.exists(pretrained_path) if config.cuda: states = torch.load(pretrained_path) else: states = torch.load(pretrained_path, map_location=lambda storage, loc: storage) net.load_state_dict(states) loader = DataLoader(db, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers) net.eval() for cnt, batched in enumerate(loader): word_inds = batched['word_inds'].long() word_lens = batched['word_lens'].long() bg_images = batched['background'].float() fg_inds = batched['fg_inds'].long() gt_inds = batched['out_inds'].long() gt_msks = batched['out_msks'].float() fg_onehots = indices2onehots(fg_inds, config.output_cls_size) # inf_outs, _ = net((word_inds, word_lens, bg_images, fg_onehots)) # obj_logits, coord_logits, attri_logits, enc_msks, what_wei, where_wei = inf_outs # print('teacher forcing') # print('obj_logits ', obj_logits.size()) # print('coord_logits ', coord_logits.size()) # print('attri_logits ', attri_logits.size()) # if config.what_attn: # print('what_att_logits ', what_wei.size()) # if config.where_attn > 0: # print('where_att_logits ', where_wei.size()) # print('----------------------') # inf_outs, env = net.inference(word_inds, word_lens, -1, 0, 0, gt_inds) inf_outs, env = net.inference(word_inds, word_lens, -1, 2.0, 0, None) obj_logits, coord_logits, attri_logits, enc_msks, what_wei, where_wei = inf_outs print('scheduled sampling') print('obj_logits ', obj_logits.size()) print('coord_logits ', coord_logits.size()) print('attri_logits ', attri_logits.size()) if config.what_attn: print('what_att_logits ', what_wei.size()) if config.where_attn > 0: print('where_att_logits ', where_wei.size()) print('----------------------') pred_out_inds, pred_out_msks = env.get_batch_inds_and_masks() print('pred_out_inds', pred_out_inds[0, 0], pred_out_inds.shape) print('gt_inds', gt_inds[0, 0], gt_inds.size()) print('pred_out_msks', pred_out_msks[0, 0], pred_out_msks.shape) print('gt_msks', gt_msks[0, 0], gt_msks.size()) batch_frames = env.batch_redraw(True) scene_inds = batched['scene_idx'] for i in range(len(scene_inds)): sid = scene_inds[i] entry = db[sid] name = osp.splitext(osp.basename(entry['color_path']))[0] imgs = batch_frames[i] out_path = osp.join(output_dir, name + '.png') fig = plt.figure(figsize=(60, 30)) plt.suptitle(entry['sentence'], fontsize=50) for j in range(len(imgs)): plt.subplot(4, 3, j + 1) plt.imshow(imgs[j, :, :, ::-1].astype(np.uint8)) plt.axis('off') target = cv2.imread(entry['color_path'], cv2.IMREAD_COLOR) plt.subplot(4, 3, 12) plt.imshow(target[:, :, ::-1]) plt.axis('off') fig.savefig(out_path, bbox_inches='tight') plt.close(fig) break
def main(): parser = argparse.ArgumentParser() parser.add_argument( "--codah_dir", type=str, required=True, help= "The input data dir. Should contain train.tsv and dev.tsv files for the task." ) parser.add_argument( "--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument( "--max_seq_length", default=128, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--train_batch_size", default=8, type=int, help="Total batch size for training.") parser.add_argument("--train_size", default=0.8, type=float, help="Percentage of the data use for training.") parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--num_train_epochs", default=6, type=int, help="Total number of training epochs to perform.") parser.add_argument( "--warmup", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument( "--categories", type=str, default="all", help= 'String with the categories to be included or excluded separated by "-" ej: "o-i-q-n".' ) parser.add_argument( "--exclude_categories", default=False, action="store_true", help= "the categories listed in `--categories` will be excluded, must not be use if all categories are listed." ) parser.add_argument( "--local_model", default=False, action="store_true", help= "This is to load the bert model from a local ckpt tensorflow index instead of downloading it." ) parser.add_argument( "--use_pooled", default=False, action="store_true", help="Use the pooler output instead of the normal cls.") parser.add_argument("--bert_dir", type=str, help="The directori to load ckpt index of bert model.") parser.add_argument("--use_bert_adam", default=False, action="store_true", help="Use build in BertAdam class instead of Adam.") args = parser.parse_args() print(args) if args.categories == "all": categories = CodahProcessor.get_all_categories() else: categories = set(args.categories.split('-')) cfg, _ = get_config() cfg.cuda = True transformer = volume_normalize('background') db = layout_coco(cfg, split='train', transform=transformer) processor = CodahProcessor(path=args.codah_dir, categories=categories, exclude=args.exclude_categories) print(" Initializing tokenizer ") tokenizer = BertTokenizer.from_pretrained(args.bert_model) print(" Creating train and dev datasets ") train_examples, eval_examples = processor.get_train_dev_examples( args.train_size) num_train_examples = len(train_examples) train_data = convert_examples_to_features(train_examples, processor.get_labels(), args.max_seq_length, tokenizer, db) train_sampler = RandomSampler(train_data) train_loader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) eval_data = convert_examples_to_features(eval_examples, processor.get_labels(), args.max_seq_length, tokenizer, db) eval_sampler = RandomSampler(eval_data) eval_loader = DataLoader(eval_data, sampler=eval_sampler, batch_size=1) print(" Initializing bert model ") # model = CodahClasifier(model_type=args.bert_model, # from_tf=args.local_model, # tf_dir=args.model_dir, # use_pooled_output=args.use_pooled, # freeze_bert=False).cuda() # model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=1) model = CodahClassifier(args.bert_model, db) model.cuda() train_and_validate(model, train_loader, eval_loader, tokenizer, processor, args.num_train_epochs, args.learning_rate, args.train_batch_size, num_train_examples, args.warmup, print_every=10, use_bert_adam=args.use_bert_adam)