def bi_attention(config, is_train, h, u, h_mask=None, u_mask=None, scope=None, tensor_dict=None): with tf.variable_scope(scope or "bi_attention"): JX = tf.shape(h)[2] M = tf.shape(h)[1] JQ = tf.shape(u)[1] h_aug = tf.tile(tf.expand_dims(h, 3), [1, 1, 1, JQ, 1]) u_aug = tf.tile(tf.expand_dims(tf.expand_dims(u, 1), 1), [1, M, JX, 1, 1]) if h_mask is None: hu_mask = None else: h_mask_aug = tf.tile(tf.expand_dims(h_mask, 3), [1, 1, 1, JQ]) u_mask_aug = tf.tile(tf.expand_dims(tf.expand_dims(u_mask, 1), 1), [1, M, JX, 1]) hu_mask = h_mask_aug & u_mask_aug u_logits = get_logits([h_aug, u_aug], None, True, wd=config.wd, mask=hu_mask, is_train=is_train, func=config.logit_func, scope='u_logits') # [N, M, JX, JQ] u_a = softsel(u_aug, u_logits) # [N, M, JX, d] h_a = softsel(h, tf.reduce_max(u_logits, 3)) # [N, M, d] h_a = tf.tile(tf.expand_dims(h_a, 2), [1, 1, JX, 1]) if tensor_dict is not None: a_u = tf.nn.softmax(u_logits) # [N, M, JX, JQ] a_h = tf.nn.softmax(tf.reduce_max(u_logits, 3)) tensor_dict['a_u'] = a_u tensor_dict['a_h'] = a_h variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=tf.get_variable_scope().name) for var in variables: tensor_dict[var.name] = var return u_a, h_a
def bi_attention(config, is_train, h, u, h_mask=None, u_mask=None, scope=None, tensor_dict=None): """ h_a: all u attending on h choosing an element of h that max-matches u First creates confusion matrix between h and u Then take max of the attention weights over u row Finally softmax over u_a: each h attending on u :param h: [N, M, JX, d] :param u: [N, JQ, d] :param h_mask: [N, M, JX] :param u_mask: [N, B] :param scope: :return: [N, M, d], [N, M, JX, d] """ with tf.variable_scope(scope or "bi_attention"): N, M, JX, JQ, d = config.batch_size, config.max_num_sents, config.max_sent_size, config.max_ques_size, config.hidden_size JX = tf.shape(h)[2] h_aug = tf.tile(tf.expand_dims(h, 3), [1, 1, 1, JQ, 1]) u_aug = tf.tile(tf.expand_dims(tf.expand_dims(u, 1), 1), [1, M, JX, 1, 1]) if h_mask is None: and_mask = None else: h_mask_aug = tf.tile(tf.expand_dims(h_mask, 3), [1, 1, 1, JQ]) u_mask_aug = tf.tile(tf.expand_dims(tf.expand_dims(u_mask, 1), 1), [1, M, JX, 1]) and_mask = h_mask_aug & u_mask_aug u_logits = get_logits([h_aug, u_aug], None, True, wd=config.wd, mask=and_mask, is_train=is_train, func=config.logit_func, scope='u_logits') # [N, M, JX, JQ] u_a = softsel(u_aug, u_logits) # [N, M, JX, d] if tensor_dict is not None: # a_h = tf.nn.softmax(h_logits) # [N, M, JX] a_u = tf.nn.softmax(u_logits) # [N, M, JX, JQ] # tensor_dict['a_h'] = a_h tensor_dict['a_u'] = a_u if config.bi: h_a = softsel(h, tf.reduce_max(u_logits, 3)) # [N, M, d] h_a = tf.tile(tf.expand_dims(h_a, 2), [1, 1, JX, 1]) else: h_a = None return u_a, h_a
def bi_attention(config, is_train, h, u, h_mask=None, u_mask=None, scope=None, tensor_dict=None): # h = [N, M, JX, 2d], u = [N, JQ, 2d] # h_mask = [N, M, JX], u_mask = [N, JQ] with tf.variable_scope(scope or "bi_attention"): JX = tf.shape(h)[2] M = tf.shape(h)[1] JQ = tf.shape(u)[1] h_aug = tf.tile(tf.expand_dims(h, 3), [1, 1, 1, JQ, 1]) # [N, M, JX, JQ*, 2d] u_aug = tf.tile(tf.expand_dims(tf.expand_dims(u, 1), 1), [1, M, JX, 1, 1]) # [N, M*, JX*, JQ, 2d] if h_mask is None: # No hu_mask = None else: # Yes h_mask_aug = tf.tile(tf.expand_dims(h_mask, 3), [1, 1, 1, JQ]) # [N, M, JX, JQ] u_mask_aug = tf.tile(tf.expand_dims(tf.expand_dims(u_mask, 1), 1), [1, M, JX, 1]) # [N, M, JX, JQ] hu_mask = h_mask_aug & u_mask_aug # mask the position where Q and C both have sentences u_logits = get_logits([h_aug, u_aug], None, True, wd=config.wd, mask=hu_mask, is_train=is_train, func=config.logit_func, scope='u_logits') # u_logits = [N, M, JX, JQ], this is the similarity matrix! u_a = softsel(u_aug, u_logits) # [N, M, JX, d] h_a = softsel(h, tf.reduce_max(u_logits, 3)) # [N, M, d] h_a = tf.tile(tf.expand_dims(h_a, 2), [1, 1, JX, 1]) if tensor_dict is not None: a_u = tf.nn.softmax(u_logits) # [N, M, JX, JQ] a_h = tf.nn.softmax(tf.reduce_max(u_logits, 3)) tensor_dict['a_u'] = a_u # C2Q tensor_dict['a_h'] = a_h # Q2C variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=tf.get_variable_scope().name) for var in variables: tensor_dict[var.name] = var return u_a, h_a
def self_attention(config, is_train, p, p_mask=None, scope=None): #[N, L, 2d] with tf.variable_scope(scope or "self_attention"): JX = p.get_shape().as_list()[1] print(p.get_shape()) p_aug_1 = tf.tile(tf.expand_dims(p, 2), [1, 1, JX, 1]) p_aug_2 = tf.tile(tf.expand_dims(p, 1), [1, JX, 1, 1]) #[N, PL, HL, 2d] if p_mask is None: ph_mask = None else: p_mask_aug_1 = tf.reduce_any(tf.cast( tf.tile(tf.expand_dims(p_mask, 2), [1, 1, JX, 1]), tf.bool), axis=3) p_mask_aug_2 = tf.reduce_any(tf.cast( tf.tile(tf.expand_dims(p_mask, 1), [1, JX, 1, 1]), tf.bool), axis=3) self_mask = p_mask_aug_1 & p_mask_aug_2 print(self_mask.get_shape().as_list()) h_logits = get_logits([p_aug_1, p_aug_2], None, True, wd=config.wd, mask=self_mask, is_train=is_train, func='tri_linear', scope='h_logits') # [N, PL, HL] self_att = softsel(p_aug_2, h_logits) return self_att
def bi_attention(config, is_train, h, u, h_mask=None, u_mask=None, scope=None, tensor_dict=None): """ h_a: all u attending on h choosing an element of h that max-matches u First creates confusion matrix between h and u Then take max of the attention weights over u row Finally softmax over u_a: each h attending on u :param h: [N, M, JX, d] :param u: [N, JQ, d] :param h_mask: [N, M, JX] :param u_mask: [N, B] :param scope: :return: [N, M, d], [N, M, JX, d] """ with tf.variable_scope(scope or "bi_attention"): JX = tf.shape(h)[2] M = tf.shape(h)[1] JQ = tf.shape(u)[1] h_aug = tf.tile(tf.expand_dims(h, 3), [1, 1, 1, JQ, 1]) u_aug = tf.tile(tf.expand_dims(tf.expand_dims(u, 1), 1), [1, M, JX, 1, 1]) if h_mask is None: hu_mask = None else: h_mask_aug = tf.tile(tf.expand_dims(h_mask, 3), [1, 1, 1, JQ]) u_mask_aug = tf.tile(tf.expand_dims(tf.expand_dims(u_mask, 1), 1), [1, M, JX, 1]) hu_mask = h_mask_aug & u_mask_aug u_logits = get_logits([h_aug, u_aug], None, True, wd=config.wd, mask=hu_mask, is_train=is_train, func=config.logit_func, scope='u_logits') # [N, M, JX, JQ] u_a = softsel(u_aug, u_logits) # [N, M, JX, d] h_a = softsel(h, tf.reduce_max(u_logits, 3)) # [N, M, d] h_a = tf.tile(tf.expand_dims(h_a, 2), [1, 1, JX, 1]) if tensor_dict is not None: a_u = tf.nn.softmax(u_logits) # [N, M, JX, JQ] a_h = tf.nn.softmax(tf.reduce_max(u_logits, 3)) tensor_dict['a_u'] = a_u tensor_dict['a_h'] = a_h variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=tf.get_variable_scope().name) for var in variables: tensor_dict[var.name] = var return u_a, h_a
def bibi_attention(config, is_train, h, u, h_mask=None, u_mask=None, scope=None): #[N, L, 2d] with tf.variable_scope(scope or "self_attention"): JX = h.get_shape().as_list()[1] #basic JQ = u.get_shape().as_list()[1] p_aug_1 = tf.tile(tf.expand_dims(h, 2), [1, 1, JQ, 1]) p_aug_2 = tf.tile(tf.expand_dims(u, 1), [1, JX, 1, 1]) #[N, PL, HL, 2d] if h_mask is None: hu_mask = None else: print(h_mask.get_shape().as_list()) print(u_mask.get_shape().as_list()) h_mask_aug = tf.cast( tf.tile(tf.expand_dims(h_mask, 2), [1, 1, JQ]), 'bool') u_mask_aug = tf.cast( tf.tile(tf.expand_dims(u_mask, 1), [1, JX, 1]), 'bool') print(h_mask_aug.get_shape().as_list()) print(u_mask_aug.get_shape().as_list()) hu_mask = h_mask_aug & u_mask_aug h_logits = get_logits([p_aug_1, p_aug_2], None, True, wd=config.wd, mask=hu_mask, is_train=is_train, func='tri_linear', scope='h_logits') # [N, JX, JQ] u_a = softsel(p_aug_2, h_logits) print("u:{} ".format(u_a.get_shape())) return u_a
def _build_forward(self): config = self.config N, M, JX, JQ, VW, VC, d, W = \ config.batch_size, config.max_num_sents, config.max_sent_size, \ config.max_ques_size, config.word_vocab_size, config.char_vocab_size, config.hidden_size, \ config.max_word_size JX = tf.shape(self.x)[2] JQ = tf.shape(self.q)[1] M = tf.shape(self.x)[1] dc, dw, dco = config.char_emb_size, config.word_emb_size, config.char_out_size with tf.variable_scope("emb"): if config.use_char_emb: # 计算字符emb with tf.variable_scope("emb_var"), tf.device("/cpu:0"): char_emb_mat = tf.get_variable("char_emb_mat", shape=[VC, dc], dtype='float') with tf.variable_scope("char"): Acx = tf.nn.embedding_lookup(char_emb_mat, self.cx) # [N, M, JX, W, dc] Acq = tf.nn.embedding_lookup(char_emb_mat, self.cq) # [N, JQ, W, dc] Acx = tf.reshape(Acx, [-1, JX, W, dc]) Acq = tf.reshape(Acq, [-1, JQ, W, dc]) filter_sizes = list( map(int, config.out_channel_dims.split(','))) # TODO What? heights = list(map(int, config.filter_heights.split(','))) assert sum(filter_sizes) == dco, (filter_sizes, dco) with tf.variable_scope("conv"): xx = multi_conv1d(Acx, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") if config.share_cnn_weights: with tf.variable_scope(tf.get_variable_scope(), reuse=True): qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") else: qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="qq") xx = tf.reshape(xx, [-1, M, JX, dco]) qq = tf.reshape(qq, [-1, JQ, dco]) if config.use_word_emb: with tf.variable_scope("emb_var"), tf.device("/cpu:0"): if config.mode == 'train': word_emb_mat = tf.get_variable( "word_emb_mat", dtype='float', shape=[VW, dw], initializer=get_initializer( config.emb_mat)) # emb_mat is glove else: word_emb_mat = tf.get_variable("word_emb_mat", shape=[VW, dw], dtype='float') if config.use_glove_for_unk: word_emb_mat = tf.concat( [word_emb_mat, self.new_emb_mat], 0) with tf.name_scope("word"): Ax = tf.nn.embedding_lookup(word_emb_mat, self.x) # [N, M, JX, d] Aq = tf.nn.embedding_lookup(word_emb_mat, self.q) # [N, JQ, d] self.tensor_dict['x'] = Ax self.tensor_dict['q'] = Aq if config.use_char_emb: xx = tf.concat([xx, Ax], 3) # [N, M, JX, di] qq = tf.concat([qq, Aq], 2) # [N, JQ, di] else: xx = Ax qq = Aq # highway network if config.highway: with tf.variable_scope("highway"): xx = highway_network(xx, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) with tf.variable_scope(tf.get_variable_scope(), reuse=True): qq = highway_network(qq, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) self.tensor_dict['xx'] = xx self.tensor_dict['qq'] = qq cell = BasicLSTMCell(d, state_is_tuple=True) d_cell = SwitchableDropoutWrapper( cell, self.is_train, input_keep_prob=config.input_keep_prob) x_len = tf.reduce_sum(tf.cast(self.x_mask, 'int32'), 2) # [N, M] q_len = tf.reduce_sum(tf.cast(self.q_mask, 'int32'), 1) # [N] with tf.variable_scope("prepro"): (fw_u, bw_u), ((_, fw_u_f), (_, bw_u_f)) = bidirectional_dynamic_rnn( d_cell, d_cell, qq, q_len, dtype='float', scope='u1') # [N, J, d], [N, d] u = tf.concat([fw_u, bw_u], 2) if config.share_lstm_weights: with tf.variable_scope(tf.get_variable_scope(), reuse=True): (fw_h, bw_h), _ = bidirectional_dynamic_rnn( cell, cell, xx, x_len, dtype='float', scope='u1') # [N, M, JX, 2d] TODO JX == x_len? h = tf.concat([fw_h, bw_h], 3) # [N, M, JX, 2d] else: (fw_h, bw_h), _ = bidirectional_dynamic_rnn( cell, cell, xx, x_len, dtype='float', scope='h1') # [N, M, JX, 2d] h = tf.concat([fw_h, bw_h], 3) # [N, M, JX, 2d] self.tensor_dict['u'] = u self.tensor_dict['h'] = h with tf.variable_scope("main"): if config.dynamic_att: p0 = h u = tf.reshape(tf.tile(tf.expand_dims(u, 1), [1, M, 1, 1]), [N * M, JQ, 2 * d]) q_mask = tf.reshape( tf.tile(tf.expand_dims(self.q_mask, 1), [1, M, 1]), [N * M, JQ]) first_cell = AttentionCell( cell, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) else: p0 = attention_layer(config, self.is_train, h, u, h_mask=self.x_mask, u_mask=self.q_mask, scope="p0", tensor_dict=self.tensor_dict) first_cell = d_cell (fw_g0, bw_g0), _ = bidirectional_dynamic_rnn( first_cell, first_cell, p0, x_len, dtype='float', scope='g0') # [N, M, JX, 2d] g0 = tf.concat([fw_g0, bw_g0], 3) (fw_g1, bw_g1), _ = bidirectional_dynamic_rnn( first_cell, first_cell, g0, x_len, dtype='float', scope='g1') # [N, M, JX, 2d] g1 = tf.concat([fw_g1, bw_g1], 3) # Output Layer logits = get_logits([g1, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits1') a1i = softsel(tf.reshape(g1, [N, M * JX, 2 * d]), tf.reshape(logits, [N, M * JX])) a1i = tf.tile(tf.expand_dims(tf.expand_dims(a1i, 1), 1), [1, M, JX, 1]) (fw_g2, bw_g2), _ = bidirectional_dynamic_rnn( d_cell, d_cell, tf.concat([p0, g1, a1i, g1 * a1i], 3), x_len, dtype='float', scope='g2') # [N, M, JX, 2d] g2 = tf.concat([fw_g2, bw_g2], 3) logits2 = get_logits([g2, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits2') flat_logits = tf.reshape(logits, [-1, M * JX]) flat_yp = tf.nn.softmax(flat_logits) # [-1, M*JX] yp = tf.reshape(flat_yp, [-1, M, JX]) flat_logits2 = tf.reshape(logits2, [-1, M * JX]) flat_yp2 = tf.nn.softmax(flat_logits2) yp2 = tf.reshape(flat_yp2, [-1, M, JX]) self.tensor_dict['g1'] = g1 self.tensor_dict['g2'] = g2 self.logits = flat_logits self.logits2 = flat_logits2 self.yp = yp self.yp2 = yp2
def _build_forward(self): config = self.config N, M, JX, JQ, VW, VC, d, W = \ config.batch_size, config.max_num_sents, config.max_sent_size, \ config.max_ques_size, config.word_vocab_size, config.char_vocab_size, config.hidden_size, \ config.max_word_size JX = tf.shape(self.x)[2] JQ = tf.shape(self.q)[1] M = tf.shape(self.x)[1] dc, dw, dco = config.char_emb_size, config.word_emb_size, config.char_out_size with tf.variable_scope("emb"): if config.use_char_emb: with tf.variable_scope("emb_var"), tf.device("/cpu:0"): char_emb_mat = tf.get_variable("char_emb_mat", shape=[VC, dc], dtype='float') with tf.variable_scope("char"): Acx = tf.nn.embedding_lookup(char_emb_mat, self.cx) # [N, M, JX, W, dc] Acq = tf.nn.embedding_lookup(char_emb_mat, self.cq) # [N, JQ, W, dc] Acx = tf.reshape(Acx, [-1, JX, W, dc]) Acq = tf.reshape(Acq, [-1, JQ, W, dc]) filter_sizes = list( map(int, config.out_channel_dims.split(','))) heights = list(map(int, config.filter_heights.split(','))) assert sum(filter_sizes) == dco, (filter_sizes, dco) with tf.variable_scope("conv"): xx = multi_conv1d(Acx, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") if config.share_cnn_weights: tf.get_variable_scope().reuse_variables() qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") else: qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="qq") xx = tf.reshape(xx, [-1, M, JX, dco]) qq = tf.reshape(qq, [-1, JQ, dco]) if config.use_word_emb: with tf.variable_scope("emb_var"), tf.device("/cpu:0"): if config.mode == 'train': word_emb_mat = tf.get_variable( "word_emb_mat", dtype='float', shape=[VW, dw], initializer=tf.random_normal_initializer) else: word_emb_mat = tf.get_variable("word_emb_mat", shape=[VW, dw], dtype='float') if config.use_glove_for_unk: word_emb_mat = tf.concat( axis=0, values=[word_emb_mat, self.new_emb_mat]) with tf.name_scope("word"): Ax = tf.nn.embedding_lookup(word_emb_mat, self.x) # [N, M, JX, d] Aq = tf.nn.embedding_lookup(word_emb_mat, self.q) # [N, JQ, d] self.tensor_dict['x'] = Ax self.tensor_dict['q'] = Aq if config.use_char_emb: xx = tf.concat(axis=3, values=[xx, Ax]) # [N, M, JX, di] qq = tf.concat(axis=2, values=[qq, Aq]) # [N, JQ, di] else: xx = Ax qq = Aq # highway network if config.highway: with tf.variable_scope("highway"): xx = highway_network(xx, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) tf.get_variable_scope().reuse_variables() qq = highway_network(qq, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) self.tensor_dict['xx'] = xx self.tensor_dict['qq'] = qq cell_fw = LSTMCell(d, state_is_tuple=True, name="basic_lstm_cell") cell_bw = LSTMCell(d, state_is_tuple=True, name="basic_lstm_cell") d_cell_fw = SwitchableDropoutWrapper( cell_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell_bw = SwitchableDropoutWrapper( cell_bw, self.is_train, input_keep_prob=config.input_keep_prob) cell2_fw = LSTMCell(d, state_is_tuple=True, name="basic_lstm_cell") cell2_bw = LSTMCell(d, state_is_tuple=True, name="basic_lstm_cell") d_cell2_fw = SwitchableDropoutWrapper( cell2_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell2_bw = SwitchableDropoutWrapper( cell2_bw, self.is_train, input_keep_prob=config.input_keep_prob) cell3_fw = LSTMCell(d, state_is_tuple=True, name="basic_lstm_cell") cell3_bw = LSTMCell(d, state_is_tuple=True, name="basic_lstm_cell") d_cell3_fw = SwitchableDropoutWrapper( cell3_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell3_bw = SwitchableDropoutWrapper( cell3_bw, self.is_train, input_keep_prob=config.input_keep_prob) cell4_fw = LSTMCell(d, state_is_tuple=True, name="basic_lstm_cell") cell4_bw = LSTMCell(d, state_is_tuple=True, name="basic_lstm_cell") d_cell4_fw = SwitchableDropoutWrapper( cell4_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell4_bw = SwitchableDropoutWrapper( cell4_bw, self.is_train, input_keep_prob=config.input_keep_prob) x_len = tf.reduce_sum(tf.cast(self.x_mask, 'int32'), 2) # [N, M] q_len = tf.reduce_sum(tf.cast(self.q_mask, 'int32'), 1) # [N] with tf.variable_scope("prepro"): (fw_u, bw_u), ((_, fw_u_f), (_, bw_u_f)) = bidirectional_dynamic_rnn( d_cell_fw, d_cell_bw, qq, q_len, dtype='float', scope='u1') # [N, J, d], [N, d] u = tf.concat(axis=2, values=[fw_u, bw_u]) if config.share_lstm_weights: tf.get_variable_scope().reuse_variables() (fw_h, bw_h), _ = bidirectional_dynamic_rnn( cell_fw, cell_bw, xx, x_len, dtype='float', scope='u1') # [N, M, JX, 2d] h = tf.concat(axis=3, values=[fw_h, bw_h]) # [N, M, JX, 2d] else: (fw_h, bw_h), _ = bidirectional_dynamic_rnn( cell_fw, cell_bw, xx, x_len, dtype='float', scope='h1') # [N, M, JX, 2d] h = tf.concat(axis=3, values=[fw_h, bw_h]) # [N, M, JX, 2d] self.tensor_dict['u'] = u self.tensor_dict['h'] = h with tf.variable_scope("main"): if config.dynamic_att: p0 = h u = tf.reshape(tf.tile(tf.expand_dims(u, 1), [1, M, 1, 1]), [N * M, JQ, 2 * d]) q_mask = tf.reshape( tf.tile(tf.expand_dims(self.q_mask, 1), [1, M, 1]), [N * M, JQ]) first_cell_fw = AttentionCell( cell2_fw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) first_cell_bw = AttentionCell( cell2_bw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) second_cell_fw = AttentionCell( cell3_fw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) second_cell_bw = AttentionCell( cell3_bw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) else: p0 = attention_layer(config, self.is_train, h, u, h_mask=self.x_mask, u_mask=self.q_mask, scope="p0", tensor_dict=self.tensor_dict) first_cell_fw = d_cell2_fw second_cell_fw = d_cell3_fw first_cell_bw = d_cell2_bw second_cell_bw = d_cell3_bw (fw_g0, bw_g0), _ = bidirectional_dynamic_rnn( first_cell_fw, first_cell_bw, p0, x_len, dtype='float', scope='g0') # [N, M, JX, 2d] g0 = tf.concat(axis=3, values=[fw_g0, bw_g0]) (fw_g1, bw_g1), _ = bidirectional_dynamic_rnn( second_cell_fw, second_cell_bw, g0, x_len, dtype='float', scope='g1') # [N, M, JX, 2d] g1 = tf.concat(axis=3, values=[fw_g1, bw_g1]) logits = get_logits([g1, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits1') a1i = softsel(tf.reshape(g1, [N, M * JX, 2 * d]), tf.reshape(logits, [N, M * JX])) a1i = tf.tile(tf.expand_dims(tf.expand_dims(a1i, 1), 1), [1, M, JX, 1]) (fw_g2, bw_g2), _ = bidirectional_dynamic_rnn( d_cell4_fw, d_cell4_bw, tf.concat(axis=3, values=[p0, g1, a1i, g1 * a1i]), x_len, dtype='float', scope='g2') # [N, M, JX, 2d] g2 = tf.concat(axis=3, values=[fw_g2, bw_g2]) logits2 = get_logits([g2, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits2') flat_logits = tf.reshape(logits, [-1, M * JX]) flat_yp = tf.nn.softmax(flat_logits) # [-1, M*JX] flat_logits2 = tf.reshape(logits2, [-1, M * JX]) flat_yp2 = tf.nn.softmax(flat_logits2) if config.na: na_bias = tf.get_variable("na_bias", shape=[], dtype='float') na_bias_tiled = tf.tile(tf.reshape(na_bias, [1, 1]), [N, 1]) # [N, 1] concat_flat_logits = tf.concat( axis=1, values=[na_bias_tiled, flat_logits]) concat_flat_yp = tf.nn.softmax(concat_flat_logits) na_prob = tf.squeeze(tf.slice(concat_flat_yp, [0, 0], [-1, 1]), [1]) flat_yp = tf.slice(concat_flat_yp, [0, 1], [-1, -1]) concat_flat_logits2 = tf.concat( axis=1, values=[na_bias_tiled, flat_logits2]) concat_flat_yp2 = tf.nn.softmax(concat_flat_logits2) na_prob2 = tf.squeeze( tf.slice(concat_flat_yp2, [0, 0], [-1, 1]), [1]) # [N] flat_yp2 = tf.slice(concat_flat_yp2, [0, 1], [-1, -1]) self.concat_logits = concat_flat_logits self.concat_logits2 = concat_flat_logits2 self.na_prob = na_prob * na_prob2 yp = tf.reshape(flat_yp, [-1, M, JX], name="yp") yp2 = tf.reshape(flat_yp2, [-1, M, JX], name="yp2") wyp = tf.nn.sigmoid(logits2, name="wyp") self.tensor_dict['g1'] = g1 self.tensor_dict['g2'] = g2 self.logits = flat_logits self.logits2 = flat_logits2 self.yp = yp self.yp2 = yp2 self.wyp = wyp
def _build_forward(self): config = self.config x_len = tf.reduce_sum(tf.cast(self.x_mask, 'int32'), 2) # [N, M] q_len = tf.reduce_sum(tf.cast(self.q_mask, 'int32'), 1) # [N] N, M, JX, JQ, VW, VC, d, W = \ config.batch_size, config.max_num_sents, config.max_sent_size, \ config.max_ques_size, config.word_vocab_size, config.char_vocab_size, config.hidden_size, \ config.max_word_size JX = tf.shape(self.x)[2] JQ = tf.shape(self.q)[1] M = tf.shape(self.x)[1] dc, dw, dco = config.char_emb_size, config.word_emb_size, config.char_out_size with tf.variable_scope("emb"): if config.use_char_emb: with tf.variable_scope("emb_var"), tf.device("/cpu:0"): char_emb_mat = tf.get_variable("char_emb_mat", shape=[VC, dc], dtype='float') with tf.variable_scope("char"): Acx = tf.nn.embedding_lookup(char_emb_mat, self.cx) # [N, M, JX, W, dc] Acq = tf.nn.embedding_lookup(char_emb_mat, self.cq) # [N, JQ, W, dc] Acx = tf.reshape(Acx, [-1, JX, W, dc]) Acq = tf.reshape(Acq, [-1, JQ, W, dc]) filter_sizes = list( map(int, config.out_channel_dims.split(','))) heights = list(map(int, config.filter_heights.split(','))) assert sum(filter_sizes) == dco, (filter_sizes, dco) with tf.variable_scope("conv"): xx = multi_conv1d(Acx, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") if config.share_cnn_weights: tf.get_variable_scope().reuse_variables() qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") else: qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="qq") xx = tf.reshape(xx, [-1, M, JX, dco]) qq = tf.reshape(qq, [-1, JQ, dco]) if config.use_word_emb: with tf.variable_scope("emb_var"), tf.device("/cpu:0"): if config.mode == 'train': word_emb_mat = tf.get_variable( "word_emb_mat", dtype='float', shape=[VW, dw], initializer=get_initializer(self.emb_mat)) else: word_emb_mat = tf.get_variable("word_emb_mat", shape=[VW, dw], dtype='float') if config.use_glove_for_unk: word_emb_mat = tf.concat( axis=0, values=[word_emb_mat, self.new_emb_mat]) with tf.name_scope("word"): Ax = tf.nn.embedding_lookup(word_emb_mat, self.x) # [N, M, JX, d] Aq = tf.nn.embedding_lookup(word_emb_mat, self.q) # [N, JQ, d] self.tensor_dict['x'] = Ax self.tensor_dict['q'] = Aq if config.use_char_emb: xx = tf.concat(axis=3, values=[xx, Ax]) # [N, M, JX, di] qq = tf.concat(axis=2, values=[qq, Aq]) # [N, JQ, di] else: xx = Ax qq = Aq # highway network if config.highway: with tf.variable_scope("highway"): xx = highway_network(xx, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train, input_keep_prob=config.highway_keep_prob) tf.get_variable_scope().reuse_variables() qq = highway_network(qq, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train, input_keep_prob=config.highway_keep_prob) self.tensor_dict['xx'] = xx self.tensor_dict['qq'] = qq with tf.variable_scope("prepro"): with tf.variable_scope('u1'): u, _ = bi_cudnn_rnn_encoder('lstm', config.hidden_size, 1, 1 - config.input_keep_prob, qq, q_len, self.is_train) if config.reasoning_layer == 'snmn': u_st = zhong_selfatt(u[:, ax, :, :], config.hidden_size * 2, seq_len=q_len, transform='squeeze') if config.share_lstm_weights: with tf.variable_scope('u1', reuse=True): h, _ = bi_cudnn_rnn_encoder('lstm', config.hidden_size, 1, 1 - config.input_keep_prob, tf.squeeze(xx, axis=1), tf.squeeze(x_len, axis=1), self.is_train) h = h[:, ax, :, :] else: with tf.variable_scope('h1'): h, _ = bi_cudnn_rnn_encoder('lstm', config.hidden_size, 1, 1 - config.input_keep_prob, tf.squeeze(xx, axis=1), tf.squeeze(x_len, axis=1), self.is_train) h = h[:, ax, :, :] self.tensor_dict['u'] = u self.tensor_dict['h'] = h with tf.variable_scope("main"): context_dim = config.hidden_size * 2 ### Reconstruct before bidaf because otherwise we need to build a larger query tensor. x_mask = self.x_mask x_len_squeeze = tf.squeeze(x_len, axis=1) p0 = h ### Main model if config.reasoning_layer == 'snmn': module_names = ['_Find', '_Compare', '_Relocate', '_NoOp'] self.snmn = NMN_Model(config, u, qq, u_st, self.q_mask, q_len, p0, x_mask, x_len, module_names, \ self.is_train) self.u_weights = self.snmn.cv_list # question word distribution at each step self.module_prob_list = self.snmn.module_prob_list # module probability at each step g0 = tf.squeeze(self.snmn.att_second, axis=-1) if config.supervise_bridge_entity: self.hop0_logits = self.snmn.bridge_logits if config.self_att: with tf.variable_scope('g0'): g0, _ = bi_cudnn_rnn_encoder( 'lstm', config.hidden_size, 1, 1 - config.input_keep_prob, tf.squeeze(g0, axis=1), x_len_squeeze, self.is_train) g0 = g0[:, ax, :, :] g0 = hotpot_biattention(config, self.is_train, g0, tf.squeeze(g0, axis=1), h_mask=x_mask, u_mask=tf.squeeze(x_mask, axis=1), scope="self_att", tensor_dict=self.tensor_dict) g0 = tf.layers.dense(g0, config.hidden_size * 2) with tf.variable_scope('g1'): g1, _ = bi_cudnn_rnn_encoder('lstm', config.hidden_size, 1, 1 - config.input_keep_prob, tf.squeeze(g0, axis=1), tf.squeeze(x_len, axis=1), self.is_train) g1 = g1[:, ax, :, :] logits = get_logits([g1, g0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=x_mask, is_train=self.is_train, func=config.answer_func, scope='logits1') with tf.variable_scope('g2'): a1i = softsel(tf.reshape(g1, [N, M * JX, 2 * d]), tf.reshape(logits, [N, M * JX])) a1i = tf.tile(a1i[:, ax, ax, :], [1, M, JX, 1]) g2, _ = bi_cudnn_rnn_encoder( 'lstm', config.hidden_size, 1, 1 - config.input_keep_prob, tf.squeeze(tf.concat(axis=3, values=[g0, g1, a1i, g0 * a1i]), axis=1), x_len_squeeze, self.is_train) g2 = g2[:, ax, :, :] logits2 = get_logits([g2, g1], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=x_mask, is_train=self.is_train, func=config.answer_func, scope='logits2') if config.dataset == 'hotpotqa': with tf.variable_scope('g3'): if config.nmn_qtype_class == 'mem_last': g3 = tf.concat( [self.snmn.mem_last[:, ax, :], u_st[:, ax, :]], axis=-1) elif config.nmn_qtype_class == 'ctrl_st': g3 = self.snmn.c_st_list[0][:, ax, :] else: raise NotImplementedError self.predict_type = dense(g3, 2, scope='predict_type') g3_1 = self.snmn.mem_last[:, ax, :] self.predict_yesno = dense(g3_1, 2, scope='predict_yesno') flat_logits = tf.reshape(logits, [-1, M * JX]) flat_yp = tf.nn.softmax(flat_logits) # [-1, M * JX] flat_logits2 = tf.reshape(logits2, [-1, M * JX]) flat_yp2 = tf.nn.softmax(flat_logits2) yp = tf.reshape(flat_yp, [-1, M, JX]) yp2 = tf.reshape(flat_yp2, [-1, M, JX]) wyp = tf.nn.sigmoid(logits2) self.logits = flat_logits self.logits2 = flat_logits2 self.yp = yp self.yp2 = yp2 self.wyp = wyp if config.dataset == 'hotpotqa': flat_predict_type = tf.reshape(self.predict_type, [-1, 2]) flat_yp3 = tf.nn.softmax(flat_predict_type) self.yp3 = tf.reshape(flat_yp3, [-1, 1, 2]) flat_predict_yesno = tf.reshape(self.predict_yesno, [-1, 2]) flat_yp3_yesno = tf.nn.softmax(flat_predict_yesno) self.yp3_yesno = tf.reshape(flat_yp3_yesno, [-1, 1, 2])
def _build_forward(self): #config为预先配置好的参数等 config = self.config N, M, JX, JQ, VW, VC, d, W = \ config.batch_size, config.max_num_sents, config.max_sent_size, \ config.max_ques_size, config.word_vocab_size, config.char_vocab_size, config.hidden_size, \ config.max_word_size JX = tf.shape(self.x)[2] JQ = tf.shape(self.q)[1] M = tf.shape(self.x)[1] dc, dw, dco = config.char_emb_size, config.word_emb_size, config.char_out_size #嵌入层 with tf.variable_scope("emb"): #字符嵌入层 if config.use_char_emb: #若需要字符嵌入层 with tf.variable_scope("emb_var"), tf.device("/cpu:0"): char_emb_mat = tf.get_variable("char_emb_mat", shape=[VC, dc], dtype='float') with tf.variable_scope("char"): Acx = tf.nn.embedding_lookup(char_emb_mat, self.cx) # [N, M, JX, W, dc] Acq = tf.nn.embedding_lookup(char_emb_mat, self.cq) # [N, JQ, W, dc] Acx = tf.reshape(Acx, [-1, JX, W, dc]) Acq = tf.reshape(Acq, [-1, JQ, W, dc]) #CNN的滤波器参数 filter_sizes = list( map(int, config.out_channel_dims.split(','))) heights = list(map(int, config.filter_heights.split(','))) assert sum(filter_sizes) == dco, (filter_sizes, dco) with tf.variable_scope("conv"): xx = multi_conv1d(Acx, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") if config.share_cnn_weights: tf.get_variable_scope().reuse_variables() qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") else: qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="qq") xx = tf.reshape(xx, [-1, M, JX, dco]) qq = tf.reshape(qq, [-1, JQ, dco]) #词嵌入层 if config.use_word_emb: with tf.variable_scope("emb_var"), tf.device("/cpu:0"): if config.mode == 'train': word_emb_mat = tf.get_variable( "word_emb_mat", dtype='float', shape=[VW, dw], initializer=get_initializer(config.emb_mat)) else: word_emb_mat = tf.get_variable("word_emb_mat", shape=[VW, dw], dtype='float') if config.use_glove_for_unk: #若调用已训练好的词嵌入文件 word_emb_mat = tf.concat( 0, [word_emb_mat, self.new_emb_mat]) with tf.name_scope("word"): #将文章主体context:x和问题query:q转换为词向量 #embedding_lookup(params, ids),根据ids寻找params中的第id行 Ax = tf.nn.embedding_lookup(word_emb_mat, self.x) # [N, M, JX, d] Aq = tf.nn.embedding_lookup(word_emb_mat, self.q) # [N, JQ, d] self.tensor_dict['x'] = Ax self.tensor_dict['q'] = Aq if config.use_char_emb: #若进行了字符嵌入,在指定维度上将字符嵌入和词嵌入进行拼接 xx = tf.concat(3, [xx, Ax]) # [N, M, JX, di] qq = tf.concat(2, [qq, Aq]) # [N, JQ, di] else: xx = Ax qq = Aq # 经过两层highway network得到context vector∈ R^(d*T)和query vectorQ∈R^(d∗J) if config.highway: with tf.variable_scope("highway"): xx = highway_network(xx, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) tf.get_variable_scope().reuse_variables() qq = highway_network(qq, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) self.tensor_dict['xx'] = xx self.tensor_dict['qq'] = qq cell = BasicLSTMCell(d, state_is_tuple=True) #SwitchableDropoutWrapper为自定义的DropoutWrapper类 d_cell = SwitchableDropoutWrapper( cell, self.is_train, input_keep_prob=config.input_keep_prob) #reduce_sum在指定的维度上求和(得到x和q的非空值总数),cast将输入的tensor映射到指定类型(此处为x_mask到int32) x_len = tf.reduce_sum(tf.cast(self.x_mask, 'int32'), 2) # [N, M] q_len = tf.reduce_sum(tf.cast(self.q_mask, 'int32'), 1) # [N] #Contextual Embedding Layer:对上一层得到的X和Q分别使用BiLSTM进行处理,分别捕捉X和Q中各自单词间的局部关系 with tf.variable_scope("prepro"): (fw_u, bw_u), ((_, fw_u_f), (_, bw_u_f)) = bidirectional_dynamic_rnn( d_cell, d_cell, qq, q_len, dtype='float', scope='u1') # [N, J, d], [N, d] #fw_u和bw_u分别为双向lstm的output u = tf.concat(2, [fw_u, bw_u]) #[N, J, 2d] if config.share_lstm_weights: tf.get_variable_scope().reuse_variables() (fw_h, bw_h), _ = bidirectional_dynamic_rnn( cell, cell, xx, x_len, dtype='float', scope='u1') # [N, M, JX, 2d] h = tf.concat(3, [fw_h, bw_h]) # [N, M, JX, 2d] else: (fw_h, bw_h), _ = bidirectional_dynamic_rnn( cell, cell, xx, x_len, dtype='float', scope='h1') # [N, M, JX, 2d] h = tf.concat(3, [fw_h, bw_h]) # [N, M, JX, 2d] self.tensor_dict['u'] = u self.tensor_dict['h'] = h #核心层Attention Flow Layer with tf.variable_scope("main"): if config.dynamic_att: p0 = h #expand_dims()在矩阵指定位置增加维度 #tile()对矩阵的指定维度进行复制 u = tf.reshape(tf.tile(tf.expand_dims(u, 1), [1, M, 1, 1]), [ N * M, JQ, 2 * d ]) #先在索引1的位置添加一个维度,然后复制M(context中最多的sentence数量)次,使u和h能具有相同的维度 q_mask = tf.reshape( tf.tile(tf.expand_dims(self.q_mask, 1), [1, M, 1]), [N * M, JQ]) first_cell = AttentionCell( cell, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) else: p0 = attention_layer(config, self.is_train, h, u, h_mask=self.x_mask, u_mask=self.q_mask, scope="p0", tensor_dict=self.tensor_dict) first_cell = d_cell (fw_g0, bw_g0), _ = bidirectional_dynamic_rnn( first_cell, first_cell, p0, x_len, dtype='float', scope='g0') # [N, M, JX, 2d] g0 = tf.concat(3, [fw_g0, bw_g0]) (fw_g1, bw_g1), _ = bidirectional_dynamic_rnn( first_cell, first_cell, g0, x_len, dtype='float', scope='g1') # [N, M, JX, 2d] g1 = tf.concat(3, [fw_g1, bw_g1]) logits = get_logits([g1, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits1') a1i = softsel(tf.reshape(g1, [N, M * JX, 2 * d]), tf.reshape(logits, [N, M * JX])) a1i = tf.tile(tf.expand_dims(tf.expand_dims(a1i, 1), 1), [1, M, JX, 1]) (fw_g2, bw_g2), _ = bidirectional_dynamic_rnn( d_cell, d_cell, tf.concat(3, [p0, g1, a1i, g1 * a1i]), x_len, dtype='float', scope='g2') # [N, M, JX, 2d] g2 = tf.concat(3, [fw_g2, bw_g2]) logits2 = get_logits([g2, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits2') flat_logits = tf.reshape(logits, [-1, M * JX]) flat_yp = tf.nn.softmax(flat_logits) # [-1, M*JX] yp = tf.reshape(flat_yp, [-1, M, JX]) flat_logits2 = tf.reshape(logits2, [-1, M * JX]) flat_yp2 = tf.nn.softmax(flat_logits2) yp2 = tf.reshape(flat_yp2, [-1, M, JX]) self.tensor_dict['g1'] = g1 self.tensor_dict['g2'] = g2 self.logits = flat_logits self.logits2 = flat_logits2 self.yp = yp self.yp2 = yp2
def _build_forward(self): config = self.config N = config.batch_size M = config.max_num_sents JX = config.max_sent_size JQ = config.max_ques_size VW = config.word_vocab_size VC = config.char_vocab_size W = config.max_word_size d = config.hidden_size JX = tf.shape(self.x)[2] # JX max sentence size, length, JQ = tf.shape(self.q)[1] # JQ max questions size, length, is the M = tf.shape(self.x)[1] # m is the max number of sentences dc, dw, dco = config.char_emb_size, config.word_emb_size, config.char_out_size # dc = 8, each char will be map to 8-number vector, "char-level word embedding size [100]" with tf.variable_scope("emb"): if config.use_char_emb: with tf.variable_scope("emb_var"), tf.device("/cpu:0"): char_emb_mat = tf.get_variable("char_emb_mat", shape=[VC, dc], dtype='float') # 330,8 a matrix for each char to its 8-number vector with tf.variable_scope("char"): Acx = tf.nn.embedding_lookup(char_emb_mat, self.cx) # [N, M, JX, W, dc] 60,None,None,16,8, batch-size, # N is the number of batch_size # M the max number of sentences # JX is the max sentence length # W is the max length of a word # dc is the vector for each char # map each char to a vector Acq = tf.nn.embedding_lookup(char_emb_mat, self.cq) # [N, JQ, W, dc] # JQ the max length of question # W the max length of words # mao each char in questiosn to vectors Acx = tf.reshape(Acx, [-1, JX, W, dc]) Acq = tf.reshape(Acq, [-1, JQ, W, dc]) # max questions size, length, max_word_size(16), char_emb_size(8) filter_sizes = list( map(int, config.out_channel_dims.split(','))) heights = list(map(int, config.filter_heights.split(','))) # so here, there are 100 filters and the size of each filter is 5 # different heights and there are different number of these filter, but here just 100 5-long filters assert sum(filter_sizes) == dco, (filter_sizes, dco) with tf.variable_scope("conv"): xx = multi_conv1d(Acx, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") if config.share_cnn_weights: tf.get_variable_scope().reuse_variables() qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") else: qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="qq") xx = tf.reshape(xx, [-1, M, JX, dco]) qq = tf.reshape( qq, [-1, JQ, dco ]) # here, xx and qq are the output of cnn, if config.use_word_emb: with tf.variable_scope("emb_var"), tf.device("/cpu:0"): if config.mode == 'train': word_emb_mat = tf.get_variable( "word_emb_mat", dtype='float', shape=[VW, dw], initializer=get_initializer(config.emb_mat)) else: word_emb_mat = tf.get_variable("word_emb_mat", shape=[VW, dw], dtype='float') if config.use_glove_for_unk: # create a new word embedding or use the glove? word_emb_mat = tf.concat( [word_emb_mat, self.new_emb_mat], 0) with tf.name_scope("word"): Ax = tf.nn.embedding_lookup(word_emb_mat, self.x) # [N, M, JX, d] Aq = tf.nn.embedding_lookup(word_emb_mat, self.q) # [N, JQ, d] self.tensor_dict['x'] = Ax self.tensor_dict['q'] = Aq if config.use_char_emb: xx = tf.concat([xx, Ax], 3) # [N, M, JX, di] qq = tf.concat([qq, Aq], 2) # [N, JQ, di] else: xx = Ax qq = Aq # here we used cnn and word embedding represented each word with a 200-unit vector # so for, xx, (batch_size, sentence#, word#, embedding), qq (batch_size, word#, embedding) # highway network if config.highway: with tf.variable_scope("highway"): xx = highway_network(xx, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) tf.get_variable_scope().reuse_variables() qq = highway_network(qq, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) self.tensor_dict['xx'] = xx self.tensor_dict['qq'] = qq # same shape with line 173 cell = BasicLSTMCell( d, state_is_tuple=True) # d = 100, hidden state number d_cell = SwitchableDropoutWrapper( cell, self.is_train, input_keep_prob=config.input_keep_prob) x_len = tf.reduce_sum(tf.cast(self.x_mask, 'int32'), 2) # [N, M], [60,?] q_len = tf.reduce_sum(tf.cast(self.q_mask, 'int32'), 1) # [N] [60] # masks are true and false, here, he sums up those truths, with tf.variable_scope("prepro"): (fw_u, bw_u), ((_, fw_u_f), (_, bw_u_f)) = bidirectional_dynamic_rnn( d_cell, d_cell, qq, q_len, dtype='float', scope='u1') # [N, J, d], [N, d] u = tf.concat( [fw_u, bw_u], 2) # (60, ?, 200) | 200 becahse combined 2 100 hidden states if config.share_lstm_weights: tf.get_variable_scope().reuse_variables() (fw_h, bw_h), _ = bidirectional_dynamic_rnn( cell, cell, xx, x_len, dtype='float', scope='u1') # [N, M, JX, 2d] h = tf.concat([fw_h, bw_h], 3) # [N, M, JX, 2d] else: (fw_h, bw_h), _ = bidirectional_dynamic_rnn( cell, cell, xx, x_len, dtype='float', scope='h1') # [N, M, JX, 2d] h = tf.concat([fw_h, bw_h], 3) # [N, M, JX, 2d] self.tensor_dict['u'] = u # [60, ?, 200] for question self.tensor_dict['h'] = h # [60, ?, ?, 200] for article with tf.variable_scope("main"): if config.dynamic_att: # todo what is this dynamic attention. p0 = h u = tf.reshape(tf.tile(tf.expand_dims(u, 1), [1, M, 1, 1]), [N * M, JQ, 2 * d]) q_mask = tf.reshape( tf.tile(tf.expand_dims(self.q_mask, 1), [1, M, 1]), [N * M, JQ]) first_cell = AttentionCell( cell, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) else: p0 = attention_layer(config, self.is_train, h, u, h_mask=self.x_mask, u_mask=self.q_mask, scope="p0", tensor_dict=self.tensor_dict) cell2 = BasicLSTMCell( d, state_is_tuple=True) # d = 100, hidden state number first_cell = SwitchableDropoutWrapper( cell2, self.is_train, input_keep_prob=config.input_keep_prob) (fw_g0, bw_g0), _ = bidirectional_dynamic_rnn( first_cell, first_cell, inputs=p0, sequence_length=x_len, dtype='float', scope='g0') # [N, M, JX, 2d] g0 = tf.concat([fw_g0, bw_g0], 3) cell3 = BasicLSTMCell( d, state_is_tuple=True) # d = 100, hidden state number first_cell3 = SwitchableDropoutWrapper( cell3, self.is_train, input_keep_prob=config.input_keep_prob) (fw_g1, bw_g1), _ = bidirectional_dynamic_rnn( first_cell3, first_cell3, g0, x_len, dtype='float', scope='g1') # [N, M, JX, 2d] g1 = tf.concat([fw_g1, bw_g1], 3) logits = get_logits([g1, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits1') a1i = softsel(tf.reshape(g1, [N, M * JX, 2 * d]), tf.reshape(logits, [N, M * JX])) a1i = tf.tile(tf.expand_dims(tf.expand_dims(a1i, 1), 1), [1, M, JX, 1]) cell4 = BasicLSTMCell( d, state_is_tuple=True) # d = 100, hidden state number first_cell4 = SwitchableDropoutWrapper( cell4, self.is_train, input_keep_prob=config.input_keep_prob) (fw_g2, bw_g2), _ = bidirectional_dynamic_rnn( first_cell4, first_cell4, tf.concat([p0, g1, a1i, g1 * a1i], 3), x_len, dtype='float', scope='g2') # [N, M, JX, 2d] g2 = tf.concat([fw_g2, bw_g2], 3) logits2 = get_logits([g2, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits2') flat_logits = tf.reshape(logits, [-1, M * JX]) flat_yp = tf.nn.softmax(flat_logits) # [-1, M*JX] yp = tf.reshape(flat_yp, [-1, M, JX]) flat_logits2 = tf.reshape(logits2, [-1, M * JX]) flat_yp2 = tf.nn.softmax(flat_logits2) yp2 = tf.reshape(flat_yp2, [-1, M, JX]) self.tensor_dict['g1'] = g1 self.tensor_dict['g2'] = g2 self.logits = flat_logits self.logits2 = flat_logits2 self.yp = yp self.yp2 = yp2
def bi_attention(config, is_train, h, u, h_mask=None, u_mask=None, scope=None, tensor_dict=None): """ :param config: :param is_train: :param h: 2d for each word in context :param u: 2d for each word in query :param h_mask: :param u_mask: :param scope: :param tensor_dict: :return: u_a: the weighted sum of query for each context word h_a: the weighted sum of context, the weights are soft(max(relevance)), and tiled to T 2d as well """ with tf.variable_scope(scope or "bi_attention"): JX = tf.shape(h)[2] M = tf.shape(h)[1] JQ = tf.shape(u)[1] h_aug = tf.tile(tf.expand_dims(h, 3), [ 1, 1, 1, JQ, 1 ]) # tf expand dims 3 let it be [60, ?, ?, ?, 200], tile let it be u_aug = tf.tile(tf.expand_dims(tf.expand_dims(u, 1), 1), [1, M, JX, 1, 1]) if h_mask is None: hu_mask = None else: h_mask_aug = tf.tile(tf.expand_dims(h_mask, 3), [1, 1, 1, JQ]) u_mask_aug = tf.tile(tf.expand_dims(tf.expand_dims(u_mask, 1), 1), [1, M, JX, 1]) hu_mask = h_mask_aug & u_mask_aug # equation 1. u_logits = get_logits( [h_aug, u_aug], None, True, wd=config.wd, mask=hu_mask, # equation 1 is_train=is_train, func=config.logit_func, scope='u_logits') # [N, M, JX, JQ] = [60,?,?,?] u_a = softsel(u_aug, u_logits) # [N, M, JX, d] h_a = softsel(h, tf.reduce_max(u_logits, 3)) # [N, M, d] h_a = tf.tile(tf.expand_dims(h_a, 2), [1, 1, JX, 1]) if tensor_dict is not None: a_u = tf.nn.softmax(u_logits) # [N, M, JX, JQ] a_h = tf.nn.softmax(tf.reduce_max(u_logits, 3)) tensor_dict['a_u'] = a_u tensor_dict['a_h'] = a_h variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=tf.get_variable_scope().name) for var in variables: tensor_dict[var.name] = var return u_a, h_a
def _build_forward(self): config = self.config N, M, JX, JQ, VW, VC, d, W = \ config.batch_size, config.max_num_sents, config.max_sent_size, \ config.max_ques_size, config.word_vocab_size, config.char_vocab_size, config.hidden_size, \ config.max_word_size JX = tf.shape(self.x)[2] JQ = tf.shape(self.q)[1] M = tf.shape(self.x)[1] dc, dw, dco = config.char_emb_size, config.word_emb_size, config.char_out_size with tf.variable_scope("emb"): if config.use_char_emb: with tf.variable_scope("emb_var"), tf.device("/cpu:0"): char_emb_mat = tf.get_variable("char_emb_mat", shape=[VC, dc], dtype='float') with tf.variable_scope("char"): Acx = tf.nn.embedding_lookup(char_emb_mat, self.cx) # [N, M, JX, W, dc] Acq = tf.nn.embedding_lookup(char_emb_mat, self.cq) # [N, JQ, W, dc] Acx = tf.reshape(Acx, [-1, JX, W, dc]) Acq = tf.reshape(Acq, [-1, JQ, W, dc]) filter_sizes = list(map(int, config.out_channel_dims.split(','))) heights = list(map(int, config.filter_heights.split(','))) assert sum(filter_sizes) == dco, (filter_sizes, dco) with tf.variable_scope("conv"): xx = multi_conv1d(Acx, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") if config.share_cnn_weights: tf.get_variable_scope().reuse_variables() qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") else: qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="qq") xx = tf.reshape(xx, [-1, M, JX, dco]) qq = tf.reshape(qq, [-1, JQ, dco]) if config.use_word_emb: with tf.variable_scope("emb_var"), tf.device("/cpu:0"): if config.mode == 'train': word_emb_mat = tf.get_variable("word_emb_mat", dtype='float', shape=[VW, dw], initializer=get_initializer(config.emb_mat)) else: word_emb_mat = tf.get_variable("word_emb_mat", shape=[VW, dw], dtype='float') if config.use_glove_for_unk: word_emb_mat = tf.concat([word_emb_mat, self.new_emb_mat], 0) with tf.name_scope("word"): Ax = tf.nn.embedding_lookup(word_emb_mat, self.x) # [N, M, JX, d] Aq = tf.nn.embedding_lookup(word_emb_mat, self.q) # [N, JQ, d] self.tensor_dict['x'] = Ax self.tensor_dict['q'] = Aq if config.use_char_emb: xx = tf.concat([xx, Ax], 3) # [N, M, JX, di] qq = tf.concat([qq, Aq], 2) # [N, JQ, di] else: xx = Ax qq = Aq # highway network if config.highway: with tf.variable_scope("highway"): xx = highway_network(xx, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) tf.get_variable_scope().reuse_variables() qq = highway_network(qq, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) self.tensor_dict['xx'] = xx self.tensor_dict['qq'] = qq cell_fw = BasicLSTMCell(d, state_is_tuple=True) cell_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell_fw = SwitchableDropoutWrapper(cell_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell_bw = SwitchableDropoutWrapper(cell_bw, self.is_train, input_keep_prob=config.input_keep_prob) cell2_fw = BasicLSTMCell(d, state_is_tuple=True) cell2_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell2_fw = SwitchableDropoutWrapper(cell2_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell2_bw = SwitchableDropoutWrapper(cell2_bw, self.is_train, input_keep_prob=config.input_keep_prob) cell3_fw = BasicLSTMCell(d, state_is_tuple=True) cell3_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell3_fw = SwitchableDropoutWrapper(cell3_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell3_bw = SwitchableDropoutWrapper(cell3_bw, self.is_train, input_keep_prob=config.input_keep_prob) cell4_fw = BasicLSTMCell(d, state_is_tuple=True) cell4_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell4_fw = SwitchableDropoutWrapper(cell4_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell4_bw = SwitchableDropoutWrapper(cell4_bw, self.is_train, input_keep_prob=config.input_keep_prob) x_len = tf.reduce_sum(tf.cast(self.x_mask, 'int32'), 2) # [N, M] q_len = tf.reduce_sum(tf.cast(self.q_mask, 'int32'), 1) # [N] with tf.variable_scope("prepro"): (fw_u, bw_u), ((_, fw_u_f), (_, bw_u_f)) = bidirectional_dynamic_rnn(d_cell_fw, d_cell_bw, qq, q_len, dtype='float', scope='u1') # [N, J, d], [N, d] u = tf.concat(axis=2, values=[fw_u, bw_u]) if config.share_lstm_weights: tf.get_variable_scope().reuse_variables() (fw_h, bw_h), _ = bidirectional_dynamic_rnn(cell_fw, cell_bw, xx, x_len, dtype='float', scope='u1') # [N, M, JX, 2d] h = tf.concat(axis=3, values=[fw_h, bw_h]) # [N, M, JX, 2d] else: (fw_h, bw_h), _ = bidirectional_dynamic_rnn(cell_fw, cell_bw, xx, x_len, dtype='float', scope='h1') # [N, M, JX, 2d] h = tf.concat(axis=3, values=[fw_h, bw_h]) # [N, M, JX, 2d] self.tensor_dict['u'] = u self.tensor_dict['h'] = h with tf.variable_scope("main"): if config.dynamic_att: p0 = h u = tf.reshape(tf.tile(tf.expand_dims(u, 1), [1, M, 1, 1]), [N * M, JQ, 2 * d]) q_mask = tf.reshape(tf.tile(tf.expand_dims(self.q_mask, 1), [1, M, 1]), [N * M, JQ]) first_cell_fw = AttentionCell(cell2_fw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) first_cell_bw = AttentionCell(cell2_bw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) second_cell_fw = AttentionCell(cell3_fw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) second_cell_bw = AttentionCell(cell3_bw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) else: p0 = attention_layer(config, self.is_train, h, u, h_mask=self.x_mask, u_mask=self.q_mask, scope="p0", tensor_dict=self.tensor_dict) first_cell_fw = d_cell2_fw second_cell_fw = d_cell3_fw first_cell_bw = d_cell2_bw second_cell_bw = d_cell3_bw (fw_g0, bw_g0), _ = bidirectional_dynamic_rnn(first_cell_fw, first_cell_bw, p0, x_len, dtype='float', scope='g0') # [N, M, JX, 2d] g0 = tf.concat(axis=3, values=[fw_g0, bw_g0]) (fw_g1, bw_g1), _ = bidirectional_dynamic_rnn(second_cell_fw, second_cell_bw, g0, x_len, dtype='float', scope='g1') # [N, M, JX, 2d] g1 = tf.concat(axis=3, values=[fw_g1, bw_g1]) with tf.variable_scope("output"): if config.model_name == "basic": logits = get_logits([g1, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits1') a1i = softsel(tf.reshape(g1, [N, M * JX, 2 * d]), tf.reshape(logits, [N, M * JX])) a1i = tf.tile(tf.expand_dims(tf.expand_dims(a1i, 1), 1), [1, M, JX, 1]) (fw_g2, bw_g2), _ = bidirectional_dynamic_rnn(d_cell4_fw, d_cell4_bw, tf.concat([p0, g1, a1i, g1 * a1i], 3), x_len, dtype='float', scope='g2') # [N, M, JX, 2d] g2 = tf.concat([fw_g2, bw_g2], 3) logits2 = get_logits([g2, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits2') flat_logits = tf.reshape(logits, [-1, M * JX]) flat_yp = tf.nn.softmax(flat_logits) # [-1, M*JX] yp = tf.reshape(flat_yp, [-1, M, JX]) flat_logits2 = tf.reshape(logits2, [-1, M * JX]) flat_yp2 = tf.nn.softmax(flat_logits2) yp2 = tf.reshape(flat_yp2, [-1, M, JX]) self.tensor_dict['g1'] = g1 self.tensor_dict['g2'] = g2 self.logits = flat_logits self.logits2 = flat_logits2 self.yp = yp self.yp2 = yp2 elif config.model_name == "basic-class": C = 3 if config.data_dir.startswith('data/snli') else 2 (fw_g2, bw_g2) = (fw_g1, bw_g1) if config.classifier == 'maxpool': g2 = tf.concat([fw_g2, bw_g2], 3) # [N, M, JX, 2d] g2 = tf.reduce_max(g2, 2) # [N, M, 2d] g2_dim = 2 * d elif config.classifier == 'sumpool': g2 = tf.concat([fw_g2, bw_g2], 3) g2 = tf.reduce_sum(g2, 2) g2_dim = 2 * d else: fw_g2_ = tf.gather(tf.transpose(fw_g2, [2, 0, 1, 3]), JX - 1) bw_g2_ = tf.gather(tf.transpose(bw_g2, [2, 0, 1, 3]), 0) g2 = tf.concat([fw_g2_, bw_g2_], 2) g2_dim = 2 * d g2_ = tf.reshape(g2, [N, g2_dim]) logits0 = linear(g2_, C, True, wd=config.wd, input_keep_prob=config.input_keep_prob, is_train=self.is_train, scope='classifier') flat_yp0 = tf.nn.softmax(logits0) yp0 = tf.reshape(flat_yp0, [N, M, C]) self.tensor_dict['g1'] = g1 self.logits0 = logits0 self.yp0 = yp0 self.logits = logits0 self.yp = yp0
def _build_forward(self): config = self.config N, M, JX, JQ, VW, VC, d, W ,EW, WOW= \ config.batch_size, config.max_num_sents, config.max_sent_size, \ config.max_ques_size, config.word_vocab_size, config.char_vocab_size, config.hidden_size, \ config.max_word_size,config.word_vocab_size-config.vw_wo_entity_size,config.vw_wo_entity_size JX = tf.shape(self.x)[2] # words JQ = tf.shape(self.q)[1] # words M = tf.shape(self.x)[1] dc, dw, dco = config.char_emb_size, config.word_emb_size, config.char_out_size #print ("dhruv is here",N, self.x.get_shape(), JX, self.q.get_shape(), VW, VC, d, W,dc, dw, dco) with tf.variable_scope("emb"): if config.use_char_emb: with tf.variable_scope("emb_var"), tf.device("/cpu:0"): char_emb_mat = tf.get_variable("char_emb_mat", shape=[VC, dc], dtype='float') with tf.variable_scope("char"): Acx = tf.nn.embedding_lookup(char_emb_mat, self.cx) # [N, M, JX, W, dc] Acq = tf.nn.embedding_lookup(char_emb_mat, self.cq) # [N, JQ, W, dc] Acx = tf.reshape(Acx, [-1, JX, W, dc]) Acq = tf.reshape(Acq, [-1, JQ, W, dc]) filter_sizes = list(map(int, config.out_channel_dims.split(','))) heights = list(map(int, config.filter_heights.split(','))) assert sum(filter_sizes) == dco, (filter_sizes, dco) with tf.variable_scope("conv"): xx = multi_conv1d(Acx, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") if config.share_cnn_weights: tf.get_variable_scope().reuse_variables() qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") else: qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="qq") xx = tf.reshape(xx, [-1, M, JX, dco]) qq = tf.reshape(qq, [-1, JQ, dco]) if config.use_word_emb: with tf.variable_scope("emb_var"), tf.device("/cpu:0"): if config.mode == 'train': entity_emb_mat = tf.get_variable("entity_emb_mat", dtype='float', shape=[EW, EW], initializer=get_initializer(config.onehot_encoded)) entity_emb_out = _linear(entity_emb_mat, dw, True, bias_initializer=tf.constant_initializer(0.0)) word_emb_mat = tf.get_variable("word_emb_mat", dtype='float', shape=[WOW, dw], initializer=get_initializer(config.emb_mat)) word_emb_mat = tf.concat(axis=0,values=[word_emb_mat, entity_emb_out]) else: word_emb_mat = tf.get_variable("word_emb_mat", shape=[VW, dw], dtype='float') if config.use_glove_for_unk: word_emb_mat = tf.concat(axis=0, values=[word_emb_mat, self.new_emb_mat]) with tf.name_scope("word"): Ax = tf.nn.embedding_lookup(word_emb_mat, self.x) # [N, M, JX, d] i.e. [batch size, max sentences, max words, embedding size] Aq = tf.nn.embedding_lookup(word_emb_mat, self.q) # [N, JQ, d] i.e. [batch size, max words, embedding size] self.tensor_dict['x'] = Ax self.tensor_dict['q'] = Aq if config.use_char_emb: xx = tf.concat(axis=3, values=[xx, Ax]) # [N, M, JX, di] qq = tf.concat(axis=2, values=[qq, Aq]) # [N, JQ, di] else: xx = Ax qq = Aq # highway network if config.highway: with tf.variable_scope("highway"): xx = highway_network(xx, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) tf.get_variable_scope().reuse_variables() qq = highway_network(qq, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) self.tensor_dict['xx'] = xx self.tensor_dict['qq'] = qq #xx = tf.Print(xx,[tf.shape(xx),xx],message="DHRUV xx=",summarize=20) cell_fw = BasicLSTMCell(d, state_is_tuple=True) cell_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell_fw = SwitchableDropoutWrapper(cell_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell_bw = SwitchableDropoutWrapper(cell_bw, self.is_train, input_keep_prob=config.input_keep_prob) cell2_fw = BasicLSTMCell(d, state_is_tuple=True) cell2_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell2_fw = SwitchableDropoutWrapper(cell2_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell2_bw = SwitchableDropoutWrapper(cell2_bw, self.is_train, input_keep_prob=config.input_keep_prob) cell3_fw = BasicLSTMCell(d, state_is_tuple=True) cell3_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell3_fw = SwitchableDropoutWrapper(cell3_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell3_bw = SwitchableDropoutWrapper(cell3_bw, self.is_train, input_keep_prob=config.input_keep_prob) cell4_fw = BasicLSTMCell(d, state_is_tuple=True) cell4_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell4_fw = SwitchableDropoutWrapper(cell4_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell4_bw = SwitchableDropoutWrapper(cell4_bw, self.is_train, input_keep_prob=config.input_keep_prob) x_len = tf.reduce_sum(tf.cast(self.x_mask, 'int32'), 2) # [N, M] q_len = tf.reduce_sum(tf.cast(self.q_mask, 'int32'), 1) # [N] with tf.variable_scope("prepro"): (fw_u, bw_u), ((_, fw_u_f), (_, bw_u_f)) = bidirectional_dynamic_rnn(d_cell_fw, d_cell_bw, qq, q_len, dtype='float', scope='u1') # [N, J, d], [N, d] u = tf.concat(axis=2, values=[fw_u, bw_u]) if config.share_lstm_weights: tf.get_variable_scope().reuse_variables() (fw_h, bw_h), _ = bidirectional_dynamic_rnn(cell_fw, cell_bw, xx, x_len, dtype='float', scope='u1') # [N, M, JX, 2d] h = tf.concat(axis=3, values=[fw_h, bw_h]) # [N, M, JX, 2d] else: (fw_h, bw_h), _ = bidirectional_dynamic_rnn(cell_fw, cell_bw, xx, x_len, dtype='float', scope='h1') # [N, M, JX, 2d] h = tf.concat(axis=3, values=[fw_h, bw_h]) # [N, M, JX, 2d] self.tensor_dict['u'] = u self.tensor_dict['h'] = h with tf.variable_scope("main"): if config.dynamic_att: # not true p0 = h u = tf.reshape(tf.tile(tf.expand_dims(u, 1), [1, M, 1, 1]), [N * M, JQ, 2 * d]) q_mask = tf.reshape(tf.tile(tf.expand_dims(self.q_mask, 1), [1, M, 1]), [N * M, JQ]) first_cell_fw = AttentionCell(cell2_fw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) first_cell_bw = AttentionCell(cell2_bw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) second_cell_fw = AttentionCell(cell3_fw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) second_cell_bw = AttentionCell(cell3_bw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) else: p0 = attention_layer(config, self.is_train, h, u, h_mask=self.x_mask, u_mask=self.q_mask, scope="p0", tensor_dict=self.tensor_dict) # p0 seems to be G in paper first_cell_fw = d_cell2_fw second_cell_fw = d_cell3_fw first_cell_bw = d_cell2_bw second_cell_bw = d_cell3_bw (fw_g0, bw_g0), _ = bidirectional_dynamic_rnn(first_cell_fw, first_cell_bw, p0, x_len, dtype='float', scope='g0') # [N, M, JX, 2d] g0 = tf.concat(axis=3, values=[fw_g0, bw_g0]) (fw_g1, bw_g1), _ = bidirectional_dynamic_rnn(second_cell_fw, second_cell_bw, g0, x_len, dtype='float', scope='g1') # [N, M, JX, 2d] g1 = tf.concat(axis=3, values=[fw_g1, bw_g1]) # g1 seems to be M in paper g1= tf.Print(g1,[tf.shape(g1)],message="g1 shape",first_n=5,summarize=200) p0 = tf.Print(p0, [tf.shape(p0)], message="p0 shape", first_n=5, summarize=200) my_cell_fw = BasicLSTMCell(d, state_is_tuple=True) my_cell_fw_d = SwitchableDropoutWrapper(my_cell_fw, self.is_train, input_keep_prob=config.input_keep_prob) my_cell_bw = BasicLSTMCell(d, state_is_tuple=True) my_cell_bw_d = SwitchableDropoutWrapper(my_cell_bw, self.is_train, input_keep_prob=config.input_keep_prob) (fw_g11,bw_g11),(my_fw_final_state, my_bw_final_state),g11_len = my_bidirectional_dynamic_rnn(my_cell_fw_d, my_cell_bw_d, g1, x_len, dtype='float', scope='my_g2') # [N, M, JX, 2d] g11 = tf.concat(axis=2, values=[fw_g11, bw_g11]) my_encoder_final_state_c = tf.concat(values = (my_fw_final_state.c, my_bw_final_state.c), axis = 1, name = "my_encoder_final_state_c") my_encoder_final_state_h = tf.concat(values = (my_fw_final_state.h, my_bw_final_state.h), axis = 1, name = "my_encoder_final_state_h") my_encoder_final_state = tf.contrib.rnn.LSTMStateTuple(c = my_encoder_final_state_c, h = my_encoder_final_state_h) #compute indices for finding span as the second task in multi task learning logits = get_logits([g1, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits1') logits = tf.Print(logits, [tf.shape(logits)], message="logits shape", first_n=5, summarize=200) a1i = softsel(tf.reshape(g1, [N, M * JX, 2 * d]), tf.reshape(logits, [N, M * JX])) a1i = tf.tile(tf.expand_dims(tf.expand_dims(a1i, 1), 1), [1, M, JX, 1]) (fw_g2, bw_g2), _ = bidirectional_dynamic_rnn(d_cell4_fw, d_cell4_bw, tf.concat(axis=3, values=[p0, g1, a1i, g1 * a1i]), x_len, dtype='float', scope='g2') # [N, M, JX, 2d] g2 = tf.concat(axis=3, values=[fw_g2, bw_g2]) logits2 = get_logits([g2, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits2') flat_logits = tf.reshape(logits, [-1, M * JX]) flat_logits = tf.Print(flat_logits, [tf.shape(flat_logits),flat_logits], message="flat_logits shape and contents", first_n=5, summarize=200) flat_yp = tf.nn.softmax(flat_logits) # [-1, M*JX] flat_logits2 = tf.reshape(logits2, [-1, M * JX]) flat_yp2 = tf.nn.softmax(flat_logits2) tgt_vocab_size = config.len_new_emb_mat # hparam # FIXME: Obtain embeddings differently? print("length is",config.len_new_emb_mat) tgt_embedding_size = dw # hparam # Look up embedding decoder_emb_inp = tf.nn.embedding_lookup(word_emb_mat, self.decoder_inputs) # [batch_size, max words, embedding_size] def decode_with_attention(helper, scope, reuse=None,maximum_iterations=None): with tf.variable_scope(scope, reuse=reuse): attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units=d, memory=g11) cell = tf.contrib.rnn.GRUCell(num_units=d) attn_cell = tf.contrib.seq2seq.AttentionWrapper(cell, attention_mechanism,attention_layer_size=d /2) out_cell = tf.contrib.rnn.OutputProjectionWrapper(attn_cell, tgt_vocab_size, reuse=reuse) decoder = tf.contrib.seq2seq.BasicDecoder(cell=out_cell, helper=helper,initial_state=out_cell.zero_state( dtype=tf.float32, batch_size=N)) # initial_state=encoder_final_state) outputs = tf.contrib.seq2seq.dynamic_decode(decoder=decoder, output_time_major=False, impute_finished=True, maximum_iterations=maximum_iterations) return outputs[0] def decode(helper, scope, reuse=None, maximum_iterations=None): with tf.variable_scope(scope, reuse=reuse): decoder_cell = BasicLSTMCell(2 * d, state_is_tuple=True) # hparam projection_layer = layers_core.Dense(tgt_vocab_size, use_bias=False) # hparam decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, my_encoder_final_state,output_layer=projection_layer) # decoder final_outputs, _ ,_= tf.contrib.seq2seq.dynamic_decode(decoder, output_time_major=False, impute_finished=True, maximum_iterations=maximum_iterations) # dynamic decoding return final_outputs # Decoder if config.mode == 'train': #TODO:doesnt seem to be correct to use this variable for dev training_helper = tf.contrib.seq2seq.TrainingHelper(decoder_emb_inp, self.target_sequence_length,time_major=False) #final_outputs = decode(helper=training_helper, scope="HAHA", reuse=None) final_outputs = decode_with_attention(helper=training_helper, scope="HAHA", reuse=None) else: training_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(word_emb_mat, tf.fill([N], self.tgt_sos_id),self.tgt_eos_id) #final_outputs= decode(helper=training_helper, scope="HAHA", reuse=True,maximum_iterations=100) final_outputs= decode_with_attention(helper=training_helper, scope="HAHA", reuse=True,maximum_iterations=100) self.decoder_logits_train = final_outputs.rnn_output self.index_start = flat_logits self.index_end = flat_logits2
def _build_forward(self): config = self.config N, M, JX, JQ, VW, VC, d, W = \ config.batch_size, config.max_num_sents, config.max_sent_size, \ config.max_ques_size, config.word_vocab_size, config.char_vocab_size, config.hidden_size, \ config.max_word_size JX = tf.shape(self.x)[2] JQ = tf.shape(self.q)[1] M = tf.shape(self.x)[1] dc, dw, dco = config.char_emb_size, config.word_emb_size, config.char_out_size with tf.variable_scope("emb"): if config.use_char_emb: with tf.variable_scope("emb_var"), tf.device("/cpu:0"): char_emb_mat = tf.get_variable("char_emb_mat", shape=[VC, dc], dtype='float') with tf.variable_scope("char"): Acx = tf.nn.embedding_lookup(char_emb_mat, self.cx) # [N, M, JX, W, dc] Acq = tf.nn.embedding_lookup(char_emb_mat, self.cq) # [N, JQ, W, dc] Acx = tf.reshape(Acx, [-1, JX, W, dc]) Acq = tf.reshape(Acq, [-1, JQ, W, dc]) filter_sizes = list( map(int, config.out_channel_dims.split(','))) heights = list(map(int, config.filter_heights.split(','))) assert sum(filter_sizes) == dco with tf.variable_scope("conv"): xx = multi_conv1d(Acx, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") if config.share_cnn_weights: tf.get_variable_scope().reuse_variables() qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") else: qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="qq") xx = tf.reshape(xx, [-1, M, JX, dco]) qq = tf.reshape(qq, [-1, JQ, dco]) if config.use_word_emb: with tf.variable_scope("emb_var"), tf.device("/cpu:0"): if config.mode == 'train': word_emb_mat = tf.get_variable( "word_emb_mat", dtype='float', shape=[VW, dw], initializer=get_initializer(config.emb_mat)) else: word_emb_mat = tf.get_variable("word_emb_mat", shape=[VW, dw], dtype='float') if config.use_glove_for_unk: word_emb_mat = tf.concat( 0, [word_emb_mat, self.new_emb_mat]) with tf.name_scope("word"): Ax = tf.nn.embedding_lookup(word_emb_mat, self.x) # [N, M, JX, d] Aq = tf.nn.embedding_lookup(word_emb_mat, self.q) # [N, JQ, d] self.tensor_dict['x'] = Ax self.tensor_dict['q'] = Aq if config.use_char_emb: xx = tf.concat(3, [xx, Ax]) # [N, M, JX, di] qq = tf.concat(2, [qq, Aq]) # [N, JQ, di] else: xx = Ax qq = Aq # highway network if config.highway: with tf.variable_scope("highway"): xx = highway_network(xx, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) tf.get_variable_scope().reuse_variables() qq = highway_network(qq, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) self.tensor_dict['xx'] = xx self.tensor_dict['qq'] = qq cell = BasicLSTMCell(d, state_is_tuple=True) d_cell = SwitchableDropoutWrapper( cell, self.is_train, input_keep_prob=config.input_keep_prob) x_len = tf.reduce_sum(tf.cast(self.x_mask, 'int32'), 2) # [N, M] q_len = tf.reduce_sum(tf.cast(self.q_mask, 'int32'), 1) # [N] with tf.variable_scope("prepro"): (fw_u, bw_u), ((_, fw_u_f), (_, bw_u_f)) = bidirectional_dynamic_rnn( d_cell, d_cell, qq, q_len, dtype='float', scope='u1') # [N, J, d], [N, d] u = tf.concat(2, [fw_u, bw_u]) if config.two_prepro_layers: (fw_u, bw_u), ((_, fw_u_f), (_, bw_u_f)) = bidirectional_dynamic_rnn( d_cell, d_cell, u, q_len, dtype='float', scope='u2') # [N, J, d], [N, d] u = tf.concat(2, [fw_u, bw_u]) if config.share_lstm_weights: tf.get_variable_scope().reuse_variables() (fw_h, bw_h), _ = bidirectional_dynamic_rnn( cell, cell, xx, x_len, dtype='float', scope='u1') # [N, M, JX, 2d] h = tf.concat(3, [fw_h, bw_h]) # [N, M, JX, 2d] if config.two_prepro_layers: (fw_h, bw_h), _ = bidirectional_dynamic_rnn( cell, cell, h, x_len, dtype='float', scope='u2') # [N, M, JX, 2d] h = tf.concat(3, [fw_h, bw_h]) # [N, M, JX, 2d] else: (fw_h, bw_h), _ = bidirectional_dynamic_rnn( cell, cell, xx, x_len, dtype='float', scope='h1') # [N, M, JX, 2d] h = tf.concat(3, [fw_h, bw_h]) # [N, M, JX, 2d] if config.two_prepro_layers: (fw_h, bw_h), _ = bidirectional_dynamic_rnn( cell, cell, h, x_len, dtype='float', scope='h2') # [N, M, JX, 2d] h = tf.concat(3, [fw_h, bw_h]) # [N, M, JX, 2d] self.tensor_dict['u'] = u self.tensor_dict['h'] = h with tf.variable_scope("main"): if config.dynamic_att: p0 = h u = tf.reshape(tf.tile(tf.expand_dims(u, 1), [1, M, 1, 1]), [N * M, JQ, 2 * d]) q_mask = tf.reshape( tf.tile(tf.expand_dims(self.q_mask, 1), [1, M, 1]), [N * M, JQ]) first_cell = AttentionCell( cell, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) else: p0 = attention_layer(config, self.is_train, h, u, h_mask=self.x_mask, u_mask=self.q_mask, scope="p0", tensor_dict=self.tensor_dict) first_cell = d_cell (fw_g0, bw_g0), _ = bidirectional_dynamic_rnn( first_cell, first_cell, p0, x_len, dtype='float', scope='g0') # [N, M, JX, 2d] g0 = tf.concat(3, [fw_g0, bw_g0]) (fw_g1, bw_g1), _ = bidirectional_dynamic_rnn( first_cell, first_cell, g0, x_len, dtype='float', scope='g1') # [N, M, JX, 2d] g1 = tf.concat(3, [fw_g1, bw_g1]) if config.late: (fw_g2, bw_g2), _ = bidirectional_dynamic_rnn( d_cell, d_cell, tf.concat(3, [g1, p0]), x_len, dtype='float', scope='g2') # [N, M, JX, 2d] g2 = tf.concat(3, [fw_g2, bw_g2]) # logits2 = u_logits(config, self.is_train, tf.concat(3, [g1, a1i]), u, h_mask=self.x_mask, u_mask=self.q_mask, scope="logits2") logits = get_logits([g1, g2, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits1') if config.feed_gt: logy = tf.log(tf.cast(self.y, 'float') + VERY_SMALL_NUMBER) logits = tf.cond(self.is_train, lambda: logy, lambda: logits) if config.feed_hard: hard_yp = tf.argmax(tf.reshape(logits, [N, M * JX]), 1) hard_logits = tf.reshape(tf.one_hot(hard_yp, M * JX), [N, M, JX]) # [N, M, JX] logits = tf.cond(self.is_train, lambda: logits, lambda: hard_logits) flat_logits = tf.reshape(logits, [-1, M * JX]) flat_yp = tf.nn.softmax(flat_logits) # [-1, M*JX] yp = tf.reshape(flat_yp, [-1, M, JX]) logits2 = get_logits([g1, g2, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits2') flat_logits2 = tf.reshape(logits2, [-1, M * JX]) flat_yp2 = tf.nn.softmax(flat_logits2) yp2 = tf.reshape(flat_yp2, [-1, M, JX]) else: logits = get_logits([g1, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits1') a1i = softsel(tf.reshape(g1, [N, M * JX, 2 * d]), tf.reshape(logits, [N, M * JX])) if config.feed_gt: logy = tf.log(tf.cast(self.y, 'float') + VERY_SMALL_NUMBER) logits = tf.cond(self.is_train, lambda: logy, lambda: logits) if config.feed_hard: hard_yp = tf.argmax(tf.reshape(logits, [N, M * JX]), 1) hard_logits = tf.reshape(tf.one_hot(hard_yp, M * JX), [N, M, JX]) # [N, M, JX] logits = tf.cond(self.is_train, lambda: logits, lambda: hard_logits) flat_logits = tf.reshape(logits, [-1, M * JX]) flat_yp = tf.nn.softmax(flat_logits) # [-1, M*JX] yp = tf.reshape(flat_yp, [-1, M, JX]) a1i = tf.tile(tf.expand_dims(tf.expand_dims(a1i, 1), 1), [1, M, JX, 1]) yp_aug = tf.expand_dims(yp, -1) g1yp = g1 * yp_aug if config.prev_mode == 'a': prev = a1i elif config.prev_mode == 'y': prev = yp_aug elif config.prev_mode == 'gy': prev = g1yp else: raise Exception() (fw_g2, bw_g2), _ = bidirectional_dynamic_rnn( d_cell, d_cell, tf.concat(3, [p0, g1, prev, g1 * prev]), x_len, dtype='float', scope='g2') # [N, M, JX, 2d] g2 = tf.concat(3, [fw_g2, bw_g2]) # logits2 = u_logits(config, self.is_train, tf.concat(3, [g1, a1i]), u, h_mask=self.x_mask, u_mask=self.q_mask, scope="logits2") logits2 = get_logits([g2, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits2') flat_logits2 = tf.reshape(logits2, [-1, M * JX]) flat_yp2 = tf.nn.softmax(flat_logits2) yp2 = tf.reshape(flat_yp2, [-1, M, JX]) self.tensor_dict['g1'] = g1 self.tensor_dict['g2'] = g2 self.logits = flat_logits self.logits2 = flat_logits2 self.yp = yp self.yp2 = yp2
def _build_forward(self): config = self.config N, M, JX, JQ, VW, d, W = \ config.batch_size, config.max_num_sents, config.max_sent_size, \ config.max_ques_size, config.word_vocab_size, config.hidden_size, \ config.max_word_size JX = tf.shape(self.x)[2] JQ = tf.shape(self.q)[1] M = tf.shape(self.x)[1] dc, dw, dco = config.char_emb_size, config.word_emb_size, config.char_out_size with tf.variable_scope("emb"): print('word embedding') # if config.use_char_emb: # with tf.variable_scope("emb_var"), tf.device("/cpu:0"): # char_emb_mat = tf.get_variable("char_emb_mat", shape=[VC, dc], dtype='float') # with tf.variable_scope("char"): # Acx = tf.nn.embedding_lookup(char_emb_mat, self.cx) # [N, M, JX, W, dc] # Acq = tf.nn.embedding_lookup(char_emb_mat, self.cq) # [N, JQ, W, dc] # Acx = tf.reshape(Acx, [-1, JX, W, dc]) # Acq = tf.reshape(Acq, [-1, JQ, W, dc]) # filter_sizes = list(map(int, config.out_channel_dims.split(','))) # heights = list(map(int, config.filter_heights.split(','))) # assert sum(filter_sizes) == dco, (filter_sizes, dco) # with tf.variable_scope("conv"): # xx = multi_conv1d(Acx, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") # if config.share_cnn_weights: # tf.get_variable_scope().reuse_variables() # qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") # else: # qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="qq") # xx = tf.reshape(xx, [-1, M, JX, dco]) # qq = tf.reshape(qq, [-1, JQ, dco]) if config.use_word_emb: with tf.variable_scope("emb_var"), tf.device("/gpu:7"): if config.mode == 'train': word_emb_mat = tf.get_variable("word_emb_mat", dtype='float', shape=[VW, dw], initializer=get_initializer(config.emb_mat)) else: word_emb_mat = tf.get_variable("word_emb_mat", shape=[VW, dw], dtype='float') # if config.use_glove_for_unk: # word_emb_mat = tf.concat(0, [word_emb_mat]) print(word_emb_mat) with tf.name_scope("word"): print('embedding lookup') Ax = tf.nn.embedding_lookup(word_emb_mat, self.x) # [N, M, JX, d] Aq = tf.nn.embedding_lookup(word_emb_mat, self.q) # [N, JQ, d] self.tensor_dict['x'] = Ax self.tensor_dict['q'] = Aq print('embedding lookup ready') # if config.use_char_emb: # xx = tf.concat(3, [xx, Ax]) # [N, M, JX, di] # qq = tf.concat(2, [qq, Aq]) # [N, JQ, di] # else: xx = Ax qq = Aq # highway network #if config.highway: # with tf.variable_scope("highway"): # xx = highway_network(xx, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) # tf.get_variable_scope().reuse_variables() # qq = highway_network(qq, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) self.tensor_dict['xx'] = xx self.tensor_dict['qq'] = qq print('context emmbedding') cell = BasicLSTMCell(d, state_is_tuple=True) d_cell = SwitchableDropoutWrapper(cell, self.is_train, input_keep_prob=config.input_keep_prob) x_len = tf.reduce_sum(tf.cast(self.x_mask, 'int32'), 2) # [N, M] q_len = tf.reduce_sum(tf.cast(self.q_mask, 'int32'), 1) # [N] print('prepro') with tf.variable_scope("prepro"): (fw_u, bw_u), ((_, fw_u_f), (_, bw_u_f)) = bidirectional_dynamic_rnn(d_cell, d_cell, qq, q_len, dtype='float32', scope='u1') # [N, J, d], [N, d] u = tf.concat(2, [fw_u, bw_u]) if config.share_lstm_weights: tf.get_variable_scope().reuse_variables() (fw_h, bw_h), _ = bidirectional_dynamic_rnn(cell, cell, xx, x_len, dtype='float', scope='u1') # [N, M, JX, 2d] h = tf.concat(3, [fw_h, bw_h]) # [N, M, JX, 2d] else: (fw_h, bw_h), _ = bidirectional_dynamic_rnn(cell, cell, xx, x_len, dtype='float', scope='h1') # [N, M, JX, 2d] h = tf.concat(3, [fw_h, bw_h]) # [N, M, JX, 2d] self.tensor_dict['u'] = u self.tensor_dict['h'] = h print('main pro') with tf.variable_scope("main"): if config.dynamic_att: p0 = h u = tf.reshape(tf.tile(tf.expand_dims(u, 1), [1, M, 1, 1]), [N * M, JQ, 2 * d]) q_mask = tf.reshape(tf.tile(tf.expand_dims(self.q_mask, 1), [1, M, 1]), [N * M, JQ]) first_cell = AttentionCell(cell, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) else: p0 = attention_layer(config, self.is_train, h, u, h_mask=self.x_mask, u_mask=self.q_mask, scope="p0", tensor_dict=self.tensor_dict) first_cell = d_cell (fw_g0, bw_g0), _ = bidirectional_dynamic_rnn(first_cell, first_cell, p0, x_len, dtype='float', scope='g0') # [N, M, JX, 2d] g0 = tf.concat(3, [fw_g0, bw_g0]) (fw_g1, bw_g1), _ = bidirectional_dynamic_rnn(first_cell, first_cell, g0, x_len, dtype='float', scope='g1') # [N, M, JX, 2d] g1 = tf.concat(3, [fw_g1, bw_g1]) logits = get_logits([g1, p0], [config.batch_size,config.max_num_sents] , True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func='sigmoid', scope='logits1') # a1i = softsel(tf.reshape(g1, [N, M * JX, 2 * d]), tf.reshape(logits, [N, M * JX])) # a1i = tf.tile(tf.expand_dims(tf.expand_dims(a1i, 1), 1), [1, M, JX, 1]) # (fw_g2, bw_g2), _ = bidirectional_dynamic_rnn(d_cell, d_cell, tf.concat(3, [p0, g1, a1i, g1 * a1i]), # x_len, dtype='float', scope='g2') # [N, M, JX, 2d] # g2 = tf.concat(3, [fw_g2, bw_g2]) # logits2 = get_logits([g2, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, # mask=self.x_mask, # is_train=self.is_train, func=config.answer_func, scope='logits2') # flat_logits = tf.reshape(logits, [-1, M * JX]) # flat_yp = tf.nn.softmax(flat_logits) # [-1, M*JX] # yp = tf.reshape(flat_yp, [-1, M, JX]) yp = tf.greater(0.5, logits) # flat_logits2 = tf.reshape(logits2, [-1, M * JX]) # flat_yp2 = tf.nn.softmax(flat_logits2) # yp2 = tf.reshape(flat_yp2, [-1, M, JX]) self.tensor_dict['g1'] = g1 # self.tensor_dict['g2'] = g2 self.logits = logits # self.logits2 = flat_logits2 self.yp = yp
def _build_forward(self): config = self.config N, M, JX, JQ, VW, VC, d, W = \ config.batch_size, config.max_num_sents, config.max_sent_size, \ config.max_ques_size, config.word_vocab_size, config.char_vocab_size, config.hidden_size, \ config.max_word_size beam_width = config.beam_width GO_TOKEN = 0 EOS_TOKEN = 1 JX = tf.shape(self.x)[2] JQ = tf.shape(self.q)[1] M = tf.shape(self.x)[1] dc, dw, dco = config.char_emb_size, config.word_emb_size, config.char_out_size with tf.variable_scope("emb"): if config.use_char_emb: with tf.variable_scope("emb_var"), tf.device("/cpu:0"): char_emb_mat = tf.get_variable("char_emb_mat", shape=[VC, dc], dtype='float') with tf.variable_scope("char"): Acx = tf.nn.embedding_lookup(char_emb_mat, self.cx) # [N, M, JX, W, dc] Acq = tf.nn.embedding_lookup(char_emb_mat, self.cq) # [N, JQ, W, dc] Acx = tf.reshape(Acx, [-1, JX, W, dc]) Acq = tf.reshape(Acq, [-1, JQ, W, dc]) filter_sizes = list( map(int, config.out_channel_dims.split(','))) heights = list(map(int, config.filter_heights.split(','))) assert sum(filter_sizes) == dco, (filter_sizes, dco) with tf.variable_scope("conv"): xx = multi_conv1d(Acx, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") if config.share_cnn_weights: tf.get_variable_scope().reuse_variables() qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") else: qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="qq") xx = tf.reshape(xx, [-1, M, JX, dco]) qq = tf.reshape(qq, [-1, JQ, dco]) if config.use_word_emb: with tf.variable_scope("emb_var"), tf.device("/cpu:0"): if config.mode == 'train': word_emb_mat = tf.get_variable( "word_emb_mat", dtype='float', shape=[VW, dw], initializer=get_initializer(config.emb_mat), trainable=True) else: word_emb_mat = tf.get_variable("word_emb_mat", shape=[VW, dw], dtype='float') if config.use_glove_for_unk: word_emb_mat = tf.concat( axis=0, values=[word_emb_mat, self.new_emb_mat]) with tf.name_scope("word"): Ax = tf.nn.embedding_lookup(word_emb_mat, self.x) # [N, M, JX, d] Aq = tf.nn.embedding_lookup(word_emb_mat, self.q) # [N, JQ, d] self.tensor_dict['x'] = Ax self.tensor_dict['q'] = Aq if config.use_char_emb: xx = tf.concat(axis=3, values=[xx, Ax]) # [N, M, JX, di] qq = tf.concat(axis=2, values=[qq, Aq]) # [N, JQ, di] else: xx = Ax qq = Aq # highway network if config.highway: with tf.variable_scope("highway"): xx = highway_network(xx, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) tf.get_variable_scope().reuse_variables() qq = highway_network(qq, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) self.tensor_dict['xx'] = xx self.tensor_dict['qq'] = qq cell_fw = BasicLSTMCell(d, state_is_tuple=True) cell_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell_fw = SwitchableDropoutWrapper( cell_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell_bw = SwitchableDropoutWrapper( cell_bw, self.is_train, input_keep_prob=config.input_keep_prob) cell2_fw = BasicLSTMCell(d, state_is_tuple=True) cell2_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell2_fw = SwitchableDropoutWrapper( cell2_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell2_bw = SwitchableDropoutWrapper( cell2_bw, self.is_train, input_keep_prob=config.input_keep_prob) cell3_fw = BasicLSTMCell(d, state_is_tuple=True) cell3_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell3_fw = SwitchableDropoutWrapper( cell3_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell3_bw = SwitchableDropoutWrapper( cell3_bw, self.is_train, input_keep_prob=config.input_keep_prob) cell4_fw = BasicLSTMCell(d, state_is_tuple=True) cell4_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell4_fw = SwitchableDropoutWrapper( cell4_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell4_bw = SwitchableDropoutWrapper( cell4_bw, self.is_train, input_keep_prob=config.input_keep_prob) x_len = tf.reduce_sum(tf.cast(self.x_mask, 'int32'), 2) # [N, M] q_len = tf.reduce_sum(tf.cast(self.q_mask, 'int32'), 1) # [N] with tf.variable_scope("prepro"): (fw_u, bw_u), ((_, fw_u_f), (_, bw_u_f)) = bidirectional_dynamic_rnn( d_cell_fw, d_cell_bw, qq, q_len, dtype='float', scope='u1') # [N, J, d], [N, d] u = tf.concat(axis=2, values=[fw_u, bw_u]) if config.share_lstm_weights: tf.get_variable_scope().reuse_variables() (fw_h, bw_h), ((_, fw_h_f), (_, bw_h_f)) = bidirectional_dynamic_rnn( cell_fw, cell_bw, xx, x_len, dtype='float', scope='u1') # [N, M, JX, 2d] h = tf.concat(axis=3, values=[fw_h, bw_h]) # [N, M, JX, 2d] else: (fw_h, bw_h), ((_, fw_h_f), (_, bw_h_f)) = bidirectional_dynamic_rnn( cell_fw, cell_bw, xx, x_len, dtype='float', scope='h1') # [N, M, JX, 2d] h = tf.concat(axis=3, values=[fw_h, bw_h]) # [N, M, JX, 2d] self.tensor_dict['u'] = u self.tensor_dict['h'] = h with tf.variable_scope("main"): if config.dynamic_att: p0 = h u = tf.reshape(tf.tile(tf.expand_dims(u, 1), [1, M, 1, 1]), [N * M, JQ, 2 * d]) q_mask = tf.reshape( tf.tile(tf.expand_dims(self.q_mask, 1), [1, M, 1]), [N * M, JQ]) first_cell_fw = AttentionCell( cell2_fw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) first_cell_bw = AttentionCell( cell2_bw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) second_cell_fw = AttentionCell( cell3_fw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) second_cell_bw = AttentionCell( cell3_bw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) else: p0 = attention_layer(config, self.is_train, h, u, h_mask=self.x_mask, u_mask=self.q_mask, scope="p0", tensor_dict=self.tensor_dict) first_cell_fw = d_cell2_fw second_cell_fw = d_cell3_fw first_cell_bw = d_cell2_bw second_cell_bw = d_cell3_bw (fw_g0, bw_g0), _ = bidirectional_dynamic_rnn( first_cell_fw, first_cell_bw, p0, x_len, dtype='float', scope='g0') # [N, M, JX, 2d] g0 = tf.concat(axis=3, values=[fw_g0, bw_g0]) (fw_g1, bw_g1), _ = bidirectional_dynamic_rnn( second_cell_fw, second_cell_bw, g0, x_len, dtype='float', scope='g1') # [N, M, JX, 2d] g1 = tf.concat(axis=3, values=[fw_g1, bw_g1]) logits = get_logits([g1, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits1') a1i = softsel(tf.reshape(g1, [N, M * JX, 2 * d]), tf.reshape(logits, [N, M * JX])) a1i = tf.tile(tf.expand_dims(tf.expand_dims(a1i, 1), 1), [1, M, JX, 1]) (fw_g2, bw_g2), _ = bidirectional_dynamic_rnn( d_cell4_fw, d_cell4_bw, tf.concat(axis=3, values=[p0, g1, a1i, g1 * a1i]), x_len, dtype='float', scope='g2') # [N, M, JX, 2d] g2 = tf.concat(axis=3, values=[fw_g2, bw_g2]) logits2 = get_logits([g2, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits2') flat_logits = tf.reshape(logits, [-1, M * JX]) flat_yp = tf.nn.softmax(flat_logits) # [-1, M*JX] flat_logits2 = tf.reshape(logits2, [-1, M * JX]) flat_yp2 = tf.nn.softmax(flat_logits2) if config.na: na_bias = tf.get_variable("na_bias", shape=[], dtype='float') na_bias_tiled = tf.tile(tf.reshape(na_bias, [1, 1]), [N, 1]) # [N, 1] concat_flat_logits = tf.concat( axis=1, values=[na_bias_tiled, flat_logits]) concat_flat_yp = tf.nn.softmax(concat_flat_logits) na_prob = tf.squeeze(tf.slice(concat_flat_yp, [0, 0], [-1, 1]), [1]) flat_yp = tf.slice(concat_flat_yp, [0, 1], [-1, -1]) concat_flat_logits2 = tf.concat( axis=1, values=[na_bias_tiled, flat_logits2]) concat_flat_yp2 = tf.nn.softmax(concat_flat_logits2) na_prob2 = tf.squeeze( tf.slice(concat_flat_yp2, [0, 0], [-1, 1]), [1]) # [N] flat_yp2 = tf.slice(concat_flat_yp2, [0, 1], [-1, -1]) self.concat_logits = concat_flat_logits self.concat_logits2 = concat_flat_logits2 self.na_prob = na_prob * na_prob2 yp = tf.reshape(flat_yp, [-1, M, JX]) yp2 = tf.reshape(flat_yp2, [-1, M, JX]) wyp = tf.nn.sigmoid(logits2) self.tensor_dict['g1'] = g1 self.tensor_dict['g2'] = g2 self.logits = flat_logits self.logits2 = flat_logits2 self.yp = yp self.yp2 = yp2 self.wyp = wyp with tf.variable_scope("q_gen"): # Question Generation Using (Paragraph & Predicted Ans Pos) NM = config.max_num_sents * config.batch_size # Separated encoder #ss = tf.reshape(xx, (-1, JX, dw+dco)) q_worthy = tf.reduce_sum( tf.to_int32(self.y), axis=2 ) # so we get probability distribution of answer-likely. (N, M) q_worthy = tf.expand_dims(tf.to_int32(tf.argmax(q_worthy, axis=1)), axis=1) # (N) -> (N, 1) q_worthy = tf.concat([ tf.expand_dims(tf.range(0, N, dtype=tf.int32), axis=1), q_worthy ], axis=1) # example : [0, 9], [1, 11], [2, 8], [3, 5], [4, 0], [5, 1] ... ss = tf.gather_nd(xx, q_worthy) syp = tf.expand_dims(tf.gather_nd(yp, q_worthy), axis=-1) syp2 = tf.expand_dims(tf.gather_nd(yp2, q_worthy), axis=-1) ss_with_ans = tf.concat([ss, syp, syp2], axis=2) qg_dim = 600 cell_fw, cell_bw = rnn.DropoutWrapper(rnn.GRUCell(qg_dim), input_keep_prob=config.input_keep_prob), \ rnn.DropoutWrapper(rnn.GRUCell(qg_dim), input_keep_prob=config.input_keep_prob) s_outputs, s_states = tf.nn.bidirectional_dynamic_rnn( cell_fw, cell_bw, ss_with_ans, dtype=tf.float32) s_outputs = tf.concat(s_outputs, axis=2) s_states = tf.concat(s_states, axis=1) start_tokens = tf.zeros([N], dtype=tf.int32) self.inp_q_with_GO = tf.concat( [tf.expand_dims(start_tokens, axis=1), self.q], axis=1) # supervise if mode is train if config.mode == "train": emb_q = tf.nn.embedding_lookup(params=word_emb_mat, ids=self.inp_q_with_GO) #emb_q = tf.reshape(tf.tile(tf.expand_dims(emb_q, axis=1), [1, M, 1, 1]), (NM, JQ+1, dw)) train_helper = seq2seq.TrainingHelper(emb_q, [JQ] * N) else: s_outputs = seq2seq.tile_batch(s_outputs, multiplier=beam_width) s_states = seq2seq.tile_batch(s_states, multiplier=beam_width) cell = rnn.DropoutWrapper(rnn.GRUCell(num_units=qg_dim * 2), input_keep_prob=config.input_keep_prob) attention_mechanism = seq2seq.BahdanauAttention(num_units=qg_dim * 2, memory=s_outputs) attn_cell = seq2seq.AttentionWrapper(cell, attention_mechanism, attention_layer_size=qg_dim * 2, output_attention=True, alignment_history=False) total_glove_vocab_size = 78878 #72686 out_cell = rnn.OutputProjectionWrapper(attn_cell, VW + total_glove_vocab_size) if config.mode == "train": decoder_initial_states = out_cell.zero_state( batch_size=N, dtype=tf.float32).clone(cell_state=s_states) decoder = seq2seq.BasicDecoder( cell=out_cell, helper=train_helper, initial_state=decoder_initial_states) else: decoder_initial_states = out_cell.zero_state( batch_size=N * beam_width, dtype=tf.float32).clone(cell_state=s_states) decoder = seq2seq.BeamSearchDecoder( cell=out_cell, embedding=word_emb_mat, start_tokens=start_tokens, end_token=EOS_TOKEN, initial_state=decoder_initial_states, beam_width=beam_width, length_penalty_weight=0.0) outputs = seq2seq.dynamic_decode(decoder=decoder, maximum_iterations=JQ) if config.mode == "train": gen_q = outputs[0].sample_id gen_q_prob = outputs[0].rnn_output gen_q_states = outputs[1] else: gen_q = outputs[0].predicted_ids[:, :, 0] gen_q_prob = tf.nn.embedding_lookup( params=word_emb_mat, ids=outputs[0].predicted_ids[:, :, 0]) gen_q_states = outputs[1] self.gen_q = gen_q self.gen_q_prob = gen_q_prob self.gen_q_states = gen_q_states
def _build_forward(self): config = self.config N, M, JX, JQ, VW, VC, d, W ,EW, WOW= \ config.batch_size, config.max_num_sents, config.max_sent_size, \ config.max_ques_size, config.len_new_emb_mat, config.char_vocab_size, config.hidden_size, \ config.max_word_size,config.word_vocab_size-config.vw_wo_entity_size,config.vw_wo_entity_size JX = tf.shape(self.x)[2] # words JQ = tf.shape(self.q)[1] # words M = tf.shape(self.x)[1] dc, dw, dco = config.char_emb_size, config.word_emb_size, config.char_out_size with tf.variable_scope("emb"): if config.use_char_emb: with tf.variable_scope("emb_var"), tf.device("/cpu:0"): char_emb_mat = tf.get_variable("char_emb_mat", shape=[VC, dc], dtype='float') with tf.variable_scope("char"): Acx = tf.nn.embedding_lookup(char_emb_mat, self.cx) # [N, M, JX, W, dc] Acq = tf.nn.embedding_lookup(char_emb_mat, self.cq) # [N, JQ, W, dc] Acx = tf.reshape(Acx, [-1, JX, W, dc]) Acq = tf.reshape(Acq, [-1, JQ, W, dc]) filter_sizes = list( map(int, config.out_channel_dims.split(','))) heights = list(map(int, config.filter_heights.split(','))) assert sum(filter_sizes) == dco, (filter_sizes, dco) with tf.variable_scope("conv"): xx = multi_conv1d(Acx, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") if config.share_cnn_weights: tf.get_variable_scope().reuse_variables() qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") else: qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="qq") xx = tf.reshape(xx, [-1, M, JX, dco]) qq = tf.reshape(qq, [-1, JQ, dco]) if config.use_word_emb: with tf.variable_scope("emb_var"), tf.device("/cpu:0"): if config.mode == 'train': init_word_emb = tf.random_normal_initializer(-0.5, 0.5) #entity_emb_mat = tf.get_variable("entity_emb_mat", dtype='float', shape=[EW, EW], initializer=get_initializer(config.onehot_encoded)) #entity_emb_out = _linear(entity_emb_mat, dw, True, bias_initializer=tf.constant_initializer(0.0)) #word_emb_mat = tf.get_variable("word_emb_mat", dtype='float', shape=[VW, dw], initializer=get_initializer(config.emb_mat)) word_emb_mat = tf.get_variable( "word_emb_mat", dtype='float', shape=[VW, dw], initializer=init_word_emb) #word_emb_mat = tf.concat(axis=0,values=[word_emb_mat, entity_emb_out]) else: word_emb_mat = tf.get_variable("word_emb_mat", shape=[VW, dw], dtype='float') #if config.use_glove_for_unk: # word_emb_mat = tf.concat(axis=0, values=[word_emb_mat, self.new_emb_mat]) with tf.name_scope("word"): Ax = tf.nn.embedding_lookup( word_emb_mat, self.x ) # [N, M, JX, d] i.e. [batch size, max sentences, max words, embedding size] Aq = tf.nn.embedding_lookup( word_emb_mat, self.q ) # [N, JQ, d] i.e. [batch size, max words, embedding size] self.tensor_dict['x'] = Ax self.tensor_dict['q'] = Aq if config.use_char_emb: xx = tf.concat(axis=3, values=[xx, Ax]) # [N, M, JX, di] qq = tf.concat(axis=2, values=[qq, Aq]) # [N, JQ, di] else: xx = Ax qq = Aq # highway network if config.highway: with tf.variable_scope("highway"): xx = highway_network(xx, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) tf.get_variable_scope().reuse_variables() qq = highway_network(qq, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) self.tensor_dict['xx'] = xx self.tensor_dict['qq'] = qq #xx = tf.Print(xx,[tf.shape(xx),xx],message="DHRUV xx=",summarize=20) cell_fw = BasicLSTMCell(d, state_is_tuple=True) cell_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell_fw = SwitchableDropoutWrapper( cell_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell_bw = SwitchableDropoutWrapper( cell_bw, self.is_train, input_keep_prob=config.input_keep_prob) cell2_fw = BasicLSTMCell(d, state_is_tuple=True) cell2_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell2_fw = SwitchableDropoutWrapper( cell2_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell2_bw = SwitchableDropoutWrapper( cell2_bw, self.is_train, input_keep_prob=config.input_keep_prob) cell3_fw = BasicLSTMCell(d, state_is_tuple=True) cell3_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell3_fw = SwitchableDropoutWrapper( cell3_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell3_bw = SwitchableDropoutWrapper( cell3_bw, self.is_train, input_keep_prob=config.input_keep_prob) cell4_fw = BasicLSTMCell(d, state_is_tuple=True) cell4_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell4_fw = SwitchableDropoutWrapper( cell4_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell4_bw = SwitchableDropoutWrapper( cell4_bw, self.is_train, input_keep_prob=config.input_keep_prob) x_len = tf.reduce_sum(tf.cast(self.x_mask, 'int32'), 2) # [N,M] q_len = tf.reduce_sum(tf.cast(self.q_mask, 'int32'), 1) # [N] with tf.variable_scope("prepro"): (fw_u, bw_u), ((_, fw_u_f), (_, bw_u_f)) = bidirectional_dynamic_rnn( d_cell_fw, d_cell_bw, qq, q_len, dtype='float', scope='u1') # [N, J, d], [N, d] u = tf.concat(axis=2, values=[fw_u, bw_u]) if config.share_lstm_weights: tf.get_variable_scope().reuse_variables() (fw_h, bw_h), (fw_s, bw_s) = bidirectional_dynamic_rnn( cell_fw, cell_bw, xx, x_len, dtype='float', scope='u1') # [N, M, JX, 2d] h = tf.concat(axis=3, values=[fw_h, bw_h]) # [N, M, JX, 2d] else: (fw_h, bw_h), _ = bidirectional_dynamic_rnn( cell_fw, cell_bw, xx, x_len, dtype='float', scope='h1') # [N, M, JX, 2d] h = tf.concat(axis=3, values=[fw_h, bw_h]) # [N, M, JX, 2d] self.tensor_dict['u'] = u self.tensor_dict['h'] = h with tf.variable_scope("main"): if config.dynamic_att: # not true p0 = h u = tf.reshape(tf.tile(tf.expand_dims(u, 1), [1, M, 1, 1]), [N * M, JQ, 2 * d]) q_mask = tf.reshape( tf.tile(tf.expand_dims(self.q_mask, 1), [1, M, 1]), [N * M, JQ]) first_cell_fw = AttentionCell( cell2_fw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) first_cell_bw = AttentionCell( cell2_bw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) second_cell_fw = AttentionCell( cell3_fw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) second_cell_bw = AttentionCell( cell3_bw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) else: p0 = attention_layer( config, self.is_train, h, u, h_mask=self.x_mask, u_mask=self.q_mask, scope="p0", tensor_dict=self.tensor_dict) # p0 seems to be G in paper first_cell_fw = d_cell2_fw second_cell_fw = d_cell3_fw first_cell_bw = d_cell2_bw second_cell_bw = d_cell3_bw #p1 = tf.reshape(p0,[N , M*JX, 8*d]) (fw_g0, bw_g0), _ = bidirectional_dynamic_rnn( first_cell_fw, first_cell_bw, p0, x_len, dtype='float', scope='g0') # [N, M, JX, 2d] g0 = tf.concat(axis=3, values=[fw_g0, bw_g0]) (fw_g1, bw_g1), (my_fw_final_state, my_bw_final_state) = bidirectional_dynamic_rnn( second_cell_fw, second_cell_bw, g0, x_len, dtype='float', scope='g1') # [N, M, JX, 2d] g1 = tf.concat(axis=3, values=[fw_g1, bw_g1]) # g1 seems to be M in paper #g1= tf.reshape(g1,[N, M , JX, 2*d]) #reshaping here again, since g1 is used ahead g1 = tf.Print(g1, [tf.shape(g1)], message="g1 shape", first_n=5, summarize=200) p0 = tf.Print(p0, [tf.shape(p0)], message="p0 shape", first_n=5, summarize=200) g11 = tf.reshape(g1, [N, -1, 2 * d]) my_encoder_final_state_c = tf.concat( values=(my_fw_final_state.c, my_bw_final_state.c), axis=1, name="my_encoder_final_state_c") my_encoder_final_state_h = tf.concat( values=(my_fw_final_state.h, my_bw_final_state.h), axis=1, name="my_encoder_final_state_h") my_encoder_final_state = tf.contrib.rnn.LSTMStateTuple( c=my_encoder_final_state_c, h=my_encoder_final_state_h) #compute indices for finding span as the second task in multi task learning logits = get_logits([g1, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits1') logits = tf.Print(logits, [tf.shape(logits)], message="logits shape", first_n=5, summarize=200) a1i = softsel(tf.reshape(g1, [N, M * JX, 2 * d]), tf.reshape(logits, [N, M * JX])) a1i = tf.tile(tf.expand_dims(tf.expand_dims(a1i, 1), 1), [1, M, JX, 1]) (fw_g2, bw_g2), _ = bidirectional_dynamic_rnn( d_cell4_fw, d_cell4_bw, tf.concat(axis=3, values=[p0, g1, a1i, g1 * a1i]), x_len, dtype='float', scope='g2') # [N, M, JX, 2d] g2 = tf.concat(axis=3, values=[fw_g2, bw_g2]) logits2 = get_logits([g2, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits2') flat_logits = tf.reshape(logits, [-1, M * JX]) flat_logits = tf.Print(flat_logits, [tf.shape(flat_logits), flat_logits], message="flat_logits shape and contents", first_n=5, summarize=200) self.flat_yp = tf.nn.softmax(flat_logits) # [-1, M*JX] flat_logits2 = tf.reshape(logits2, [-1, M * JX]) self.flat_yp2 = tf.nn.softmax(flat_logits2) tgt_vocab_size = config.len_new_emb_mat # hparam # FIXME: Obtain embeddings differently? print("length is", config.len_new_emb_mat) nodes = d # Look up embedding decoder_emb_inp = tf.nn.embedding_lookup( word_emb_mat, self.decoder_inputs) # [batch_size, max words, embedding_size] with tf.variable_scope("rnn_decoder", reuse=tf.AUTO_REUSE): init = tf.random_normal_initializer(0.0, 0.5) W_dense = tf.get_variable(name="W_dense", shape=[2 * nodes, tgt_vocab_size], dtype=tf.float32, initializer=init) b_dense = tf.get_variable(name="b_dense", shape=[tgt_vocab_size], dtype=tf.float32, initializer=tf.zeros_initializer) W_att_dec = tf.get_variable(name="W_att_dec", shape=[2 * nodes, 2 * nodes], dtype=tf.float32, initializer=init) W_att_enc = tf.get_variable(name="W_att_enc1", shape=[1, 1, 2 * nodes, 2 * nodes], dtype=tf.float32, initializer=init) v_blend = tf.get_variable(name="v_blend", shape=[1, 2 * nodes], dtype=tf.float32, initializer=init) pad_time_slice = tf.fill([N], 0, name='PAD') pad_step_embedded = tf.nn.embedding_lookup( word_emb_mat, pad_time_slice) decoder_cell = tf.contrib.rnn.BasicLSTMCell( 2 * nodes, state_is_tuple=True ) # doesnt work without the factor of 2?? '''Loop transition function is a mapping (time, previous_cell_output, previous_cell_state, previous_loop_state) -> (elements_finished, input, cell_state, output, loop_state). It is called before RNNCell to prepare its inputs and state. Everything is a Tensor except for initial call at time=0 when everything is None (except time).''' def execute_pointer_network(attn_dist): #this is to find the word in the summary, which recieved highest probability and pass it to the next step in decoder index_pos = tf.argmax(attn_dist, axis=1) index_pos = tf.expand_dims(index_pos, 1) index_pos = tf.concat([ tf.reshape(tf.range(start=0, limit=N, dtype=tf.int64), [N, 1]), tf.zeros([N, 1], tf.int64), index_pos ], axis=1) index_pos = tf.cast(tf.gather_nd(params=self.x, indices=index_pos), dtype=tf.int64) return index_pos def execute_normal_decoder(previous_output, W_dense, b_dense): output_logits = tf.add(tf.matmul(previous_output, W_dense), b_dense) return tf.argmax(output_logits, axis=1) def loop_fn_initial(): initial_elements_finished = ( 0 >= self.target_sequence_length ) # all False at the initial step #initial_input = tf.concat([decoder_emb_inp[:,0], my_encoder_final_state_h], 1) initial_input = decoder_emb_inp[:, 0] initial_cell_state = my_encoder_final_state #setting the correct shapes , as it is used to determine the emit structure initial_cell_output = tf.cond( self.pointer_gen, lambda: tf.zeros([M * JX], tf.float32), lambda: tf.zeros([2 * nodes], tf.float32)) initial_loop_state = None # we don't need to pass any additional information return (initial_elements_finished, initial_input, initial_cell_state, initial_cell_output, initial_loop_state) encoder_output = tf.expand_dims(g11, axis=2) def loop_fn_transition(time, previous_output, previous_state, previous_loop_state): def get_next_input(): # compute Badhanau style attention #performing convolution or reshaping input to (-1,2*d) and then doing matmul, is essentially the same operation #see matrix_mult.py...conv2d might be faster?? #https://stackoverflow.com/questions/38235555/tensorflow-matmul-of-input-matrix-with-batch-data encoder_features = tf.nn.conv2d( encoder_output, W_att_enc, [1, 1, 1, 1], "SAME" ) # shape (batch_size,max_enc_steps,1,attention_vec_size) dec_portion = tf.matmul(previous_state.h, W_att_dec) decoder_features = tf.expand_dims( tf.expand_dims(dec_portion, 1), 1 ) # reshape to (batch_size, 1, 1, attention_vec_size) #python broadcasting will alllow the two features to get added e_not_masked = tf.reduce_sum( v_blend * tf.nn.tanh(encoder_features + decoder_features), [2, 3]) # calculate e, (batch_size, max_enc_steps) #The shape of output of a softmax is the same as the input: it just normalizes the values. attn_dist = tf.nn.softmax( e_not_masked) # (batch_size, max_enc_steps) attn_dist = tf.Print(attn_dist, [tf.shape(attn_dist)], message="attn_dist", first_n=5, summarize=200) #Multiplying all the 2d vectors with same attn_dist values,and finally keeping 1 2d vector for every batch example context_vector = tf.reduce_sum( tf.reshape(attn_dist, [N, -1, 1, 1]) * encoder_output, [1, 2]) # shape (batch_size, attn_size). context_vector = tf.reshape(context_vector, [-1, 2 * nodes]) #next_input = tf.cond(self.is_train, lambda: tf.concat( # [tf.reshape(decoder_emb_inp[:, time], (N, dw)), context_vector], 1), # lambda: tf.concat([tf.nn.embedding_lookup(word_emb_mat, prediction), context_vector], 1)) #output_logits = tf.add(tf.matmul(previous_output, W_dense), b_dense) prediction = tf.cond( self.pointer_gen, lambda: execute_pointer_network(attn_dist), lambda: execute_normal_decoder( previous_output, W_dense, b_dense)) with tf.variable_scope("modified_dec_inputs", reuse=tf.AUTO_REUSE): next_input = tf.cond( self.is_train, lambda: _linear(args=[context_vector] + [ tf.reshape(decoder_emb_inp[:, time], (N, dw)) ], output_size=dw, bias=True), lambda: _linear([context_vector] + [ tf.nn.embedding_lookup( word_emb_mat, prediction) ], dw, True)) return next_input, attn_dist elements_finished = ( time >= self.target_sequence_length ) # this operation produces boolean tensor of [batch_size] # defining if corresponding sequence has ended finished = tf.reduce_all( elements_finished) # -> boolean scalar #input = tf.cond(finished, lambda: tf.concat([pad_step_embedded, my_encoder_final_state_h], 1),get_next_input) input, attn_distribution = tf.cond( finished, lambda: (pad_step_embedded, tf.zeros([N, M * JX], tf.float32)), get_next_input) attn_distribution = tf.Print(attn_distribution, [tf.shape(attn_distribution)], message="attn_distribution", first_n=5, summarize=200) state = previous_state output = tf.cond(self.pointer_gen, lambda: attn_distribution, lambda: previous_output) output = tf.Print(output, [tf.shape(output)], message="OUTPUT", first_n=5, summarize=200) loop_state = None return (elements_finished, input, state, output, loop_state) def loop_fn(time, previous_output, previous_state, previous_loop_state): if previous_state is None: # time == 0 assert previous_output is None and previous_state is None return loop_fn_initial() else: return loop_fn_transition(time, previous_output, previous_state, previous_loop_state) decoder_outputs_ta, decoder_final_state, _ = tf.nn.raw_rnn( decoder_cell, loop_fn) decoder_outputs = decoder_outputs_ta.stack() decoder_outputs = tf.Print(decoder_outputs, [tf.shape(decoder_outputs)], message="decoder_outputs", first_n=5, summarize=200) # To do output projection, we have to temporarilly flatten decoder_outputs from [max_steps, batch_size, hidden_dim] to # [max_steps*batch_size, hidden_dim], as tf.matmul needs rank-2 tensors at most. decoder_max_steps, decoder_batch_size, decoder_dim = tf.unstack( tf.shape(decoder_outputs)) decoder_outputs_flat = tf.reshape(decoder_outputs, (-1, decoder_dim)) #if pointer networks, no need to pass through dense layer decoder_logits_flat = tf.cond( self.pointer_gen, lambda: decoder_outputs_flat, lambda: tf.add( tf.matmul(decoder_outputs_flat, W_dense), b_dense)) decoder_logits = tf.cond( self.pointer_gen, lambda: tf.reshape( decoder_logits_flat, (decoder_max_steps, decoder_batch_size, decoder_dim)), lambda: tf.reshape(decoder_logits_flat, (decoder_max_steps, decoder_batch_size, tgt_vocab_size))) decoder_logits = _transpose_batch_time(decoder_logits) #decoder_prediction = tf.argmax(decoder_logits, -1) #self.decoder_logits_train = final_outputs.rnn_output self.decoder_logits_train = decoder_logits self.index_start = flat_logits self.index_end = flat_logits2
def _build_forward(self): config = self.config N, M, JX, JQ, VW, VC, d, W = \ config.batch_size, config.max_num_sents, config.max_sent_size, \ config.max_ques_size, config.word_vocab_size, config.char_vocab_size, config.hidden_size, \ config.max_word_size JX = tf.shape(self.x)[2] JQ = tf.shape(self.q)[1] M = tf.shape(self.x)[1] dc, dw, dco = config.char_emb_size, config.word_emb_size, config.char_out_size with tf.variable_scope("emb"): if config.use_char_emb: with tf.variable_scope("emb_var"), tf.device("/cpu:0"): char_emb_mat = tf.get_variable("char_emb_mat", shape=[VC, dc], dtype='float') with tf.variable_scope("char"): Acx = tf.nn.embedding_lookup(char_emb_mat, self.cx) # [N, M, JX, W, dc] Acq = tf.nn.embedding_lookup(char_emb_mat, self.cq) # [N, JQ, W, dc] Acx = tf.reshape(Acx, [-1, JX, W, dc]) Acq = tf.reshape(Acq, [-1, JQ, W, dc]) filter_sizes = list( map(int, config.out_channel_dims.split(','))) heights = list(map(int, config.filter_heights.split(','))) assert sum(filter_sizes) == dco, (filter_sizes, dco) with tf.variable_scope("conv"): xx = multi_conv1d(Acx, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") if config.share_cnn_weights: tf.get_variable_scope().reuse_variables() qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx") else: qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="qq") xx = tf.reshape(xx, [-1, M, JX, dco]) qq = tf.reshape(qq, [-1, JQ, dco]) if config.use_word_emb: with tf.variable_scope("emb_var"), tf.device("/cpu:0"): if config.mode == 'train': word_emb_mat = tf.get_variable( "word_emb_mat", dtype='float', shape=[VW, dw], initializer=get_initializer(config.emb_mat)) else: word_emb_mat = tf.get_variable("word_emb_mat", shape=[VW, dw], dtype='float') if config.use_glove_for_unk: word_emb_mat = tf.concat( axis=0, values=[word_emb_mat, self.new_emb_mat]) with tf.name_scope("word"): Ax = tf.nn.embedding_lookup(word_emb_mat, self.x) # [N, M, JX, d] Aq = tf.nn.embedding_lookup(word_emb_mat, self.q) # [N, JQ, d] self.tensor_dict['x'] = Ax self.tensor_dict['q'] = Aq if config.use_char_emb: xx = tf.concat(axis=3, values=[xx, Ax]) # [N, M, JX, di] qq = tf.concat(axis=2, values=[qq, Aq]) # [N, JQ, di] else: xx = Ax qq = Aq # highway network if config.highway: with tf.variable_scope("highway"): xx = highway_network(xx, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) tf.get_variable_scope().reuse_variables() qq = highway_network(qq, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train) self.tensor_dict['xx'] = xx self.tensor_dict['qq'] = qq cell_fw = BasicLSTMCell(d, state_is_tuple=True) cell_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell_fw = SwitchableDropoutWrapper( cell_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell_bw = SwitchableDropoutWrapper( cell_bw, self.is_train, input_keep_prob=config.input_keep_prob) cell2_fw = BasicLSTMCell(d, state_is_tuple=True) cell2_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell2_fw = SwitchableDropoutWrapper( cell2_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell2_bw = SwitchableDropoutWrapper( cell2_bw, self.is_train, input_keep_prob=config.input_keep_prob) cell3_fw = BasicLSTMCell(d, state_is_tuple=True) cell3_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell3_fw = SwitchableDropoutWrapper( cell3_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell3_bw = SwitchableDropoutWrapper( cell3_bw, self.is_train, input_keep_prob=config.input_keep_prob) cell4_fw = BasicLSTMCell(d, state_is_tuple=True) cell4_bw = BasicLSTMCell(d, state_is_tuple=True) d_cell4_fw = SwitchableDropoutWrapper( cell4_fw, self.is_train, input_keep_prob=config.input_keep_prob) d_cell4_bw = SwitchableDropoutWrapper( cell4_bw, self.is_train, input_keep_prob=config.input_keep_prob) x_len = tf.reduce_sum(tf.cast(self.x_mask, 'int32'), 2) # [N, M] q_len = tf.reduce_sum(tf.cast(self.q_mask, 'int32'), 1) # [N] with tf.variable_scope("prepro"): (fw_u, bw_u), ((_, fw_u_f), (_, bw_u_f)) = bidirectional_dynamic_rnn( d_cell_fw, d_cell_bw, qq, q_len, dtype='float', scope='u1') # [N, J, d], [N, d] u = tf.concat(axis=2, values=[fw_u, bw_u]) if config.share_lstm_weights: tf.get_variable_scope().reuse_variables() (fw_h, bw_h), _ = bidirectional_dynamic_rnn( cell_fw, cell_bw, xx, x_len, dtype='float', scope='u1') # [N, M, JX, 2d] h = tf.concat(axis=3, values=[fw_h, bw_h]) # [N, M, JX, 2d] else: (fw_h, bw_h), _ = bidirectional_dynamic_rnn( cell_fw, cell_bw, xx, x_len, dtype='float', scope='h1') # [N, M, JX, 2d] h = tf.concat(axis=3, values=[fw_h, bw_h]) # [N, M, JX, 2d] self.tensor_dict['u'] = u self.tensor_dict['h'] = h with tf.variable_scope("main"): if config.dynamic_att: p0 = h u = tf.reshape(tf.tile(tf.expand_dims(u, 1), [1, M, 1, 1]), [N * M, JQ, 2 * d]) q_mask = tf.reshape( tf.tile(tf.expand_dims(self.q_mask, 1), [1, M, 1]), [N * M, JQ]) first_cell_fw = AttentionCell( cell2_fw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) first_cell_bw = AttentionCell( cell2_bw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) second_cell_fw = AttentionCell( cell3_fw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) second_cell_bw = AttentionCell( cell3_bw, u, mask=q_mask, mapper='sim', input_keep_prob=self.config.input_keep_prob, is_train=self.is_train) else: p0 = attention_layer(config, self.is_train, h, u, h_mask=self.x_mask, u_mask=self.q_mask, scope="p0", tensor_dict=self.tensor_dict) first_cell_fw = d_cell2_fw second_cell_fw = d_cell3_fw first_cell_bw = d_cell2_bw second_cell_bw = d_cell3_bw config.ruminating_layer = True if config.ruminating_layer: ''' RUMINATING LAYER ''' with tf.variable_scope('rum_layer'): print('-' * 5 + "RUMINATING LAYER" + '-' * 5) print("Context", xx) #[N,M,JX,2d] print("Question", qq) #[N,JQ,2d] print("p0", p0) #[N,M,JX,8D] sum_cell = BasicLSTMCell(d, state_is_tuple=True) (s_f, s_b), _ = bidirectional_dynamic_rnn(sum_cell, sum_cell, p0, x_len, dtype=tf.float32, scope="sum_layer") batch_lens = (tf.reshape(x_len, [N * M])) s_f = tf.reshape(s_f, [N * M, JX, d]) s_b = tf.reshape(s_b, [N * M, JX, d]) s_fout = tf.reshape(extract_axis_1(s_f, batch_lens), [N, M, d]) s_bout = tf.reshape(extract_axis_1(s_b, batch_lens), [N, M, d]) s = tf.concat(axis=2, values=[s_fout, s_bout]) # [N,M,2d] print("summarization layer", s) print('-' * 5 + "QUESTION RUMINATE LAYER" + '-' * 5) S_Q = tf.tile(tf.expand_dims(s, 2), [1, 1, JQ, 1]) # [N,M,JQ,2d] S_cell_fw = BasicLSTMCell(d, state_is_tuple=True) S_cell_bw = BasicLSTMCell(d, state_is_tuple=True) (fw_hq, bw_hq), _ = bidirectional_dynamic_rnn(S_cell_fw, S_cell_bw, S_Q, q_len, dtype=tf.float32, scope="S_Q") S_Q = tf.concat(axis=3, values=[fw_hq, bw_hq]) q_m = tf.reshape(tf.expand_dims(qq, 1), [N, M, JQ, 2 * d]) with tf.variable_scope("question_rum_layer"): Q_hat = ruminating_layer(S_Q, q_m, N, M, JQ, d) print("Q_hat", Q_hat) #[N,M,JQ,2d] print('-' * 5 + "CONTEXT RUMINATE LAYER" + '-' * 5) S_C = tf.tile(tf.expand_dims(s, 2), [1, 1, JX, 1]) # [N,M,JX,2d] C_cell_fw = BasicLSTMCell(d, state_is_tuple=True) C_cell_bw = BasicLSTMCell(d, state_is_tuple=True) (fw_h, bw_h), _ = bidirectional_dynamic_rnn(C_cell_fw, C_cell_bw, S_C, x_len, dtype=tf.float32, scope="S_C") S_C = tf.concat(axis=3, values=[fw_h, bw_h]) #[N,M,JX,2d] with tf.variable_scope("context_rum_layer"): C_hat = ruminating_layer(S_C, xx, N, M, JX, d) print("C_hat", C_hat) #[N,M,JX,2d] #Second Hop bi-Attention print('-' * 5 + "SECOND HOP ATTENTION" + '-' * 5) sh_aug = tf.tile(tf.expand_dims(C_hat, 3), [1, 1, 1, JQ, 1]) #[N,M,JX,2d] su_aug = tf.tile(tf.expand_dims(Q_hat, 2), [1, 1, JX, 1, 1]) #[N,M,JQ,2d] sh_mask_aug = tf.tile(tf.expand_dims(self.x_mask, -1), [1, 1, 1, JQ]) su_mask_aug = tf.tile( tf.expand_dims(tf.expand_dims(self.q_mask, 1), 1), [1, M, JX, 1]) shu_mask = sh_mask_aug & su_mask_aug su_logits = get_logits([sh_aug, su_aug], None, True, wd=config.wd, mask=shu_mask, is_train=True, func=config.logit_func, scope='su_logits') su_a = softsel(su_aug, su_logits) sh_a = softsel(C_hat, tf.reduce_max(su_logits, 3)) sh_a = tf.tile(tf.expand_dims(sh_a, 2), [1, 1, JX, 1]) p00 = tf.concat( axis=3, values=[C_hat, su_a, C_hat * su_a, C_hat * sh_a]) print("p00", p00) #[N,M,JX,8d] p0 = p00 print('-' * 5 + "END RUMINATING LAYER" + '-' * 5) (fw_g0, bw_g0), _ = bidirectional_dynamic_rnn( first_cell_fw, first_cell_bw, p0, x_len, dtype='float', scope='g0') # [N, M, JX, 2d] g0 = tf.concat(axis=3, values=[fw_g0, bw_g0]) (fw_g1, bw_g1), _ = bidirectional_dynamic_rnn( second_cell_fw, second_cell_bw, g0, x_len, dtype='float', scope='g1') # [N, M, JX, 2d] g1 = tf.concat(axis=3, values=[fw_g1, bw_g1]) logits = get_logits([g1, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits1') a1i = softsel(tf.reshape(g1, [N, M * JX, 2 * d]), tf.reshape(logits, [N, M * JX])) a1i = tf.tile(tf.expand_dims(tf.expand_dims(a1i, 1), 1), [1, M, JX, 1]) (fw_g2, bw_g2), _ = bidirectional_dynamic_rnn( d_cell4_fw, d_cell4_bw, tf.concat(axis=3, values=[p0, g1, a1i, g1 * a1i]), x_len, dtype='float', scope='g2') # [N, M, JX, 2d] g2 = tf.concat(axis=3, values=[fw_g2, bw_g2]) logits2 = get_logits([g2, p0], d, True, wd=config.wd, input_keep_prob=config.input_keep_prob, mask=self.x_mask, is_train=self.is_train, func=config.answer_func, scope='logits2') flat_logits = tf.reshape(logits, [-1, M * JX]) flat_yp = tf.nn.softmax(flat_logits) # [-1, M*JX] flat_logits2 = tf.reshape(logits2, [-1, M * JX]) flat_yp2 = tf.nn.softmax(flat_logits2) self.tensor_dict['g1'] = g1 self.tensor_dict['g2'] = g2 if config.na: na_bias = tf.get_variable("na_bias", shape=[], dtype='float') na_bias_tiled = tf.tile(tf.reshape(na_bias, [1, 1]), [N, 1]) # [N, 1] concat_flat_logits = tf.concat( axis=1, values=[na_bias_tiled, flat_logits]) concat_flat_yp = tf.nn.softmax(concat_flat_logits) na_prob = tf.squeeze(tf.slice(concat_flat_yp, [0, 0], [-1, 1]), [1]) flat_yp = tf.slice(concat_flat_yp, [0, 1], [-1, -1]) concat_flat_logits2 = tf.concat( axis=1, values=[na_bias_tiled, flat_logits2]) concat_flat_yp2 = tf.nn.softmax(concat_flat_logits2) na_prob2 = tf.squeeze( tf.slice(concat_flat_yp2, [0, 0], [-1, 1]), [1]) # [N] flat_yp2 = tf.slice(concat_flat_yp2, [0, 1], [-1, -1]) self.concat_logits = concat_flat_logits self.concat_logits2 = concat_flat_logits2 self.na_prob = na_prob * na_prob2 yp = tf.reshape(flat_yp, [-1, M, JX]) yp2 = tf.reshape(flat_yp2, [-1, M, JX]) wyp = tf.nn.sigmoid(logits2) self.logits = flat_logits self.logits2 = flat_logits2 self.yp = yp self.yp2 = yp2 self.wyp = wyp