def test_conv1d_update_state(self): batch = 2 d_model = 6 filter_size = 3 batch_dim = mtf.Dimension("batch", batch) filter_dim = mtf.Dimension("filter", filter_size) x = np.random.randn(batch, d_model) x_mtf = self.converter.convert_np_array_to_mtf_tensor( x, dtype=tf.float32, dim_names=["batch", "d_model"]) old_state = np.random.randn(batch, filter_size, d_model) old_state_mtf = self.converter.convert_np_array_to_mtf_tensor( old_state, dtype=tf.float32, dim_names=["batch", "filter", "d_model"]) position_mtf = mtf.constant(self.converter.mesh, filter_size - 1, shape=mtf.Shape([batch_dim]), dtype=tf.int32) conv_layer = transformer_layers.Conv1D() output_mtf = conv_layer.update_state(old_state_mtf, x_mtf, position_mtf, filter_dim, dtype=tf.float32) actual = self.converter.convert_mtf_tensor_to_np_array(output_mtf) expected = np.empty(shape=old_state.shape) expected[:, :filter_size - 1, :] = old_state[:, 1:, :] expected[:, -1, :] = x self.assertAllClose(actual, expected)
def test_conv1d_record_states_first_part_mode_with_partial_sequence(self): batch = 2 d_model = 6 length = 6 filter_size = 3 inputs = np.random.randint(1, 10, size=[batch, length]) context = get_dummy_decoder_context( self.converter, batch=batch, d_model=d_model, initial_position= 2, # indices 0 and 1 correspond to partial sequences. inputs=inputs, mode="first_part") x = np.random.randn(batch, length, d_model) x_mtf = self.converter.convert_np_array_to_mtf_tensor( x, dtype=tf.float32, dim_names=["batch", "length", "d_model"]) conv_layer = transformer_layers.Conv1D() conv_layer.record_states_first_part_mode(context, x_mtf, filter_size) actual = self.converter.convert_mtf_tensor_to_np_array( context.new_states[0]) expected = np.zeros(shape=[batch, filter_size, d_model]) expected[:, -2, :] = x[:, 0, :] expected[:, -1, :] = x[:, 1, :] self.assertAllClose(actual, expected)
def test_conv1d_record_states_incremental_mode(self): batch = 2 d_model = 6 filter_size = 3 state = np.random.randn(batch, filter_size, d_model) context = get_dummy_decoder_context(self.converter, batch=batch, d_model=d_model, state=state) x = np.random.randn(batch, d_model) x_mtf = self.converter.convert_np_array_to_mtf_tensor( x, dtype=tf.float32, dim_names=["batch", "d_model"]) conv_layer = transformer_layers.Conv1D() _ = conv_layer.record_states_incremental_mode(context, x_mtf, filter_size) actual = self.converter.convert_mtf_tensor_to_np_array( context.new_states[0]) # [batch, 2, d_model], [batch, 1, d_model] -> [batch, 3, d_model] expected = np.concatenate([state[:, 1:, :], x[:, np.newaxis, :]], axis=1) self.assertAllClose(actual, expected)