def call(self, inputs, training=None, mask=None): # pylint: disable=too-many-locals input_x = inputs["input_x"] if self.use_dense_task: dense_input = inputs["input_dense"] if self.use_true_length: # [batch_size, max_doc_len, max_sen_len] input_hx = self.pad_to_hier_input_true_len( input_x, self.max_doc_len, self.max_sen_len, self.split_token, padding_token=self.padding_token) else: # [batch_size, max_doc_len, max_sen_len] input_hx = self.pad_to_hier_input( input_x, self.max_doc_len, self.max_sen_len, padding_token=self.padding_token) # [batch_size, max_doc_len] sen_lens = compute_sen_lens(input_hx, padding_token=self.padding_token) # [batch_size] doc_lens = compute_doc_lens(sen_lens) # [batch_size, max_doc_len, max_sen_len, 1] sen_mask = tf.expand_dims( tf.sequence_mask(sen_lens, self.max_sen_len, dtype=tf.float32), axis=-1) # [batch_size, max_doc_len, 1] doc_mask = tf.expand_dims( tf.sequence_mask(doc_lens, self.max_doc_len, dtype=tf.float32), axis=-1) # [batch_size, max_doc_len, max_sen_len, embed_len] out = self.embed(input_hx) if self.use_pretrained_model: input_px = self.get_pre_train_graph(input_x) input_px = tf.reshape(input_px, [-1, self.max_doc_len, self.max_sen_len, self.pretrained_model_dim]) out = tf.concat([out, input_px], axis=-1) out = self.embed_d(out, training=training) all_sen_encoder = tf.keras.layers.TimeDistributed(self.sen_encoder) # [batch_size, max_doc_len, features] out = all_sen_encoder(out, training=training, mask=sen_mask) # [batch_size, features] out = self.doc_encoder(out, training=training, mask=doc_mask) if self.use_dense_input: dense_out = self.dense_input_linear(dense_input) if self.only_dense_input: out = dense_out else: out = tf.keras.layers.Concatenate()([out, dense_out]) # [batch_size, class_num] scores = self.final_dense(out) return scores
def test_compute_doc_lens(self): ''' compute document length''' docs = tf.placeholder(dtype=tf.int32) lens = compute_doc_lens(docs) with self.cached_session(use_gpu=False, force_gpu=False) as sess: # test for 1d res = sess.run(lens, feed_dict={docs: [1, 2, 0, 0]}) self.assertEqual(res, 2) # test for 2d res = sess.run(lens, feed_dict={docs: [[1, 2, 0, 0], [1, 2, 3, 4]]}) self.assertAllEqual(res, [2, 4])
def test_compute_doc_lens(self): docs = tf.placeholder(dtype=tf.int32) lens = compute_doc_lens(docs) with self.session() as sess: # test for 1d res = sess.run(lens, feed_dict={docs: [1, 2, 0, 0]}) logging.info(res) self.assertEqual(res, 2) # test for 2d res = sess.run(lens, feed_dict={docs: [[1, 2, 0, 0], [1, 2, 3, 4]]}) logging.info(res) self.assertAllEqual(res, [2, 4])