def decode_fn(hparams):
    """The main function."""
    decode_dict, decode_mask, label_dict = _decode_common(hparams)
    if FLAGS.problem != "android_howto":
        decode_dict["input_refs"] = decode_utils.unify_input_ref(
            decode_dict["verbs"], decode_dict["input_refs"])
    print_ops = []
    for key in [
            "raw_task", "verbs", "objects", "verb_refs", "obj_refs",
            "input_refs"
    ]:
        print_ops.append(
            tf.print(key,
                     tf.shape(decode_dict[key]),
                     decode_dict[key],
                     label_dict[key],
                     "decode_mask",
                     decode_mask,
                     summarize=100))
    acc_metrics = decode_utils.compute_seq_metrics(label_dict,
                                                   decode_dict,
                                                   mask=None)
    saver = tf.train.Saver()
    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
        tf.logging.info("Restoring from the latest checkpoint: %s" %
                        (latest_checkpoint))
        saver.restore(session, latest_checkpoint)
        task_seqs = []
        ref_seqs = []
        act_seqs = []
        mask_seqs = []
        try:
            i = 0
            while True:
                tf.logging.info("Example %d" % i)
                task, acc, mask, label, decode = session.run([
                    decode_dict["raw_task"], acc_metrics, decode_mask,
                    label_dict, decode_dict
                ])
                ref_seq = {}
                ref_seq["gt_seq"] = np.concatenate([
                    label["verb_refs"], label["obj_refs"], label["input_refs"]
                ],
                                                   axis=-1)
                ref_seq["pred_seq"] = np.concatenate([
                    decode["verb_refs"], decode["obj_refs"],
                    decode["input_refs"]
                ],
                                                     axis=-1)
                ref_seq["complete_seq_acc"] = acc["complete_refs_acc"]
                ref_seq["partial_seq_acc"] = acc["partial_refs_acc"]
                act_seq = {}
                act_seq["gt_seq"] = np.concatenate([
                    np.expand_dims(label["verbs"], 2),
                    np.expand_dims(label["objects"], 2), label["input_refs"]
                ],
                                                   axis=-1)
                act_seq["pred_seq"] = np.concatenate([
                    np.expand_dims(decode["verbs"], 2),
                    np.expand_dims(decode["objects"], 2), decode["input_refs"]
                ],
                                                     axis=-1)
                act_seq["complete_seq_acc"] = acc["complete_acts_acc"]
                act_seq["partial_seq_acc"] = acc["partial_acts_acc"]
                print("task", task)
                print("ref_seq", ref_seq)
                print("act_seq", act_seq)
                print("mask", mask)
                task_seqs.append(task)
                ref_seqs.append(ref_seq)
                act_seqs.append(act_seq)
                mask_seqs.append(mask)
                i += 1
        except tf.errors.OutOfRangeError:
            pass
        save(task_seqs, ref_seqs, mask_seqs, "joint_refs")
        save(task_seqs, act_seqs, mask_seqs, "joint_act")
Beispiel #2
0
def _eval(metrics, pred_dict, loss_dict, features, areas, compute_seq_accuracy,
          hparams, metric_types, decode_length=20):
  """Internal eval function."""
  # Assume data sources are not mixed within each batch
  if compute_seq_accuracy:
    decode_features = {}
    for key in features:
      if not key.endswith("_refs"):
        decode_features[key] = features[key]
    decode_utils.decode_n_step(seq2act_model.compute_logits,
                               decode_features, areas,
                               hparams, n=decode_length, beam_size=1)
    decode_features["input_refs"] = decode_utils.unify_input_ref(
        decode_features["verbs"], decode_features["input_refs"])
    acc_metrics = decode_utils.compute_seq_metrics(
        features, decode_features)
    metrics["seq_full_acc"] = tf.metrics.mean(acc_metrics["complete_refs_acc"])
    metrics["seq_partial_acc"] = tf.metrics.mean(
        acc_metrics["partial_refs_acc"])
    if "final_accuracy" in metric_types:
      metrics["complet_act_accuracy"] = tf.metrics.mean(
          acc_metrics["complete_acts_acc"])
      metrics["partial_seq_acc"] = tf.metrics.mean(
          acc_metrics["partial_acts_acc"])
      print0 = tf.print("*** lang", features["raw_task"], summarize=100)
      with tf.control_dependencies([print0]):
        loss_dict["total_loss"] = tf.identity(loss_dict["total_loss"])
  else:
    decode_features = None
  if "ref_accuracy" in metric_types:
    with tf.control_dependencies([
        tf.assert_equal(tf.rank(features["verb_refs"]), 3),
        tf.assert_equal(tf.shape(features["verb_refs"])[-1], 2)]):
      _ref_accuracy(features, pred_dict,
                    tf.less(features["verb_refs"][:, :, 0],
                            features["verb_refs"][:, :, 1]),
                    "verb_refs", metrics, decode_features,
                    measure_beginning_eos=True)
    _ref_accuracy(features, pred_dict,
                  tf.less(features["obj_refs"][:, :, 0],
                          features["obj_refs"][:, :, 1]),
                  "obj_refs", metrics, decode_features,
                  measure_beginning_eos=True)
    _ref_accuracy(features, pred_dict,
                  tf.less(features["input_refs"][:, :, 0],
                          features["input_refs"][:, :, 1]),
                  "input_refs", metrics, decode_features,
                  measure_beginning_eos=True)
  if "basic_accuracy" in metric_types:
    target_verbs = tf.reshape(features["verbs"], [-1])
    verb_nonpadding = tf.greater(target_verbs, 1)
    target_verbs = tf.boolean_mask(target_verbs, verb_nonpadding)
    predict_verbs = tf.boolean_mask(tf.reshape(pred_dict["verbs"], [-1]),
                                    verb_nonpadding)
    metrics["verb"] = tf.metrics.accuracy(
        labels=target_verbs,
        predictions=predict_verbs,
        name="verb_accuracy")
    input_mask = tf.reshape(
        tf.less(features["verb_refs"][:, :, 0],
                features["verb_refs"][:, :, 1]), [-1])
    metrics["input"] = tf.metrics.accuracy(
        labels=tf.boolean_mask(
            tf.reshape(tf.to_int32(
                tf.less(features["input_refs"][:, :, 0],
                        features["input_refs"][:, :, 1])), [-1]), input_mask),
        predictions=tf.boolean_mask(
            tf.reshape(pred_dict["input"], [-1]), input_mask),
        name="input_accuracy")
    metrics["object"] = tf.metrics.accuracy(
        labels=tf.boolean_mask(tf.reshape(features["objects"], [-1]),
                               verb_nonpadding),
        predictions=tf.boolean_mask(tf.reshape(pred_dict["objects"], [-1]),
                                    verb_nonpadding),
        name="object_accuracy")
    metrics["eval_object_loss"] = tf.metrics.mean(
        tf.reduce_mean(
            tf.boolean_mask(tf.reshape(loss_dict["object_losses"], [-1]),
                            verb_nonpadding)))
    metrics["eval_verb_loss"] = tf.metrics.mean(
        tf.reduce_mean(
            tf.boolean_mask(tf.reshape(loss_dict["verbs_losses"], [-1]),
                            verb_nonpadding)))