def replace_pattern(self, graph: Graph, match: dict): y = match['maximum'].in_port(0).data.get_value() if y is None: y = match['maximum'].in_port(1).data.get_value() if y is None or y.shape != (): log.debug( 'The value of the "maximum_y_data" is not defined or is not constant' ) return normalize_input_node = match['square'].in_port(0).get_source().node normalize_node = NormalizeOp( graph, { 'name': normalize_input_node.soft_get('name') + '/Normalize', 'eps': y, 'across_spatial': 0, 'channel_shared': 0 }).create_node() weights_node = Const( graph, { 'value': np.ones(shape=int64_array([match['input'].shape[-1]]), dtype=match['input'].data_type) }).create_node() # the normalize_input_node has 2 consumers so it is necessary to disconnect output port first normalize_input_node.out_port(0).disconnect() normalize_input_node.out_port(0).get_connection().set_destination( normalize_node.in_port(0)) weights_node.out_port(0).get_connection().set_destination( normalize_node.in_port(1)) match['l2_normalize'].out_port(0).get_connection().set_source( normalize_node.out_port(0))
def replace_pattern(self, graph: Graph, match: dict): bias_add = match['bias_add'] merge = match['merge'] normalize_node = NormalizeOp( graph, { 'name': merge.name + '/Normalize', 'eps': 1e-6, 'across_spatial': 0, 'channel_shared': 0 }).create_node() # the normalize_input_node has 2 consumers so it is necessary to disconnect output port first bias_add.out_port(0).connect(normalize_node.in_port(0)) merge.in_port(0).disconnect() normalize_node.out_port(0).get_connection().set_destination( merge.in_port(0))
def replace_pattern(self, graph: Graph, match: dict): y = match['maximum'].in_port(0).data.get_value() if y is None: y = match['maximum'].in_port(1).data.get_value() if y is None or y.shape != (): log.debug( 'The value of the "maximum_y_data" is not defined or is not constant' ) return # rename l2_normalize node since it will be no longer output after the transformation output_name = match['l2_normalize'].soft_get('name', match['l2_normalize'].id) normalizel2_name = output_name + '/normalizel2' rename_node(match['l2_normalize'], normalizel2_name) normalize_node = NormalizeOp( graph, { 'name': output_name, 'eps': y, 'across_spatial': 0, 'channel_shared': 0 }).create_node() rename_node(normalize_node, output_name) weights_node = Const( graph, { 'value': np.ones(shape=int64_array([match['input'].shape[-1]]), dtype=match['input'].data_type) }).create_node() match['square'].in_port(0).get_source().connect( normalize_node.in_port(0)) match['square'].in_port(0).disconnect() if match['l2_normalize'].in_port( 0).get_source().node.id == match['rsqrt'].id: match['l2_normalize'].in_port(1).disconnect() else: match['l2_normalize'].in_port(0).disconnect() weights_node.out_port(0).get_connection().set_destination( normalize_node.in_port(1)) match['l2_normalize'].out_port(0).get_connection().set_source( normalize_node.out_port(0))