def test_positive_matmul_infer(self, A_shape, B_shape, C_shape, transpose_a, transpose_b): graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[ ('A_d', { 'shape': shape_array(A_shape) }), ('B_d', { 'shape': shape_array(B_shape) }), ('mat_mul', { 'transpose_a': transpose_a, 'transpose_b': transpose_b }), ]) node = Node(graph, 'mat_mul') MatMul.infer(node) msg = "MatMul infer failed for case: A_shape={}, B_shape={}, transpose_a={}, transpose_b={} " \ "expected_shape={}, actual_shape={}" self.assertTrue( np.array_equal(graph.node['mat_mul_d']['shape'], shape_array(C_shape)), msg.format(A_shape, B_shape, transpose_a, transpose_b, C_shape, graph.node['mat_mul_d']['shape']))
def test_value_propagation(self, a_shape, a_value, b_shape, b_value, transpose_a, transpose_b): graph = build_graph( nodes_attrs=graph_nodes_attrs, edges=graph_edges, update_attributes={ 'A': {'shape': int64_array(a_shape), 'value': a_value}, 'A_data': {'shape': int64_array(a_shape), 'value': a_value}, 'B': {'shape': int64_array(b_shape), 'value': b_value}, 'B_data': {'shape': int64_array(b_shape), 'value': b_value}, 'matmul': {'transpose_a': transpose_a, 'transpose_b': transpose_b}, 'matmul_data': {'value': None, 'shape': None}, } ) node = Node(graph, 'matmul') MatMul.infer(node) node_data = node.out_port(0).get_destination().data.get_value() a = a_value b = b_value if transpose_a: a = transpose(a) if transpose_b: b = transpose(b) ref_data = np.matmul(a, b) node_data_shape = node_data.shape ref_data_shape = ref_data.shape msg = "Value propagation for 'matmul' node is not correct." self.assertTrue(node_data_shape == ref_data_shape and np.all(node_data == ref_data), msg)