Example #1
0
    def test_caffe_interp_infer_wh(self):
        graph = build_graph(
            nodes_attributes, [('node_1', 'interp'), ('interp', 'node_3'),
                               ('node_3', 'op_output')], {
                                   'node_3': {
                                       'shape': None
                                   },
                                   'node_1': {
                                       'shape': np.array([1, 1024, 1, 1])
                                   },
                                   'interp': {
                                       'width': 65,
                                       'height': 33,
                                       'zoom_factor': 1,
                                       'shrink_factor': 1,
                                       'pad_beg': 0,
                                       'pad_end': 0
                                   }
                               })
        graph.graph['layout'] = 'NCHW'

        interp_node = Node(graph, 'interp')
        InterpOp.interp_infer(interp_node)
        exp_shape = np.array([1, 1024, 33, 65])
        res_shape = graph.node['node_3']['shape']
        for i in range(0, len(exp_shape)):
            self.assertEqual(exp_shape[i], res_shape[i])
Example #2
0
    def test_caffe_interp_2_blobs(self):
        graph = build_graph(
            nodes_attributes, [('node_1', 'interp'), ('node_2', 'interp'),
                               ('interp', 'node_3'), ('node_3', 'op_output')],
            {
                'node_3': {
                    'shape': None
                },
                'node_1': {
                    'shape': np.array([1, 256, 33, 66])
                },
                'node_2': {
                    'shape': np.array([1, 1, 3, 6])
                },
                'interp': {
                    'zoom_factor': 1,
                    'shrink_factor': 1,
                    'pad_beg': 0,
                    'pad_end': 0,
                    'parse_2nd_input': 'shape',
                }
            })
        graph.graph['layout'] = 'NCHW'

        interp_node = Node(graph, 'interp')
        InterpOp.interp_infer(interp_node)
        exp_shape = np.array([1, 256, 3, 6])
        res_shape = graph.node['node_3']['shape']
        for i in range(0, len(exp_shape)):
            self.assertEqual(exp_shape[i], res_shape[i])
Example #3
0
 def extract(node):
     mapping_rule = {
         'pad_end': 0,
         'pad_beg': 0,
         'align_corners': int(node.pb.attr['align_corners'].b)
     }
     InterpOp.update_node_stat(node, mapping_rule)
     return __class__.enabled
    def replace_pattern(graph: Graph, match: dict):
        node = match['interpolate']

        # common
        mode = node.mode
        assert mode in ['linear', 'nearest', 'cubic', 'area']
        in_shape = node.in_port(0).data.get_shape()
        assert in_shape is not None and len(in_shape) == 4
        out_shape = node.out_port(0).data.get_shape()
        assert out_shape is not None and len(out_shape) == 4
        in_height, in_width = in_shape[2], in_shape[3]
        out_height, out_width = out_shape[2], out_shape[3]
        factor = factor_update(
            None if not node.has_valid('factor') else node.factor,
            [float(out_height) / in_height,
             float(out_width) / in_width], [in_height, in_width],
            [out_height, out_width], node.soft_get('name'))
        update_attrs = {
            'width': out_width,
            'height': out_height,
            'factor': factor,
        }

        if (node.has_valid('shrink_factor')
                and node.has_valid('zoom_factor')) or factor is None:
            del update_attrs['factor']
            if node.has('factor'):
                del node['factor']

        if ((node.has_valid('shrink_factor') and node.shrink_factor != 1) or
            (node.has_valid('zoom_factor') and node.zoom_factor != 1) or 'factor' in update_attrs) \
                and ((not node.has_valid('width') or node.width == 0) and
                     (not node.has_valid('height') or node.height == 0)):
            update_attrs['width'] = 0
            update_attrs['height'] = 0

        # specific
        if mode in ['nearest', 'cubic', 'area'
                    ] or node.has_and_set('convert_to_resample'):
            assert not node.align_corners
            assert node.pads_begin == 0 and node.pads_end == 0
            update_attrs[
                'resample_type'] = InterpolateToInterpOrResample.type_map[mode]
            ResampleOp.update_node_stat(node, update_attrs)
            node.in_port(1).disconnect()
        elif mode == 'linear':
            update_attrs.update({
                'pad_beg': node.pads_begin,
                'pad_end': node.pads_end,
                'align_corners': node.align_corners,
            })
            InterpOp.update_node_stat(node, update_attrs)
            node.in_port(1).disconnect()
        node['force_precision_in_ports'] = None
Example #5
0
    def test_tf_interp_infer_two_inputs(self):

        graph = build_graph(
            nodes_attributes, [('node_1', 'interp'), ('node_2', 'interp'),
                               ('interp', 'node_3')], {
                                   'node_1': {
                                       'shape': np.array([1, 20, 30, 100])
                                   },
                                   'node_2': {
                                       'shape': np.array([2]),
                                       'value': np.array([2, 3])
                                   }
                               })
        graph.graph['layout'] = 'NHWC'
        interp_node = Node(graph, 'interp')
        InterpOp.interp_infer(interp_node)
        exp_shape = np.array([1, 2, 3, 100])
        res_shape = graph.node['node_3']['shape']
        for i in range(0, len(exp_shape)):
            self.assertEqual(exp_shape[i], res_shape[i])
Example #6
0
 def test_tf_interp_infer_one_input_hw(self):
     graph = build_graph(
         nodes_attributes, [('node_1', 'interp'), ('interp', 'node_3')], {
             'node_1': {
                 'shape': np.array([1, 20, 30, 100])
             },
             'interp': {
                 'height': 4,
                 'width': 6,
                 'pad_beg': 0,
                 'pad_end': 0,
                 'zoom_factor': None,
                 'shrink_factor': None
             }
         })
     graph.graph['layout'] = 'NHWC'
     interp_node = Node(graph, 'interp')
     InterpOp.interp_infer(interp_node)
     exp_shape = np.array([1, 4, 6, 100])
     res_shape = graph.node['node_3']['shape']
     for i in range(0, len(exp_shape)):
         self.assertEqual(exp_shape[i], res_shape[i])
Example #7
0
    def test_caffe_interp_infer_zoom_shrink_error(self):
        graph = build_graph(
            nodes_attributes, [('node_1', 'interp'), ('interp', 'node_3'),
                               ('node_3', 'op_output')], {
                                   'node_3': {
                                       'shape': None
                                   },
                                   'node_1': {
                                       'shape': np.array([1, 256, 33, 65])
                                   },
                                   'interp': {
                                       'zoom_factor': 0,
                                       'height': 0,
                                       'width': 0,
                                       'shrink_factor': 0,
                                       'pad_beg': 0,
                                       'pad_end': 0
                                   }
                               })
        graph.graph['layout'] = 'NCHW'

        interp_node = Node(graph, 'interp')
        InterpOp.interp_infer(interp_node)
        self.assertIsNone(graph.node['node_3']['shape'])
    def replace_pattern(graph: Graph, match: dict):
        node = match['interpolate']

        assert 1 in node.in_ports() and not node.in_port(1).disconnected() and \
               node.in_port(1).data.get_value() is not None, 'Interpolate node {} is corrupted: no 1-port input found'

        # common
        mode = node.mode
        assert mode in ['linear', 'nearest', 'cubic', 'area']
        in_shape = node.in_port(0).data.get_shape()
        assert in_shape is not None and len(in_shape) in [4, 5]
        out_shape = node.out_port(0).data.get_shape()
        assert out_shape is not None and len(out_shape) in [4, 5]
        in_height, in_width = in_shape[2], in_shape[3]
        out_height, out_width = out_shape[2], out_shape[3]
        factor = factor_update(
            None if not node.has_valid('factor') else node.factor,
            [float(out_height) / in_height,
             float(out_width) / in_width], [in_height, in_width],
            [out_height, out_width], node.soft_get('name'))
        update_attrs = {
            'width': out_width,
            'height': out_height,
            'factor': factor,
        }

        if (node.has_valid('shrink_factor')
                and node.has_valid('zoom_factor')) or factor is None:
            del update_attrs['factor']
            if node.has('factor'):
                del node['factor']

        if ((node.has_valid('shrink_factor') and node.shrink_factor != 1) or
            (node.has_valid('zoom_factor') and node.zoom_factor != 1) or 'factor' in update_attrs) \
                and ((not node.has_valid('width') or node.width == 0) and
                     (not node.has_valid('height') or node.height == 0)):
            update_attrs['width'] = 0
            update_attrs['height'] = 0

        # specific
        if mode in ['nearest', 'cubic', 'area'
                    ] or node.has_and_set('convert_to_resample'):
            assert not node.align_corners
            assert node.pads_begin == 0 and node.pads_end == 0
            update_attrs[
                'resample_type'] = InterpolateToInterpOrResample.type_map[mode]
            ResampleOp.update_node_stat(node, update_attrs)

            if not graph.graph[
                    'cmd_params'].keep_shape_ops or graph.graph['fw'] != 'tf':
                node.in_port(1).disconnect()
            else:
                # we avoid making resample non-reshapable for tf version
                shape = Shape(graph, {}).create_node()
                node.in_port(0).get_source().connect(shape.in_port(0))

                batch = node_to_get_batch_value(shape)
                features = node_to_get_features_dimension_value(shape)
                full_shape = new_shape_node_from_shape_nodes(
                    [batch, features,
                     node.in_port(1).get_source().node])
                node.in_port(1).get_connection().set_source(
                    full_shape.out_port(0))
                full_shape['override_output_shape'] = True

        elif mode == 'linear':
            assert len(in_shape) == 4, 'Interp does not support 5D input'
            update_attrs.update({
                'pad_beg': node.pads_begin,
                'pad_end': node.pads_end,
                'align_corners': node.align_corners,
            })
            InterpOp.update_node_stat(node, update_attrs)
            node.in_port(1).disconnect()