示例#1
0
    def test_slice_infer_no_slice_point(self):
        graph = build_graph(
            nodes_attributes, [('node_1', 'Slice_node'),
                               ('Slice_node', 'node_2'),
                               ('Slice_node', 'node_3'),
                               ('node_2', 'op_output'),
                               ('node_3', 'op_output_1')], {
                                   'node_1': {
                                       'shape': np.array([1, 288, 56, 56])
                                   },
                                   'node_2': {
                                       'shape': None
                                   },
                                   'node_3': {
                                       'shape': None
                                   },
                                   'Slice_node': {
                                       'axis': 1,
                                       'slice_point': []
                                   }
                               })

        slice_node = Node(graph, 'Slice_node')

        caffe_slice_infer(slice_node)
        exp_shape = np.array([1, 144, 56, 56])
        res_shape1 = graph.node['node_2']['shape']
        res_shape2 = graph.node['node_3']['shape']

        for i in range(0, len(exp_shape)):
            self.assertEqual(exp_shape[i], res_shape1[i])

        for i in range(0, len(exp_shape)):
            self.assertEqual(exp_shape[i], res_shape2[i])
示例#2
0
    def test_slice_infer_3_outs(self):
        graph = build_graph(
            nodes_attributes, [('node_1', 'Slice_node'),
                               ('Slice_node', 'node_2'),
                               ('Slice_node', 'node_3'),
                               ('Slice_node', 'node_4')], {
                                   'node_1': {
                                       'shape': np.array([1, 288, 56, 56])
                                   },
                                   'node_2': {
                                       'is_output': True,
                                       'shape': None
                                   },
                                   'node_3': {
                                       'is_output': True,
                                       'shape': None
                                   },
                                   'node_4': {
                                       'is_output': True,
                                       'shape': None
                                   },
                                   'Slice_node': {
                                       'axis': 1,
                                       'slice_point': [100, 150]
                                   }
                               })

        slice_node = Node(graph, 'Slice_node')

        caffe_slice_infer(slice_node)
        exp_shape1 = np.array([1, 100, 56, 56])
        exp_shape2 = np.array([1, 50, 56, 56])
        exp_shape3 = np.array([1, 138, 56, 56])
        res_shape1 = graph.node['node_2']['shape']
        res_shape2 = graph.node['node_3']['shape']
        res_shape3 = graph.node['node_4']['shape']

        for i in range(0, len(exp_shape1)):
            self.assertEqual(exp_shape1[i], res_shape1[i])

        for i in range(0, len(exp_shape2)):
            self.assertEqual(exp_shape2[i], res_shape2[i])

        for i in range(0, len(exp_shape3)):
            self.assertEqual(exp_shape3[i], res_shape3[i])
示例#3
0
    def infer(node: Node):
        axis = None
        steps = None
        if len(node.in_nodes()) == 1:
            # Caffe or ONNX before 10 opset
            if node.has('start') and node.has('end') and node.has('axis'):
                # ONNX case
                if node.has_valid('start') and node.has_valid('end') and node.has('axis'):
                    start = node.start
                    end = node.end
                    axis = node.axis
                else:
                    log.warning('Incorrect slice operation: no starts or end attr')
                    return
            else:
                # Caffe case
                from mo.front.common.partial_infer.slice import caffe_slice_infer
                caffe_slice_infer(node)
        elif len(node.in_nodes()) >= 3:
            if node.has('format') and node['format'] == 'onnx':
                # ONNX 10 opset case
                starts_node = node.in_node(1)
                ends_node = node.in_node(2)
                if starts_node.has_valid('value') and ends_node.has_valid('value'):
                    start = np.array(node.in_node(1).value, dtype=np.int64)
                    end = np.array(node.in_node(2).value, dtype=np.int64)
                    if 3 in node.in_nodes():
                        if node.in_node(3).has_valid('value'):
                            axis = np.array(node.in_node(3).value, dtype=np.int64)
                        else:
                            log.warning('Incorrect slice operation: axes should be const')
                            return
                    if 4 in node.in_nodes():
                        if node.in_node(4).has_valid('value'):
                            steps = np.array(node.in_node(4).value, dtype=np.int64)
                        else:
                            log.warning('Incorrect slice operation: steps should be const')
                            return
                else:
                    log.warning('Incorrect slice operation: no starts or ends attr')
                    return
            else:
                # TF case
                start_node = node.in_node(1)
                size_node = node.in_node(2)
                if start_node.has_valid('value') and size_node.has_valid('value'):
                    start = np.array(node.in_node(1).value, dtype=np.int64)
                    size = np.array(node.in_node(2).value, dtype=np.int64)
                    end = start + size
                    axis = None

                    # Delete edges to start, size nodes
                    node.graph.remove_edge(node.in_node(1).id, node.id)
                    node.graph.remove_edge(node.in_node(2).id, node.id)

                    node['start'] = start
                    node['end'] = end
                    node['axis'] = None
                else:
                    log.warning('Incorrect slice operation: no starts or end attr')
                    return
        else:
            log.warning('Incorrect number of input nodes in slice operation')
            return

        input_shape = node.in_node(0).shape
        # Check for situation when size[i] == -1 in TF
        for i in range(start.size):
            if end[i] < start[i]:
                end[i] = input_shape[i]
        # Update end param
        node.end = end
        value = node.in_node(0).value

        # If value is None create dummy vaue for shape propogation
        if value is None:
            value = np.zeros(input_shape)

        # Following ONNX and TF specification, in case of unknown axis, axises should be in greater order
        if axis is None:
            axis = [x for x in range(len(start))]

        if steps is None:
            steps = np.ones(start.size, dtype=np.int64)

        # Calculate output value for slice operation
        slice_idx = [None for x in range(len(node.in_node().shape))]
        shrink_axis_mask = [False for x in range(len(node.in_node().shape))]
        for id in range(len(axis)):
            # Ranged for output value for specified axis
            slice_idx[axis[id]] = slice(start[id], end[id], steps[id])

        # TODO: check whether this check is really important
        for axis, s in enumerate(slice_idx):
            if s is None:
                slice_idx[axis] = slice(0, input_shape[axis], 1)

        # Add new parameters to node
        node['slices'] = np.array(slice_idx)
        node['shrink_axis_mask'] = np.array(shrink_axis_mask)

        value = value[tuple(slice_idx)]
        node.out_node().value = value.copy() if node.in_node(0).value is not None else None
        node.out_node().shape = np.array(value.shape)