예제 #1
0
num_vocab_nmn = len(assembler.module_names)
num_choices = data_reader_tst.batch_loader.answer_dict.num_vocab

# Network inputs
input_seq_batch = tf.placeholder(tf.int32, [None, None])
seq_length_batch = tf.placeholder(tf.int32, [None])
image_feat_batch = tf.placeholder(tf.float32, [None, H_feat, W_feat, D_feat])
expr_validity_batch = tf.placeholder(tf.bool, [None])

# The model for testing
nmn3_model_tst = NMN3Model(
    image_feat_batch, input_seq_batch,
    seq_length_batch, T_decoder=T_decoder,
    num_vocab_txt=num_vocab_txt, embed_dim_txt=embed_dim_txt,
    num_vocab_nmn=num_vocab_nmn, embed_dim_nmn=embed_dim_nmn,
    lstm_dim=lstm_dim, num_layers=num_layers,
    assembler=assembler,
    encoder_dropout=False,
    decoder_dropout=False,
    decoder_sampling=False,
    num_choices=num_choices)

snapshot_saver = tf.train.Saver(max_to_keep=None)
snapshot_saver.restore(sess, snapshot_file)

def run_test(dataset_tst, save_file, eval_output_file):
    if dataset_tst is None:
        return
    print('Running test...')
    answer_correct = 0
    layout_correct = 0
예제 #2
0
image_feat_batch = tf.placeholder(tf.float32, [None, H_feat, W_feat, D_feat])
expr_validity_batch = tf.placeholder(tf.bool, [None])
answer_label_batch = tf.placeholder(tf.int32, [None])
use_gt_layout = tf.constant(True, dtype=tf.bool)
gt_layout_batch = tf.placeholder(tf.int32, [None, None])

# The model for training
nmn3_model_trn = NMN3Model(image_feat_batch,
                           input_seq_batch,
                           seq_length_batch,
                           T_decoder=T_decoder,
                           num_vocab_txt=num_vocab_txt,
                           embed_dim_txt=embed_dim_txt,
                           num_vocab_nmn=num_vocab_nmn,
                           embed_dim_nmn=embed_dim_nmn,
                           lstm_dim=lstm_dim,
                           num_layers=num_layers,
                           assembler=assembler,
                           encoder_dropout=encoder_dropout,
                           decoder_dropout=decoder_dropout,
                           decoder_sampling=decoder_sampling,
                           num_choices=num_choices,
                           use_gt_layout=use_gt_layout,
                           gt_layout_batch=gt_layout_batch)

compiler = nmn3_model_trn.compiler
scores = nmn3_model_trn.scores
log_seq_prob = nmn3_model_trn.log_seq_prob

# Loss function
softmax_loss_per_sample = tf.nn.sparse_softmax_cross_entropy_with_logits(