def replace_pattern(self, graph: Graph, match: dict): node = match['op'] if (node.data_format != b'NHWC' or len(node.in_nodes()) != 5 or node.in_node(0).value is not None or # input node.in_node(1).value is None or # scale node.in_node(2).value is None or # offset node.in_node(3).value is not None or # mean node.in_node(4).value is not None or # variance node.in_node(1).value.ndim != 1 or node.in_node(2).value.ndim != 1): return scale_mul = Eltwise( graph, dict(operation='mul', name=node.name + '/scale_mul_')) shift_add = Eltwise( graph, dict(operation='sum', name=node.name + '/shift_add_')) mean_add = Eltwise( graph, dict(operation='sum', name=node.name + '/mean_add_')) variance_mul = Eltwise( graph, dict(operation='mul', name=node.name + '/variance_mul_')) mean_negate = Power(graph, dict(scale=-1, name=node.name + '/mean_negate_')) mean_arg = mean_add.create_node_with_data([ node.in_node(0), mean_negate.create_node_with_data([node.in_node(3)]) ]) variance_square = Power( graph, dict(power=2, name=node.name + '/variance_square_')) variance_denom = Power( graph, dict(shift=node.eps, power=-0.5, name=node.name + '/variance_denom_')) variance_arg = variance_mul.create_node_with_data([ mean_arg, variance_denom.create_node_with_data([node.in_node(4)]) ]) shift_add.create_node_with_data([ scale_mul.create_node_with_data([variance_arg, node.in_node(1)]), node.in_node(2) ], data_nodes=node.out_node()) node.graph.remove_node(node.id)
def replace_pattern(self, graph: Graph, match: dict): node = match['minimum'] # Constant propagation case if node.in_node(0).value is not None and node.in_node(1).value is not None: return negate_1 = Power(graph, dict(scale=-1, name=node.name + '/negate1_')) negate_2 = Power(graph, dict(scale=-1, name=node.name + '/negate2_')) maximum = Eltwise(graph, dict(operation='max', name=node.name + '/Max_')) negate_output = Power(graph, dict(scale=-1, name=node.name + '/negate_out_')) negate_output.create_node_with_data( inputs=[maximum.create_node_with_data([negate_1.create_node_with_data([node.in_node(0)]), negate_2.create_node_with_data([node.in_node(1)])])], data_nodes=node.out_node()) # Delete minimum vertex node.graph.remove_node(node.id)
def replace_pattern(self, graph: Graph, match: dict): assert match['operator'].has('multiplication_transparent_ports') port = match['operator'].input_ports_with(match['quantized']) assert len(port) >= 1 if len(port) > 1: log.debug( 'BinarizeWeightsM1P1 cannot apply transformation for data {} because it consumed more' ' than once'.format(match['quantized'].name)) return assert len(port) == 1 port = port[0] applicable = [ pair for pair in match['operator'].multiplication_transparent_ports if pair[0] == port ] if len(applicable) == 0: return # Look at 3-rd and 4-th inputs of Quantize -- they have constants that should be passed through. # Assume that the constant that should be passed through is a scalar. quantize = match['quantize'] output_low = quantize.in_node(3) output_high = quantize.in_node(4) if not output_low.has_valid('value') and not output_high.has_valid( 'value'): return output_low = output_low.value output_high = output_high.value # This pass is applicable for binarization only. Other intX variants are not relevant. if quantize.levels != 2: return # Recognize two cases: 0/+1 and -1/+1. zp1 = np.all(output_low == 0) or np.all(output_high == 0) m1p1 = np.all(-output_low == output_high) if (not zp1 and not m1p1) or (zp1 and m1p1): log.debug( 'BinarizeWeightsM1P1 cannot apply transformation for data {} because it does\'t has one of' ' 0/+1 or -1/+1 forms.'.format(match['quantized'].name)) return # Recognize scalar if len(np.unique(output_low)) != 1 or len(np.unique(output_high)) != 1: log.debug( 'BinarizeWeightsM1P1 cannot apply transformation for data {} because output_low or output_high ' 'cannot be interpreted as scalars.'.format( match['quantized'].name)) return # TODO: Extract real scalar from 3rd and 4th inputs; reusing original tensors is dangerous because # it may have incompatible shape. mult_term = quantize.in_node(3) if np.all( output_high == 0) else quantize.in_node(4) # Patch inflow path (by diving by mult_term) # Put a new Power/Mul combination here: # ---->---- (here)---> data ---> [3rd/4th ports]quantize ---> quantized ---> operator if len(match['quantized'].out_nodes()) > 1: log.debug( 'BinarizeWeightsM1P1: len(match[\'quantized\'].out_nodes()) > 1' ) return div_op = Power(graph, { 'name': quantize.name + '/DivNormalize', 'power': -1.0 }) div_output = div_op.create_node_with_data([mult_term]) for i in [3, 4]: match['quantize'].insert_node_with_data_before( match['quantize'].in_node(i), Mul, dict(name=quantize.name + '/MulNormalize'), additional_inputs=[div_output], ) match[ 'quantized'].value = None # reset value because it will be recomputed match['quantize'].infer(match['quantize']) # Put a complimentary new Mul node here: operator -->---(here)-----> operator.out_node() match['operator'].insert_node_with_data_after( match['operator'].out_node(), Mul, dict(name=match['operator'].name + '/MulNormalize'), [mult_term], ) # Disable 'operator' fusion with linear ops, otherwise it will annihilate changes that we just made match['operator']['can_be_fused'] = False