def replace_pattern(self, graph: Graph, match: dict): node = match['op'] if (node.data_format != b'NHWC' or len(node.in_nodes()) != 5 or node.in_node(0).value is not None or # input node.in_node(1).value is None or # scale node.in_node(2).value is None or # offset node.in_node(3).value is not None or # mean node.in_node(4).value is not None or # variance node.in_node(1).value.ndim != 1 or node.in_node(2).value.ndim != 1): return scale_mul = Mul(graph, dict(name=node.name + '/scale_mul_')) shift_add = Add(graph, dict(name=node.name + '/shift_add_')) mean_add = Add(graph, dict(name=node.name + '/mean_add_')) variance_mul = Mul(graph, dict(name=node.name + '/variance_mul_')) neg_const = Const( graph, dict(value=np.array(-1), name=node.name + '/mean_negate_')) mean_negate = Mul(graph, dict(name=node.name + '/mean_negate_')) mean_arg = mean_add.create_node_with_data([ node.in_node(0), mean_negate.create_node_with_data( [node.in_node(3), neg_const.create_node_with_data()]) ]) shift_const = Const( graph, dict(value=node.eps, name=node.name + '/variance_denom_shift_const_')) power_const = Const( graph, dict(value=-0.5, name=node.name + '/variance_denom_power_const_')) variance_denom_shift = Add( graph, dict(name=node.name + '/variance_denom_shift_')) variance_denom_power = Pow( graph, dict(name=node.name + '/variance_denom_power_')) variance_arg = variance_mul.create_node_with_data([ mean_arg, variance_denom_power.create_node_with_data([ variance_denom_shift.create_node_with_data( [node.in_node(4), shift_const.create_node_with_data()]), power_const.create_node_with_data() ]) ]) shift_add.create_node_with_data([ scale_mul.create_node_with_data([variance_arg, node.in_node(1)]), node.in_node(2) ], data_nodes=node.out_node()) node.graph.remove_node(node.id)
def replace_pattern(self, graph: Graph, match: dict): assert match['operator'].has('multiplication_transparent_ports') port = match['operator'].input_ports_with(match['quantized']) assert len(port) >= 1 if len(port) > 1: log.debug( 'BinarizeWeightsM1P1 cannot apply transformation for data {} because it consumed more' ' than once'.format(match['quantized'].name)) return assert len(port) == 1 port = port[0] applicable = [ pair for pair in match['operator'].multiplication_transparent_ports if pair[0] == port ] if len(applicable) == 0: return # Look at 3-rd and 4-th inputs of FakeQuantize -- they have constants that should be passed through. # Assume that the constant that should be passed through is a scalar. quantize = match['quantize'] output_low = quantize.in_node(3) output_high = quantize.in_node(4) quantize_name = quantize.soft_get('name', quantize.id) if not output_low.has_valid('value') and not output_high.has_valid( 'value'): return output_low = output_low.value output_high = output_high.value # This pass is applicable for binarization only. Other intX variants are not relevant. if quantize.levels != 2: return # Recognize two cases: 0/+1 and -1/+1. zp1 = np.all(output_low == 0) or np.all(output_high == 0) m1p1 = np.all(-output_low == output_high) if (not zp1 and not m1p1) or (zp1 and m1p1): log.debug( 'BinarizeWeightsM1P1 cannot apply transformation for data {} because it does\'t has one of' ' 0/+1 or -1/+1 forms.'.format(match['quantized'].name)) return # TODO: Extract real scalar from 3rd and 4th inputs; reusing original tensors is dangerous because # it may have incompatible shape. mult_term = quantize.in_node(3) if np.all( output_high == 0) else quantize.in_node(4) new_shape = Const( graph, { 'name': quantize_name + '/Reshape/Shape', 'value': int64_array([-1, 1, 1]) }).create_node_with_data() reshape = Reshape(graph, { 'name': quantize_name + '/Reshape' }).create_node_with_data([mult_term, new_shape]) # Patch inflow path (by diving by mult_term) # Put a new Pow/Mul combination here: # ---->---- (here)---> data ---> [3rd/4th ports]quantize ---> quantized ---> operator if len(match['quantized'].out_nodes()) > 1: log.debug( 'BinarizeWeightsM1P1: len(match[\'quantized\'].out_nodes()) > 1' ) return power_of_exponent = Const(graph, { 'name': quantize_name + '/DivNormalize/Power', 'value': mo_array(-1.0) }).create_node_with_data() div_op = Pow(graph, {'name': quantize_name + '/DivNormalize'}) div_output = div_op.create_node_with_data( [mult_term, power_of_exponent]) for i in [3, 4]: match['quantize'].insert_node_with_data_before( match['quantize'].in_node(i), Mul, dict(name=quantize_name + '/MulNormalize'), additional_inputs=[div_output], ) match[ 'quantized'].value = None # reset value because it will be recomputed match['quantize'].infer(match['quantize']) # Put a complimentary new Mul node here: operator -->---(here)-----> operator.out_node() match['operator'].insert_node_with_data_after( match['operator'].out_node(), Mul, dict(name=match['operator'].name + '/MulNormalize'), [reshape], ) # Disable 'operator' fusion with linear ops, otherwise it will annihilate changes that we just made match['operator']['can_be_fused'] = False