def test_slice_infer_4(self): graph = self.build_test_graph() node = Node(graph, 'sslice_1') node.in_node(1).value = np.array([0, 10, 10, 0]) node.begin_mask = [1, 0, 0, 1] # 6 tf_strided_slice_infer(node) self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 34, 30, 2])), 'Wrong output shape detected')
def test_slice_infer_1(self): graph = self.build_test_graph() node = Node(graph, 'sslice_1') tf_strided_slice_infer(node) self.assertTrue( np.array_equal(node.out_node().shape, np.array([1, 34, 30, 2])), 'Wrong output shape detected')
def test_slice_infer_dim_beg(self): graph = self.build_test_graph_dim_beg() node = Node(graph, 'sslice_1') node.shrink_axis_mask = [1] tf_strided_slice_infer(node) self.assertTrue(np.array_equal(node.out_node().shape, np.array([4])), 'Wrong output shape detected') self.assertTrue(np.array_equal(node.out_node().value, np.array([1, 34, 34, 62])), 'Wrong output shape detected')
def infer(node: Node): tf_strided_slice_infer(node) if node.graph.graph['layout'] == 'NHWC' and node.out_port( 0).data.get_value() is None: PermuteAttrs.create_permute_attrs( node, attrs=[ ('shrink_axis_mask', 'input:0', permute_masks), ('new_axis_mask', 'input:0', permute_masks), ('ellipsis_mask', 'input:0', permute_masks), ('begin_mask', 'input:0', permute_masks), ('end_mask', 'input:0', permute_masks), ]) for i in range(1, len(node.in_nodes())): if node.in_node( i).value is not None and node.in_node(i).shape[0] > 3: perm = PermuteAttrs.get_nhwc_to_nchw_permutation( len(node.in_node(0).shape)) node.in_node(i).value = permute_array_with_ellipsis( node, perm, node.in_node(i).value, 0) # due to permutation from nhwc to nchw we will extend all masks and inputs idx = np.nonzero(node.ellipsis_mask) node.ellipsis_mask[idx] = 0
def test_slice_infer_2(self): graph = self.build_test_graph() node = Node(graph, 'sslice_1') node.end_mask = [1, 0, 0, 1] # 6 tf_strided_slice_infer(node) self.assertTrue( np.array_equal(node.out_node().shape, np.array([1, 35, 35, 2])), 'Wrong output shape detected')
def test_slice_infer_neg_end(self): graph = self.build_test_graph() node = Node(graph, 'sslice_1') end_node = Node(graph, 'sslice_end_1') end_node.value = np.array([1, -1, -5, -1]) tf_strided_slice_infer(node) self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 34, 30, 2])), 'Wrong output shape detected') self.assertTrue(np.array_equal(end_node.value, np.array([1, -1, -5, -1])), 'Negative values in end were converted to positive')
def test_slice_infer_13(self): graph = self.build_test_graph2() node = Node(graph, 'sslice_1') node.in_node(1).value = np.array([1]) node.shrink_axis_mask = [1] tf_strided_slice_infer(node) self.assertTrue(np.array_equal(node.out_node().shape, np.array([])), 'Wrong output shape detected') self.assertTrue(np.array_equal(node.out_node().value, np.array(34)), 'Wrong output shape detected')
def test_slice_infer_12(self): graph = self.build_test_graph() node = Node(graph, 'sslice_1') node.begin_mask = [0, 0, 0, 0] # 15 node.end_mask = [0, 0, 0, 0] # 15 node.shrink_axis_mask = [1, 1, 1, 0] # 7 tf_strided_slice_infer(node) self.assertTrue(np.array_equal(node.out_node().shape, np.array([3])), 'Wrong output shape detected')
def test_slice_infer_7(self): graph = self.build_test_graph2() node = Node(graph, 'sslice_1') node.in_node(1).value = np.array([1]) node.in_node(2).value = np.array([3]) tf_strided_slice_infer(node) self.assertTrue(np.array_equal(node.out_node().shape, np.array([2])), 'Wrong output shape detected') self.assertTrue(np.array_equal(node.out_node().value, np.array([34, 34])), 'Wrong output value detected')
def test_slice_infer_11(self): graph = self.build_test_graph() node = Node(graph, 'sslice_1') node.begin_mask = 15 # 1111 node.end_mask = 15 # 1111 node.shrink_axis_mask = 5 # 0101 tf_strided_slice_infer(node) self.assertTrue( np.array_equal(node.out_node().shape, np.array([35, 3])), 'Wrong output shape detected')
def test_slice_infer_8(self): graph = self.build_test_graph2() node = Node(graph, 'sslice_1') node.new_axis_mask = 1 tf_strided_slice_infer(node) self.assertTrue( np.array_equal(node.out_node().shape, np.array([1, 4])), 'Wrong output shape detected') self.assertTrue( np.array_equal(node.out_node().value, np.array([[1, 34, 34, 62]])), 'Wrong output value detected')
def test_slice_infer_14(self): graph = self.build_test_graph2() node = Node(graph, 'sslice_1') node.in_node(3).value = np.array([-1]) node.end_mask = [0] node.begin_mask = [0] node.in_node(0).shape = [4] tf_strided_slice_infer(node) self.assertTrue(np.array_equal(node.out_node().shape, np.array([4])), 'Wrong output shape detected') print(node.out_node().value) self.assertTrue(np.array_equal(node.out_node().value, np.array([62, 34, 34, 1])), 'Wrong output shape detected')
def infer(node: Node): tf_strided_slice_infer(node) out_shape = node.out_port(0).data.get_shape() assert out_shape is not None, \ 'Output shape was not calculated for node {}'.format(node.name) # extend inputs according to ellipsis mask and/or input_shape for i_port in node.in_ports().values(): if i_port.idx == 0 or i_port.disconnected(): continue old_value = i_port.data.get_value() # additional check for non-const input # error will be return in shape inference if non-const will be added # it is paranoid check for case if shape inference will be changed assert old_value is not None, \ '{} input of {} node is not constant: \'value\' attribute for edge ' + \ 'contains None'.format(i_port.idx, node.name) # insert 0 for begin and end and 1 for stride new_value = int64_array(extend_mask_according_ellipsis(node.ellipsis_mask, node.shrink_axis_mask, len(out_shape), list(old_value), int(i_port.idx == 3))) # set_value additionally set_shape and propagate value to Const node if not np.array_equal(new_value, old_value): i_port.data.set_value(new_value) # extend masks before removing ellipsis for attr in ["new_axis_mask", "shrink_axis_mask", "begin_mask", "end_mask", "ellipsis_mask"]: node[attr] = int64_array(extend_mask_according_ellipsis(node.ellipsis_mask, node.shrink_axis_mask, len(out_shape), list(node[attr]), 0)) # we will extend all masks and inputs to simplify future transformations idx = np.nonzero(node.ellipsis_mask) node.ellipsis_mask[idx] = 0 if node.graph.graph['layout'] == 'NHWC' and node.out_port(0).data.get_value() is None: PermuteAttrs.create_permute_attrs(node, attrs=[('shrink_axis_mask', 'input:0', permute_masks), ('new_axis_mask', 'input:0', permute_masks), ('ellipsis_mask', 'input:0', permute_masks), ('begin_mask', 'input:0', permute_masks), ('end_mask', 'input:0', permute_masks), ]) # permute inputs in_shape = node.in_port(0).get_source().data.get_shape() assert in_shape is not None, \ 'Input shape is unknown for 0 input of node {}'.format(node.name) input_rank = len(in_shape) if input_rank > 3: for i_port in node.in_ports().values(): if i_port.idx == 0 or i_port.disconnected(): continue new_value = permute_array(node, i_port.data.get_value()) # set_value additionally set_shape and propagate value to Const node i_port.data.set_value(new_value)
def infer(node: Node): tf_strided_slice_infer(node) if node.graph.graph['layout'] == 'NHWC' and node.out_port( 0).data.get_value() is None: PermuteAttrs.create_permute_attrs( node, attrs=[ ('shrink_axis_mask', 'input:0', permute_masks), ('new_axis_mask', 'input:0', permute_masks), ('ellipsis_mask', 'input:0', permute_masks), ('begin_mask', 'input:0', permute_masks), ('end_mask', 'input:0', permute_masks), ]) for i in range(1, len(node.in_nodes())): if node.in_node(i).value is not None and len( node.in_node(0).shape) > 3: node.in_node(i).value = permute_array_with_ellipsis( node, node.in_node(i).value, 0) # extend masks before removing ellipsis if np.any(node.ellipsis_mask): for attr in [ "new_axis_mask", "shrink_axis_mask", "begin_mask", "end_mask" ]: node[attr] = int64_array( extend_mask_according_ellipsis( node.ellipsis_mask, node.shrink_axis_mask, len(node.out_port(0).data.get_shape()), list(node[attr]), attr in ["begin_mask", "end_mask"])) # due to permutation from nhwc to nchw we will extend all masks and inputs idx = np.nonzero(node.ellipsis_mask) node.ellipsis_mask[idx] = 0