def _compute_query_embedding(features, references, hparams, embed_scope=None): """Computes lang embeds for verb and object from predictions. Args: features: a dictionary contains "inputs" that is a tensor in shape of [batch_size, num_tokens], "verb_id_seq" that is in shape of [batch_size, num_actions], "object_spans" and "param_span" tensor in shape of [batch_size, num_actions, 2]. 0 is used as padding or non-existent values. references: the dict that keeps the reference results. hparams: the general hyperparameters for the model. embed_scope: the embedding variable scope. Returns: verb_embeds: a Tensor of shape [batch_size, num_steps, depth] object_embeds: [batch_size, num_steps, depth] """ pred_verb_refs = seq2act_reference.predict_refs( references["verb_area_logits"], references["areas"]["starts"], references["areas"]["ends"]) pred_obj_refs = seq2act_reference.predict_refs( references["obj_area_logits"], references["areas"]["starts"], references["areas"]["ends"]) input_embeddings, _ = common_embed.embed_tokens( features["task"], hparams.task_vocab_size, hparams.hidden_size, hparams, embed_scope=references["embed_scope"]) if hparams.obj_text_aggregation == "sum": area_encodings, _, _ = area_utils.compute_sum_image( input_embeddings, max_area_width=hparams.max_span) shape = common_layers.shape_list(features["task"]) encoder_input_length = shape[1] verb_embeds = seq2act_reference.span_embedding(encoder_input_length, area_encodings, pred_verb_refs, hparams) object_embeds = seq2act_reference.span_embedding( encoder_input_length, area_encodings, pred_obj_refs, hparams) elif hparams.obj_text_aggregation == "mean": verb_embeds = seq2act_reference.span_average_embed( input_embeddings, pred_verb_refs, embed_scope, hparams) object_embeds = seq2act_reference.span_average_embed( input_embeddings, pred_obj_refs, embed_scope, hparams) else: raise ValueError("Unrecognized query aggreggation %s" % (hparams.span_aggregation)) return verb_embeds, object_embeds
def encode_decode_task(features, hparams, train, attention_weights=None): """Model core graph for the one-shot action. Args: features: a dictionary contains "inputs" that is a tensor in shape of [batch_size, num_tokens], "verb_id_seq" that is in shape of [batch_size, num_actions], "object_spans" and "param_span" tensor in shape of [batch_size, num_actions, 2]. 0 is used as padding or non-existent values. hparams: the general hyperparameters for the model. train: the train mode. attention_weights: the dict to keep attention weights for analysis. Returns: loss_dict: the losses for training. prediction_dict: the predictions for action tuples. areas: the area encodings of the task. scope: the embedding scope. """ del train input_embeddings, scope = common_embed.embed_tokens( features["task"], hparams.task_vocab_size, hparams.hidden_size, hparams) with tf.variable_scope("encode_decode", reuse=tf.AUTO_REUSE): encoder_nonpadding = tf.minimum(tf.to_float(features["task"]), 1.0) input_embeddings = tf.multiply(tf.expand_dims(encoder_nonpadding, 2), input_embeddings) encoder_input, self_attention_bias, encoder_decoder_attention_bias = ( transformer.transformer_prepare_encoder(input_embeddings, None, hparams, features=None)) encoder_input = tf.nn.dropout(encoder_input, keep_prob=1.0 - hparams.layer_prepostprocess_dropout) if hparams.instruction_encoder == "transformer": encoder_output = transformer.transformer_encoder( encoder_input, self_attention_bias, hparams, save_weights_to=attention_weights, make_image_summary=not common_layers.is_xla_compiled()) else: raise ValueError("Unsupported instruction encoder %s" % (hparams.instruction_encoder)) span_rep = hparams.get("span_rep", "area") area_encodings, area_starts, area_ends = area_utils.compute_sum_image( encoder_output, max_area_width=hparams.max_span) current_shape = tf.shape(area_encodings) if span_rep == "area": area_encodings, _, _ = area_utils.compute_sum_image( encoder_output, max_area_width=hparams.max_span) elif span_rep == "basic": area_encodings = area_utils.compute_alternative_span_rep( encoder_output, input_embeddings, max_area_width=hparams.max_span, hidden_size=hparams.hidden_size, advanced=False) elif span_rep == "coref": area_encodings = area_utils.compute_alternative_span_rep( encoder_output, input_embeddings, max_area_width=hparams.max_span, hidden_size=hparams.hidden_size, advanced=True) else: raise ValueError("xyz") areas = {} areas["encodings"] = area_encodings areas["starts"] = area_starts areas["ends"] = area_ends with tf.control_dependencies([ tf.print("encoder_output", tf.shape(encoder_output)), tf.assert_equal(current_shape, tf.shape(area_encodings), summarize=100) ]): paddings = tf.cast(tf.less(self_attention_bias, -1), tf.int32) padding_sum, _, _ = area_utils.compute_sum_image( tf.expand_dims(tf.squeeze(paddings, [1, 2]), 2), max_area_width=hparams.max_span) num_areas = common_layers.shape_list(area_encodings)[1] area_paddings = tf.reshape(tf.minimum(tf.to_float(padding_sum), 1.0), [-1, num_areas]) areas["bias"] = area_paddings decoder_nonpadding = tf.to_float( tf.greater(features["verb_refs"][:, :, 1], features["verb_refs"][:, :, 0])) if hparams.instruction_encoder == "lstm": hparams_decoder = copy.copy(hparams) hparams_decoder.set_hparam("pos", "none") else: hparams_decoder = hparams decoder_input, decoder_self_attention_bias = _prepare_decoder_input( area_encodings, decoder_nonpadding, features, hparams_decoder, embed_scope=scope) decoder_input = tf.nn.dropout(decoder_input, keep_prob=1.0 - hparams.layer_prepostprocess_dropout) if hparams.instruction_decoder == "transformer": decoder_output = transformer.transformer_decoder( decoder_input=decoder_input, encoder_output=encoder_output, decoder_self_attention_bias=decoder_self_attention_bias, encoder_decoder_attention_bias=encoder_decoder_attention_bias, hparams=hparams_decoder) else: raise ValueError("Unsupported instruction encoder %s" % (hparams.instruction_encoder)) return decoder_output, decoder_nonpadding, areas, scope