Пример #1
0
def propagate_fused_batch_norm(op, const_value_by_tensor):
    # type: (TFOperation, _ConstValueByTensorT)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]]
    format = (infer.Format.NCHW if "data_format" in op.attribs and op.attribs["data_format"][1].upper() == "C"
              else infer.Format.NHWC)
    input_shape = op.inputs[0].shape
    channel_shape = [input_shape[infer.channel_axis(format)]]
    return [infer.copy(input_shape),
            infer.copy(channel_shape),
            infer.copy(channel_shape),
            infer.copy(channel_shape),
            infer.copy(channel_shape)], [op.attribs['T']] * 5
Пример #2
0
    def test_conv(self):
        self.assertEqual([10, 30, 30, 16], infer.conv(input=[10, 32, 32, 3],
                                                      filter=[3, 3],
                                                      padding=infer.Padding.VALID,
                                                      stride=[1, 1],
                                                      dilation=[1, 1],
                                                      groups=1,
                                                      spatial_begin=infer.spatial_begin(infer.Format.NHWC),
                                                      channel_axis=infer.channel_axis(infer.Format.NHWC),
                                                      output_channels=16))

        self.assertEqual([10, 16, 30, 30], infer.conv(input=[10, 3, 32, 32],
                                                      filter=[3, 3],
                                                      padding=[(0, 0), (0, 0)],
                                                      stride=[1, 1],
                                                      dilation=[1, 1],
                                                      groups=1,
                                                      spatial_begin=infer.spatial_begin(infer.Format.NCHW),
                                                      channel_axis=infer.channel_axis(infer.Format.NCHW),
                                                      output_channels=16))

        self.assertEqual([10, 3, 32, 32], infer.conv(input=[10, 16, 30, 30],
                                                     filter=[3, 3],
                                                     padding=[(0, 0), (0, 0)],
                                                     stride=[1, 1],
                                                     dilation=[1, 1],
                                                     groups=1,
                                                     format=infer.Format.NCHW,
                                                     output_channels=3,
                                                     deconv=True))

        self.assertEqual([10, 3, 32, 32], infer.conv(input=[10, 16, 32, 32],
                                                     filter=[3, 3],
                                                     padding=infer.Padding.SAME_UPPER,
                                                     stride=[1, 1],
                                                     dilation=[1, 1],
                                                     groups=1,
                                                     format=infer.Format.NCHW,
                                                     output_channels=3))

        self.assertEqual([10, 6, 32, 32], infer.conv(input=[10, 3, 32, 32],
                                                     filter=[3, 3],
                                                     padding=infer.Padding.SAME_LOWER,
                                                     stride=[1, 1],
                                                     dilation=[1, 1],
                                                     groups=0,
                                                     format=infer.Format.NCHW,
                                                     output_channels=6))

        self.assertEqual([10, 16, 32, 32], infer.conv(input=[10, 3, 32, 32],
                                                      filter=[3, 3],
                                                      padding=infer.Padding.SAME_UPPER,
                                                      stride=[1, 1],
                                                      dilation=[1, 1],
                                                      groups=1,
                                                      format=infer.Format.NCHW,
                                                      output_channels=16,
                                                      deconv=True))

        self.assertEqual([10, 16, 64, 64], infer.conv(input=[10, 3, 32, 32],
                                                      filter=[3, 3],
                                                      padding=infer.Padding.SAME_UPPER,
                                                      stride=[2, 2],
                                                      dilation=[1, 1],
                                                      groups=1,
                                                      format=infer.Format.NCHW,
                                                      output_channels=16,
                                                      deconv=True))

        self.assertEqual([10, 16, 65, 65], infer.conv(input=[10, 3, 32, 32],
                                                      filter=[3, 3],
                                                      padding=infer.Padding.SAME_UPPER,
                                                      stride=[2, 2],
                                                      dilation=[1, 1],
                                                      groups=1,
                                                      format=infer.Format.NCHW,
                                                      output_channels=16,
                                                      output_padding=[(0, 1), (0, 1)],
                                                      deconv=True))