def replace_op(self, graph: nx.MultiDiGraph, node: Node): reciprocal = Power( graph, dict(scale=1, power=-1, shift=0, name=node.name + '/power_')) out_node = reciprocal.create_node([node.in_node(0)]) return [out_node.id]
def replace_pattern(graph: Graph, match: [str, Node]): consumers = [ n for n in match if n not in ['mul_add', 'pow_d'] and not check_node_usages_out_of_match(match, n) ] if consumers: log.warning( 'Power(mul_add,pow) pattern was detected. Non pattern consumers of nodes: "{}" were found.' ' Won\'t replace'.format(', '.join( [match[n].id for n in consumers]))) return mul_add = match['mul_add'] pow = match['pow'] new_power = Power( graph, { 'name': mul_add.name + '/fused_power', 'shift': mul_add.shift, 'scale': mul_add.scale, 'power': pow.power }).create_node() source = mul_add.in_port(0).get_connection().get_source() mul_add.in_port(0).disconnect() new_power.in_port(0).connect(source) pow.out_port(0).get_connection().set_source(new_power.out_port(0)) log.debug( 'Power nodes {} and {} were fused to single Power node {}'.format( mul_add.name, pow.name, new_power.name))
def replace_op(self, graph: Graph, node: Node): power = Power(graph, dict(scale=0, name=node.name + '/Power/')).create_node() # Reconnecting inputs to this new node node.in_port(0).get_connection().set_destination(power.in_port(0)) node.out_port(0).get_connection().set_source(power.out_port(0)) return [power.id]
def _create_sub(graph: Graph, input_1: Node, port_1: int, input_2: Node, port_2: int): negate = Power(graph, dict(scale=-1, name=input_2.name + '/negate_')) add = Eltwise(graph, dict(operation='sum', name=input_1.name + '/add_')) out_node = add.create_node([(input_1, port_1), negate.create_node([(input_2, port_2)])]) return out_node
def test_power_two_input_infer3(self): graph = self.create_graph(single_input=False) power_node = Node(graph, 'power') input2 = Node(graph, 'input2') input2.value = None Power.infer(power_node) self.assertIsNone(power_node.out_node().shape)
def replace_op(self, graph: nx.MultiDiGraph, node: Node): negate = Power(graph, dict(scale=-1, name=node.name + '/negate_')) add = Eltwise(graph, dict(operation='sum', name=node.name + '/add_')) out_node = add.create_node([ (node.in_node(0), node.in_edge(0)['out']), negate.create_node([(node.in_node(1), node.in_edge(1)['out'])]) ]) # Replace edge from out port 0 of the matched node with a edge from node out_node.id with port 0. # The "explicit" version of the return value is: [(out_node.id, 0)]) return [out_node.id]
def test_power_two_input_infer1(self): graph = self.create_graph(single_input=False) graph.graph['layout'] = 'NCHW' power_node = Node(graph, 'power') Power.infer(power_node) self.assertTrue( np.array_equal(power_node.out_node().shape, power_node.in_node(0).shape))
def replace_pattern(graph: Graph, match: dict): node = match['op'] shape = node.in_port(0).data.get_shape().copy() assert shape[1] % node.group == 0 power_node = Power(graph, attrs={'name': node.id + '_power', 'power': node.p}).create_node() reshape_node = create_op_node_with_second_input(graph, Reshape, int64_array([shape[0], shape[1] / node.group, node.group]), {'name': node.id + '_reshape'}) reshape_node.in_port(0).connect(power_node.out_port(0)) reducesum_node = create_op_node_with_second_input(graph, ReduceSum, int64_array([2]), {'name': node.id + '_sum', 'keep_dims': False}) reducesum_node.in_port(0).connect(reshape_node.out_port(0)) invpower_node = Power(graph, attrs={'name': node.id + '_invpower', 'power': 1.0 / node.p}).create_node() invpower_node.in_port(0).connect(reducesum_node.out_port(0)) node.in_port(0).get_connection().set_destination(power_node.in_port(0)) node.out_port(0).get_connection().set_source(invpower_node.out_port(0))
def replace_op(self, graph: Graph, node: Node): reciprocal = Power(graph, {'scale': 1, 'power': np.float64(-1), 'shift': 0, 'name': node.name + '/reciprocal_'}).create_node() mul = Eltwise(graph, {'operation': 'mul', 'name': node.name + '/mul_'}).create_node() # Connect nodes node.in_port(1).get_connection().set_destination(reciprocal.in_port(0)) node.in_port(0).get_connection().set_destination(mul.in_port(1)) reciprocal.out_port(0).connect(mul.in_port(0)) # The "explicit" version of the return value is: [(out_node.id, 0)]) return [mul.id]
def extract(node: Node): pb = node.pb assert pb, 'Protobuf layer can not be empty' param = pb.power_param attrs = { 'output_spatial_shape': None, 'power': param.power, 'scale': param.scale, 'shift': param.shift, } Power.update_node_stat(node, attrs) return __class__.enabled
def replace_op(self, graph: nx.MultiDiGraph, node: Node): reciprocal = Power( graph, dict(scale=1, power=np.float64(-1), shift=0, name=node.name + '/reciprocal_')) mul = Eltwise(graph, dict(operation='mul', name=node.name + '/mul_')) out_node = mul.create_node([ (node.in_node(0), node.in_edge(0)['out']), reciprocal.create_node([(node.in_node(1), node.in_edge(1)['out'])]) ]) # Replace edge from out port 0 of the matched node with a edge from node out_node.id with port 0. # The "explicit" version of the return value is: [(out_node.id, 0)]) return [out_node.id]
def replace_pattern(self, graph: Graph, match: dict): node = match['minimum'] # Constant propagation case if node.in_node(0).value is not None and node.in_node(1).value is not None: return negate_1 = Power(graph, dict(scale=-1, name=node.name + '/negate1_')) negate_2 = Power(graph, dict(scale=-1, name=node.name + '/negate2_')) maximum = Eltwise(graph, dict(operation='max', name=node.name + '/Max_')) negate_output = Power(graph, dict(scale=-1, name=node.name + '/negate_out_')) negate_output.create_node_with_data( inputs=[maximum.create_node_with_data([negate_1.create_node_with_data([node.in_node(0)]), negate_2.create_node_with_data([node.in_node(1)])])], data_nodes=node.out_node()) # Delete minimum vertex node.graph.remove_node(node.id)
def replace_pattern(graph: Graph, match: [str, Node]): op = match['op'] op_type = op.type const_port, tensor_port = get_value_in_port(op), get_tensor_in_port(op) if const_port is None or tensor_port is None: return value = const_port.data.get_value() assert value is not None if value.size != 1: return value = value.item(0) assert op_type in EltwisesWithScalarInputToPower.eltw_types if op_type == 'Add': delete_node = value == 0 Power.update_node_stat(op, {'shift': value}) elif op_type == 'Multiply': delete_node = value == 1 Power.update_node_stat(op, {'scale': value}) elif op_type == 'Pow': delete_node = value == 1 Power.update_node_stat(op, {'power': value}) const_port.disconnect() if tensor_port.idx != 0: tensor_port.get_connection().set_destination(op.in_port(0))
def replace_pattern(self, graph: Graph, match: dict): node = match['op'] if (node.data_format != b'NHWC' or len(node.in_nodes()) != 5 or node.in_node(0).value is not None or # input node.in_node(1).value is None or # scale node.in_node(2).value is None or # offset node.in_node(3).value is not None or # mean node.in_node(4).value is not None or # variance node.in_node(1).value.ndim != 1 or node.in_node(2).value.ndim != 1): return scale_mul = Eltwise( graph, dict(operation='mul', name=node.name + '/scale_mul_')) shift_add = Eltwise( graph, dict(operation='sum', name=node.name + '/shift_add_')) mean_add = Eltwise( graph, dict(operation='sum', name=node.name + '/mean_add_')) variance_mul = Eltwise( graph, dict(operation='mul', name=node.name + '/variance_mul_')) mean_negate = Power(graph, dict(scale=-1, name=node.name + '/mean_negate_')) mean_arg = mean_add.create_node_with_data([ node.in_node(0), mean_negate.create_node_with_data([node.in_node(3)]) ]) variance_square = Power( graph, dict(power=2, name=node.name + '/variance_square_')) variance_denom = Power( graph, dict(shift=node.eps, power=-0.5, name=node.name + '/variance_denom_')) variance_arg = variance_mul.create_node_with_data([ mean_arg, variance_denom.create_node_with_data([node.in_node(4)]) ]) shift_add.create_node_with_data([ scale_mul.create_node_with_data([variance_arg, node.in_node(1)]), node.in_node(2) ], data_nodes=node.out_node()) node.graph.remove_node(node.id)
def extract(cls, node): Power.update_node_stat(node, {'power': -0.5}) return cls.enabled
def extract(node): # update the attributes of the node Power.update_node_stat(node, {'power': 1 / 2, 'op': SqrtExtractor.op}) return __class__.enabled
def extract(node): Power.update_node_stat(node, {'power': -0.5}) return __class__.enabled
def extract(node): # update the attributes of the node Power.update_node_stat(node, {'power': 2}) return __class__.enabled
def extract(node): Power.update_node_stat(node, {'scale': 0}) return __class__.enabled
def extract(node: Node): Power.update_node_stat(node) return __class__.enabled
def replace_pattern(self, graph: Graph, match: dict): const = 0.99 merge = match['merge'] digits = significant_digits() pnorm = Power( graph, { 'name': merge.name + '/reciprocal_', 'type': 'PNORM', 'significant': digits[0], 'to_significant': digits[1], 'scale': 1, 'shift': 0, 'power': get_power_attr() }).create_node() merge.in_port(0).get_connection().set_destination(pnorm.in_port(0)) in_shape = pnorm.in_port(0).data.get_shape() in_shape = list(in_shape) in_shape.insert(0, 1) reshape1 = Reshape(graph, { 'name': merge.name + '/Reshape_Node1' }).create_node() reshape_dim1 = Const(graph, { 'value': np.array(in_shape), 'name': merge.id + '/Reshape_Dim1' }).create_node() pnorm.out_port(0).connect(reshape1.in_port(0)) reshape1.in_port(1).connect(reshape_dim1.out_port(0)) concat_node = Concat( graph, { 'axis': 0, 'name': merge.name + '/Concat_', 'override_output_shape': True }).create_node() const3 = Const(graph, { 'name': merge.name + '/const_reduce', 'value': 0 }).create_node() for ii, idx in enumerate( range(pnorm.significant, pnorm.to_significant + 1, 1)): const_node = Const( graph, { 'value': float_array(math.pow(const, idx)), 'name': merge.name + '/Const_' + ii.__str__() }).create_node() mul_node = Mul(graph, { 'name': merge.name + '/Mul_' + ii.__str__() }).create_node() const_node.out_port(0).connect(mul_node.in_port(0)) reshape1.out_port(0).connect( mul_node.in_port(1)) # connect to the graph node mul_node2 = Mul(graph, { 'name': merge.name + '/Mul_Div_' + ii.__str__() }).create_node() const_node2 = Const( graph, { 'value': float_array(math.pow(const, -1 * idx)), 'name': merge.name + '/Const_Pow_' + ii.__str__() }).create_node() cast_node = ExpOp(graph, { 'name': merge.name + '/Exp_' + idx.__str__() }).create_node() mul_node.out_port(0).connect(cast_node.in_port(0)) const_node2.out_port(0).connect(mul_node2.in_port(1)) cast_node.out_port(0).connect(mul_node2.in_port(0)) concat_node.add_input_port(ii, skip_if_exist=True) concat_node.in_port(ii).get_connection().set_source( mul_node2.out_port(0)) in_shape = pnorm.in_port(0).data.get_shape() in_shape = list(in_shape) reducesum_node = ReduceMean( graph, { 'name': merge.id + '/_pnorm_reduced_sum', 'keep_dims': True, 'in_ports_count': 2, 'shape': in_shape, 'axis': 0, 'need_shape_inference': None, 'infer': reduce_infer }).create_node() const3.out_port(0).connect(reducesum_node.in_port(1)) reducesum_node.in_port(0).get_connection().set_source( concat_node.out_port(0)) reshape = Reshape(graph, { 'name': merge.name + '/Reshape_Node' }).create_node() reshape_dim = Const(graph, { 'value': np.array(in_shape), 'name': merge.id + '/Reshape_Dim' }).create_node() reducesum_node.out_port(0).connect(reshape.in_port(0)) reshape.in_port(1).connect(reshape_dim.out_port(0)) merge.out_port(0).get_connection().set_source(reshape.out_port(0))
def extract(cls, node): Power.update_node_stat(node, {'scale': 0}) return cls.enabled
def extract(node: Node): scale = onnx_attr(node, 'scale', 'f', default=np.array(1.0), dst_type=lambda x: np.array(x)) Power.update_node_stat(node, {'scale': scale}) return __class__.enabled
def extract(cls, node): # update the attributes of the node Power.update_node_stat(node, {'power': 2}) return cls.enabled
def replace_pattern(self, graph: Graph, match: dict): assert match['operator'].has('multiplication_transparent_ports') port = match['operator'].input_ports_with(match['quantized']) assert len(port) >= 1 if len(port) > 1: log.debug( 'BinarizeWeightsM1P1 cannot apply transformation for data {} because it consumed more' ' than once'.format(match['quantized'].name)) return assert len(port) == 1 port = port[0] applicable = [ pair for pair in match['operator'].multiplication_transparent_ports if pair[0] == port ] if len(applicable) == 0: return # Look at 3-rd and 4-th inputs of Quantize -- they have constants that should be passed through. # Assume that the constant that should be passed through is a scalar. quantize = match['quantize'] output_low = quantize.in_node(3) output_high = quantize.in_node(4) if not output_low.has_valid('value') and not output_high.has_valid( 'value'): return output_low = output_low.value output_high = output_high.value # This pass is applicable for binarization only. Other intX variants are not relevant. if quantize.levels != 2: return # Recognize two cases: 0/+1 and -1/+1. zp1 = np.all(output_low == 0) or np.all(output_high == 0) m1p1 = np.all(-output_low == output_high) if (not zp1 and not m1p1) or (zp1 and m1p1): log.debug( 'BinarizeWeightsM1P1 cannot apply transformation for data {} because it does\'t has one of' ' 0/+1 or -1/+1 forms.'.format(match['quantized'].name)) return # Recognize scalar if len(np.unique(output_low)) != 1 or len(np.unique(output_high)) != 1: log.debug( 'BinarizeWeightsM1P1 cannot apply transformation for data {} because output_low or output_high ' 'cannot be interpreted as scalars.'.format( match['quantized'].name)) return # TODO: Extract real scalar from 3rd and 4th inputs; reusing original tensors is dangerous because # it may have incompatible shape. mult_term = quantize.in_node(3) if np.all( output_high == 0) else quantize.in_node(4) # Patch inflow path (by diving by mult_term) # Put a new Power/Mul combination here: # ---->---- (here)---> data ---> [3rd/4th ports]quantize ---> quantized ---> operator if len(match['quantized'].out_nodes()) > 1: log.debug( 'BinarizeWeightsM1P1: len(match[\'quantized\'].out_nodes()) > 1' ) return div_op = Power(graph, { 'name': quantize.name + '/DivNormalize', 'power': -1.0 }) div_output = div_op.create_node_with_data([mult_term]) for i in [3, 4]: match['quantize'].insert_node_with_data_before( match['quantize'].in_node(i), Mul, dict(name=quantize.name + '/MulNormalize'), additional_inputs=[div_output], ) match[ 'quantized'].value = None # reset value because it will be recomputed match['quantize'].infer(match['quantize']) # Put a complimentary new Mul node here: operator -->---(here)-----> operator.out_node() match['operator'].insert_node_with_data_after( match['operator'].out_node(), Mul, dict(name=match['operator'].name + '/MulNormalize'), [mult_term], ) # Disable 'operator' fusion with linear ops, otherwise it will annihilate changes that we just made match['operator']['can_be_fused'] = False