def test_conv_rebatch(self): input_signature = ShapeDtype((3, 29, 5, 5, 20)) result_shape = base.check_shape_agreement(convolution.Conv(30, (3, 3)), input_signature) self.assertEqual(result_shape, (3, 29, 3, 3, 30))
def BuildConv(): return convolution.Conv(filters=n_units, kernel_size=kernel_size, padding='SAME')
def test_conv(self): input_shape = (29, 5, 5, 20) result_shape = base.check_shape_agreement(convolution.Conv(30, (3, 3)), input_shape) self.assertEqual(result_shape, (29, 3, 3, 30))