def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        quantize = match['quantize']
        preop = match['preop']

        for i in [0, 1]:
            if preop.in_port(i).get_source().node.soft_get('type') in [
                    'Convolution', 'Deconvolution', 'MatMul'
            ]:
                return

        tensor_port, value_port = get_tensor_in_port(preop), get_value_in_port(
            preop)

        if value_port is None or value_port.data.get_value() is None:
            log.debug(
                'AddQuantizeFuse: cannot fuse because Add op has dynamic inputs'
            )
            return

        # Direct modifications to quantize 1-st and 2-nd port inputs are performed.
        # So the data nodes at those inputs shouldn't have more than 1 consumer maximum 2 consumers to the same
        # quantize op (consumed by 1st and 2nd ports). So we duplicate FakeQuantize in_port 1, 2, 3, 4 data
        resolve_shared_inputs(node=quantize, port_ids_to_duplicate=[1, 2])

        quantize.in_port(1).data.set_value(
            quantize.in_port(1).data.get_value() - value_port.data.get_value())
        if quantize.in_node(1).id != quantize.in_node(2).id:
            quantize.in_port(2).data.set_value(
                quantize.in_port(2).data.get_value() -
                value_port.data.get_value())

        in_add_connection = quantize.in_port(0).get_source().node.in_port(
            0).get_connection()
        quantize.in_port(0).disconnect()
        in_add_connection.add_destination(quantize.in_port(0))
Ejemplo n.º 2
0
    def replace_pattern(self, graph: Graph, match: dict):
        quantize = match['quantize']

        if quantize.levels == 2:
            # extra logic due to special 1 & 2 port input meaning in binary case - it is threshold separating two quants
            threshold = quantize.in_port(1).data.get_value()

            # Direct modifications to quantize 1-st port input are performed.
            # So the data node at this input shouldn't have more than 1 consumer maximum 2 consumers to the same
            # quantize op (consumed by 1st and 2nd ports). So we duplicate FakeQuantize in_port 1 data if needed
            resolve_shared_inputs(node=quantize, port_ids_to_duplicate=[1])

            # As we restricted to binarization case only, so we need to detect from
            # which side of 0 FakeQuantize threshold resides:
            #   if the threshold > 0, it remains the same;
            #   if the threshold == 0, it also remains the same;
            #   if the threshold < 0, it should be modified to -infinity that means that all inputs map to output_high
            modification_mask = threshold < 0
            threshold[modification_mask] = float('-inf')

        # Reconnect ReLU as it no longer needed for current FakeQuantize
        in_relu_connection = quantize.in_port(0).get_source().node.in_port(
            0).get_connection()
        quantize.in_port(0).disconnect()
        in_relu_connection.add_destination(quantize.in_port(0))