Beispiel #1
0
    def build_model(self):
        from tensorflow.python.keras.layers import Dense, Dot

        dim = self.sent1.get_shape().as_list()[-1]
        temp_W = tf.layers.dense(self.sent2, dim, name="dense")  # (B, L2, dim)
        temp_W = Dot(axes=[2, 2])([self.sent1, temp_W])  # (B, L1, L2)

        if self.sent1_mask is not None:
            s1_mask_exp = tf.expand_dims(self.sent1_mask, axis=2)  # (B, L1, 1)
            s2_mask_exp = tf.expand_dims(self.sent2_mask, axis=1)  # (B, 1, L2)
            temp_W1 = temp_W - (1 - s1_mask_exp) * 1e20
            temp_W2 = temp_W - (1 - s2_mask_exp) * 1e20
        else:
            temp_W1 = temp_W
            temp_W2 = temp_W

        W1 = tf.nn.softmax(temp_W1, axis=1)
        W2 = tf.nn.softmax(temp_W2, axis=2)

        M1 = Dot(axes=[2, 1])([W2, self.sent2])
        M2 = Dot(axes=[2, 1])([W1, self.sent1])

        s1_cat = tf.concat([M2 - self.sent2, M2 * self.sent2], axis=-1)
        s2_cat = tf.concat([M1 - self.sent1, M1 * self.sent1], axis=-1)

        S1 = tf.layers.dense(s1_cat,
                             dim,
                             activation=tf.nn.relu,
                             name="cat_dense")
        S2 = tf.layers.dense(s2_cat,
                             dim,
                             activation=tf.nn.relu,
                             name="cat_dense",
                             reuse=True)

        if self.is_training:
            S1 = dropout(S1, dropout_prob=0.1)
            S1 = dropout(S1, dropout_prob=0.1)

        if self.sent1_mask is not None:
            S2 = S2 * tf.expand_dims(self.sent1_mask, axis=2)
            S1 = S1 * tf.expand_dims(self.sent2_mask, axis=2)

        C1 = tf.reduce_max(S1, axis=1)
        C2 = tf.reduce_max(S2, axis=1)

        C_cat = tf.concat([C1, C2], axis=1)

        return gelu(tf.layers.dense(C_cat, dim))
Beispiel #2
0
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
                use_one_hot_embeddings):

  """Creates a classification model."""
  model = modeling.BertModel(
      config=bert_config,
      is_training=is_training,
      input_ids=input_ids,
      input_mask=input_mask,
      token_type_ids=segment_ids,
      use_one_hot_embeddings=use_one_hot_embeddings)

  final_hidden = model.get_sequence_output()

  final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3)
  batch_size = final_hidden_shape[0]
  seq_length = final_hidden_shape[1]
  hidden_size = final_hidden_shape[2]

  keep_prob = 1.0
  if is_training:
    keep_prob = 0.7

  W1 = tf.get_variable(
    "cls/squad/output_weights1", [768, hidden_size],
    initializer=tf.truncated_normal_initializer(stddev=0.02))

  b1 = tf.get_variable(
      "cls/squad/output_bias1", [768], initializer=tf.zeros_initializer())

  W2 = tf.get_variable(
      "cls/squad/output_weights2", [384, 768],
      initializer=tf.truncated_normal_initializer(stddev=0.02))

  b2 = tf.get_variable(
      "cls/squad/output_bias2", [384], initializer=tf.zeros_initializer())

  W3 = tf.get_variable(
      "cls/squad/output_weights3", [2, 384],
      initializer=tf.truncated_normal_initializer(stddev=0.02))

  b3 = tf.get_variable(
      "cls/squad/output_bias3", [2], initializer=tf.zeros_initializer())

  final_hidden_matrix = tf.reshape(final_hidden,
                                   [batch_size * seq_length, hidden_size])

    
  logits = tf.matmul(final_hidden_matrix, W1, transpose_b=True)
  logits = tf.nn.bias_add(logits, b1)
  logits = modeling.gelu(logits)
  logits = tf.nn.dropout(logits, keep_prob)

  logits = tf.matmul(logits, W2, transpose_b=True)
  logits = tf.nn.bias_add(logits, b2)
  logits = modeling.gelu(logits)
  logits = tf.nn.dropout(logits, keep_prob)

  logits = tf.matmul(logits, W3, transpose_b=True)
  logits = tf.nn.bias_add(logits, b3)

  logits = tf.reshape(logits, [batch_size, seq_length, 2])
  logits = tf.transpose(logits, [2, 0, 1])

  unstacked_logits = tf.unstack(logits, axis=0)

  (start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1])

  return (start_logits, end_logits)