def tf_sim( pos_dial_embed: "tf.Tensor", pos_bot_embed: "tf.Tensor", neg_dial_embed: "tf.Tensor", neg_bot_embed: "tf.Tensor", dial_bad_negs: "tf.Tensor", bot_bad_negs: "tf.Tensor", mask: Optional["tf.Tensor"], ) -> Tuple["tf.Tensor", "tf.Tensor", "tf.Tensor", "tf.Tensor", "tf.Tensor"]: """Define similarity.""" # calculate similarity with several # embedded actions for the loss neg_inf = large_compatible_negative(pos_dial_embed.dtype) sim_pos = tf_raw_sim(pos_dial_embed, pos_bot_embed, mask) sim_neg = tf_raw_sim(pos_dial_embed, neg_bot_embed, mask) + neg_inf * bot_bad_negs sim_neg_bot_bot = (tf_raw_sim(pos_bot_embed, neg_bot_embed, mask) + neg_inf * bot_bad_negs) sim_neg_dial_dial = (tf_raw_sim(pos_dial_embed, neg_dial_embed, mask) + neg_inf * dial_bad_negs) sim_neg_bot_dial = (tf_raw_sim(pos_bot_embed, neg_dial_embed, mask) + neg_inf * dial_bad_negs) # output similarities between user input and bot actions # and similarities between bot actions and similarities between user inputs return sim_pos, sim_neg, sim_neg_bot_bot, sim_neg_dial_dial, sim_neg_bot_dial
def prepare_encoder_input(features, hparams, embed_scope=None, embed_token_fn=common_embed.embed_tokens): """Prepares the input for the screen encoder. Args: features: the feature dict. hparams: the hyperparameter. embed_scope: the embedding variable scope. embed_token_fn: the function for embedding tokens. Returns: object_embedding: a Tensor of shape [batch_size, num_steps, max_object_count, embed_depth] object_mask: a binary tensor of shape [batch_size, num_steps, max_object_count] nonpadding_bias: a Tensor of shape [batch_size, num_steps, max_object_count] """ with tf.control_dependencies( [tf.assert_equal(tf.rank(features["obj_text"]), 4)]): if hparams.get("synthetic_screen_noise", 0.) > 0.: num_objects = tf.shape(features["obj_text"])[2] # [batch, length, num_objects] target_obj_mask = tf.cast( tf.one_hot(features["objects"], depth=num_objects), tf.bool) num_tokens = tf.shape(features["obj_text"])[-1] target_obj_mask = tf.tile(tf.expand_dims(target_obj_mask, 3), [1, 1, 1, num_tokens]) # Randomly keep tokens keep_mask = tf.greater_equal( tf.random_uniform(shape=tf.shape(features["obj_text"])), hparams.synthetic_screen_noise) # Keep paddings keep_mask = tf.logical_or(tf.equal(features["obj_text"], 0), keep_mask) # Keep targets target_obj_mask = tf.logical_or(target_obj_mask, keep_mask) features["obj_text"] = tf.where( target_obj_mask, features["obj_text"], tf.random_uniform(shape=tf.shape(features["obj_text"]), maxval=50000, dtype=tf.int32)) text_embeddings, _ = embed_token_fn(features["obj_text"], hparams.task_vocab_size, hparams.hidden_size, hparams, embed_scope=embed_scope) with tf.variable_scope("obj_text_embed", reuse=tf.AUTO_REUSE): if hparams.obj_text_aggregation == "max": embed_bias = tf.cast(tf.less(features["obj_text"], 2), tf.float32) * -1e7 with tf.control_dependencies( [tf.assert_equal(tf.rank(embed_bias), 4)]): text_embeddings = tf.reduce_max( text_embeddings + tf.expand_dims(embed_bias, 4), -2) no_txt_embed = tf.get_variable(name="no_txt_embed", shape=[hparams.hidden_size]) shape = common_layers.shape_list(text_embeddings) no_txt_embed = tf.tile( tf.reshape(no_txt_embed, [1, 1, 1, hparams.hidden_size]), [shape[0], shape[1], shape[2], 1]) text_embeddings = tf.maximum(text_embeddings, no_txt_embed) elif hparams.obj_text_aggregation == "sum": # [batch, step, #max_obj, #max_token] 0 for padded tokens real_objects = tf.cast( tf.greater_equal(features["obj_text"], 2), tf.float32) # [batch, step, #max_obj, hidden] 0s for padded objects text_embeddings = tf.reduce_sum( text_embeddings * tf.expand_dims(real_objects, 4), -2) elif hparams.obj_text_aggregation == "mean": shape_list = common_layers.shape_list(text_embeddings) embeddings = tf.reshape(text_embeddings, [-1] + shape_list[3:]) emb_sum = tf.reduce_sum(tf.abs(embeddings), axis=-1) non_paddings = tf.not_equal(emb_sum, 0.0) embeddings = common_embed.average_bag_of_embeds( embeddings, non_paddings, use_bigrams=True, bigram_embed_scope=embed_scope, append_start_end=True) text_embeddings = tf.reshape( embeddings, shape_list[:3] + [hparams.hidden_size]) else: raise ValueError("Unrecognized token aggregation %s" % (hparams.obj_text_aggregation)) with tf.control_dependencies([ tf.assert_equal(tf.rank(features["obj_type"]), 3), tf.assert_equal(tf.rank(features["obj_clickable"]), 3) ]): with tf.variable_scope("encode_object_attr", reuse=tf.AUTO_REUSE): type_embedding = tf.nn.embedding_lookup(params=tf.get_variable( name="embed_type_w", shape=[hparams.get("num_types", 100), hparams.hidden_size]), ids=tf.maximum( features["obj_type"], 0)) clickable_embedding = tf.nn.embedding_lookup( params=tf.get_variable(name="embed_clickable_w", shape=[2, hparams.hidden_size]), ids=features["obj_clickable"]) with tf.control_dependencies( [tf.assert_equal(tf.rank(features["obj_screen_pos"]), 4)]): def _create_embed(feature_name, vocab_size, depth): """Embed a position feature.""" pos_embedding_list = [] with tf.variable_scope("encode_object_" + feature_name, reuse=tf.AUTO_REUSE): num_featues = common_layers.shape_list( features[feature_name])[-1] for i in range(num_featues): pos_embedding_list.append( tf.nn.embedding_lookup( params=tf.get_variable(name=feature_name + "_embed_w_%d" % i, shape=[vocab_size, depth]), ids=features[feature_name][:, :, :, i])) pos_embedding = tf.add_n(pos_embedding_list) return pos_embedding pos_embedding = _create_embed("obj_screen_pos", hparams.max_pixel_pos, hparams.hidden_size) if "all" == hparams.screen_embedding_feature or ( "dom" in hparams.screen_embedding_feature): dom_embedding = _create_embed("obj_dom_pos", hparams.max_dom_pos, hparams.hidden_size) object_embed = tf.zeros_like(text_embeddings, dtype=tf.float32) if hparams.screen_embedding_feature == "all": object_embed = (text_embeddings + type_embedding + pos_embedding + dom_embedding) elif "text" in hparams.screen_embedding_feature: object_embed += text_embeddings elif "type" in hparams.screen_embedding_feature: object_embed += type_embedding elif "pos" in hparams.screen_embedding_feature: object_embed += pos_embedding elif "dom" in hparams.screen_embedding_feature: object_embed += dom_embedding elif "click" in hparams.screen_embedding_feature: object_embed += clickable_embedding object_mask = tf.cast(tf.not_equal(features["obj_type"], -1), tf.float32) object_embed = object_embed * tf.expand_dims(object_mask, 3) att_bias = (1. - object_mask) * common_attention.large_compatible_negative( object_embed.dtype) return object_embed, object_mask, att_bias