예제 #1
0
    def test_rebatch(self):
        input_shape = (29, 5, 5, 20)
        result_shape = base.check_shape_agreement(convolution.Conv(30, (3, 3)),
                                                  input_shape)
        self.assertEqual(result_shape, (29, 3, 3, 30))

        input_shape = (29, 5, 5, 20)
        result_shape = base.check_shape_agreement(
            combinators.Rebatch(convolution.Conv(30, (3, 3)), n_batch_dims=1),
            input_shape)
        self.assertEqual(result_shape, (29, 3, 3, 30))

        input_shape = (19, 29, 5, 5, 20)
        result_shape = base.check_shape_agreement(
            combinators.Rebatch(convolution.Conv(30, (3, 3)), n_batch_dims=2),
            input_shape)
        self.assertEqual(result_shape, (19, 29, 3, 3, 30))
예제 #2
0
파일: rnn.py 프로젝트: timxzz/tensor2tensor
 def BuildConv():
     return convolution.Conv(filters=units,
                             kernel_size=kernel_size,
                             padding='SAME')
예제 #3
0
 def test_conv_rebatch(self):
     input_shape = (3, 29, 5, 5, 20)
     result_shape = base.check_shape_agreement(convolution.Conv(30, (3, 3)),
                                               input_shape)
     self.assertEqual(result_shape, (3, 29, 3, 3, 30))