def test_rebatch(self): input_shape = (29, 5, 5, 20) result_shape = base.check_shape_agreement(convolution.Conv(30, (3, 3)), input_shape) self.assertEqual(result_shape, (29, 3, 3, 30)) input_shape = (29, 5, 5, 20) result_shape = base.check_shape_agreement( combinators.Rebatch(convolution.Conv(30, (3, 3)), n_batch_dims=1), input_shape) self.assertEqual(result_shape, (29, 3, 3, 30)) input_shape = (19, 29, 5, 5, 20) result_shape = base.check_shape_agreement( combinators.Rebatch(convolution.Conv(30, (3, 3)), n_batch_dims=2), input_shape) self.assertEqual(result_shape, (19, 29, 3, 3, 30))
def BuildConv(): return convolution.Conv(filters=units, kernel_size=kernel_size, padding='SAME')
def test_conv_rebatch(self): input_shape = (3, 29, 5, 5, 20) result_shape = base.check_shape_agreement(convolution.Conv(30, (3, 3)), input_shape) self.assertEqual(result_shape, (3, 29, 3, 3, 30))