def test_slice_infer_no_slice_point(self): graph = build_graph( nodes_attributes, [('node_1', 'Slice_node'), ('Slice_node', 'node_2'), ('Slice_node', 'node_3'), ('node_2', 'op_output'), ('node_3', 'op_output_1')], { 'node_1': { 'shape': np.array([1, 288, 56, 56]) }, 'node_2': { 'shape': None }, 'node_3': { 'shape': None }, 'Slice_node': { 'axis': 1, 'slice_point': [] } }) slice_node = Node(graph, 'Slice_node') caffe_slice_infer(slice_node) exp_shape = np.array([1, 144, 56, 56]) res_shape1 = graph.node['node_2']['shape'] res_shape2 = graph.node['node_3']['shape'] for i in range(0, len(exp_shape)): self.assertEqual(exp_shape[i], res_shape1[i]) for i in range(0, len(exp_shape)): self.assertEqual(exp_shape[i], res_shape2[i])
def test_slice_infer_3_outs(self): graph = build_graph( nodes_attributes, [('node_1', 'Slice_node'), ('Slice_node', 'node_2'), ('Slice_node', 'node_3'), ('Slice_node', 'node_4')], { 'node_1': { 'shape': np.array([1, 288, 56, 56]) }, 'node_2': { 'is_output': True, 'shape': None }, 'node_3': { 'is_output': True, 'shape': None }, 'node_4': { 'is_output': True, 'shape': None }, 'Slice_node': { 'axis': 1, 'slice_point': [100, 150] } }) slice_node = Node(graph, 'Slice_node') caffe_slice_infer(slice_node) exp_shape1 = np.array([1, 100, 56, 56]) exp_shape2 = np.array([1, 50, 56, 56]) exp_shape3 = np.array([1, 138, 56, 56]) res_shape1 = graph.node['node_2']['shape'] res_shape2 = graph.node['node_3']['shape'] res_shape3 = graph.node['node_4']['shape'] for i in range(0, len(exp_shape1)): self.assertEqual(exp_shape1[i], res_shape1[i]) for i in range(0, len(exp_shape2)): self.assertEqual(exp_shape2[i], res_shape2[i]) for i in range(0, len(exp_shape3)): self.assertEqual(exp_shape3[i], res_shape3[i])
def infer(node: Node): axis = None steps = None if len(node.in_nodes()) == 1: # Caffe or ONNX before 10 opset if node.has('start') and node.has('end') and node.has('axis'): # ONNX case if node.has_valid('start') and node.has_valid('end') and node.has('axis'): start = node.start end = node.end axis = node.axis else: log.warning('Incorrect slice operation: no starts or end attr') return else: # Caffe case from mo.front.common.partial_infer.slice import caffe_slice_infer caffe_slice_infer(node) elif len(node.in_nodes()) >= 3: if node.has('format') and node['format'] == 'onnx': # ONNX 10 opset case starts_node = node.in_node(1) ends_node = node.in_node(2) if starts_node.has_valid('value') and ends_node.has_valid('value'): start = np.array(node.in_node(1).value, dtype=np.int64) end = np.array(node.in_node(2).value, dtype=np.int64) if 3 in node.in_nodes(): if node.in_node(3).has_valid('value'): axis = np.array(node.in_node(3).value, dtype=np.int64) else: log.warning('Incorrect slice operation: axes should be const') return if 4 in node.in_nodes(): if node.in_node(4).has_valid('value'): steps = np.array(node.in_node(4).value, dtype=np.int64) else: log.warning('Incorrect slice operation: steps should be const') return else: log.warning('Incorrect slice operation: no starts or ends attr') return else: # TF case start_node = node.in_node(1) size_node = node.in_node(2) if start_node.has_valid('value') and size_node.has_valid('value'): start = np.array(node.in_node(1).value, dtype=np.int64) size = np.array(node.in_node(2).value, dtype=np.int64) end = start + size axis = None # Delete edges to start, size nodes node.graph.remove_edge(node.in_node(1).id, node.id) node.graph.remove_edge(node.in_node(2).id, node.id) node['start'] = start node['end'] = end node['axis'] = None else: log.warning('Incorrect slice operation: no starts or end attr') return else: log.warning('Incorrect number of input nodes in slice operation') return input_shape = node.in_node(0).shape # Check for situation when size[i] == -1 in TF for i in range(start.size): if end[i] < start[i]: end[i] = input_shape[i] # Update end param node.end = end value = node.in_node(0).value # If value is None create dummy vaue for shape propogation if value is None: value = np.zeros(input_shape) # Following ONNX and TF specification, in case of unknown axis, axises should be in greater order if axis is None: axis = [x for x in range(len(start))] if steps is None: steps = np.ones(start.size, dtype=np.int64) # Calculate output value for slice operation slice_idx = [None for x in range(len(node.in_node().shape))] shrink_axis_mask = [False for x in range(len(node.in_node().shape))] for id in range(len(axis)): # Ranged for output value for specified axis slice_idx[axis[id]] = slice(start[id], end[id], steps[id]) # TODO: check whether this check is really important for axis, s in enumerate(slice_idx): if s is None: slice_idx[axis] = slice(0, input_shape[axis], 1) # Add new parameters to node node['slices'] = np.array(slice_idx) node['shrink_axis_mask'] = np.array(shrink_axis_mask) value = value[tuple(slice_idx)] node.out_node().value = value.copy() if node.in_node(0).value is not None else None node.out_node().shape = np.array(value.shape)