def testStackingOverTimeFPropReduceMaxPadding(self): params = linears.StackingOverTime.Params() params.name = 'stackingOverTime' params.left_context = 2 params.right_context = 0 params.stride = 2 params.padding_reduce_option = 'reduce_max' stacker = linears.StackingOverTime(params) stacker_vars = None self.assertEqual(stacker.window_size, 3) inputs = jnp.array([[[1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6]], [[7, 7], [8, 8], [0, 0], [0, 0], [0, 0], [0, 0]]], dtype=jnp.float32) paddings = jnp.array( [[[0], [0], [0], [0], [0], [0]], [[0], [0], [1], [1], [1], [1]]], dtype=jnp.float32) outputs, output_paddings = test_utils.apply(stacker, stacker_vars, stacker.fprop, inputs, paddings) print(f'{outputs}') expected_outputs = jnp.array([ [[0, 0, 0, 0, 1, 1], [1, 1, 2, 2, 3, 3], [3, 3, 4, 4, 5, 5]], [[0, 0, 0, 0, 7, 7], [7, 7, 8, 8, 0, 0], [0, 0, 0, 0, 0, 0]], ], dtype=jnp.float32) self.assertAllClose(expected_outputs, outputs) expected_output_paddings = jnp.array([[[1], [0], [0]], [[1], [1], [1]]], dtype=jnp.float32) self.assertAllClose(expected_output_paddings, output_paddings)
def testStackingOverTimeFProp2(self): params = linears.StackingOverTime.Params() params.name = 'stackingOverTime' params.left_context = 0 params.right_context = 1 params.stride = 2 stacker = linears.StackingOverTime(params) stacker_vars = None self.assertEqual(stacker.window_size, 2) inputs = np.random.normal(size=[2, 21, 16]) # poor man's tf.sequence_mask in np. mask = np.zeros([2, 21]).astype(np.float32) mask[0, :9] = 1. mask[1, :14] = 1. paddings = 1.0 - mask paddings = jnp.expand_dims(paddings, -1) outputs, output_paddings = test_utils.apply(stacker, stacker_vars, stacker.fprop, inputs, paddings) # length self.assertAllClose( np.array([5, 7], dtype=np.float32), np.sum(1.0 - output_paddings, (1, 2))) # input and output sums are equal self.assertAllClose(np.sum(inputs, (1, 2)), np.sum(outputs, (1, 2)))
def testStackingOverTimePadWithRightFrameFProp(self, pad_with_right_frame): params = linears.StackingOverTime.Params() params.name = 'stackingOverTime' params.left_context = 0 params.right_context = 1 params.stride = 2 params.pad_with_right_frame = pad_with_right_frame stacker = linears.StackingOverTime(params) stacker_vars = None self.assertEqual(stacker.window_size, 2) # input shape [2, 5, 2] inputs = jnp.array([[[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]], [[7, 7], [8, 8], [0, 0], [0, 0], [0, 0]]], dtype=jnp.float32) paddings = jnp.array( [[[0], [0], [0], [0], [0]], [[0], [0], [1], [1], [1]]], dtype=jnp.float32) outputs, output_paddings = test_utils.apply(stacker, stacker_vars, stacker.fprop, inputs, paddings) print(f'{outputs}') if pad_with_right_frame: # output shape [2, 3, 4] # [5, 5] is duplication of the last input frame. expected_outputs = jnp.array([ [[1, 1, 2, 2], [3, 3, 4, 4], [5, 5, 5, 5]], [[7, 7, 8, 8], [0, 0, 0, 0], [0, 0, 0, 0]], ], dtype=jnp.float32) else: expected_outputs = jnp.array([ [[1, 1, 2, 2], [3, 3, 4, 4], [5, 5, 0, 0]], [[7, 7, 8, 8], [0, 0, 0, 0], [0, 0, 0, 0]], ], dtype=jnp.float32) self.assertAllClose(expected_outputs, outputs) expected_output_paddings = jnp.array( [[[0], [0], [0]], [[0], [1], [1]]], dtype=jnp.float32) self.assertAllClose(expected_output_paddings, output_paddings)
def testStackingOverTimeIdentityFProp(self): params = linears.StackingOverTime.Params() params.name = 'stackingOverTime' params.left_context = 0 params.right_context = 0 params.stride = 1 stacker = linears.StackingOverTime(params) stacker_vars = None self.assertEqual(stacker.window_size, 1) inputs = jnp.array([[[1], [2], [3], [4], [5]]], dtype=jnp.float32) paddings = jnp.zeros([1, 5, 1], dtype=jnp.float32) outputs, output_paddings = test_utils.apply(stacker, stacker_vars, stacker.fprop, inputs, paddings) print(f'{outputs}') expected_outputs = jnp.array([[[1], [2], [3], [4], [5]]], dtype=jnp.float32) self.assertAllClose(expected_outputs, outputs) expected_output_paddings = jnp.array([[[0], [0], [0], [0], [0]]], dtype=jnp.float32) self.assertAllClose(expected_output_paddings, output_paddings)