def build_encoder(self, input_lengths, input_mask, *args, **kargs): reuse = kargs["reuse"] word_emb, entity_emb = self.build_emebdding(*args, **kargs) dropout_rate = tf.cond(self.is_training, lambda:self.config.dropout_rate, lambda:0.0) with tf.variable_scope(self.config.scope+"_encoder", reuse=reuse): input_dim = word_emb.get_shape()[-1] word_emb = match_utils.multi_highway_layer(word_emb, input_dim, self.config.highway_layer_num) [sent_repres_fw, sent_repres_bw, sent_repres] = layer_utils.my_lstm_layer(word_emb, self.config.context_lstm_dim, input_lengths=input_lengths, scope_name=self.config.scope, reuse=reuse, is_training=self.is_training, dropout_rate=dropout_rate, use_cudnn=self.config.use_cudnn) # word_emb = tf.layers.dense(word_emb, self.emb_size) memory_tran = tf.transpose(self.memory, [1,0]) # e * c word_emb_ = tf.expand_dims(sent_repres, 3) input_mask = tf.cast(input_mask, tf.float32) print(word_emb_.get_shape(), "======emb shape======") H_enc = leam_utils.att_emb_ngram_encoder_maxout( word_emb_, input_mask, self.memory, memory_tran, self.config ) print("===H_enc shape===", H_enc.get_shape()) return H_enc
def build_encoder(self, index, *args, **kargs): reuse = kargs["reuse"] word_emb = self.build_emebdding(index, *args, **kargs) dropout_rate = tf.cond(self.is_training, lambda:self.config.dropout_rate, lambda:0.0) word_emb = tf.nn.dropout(word_emb, 1-dropout_rate) with tf.variable_scope(self.config.scope+"_input_highway", reuse=reuse): input_dim = word_emb.get_shape()[-1] sent_repres = match_utils.multi_highway_layer(word_emb, input_dim, self.config.highway_layer_num) return sent_repres
def _semantic_feature_layer(self, seq_input, seq_len, granularity="word", reuse=False): assert granularity in ["char", "word"] with self.graph.as_default(): emb_seq = tf.nn.embedding_lookup(self.emb_mat, seq_input) dropout_rate = tf.cond(self.is_training, lambda:self.config.dropout_rate, lambda:0.0) input_dim = emb_seq.get_shape()[-1] with tf.variable_scope(self.config.scope+"_input_highway", reuse=reuse): emb_seq = match_utils.multi_highway_layer(emb_seq, input_dim, self.config.highway_layer_num) emb_seq = tf.nn.dropout(emb_seq, 1 - dropout_rate) #### encode input_dim = self.emb_size #self.config["embedding_dim"] enc_seq = encode(emb_seq, method=self.config["encode_method"], input_dim=input_dim, params=self.config, sequence_length=seq_len, mask_zero=self.config["embedding_mask_zero"], scope_name=self.scope + "enc_seq_%s"%granularity, reuse=reuse, training=self.is_training) #### attend feature_dim = self.config["encode_dim"] print("==semantic feature dim==", feature_dim, enc_seq.get_shape()) context = None att_seq = attend(enc_seq, context=context, encode_dim=self.config["encode_dim"], feature_dim=feature_dim, attention_dim=self.config["attention_dim"], method=self.config["attend_method"], scope_name=self.scope + "att_seq_%s"%granularity, reuse=reuse, num_heads=self.config["attention_num_heads"]) print("==semantic layer attention seq shape==", att_seq.get_shape()) #### MLP nonlinear projection sem_seq = mlp_layer(att_seq, fc_type=self.config["fc_type"], hidden_units=self.config["fc_hidden_units"], dropouts=self.config["fc_dropouts"], scope_name=self.scope + "sem_seq_%s"%granularity, reuse=reuse, training=self.is_training, seed=self.config["random_seed"]) print("==semantic layer mlp seq shape==", sem_seq.get_shape()) return emb_seq, enc_seq, att_seq, sem_seq
def build_encoder(self, index, input_lengths, input_mask, *args, **kargs): reuse = kargs["reuse"] word_emb = self.build_emebdding(index, *args, **kargs) dropout_rate = tf.cond(self.is_training, lambda:self.config.dropout_rate, lambda:0.0) word_emb = tf.nn.dropout(word_emb, 1-dropout_rate) with tf.variable_scope(self.config.scope+"_input_highway", reuse=reuse): input_dim = word_emb.get_shape()[-1] sent_repres = match_utils.multi_highway_layer(word_emb, input_dim, self.config.highway_layer_num) [_, _, sent_repres] = layer_utils.my_lstm_layer(sent_repres, self.config.context_lstm_dim, input_lengths=input_lengths, scope_name=self.config.scope, reuse=reuse, is_training=self.is_training, dropout_rate=dropout_rate, use_cudnn=self.config.use_cudnn) sent_repres = tf.layers.dense(sent_repres, self.config.context_lstm_dim*2, activation=tf.nn.relu) + sent_repres ignore_padding = (1 - input_mask) ignore_padding = decathlon_utils.attention_bias_ignore_padding(ignore_padding) encoder_self_attention_bias = ignore_padding output = decathlon_utils.multihead_attention_texar(sent_repres, memory=None, memory_attention_bias=encoder_self_attention_bias, num_heads=self.config.num_heads, num_units=None, dropout_rate=dropout_rate, scope="multihead_attention") output = tf.layers.dense(output, self.config.context_lstm_dim*2, activation=tf.nn.relu) + output output = qanet_layers.layer_norm(output, scope = "layer_norm", reuse = reuse) return sent_repres
def build_encoder(self, index, input_lengths, *args, **kargs): reuse = kargs["reuse"] word_emb = self.build_emebdding(index, *args, **kargs) dropout_rate = tf.cond(self.is_training, lambda: self.config.dropout_rate, lambda: 0.0) word_emb = tf.nn.dropout(word_emb, 1 - dropout_rate) with tf.variable_scope(self.config.scope + "_input_highway", reuse=reuse): input_dim = word_emb.get_shape()[-1] sent_repres = match_utils.multi_highway_layer( word_emb, input_dim, self.config.highway_layer_num) if self.config.rnn == "lstm": [sent_repres_fw, sent_repres_bw, sent_repres ] = layer_utils.my_lstm_layer(sent_repres, self.config.context_lstm_dim, input_lengths=input_lengths, scope_name=self.config.scope, reuse=reuse, is_training=self.is_training, dropout_rate=dropout_rate, use_cudnn=self.config.use_cudnn) elif self.config.rnn == "slstm": word_emb_proj = tf.layers.dense(word_emb, self.config.slstm_hidden_size) initial_hidden_states = word_emb_proj initial_cell_states = tf.identity(initial_hidden_states) [new_hidden_states, new_cell_states, dummynode_hidden_states ] = slstm_utils.slstm_cell(self.config, self.config.scope, self.config.slstm_hidden_size, input_lengths, initial_hidden_states, initial_cell_states, self.config.slstm_layer_num, dropout_rate, reuse=reuse) sent_repres = new_hidden_states return sent_repres
def build_encoder(self, input_lengths, input_mask, *args, **kargs): reuse = kargs["reuse"] word_emb = self.build_emebdding(*args, **kargs) dropout_rate = tf.cond(self.is_training, lambda: self.config.dropout_rate, lambda: 0.0) word_emb = tf.nn.dropout(word_emb, 1 - dropout_rate) with tf.variable_scope(self.config.scope + "_input_highway", reuse=reuse): input_dim = word_emb.get_shape()[-1] sent_repres = match_utils.multi_highway_layer( word_emb, input_dim, self.config.highway_layer_num) # sent_repres = tf.layers.dense(sent_repres, 100, activation=tf.nn.relu, use_bias=False) mask = tf.expand_dims(input_mask, -1) sent_repres *= tf.cast(mask, tf.float32) if self.config.encoder_type == "textcnn": output = textcnn_utils.text_cnn( sent_repres, [1, 3, 5, 7], "textcnn", self.emb_size, self.config.num_filters, max_pool_size=self.config.max_pool_size) elif self.config.encoder_type == "multi_head_attn": output = textcnn_utils.self_attn(sent_repres, input_mask, self.scope + "_" + self.config.encoder_type, dropout_rate, self.config, reuse=None) output = textcnn_utils.task_specific_attention( output, 2 * self.config.context_lstm_dim, input_mask, scope=self.scope + "_self_attention") print("output shape====", output.get_shape()) return output
def build_encoder(self, input_lengths, input_mask, *args, **kargs): reuse = kargs["reuse"] word_emb = self.build_emebdding(*args, **kargs) dropout_rate = tf.cond(self.is_training, lambda: self.config.dropout_rate, lambda: 0.0) word_emb = tf.nn.dropout(word_emb, 1 - dropout_rate) with tf.variable_scope(self.config.scope + "_input_highway", reuse=reuse): input_dim = word_emb.get_shape()[-1] sent_repres = match_utils.multi_highway_layer( word_emb, input_dim, self.config.highway_layer_num) if self.config.rnn == "lstm": [sent_repres_fw, sent_repres_bw, sent_repres ] = layer_utils.my_lstm_layer(sent_repres, self.config.context_lstm_dim, input_lengths=input_lengths, scope_name=self.config.scope, reuse=reuse, is_training=self.is_training, dropout_rate=dropout_rate, use_cudnn=self.config.use_cudnn) match_dim = self.config.context_lstm_dim * 6 elif self.config.rnn == "slstm": word_emb_proj = tf.layers.dense(word_emb, self.config.slstm_hidden_size) initial_hidden_states = word_emb_proj initial_cell_states = tf.identity(initial_hidden_states) [new_hidden_states, new_cell_states, dummynode_hidden_states ] = slstm_utils.slstm_cell(self.config, self.config.scope, self.config.slstm_hidden_size, input_lengths, initial_hidden_states, initial_cell_states, self.config.slstm_layer_num, dropout_rate, reuse=reuse) sent_repres = new_hidden_states match_dim = self.config.slstm_hidden_size * 3 if self.config.multi_head: mask = tf.cast(input_mask, tf.float32) ignore_padding = (1 - mask) ignore_padding = label_network_utils.attention_bias_ignore_padding( ignore_padding) encoder_self_attention_bias = ignore_padding sent_repres = label_network_utils.multihead_attention_texar( sent_repres, memory=None, memory_attention_bias=encoder_self_attention_bias, num_heads=8, num_units=128, dropout_rate=dropout_rate, scope="multihead_attention") v_attn = self_attn.multi_dimensional_attention( sent_repres, input_mask, 'multi_dim_attn_for_%s' % self.config.scope, 1 - dropout_rate, self.is_training, self.config.weight_decay, "relu") mask = tf.expand_dims(input_mask, -1) v_sum = tf.reduce_sum(sent_repres * tf.cast(mask, tf.float32), 1) v_ave = tf.div( v_sum, tf.expand_dims( tf.cast(input_lengths, tf.float32) + EPSILON, -1)) v_max = tf.reduce_max(qanet_layers.mask_logits(sent_repres, mask), axis=1) v_last = esim_utils.last_relevant_output(sent_repres, input_lengths) out = tf.concat([v_ave, v_max, v_last, v_attn], axis=-1) return out, match_dim
def build_interactor(self, sent1_emb, sent2_emb, sent1_len, sent2_len, sent1_mask, sent2_mask, *args, **kargs): num_lstm_layers = kargs["num_lstm_layers"] dropout_rate = tf.cond(self.is_training, lambda: self.config.dropout_rate, lambda: 0.0) input_dim = sent1_emb.get_shape()[-1] with tf.variable_scope(self.config.scope + "_embed_hishway"): sent1_repres = match_utils.multi_highway_layer( sent1_emb, input_dim, self.config.highway_layer_num) tf.get_variable_scope().reuse_variables() sent2_repres = match_utils.multi_highway_layer( sent2_emb, input_dim, self.config.highway_layer_num) match_dim = self.emb_size for i in range(num_lstm_layers): with tf.variable_scope(self.config.scope + "_densely_co_attentive_{}".format(i), reuse=None): sent1_repres_, match_dim_ = self.build_encoder(sent1_repres, sent1_len, reuse=None) sent2_repres_, match_dim_ = self.build_encoder(sent2_repres, sent1_len, reuse=True) match_dim += match_dim_ print("===before=====", i, sent1_repres_.get_shape(), sent2_repres_.get_shape()) if self.config.get("co_attention", None): [query_attention, context_attention ] = drcn_utils.query_context_attention(sent1_repres_, sent2_repres_, sent1_len, sent2_len, sent1_mask, sent2_mask, dropout_rate, self.config.scope, reuse=None) sent1_repres = tf.concat( [sent1_repres_, query_attention, sent1_repres], axis=-1) sent2_repres = tf.concat( [sent2_repres_, context_attention, sent2_repres], axis=-1) match_dim += match_dim_ else: sent1_repres = tf.concat([sent1_repres_, sent1_repres], axis=-1) sent2_repres = tf.concat([sent2_repres_, sent2_repres], axis=-1) print("====i====", sent1_repres.get_shape(), sent2_repres.get_shape()) if np.mod(i + 1, 2) == 0 and self.config.with_auto_encoding: sent1_repres = self.auto_encoder(sent1_repres, reuse=None) sent2_repres = self.auto_encoder(sent2_repres, reuse=True) if self.config.recurrent_layer_norm: sent1_repres = tf.contrib.layers.layer_norm( sent1_repres, reuse=None, scope="lstm_layer_norm") sent2_repres = tf.contrib.layers.layer_norm( sent2_repres, reuse=True, scope="lstm_layer_norm") mask_q = tf.expand_dims(sent1_mask, -1) mask_c = tf.expand_dims(sent2_mask, -1) v_1_max = tf.reduce_max(qanet_layers.mask_logits(sent1_repres, mask_q), axis=1) v_2_max = tf.reduce_max(qanet_layers.mask_logits(sent2_repres, mask_c), axis=1) v = tf.concat([ v_1_max, v_2_max, v_1_max * v_2_max, v_1_max - v_2_max, tf.abs(v_1_max - v_2_max) ], axis=-1) v = tf.nn.dropout(v, 1 - dropout_rate) match_dim = match_dim * 5 return v_1_max, v_2_max, v, match_dim
def _interaction_semantic_feature_layer(self, seq_input_left, seq_input_right, seq_len_left, seq_len_right, granularity="word"): assert granularity in ["char", "word"] #### embed with self.graph.as_default(): emb_seq_left = tf.nn.embedding_lookup(self.emb_mat, seq_input_left) dropout_rate = tf.cond(self.is_training, lambda:self.config.dropout_rate, lambda:0.0) input_dim = emb_seq_left.get_shape()[-1] with tf.variable_scope(self.config.scope+"_input_highway", reuse=False): emb_seq_left = match_utils.multi_highway_layer(emb_seq_left, input_dim, self.config.highway_layer_num) emb_seq_left = tf.nn.dropout(emb_seq_left, 1 - dropout_rate) seq_input_right = tf.nn.embedding_lookup(self.emb_mat, seq_input_right) dropout_rate = tf.cond(self.is_training, lambda:self.config.dropout_rate, lambda:0.0) input_dim = seq_input_right.get_shape()[-1] with tf.variable_scope(self.config.scope+"_input_highway", reuse=True): seq_input_right = match_utils.multi_highway_layer(seq_input_right, input_dim, self.config.highway_layer_num) seq_input_right = tf.nn.dropout(seq_input_right, 1 - dropout_rate) #### encode input_dim = self.emb_size #self.config["embedding_dim"] enc_seq_left = encode(emb_seq_left, method=self.config["encode_method"], input_dim=input_dim, params=self.config, sequence_length=seq_len_left, mask_zero=self.config["embedding_mask_zero"], scope_name=self.scope + "enc_seq_%s"%granularity, reuse=False, training=self.is_training) enc_seq_right = encode(emb_seq_right, method=self.config["encode_method"], input_dim=input_dim, params=self.config, sequence_length=seq_len_right, mask_zero=self.config["embedding_mask_zero"], scope_name=self.scope + "enc_seq_%s" % granularity, reuse=True, training=self.is_training) #### attend # [batchsize, s1, s2] att_mat = tf.einsum("abd,acd->abc", enc_seq_left, enc_seq_right) feature_dim = self.config["encode_dim"] + self.config["max_seq_len_%s"%granularity] att_seq_left = attend(enc_seq_left, context=att_mat, feature_dim=feature_dim, method=self.config["attend_method"], scope_name=self.scope + "att_seq_%s"%granularity, reuse=False) att_seq_right = attend(enc_seq_right, context=tf.transpose(att_mat), feature_dim=feature_dim, method=self.config["attend_method"], scope_name=self.scope + "att_seq_%s" % granularity, reuse=True) #### MLP nonlinear projection sem_seq_left = mlp_layer(att_seq_left, fc_type=self.config["fc_type"], hidden_units=self.config["fc_hidden_units"], dropouts=self.config["fc_dropouts"], scope_name=self.scope + "sem_seq_%s"%granularity, reuse=False, training=self.is_training, seed=self.config["random_seed"]) sem_seq_right = mlp_layer(att_seq_right, fc_type=self.config["fc_type"], hidden_units=self.config["fc_hidden_units"], dropouts=self.config["fc_dropouts"], scope_name=self.scope + "sem_seq_%s" % granularity, reuse=True, training=self.is_training, seed=self.config["random_seed"]) return emb_seq_left, enc_seq_left, att_seq_left, sem_seq_left, \ emb_seq_right, enc_seq_right, att_seq_right, sem_seq_right
def build_interactor(self, sent1_repres, sent2_repres, sent1_len, sent2_len, sent1_mask, sent2_mask, *args, **kargs): reuse = kargs["reuse"] input_dim = sent1_repres.get_shape()[-1] dropout_rate = tf.cond(self.is_training, lambda: self.config.dropout_rate, lambda: 0.0) with tf.variable_scope(self.config.scope + "_interaction_module", reuse=reuse): [c2q_concat, q2c_concat] = man_utils.concat_attention(sent1_repres, sent2_repres, sent1_len, sent2_len, sent1_mask, sent2_mask, dropout_rate, self.config.scope, reuse=reuse) [c2q_bilinear, q2c_bilinear] = man_utils.bilinear_attention(sent1_repres, sent2_repres, sent1_len, sent2_len, sent1_mask, sent2_mask, dropout_rate, self.config.scope, reuse=reuse) [c2q_dot, q2c_dot] = man_utils.dot_attention(sent1_repres, sent2_repres, sent1_len, sent2_len, sent1_mask, sent2_mask, dropout_rate, self.config.scope, reuse=reuse) [c2q_minus, q2c_minus] = man_utils.minus_attention(sent1_repres, sent2_repres, sent1_len, sent2_len, sent1_mask, sent2_mask, dropout_rate, self.config.scope, reuse=reuse) sent1_agg = tf.concat( [sent1_repres, c2q_concat, c2q_bilinear, c2q_dot, c2q_minus], axis=-1) sent1_agg_dim = self.config.context_lstm_dim * 10 sent2_agg = tf.concat( [sent2_repres, q2c_concat, q2c_bilinear, q2c_dot, q2c_minus], axis=-1) sent2_agg_dim = self.config.context_lstm_dim * 10 with tf.variable_scope(self.config.scope + "_inner_highway", reuse=None): sent1_agg = match_utils.multi_highway_layer( sent1_agg, sent1_agg_dim, 1, scope="sent_attention_highway") tf.get_variable_scope().reuse_variables() sent2_agg = match_utils.multi_highway_layer( sent2_agg, sent2_agg_dim, 1, scope="sent_attention_highway") [_, _, sent1_agg ] = layer_utils.my_lstm_layer(sent1_agg, self.config.context_lstm_dim, input_lengths=sent1_len, scope_name="inner_aggeration", reuse=False, is_training=self.is_training, dropout_rate=dropout_rate, use_cudnn=self.config.use_cudnn) [_, _, sent2_agg ] = layer_utils.my_lstm_layer(sent2_agg, self.config.context_lstm_dim, input_lengths=sent2_len, scope_name="inner_aggeration", reuse=True, is_training=self.is_training, dropout_rate=dropout_rate, use_cudnn=self.config.use_cudnn) with tf.variable_scope(self.config.scope + "_predictor_self_attention", reuse=None): context_attn = man_utils.self_attention(sent1_repres, sent2_agg, sent1_len, sent2_len, sent1_mask, sent2_mask, dropout_rate, self.config.scope, reuse=None) tf.get_variable_scope().reuse_variables() query_attn = man_utils.self_attention(sent2_repres, sent1_agg, sent2_len, sent1_len, sent2_mask, sent1_mask, dropout_rate, self.config.scope, reuse=None) aggre_output = tf.concat([ context_attn, query_attn, tf.abs(context_attn - query_attn), context_attn * query_attn ], axis=-1) match_dim = self.config.context_lstm_dim * 2 * 4 return context_attn, query_attn, aggre_output, match_dim
def build_encoder(self, index, input_lengths, input_mask, *args, **kargs): reuse = kargs["reuse"] word_emb = self.build_emebdding(index, *args, **kargs) dropout_rate = tf.cond(self.is_training, lambda: self.config.dropout_rate, lambda: 0.0) word_emb = tf.nn.dropout(word_emb, 1 - dropout_rate) with tf.variable_scope(self.config.scope + "_input_highway", reuse=reuse): input_dim = word_emb.get_shape()[-1] sent_repres = match_utils.multi_highway_layer( word_emb, input_dim, self.config.highway_layer_num) if self.config.rnn == "lstm": [sent_repres_fw, sent_repres_bw, sent_repres ] = layer_utils.my_lstm_layer(sent_repres, self.config.context_lstm_dim, input_lengths=input_lengths, scope_name=self.config.scope, reuse=reuse, is_training=self.is_training, dropout_rate=dropout_rate, use_cudnn=self.config.use_cudnn) elif self.config.rnn == "slstm": word_emb_proj = tf.layers.dense(word_emb, self.config.slstm_hidden_size) initial_hidden_states = word_emb_proj initial_cell_states = tf.identity(initial_hidden_states) [new_hidden_states, new_cell_states, dummynode_hidden_states ] = slstm_utils.slstm_cell(self.config, self.config.scope, self.config.slstm_hidden_size, input_lengths, initial_hidden_states, initial_cell_states, self.config.slstm_layer_num, dropout_rate, reuse=reuse) sent_repres = new_hidden_states elif self.config.rnn == "base_transformer": sent_repres = base_transformer_utils.transformer_encoder( sent_repres, target_space=None, hparams=self.config, features=None, make_image_summary=False) elif self.config.rnn == "universal_transformer": sent_repres, act_loss = universal_transformer_utils.universal_transformer_encoder( sent_repres, target_space=None, hparams=self.config, features=None, make_image_summary=False) elif self.config.rnn == "highway": sent_repres = sent_repres input_mask = tf.expand_dims(tf.cast(input_mask, tf.float32), axis=-1) sent_repres_sum = tf.reduce_sum(sent_repres * input_mask, axis=1) sent_repres_avr = tf.div( sent_repres_sum, tf.expand_dims( tf.cast(input_lengths, tf.float32) + EPSILON, -1)) if self.config.metric == "Hyperbolic": sent_repres = tf.clip_by_norm(sent_repres_sum, 1.0 - EPSILON, axes=1) else: sent_repres = sent_repres_avr if self.config.rnn == "universal_transformer": return sent_repres, act_loss else: return sent_repres
def build_encoder(self, input_lengths, input_mask, *args, **kargs): reuse = kargs["reuse"] word_emb, entity_emb = self.build_emebdding(*args, **kargs) dropout_rate = tf.cond(self.is_training, lambda: self.config.dropout_rate, lambda: 0.0) word_emb = tf.nn.dropout(word_emb, 1 - dropout_rate) with tf.variable_scope(self.config.scope + "_input_highway", reuse=reuse): input_dim = word_emb.get_shape()[-1] sent_repres = match_utils.multi_highway_layer( word_emb, input_dim, self.config.highway_layer_num) mask = tf.expand_dims(input_mask, -1) # sent_repres = tf.layers.dense(sent_repres, self.emb_size) sent_repres *= tf.cast(mask, tf.float32) # sent_repres = label_network_utils.self_attn( # enc=sent_repres, # scope=self.config.scope, # dropout=dropout_rate, # reuse=None, # config=self.config # ) # sent_repres = label_network_utils.text_cnn( # sent_repres, # filter_sizes=[1,3,5], # scope=self.config.scope, # embed_size=self.emb_size, # num_filters=self.config.num_filters) # output = sent_repres # print(sent_repres.get_shape(), "===text cnn encoder shape===") [sent_repres_fw, sent_repres_bw, sent_repres ] = layer_utils.my_lstm_layer(sent_repres, self.config.context_lstm_dim, input_lengths=input_lengths, scope_name=self.config.scope, reuse=reuse, is_training=self.is_training, dropout_rate=dropout_rate, use_cudnn=self.config.use_cudnn) match_dim = self.config.context_lstm_dim * 8 with tf.variable_scope(self.config.scope + "sent_label_attention", reuse=reuse): memory = tf.expand_dims(self.memory, axis=0) memory = tf.tile(memory, [tf.shape(sent_repres)[0], 1, 1]) # entity_emb = tf.expand_dims(entity_emb, axis=1) # entity_emb = tf.tile(entity_emb, [1, tf.shape(memory)[1], 1]) # print("===emb shape===", entity_emb.get_shape()) # # batch x classes x dim # memory = tf.concat([memory, entity_emb], axis=-1) print("==memory shape==", memory.get_shape()) # output = label_network_utils.memory_attention(sent_repres, # memory, input_mask, # scope=self.config.scope, # memory_mask=None) print(sent_repres.get_shape(), memory.get_shape()) output = label_network_utils.memory_attention_v1( sent_repres, memory, input_mask, "memory_attention", memory_mask=None, reuse=None, attention_output="multi_head", num_heads=4, dropout_rate=dropout_rate, threshold=1 / float(self.num_classes), apply_hard_attn=True) print("==output shape==", output.get_shape()) return sent_repres, entity_emb, output