Beispiel #1
0
    def test_transpose_infer_1(self, order):
        graph = self._create_graph_with_transpose(order)
        transpose_node = Node(graph, 'transpose')

        Transpose.infer(transpose_node)

        ref = [transpose_node.in_node().shape[i] for i in order]
        self.assertTrue(
            np.array_equal(transpose_node.out_node().shape, np.array(ref)))
Beispiel #2
0
    def test_transpose_infer_2(self):
        order = None
        graph = self._create_graph_with_transpose(order)
        transpose_node = Node(graph, 'transpose')
        transpose_node['reverse_order'] = True
        Transpose.infer(transpose_node)

        ref = np.array([x for x in reversed(transpose_node.in_node().shape)])
        self.assertTrue(
            np.array_equal(transpose_node.out_node().shape, ref),
            "Shapes are not the same: {} and {}".format(
                transpose_node.out_node().shape, ref))