def _build_model_op(self): with tf.variable_scope('embeddings'): _word_emb = tf.Variable(self.config.emb, name='_word_emb', dtype=tf.float32, trainable=self.config.finetune_emb) self.sl_emb = tf.nn.embedding_lookup(_word_emb, self.sl, name='sent_left_emb') self.sr_emb = tf.nn.embedding_lookup(_word_emb, self.sr, name='sent_right_emb') self.desc_emb = tf.nn.embedding_lookup(_word_emb, self.desc, name='desc_emb') with tf.variable_scope('lexical_encoder'): lexical_rnn = bidirectional_dynamic_rnn(self.config.num_units, use_peepholes=True, scope='lexical_rnn') self.de = lexical_rnn(self.desc_emb, self.desc_seq_len, return_last_state=True, keep_prob=self.keep_prob, is_train=self.is_train) # (batch_size, 2 * num_units) with tf.variable_scope('concat_sentence'): de = tf.expand_dims(self.de, axis=1) s_emb = tf.concat([self.sl_emb, de, self.sr_emb], axis=1) # (batch_size, seq_len, word_dim) with tf.variable_scope('context_encoder'): context_rnn = bidirectional_dynamic_rnn(self.config.num_units, use_peepholes=True, scope='context_rnn') self.hi = context_rnn(s_emb, self.sent_seq_len, return_last_state=True, keep_prob=self.keep_prob, is_train=self.is_train) # (batch_size, 2 * num_units) with tf.variable_scope('project'): w = tf.get_variable(name='W', shape=[2 * self.config.num_units, 2 * self.config.num_units], dtype=tf.float32) b = tf.get_variable("b", shape=[self.config.batch_size, ], dtype=tf.float32, initializer=tf.constant_initializer(0.)) tmp = tf.matmul(self.hi, w) # (batch_size, 2 * num_units) tmp = tf.matmul(tmp, tf.transpose(self.de)) # (batch_size, 2 * num_units) # tmp = tf.reshape(tmp, [self.config.batch_size, self.config.batch_size]) print("tmp", tmp.get_shape()) self.logits = tf.sigmoid(tf.diag_part(tf.nn.bias_add(tmp, b)))
def _build_model_op(self): with tf.variable_scope('embeddings'): _word_emb = tf.Variable(self.config.emb, name='_word_emb', dtype=tf.float32, trainable=self.config.finetune_emb) self.sl_emb = tf.nn.embedding_lookup(_word_emb, self.sl, name='sent_left_emb') self.sr_emb = tf.nn.embedding_lookup(_word_emb, self.sr, name='sent_right_emb') self.desc_emb = tf.nn.embedding_lookup(_word_emb, self.desc, name='desc_emb') with tf.variable_scope('lexical_encoder'): s = self.desc_emb.get_shape() desc_emb = tf.reshape(self.desc_emb, shape=[s[0] * s[1], s[2], s[-1]]) desc_seq_len = tf.reshape(self.desc_seq_len, shape=[s[0] * s[1]]) lexical_rnn = bidirectional_dynamic_rnn(self.config.desc_units, use_peepholes=True, scope='lexical_rnn') de = lexical_rnn(desc_emb, desc_seq_len, return_last_state=True) # (batch_size, num_cands, num_units) self.de = tf.reshape( de, shape=[s[0], s[1], 2 * self.config.desc_units]) with tf.variable_scope('concat_sentence'): de = tf.expand_dims( self.de, axis=2) # (batch_size, num_cands, 1, num_units) print("de", de.get_shape()) sl_emb = tf.tile(tf.expand_dims(self.sl_emb, axis=1), [1, self.config.num_cands, 1, 1]) sr_emb = tf.tile(tf.expand_dims(self.sl_emb, axis=1), [1, self.config.num_cands, 1, 1]) s_emb = tf.concat([sl_emb, de, sr_emb], axis=-2) print("s_emb", s_emb.get_shape()) with tf.variable_scope('context_encoder'): s = s_emb.get_shape().as_list() s_emb = tf.reshape(s_emb, shape=[s[0] * s[1], s[2], s[3]]) sent_seq_len = tf.concat([self.sent_seq_len for _ in range(s[1])], axis=0) context_rnn = bidirectional_dynamic_rnn(self.config.num_units, use_peepholes=True, scope='context_rnn') hi = context_rnn(s_emb, sent_seq_len, return_last_state=True) self.hi = tf.reshape(hi, shape=[s[0], s[1], 2 * self.config.num_units]) with tf.variable_scope('project'): batch_size, num_cands, _ = self.de.get_shape() w = tf.get_variable(name='W', shape=[ batch_size, 2 * self.config.num_units, 2 * self.config.num_units ], dtype=tf.float32) # (2 * num_units, batch_size * cands) # de = tf.transpose(tf.reshape(self.de, shape=[-1, 2 * self.config.num_units * s[0]])) # hi = tf.reshape(self.hi, shape=[-1, 2 * self.config.num_units * 3]) # (batch_size * cands, 2 * num_units) # p = tf.diag_part(tf.matmul(tf.matmul(hi, w), de)) de = tf.transpose(self.de, perm=[0, 2, 1]) print(de.get_shape()) p = tf.matmul(tf.matmul(self.hi, w), de) print("p", p.get_shape()) output = tf.sigmoid(p) # ignore bias print("output", output.get_shape()) print("y", self.y.get_shape()) self.logits = tf.reshape(p, shape=[batch_size, num_cands])
def _build_model_op(self): with tf.variable_scope('embeddings'): _word_emb = tf.Variable(self.config.emb, name='_word_emb', dtype=tf.float32, trainable=self.config.finetune_emb) self.sl_emb = tf.nn.embedding_lookup(_word_emb, self.sl, name='sent_left_emb') self.sr_emb = tf.nn.embedding_lookup(_word_emb, self.sr, name='sent_right_emb') self.desc_emb_c1 = tf.nn.embedding_lookup(_word_emb, self.desc_c1, name='desc_emb') self.desc_emb_c2 = tf.nn.embedding_lookup(_word_emb, self.desc_c2, name='desc_emb') self.desc_emb_c3 = tf.nn.embedding_lookup(_word_emb, self.desc_c3, name='desc_emb') with tf.variable_scope('lexical_encoder'): lexical_rnn = bidirectional_dynamic_rnn(self.config.num_units, use_peepholes=True, scope='lexical_rnn') self.de1 = lexical_rnn( self.desc_emb_c1, self.desc_seq_len_c1, return_last_state=True, keep_prob=self.keep_prob, is_train=self.is_train) # (batch_size, 2 * num_units) self.de2 = lexical_rnn( self.desc_emb_c2, self.desc_seq_len_c2, return_last_state=True, keep_prob=self.keep_prob, is_train=self.is_train) # (batch_size, 2 * num_units) self.de3 = lexical_rnn( self.desc_emb_c3, self.desc_seq_len_c3, return_last_state=True, keep_prob=self.keep_prob, is_train=self.is_train) # (batch_size, 2 * num_units) with tf.variable_scope('concat_sentence'): de1 = tf.expand_dims(self.de1, axis=1) de2 = tf.expand_dims(self.de2, axis=1) de3 = tf.expand_dims(self.de3, axis=1) s_emb1 = tf.concat([self.sl_emb, de1, self.sr_emb], axis=1) # (batch_size, seq_len, word_dim) s_emb2 = tf.concat([self.sl_emb, de2, self.sr_emb], axis=1) # (batch_size, seq_len, word_dim) s_emb3 = tf.concat([self.sl_emb, de3, self.sr_emb], axis=1) # (batch_size, seq_len, word_dim) with tf.variable_scope('context_encoder'): context_rnn = bidirectional_dynamic_rnn(self.config.num_units, use_peepholes=True, scope='context_rnn') self.hi_c1 = context_rnn( s_emb1, self.sent_seq_len, return_last_state=True, keep_prob=self.keep_prob, is_train=self.is_train) # (batch_size, 2 * num_units) self.hi_c2 = context_rnn( s_emb2, self.sent_seq_len, return_last_state=True, keep_prob=self.keep_prob, is_train=self.is_train) # (batch_size, 2 * num_units) self.hi_c3 = context_rnn( s_emb3, self.sent_seq_len, return_last_state=True, keep_prob=self.keep_prob, is_train=self.is_train) # (batch_size, 2 * num_units) """ with tf.variable_scope('project'): w = tf.get_variable(name='W', shape=[2 * self.config.num_units, 2 * self.config.num_units], dtype=tf.float32) b = tf.get_variable("b", shape=[2 * self.config.num_units], dtype=tf.float32, initializer=tf.constant_initializer(0.)) tmp = tf.nn.bias_add(tf.matmul(self.hi, w), b) # (batch_size, 2 * num_units) tmp = tf.multiply(tmp, self.de) # (batch_size, 2 * num_units) self.logits = tf.sigmoid(tf.reduce_sum(tmp, axis=-1)) # (batch_size, 1) # self.logits = dense(tmp, hidden_dim=2, use_bias=True, scope='compute_logits') """ with tf.variable_scope('project'): W = tf.get_variable( name='W', shape=[2 * self.config.num_units, 2 * self.config.num_units], dtype=tf.float32) b = tf.get_variable("b", shape=[2 * self.config.num_units], dtype=tf.float32, initializer=tf.constant_initializer(0.)) tmp_c1 = tf.nn.bias_add(tf.matmul(self.hi_c1, W), b) # (batch_size, 2 * num_units) tmp_c2 = tf.nn.bias_add(tf.matmul(self.hi_c2, W), b) tmp_c3 = tf.nn.bias_add(tf.matmul(self.hi_c3, W), b) tmp_c1 = tf.reduce_sum(tf.matmul(tmp_c1, tf.transpose(self.de1)), axis=-1) # (batch_size, 1) tmp_c2 = tf.reduce_sum(tf.matmul(tmp_c2, tf.transpose(self.de2)), axis=-1) tmp_c3 = tf.reduce_sum(tf.matmul(tmp_c3, tf.transpose(self.de3)), axis=-1) tmp_c1 = tf.reshape(tmp_c1, shape=[-1, 1]) # (batch_size, 1) tmp_c2 = tf.reshape(tmp_c2, shape=[-1, 1]) # (batch_size, 1) tmp_c3 = tf.reshape(tmp_c3, shape=[-1, 1]) # (batch_size, 1) self.logits = tf.concat( [tf.sigmoid(tmp_c1), tf.sigmoid(tmp_c2), tf.sigmoid(tmp_c3)], axis=1) # (batch_size, 3)