Exemplo n.º 1
0
    def test_group_conv(self):
        batch_size = 3
        seqlen = 12
        hidden = 32
        hidden_out = 64
        feature_group_count = 2
        inputs = np.zeros((batch_size, seqlen, hidden))
        inputs[0, 0, :2] = 1.0
        inputs[0, 5, 24] = 1.0
        inputs[0, 7, 28:32] = 1.0

        data = jnp.asarray(inputs)
        net = conv.Conv1D(output_channels=hidden_out,
                          kernel_shape=1,
                          with_bias=False,
                          feature_group_count=feature_group_count)
        out = net(data)

        expected_output_shape = (batch_size, seqlen, hidden_out)
        self.assertEqual(out.shape, expected_output_shape)

        # Make sure changing first half in time step 0 did affect exactly
        # all first half elements in the output:
        self.assertTrue(
            (out[0, 0, :hidden_out // feature_group_count] != 0).all())
        self.assertTrue((out[0, 0,
                             hidden_out // feature_group_count:-1] == 0).all())

        # Make sure time step 5 and 7 it is the second half exactly.
        self.assertTrue(
            (out[0, 5, :hidden_out // feature_group_count] == 0).all())
        self.assertTrue(
            (out[0, 7, hidden_out // feature_group_count:-1] != 0).all())
Exemplo n.º 2
0
 def f():
   data = jnp.ones([1, 5, 1])
   net = conv.Conv1D(
       output_channels=1,
       kernel_shape=3,
       stride=1,
       padding="VALID",
       with_bias=with_bias,
       **create_constant_initializers(1.0, 1.0, with_bias))
   return net(data)