コード例 #1
0
    var_names = {v.op.name: v for v in tf.global_variables()}
snapshot_file = cfg.TEST.SNAPSHOT_FILE % (cfg.EXP_NAME, cfg.TEST.ITER)
snapshot_saver = tf.train.Saver(var_names)
snapshot_saver.restore(sess, snapshot_file)

# Write results
result_dir = cfg.TEST.RESULT_DIR % (cfg.EXP_NAME, cfg.TEST.ITER)
vis_dir = os.path.join(
    result_dir, 'loc_%s_%s' % (cfg.TEST.VIS_DIR_PREFIX, cfg.TEST.SPLIT_LOC))
os.makedirs(result_dir, exist_ok=True)
os.makedirs(vis_dir, exist_ok=True)

# Run test
bbox_correct, num_questions = 0, 0
iou_th = cfg.TEST.BBOX_IOU_THRESH
for n_batch, batch in enumerate(data_reader.batches()):
    fetch_list = [model.loc_scores, model.bbox_offset]
    bbox_incorrect = num_questions - bbox_correct
    if cfg.TEST.VIS_SEPARATE_CORRECTNESS:
        run_vis = (
            bbox_correct < cfg.TEST.NUM_VIS_CORRECT or
            bbox_incorrect < cfg.TEST.NUM_VIS_INCORRECT)
    else:
        run_vis = num_questions < cfg.TEST.NUM_VIS
    if run_vis:
        fetch_list.append(model.vis_outputs)
    fetch_list_val = sess.run(fetch_list, feed_dict={
            input_seq_batch: batch['input_seq_batch'],
            seq_length_batch: batch['seq_length_batch'],
            image_feat_batch: batch['image_feat_batch']})
コード例 #2
0
ファイル: train_net_joint.py プロジェクト: ronghanghu/snmn
summary_trn.append(tf.summary.scalar("loss/bbox_ind", loss_bbox_ind_ph))
summary_trn.append(tf.summary.scalar("loss/bbox_offset", loss_bbox_offset_ph))
summary_trn.append(tf.summary.scalar("loss/layout", loss_layout_ph))
summary_trn.append(tf.summary.scalar("loss/rec", loss_rec_ph))
summary_trn.append(tf.summary.scalar("loss/sharpen", loss_sharpen_ph))
summary_trn.append(tf.summary.scalar("loss/sharpen_scale", sharpen_scale_ph))
summary_trn.append(tf.summary.scalar("eval/vqa/accuracy", vqa_accuracy_ph))
summary_trn.append(tf.summary.scalar("eval/loc/P1", loc_accuracy_ph))
log_step_trn = tf.summary.merge(summary_trn)

# Run training
vqa_avg_accuracy, loc_avg_accuracy, accuracy_decay = 0., 0., 0.99
iou_th = cfg.TRAIN.BBOX_IOU_THRESH
sharpen_loss_scaler = SharpenLossScaler(cfg)
for n_batch, (batch_vqa, batch_loc) in enumerate(
        zip(data_reader_vqa.batches(), data_reader_loc.batches())):
    n_iter = n_batch + cfg.TRAIN.START_ITER
    if n_iter >= cfg.TRAIN.MAX_ITER:
        break

    sharpen_scale = sharpen_loss_scaler(n_iter)
    feed_dict = {
        input_seq_batch:
        batch_vqa['input_seq_batch'],
        seq_length_batch:
        batch_vqa['seq_length_batch'],
        image_feat_batch:
        batch_vqa['image_feat_batch'],
        answer_label_batch:
        batch_vqa['answer_label_batch'],
        bbox_ind_batch: