def test_permute_begin_end_ellipsis(self): # Testing constant path case graph = build_graph(nodes_attributes, [('input', 'data_1'), ('data_1', 'strided_slice'), ('begin', 'begin_data'), ('begin_data', 'strided_slice'), ('end', 'end_data'), ('end_data', 'strided_slice'), ('stride', 'stride_data'), ('stride_data', 'strided_slice'), ('strided_slice', 'data_2')], {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None}, 'begin': {'value': [0, 1], 'shape': [2]}, 'end': {'value': [1, 0], 'shape': [2]}, 'stride': {'value': [1, 2], 'shape': [2]}, 'strided_slice': {'begin_mask': np.array([0, 0]), 'end_mask': np.array([1, 0]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': [0], 'ellipsis_mask': np.array([1, 0])}, 'data_2': {'shape': np.array([1, 2, 3, 4]), 'value': None}, }) slice_node = Node(graph, 'strided_slice') slice_node['begin_mask'] = int64_array(extend_mask_according_ellipsis(slice_node['ellipsis_mask'], slice_node['shrink_axis_mask'], 4, list(slice_node['begin_mask']), 0)) permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'begin_mask') self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([0, 0, 0, 0]))) slice_node['end_mask'] = int64_array(extend_mask_according_ellipsis(slice_node['ellipsis_mask'], slice_node['shrink_axis_mask'], 4, list(slice_node['end_mask']), 0)) permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'end_mask') self.assertTrue(np.array_equal(slice_node.end_mask, np.array([1, 0, 0, 0])))
def test_permute_begin_end_shrink(self): # Testing constant path case graph = build_graph( nodes_attributes, [('data_1', 'strided_slice'), ('begin', 'strided_slice'), ('end', 'strided_slice'), ('stride', 'strided_slice'), ('strided_slice', 'data_2')], { 'data_1': { 'shape': np.array([1, 2, 3, 4]), 'value': None }, 'strided_slice': { 'begin_mask': np.array([1, 0, 0, 1]), 'end_mask': np.array([0, 1, 0, 1]), 'new_axis_mask': np.array([0, 0, 0]), 'shrink_axis_mask': [1, 0, 0], 'ellipsis_mask': np.array([0, 0, 0]) }, 'data_2': { 'shape': np.array([2, 3, 4]), 'value': None }, }) slice_node = Node(graph, 'strided_slice') permute_masks( slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'begin_mask') self.assertTrue( np.array_equal(slice_node.begin_mask, np.array([1, 1, 0, 0]))) permute_masks( slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'end_mask') self.assertTrue( np.array_equal(slice_node.end_mask, np.array([0, 1, 1, 0])))