示例#1
0
 def extract(cls, einsum_node):
     einsum_name = einsum_node.soft_get('name', einsum_node.id)
     attrs = get_mxnet_layer_attrs(einsum_node.symbol_dict)
     equation = attrs.str('subscripts')
     normalized_equation = Einsum.normalize_equation(einsum_name, equation)
     Einsum.update_node_stat(einsum_node, {'equation': normalized_equation})
     return cls.enabled
示例#2
0
 def extract(cls, einsum_node):
     einsum_name = einsum_node.soft_get('name', einsum_node.id)
     equation = onnx_attr(einsum_node, 'equation',
                          's').decode(encoding="utf-8")
     normalized_equation = Einsum.normalize_equation(einsum_name, equation)
     Einsum.update_node_stat(einsum_node, {'equation': normalized_equation})
     return cls.enabled
示例#3
0
    def test_einsum(self, input_shapes, equation, ref_output_shape):
        graph = create_einsum_graph(input_shapes, equation)
        einsum_node = Node(graph, 'einsum_node')
        Einsum.infer(einsum_node)

        # get the result
        res_output_shape = graph.node['einsum_node_d']['shape']

        self.assertTrue(
            np.array_equal(ref_output_shape, res_output_shape),
            'shape does not match expected: {} and given: {}'.format(
                ref_output_shape, res_output_shape))
    def find_and_replace_pattern(self, graph: Graph):
        import openvino.tools.mo.middle.InsertLayoutPropagationTransposes as InsertTransposes
        for einsum in graph.get_op_nodes(type='Einsum'):
            einsum_name = einsum.soft_get('name', einsum.id)
            assert einsum.has_valid('equation'), "Equation attribute is mandatory" \
                                                 " for Einsum node {}".format(einsum_name)
            equation = einsum.equation
            connected_in_ports = [
                port for port in einsum.in_ports().values()
                if not port.disconnected()
            ]
            num_inputs = len(connected_in_ports)

            # check if correct_data_layout attribute is set for inputs and output
            # this attribute can be set up within MarkSubgraphWithCorrectLayout transformation
            # for example, when Einsum is located near to MatMul operation in a graph
            input_correct_layout_mask = []
            for input_ind in range(num_inputs):
                input_correct_layout_mask.append(
                    is_input_data_in_correct_layout(einsum, input_ind))
            is_output_layout_correct = is_output_data_in_correct_layout(
                einsum, 0)

            # compute a mask of which inputs/output are adjusted to the required layout
            # if they are not adjusted, it means to require transpose
            input_ranks = [
                len(einsum.in_port(port_idx).data.get_shape())
                for port_idx in range(num_inputs)
            ]
            output_rank = len(einsum.out_port(0).data.get_shape())
            permuted_equation, are_inputs_adjusted, is_output_adjusted = Einsum.adjust_equation_with_NCHW_layout(
                einsum_name, equation, input_ranks, output_rank,
                input_correct_layout_mask, is_output_layout_correct)
            assert len(are_inputs_adjusted) == num_inputs

            # setup adjusted equation
            einsum.equation = permuted_equation

            # insert Transpose node to get NHWC layout back (for inputs) that is required due to specifics of equation
            for input_ind in range(num_inputs):
                if not are_inputs_adjusted[input_ind]:
                    # that means Einsum can only accept input in NHWC layout
                    # so the inserted transpose before the Einsum will convert the layout to NHWC
                    InsertTransposes.insert_transpose(
                        graph, einsum.in_port(input_ind), before_input=True)
            if not is_output_adjusted:
                # that means Einsum can only generate output in NHWC layout
                # so the inserted transpose followed after the output will convert the layout back into NCHW layout
                InsertTransposes.insert_transpose(graph,
                                                  einsum.out_port(0),
                                                  before_input=False)
示例#5
0
 def extract(cls, einsum_node):
     einsum_name = einsum_node.soft_get('name', einsum_node.id)
     equation = einsum_node.pb.attr['equation'].s.decode('utf-8')
     normalized_equation = Einsum.normalize_equation(einsum_name, equation)
     Einsum.update_node_stat(einsum_node, {'equation': normalized_equation})
     return cls.enabled