Ejemplo n.º 1
0
 def test_slice_infer_4(self):
     graph = self.build_test_graph()
     node = Node(graph, 'sslice_1')
     node.in_node(1).value = np.array([0, 10, 10, 0])
     node.begin_mask = [1, 0, 0, 1]  # 6
     tf_strided_slice_infer(node)
     self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 34, 30, 2])), 'Wrong output shape detected')
Ejemplo n.º 2
0
 def test_slice_infer_1(self):
     graph = self.build_test_graph()
     node = Node(graph, 'sslice_1')
     tf_strided_slice_infer(node)
     self.assertTrue(
         np.array_equal(node.out_node().shape, np.array([1, 34, 30, 2])),
         'Wrong output shape detected')
Ejemplo n.º 3
0
 def test_slice_infer_dim_beg(self):
     graph = self.build_test_graph_dim_beg()
     node = Node(graph, 'sslice_1')
     node.shrink_axis_mask = [1]
     tf_strided_slice_infer(node)
     self.assertTrue(np.array_equal(node.out_node().shape, np.array([4])), 'Wrong output shape detected')
     self.assertTrue(np.array_equal(node.out_node().value, np.array([1, 34, 34, 62])), 'Wrong output shape detected')
Ejemplo n.º 4
0
    def infer(node: Node):
        tf_strided_slice_infer(node)

        if node.graph.graph['layout'] == 'NHWC' and node.out_port(
                0).data.get_value() is None:
            PermuteAttrs.create_permute_attrs(
                node,
                attrs=[
                    ('shrink_axis_mask', 'input:0', permute_masks),
                    ('new_axis_mask', 'input:0', permute_masks),
                    ('ellipsis_mask', 'input:0', permute_masks),
                    ('begin_mask', 'input:0', permute_masks),
                    ('end_mask', 'input:0', permute_masks),
                ])
            for i in range(1, len(node.in_nodes())):
                if node.in_node(
                        i).value is not None and node.in_node(i).shape[0] > 3:
                    perm = PermuteAttrs.get_nhwc_to_nchw_permutation(
                        len(node.in_node(0).shape))
                    node.in_node(i).value = permute_array_with_ellipsis(
                        node, perm,
                        node.in_node(i).value, 0)

            # due to permutation from nhwc to nchw we will extend all masks and inputs
            idx = np.nonzero(node.ellipsis_mask)
            node.ellipsis_mask[idx] = 0
Ejemplo n.º 5
0
 def test_slice_infer_2(self):
     graph = self.build_test_graph()
     node = Node(graph, 'sslice_1')
     node.end_mask = [1, 0, 0, 1]  # 6
     tf_strided_slice_infer(node)
     self.assertTrue(
         np.array_equal(node.out_node().shape, np.array([1, 35, 35, 2])),
         'Wrong output shape detected')
Ejemplo n.º 6
0
 def test_slice_infer_neg_end(self):
     graph = self.build_test_graph()
     node = Node(graph, 'sslice_1')
     end_node = Node(graph, 'sslice_end_1')
     end_node.value = np.array([1, -1, -5, -1])
     tf_strided_slice_infer(node)
     self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 34, 30, 2])), 'Wrong output shape detected')
     self.assertTrue(np.array_equal(end_node.value, np.array([1, -1, -5, -1])), 'Negative values in end were converted to positive')
Ejemplo n.º 7
0
 def test_slice_infer_13(self):
     graph = self.build_test_graph2()
     node = Node(graph, 'sslice_1')
     node.in_node(1).value = np.array([1])
     node.shrink_axis_mask = [1]
     tf_strided_slice_infer(node)
     self.assertTrue(np.array_equal(node.out_node().shape, np.array([])), 'Wrong output shape detected')
     self.assertTrue(np.array_equal(node.out_node().value, np.array(34)), 'Wrong output shape detected')
Ejemplo n.º 8
0
 def test_slice_infer_12(self):
     graph = self.build_test_graph()
     node = Node(graph, 'sslice_1')
     node.begin_mask = [0, 0, 0, 0]  # 15
     node.end_mask = [0, 0, 0, 0]  # 15
     node.shrink_axis_mask = [1, 1, 1, 0]  # 7
     tf_strided_slice_infer(node)
     self.assertTrue(np.array_equal(node.out_node().shape, np.array([3])), 'Wrong output shape detected')
Ejemplo n.º 9
0
 def test_slice_infer_7(self):
     graph = self.build_test_graph2()
     node = Node(graph, 'sslice_1')
     node.in_node(1).value = np.array([1])
     node.in_node(2).value = np.array([3])
     tf_strided_slice_infer(node)
     self.assertTrue(np.array_equal(node.out_node().shape, np.array([2])), 'Wrong output shape detected')
     self.assertTrue(np.array_equal(node.out_node().value, np.array([34, 34])), 'Wrong output value detected')
Ejemplo n.º 10
0
 def test_slice_infer_11(self):
     graph = self.build_test_graph()
     node = Node(graph, 'sslice_1')
     node.begin_mask = 15  # 1111
     node.end_mask = 15  # 1111
     node.shrink_axis_mask = 5  # 0101
     tf_strided_slice_infer(node)
     self.assertTrue(
         np.array_equal(node.out_node().shape, np.array([35, 3])),
         'Wrong output shape detected')
Ejemplo n.º 11
0
 def test_slice_infer_8(self):
     graph = self.build_test_graph2()
     node = Node(graph, 'sslice_1')
     node.new_axis_mask = 1
     tf_strided_slice_infer(node)
     self.assertTrue(
         np.array_equal(node.out_node().shape, np.array([1, 4])),
         'Wrong output shape detected')
     self.assertTrue(
         np.array_equal(node.out_node().value, np.array([[1, 34, 34, 62]])),
         'Wrong output value detected')
Ejemplo n.º 12
0
 def test_slice_infer_14(self):
     graph = self.build_test_graph2()
     node = Node(graph, 'sslice_1')
     node.in_node(3).value = np.array([-1])
     node.end_mask = [0]
     node.begin_mask = [0]
     node.in_node(0).shape = [4]
     tf_strided_slice_infer(node)
     self.assertTrue(np.array_equal(node.out_node().shape, np.array([4])), 'Wrong output shape detected')
     print(node.out_node().value)
     self.assertTrue(np.array_equal(node.out_node().value, np.array([62, 34, 34, 1])), 'Wrong output shape detected')
Ejemplo n.º 13
0
    def infer(node: Node):
        tf_strided_slice_infer(node)

        out_shape = node.out_port(0).data.get_shape()
        assert out_shape is not None, \
            'Output shape was not calculated for node {}'.format(node.name)
        # extend inputs according to ellipsis mask and/or input_shape
        for i_port in node.in_ports().values():
            if i_port.idx == 0 or i_port.disconnected():
                continue
            old_value = i_port.data.get_value()
            # additional check for non-const input
            # error will be return in shape inference if non-const will be added
            # it is paranoid check for case if shape inference will be changed
            assert old_value is not None, \
                '{} input of {} node is not constant: \'value\' attribute for edge ' + \
                'contains None'.format(i_port.idx, node.name)
            # insert 0 for begin and end and 1 for stride
            new_value = int64_array(extend_mask_according_ellipsis(node.ellipsis_mask, node.shrink_axis_mask,
                                                                   len(out_shape), list(old_value),
                                                                   int(i_port.idx == 3)))
            # set_value additionally set_shape and propagate value to Const node
            if not np.array_equal(new_value, old_value):
                i_port.data.set_value(new_value)

        # extend masks before removing ellipsis
        for attr in ["new_axis_mask", "shrink_axis_mask", "begin_mask", "end_mask", "ellipsis_mask"]:
            node[attr] = int64_array(extend_mask_according_ellipsis(node.ellipsis_mask, node.shrink_axis_mask,
                                                                    len(out_shape), list(node[attr]), 0))

        # we will extend all masks and inputs to simplify future transformations
        idx = np.nonzero(node.ellipsis_mask)
        node.ellipsis_mask[idx] = 0

        if node.graph.graph['layout'] == 'NHWC' and node.out_port(0).data.get_value() is None:
            PermuteAttrs.create_permute_attrs(node, attrs=[('shrink_axis_mask', 'input:0', permute_masks),
                                                           ('new_axis_mask', 'input:0', permute_masks),
                                                           ('ellipsis_mask', 'input:0', permute_masks),
                                                           ('begin_mask', 'input:0', permute_masks),
                                                           ('end_mask', 'input:0', permute_masks),
                                                           ])
            # permute inputs
            in_shape = node.in_port(0).get_source().data.get_shape()
            assert in_shape is not None, \
                'Input shape is unknown for 0 input of node {}'.format(node.name)
            input_rank = len(in_shape)
            if input_rank > 3:
                for i_port in node.in_ports().values():
                    if i_port.idx == 0 or i_port.disconnected():
                        continue
                    new_value = permute_array(node, i_port.data.get_value())
                    # set_value additionally set_shape and propagate value to Const node
                    i_port.data.set_value(new_value)
Ejemplo n.º 14
0
    def infer(node: Node):
        tf_strided_slice_infer(node)

        if node.graph.graph['layout'] == 'NHWC' and node.out_port(
                0).data.get_value() is None:
            PermuteAttrs.create_permute_attrs(
                node,
                attrs=[
                    ('shrink_axis_mask', 'input:0', permute_masks),
                    ('new_axis_mask', 'input:0', permute_masks),
                    ('ellipsis_mask', 'input:0', permute_masks),
                    ('begin_mask', 'input:0', permute_masks),
                    ('end_mask', 'input:0', permute_masks),
                ])
            for i in range(1, len(node.in_nodes())):
                if node.in_node(i).value is not None and len(
                        node.in_node(0).shape) > 3:
                    node.in_node(i).value = permute_array_with_ellipsis(
                        node,
                        node.in_node(i).value, 0)

            # extend masks before removing ellipsis
            if np.any(node.ellipsis_mask):
                for attr in [
                        "new_axis_mask", "shrink_axis_mask", "begin_mask",
                        "end_mask"
                ]:
                    node[attr] = int64_array(
                        extend_mask_according_ellipsis(
                            node.ellipsis_mask, node.shrink_axis_mask,
                            len(node.out_port(0).data.get_shape()),
                            list(node[attr]), attr
                            in ["begin_mask", "end_mask"]))

            # due to permutation from nhwc to nchw we will extend all masks and inputs
            idx = np.nonzero(node.ellipsis_mask)
            node.ellipsis_mask[idx] = 0