def prepare_train_data_set(**data_cofig): data_root_dir = data_cofig['data_root_dir'] vocab_layout_file = os.path.join(data_root_dir, data_cofig['vocab_layout_file']) assembler = Assembler(vocab_layout_file) imdb_file_trn = os.path.join(data_root_dir, 'imdb', data_cofig['imdb_file_trn']) image_feat_dir = os.path.join(data_root_dir, data_cofig['preprocess_model'], 'train') vocab_question_file = os.path.join(data_root_dir, data_cofig['vocab_question_file']) vocab_answer_file = os.path.join(data_root_dir, data_cofig['vocab_answer_file']) prune_filter_module = data_cofig['prune_filter_module'] N = data_cofig['N'] T_encoder = data_cofig['T_encoder'] T_decoder = data_cofig['T_decoder'] data_reader_trn = DataReader(imdb_file_trn, image_feat_dir, shuffle=False, one_pass=True, batch_size=N, T_encoder=T_encoder, T_decoder=T_decoder, assembler=assembler, vocab_question_file=vocab_question_file, vocab_answer_file=vocab_answer_file, prune_filter_module=prune_filter_module) num_vocab_txt = data_reader_trn.batch_loader.vocab_dict.num_vocab num_vocab_nmn = len(assembler.module_names) num_choices = data_reader_trn.batch_loader.answer_dict.num_vocab return data_reader_trn, num_vocab_txt, num_choices, num_vocab_nmn, assembler
def prepare_train_data_set(**data_cofig): data_root_dir = data_cofig['data_root_dir'] vocab_layout_file = os.path.join(data_root_dir, data_cofig['vocab_layout_file']) assembler = Assembler(vocab_layout_file) imdb_file_trn = os.path.join(data_root_dir, 'imdb', data_cofig['imdb_file_trn']) image_feat_dir = os.path.join(data_root_dir, data_cofig['preprocess_model'], 'train') vocab_question_file = os.path.join(data_root_dir, data_cofig['vocab_question_file']) vocab_answer_file = os.path.join(data_root_dir, data_cofig['vocab_answer_file']) prune_filter_module = data_cofig['prune_filter_module'] N = data_cofig['N'] T_encoder = data_cofig['T_encoder'] T_decoder = data_cofig['T_decoder'] image_depth_first = data_cofig['image_depth_first'] vqa_train_dataset = vqa_dataset(imdb_file=imdb_file_trn, image_feat_directory=image_feat_dir, T_encoder=T_encoder, T_decoder=T_decoder, assembler=assembler, vocab_question_file=vocab_question_file, vocab_answer_file=vocab_answer_file, prune_filter_module=prune_filter_module, image_depth_first=image_depth_first) data_reader_trn = DataLoader(dataset=vqa_train_dataset, batch_size=N, shuffle=True) num_vocab_txt = vqa_train_dataset.vocab_dict.num_vocab num_vocab_nmn = len(assembler.module_names) num_choices = vqa_train_dataset.answer_dict.num_vocab return data_reader_trn, num_vocab_txt, num_choices, num_vocab_nmn, assembler
T_decoder = 20 N = 64 prune_filter_module = True imdb_file_tst = './exp_clevr/data/imdb/imdb_%s.npy' % tst_image_set snapshot_file = './exp_clevr/tfmodel/%s/%s' % (exp_name, snapshot_name) save_file = './exp_clevr/results/%s/%s.%s.txt' % (exp_name, snapshot_name, tst_image_set) os.makedirs(os.path.dirname(save_file), exist_ok=True) eval_output_file = './exp_clevr/eval_outputs/%s/%s.%s.txt' % ( exp_name, snapshot_name, tst_image_set) os.makedirs(os.path.dirname(eval_output_file), exist_ok=True) assembler = Assembler(vocab_layout_file) data_reader_tst = DataReader(imdb_file_tst, shuffle=False, one_pass=True, batch_size=N, T_encoder=T_encoder, T_decoder=T_decoder, assembler=assembler, vocab_question_file=vocab_question_file, vocab_answer_file=vocab_answer_file, prune_filter_module=prune_filter_module) print('Running test ...') answer_correct_total = 0 layout_correct_total = 0 layout_valid_total = 0
def run_eval(exp_name, snapshot_name, tst_image_set, data_dir, image_feat_dir, tf_model_dir, print_log=False): vocab_question_file = os.path.join(data_dir, "vocabulary_clevr.txt") vocab_layout_file = os.path.join(data_dir, "vocabulary_layout.txt") vocab_answer_file = os.path.join(data_dir, "answers_clevr.txt") imdb_file_tst_base_name = 'imdb_%s.npy' % tst_image_set imdb_file_tst = os.path.join(data_dir, "imdb", imdb_file_tst_base_name) image_feat_dir_tst = os.path.join(image_feat_dir, tst_image_set) #module_snapshot_file = './exp_clevr/tfmodel/%s/%s' % (exp_name, "model_"+snapshot_name) module_snapshot_file = os.path.join(tf_model_dir, exp_name, "model_" + snapshot_name) assembler = Assembler(vocab_layout_file) data_reader_tst = DataReader(imdb_file_tst, image_feat_dir_tst, shuffle=False, one_pass=True, batch_size=N, T_encoder=T_encoder, T_decoder=T_decoder, assembler=assembler, vocab_question_file=vocab_question_file, vocab_answer_file=vocab_answer_file, prune_filter_module=prune_filter_module) if data_reader_tst is not None: print('Running test ...') answer_correct_total = 0 layout_correct_total = 0 layout_valid_total = 0 num_questions_total = 0 ##load my model myModel = torch.load(module_snapshot_file) for i, batch in enumerate(data_reader_tst.batches()): _, batch_size = batch['input_seq_batch'].shape input_text_seq_lens = batch['seq_length_batch'] input_text_seqs = batch['input_seq_batch'] input_layouts = batch['gt_layout_batch'] input_images = batch['image_feat_batch'] input_answers = batch['answer_label_batch'] num_questions_total += batch_size input_txt_variable = Variable(torch.LongTensor(input_text_seqs)) input_txt_variable = input_txt_variable.cuda( ) if use_cuda else input_txt_variable input_layout_variable = None _, _, myAnswer, predicted_layouts, expr_validity_array, _ = myModel( input_txt_variable=input_txt_variable, input_text_seq_lens=input_text_seq_lens, input_layout_variable=input_layout_variable, input_answers=None, input_images=input_images, sample_token=False) layout_correct_total += np.sum( np.all(predicted_layouts == input_layouts, axis=0)) answer_correct_total += np.sum( np.logical_and(expr_validity_array, myAnswer == input_answers)) layout_valid_total += np.sum(expr_validity_array) ##current accuracy layout_accuracy = layout_correct_total / num_questions_total answer_accuracy = answer_correct_total / num_questions_total layout_validity = layout_valid_total / num_questions_total if (i + 1) % 100 == 0 and print_log: print( "iter:", i + 1, " layout_accuracy=%.4f" % layout_accuracy, " answer_accuracy=%.4f" % answer_accuracy, " layout_validity=%.4f" % layout_validity, ) return layout_accuracy, layout_correct_total, num_questions_total, answer_accuracy