Пример #1
0
    def test_transpose_infer_1(self, order):
        graph = self._create_graph_with_transpose(order)
        transpose_node = Node(graph, 'transpose')

        Permute.infer(transpose_node)

        ref = [transpose_node.in_node().shape[i] for i in order]
        self.assertTrue(
            np.array_equal(transpose_node.out_node().shape, np.array(ref)))
Пример #2
0
    def test_transpose_infer_neg_2(self):
        order = None
        graph = self._create_graph_with_transpose(order)
        transpose_node = Node(graph, 'transpose')
        transpose_node['reverse_order'] = False

        Permute.infer(transpose_node)

        ref = None
        self.assertTrue(transpose_node.out_node().shape is None,
                        "Output shape should be None")
Пример #3
0
    def test_transpose_infer_2(self):
        order = None
        graph = self._create_graph_with_transpose(order)
        transpose_node = Node(graph, 'transpose')
        transpose_node['reverse_order'] = True

        Permute.infer(transpose_node)

        ref = np.array([x for x in reversed(transpose_node.in_node().shape)])
        self.assertTrue(
            np.array_equal(transpose_node.out_node().shape, ref),
            "Shapes are not the same: {} and {}".format(
                transpose_node.out_node().shape, ref))
Пример #4
0
 def infer(node: Node):
     node.order = list(range(node.in_node().shape.size))
     node.order[node.dim2], node.order[node.dim1] = node.order[
         node.dim1], node.order[node.dim2]
     Permute.infer(node)