def replace_op(self, graph: Graph, node: Node): # save the original node name to use it in the new Pad op instance original_name = node.soft_get('name', node.id) rename_node(node, original_name + '/TBR') new_pad = Pad(graph, { 'mode': node.soft_get('mode', None) }).create_node() rename_node(new_pad, original_name) node.in_port(0).get_connection().set_destination(new_pad.in_port(0)) if node.soft_get('mode') == 'constant': # the input with fill value is an optional third input in ONNX if not node.in_port(2).disconnected(): node.in_port(2).get_connection().set_destination( new_pad.in_port(3)) else: new_pad.in_port(3).connect( Const(graph, { 'value': 0.0 }).create_node().out_port(0)) # convert ONNX representation of the pads as [2 * N] to MO representation: [N] and [N] split_pads = create_op_with_const_inputs(graph, Split, {1: int64_array(0)}, {'num_splits': 2}) node.in_port(1).get_connection().set_destination(split_pads.in_port(0)) split_pads.out_port(0).connect(new_pad.in_port(1)) split_pads.out_port(1).connect(new_pad.in_port(2)) return [new_pad.id]
def find_and_replace_pattern(self, graph: Graph): for attr_pad in graph.get_op_nodes(op='AttributedPad'): # save the original node name to use it in the new Pad op instance original_name = attr_pad.soft_get('name', attr_pad.id) new_pad = Pad(graph, { 'mode': attr_pad.soft_get('mode', None), }).create_node() rename_nodes([(attr_pad, original_name + '/to_be_removed'), (new_pad, original_name)]) attr_pad.in_port(0).get_connection().set_destination( new_pad.in_port(0)) new_pad.in_port(1).connect( Const(graph, { 'value': attr_pad.pads[:, 0] }).create_node().out_port(0)) new_pad.in_port(2).connect( Const(graph, { 'value': attr_pad.pads[:, 1] }).create_node().out_port(0)) if attr_pad.soft_get('mode') == 'constant': # create Constant node of proper data type (equal to the data type of the Pad first input) convert_pad_value = create_op_with_const_inputs( graph, ConvertLike, {0: attr_pad.fill_value}, {'name': original_name + '/pad_value_convert'}) convert_pad_value.in_port(1).connect( new_pad.in_port(0).get_source()) new_pad.in_port(3).connect(convert_pad_value.out_port(0)) attr_pad.out_port(0).get_connection().set_source( new_pad.out_port(0)) graph.remove_node(attr_pad.id)
def find_and_replace_pattern(self, graph: Graph): for tfpad in graph.get_op_nodes(op='TFPad'): # save the original node name to use it in the new Pad op instance original_name = tfpad.soft_get('name', tfpad.id) tfpad['name'] = original_name + '/to_be_removed' new_pad = Pad(graph, {'mode': tfpad.soft_get('mode', None), }).create_node() rename_node(new_pad, original_name) tfpad.in_port(0).get_connection().set_destination(new_pad.in_port(0)) if tfpad.soft_get('mode') == 'constant': # the input with fill value is an optional third input in TF if not tfpad.in_port(2).disconnected(): tfpad.in_port(2).get_connection().set_destination(new_pad.in_port(3)) else: new_pad.in_port(3).connect(Const(graph, {'value': 0.0, 'name': new_pad.name + '/value'} ).create_node().out_port(0)) # convert TF representation of the pads as [N, 2] to MO representation: [N] and [N] transposed_pads = create_op_with_const_inputs(graph, Transpose, {1: int64_array([1, 0])}) tfpad.in_port(1).get_connection().set_destination(transposed_pads.in_port(0)) split_pads = create_op_with_const_inputs(graph, Split, {1: int64_array(0)}, {'num_splits': 2}) transposed_pads.out_port(0).connect(split_pads.in_port(0)) for port_ind in range(2): split_pads.add_output_port(port_ind, skip_if_exist=True) new_pad.in_port(port_ind + 1).connect(split_pads.out_port(port_ind)) new_pad.in_port(port_ind + 1).get_connection().insert_node( create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])})) tfpad.out_port(0).get_connection().set_source(new_pad.out_port(0)) graph.remove_node(tfpad.id)
def find_and_replace_pattern(self, graph: Graph): for attr_pad in graph.get_op_nodes(op='AttributedPad'): # save the original node name to use it in the new Pad op instance original_name = attr_pad.soft_get('name', attr_pad.id) new_pad = Pad(graph, { 'mode': attr_pad.soft_get('mode', None), }).create_node() rename_nodes([(attr_pad, original_name + '/to_be_removed'), (new_pad, original_name)]) attr_pad.in_port(0).get_connection().set_destination( new_pad.in_port(0)) new_pad.in_port(1).connect( Const(graph, { 'value': attr_pad.pads[:, 0] }).create_node().out_port(0)) new_pad.in_port(2).connect( Const(graph, { 'value': attr_pad.pads[:, 1] }).create_node().out_port(0)) if attr_pad.soft_get('mode') == 'constant': new_pad.in_port(3).connect( Const(graph, { 'value': attr_pad.fill_value }).create_node().out_port(0)) attr_pad.out_port(0).get_connection().set_source( new_pad.out_port(0)) graph.remove_node(attr_pad.id)
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='SpaceToBatch') + graph.get_op_nodes(op='BatchToSpace'): node.add_input_port(3, skip_if_exist=True) # convert TF representation of the pads/crops as [N, 2] to IE representation: [N] and [N] transposed_pads = create_op_with_const_inputs(graph, Transpose, {1: int64_array([1, 0])}) node.in_port(2).get_connection().set_destination(transposed_pads.in_port(0)) split_pads = create_op_with_const_inputs(graph, Split, {1: int64_array(0)}, {'num_splits': 2}) transposed_pads.out_port(0).connect(split_pads.in_port(0)) for port_ind in range(2): node.in_port(port_ind + 2).connect(split_pads.out_port(port_ind)) node.in_port(port_ind + 2).get_connection().insert_node( create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])})) # add zeros/ones to related inputs to align it with data input in0_rank = Rank(graph, {'name': node.name + '/rank_0'}).create_node() in1_rank = Shape(graph, {'name': node.name + '/rank_1'}).create_node() diff_size = Sub(graph, {'name': node.name + '/sub_0'}).create_node() diff = Sub(graph, {'name': node.name + '/sub_1'}).create_node() const_begin = Const(graph, {'value': int64_array([1])}).create_node() const_pad_val = Const(graph, {'value': int64_array([1])}).create_node() block_shape = Pad(graph, {'name': node.name + '/aligned_block_shape', 'mode': 'constant'}).create_node() # in case of SpaceToBatch begin = pads_begin, end = pads_end # in case of BatchToSpace begin = crops_begin, end = crops_end new_begin_name = '/aligned_pads_begin' new_end_name = '/aligned_pads_end' if node.type == 'BatchToSpace': new_begin_name = '/aligned_crops_begin' new_end_name = '/aligned_crops_end' begin = Pad(graph, {'name': node.name + new_begin_name, 'mode': 'constant'}).create_node() end = Pad(graph, {'name': node.name + new_end_name, 'mode': 'constant'}).create_node() in0_rank_1d = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]), {'name': node.name + '/1d_rank_of_0'}, in0_rank) in1_rank_1d = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]), {'name': node.name + '/1d_rank_of_1'}, in1_rank) node.in_port(0).get_source().connect(in0_rank.in_port(0)) node.in_port(1).get_source().connect(in1_rank.in_port(0)) in0_rank_1d.out_port(0).connect(diff_size.in_port(0)) in1_rank_1d.out_port(0).connect(diff_size.in_port(1)) diff_size.out_port(0).connect(diff.in_port(0)) const_begin.out_port(0).connect(diff.in_port(1)) const_pad_val.out_port(0).connect(block_shape.in_port(3)) inputs_array = [block_shape, begin, end] for idx, input_to_node in enumerate(inputs_array): node.in_port(idx + 1).get_connection().set_destination(input_to_node.in_port(0)) const_begin.out_port(0).connect(input_to_node.in_port(1)) diff.out_port(0).connect(input_to_node.in_port(2)) input_to_node.out_port(0).connect(node.in_port(idx + 1))