def replace_sub_graph(self, graph: Graph, match: dict): op = match['op'] out_port = op.in_port(0).get_source() if op.soft_get('scale', 1) != 1: const = Const(graph, {'value': np.array(op.scale)}).create_node() mul = Mul(graph, {'name': op.name + '/mul_'}).create_node() const.out_port(0).connect(mul.in_port(1)) out_port.connect(mul.in_port(0)) out_port = mul.out_port(0) if op.soft_get('shift', 0) != 0: const = Const(graph, {'value': np.array(op.shift)}).create_node() add = Add(graph, {'name': op.name + '/add_'}).create_node() const.out_port(0).connect(add.in_port(1)) out_port.connect(add.in_port(0)) out_port = add.out_port(0) if op.soft_get('power', 1) != 1: const = Const(graph, {'value': np.array(op.power)}).create_node() pow = Pow(graph, {'name': op.name + '/pow_'}).create_node() const.out_port(0).connect(pow.in_port(1)) out_port.connect(pow.in_port(0)) out_port = pow.out_port(0) op.out_port(0).get_connection().set_source(out_port)
def replace_op(self, graph: Graph, node: Node): pow_2 = Const(graph, {'value': np.float32(2.0)}).create_node() reduce_axis = Const(graph, {'value': np.int32(-1)}).create_node() pow_0_5 = Const(graph, {'value': np.float32(0.5)}).create_node() sq = Pow(graph, dict(name=node.in_node(0).name + '/sq', power=2.0)).create_node([node.in_node(0), pow_2]) sum = ReduceSum(graph, dict(name=sq.name + '/sum')).create_node([sq, reduce_axis]) sqrt = Pow(graph, dict(name=sum.name + '/sqrt', power=0.5)).create_node([sum, pow_0_5]) return [sqrt.id]
def replace_pattern(self, graph: Graph, match: dict): node = match['op'] if (node.data_format != b'NHWC' or len(node.in_nodes()) != 5 or node.in_node(0).value is not None or # input node.in_node(1).value is None or # scale node.in_node(2).value is None or # offset node.in_node(3).value is not None or # mean node.in_node(4).value is not None or # variance node.in_node(1).value.ndim != 1 or node.in_node(2).value.ndim != 1): return scale_mul = Mul(graph, dict(name=node.name + '/scale_mul_')) shift_add = Add(graph, dict(name=node.name + '/shift_add_')) mean_add = Add(graph, dict(name=node.name + '/mean_add_')) variance_mul = Mul(graph, dict(name=node.name + '/variance_mul_')) neg_const = Const( graph, dict(value=np.array(-1), name=node.name + '/mean_negate_')) mean_negate = Mul(graph, dict(name=node.name + '/mean_negate_')) mean_arg = mean_add.create_node_with_data([ node.in_node(0), mean_negate.create_node_with_data( [node.in_node(3), neg_const.create_node_with_data()]) ]) shift_const = Const( graph, dict(value=node.eps, name=node.name + '/variance_denom_shift_const_')) power_const = Const( graph, dict(value=-0.5, name=node.name + '/variance_denom_power_const_')) variance_denom_shift = Add( graph, dict(name=node.name + '/variance_denom_shift_')) variance_denom_power = Pow( graph, dict(name=node.name + '/variance_denom_power_')) variance_arg = variance_mul.create_node_with_data([ mean_arg, variance_denom_power.create_node_with_data([ variance_denom_shift.create_node_with_data( [node.in_node(4), shift_const.create_node_with_data()]), power_const.create_node_with_data() ]) ]) shift_add.create_node_with_data([ scale_mul.create_node_with_data([variance_arg, node.in_node(1)]), node.in_node(2) ], data_nodes=node.out_node()) node.graph.remove_node(node.id)
def replace_op(self, graph: Graph, node: Node): const = Const( graph, dict(value=mo_array(-1.), name=node.name + '/reciprocal_pow_const_')).create_node() reciprocal = Pow(graph, { 'name': node.name + '/reciprocal_pow_' }).create_node() node.in_port(0).get_connection().set_destination(reciprocal.in_port(0)) const.out_port(0).connect(reciprocal.in_port(1)) return [reciprocal.id]
def placeholder_scales(self, placeholder: Node): """ Helper function to get scales for prior boxes out of input image size: [1 / im_width, 1 / im_height, 1 / im_width, 1 / im_height] """ graph = placeholder.graph name = placeholder.soft_get('name', placeholder.id) shape_value = placeholder.soft_get('shape', None) assert shape_value is not None, \ "[ {} replacer ] Placeholder `{}` should have shape attribute".format(self.replacement_id, name) assert isinstance(shape_value, np.ndarray), \ "[ {} replacer ] Placeholder `{}` shape attribute should be np.ndarray".format(self.replacement_id, name) assert shape_value.size == 4, \ "[ {} replacer ] Placeholder `{}` should be 4D. Shape: {}".format(self.replacement_id, name, shape_value) shape = Shape(graph, {'name': 'input_image_shape'}).create_node() shape.in_port(0).connect(placeholder.out_port(0)) begin = Const(graph, {'value': int64_array([1])}).create_node() end = Const(graph, {'value': int64_array([3])}).create_node() stride = Const(graph, {'value': int64_array([1])}).create_node() spatial = StridedSlice(graph, {'name': name + '/get_h_w', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0])}).create_node() spatial.in_port(0).connect(shape.out_port(0)) spatial.in_port(1).connect(begin.out_port(0)) spatial.in_port(2).connect(end.out_port(0)) spatial.in_port(3).connect(stride.out_port(0)) power = Const(graph, {'value': float32_array([-1.])}).create_node() spatial_scale = Pow(graph, {}).create_node() spatial_scale.in_port(0).connect(spatial.out_port(0)) spatial_scale.in_port(1).connect(power.out_port(0)) # Power `type_infer` requires inputs to have equal data type convert_to_fp32 = Cast(graph, {'dst_type': np.float32}).create_node() spatial_scale.in_port(0).get_connection().insert_node(convert_to_fp32) order = Const(graph, {'value': int64_array([1, 0])}).create_node() axis_const = Const(graph, {'value': int64_array(0)}).create_node() reverse = Gather(graph, {}).create_node() reverse.in_port(0).connect(spatial_scale.out_port(0)) reverse.in_port(1).connect(order.out_port(0)) axis_const.out_port(0).connect(reverse.in_port(2)) priors_scale_node = Concat(graph, {'axis': 0, 'in_ports_count': 2}).create_node() priors_scale_node.add_input_port(0, skip_if_exist=True) priors_scale_node.add_input_port(1, skip_if_exist=True) priors_scale_node.in_port(0).connect(reverse.out_port(0)) priors_scale_node.in_port(1).connect(reverse.out_port(0)) return priors_scale_node
def extract(cls, node: Node): Pow.update_node_stat(node) return cls.enabled
def replace_pattern(self, graph: Graph, match: dict): assert match['operator'].has('multiplication_transparent_ports') port = match['operator'].input_ports_with(match['quantized']) assert len(port) >= 1 if len(port) > 1: log.debug( 'BinarizeWeightsM1P1 cannot apply transformation for data {} because it consumed more' ' than once'.format(match['quantized'].name)) return assert len(port) == 1 port = port[0] applicable = [ pair for pair in match['operator'].multiplication_transparent_ports if pair[0] == port ] if len(applicable) == 0: return # Look at 3-rd and 4-th inputs of FakeQuantize -- they have constants that should be passed through. # Assume that the constant that should be passed through is a scalar. quantize = match['quantize'] output_low = quantize.in_node(3) output_high = quantize.in_node(4) quantize_name = quantize.soft_get('name', quantize.id) if not output_low.has_valid('value') and not output_high.has_valid( 'value'): return output_low = output_low.value output_high = output_high.value # This pass is applicable for binarization only. Other intX variants are not relevant. if quantize.levels != 2: return # Recognize two cases: 0/+1 and -1/+1. zp1 = np.all(output_low == 0) or np.all(output_high == 0) m1p1 = np.all(-output_low == output_high) if (not zp1 and not m1p1) or (zp1 and m1p1): log.debug( 'BinarizeWeightsM1P1 cannot apply transformation for data {} because it does\'t has one of' ' 0/+1 or -1/+1 forms.'.format(match['quantized'].name)) return # TODO: Extract real scalar from 3rd and 4th inputs; reusing original tensors is dangerous because # it may have incompatible shape. mult_term = quantize.in_node(3) if np.all( output_high == 0) else quantize.in_node(4) new_shape = Const( graph, { 'name': quantize_name + '/Reshape/Shape', 'value': int64_array([-1, 1, 1]) }).create_node_with_data() reshape = Reshape(graph, { 'name': quantize_name + '/Reshape' }).create_node_with_data([mult_term, new_shape]) # Patch inflow path (by diving by mult_term) # Put a new Pow/Mul combination here: # ---->---- (here)---> data ---> [3rd/4th ports]quantize ---> quantized ---> operator if len(match['quantized'].out_nodes()) > 1: log.debug( 'BinarizeWeightsM1P1: len(match[\'quantized\'].out_nodes()) > 1' ) return power_of_exponent = Const(graph, { 'name': quantize_name + '/DivNormalize/Power', 'value': mo_array(-1.0) }).create_node_with_data() div_op = Pow(graph, {'name': quantize_name + '/DivNormalize'}) div_output = div_op.create_node_with_data( [mult_term, power_of_exponent]) for i in [3, 4]: match['quantize'].insert_node_with_data_before( match['quantize'].in_node(i), Mul, dict(name=quantize_name + '/MulNormalize'), additional_inputs=[div_output], ) match[ 'quantized'].value = None # reset value because it will be recomputed match['quantize'].infer(match['quantize']) # Put a complimentary new Mul node here: operator -->---(here)-----> operator.out_node() match['operator'].insert_node_with_data_after( match['operator'].out_node(), Mul, dict(name=match['operator'].name + '/MulNormalize'), [reshape], ) # Disable 'operator' fusion with linear ops, otherwise it will annihilate changes that we just made match['operator']['can_be_fused'] = False
def extract(cls, node): Pow.update_node_stat(node, {'data_type': tf_dtype_extractor(node.pb.attr["T"].type)}) return cls.enabled