def test_basic_invocation_train(self):
        batch_size = 2
        seq_length = 8
        hidden_size = 4
        sequence_data = np.random.uniform(size=(batch_size, seq_length,
                                                hidden_size)).astype('float32')
        paragraph_mask = np.random.uniform(size=(batch_size,
                                                 seq_length)).astype('float32')
        class_index = np.random.uniform(size=(batch_size)).astype('uint8')
        start_positions = np.zeros(shape=(batch_size)).astype('uint8')

        layer = span_labeling.XLNetSpanLabeling(input_width=hidden_size,
                                                start_n_top=2,
                                                end_n_top=2,
                                                activation='tanh',
                                                dropout_rate=0.,
                                                initializer='glorot_uniform')
        output = layer(sequence_data=sequence_data,
                       class_index=class_index,
                       paragraph_mask=paragraph_mask,
                       start_positions=start_positions,
                       training=True)

        expected_keys = {
            'start_logits',
            'end_logits',
            'class_logits',
            'start_predictions',
            'end_predictions',
        }
        self.assertSetEqual(expected_keys, set(output.keys()))
Пример #2
0
    def test_basic_invocation_beam_search(self, top_n):
        batch_size = 2
        seq_length = 8
        hidden_size = 4
        sequence_data = np.random.uniform(size=(batch_size, seq_length,
                                                hidden_size)).astype('float32')
        position_mask = np.random.uniform(size=(batch_size,
                                                seq_length)).astype('float32')
        class_index = np.random.uniform(size=(batch_size)).astype('uint8')

        layer = span_labeling.XLNetSpanLabeling(input_width=hidden_size,
                                                start_n_top=top_n,
                                                end_n_top=top_n,
                                                activation='tanh',
                                                dropout_rate=0.,
                                                initializer='glorot_uniform')
        output = layer(sequence_data=sequence_data,
                       class_index=class_index,
                       position_mask=position_mask,
                       training=False)
        expected_keys = {
            'start_top_log_probs',
            'end_top_log_probs',
            'class_logits',
            'start_top_index',
            'end_top_index',
        }
        self.assertSetEqual(expected_keys, set(output.keys()))
Пример #3
0
    def test_subclass_invocation(self):
        """Tests basic invocation of this layer wrapped in a subclass."""
        seq_length = 8
        hidden_size = 4
        batch_size = 2

        sequence_data = tf.keras.Input(shape=(seq_length, hidden_size),
                                       dtype=tf.float32)
        class_index = tf.keras.Input(shape=(), dtype=tf.uint8)
        position_mask = tf.keras.Input(shape=(seq_length), dtype=tf.float32)
        start_positions = tf.keras.Input(shape=(), dtype=tf.int32)

        layer = span_labeling.XLNetSpanLabeling(input_width=hidden_size,
                                                start_n_top=5,
                                                end_n_top=5,
                                                activation='tanh',
                                                dropout_rate=0.,
                                                initializer='glorot_uniform')

        output = layer(sequence_data=sequence_data,
                       class_index=class_index,
                       position_mask=position_mask,
                       start_positions=start_positions)
        model = tf.keras.Model(inputs={
            'sequence_data': sequence_data,
            'class_index': class_index,
            'position_mask': position_mask,
            'start_positions': start_positions,
        },
                               outputs=output)

        sequence_data = tf.random.uniform(shape=(batch_size, seq_length,
                                                 hidden_size),
                                          dtype=tf.float32)
        position_mask = tf.random.uniform(shape=(batch_size, seq_length),
                                          dtype=tf.float32)
        class_index = tf.ones(shape=(batch_size, ), dtype=tf.uint8)
        start_positions = tf.random.uniform(shape=(batch_size, ),
                                            maxval=5,
                                            dtype=tf.int32)

        inputs = dict(sequence_data=sequence_data,
                      position_mask=position_mask,
                      class_index=class_index,
                      start_positions=start_positions)

        output = model(inputs)
        self.assertIsInstance(output, dict)

        # Test `call` without training flag.
        output = model(inputs, training=False)
        self.assertIsInstance(output, dict)

        # Test `call` with training flag.
        # Note: this fails due to incompatibility with the functional API.
        with self.assertRaisesRegexp(AssertionError,
                                     'Could not compute output KerasTensor'):
            model(inputs, training=True)
Пример #4
0
  def test_functional_model_invocation(self):
    """Tests basic invocation of this layer wrapped by a Functional model."""
    seq_length = 8
    hidden_size = 4
    batch_size = 2

    sequence_data = tf.keras.Input(shape=(seq_length, hidden_size),
                                   dtype=tf.float32)
    class_index = tf.keras.Input(shape=(), dtype=tf.uint8)
    position_mask = tf.keras.Input(shape=(seq_length), dtype=tf.float32)
    start_positions = tf.keras.Input(shape=(), dtype=tf.float32)

    layer = span_labeling.XLNetSpanLabeling(
        input_width=hidden_size,
        start_n_top=5,
        end_n_top=5,
        activation='tanh',
        dropout_rate=0.,
        initializer='glorot_uniform')

    output = layer(sequence_data=sequence_data,
                   class_index=class_index,
                   position_mask=position_mask,
                   start_positions=start_positions)
    model = tf.keras.Model(
        inputs={
            'sequence_data': sequence_data,
            'class_index': class_index,
            'position_mask': position_mask,
            'start_positions': start_positions,
        },
        outputs=output)

    sequence_data = tf.random.uniform(
        shape=(batch_size, seq_length, hidden_size), dtype=tf.float32)
    position_mask = tf.random.uniform(
        shape=(batch_size, seq_length), dtype=tf.float32)
    class_index = tf.ones(shape=(batch_size,), dtype=tf.uint8)
    start_positions = tf.random.uniform(shape=(batch_size,), dtype=tf.float32)

    inputs = dict(sequence_data=sequence_data,
                  position_mask=position_mask,
                  class_index=class_index,
                  start_positions=start_positions)

    output = model(inputs)
    self.assertIsInstance(output, dict)

    # Test `call` with training flag.
    output = model.call(inputs, training=True)
    self.assertIsInstance(output, dict)

    # Test `call` without training flag.
    output = model.call(inputs, training=False)
    self.assertIsInstance(output, dict)
Пример #5
0
    def test_serialize_deserialize(self):
        # Create a network object that sets all of its config options.
        network = span_labeling.XLNetSpanLabeling(input_width=128,
                                                  start_n_top=5,
                                                  end_n_top=1,
                                                  activation='tanh',
                                                  dropout_rate=0.34,
                                                  initializer='zeros')

        # Create another network object from the first object's config.
        new_network = span_labeling.XLNetSpanLabeling.from_config(
            network.get_config())

        # If the serialization was successful, the new config should match the old.
        self.assertAllEqual(network.get_config(), new_network.get_config())