def replace_pattern(self, graph: Graph, match: dict): y = match['maximum'].in_port(0).data.get_value() if y is None: y = match['maximum'].in_port(1).data.get_value() if y is None or y.shape != (): log.debug( 'The value of the "maximum_y_data" is not defined or is not constant' ) return normalize_input_node = match['square'].in_port(0).get_source().node normalize_node = NormalizeOp( graph, { 'name': normalize_input_node.soft_get('name') + '/Normalize', 'eps': y, 'across_spatial': 0, 'channel_shared': 0 }).create_node() weights_node = Const( graph, { 'value': np.ones(shape=int64_array([match['input'].shape[-1]]), dtype=match['input'].data_type) }).create_node() # the normalize_input_node has 2 consumers so it is necessary to disconnect output port first normalize_input_node.out_port(0).disconnect() normalize_input_node.out_port(0).get_connection().set_destination( normalize_node.in_port(0)) weights_node.out_port(0).get_connection().set_destination( normalize_node.in_port(1)) match['l2_normalize'].out_port(0).get_connection().set_source( normalize_node.out_port(0))
def extract(cls, node): proto_layer = node.pb param = proto_layer.norm_param attrs = collect_attributes(param, enable_flattening_nested_params=True) attrs.update(weights_biases(False, node.model_pb)) # update the attributes of the node NormalizeOp.update_node_stat(node, attrs) return cls.enabled
def extract(cls, node): across_spatial = onnx_attr(node, 'across_spatial', 'i', default=0) channel_shared = onnx_attr(node, 'channel_shared', 'i', default=0) eps = onnx_attr(node, 'eps', 'f', default=0) attrs = { 'across_spatial': bool(across_spatial), 'channel_shared': bool(channel_shared), 'eps': eps, 'layout': 'NCHW' } # update the attributes of the node NormalizeOp.update_node_stat(node, attrs) return cls.enabled
def replace_pattern(self, graph: Graph, match: dict): bias_add = match['bias_add'] merge = match['merge'] normalize_node = NormalizeOp( graph, { 'name': merge.name + '/Normalize', 'eps': 1e-6, 'across_spatial': 0, 'channel_shared': 0 }).create_node() # the normalize_input_node has 2 consumers so it is necessary to disconnect output port first bias_add.out_port(0).connect(normalize_node.in_port(0)) merge.in_port(0).disconnect() normalize_node.out_port(0).get_connection().set_destination( merge.in_port(0))
def replace_pattern(self, graph: Graph, match: dict): y = match['maximum'].in_port(0).data.get_value() if y is None: y = match['maximum'].in_port(1).data.get_value() if y is None or y.shape != (): log.debug( 'The value of the "maximum_y_data" is not defined or is not constant' ) return # rename l2_normalize node since it will be no longer output after the transformation output_name = match['l2_normalize'].soft_get('name', match['l2_normalize'].id) normalizel2_name = output_name + '/normalizel2' rename_node(match['l2_normalize'], normalizel2_name) normalize_node = NormalizeOp( graph, { 'name': output_name, 'eps': y, 'across_spatial': 0, 'channel_shared': 0 }).create_node() rename_node(normalize_node, output_name) weights_node = Const( graph, { 'value': np.ones(shape=int64_array([match['input'].shape[-1]]), dtype=match['input'].data_type) }).create_node() match['square'].in_port(0).get_source().connect( normalize_node.in_port(0)) match['square'].in_port(0).disconnect() if match['l2_normalize'].in_port( 0).get_source().node.id == match['rsqrt'].id: match['l2_normalize'].in_port(1).disconnect() else: match['l2_normalize'].in_port(0).disconnect() weights_node.out_port(0).get_connection().set_destination( normalize_node.in_port(1)) match['l2_normalize'].out_port(0).get_connection().set_source( normalize_node.out_port(0))
def extract(cls, node): pb = node.parameters try: collect_until_token(pb, b'<Dim>') except Error: try: pb.seek(0) collect_until_token(pb, b'<InputDim>') except Error: raise Error("Neither <Dim> nor <InputDim> were found") in_dim = read_binary_integer32_token(pb) try: collect_until_token(pb, b'<TargetRms>') target_rms = read_binary_float_token(pb) except Error: # model does not contain TargetRms target_rms = 1.0 try: collect_until_token(pb, b'<AddLogStddev>') add_log = read_binary_bool_token(pb) except Error: # model does not contain AddLogStddev add_log = False if add_log is not False: raise Error( "AddLogStddev True in Normalize component is not supported") scale = target_rms * np.sqrt(in_dim) attrs = { 'eps': 0.00000001, 'across_spatial': 0, 'channel_shared': 1, 'in_dim': in_dim, } embed_input(attrs, 1, 'weights', [scale]) NormalizeOp.update_node_stat(node, attrs) return cls.enabled