Ejemplo n.º 1
0
  def call(self, inputs, training=None, mask=None):  # pylint: disable=too-many-locals
    input_x = inputs["input_x"]
    if self.use_dense_task:
      dense_input = inputs["input_dense"]

    if self.use_true_length:
      # [batch_size, max_doc_len, max_sen_len]
      input_hx = self.pad_to_hier_input_true_len(
          input_x,
          self.max_doc_len,
          self.max_sen_len,
          self.split_token,
          padding_token=self.padding_token)
    else:
      # [batch_size, max_doc_len, max_sen_len]
      input_hx = self.pad_to_hier_input(
          input_x,
          self.max_doc_len,
          self.max_sen_len,
          padding_token=self.padding_token)

    # [batch_size, max_doc_len]
    sen_lens = compute_sen_lens(input_hx, padding_token=self.padding_token)
    # [batch_size]
    doc_lens = compute_doc_lens(sen_lens)
    # [batch_size, max_doc_len, max_sen_len, 1]
    sen_mask = tf.expand_dims(
        tf.sequence_mask(sen_lens, self.max_sen_len, dtype=tf.float32), axis=-1)

    # [batch_size, max_doc_len, 1]
    doc_mask = tf.expand_dims(
        tf.sequence_mask(doc_lens, self.max_doc_len, dtype=tf.float32), axis=-1)

    # [batch_size, max_doc_len, max_sen_len, embed_len]
    out = self.embed(input_hx)
    if self.use_pretrained_model:
      input_px = self.get_pre_train_graph(input_x)
      input_px = tf.reshape(input_px, [-1, self.max_doc_len,
                                       self.max_sen_len, self.pretrained_model_dim])
      out = tf.concat([out, input_px], axis=-1)
    out = self.embed_d(out, training=training)
    all_sen_encoder = tf.keras.layers.TimeDistributed(self.sen_encoder)
    # [batch_size, max_doc_len, features]
    out = all_sen_encoder(out, training=training, mask=sen_mask)
    # [batch_size, features]
    out = self.doc_encoder(out, training=training, mask=doc_mask)

    if self.use_dense_input:
      dense_out = self.dense_input_linear(dense_input)
      if self.only_dense_input:
        out = dense_out
      else:
        out = tf.keras.layers.Concatenate()([out, dense_out])

    # [batch_size, class_num]
    scores = self.final_dense(out)

    return scores
Ejemplo n.º 2
0
    def test_compute_doc_lens(self):
        ''' compute document length'''
        docs = tf.placeholder(dtype=tf.int32)
        lens = compute_doc_lens(docs)

        with self.cached_session(use_gpu=False, force_gpu=False) as sess:
            # test for 1d
            res = sess.run(lens, feed_dict={docs: [1, 2, 0, 0]})
            self.assertEqual(res, 2)

            # test for 2d
            res = sess.run(lens,
                           feed_dict={docs: [[1, 2, 0, 0], [1, 2, 3, 4]]})
            self.assertAllEqual(res, [2, 4])
Ejemplo n.º 3
0
    def test_compute_doc_lens(self):
        docs = tf.placeholder(dtype=tf.int32)
        lens = compute_doc_lens(docs)

        with self.session() as sess:
            # test for 1d
            res = sess.run(lens, feed_dict={docs: [1, 2, 0, 0]})
            logging.info(res)
            self.assertEqual(res, 2)

            # test for 2d
            res = sess.run(lens,
                           feed_dict={docs: [[1, 2, 0, 0], [1, 2, 3, 4]]})
            logging.info(res)
            self.assertAllEqual(res, [2, 4])