def test_permute_begin_end_ellipsis(self): # Testing constant path case graph = build_graph(nodes_attributes, [('input', 'data_1'), ('data_1', 'strided_slice'), ('begin', 'begin_data'), ('begin_data', 'strided_slice'), ('end', 'end_data'), ('end_data', 'strided_slice'), ('stride', 'stride_data'), ('stride_data', 'strided_slice'), ('strided_slice', 'data_2')], {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None}, 'begin': {'value': [0, 1], 'shape': [2]}, 'end': {'value': [1, 0], 'shape': [2]}, 'stride': {'value': [1, 2], 'shape': [2]}, 'strided_slice': {'begin_mask': np.array([0, 0]), 'end_mask': np.array([1, 0]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': [0], 'ellipsis_mask': np.array([1, 0])}, 'data_2': {'shape': np.array([1, 2, 3, 4]), 'value': None}, }) slice_node = Node(graph, 'strided_slice') slice_node['begin_mask'] = int64_array(extend_mask_according_ellipsis(slice_node['ellipsis_mask'], slice_node['shrink_axis_mask'], 4, list(slice_node['begin_mask']), 0)) permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'begin_mask') self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([0, 0, 0, 0]))) slice_node['end_mask'] = int64_array(extend_mask_according_ellipsis(slice_node['ellipsis_mask'], slice_node['shrink_axis_mask'], 4, list(slice_node['end_mask']), 0)) permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'end_mask') self.assertTrue(np.array_equal(slice_node.end_mask, np.array([1, 0, 0, 0])))
def test_extend_mask_twice(self): ellipsis_mask = int64_array([1, 0]) shrink_mask = int64_array([0, 0]) length_shape = 4 mask = int64_array([0, 1]) ins_value = 0 mask = extend_mask_according_ellipsis(ellipsis_mask, shrink_mask, length_shape, list(mask), ins_value) mask = extend_mask_according_ellipsis(ellipsis_mask, shrink_mask, length_shape, list(mask), ins_value) self.assertEquals(mask, [0, 0, 0, 1])
def test_extend_mask_shrinked_shrink_mask(self): ellipsis_mask = int64_array([0, 1, 0]) shrink_mask = int64_array([0, 0, 1]) length_shape = 4 ins_value = 2 shrink_mask = extend_mask_according_ellipsis(ellipsis_mask, shrink_mask, length_shape, list(shrink_mask), ins_value) self.assertEquals(shrink_mask, [0, 0, 2, 2, 1])
def test_permute_begin_end_ellipsis_new_inputs(self): # Testing constant path case graph = build_graph(nodes_attributes, [('input', 'data_1'), ('data_1', 'strided_slice', {'in': 0}), ('begin', 'begin_data'), ('begin_data', 'strided_slice', {'in': 1}), ('end', 'end_data'), ('end_data', 'strided_slice', {'in': 2}), ('stride', 'stride_data'), ('stride_data', 'strided_slice', {'in': 3}), ('strided_slice', 'data_2')], {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None}, 'strided_slice': {'begin_mask': np.array([0, 0, 0]), 'end_mask': np.array([1, 0, 0]), 'new_axis_mask': np.array([1, 0, 0]), 'shrink_axis_mask': [0], 'ellipsis_mask': np.array([0, 1, 0])}, 'begin': {'value': np.array([0, 1, 2])}, 'end': {'value': np.array([1, 2, 3])}, 'stride': {'value': np.array([1, 1, 1])}, 'begin_data': {'value': np.array([0, 1, 2])}, 'end_data': {'value': np.array([1, 2, 3])}, 'stride_data': {'value': np.array([1, 1, 1])}, 'data_2': {'shape': np.array([1, 1, 2, 3, 4]), 'value': None}, }) slice_node = Node(graph, 'strided_slice') slice_node.in_node(1).value = int64_array(extend_mask_according_ellipsis(slice_node['ellipsis_mask'], slice_node['shrink_axis_mask'], 5, list(slice_node.in_node(1).value), 0)) slice_node.in_node(1).value = permute_array(slice_node, slice_node.in_node(1).value) self.assertTrue(np.array_equal(slice_node.in_node(1).value, np.array([0, 2, 1, 0, 0]))) slice_node.in_node(2).value = int64_array(extend_mask_according_ellipsis(slice_node['ellipsis_mask'], slice_node['shrink_axis_mask'], 5, list(slice_node.in_node(2).value), 0)) slice_node.in_node(2).value = permute_array(slice_node, slice_node.in_node(2).value) self.assertTrue(np.array_equal(slice_node.in_node(2).value, np.array([1, 3, 2, 0, 0])))