Ejemplo n.º 1
0
    def find_and_replace_pattern(self, graph: Graph):
        for attr_pad in graph.get_op_nodes(op='AttributedPad'):
            # save the original node name to use it in the new Pad op instance
            original_name = attr_pad.soft_get('name', attr_pad.id)

            new_pad = Pad(graph, {
                'mode': attr_pad.soft_get('mode', None),
            }).create_node()
            rename_nodes([(attr_pad, original_name + '/to_be_removed'),
                          (new_pad, original_name)])

            attr_pad.in_port(0).get_connection().set_destination(
                new_pad.in_port(0))
            new_pad.in_port(1).connect(
                Const(graph, {
                    'value': attr_pad.pads[:, 0]
                }).create_node().out_port(0))
            new_pad.in_port(2).connect(
                Const(graph, {
                    'value': attr_pad.pads[:, 1]
                }).create_node().out_port(0))
            if attr_pad.soft_get('mode') == 'constant':
                new_pad.in_port(3).connect(
                    Const(graph, {
                        'value': attr_pad.fill_value
                    }).create_node().out_port(0))

            attr_pad.out_port(0).get_connection().set_source(
                new_pad.out_port(0))
            graph.remove_node(attr_pad.id)
Ejemplo n.º 2
0
    def find_and_replace_pattern(self, graph: Graph):
        for tfpad in graph.get_op_nodes(op='TFPad'):
            # save the original node name to use it in the new Pad op instance
            original_name = tfpad.soft_get('name', tfpad.id)
            tfpad['name'] = original_name + '/to_be_removed'

            new_pad = Pad(graph, {'mode': tfpad.soft_get('mode', None), }).create_node()
            rename_node(new_pad, original_name)

            tfpad.in_port(0).get_connection().set_destination(new_pad.in_port(0))

            if tfpad.soft_get('mode') == 'constant':
                # the input with fill value is an optional third input in TF
                if not tfpad.in_port(2).disconnected():
                    tfpad.in_port(2).get_connection().set_destination(new_pad.in_port(3))

            # convert TF representation of the pads as [N, 2] to MO representation: [N] and [N]
            transposed_pads = create_op_with_const_inputs(graph, Transpose, {1: int64_array([1, 0])})
            tfpad.in_port(1).get_connection().set_destination(transposed_pads.in_port(0))
            split_pads = create_op_with_const_inputs(graph, Split, {1: int64_array(0)}, {'num_splits': 2})
            transposed_pads.out_port(0).connect(split_pads.in_port(0))
            for port_ind in range(2):
                split_pads.add_output_port(port_ind, skip_if_exist=True)
                new_pad.in_port(port_ind + 1).connect(split_pads.out_port(port_ind))
                new_pad.in_port(port_ind + 1).get_connection().insert_node(
                    create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])}))

            tfpad.out_port(0).get_connection().set_source(new_pad.out_port(0))
            graph.remove_node(tfpad.id)
Ejemplo n.º 3
0
    def find_and_replace_pattern(self, graph: Graph):
        for attr_pad in graph.get_op_nodes(op='AttributedPad'):
            # save the original node name to use it in the new Pad op instance
            original_name = attr_pad.soft_get('name', attr_pad.id)

            new_pad = Pad(graph, {
                'mode': attr_pad.soft_get('mode', None),
            }).create_node()
            rename_nodes([(attr_pad, original_name + '/to_be_removed'),
                          (new_pad, original_name)])

            attr_pad.in_port(0).get_connection().set_destination(
                new_pad.in_port(0))
            new_pad.in_port(1).connect(
                Const(graph, {
                    'value': attr_pad.pads[:, 0]
                }).create_node().out_port(0))
            new_pad.in_port(2).connect(
                Const(graph, {
                    'value': attr_pad.pads[:, 1]
                }).create_node().out_port(0))
            if attr_pad.soft_get('mode') == 'constant':
                # create Constant node of proper data type (equal to the data type of the Pad first input)
                convert_pad_value = create_op_with_const_inputs(
                    graph, ConvertLike, {0: attr_pad.fill_value},
                    {'name': original_name + '/pad_value_convert'})
                convert_pad_value.in_port(1).connect(
                    new_pad.in_port(0).get_source())
                new_pad.in_port(3).connect(convert_pad_value.out_port(0))

            attr_pad.out_port(0).get_connection().set_source(
                new_pad.out_port(0))
            graph.remove_node(attr_pad.id)