def test_transpose_infer_1(self, order): graph = self._create_graph_with_transpose(order) transpose_node = Node(graph, 'transpose') Permute.infer(transpose_node) ref = [transpose_node.in_node().shape[i] for i in order] self.assertTrue( np.array_equal(transpose_node.out_node().shape, np.array(ref)))
def test_transpose_infer_neg_2(self): order = None graph = self._create_graph_with_transpose(order) transpose_node = Node(graph, 'transpose') transpose_node['reverse_order'] = False Permute.infer(transpose_node) ref = None self.assertTrue(transpose_node.out_node().shape is None, "Output shape should be None")
def test_transpose_infer_2(self): order = None graph = self._create_graph_with_transpose(order) transpose_node = Node(graph, 'transpose') transpose_node['reverse_order'] = True Permute.infer(transpose_node) ref = np.array([x for x in reversed(transpose_node.in_node().shape)]) self.assertTrue( np.array_equal(transpose_node.out_node().shape, ref), "Shapes are not the same: {} and {}".format( transpose_node.out_node().shape, ref))
def infer(node: Node): node.order = list(range(node.in_node().shape.size)) node.order[node.dim2], node.order[node.dim1] = node.order[ node.dim1], node.order[node.dim2] Permute.infer(node)