예제 #1
0
def tf_sim(
    pos_dial_embed: "tf.Tensor",
    pos_bot_embed: "tf.Tensor",
    neg_dial_embed: "tf.Tensor",
    neg_bot_embed: "tf.Tensor",
    dial_bad_negs: "tf.Tensor",
    bot_bad_negs: "tf.Tensor",
    mask: Optional["tf.Tensor"],
) -> Tuple["tf.Tensor", "tf.Tensor", "tf.Tensor", "tf.Tensor", "tf.Tensor"]:
    """Define similarity."""

    # calculate similarity with several
    # embedded actions for the loss
    neg_inf = large_compatible_negative(pos_dial_embed.dtype)

    sim_pos = tf_raw_sim(pos_dial_embed, pos_bot_embed, mask)
    sim_neg = tf_raw_sim(pos_dial_embed, neg_bot_embed,
                         mask) + neg_inf * bot_bad_negs
    sim_neg_bot_bot = (tf_raw_sim(pos_bot_embed, neg_bot_embed, mask) +
                       neg_inf * bot_bad_negs)
    sim_neg_dial_dial = (tf_raw_sim(pos_dial_embed, neg_dial_embed, mask) +
                         neg_inf * dial_bad_negs)
    sim_neg_bot_dial = (tf_raw_sim(pos_bot_embed, neg_dial_embed, mask) +
                        neg_inf * dial_bad_negs)

    # output similarities between user input and bot actions
    # and similarities between bot actions and similarities between user inputs
    return sim_pos, sim_neg, sim_neg_bot_bot, sim_neg_dial_dial, sim_neg_bot_dial
예제 #2
0
def prepare_encoder_input(features,
                          hparams,
                          embed_scope=None,
                          embed_token_fn=common_embed.embed_tokens):
    """Prepares the input for the screen encoder.

  Args:
    features: the feature dict.
    hparams: the hyperparameter.
    embed_scope: the embedding variable scope.
    embed_token_fn: the function for embedding tokens.
  Returns:
    object_embedding: a Tensor of shape
        [batch_size, num_steps, max_object_count, embed_depth]
    object_mask: a binary tensor of shape
        [batch_size, num_steps, max_object_count]
    nonpadding_bias: a Tensor of shape
        [batch_size, num_steps, max_object_count]
  """
    with tf.control_dependencies(
        [tf.assert_equal(tf.rank(features["obj_text"]), 4)]):
        if hparams.get("synthetic_screen_noise", 0.) > 0.:
            num_objects = tf.shape(features["obj_text"])[2]
            # [batch, length, num_objects]
            target_obj_mask = tf.cast(
                tf.one_hot(features["objects"], depth=num_objects), tf.bool)
            num_tokens = tf.shape(features["obj_text"])[-1]
            target_obj_mask = tf.tile(tf.expand_dims(target_obj_mask, 3),
                                      [1, 1, 1, num_tokens])
            # Randomly keep tokens
            keep_mask = tf.greater_equal(
                tf.random_uniform(shape=tf.shape(features["obj_text"])),
                hparams.synthetic_screen_noise)
            # Keep paddings
            keep_mask = tf.logical_or(tf.equal(features["obj_text"], 0),
                                      keep_mask)
            # Keep targets
            target_obj_mask = tf.logical_or(target_obj_mask, keep_mask)
            features["obj_text"] = tf.where(
                target_obj_mask, features["obj_text"],
                tf.random_uniform(shape=tf.shape(features["obj_text"]),
                                  maxval=50000,
                                  dtype=tf.int32))
        text_embeddings, _ = embed_token_fn(features["obj_text"],
                                            hparams.task_vocab_size,
                                            hparams.hidden_size,
                                            hparams,
                                            embed_scope=embed_scope)
        with tf.variable_scope("obj_text_embed", reuse=tf.AUTO_REUSE):
            if hparams.obj_text_aggregation == "max":
                embed_bias = tf.cast(tf.less(features["obj_text"], 2),
                                     tf.float32) * -1e7
                with tf.control_dependencies(
                    [tf.assert_equal(tf.rank(embed_bias), 4)]):
                    text_embeddings = tf.reduce_max(
                        text_embeddings + tf.expand_dims(embed_bias, 4), -2)
                    no_txt_embed = tf.get_variable(name="no_txt_embed",
                                                   shape=[hparams.hidden_size])
                    shape = common_layers.shape_list(text_embeddings)
                    no_txt_embed = tf.tile(
                        tf.reshape(no_txt_embed,
                                   [1, 1, 1, hparams.hidden_size]),
                        [shape[0], shape[1], shape[2], 1])
                    text_embeddings = tf.maximum(text_embeddings, no_txt_embed)
            elif hparams.obj_text_aggregation == "sum":
                # [batch, step, #max_obj, #max_token]  0 for padded tokens
                real_objects = tf.cast(
                    tf.greater_equal(features["obj_text"], 2), tf.float32)
                # [batch, step, #max_obj, hidden]   0s for padded objects
                text_embeddings = tf.reduce_sum(
                    text_embeddings * tf.expand_dims(real_objects, 4), -2)
            elif hparams.obj_text_aggregation == "mean":
                shape_list = common_layers.shape_list(text_embeddings)
                embeddings = tf.reshape(text_embeddings, [-1] + shape_list[3:])
                emb_sum = tf.reduce_sum(tf.abs(embeddings), axis=-1)
                non_paddings = tf.not_equal(emb_sum, 0.0)
                embeddings = common_embed.average_bag_of_embeds(
                    embeddings,
                    non_paddings,
                    use_bigrams=True,
                    bigram_embed_scope=embed_scope,
                    append_start_end=True)
                text_embeddings = tf.reshape(
                    embeddings, shape_list[:3] + [hparams.hidden_size])
            else:
                raise ValueError("Unrecognized token aggregation %s" %
                                 (hparams.obj_text_aggregation))
    with tf.control_dependencies([
            tf.assert_equal(tf.rank(features["obj_type"]), 3),
            tf.assert_equal(tf.rank(features["obj_clickable"]), 3)
    ]):
        with tf.variable_scope("encode_object_attr", reuse=tf.AUTO_REUSE):
            type_embedding = tf.nn.embedding_lookup(params=tf.get_variable(
                name="embed_type_w",
                shape=[hparams.get("num_types", 100), hparams.hidden_size]),
                                                    ids=tf.maximum(
                                                        features["obj_type"],
                                                        0))
            clickable_embedding = tf.nn.embedding_lookup(
                params=tf.get_variable(name="embed_clickable_w",
                                       shape=[2, hparams.hidden_size]),
                ids=features["obj_clickable"])
    with tf.control_dependencies(
        [tf.assert_equal(tf.rank(features["obj_screen_pos"]), 4)]):

        def _create_embed(feature_name, vocab_size, depth):
            """Embed a position feature."""
            pos_embedding_list = []
            with tf.variable_scope("encode_object_" + feature_name,
                                   reuse=tf.AUTO_REUSE):
                num_featues = common_layers.shape_list(
                    features[feature_name])[-1]
                for i in range(num_featues):
                    pos_embedding_list.append(
                        tf.nn.embedding_lookup(
                            params=tf.get_variable(name=feature_name +
                                                   "_embed_w_%d" % i,
                                                   shape=[vocab_size, depth]),
                            ids=features[feature_name][:, :, :, i]))
                pos_embedding = tf.add_n(pos_embedding_list)
                return pos_embedding

        pos_embedding = _create_embed("obj_screen_pos", hparams.max_pixel_pos,
                                      hparams.hidden_size)
    if "all" == hparams.screen_embedding_feature or (
            "dom" in hparams.screen_embedding_feature):
        dom_embedding = _create_embed("obj_dom_pos", hparams.max_dom_pos,
                                      hparams.hidden_size)
    object_embed = tf.zeros_like(text_embeddings, dtype=tf.float32)
    if hparams.screen_embedding_feature == "all":
        object_embed = (text_embeddings + type_embedding + pos_embedding +
                        dom_embedding)
    elif "text" in hparams.screen_embedding_feature:
        object_embed += text_embeddings
    elif "type" in hparams.screen_embedding_feature:
        object_embed += type_embedding
    elif "pos" in hparams.screen_embedding_feature:
        object_embed += pos_embedding
    elif "dom" in hparams.screen_embedding_feature:
        object_embed += dom_embedding
    elif "click" in hparams.screen_embedding_feature:
        object_embed += clickable_embedding
    object_mask = tf.cast(tf.not_equal(features["obj_type"], -1), tf.float32)
    object_embed = object_embed * tf.expand_dims(object_mask, 3)
    att_bias = (1. - object_mask) * common_attention.large_compatible_negative(
        object_embed.dtype)
    return object_embed, object_mask, att_bias