def test_group_conv(self): batch_size = 3 seqlen = 12 hidden = 32 hidden_out = 64 feature_group_count = 2 inputs = np.zeros((batch_size, seqlen, hidden)) inputs[0, 0, :2] = 1.0 inputs[0, 5, 24] = 1.0 inputs[0, 7, 28:32] = 1.0 data = jnp.asarray(inputs) net = conv.Conv1D(output_channels=hidden_out, kernel_shape=1, with_bias=False, feature_group_count=feature_group_count) out = net(data) expected_output_shape = (batch_size, seqlen, hidden_out) self.assertEqual(out.shape, expected_output_shape) # Make sure changing first half in time step 0 did affect exactly # all first half elements in the output: self.assertTrue( (out[0, 0, :hidden_out // feature_group_count] != 0).all()) self.assertTrue((out[0, 0, hidden_out // feature_group_count:-1] == 0).all()) # Make sure time step 5 and 7 it is the second half exactly. self.assertTrue( (out[0, 5, :hidden_out // feature_group_count] == 0).all()) self.assertTrue( (out[0, 7, hidden_out // feature_group_count:-1] != 0).all())
def f(): data = jnp.ones([1, 5, 1]) net = conv.Conv1D( output_channels=1, kernel_shape=3, stride=1, padding="VALID", with_bias=with_bias, **create_constant_initializers(1.0, 1.0, with_bias)) return net(data)