Ejemplo n.º 1
0
    def test_permute_begin_end_ellipsis(self):
        # Testing constant path case
        graph = build_graph(nodes_attributes,
                            [('input', 'data_1'),
                             ('data_1', 'strided_slice'),
                             ('begin', 'begin_data'),
                             ('begin_data', 'strided_slice'),
                             ('end', 'end_data'),
                             ('end_data', 'strided_slice'),
                             ('stride', 'stride_data'),
                             ('stride_data', 'strided_slice'),
                             ('strided_slice', 'data_2')],
                            {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
                             'begin': {'value': [0, 1], 'shape': [2]},
                             'end': {'value': [1, 0], 'shape': [2]},
                             'stride': {'value': [1, 2], 'shape': [2]},
                             'strided_slice': {'begin_mask': np.array([0, 0]), 'end_mask': np.array([1, 0]),
                                               'new_axis_mask': np.array([0]), 'shrink_axis_mask': [0],
                                               'ellipsis_mask': np.array([1, 0])},
                             'data_2': {'shape': np.array([1, 2, 3, 4]), 'value': None},
                             })

        slice_node = Node(graph, 'strided_slice')
        slice_node['begin_mask'] = int64_array(extend_mask_according_ellipsis(slice_node['ellipsis_mask'],
                                                                              slice_node['shrink_axis_mask'], 4,
                                                                              list(slice_node['begin_mask']), 0))
        permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'begin_mask')
        self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([0, 0, 0, 0])))

        slice_node['end_mask'] = int64_array(extend_mask_according_ellipsis(slice_node['ellipsis_mask'],
                                                                            slice_node['shrink_axis_mask'], 4,
                                                                            list(slice_node['end_mask']), 0))
        permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'end_mask')
        self.assertTrue(np.array_equal(slice_node.end_mask, np.array([1, 0, 0, 0])))
Ejemplo n.º 2
0
    def test_permute_begin_end_shrink(self):
        # Testing constant path case
        graph = build_graph(
            nodes_attributes, [('data_1', 'strided_slice'),
                               ('begin', 'strided_slice'),
                               ('end', 'strided_slice'),
                               ('stride', 'strided_slice'),
                               ('strided_slice', 'data_2')], {
                                   'data_1': {
                                       'shape': np.array([1, 2, 3, 4]),
                                       'value': None
                                   },
                                   'strided_slice': {
                                       'begin_mask': np.array([1, 0, 0, 1]),
                                       'end_mask': np.array([0, 1, 0, 1]),
                                       'new_axis_mask': np.array([0, 0, 0]),
                                       'shrink_axis_mask': [1, 0, 0],
                                       'ellipsis_mask': np.array([0, 0, 0])
                                   },
                                   'data_2': {
                                       'shape': np.array([2, 3, 4]),
                                       'value': None
                                   },
                               })

        slice_node = Node(graph, 'strided_slice')
        permute_masks(
            slice_node,
            PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]),
            'begin_mask')

        self.assertTrue(
            np.array_equal(slice_node.begin_mask, np.array([1, 1, 0, 0])))

        permute_masks(
            slice_node,
            PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]),
            'end_mask')
        self.assertTrue(
            np.array_equal(slice_node.end_mask, np.array([0, 1, 1, 0])))