Example #1
0
    def test_positive_matmul_infer(self, A_shape, B_shape, C_shape,
                                   transpose_a, transpose_b):
        graph = build_graph_with_attrs(nodes_with_attrs=self.nodes,
                                       edges_with_attrs=self.edges,
                                       update_nodes_attributes=[
                                           ('A_d', {
                                               'shape': shape_array(A_shape)
                                           }),
                                           ('B_d', {
                                               'shape': shape_array(B_shape)
                                           }),
                                           ('mat_mul', {
                                               'transpose_a': transpose_a,
                                               'transpose_b': transpose_b
                                           }),
                                       ])
        node = Node(graph, 'mat_mul')
        MatMul.infer(node)

        msg = "MatMul infer failed for case: A_shape={}, B_shape={}, transpose_a={}, transpose_b={} " \
              "expected_shape={}, actual_shape={}"

        self.assertTrue(
            np.array_equal(graph.node['mat_mul_d']['shape'],
                           shape_array(C_shape)),
            msg.format(A_shape, B_shape, transpose_a, transpose_b, C_shape,
                       graph.node['mat_mul_d']['shape']))
 def test_value_propagation(self, a_shape, a_value, b_shape, b_value, transpose_a, transpose_b):
     graph = build_graph(
         nodes_attrs=graph_nodes_attrs,
         edges=graph_edges,
         update_attributes={
             'A': {'shape': int64_array(a_shape), 'value': a_value},
             'A_data': {'shape': int64_array(a_shape), 'value': a_value},
             'B': {'shape': int64_array(b_shape), 'value': b_value},
             'B_data': {'shape': int64_array(b_shape), 'value': b_value},
             'matmul': {'transpose_a': transpose_a, 'transpose_b': transpose_b},
             'matmul_data': {'value': None, 'shape': None},
         }
     )
     node = Node(graph, 'matmul')
     MatMul.infer(node)
     node_data = node.out_port(0).get_destination().data.get_value()
     a = a_value
     b = b_value
     if transpose_a:
         a = transpose(a)
     if transpose_b:
         b = transpose(b)
     ref_data = np.matmul(a, b)
     node_data_shape = node_data.shape
     ref_data_shape = ref_data.shape
     msg = "Value propagation for 'matmul' node is not correct."
     self.assertTrue(node_data_shape == ref_data_shape and np.all(node_data == ref_data), msg)