def replace_pattern(self, graph: Graph, match: Dict[str, Node]): quantize = match['quantize'] preop = match['preop'] for i in [0, 1]: if preop.in_port(i).get_source().node.soft_get('type') in [ 'Convolution', 'Deconvolution', 'MatMul' ]: return tensor_port, value_port = get_tensor_in_port(preop), get_value_in_port( preop) if value_port is None or value_port.data.get_value() is None: log.debug( 'AddQuantizeFuse: cannot fuse because Add op has dynamic inputs' ) return # Direct modifications to quantize 1-st and 2-nd port inputs are performed. # So the data nodes at those inputs shouldn't have more than 1 consumer maximum 2 consumers to the same # quantize op (consumed by 1st and 2nd ports). So we duplicate FakeQuantize in_port 1, 2, 3, 4 data resolve_shared_inputs(node=quantize, port_ids_to_duplicate=[1, 2]) quantize.in_port(1).data.set_value( quantize.in_port(1).data.get_value() - value_port.data.get_value()) if quantize.in_node(1).id != quantize.in_node(2).id: quantize.in_port(2).data.set_value( quantize.in_port(2).data.get_value() - value_port.data.get_value()) in_add_connection = quantize.in_port(0).get_source().node.in_port( 0).get_connection() quantize.in_port(0).disconnect() in_add_connection.add_destination(quantize.in_port(0))
def replace_pattern(self, graph: Graph, match: dict): quantize = match['quantize'] if quantize.levels == 2: # extra logic due to special 1 & 2 port input meaning in binary case - it is threshold separating two quants threshold = quantize.in_port(1).data.get_value() # Direct modifications to quantize 1-st port input are performed. # So the data node at this input shouldn't have more than 1 consumer maximum 2 consumers to the same # quantize op (consumed by 1st and 2nd ports). So we duplicate FakeQuantize in_port 1 data if needed resolve_shared_inputs(node=quantize, port_ids_to_duplicate=[1]) # As we restricted to binarization case only, so we need to detect from # which side of 0 FakeQuantize threshold resides: # if the threshold > 0, it remains the same; # if the threshold == 0, it also remains the same; # if the threshold < 0, it should be modified to -infinity that means that all inputs map to output_high modification_mask = threshold < 0 threshold[modification_mask] = float('-inf') # Reconnect ReLU as it no longer needed for current FakeQuantize in_relu_connection = quantize.in_port(0).get_source().node.in_port( 0).get_connection() quantize.in_port(0).disconnect() in_relu_connection.add_destination(quantize.in_port(0))