def test_basic_invocation_train(self): batch_size = 2 seq_length = 8 hidden_size = 4 sequence_data = np.random.uniform(size=(batch_size, seq_length, hidden_size)).astype('float32') paragraph_mask = np.random.uniform(size=(batch_size, seq_length)).astype('float32') class_index = np.random.uniform(size=(batch_size)).astype('uint8') start_positions = np.zeros(shape=(batch_size)).astype('uint8') layer = span_labeling.XLNetSpanLabeling(input_width=hidden_size, start_n_top=2, end_n_top=2, activation='tanh', dropout_rate=0., initializer='glorot_uniform') output = layer(sequence_data=sequence_data, class_index=class_index, paragraph_mask=paragraph_mask, start_positions=start_positions, training=True) expected_keys = { 'start_logits', 'end_logits', 'class_logits', 'start_predictions', 'end_predictions', } self.assertSetEqual(expected_keys, set(output.keys()))
def test_basic_invocation_beam_search(self, top_n): batch_size = 2 seq_length = 8 hidden_size = 4 sequence_data = np.random.uniform(size=(batch_size, seq_length, hidden_size)).astype('float32') position_mask = np.random.uniform(size=(batch_size, seq_length)).astype('float32') class_index = np.random.uniform(size=(batch_size)).astype('uint8') layer = span_labeling.XLNetSpanLabeling(input_width=hidden_size, start_n_top=top_n, end_n_top=top_n, activation='tanh', dropout_rate=0., initializer='glorot_uniform') output = layer(sequence_data=sequence_data, class_index=class_index, position_mask=position_mask, training=False) expected_keys = { 'start_top_log_probs', 'end_top_log_probs', 'class_logits', 'start_top_index', 'end_top_index', } self.assertSetEqual(expected_keys, set(output.keys()))
def test_subclass_invocation(self): """Tests basic invocation of this layer wrapped in a subclass.""" seq_length = 8 hidden_size = 4 batch_size = 2 sequence_data = tf.keras.Input(shape=(seq_length, hidden_size), dtype=tf.float32) class_index = tf.keras.Input(shape=(), dtype=tf.uint8) position_mask = tf.keras.Input(shape=(seq_length), dtype=tf.float32) start_positions = tf.keras.Input(shape=(), dtype=tf.int32) layer = span_labeling.XLNetSpanLabeling(input_width=hidden_size, start_n_top=5, end_n_top=5, activation='tanh', dropout_rate=0., initializer='glorot_uniform') output = layer(sequence_data=sequence_data, class_index=class_index, position_mask=position_mask, start_positions=start_positions) model = tf.keras.Model(inputs={ 'sequence_data': sequence_data, 'class_index': class_index, 'position_mask': position_mask, 'start_positions': start_positions, }, outputs=output) sequence_data = tf.random.uniform(shape=(batch_size, seq_length, hidden_size), dtype=tf.float32) position_mask = tf.random.uniform(shape=(batch_size, seq_length), dtype=tf.float32) class_index = tf.ones(shape=(batch_size, ), dtype=tf.uint8) start_positions = tf.random.uniform(shape=(batch_size, ), maxval=5, dtype=tf.int32) inputs = dict(sequence_data=sequence_data, position_mask=position_mask, class_index=class_index, start_positions=start_positions) output = model(inputs) self.assertIsInstance(output, dict) # Test `call` without training flag. output = model(inputs, training=False) self.assertIsInstance(output, dict) # Test `call` with training flag. # Note: this fails due to incompatibility with the functional API. with self.assertRaisesRegexp(AssertionError, 'Could not compute output KerasTensor'): model(inputs, training=True)
def test_functional_model_invocation(self): """Tests basic invocation of this layer wrapped by a Functional model.""" seq_length = 8 hidden_size = 4 batch_size = 2 sequence_data = tf.keras.Input(shape=(seq_length, hidden_size), dtype=tf.float32) class_index = tf.keras.Input(shape=(), dtype=tf.uint8) position_mask = tf.keras.Input(shape=(seq_length), dtype=tf.float32) start_positions = tf.keras.Input(shape=(), dtype=tf.float32) layer = span_labeling.XLNetSpanLabeling( input_width=hidden_size, start_n_top=5, end_n_top=5, activation='tanh', dropout_rate=0., initializer='glorot_uniform') output = layer(sequence_data=sequence_data, class_index=class_index, position_mask=position_mask, start_positions=start_positions) model = tf.keras.Model( inputs={ 'sequence_data': sequence_data, 'class_index': class_index, 'position_mask': position_mask, 'start_positions': start_positions, }, outputs=output) sequence_data = tf.random.uniform( shape=(batch_size, seq_length, hidden_size), dtype=tf.float32) position_mask = tf.random.uniform( shape=(batch_size, seq_length), dtype=tf.float32) class_index = tf.ones(shape=(batch_size,), dtype=tf.uint8) start_positions = tf.random.uniform(shape=(batch_size,), dtype=tf.float32) inputs = dict(sequence_data=sequence_data, position_mask=position_mask, class_index=class_index, start_positions=start_positions) output = model(inputs) self.assertIsInstance(output, dict) # Test `call` with training flag. output = model.call(inputs, training=True) self.assertIsInstance(output, dict) # Test `call` without training flag. output = model.call(inputs, training=False) self.assertIsInstance(output, dict)
def test_serialize_deserialize(self): # Create a network object that sets all of its config options. network = span_labeling.XLNetSpanLabeling(input_width=128, start_n_top=5, end_n_top=1, activation='tanh', dropout_rate=0.34, initializer='zeros') # Create another network object from the first object's config. new_network = span_labeling.XLNetSpanLabeling.from_config( network.get_config()) # If the serialization was successful, the new config should match the old. self.assertAllEqual(network.get_config(), new_network.get_config())