示例#1
0
def _compute_query_embedding(features, references, hparams, embed_scope=None):
    """Computes lang embeds for verb and object from predictions.

  Args:
    features: a dictionary contains "inputs" that is a tensor in shape of
        [batch_size, num_tokens], "verb_id_seq" that is in shape of
        [batch_size, num_actions], "object_spans" and "param_span" tensor
        in shape of [batch_size, num_actions, 2]. 0 is used as padding or
        non-existent values.
    references: the dict that keeps the reference results.
    hparams: the general hyperparameters for the model.
    embed_scope: the embedding variable scope.
  Returns:
    verb_embeds: a Tensor of shape
        [batch_size, num_steps, depth]
    object_embeds:
        [batch_size, num_steps, depth]
  """
    pred_verb_refs = seq2act_reference.predict_refs(
        references["verb_area_logits"], references["areas"]["starts"],
        references["areas"]["ends"])
    pred_obj_refs = seq2act_reference.predict_refs(
        references["obj_area_logits"], references["areas"]["starts"],
        references["areas"]["ends"])
    input_embeddings, _ = common_embed.embed_tokens(
        features["task"],
        hparams.task_vocab_size,
        hparams.hidden_size,
        hparams,
        embed_scope=references["embed_scope"])
    if hparams.obj_text_aggregation == "sum":
        area_encodings, _, _ = area_utils.compute_sum_image(
            input_embeddings, max_area_width=hparams.max_span)
        shape = common_layers.shape_list(features["task"])
        encoder_input_length = shape[1]
        verb_embeds = seq2act_reference.span_embedding(encoder_input_length,
                                                       area_encodings,
                                                       pred_verb_refs, hparams)
        object_embeds = seq2act_reference.span_embedding(
            encoder_input_length, area_encodings, pred_obj_refs, hparams)
    elif hparams.obj_text_aggregation == "mean":
        verb_embeds = seq2act_reference.span_average_embed(
            input_embeddings, pred_verb_refs, embed_scope, hparams)
        object_embeds = seq2act_reference.span_average_embed(
            input_embeddings, pred_obj_refs, embed_scope, hparams)
    else:
        raise ValueError("Unrecognized query aggreggation %s" %
                         (hparams.span_aggregation))
    return verb_embeds, object_embeds
示例#2
0
def encode_decode_task(features, hparams, train, attention_weights=None):
    """Model core graph for the one-shot action.

  Args:
    features: a dictionary contains "inputs" that is a tensor in shape of
        [batch_size, num_tokens], "verb_id_seq" that is in shape of
        [batch_size, num_actions], "object_spans" and "param_span" tensor
        in shape of [batch_size, num_actions, 2]. 0 is used as padding or
        non-existent values.
    hparams: the general hyperparameters for the model.
    train: the train mode.
    attention_weights: the dict to keep attention weights for analysis.
  Returns:
    loss_dict: the losses for training.
    prediction_dict: the predictions for action tuples.
    areas: the area encodings of the task.
    scope: the embedding scope.
  """
    del train
    input_embeddings, scope = common_embed.embed_tokens(
        features["task"], hparams.task_vocab_size, hparams.hidden_size,
        hparams)
    with tf.variable_scope("encode_decode", reuse=tf.AUTO_REUSE):
        encoder_nonpadding = tf.minimum(tf.to_float(features["task"]), 1.0)
        input_embeddings = tf.multiply(tf.expand_dims(encoder_nonpadding, 2),
                                       input_embeddings)
        encoder_input, self_attention_bias, encoder_decoder_attention_bias = (
            transformer.transformer_prepare_encoder(input_embeddings,
                                                    None,
                                                    hparams,
                                                    features=None))
        encoder_input = tf.nn.dropout(encoder_input,
                                      keep_prob=1.0 -
                                      hparams.layer_prepostprocess_dropout)
        if hparams.instruction_encoder == "transformer":
            encoder_output = transformer.transformer_encoder(
                encoder_input,
                self_attention_bias,
                hparams,
                save_weights_to=attention_weights,
                make_image_summary=not common_layers.is_xla_compiled())
        else:
            raise ValueError("Unsupported instruction encoder %s" %
                             (hparams.instruction_encoder))
        span_rep = hparams.get("span_rep", "area")
        area_encodings, area_starts, area_ends = area_utils.compute_sum_image(
            encoder_output, max_area_width=hparams.max_span)
        current_shape = tf.shape(area_encodings)
        if span_rep == "area":
            area_encodings, _, _ = area_utils.compute_sum_image(
                encoder_output, max_area_width=hparams.max_span)
        elif span_rep == "basic":
            area_encodings = area_utils.compute_alternative_span_rep(
                encoder_output,
                input_embeddings,
                max_area_width=hparams.max_span,
                hidden_size=hparams.hidden_size,
                advanced=False)
        elif span_rep == "coref":
            area_encodings = area_utils.compute_alternative_span_rep(
                encoder_output,
                input_embeddings,
                max_area_width=hparams.max_span,
                hidden_size=hparams.hidden_size,
                advanced=True)
        else:
            raise ValueError("xyz")
        areas = {}
        areas["encodings"] = area_encodings
        areas["starts"] = area_starts
        areas["ends"] = area_ends
        with tf.control_dependencies([
                tf.print("encoder_output", tf.shape(encoder_output)),
                tf.assert_equal(current_shape,
                                tf.shape(area_encodings),
                                summarize=100)
        ]):
            paddings = tf.cast(tf.less(self_attention_bias, -1), tf.int32)
        padding_sum, _, _ = area_utils.compute_sum_image(
            tf.expand_dims(tf.squeeze(paddings, [1, 2]), 2),
            max_area_width=hparams.max_span)
        num_areas = common_layers.shape_list(area_encodings)[1]
        area_paddings = tf.reshape(tf.minimum(tf.to_float(padding_sum), 1.0),
                                   [-1, num_areas])
        areas["bias"] = area_paddings
        decoder_nonpadding = tf.to_float(
            tf.greater(features["verb_refs"][:, :, 1],
                       features["verb_refs"][:, :, 0]))
        if hparams.instruction_encoder == "lstm":
            hparams_decoder = copy.copy(hparams)
            hparams_decoder.set_hparam("pos", "none")
        else:
            hparams_decoder = hparams
        decoder_input, decoder_self_attention_bias = _prepare_decoder_input(
            area_encodings,
            decoder_nonpadding,
            features,
            hparams_decoder,
            embed_scope=scope)
        decoder_input = tf.nn.dropout(decoder_input,
                                      keep_prob=1.0 -
                                      hparams.layer_prepostprocess_dropout)
        if hparams.instruction_decoder == "transformer":
            decoder_output = transformer.transformer_decoder(
                decoder_input=decoder_input,
                encoder_output=encoder_output,
                decoder_self_attention_bias=decoder_self_attention_bias,
                encoder_decoder_attention_bias=encoder_decoder_attention_bias,
                hparams=hparams_decoder)
        else:
            raise ValueError("Unsupported instruction encoder %s" %
                             (hparams.instruction_encoder))
        return decoder_output, decoder_nonpadding, areas, scope