def replace_pattern(self, graph: Graph, match: dict): clamp = match['clamp'] name = clamp.soft_get('name', clamp.id) min_value = max_value = None port_1_exist = clamp.has_port( 'in', 1) and not clamp.in_port(1).disconnected() port_2_exist = clamp.has_port( 'in', 2) and not clamp.in_port(2).disconnected() if port_1_exist and clamp.in_port(1).get_source().node.soft_get( 'type') == 'Const': min_value = clamp.in_port(1).data.get_value() if port_2_exist and clamp.in_port(2).get_source().node.soft_get( 'type') == 'Const': max_value = clamp.in_port(2).data.get_value() rename_node(clamp, name + '/TBR') if min_value is None or max_value is None: max_node = min_node = None if port_1_exist: max_node = Maximum(graph, {}).create_node() clamp.in_port(0).get_connection().set_destination( max_node.in_port(0)) clamp.in_port(1).get_connection().set_destination( max_node.in_port(1)) clamp.out_port(0).get_connection().set_source( max_node.out_port(0)) if port_2_exist: min_node = Minimum(graph, {}).create_node() if max_node is not None: max_node.out_port(0).get_connection().set_source( min_node.out_port(0)) max_node.out_port(0).connect(min_node.in_port(0)) else: clamp.in_port(0).get_connection().set_destination( min_node.in_port(0)) clamp.out_port(0).get_connection().set_source( min_node.out_port(0)) clamp.in_port(2).get_connection().set_destination( min_node.in_port(1)) assert min_node is not None or max_node is not None, 'Clamp node should have either min or max input used' rename_node(min_node if min_node is not None else max_node, name) else: a_clamp = AttributedClamp(graph, { 'name': name, 'min': min_value, 'max': max_value }).create_node() rename_node(a_clamp, name) clamp.in_port(0).get_connection().set_destination( a_clamp.in_port(0)) clamp.out_port(0).get_connection().set_source(a_clamp.out_port(0))
def find_and_replace_pattern(self, graph: Graph): for cbv in graph.get_op_nodes(op='ClipByValueTF'): cbv_name = cbv.soft_get('name', cbv.id) minimum = Minimum(graph, { 'name': cbv_name + '/CLipMinimum' }).create_node() maximum = Maximum(graph, { 'name': cbv_name + '/CLipMaximum' }).create_node() minimum.in_port(0).connect(cbv.in_port(0).get_source()) minimum.in_port(1).connect(cbv.in_port(2).get_source()) maximum.in_port(0).connect(minimum.out_port(0)) maximum.in_port(1).connect(cbv.in_port(1).get_source()) cbv.out_port(0).get_connection().set_source(maximum.out_port(0)) rename_nodes([(cbv, cbv_name + '/TBR'), (maximum, cbv_name)]) graph.remove_node(cbv.id)