コード例 #1
0
    def test_permute_begin_end_ellipsis(self):
        # Testing constant path case
        graph = build_graph(nodes_attributes,
                            [('input', 'data_1'),
                             ('data_1', 'strided_slice'),
                             ('begin', 'begin_data'),
                             ('begin_data', 'strided_slice'),
                             ('end', 'end_data'),
                             ('end_data', 'strided_slice'),
                             ('stride', 'stride_data'),
                             ('stride_data', 'strided_slice'),
                             ('strided_slice', 'data_2')],
                            {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
                             'begin': {'value': [0, 1], 'shape': [2]},
                             'end': {'value': [1, 0], 'shape': [2]},
                             'stride': {'value': [1, 2], 'shape': [2]},
                             'strided_slice': {'begin_mask': np.array([0, 0]), 'end_mask': np.array([1, 0]),
                                               'new_axis_mask': np.array([0]), 'shrink_axis_mask': [0],
                                               'ellipsis_mask': np.array([1, 0])},
                             'data_2': {'shape': np.array([1, 2, 3, 4]), 'value': None},
                             })

        slice_node = Node(graph, 'strided_slice')
        slice_node['begin_mask'] = int64_array(extend_mask_according_ellipsis(slice_node['ellipsis_mask'],
                                                                              slice_node['shrink_axis_mask'], 4,
                                                                              list(slice_node['begin_mask']), 0))
        permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'begin_mask')
        self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([0, 0, 0, 0])))

        slice_node['end_mask'] = int64_array(extend_mask_according_ellipsis(slice_node['ellipsis_mask'],
                                                                            slice_node['shrink_axis_mask'], 4,
                                                                            list(slice_node['end_mask']), 0))
        permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'end_mask')
        self.assertTrue(np.array_equal(slice_node.end_mask, np.array([1, 0, 0, 0])))
コード例 #2
0
 def test_extend_mask_twice(self):
     ellipsis_mask = int64_array([1, 0])
     shrink_mask = int64_array([0, 0])
     length_shape = 4
     mask = int64_array([0, 1])
     ins_value = 0
     mask = extend_mask_according_ellipsis(ellipsis_mask, shrink_mask, length_shape, list(mask), ins_value)
     mask = extend_mask_according_ellipsis(ellipsis_mask, shrink_mask, length_shape, list(mask), ins_value)
     self.assertEquals(mask, [0, 0, 0, 1])
コード例 #3
0
 def test_extend_mask_shrinked_shrink_mask(self):
     ellipsis_mask = int64_array([0, 1, 0])
     shrink_mask = int64_array([0, 0, 1])
     length_shape = 4
     ins_value = 2
     shrink_mask = extend_mask_according_ellipsis(ellipsis_mask, shrink_mask, length_shape, list(shrink_mask),
                                                  ins_value)
     self.assertEquals(shrink_mask, [0, 0, 2, 2, 1])
コード例 #4
0
    def test_permute_begin_end_ellipsis_new_inputs(self):
        # Testing constant path case
        graph = build_graph(nodes_attributes,
                            [('input', 'data_1'),
                             ('data_1', 'strided_slice', {'in': 0}),
                             ('begin', 'begin_data'),
                             ('begin_data', 'strided_slice', {'in': 1}),
                             ('end', 'end_data'),
                             ('end_data', 'strided_slice', {'in': 2}),
                             ('stride', 'stride_data'),
                             ('stride_data', 'strided_slice', {'in': 3}),
                             ('strided_slice', 'data_2')],
                            {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
                             'strided_slice': {'begin_mask': np.array([0, 0, 0]), 'end_mask': np.array([1, 0, 0]),
                                               'new_axis_mask': np.array([1, 0, 0]), 'shrink_axis_mask': [0],
                                               'ellipsis_mask': np.array([0, 1, 0])},
                             'begin': {'value': np.array([0, 1, 2])},
                             'end': {'value': np.array([1, 2, 3])},
                             'stride': {'value': np.array([1, 1, 1])},
                             'begin_data': {'value': np.array([0, 1, 2])},
                             'end_data': {'value': np.array([1, 2, 3])},
                             'stride_data': {'value': np.array([1, 1, 1])},
                             'data_2': {'shape': np.array([1, 1, 2, 3, 4]), 'value': None},
                             })

        slice_node = Node(graph, 'strided_slice')
        slice_node.in_node(1).value = int64_array(extend_mask_according_ellipsis(slice_node['ellipsis_mask'],
                                                                                 slice_node['shrink_axis_mask'], 5,
                                                                                 list(slice_node.in_node(1).value), 0))
        slice_node.in_node(1).value = permute_array(slice_node, slice_node.in_node(1).value)
        self.assertTrue(np.array_equal(slice_node.in_node(1).value, np.array([0, 2, 1, 0, 0])))

        slice_node.in_node(2).value = int64_array(extend_mask_according_ellipsis(slice_node['ellipsis_mask'],
                                                                                 slice_node['shrink_axis_mask'], 5,
                                                                                 list(slice_node.in_node(2).value), 0))
        slice_node.in_node(2).value = permute_array(slice_node, slice_node.in_node(2).value)
        self.assertTrue(np.array_equal(slice_node.in_node(2).value, np.array([1, 3, 2, 0, 0])))