Exemplo n.º 1
0
def decode_sequence(features, areas, hparams, decode_length,
                    post_processing=True):
  """Decodes the entire sequence in an auto-regressive way."""
  decode_utils.decode_n_step(seq2act_model.compute_logits,
                             features, areas,
                             hparams, n=decode_length, beam_size=1)
  if post_processing:
    features["input_refs"] = decode_utils.unify_input_ref(
        features["verbs"], features["input_refs"])
    pred_lengths = decode_utils.verb_refs_to_lengths(features["task"],
                                                     features["verb_refs"],
                                                     include_eos=False)
  predicted_actions = tf.concat([
      features["verb_refs"],
      features["obj_refs"],
      features["input_refs"],
      tf.to_int32(tf.expand_dims(features["verbs"], 2)),
      tf.to_int32(tf.expand_dims(features["objects"], 2))], axis=-1)
  if post_processing:
    predicted_actions = tf.where(
        tf.tile(tf.expand_dims(
            tf.sequence_mask(pred_lengths,
                             maxlen=tf.shape(predicted_actions)[1]),
            2), [1, 1, tf.shape(predicted_actions)[-1]]), predicted_actions,
        tf.zeros_like(predicted_actions))
  return predicted_actions
Exemplo n.º 2
0
def _ref_accuracy(features,
                  pred_dict,
                  nonpadding,
                  name,
                  metrics,
                  decode_refs=None,
                  measure_beginning_eos=False,
                  debug=False):
    """Computes the accuracy of reference prediction.

  Args:
    features: the feature dict.
    pred_dict: the dictionary to hold the prediction results.
    nonpadding: a 2D boolean tensor for masking out paddings.
    name: the name of the feature to be predicted.
    metrics: the eval metrics.
    decode_refs: decoded references.
    measure_beginning_eos: whether to measure the beginning and the end.
    debug: whether to output mismatches.
  """
    if decode_refs is not None:
        gt_seq_lengths = decode_utils.verb_refs_to_lengths(
            features["task"], features["verb_refs"])
        pr_seq_lengths = decode_utils.verb_refs_to_lengths(
            decode_refs["task"], decode_refs["verb_refs"])
        full_acc, partial_acc = decode_utils.sequence_accuracy(
            features[name],
            decode_refs[name],
            gt_seq_lengths,
            pr_seq_lengths,
            debug=debug,
            name=name)
        metrics[name + "_full_accuracy"] = tf.metrics.mean(full_acc)
        metrics[name + "_partial_accuracy"] = tf.metrics.mean(partial_acc)
    if measure_beginning_eos:
        nonpadding = tf.reshape(nonpadding, [-1])
        refs = tf.reshape(features[name], [-1, 2])
        predict_refs = tf.reshape(pred_dict[name], [-1, 2])
        metrics[name + "_start"] = tf.metrics.accuracy(
            labels=tf.boolean_mask(refs[:, 0], nonpadding),
            predictions=tf.boolean_mask(predict_refs[:, 0], nonpadding),
            name=name + "_start_accuracy")
        metrics[name + "_end"] = tf.metrics.accuracy(
            labels=tf.boolean_mask(refs[:, 1], nonpadding),
            predictions=tf.boolean_mask(predict_refs[:, 1], nonpadding),
            name=name + "_end_accuracy")