def decode_sequence(features, areas, hparams, decode_length, post_processing=True): """Decodes the entire sequence in an auto-regressive way.""" decode_utils.decode_n_step(seq2act_model.compute_logits, features, areas, hparams, n=decode_length, beam_size=1) if post_processing: features["input_refs"] = decode_utils.unify_input_ref( features["verbs"], features["input_refs"]) pred_lengths = decode_utils.verb_refs_to_lengths(features["task"], features["verb_refs"], include_eos=False) predicted_actions = tf.concat([ features["verb_refs"], features["obj_refs"], features["input_refs"], tf.to_int32(tf.expand_dims(features["verbs"], 2)), tf.to_int32(tf.expand_dims(features["objects"], 2))], axis=-1) if post_processing: predicted_actions = tf.where( tf.tile(tf.expand_dims( tf.sequence_mask(pred_lengths, maxlen=tf.shape(predicted_actions)[1]), 2), [1, 1, tf.shape(predicted_actions)[-1]]), predicted_actions, tf.zeros_like(predicted_actions)) return predicted_actions
def _decode_common(hparams): """Common graph for decoding.""" features = get_input(hparams, FLAGS.data_files) decode_features = {} for key in features: if key.endswith("_refs"): continue decode_features[key] = features[key] _, _, _, references = seq2act_model.compute_logits( features, hparams, mode=tf.estimator.ModeKeys.EVAL) decode_utils.decode_n_step(seq2act_model.compute_logits, decode_features, references["areas"], hparams, n=20, beam_size=FLAGS.beam_size) decode_mask = generate_action_mask(decode_features) return decode_features, decode_mask, features
def _eval(metrics, pred_dict, loss_dict, features, areas, compute_seq_accuracy, hparams, metric_types, decode_length=20): """Internal eval function.""" # Assume data sources are not mixed within each batch if compute_seq_accuracy: decode_features = {} for key in features: if not key.endswith("_refs"): decode_features[key] = features[key] decode_utils.decode_n_step(seq2act_model.compute_logits, decode_features, areas, hparams, n=decode_length, beam_size=1) decode_features["input_refs"] = decode_utils.unify_input_ref( decode_features["verbs"], decode_features["input_refs"]) acc_metrics = decode_utils.compute_seq_metrics( features, decode_features) metrics["seq_full_acc"] = tf.metrics.mean(acc_metrics["complete_refs_acc"]) metrics["seq_partial_acc"] = tf.metrics.mean( acc_metrics["partial_refs_acc"]) if "final_accuracy" in metric_types: metrics["complet_act_accuracy"] = tf.metrics.mean( acc_metrics["complete_acts_acc"]) metrics["partial_seq_acc"] = tf.metrics.mean( acc_metrics["partial_acts_acc"]) print0 = tf.print("*** lang", features["raw_task"], summarize=100) with tf.control_dependencies([print0]): loss_dict["total_loss"] = tf.identity(loss_dict["total_loss"]) else: decode_features = None if "ref_accuracy" in metric_types: with tf.control_dependencies([ tf.assert_equal(tf.rank(features["verb_refs"]), 3), tf.assert_equal(tf.shape(features["verb_refs"])[-1], 2)]): _ref_accuracy(features, pred_dict, tf.less(features["verb_refs"][:, :, 0], features["verb_refs"][:, :, 1]), "verb_refs", metrics, decode_features, measure_beginning_eos=True) _ref_accuracy(features, pred_dict, tf.less(features["obj_refs"][:, :, 0], features["obj_refs"][:, :, 1]), "obj_refs", metrics, decode_features, measure_beginning_eos=True) _ref_accuracy(features, pred_dict, tf.less(features["input_refs"][:, :, 0], features["input_refs"][:, :, 1]), "input_refs", metrics, decode_features, measure_beginning_eos=True) if "basic_accuracy" in metric_types: target_verbs = tf.reshape(features["verbs"], [-1]) verb_nonpadding = tf.greater(target_verbs, 1) target_verbs = tf.boolean_mask(target_verbs, verb_nonpadding) predict_verbs = tf.boolean_mask(tf.reshape(pred_dict["verbs"], [-1]), verb_nonpadding) metrics["verb"] = tf.metrics.accuracy( labels=target_verbs, predictions=predict_verbs, name="verb_accuracy") input_mask = tf.reshape( tf.less(features["verb_refs"][:, :, 0], features["verb_refs"][:, :, 1]), [-1]) metrics["input"] = tf.metrics.accuracy( labels=tf.boolean_mask( tf.reshape(tf.to_int32( tf.less(features["input_refs"][:, :, 0], features["input_refs"][:, :, 1])), [-1]), input_mask), predictions=tf.boolean_mask( tf.reshape(pred_dict["input"], [-1]), input_mask), name="input_accuracy") metrics["object"] = tf.metrics.accuracy( labels=tf.boolean_mask(tf.reshape(features["objects"], [-1]), verb_nonpadding), predictions=tf.boolean_mask(tf.reshape(pred_dict["objects"], [-1]), verb_nonpadding), name="object_accuracy") metrics["eval_object_loss"] = tf.metrics.mean( tf.reduce_mean( tf.boolean_mask(tf.reshape(loss_dict["object_losses"], [-1]), verb_nonpadding))) metrics["eval_verb_loss"] = tf.metrics.mean( tf.reduce_mean( tf.boolean_mask(tf.reshape(loss_dict["verbs_losses"], [-1]), verb_nonpadding)))