def test_conv2d_nchw(self): node = PB({ 'pb': PB({ 'attr': { 'data_format': PB({'s': b"NCHW"}), 'strides': PB({'list': PB({"i": self.strides})}), 'padding': PB({'s': b'VALID'}), 'dilations': PB({'list': PB({"i": [1, 1, 1, 1]})}) } }) }) self.expected = { # spatial_dims = [2, 3] will be detected in infer function "channel_dims": [1], "batch_dims": [0], "input_feature_channel": 2, "output_feature_channel": 3, 'dilation': np.array([1, 1, 1, 1], dtype=np.int8), 'stride': np.array(self.strides, dtype=np.int8), } Conv2DFrontExtractor.extract(node) self.res = node self.expected_call_args = (None, False) self.compare()
def test_conv_2d_defaults(self): node = PB({ 'pb': PB({ 'attr': { 'data_format': PB({'s': b"NHWC"}), 'strides': PB({'list': PB({"i": self.strides})}), 'padding': PB({'s': b'VALID'}), 'dilations': PB({'list': PB({"i": [1, 1, 1, 1]})}) } }) }) self.expected = { 'bias_addable': True, 'dilation': np.array([1, 1, 1, 1], dtype=np.int8), 'type': 'Convolution', 'layout': 'NHWC', } Conv2DFrontExtractor.extract(node) self.res = node self.expected_call_args = (None, False) self.compare()