Exemple #1
0
  def testStackingOverTimeFPropReduceMaxPadding(self):
    params = linears.StackingOverTime.Params()
    params.name = 'stackingOverTime'
    params.left_context = 2
    params.right_context = 0
    params.stride = 2
    params.padding_reduce_option = 'reduce_max'

    stacker = linears.StackingOverTime(params)
    stacker_vars = None
    self.assertEqual(stacker.window_size, 3)

    inputs = jnp.array([[[1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6]],
                        [[7, 7], [8, 8], [0, 0], [0, 0], [0, 0], [0, 0]]],
                       dtype=jnp.float32)
    paddings = jnp.array(
        [[[0], [0], [0], [0], [0], [0]], [[0], [0], [1], [1], [1], [1]]],
        dtype=jnp.float32)

    outputs, output_paddings = test_utils.apply(stacker, stacker_vars,
                                                stacker.fprop, inputs, paddings)
    print(f'{outputs}')
    expected_outputs = jnp.array([
        [[0, 0, 0, 0, 1, 1], [1, 1, 2, 2, 3, 3], [3, 3, 4, 4, 5, 5]],
        [[0, 0, 0, 0, 7, 7], [7, 7, 8, 8, 0, 0], [0, 0, 0, 0, 0, 0]],
    ],
                                 dtype=jnp.float32)

    self.assertAllClose(expected_outputs, outputs)

    expected_output_paddings = jnp.array([[[1], [0], [0]], [[1], [1], [1]]],
                                         dtype=jnp.float32)
    self.assertAllClose(expected_output_paddings, output_paddings)
Exemple #2
0
  def testStackingOverTimeFProp2(self):
    params = linears.StackingOverTime.Params()
    params.name = 'stackingOverTime'
    params.left_context = 0
    params.right_context = 1
    params.stride = 2

    stacker = linears.StackingOverTime(params)
    stacker_vars = None
    self.assertEqual(stacker.window_size, 2)

    inputs = np.random.normal(size=[2, 21, 16])
    # poor man's tf.sequence_mask in np.
    mask = np.zeros([2, 21]).astype(np.float32)
    mask[0, :9] = 1.
    mask[1, :14] = 1.

    paddings = 1.0 - mask
    paddings = jnp.expand_dims(paddings, -1)
    outputs, output_paddings = test_utils.apply(stacker, stacker_vars,
                                                stacker.fprop, inputs, paddings)

    # length
    self.assertAllClose(
        np.array([5, 7], dtype=np.float32), np.sum(1.0 - output_paddings,
                                                   (1, 2)))
    # input and output sums are equal
    self.assertAllClose(np.sum(inputs, (1, 2)), np.sum(outputs, (1, 2)))
Exemple #3
0
    def testStackingOverTimePadWithRightFrameFProp(self, pad_with_right_frame):
        params = linears.StackingOverTime.Params()
        params.name = 'stackingOverTime'
        params.left_context = 0
        params.right_context = 1
        params.stride = 2
        params.pad_with_right_frame = pad_with_right_frame

        stacker = linears.StackingOverTime(params)
        stacker_vars = None
        self.assertEqual(stacker.window_size, 2)

        # input shape [2, 5, 2]
        inputs = jnp.array([[[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]],
                            [[7, 7], [8, 8], [0, 0], [0, 0], [0, 0]]],
                           dtype=jnp.float32)
        paddings = jnp.array(
            [[[0], [0], [0], [0], [0]], [[0], [0], [1], [1], [1]]],
            dtype=jnp.float32)
        outputs, output_paddings = test_utils.apply(stacker, stacker_vars,
                                                    stacker.fprop, inputs,
                                                    paddings)
        print(f'{outputs}')

        if pad_with_right_frame:
            # output shape [2, 3, 4]
            # [5, 5] is duplication of the last input frame.
            expected_outputs = jnp.array([
                [[1, 1, 2, 2], [3, 3, 4, 4], [5, 5, 5, 5]],
                [[7, 7, 8, 8], [0, 0, 0, 0], [0, 0, 0, 0]],
            ],
                                         dtype=jnp.float32)
        else:
            expected_outputs = jnp.array([
                [[1, 1, 2, 2], [3, 3, 4, 4], [5, 5, 0, 0]],
                [[7, 7, 8, 8], [0, 0, 0, 0], [0, 0, 0, 0]],
            ],
                                         dtype=jnp.float32)

        self.assertAllClose(expected_outputs, outputs)

        expected_output_paddings = jnp.array(
            [[[0], [0], [0]], [[0], [1], [1]]], dtype=jnp.float32)
        self.assertAllClose(expected_output_paddings, output_paddings)
Exemple #4
0
  def testStackingOverTimeIdentityFProp(self):
    params = linears.StackingOverTime.Params()
    params.name = 'stackingOverTime'
    params.left_context = 0
    params.right_context = 0
    params.stride = 1

    stacker = linears.StackingOverTime(params)
    stacker_vars = None
    self.assertEqual(stacker.window_size, 1)
    inputs = jnp.array([[[1], [2], [3], [4], [5]]], dtype=jnp.float32)
    paddings = jnp.zeros([1, 5, 1], dtype=jnp.float32)

    outputs, output_paddings = test_utils.apply(stacker, stacker_vars,
                                                stacker.fprop, inputs, paddings)
    print(f'{outputs}')
    expected_outputs = jnp.array([[[1], [2], [3], [4], [5]]], dtype=jnp.float32)
    self.assertAllClose(expected_outputs, outputs)
    expected_output_paddings = jnp.array([[[0], [0], [0], [0], [0]]],
                                         dtype=jnp.float32)
    self.assertAllClose(expected_output_paddings, output_paddings)