コード例 #1
0
 def test_downsample(self):
     self.assertEqual([10, 16, 16, 3],
                      infer.downsample([10, 32, 32, 3], [2, 2],
                                       spatial_begin=infer.spatial_begin(
                                           infer.Format.NHWC)))
     self.assertEqual([10, 16, 8],
                      infer.downsample([10, 32, 32], [2, 4],
                                       format=infer.Format.NHWC))
     self.assertEqual([10, 3, 8, 16],
                      infer.downsample([10, 3, 32, 32], [4, 2],
                                       format=infer.Format.NCHW))
コード例 #2
0
 def test_upsample(self):
     self.assertEqual([10, 64, 64, 3],
                      infer.upsample([10, 32, 32, 3], [2, 2],
                                     spatial_begin=infer.spatial_begin(
                                         infer.Format.NHWC)))
     self.assertEqual([10, 64, 128],
                      infer.upsample([10, 32, 32], [2, 4],
                                     format=infer.Format.NHWC))
     self.assertEqual([10, 3, 128, 64],
                      infer.upsample([10, 3, 32, 32], [4, 2],
                                     format=infer.Format.NCHW))
コード例 #3
0
 def test_resize(self):
     self.assertEqual([10, 16, 64, 3],
                      infer.resize([10, 32, 32, 3], [16, 64],
                                   spatial_begin=infer.spatial_begin(
                                       infer.Format.NHWC)))
     self.assertEqual([10, 16, 64],
                      infer.resize([10, 32, 32], [16, 64],
                                   format=infer.Format.NHWC))
     self.assertEqual([10, 3, 16, 64],
                      infer.resize([10, 3, 32, 32], [16, 64],
                                   format=infer.Format.NCHW))
コード例 #4
0
    def test_conv(self):
        self.assertEqual([10, 30, 30, 16], infer.conv(input=[10, 32, 32, 3],
                                                      filter=[3, 3],
                                                      padding=infer.Padding.VALID,
                                                      stride=[1, 1],
                                                      dilation=[1, 1],
                                                      groups=1,
                                                      spatial_begin=infer.spatial_begin(infer.Format.NHWC),
                                                      channel_axis=infer.channel_axis(infer.Format.NHWC),
                                                      output_channels=16))

        self.assertEqual([10, 16, 30, 30], infer.conv(input=[10, 3, 32, 32],
                                                      filter=[3, 3],
                                                      padding=[(0, 0), (0, 0)],
                                                      stride=[1, 1],
                                                      dilation=[1, 1],
                                                      groups=1,
                                                      spatial_begin=infer.spatial_begin(infer.Format.NCHW),
                                                      channel_axis=infer.channel_axis(infer.Format.NCHW),
                                                      output_channels=16))

        self.assertEqual([10, 3, 32, 32], infer.conv(input=[10, 16, 30, 30],
                                                     filter=[3, 3],
                                                     padding=[(0, 0), (0, 0)],
                                                     stride=[1, 1],
                                                     dilation=[1, 1],
                                                     groups=1,
                                                     format=infer.Format.NCHW,
                                                     output_channels=3,
                                                     deconv=True))

        self.assertEqual([10, 3, 32, 32], infer.conv(input=[10, 16, 32, 32],
                                                     filter=[3, 3],
                                                     padding=infer.Padding.SAME_UPPER,
                                                     stride=[1, 1],
                                                     dilation=[1, 1],
                                                     groups=1,
                                                     format=infer.Format.NCHW,
                                                     output_channels=3))

        self.assertEqual([10, 6, 32, 32], infer.conv(input=[10, 3, 32, 32],
                                                     filter=[3, 3],
                                                     padding=infer.Padding.SAME_LOWER,
                                                     stride=[1, 1],
                                                     dilation=[1, 1],
                                                     groups=0,
                                                     format=infer.Format.NCHW,
                                                     output_channels=6))

        self.assertEqual([10, 16, 32, 32], infer.conv(input=[10, 3, 32, 32],
                                                      filter=[3, 3],
                                                      padding=infer.Padding.SAME_UPPER,
                                                      stride=[1, 1],
                                                      dilation=[1, 1],
                                                      groups=1,
                                                      format=infer.Format.NCHW,
                                                      output_channels=16,
                                                      deconv=True))

        self.assertEqual([10, 16, 64, 64], infer.conv(input=[10, 3, 32, 32],
                                                      filter=[3, 3],
                                                      padding=infer.Padding.SAME_UPPER,
                                                      stride=[2, 2],
                                                      dilation=[1, 1],
                                                      groups=1,
                                                      format=infer.Format.NCHW,
                                                      output_channels=16,
                                                      deconv=True))

        self.assertEqual([10, 16, 65, 65], infer.conv(input=[10, 3, 32, 32],
                                                      filter=[3, 3],
                                                      padding=infer.Padding.SAME_UPPER,
                                                      stride=[2, 2],
                                                      dilation=[1, 1],
                                                      groups=1,
                                                      format=infer.Format.NCHW,
                                                      output_channels=16,
                                                      output_padding=[(0, 1), (0, 1)],
                                                      deconv=True))