def test_conv1d_update_state(self):
        batch = 2
        d_model = 6
        filter_size = 3
        batch_dim = mtf.Dimension("batch", batch)
        filter_dim = mtf.Dimension("filter", filter_size)

        x = np.random.randn(batch, d_model)
        x_mtf = self.converter.convert_np_array_to_mtf_tensor(
            x, dtype=tf.float32, dim_names=["batch", "d_model"])

        old_state = np.random.randn(batch, filter_size, d_model)
        old_state_mtf = self.converter.convert_np_array_to_mtf_tensor(
            old_state,
            dtype=tf.float32,
            dim_names=["batch", "filter", "d_model"])

        position_mtf = mtf.constant(self.converter.mesh,
                                    filter_size - 1,
                                    shape=mtf.Shape([batch_dim]),
                                    dtype=tf.int32)
        conv_layer = transformer_layers.Conv1D()
        output_mtf = conv_layer.update_state(old_state_mtf,
                                             x_mtf,
                                             position_mtf,
                                             filter_dim,
                                             dtype=tf.float32)
        actual = self.converter.convert_mtf_tensor_to_np_array(output_mtf)

        expected = np.empty(shape=old_state.shape)
        expected[:, :filter_size - 1, :] = old_state[:, 1:, :]
        expected[:, -1, :] = x
        self.assertAllClose(actual, expected)
    def test_conv1d_record_states_first_part_mode_with_partial_sequence(self):
        batch = 2
        d_model = 6
        length = 6
        filter_size = 3

        inputs = np.random.randint(1, 10, size=[batch, length])
        context = get_dummy_decoder_context(
            self.converter,
            batch=batch,
            d_model=d_model,
            initial_position=
            2,  # indices 0 and 1 correspond to partial sequences.
            inputs=inputs,
            mode="first_part")

        x = np.random.randn(batch, length, d_model)
        x_mtf = self.converter.convert_np_array_to_mtf_tensor(
            x, dtype=tf.float32, dim_names=["batch", "length", "d_model"])

        conv_layer = transformer_layers.Conv1D()
        conv_layer.record_states_first_part_mode(context, x_mtf, filter_size)
        actual = self.converter.convert_mtf_tensor_to_np_array(
            context.new_states[0])
        expected = np.zeros(shape=[batch, filter_size, d_model])
        expected[:, -2, :] = x[:, 0, :]
        expected[:, -1, :] = x[:, 1, :]

        self.assertAllClose(actual, expected)
    def test_conv1d_record_states_incremental_mode(self):
        batch = 2
        d_model = 6
        filter_size = 3

        state = np.random.randn(batch, filter_size, d_model)
        context = get_dummy_decoder_context(self.converter,
                                            batch=batch,
                                            d_model=d_model,
                                            state=state)

        x = np.random.randn(batch, d_model)
        x_mtf = self.converter.convert_np_array_to_mtf_tensor(
            x, dtype=tf.float32, dim_names=["batch", "d_model"])
        conv_layer = transformer_layers.Conv1D()
        _ = conv_layer.record_states_incremental_mode(context, x_mtf,
                                                      filter_size)
        actual = self.converter.convert_mtf_tensor_to_np_array(
            context.new_states[0])

        # [batch, 2, d_model], [batch, 1, d_model] -> [batch, 3, d_model]
        expected = np.concatenate([state[:, 1:, :], x[:, np.newaxis, :]],
                                  axis=1)
        self.assertAllClose(actual, expected)