예제 #1
0
    Q15ScaledQuantization
from expressions.symbolic.symbol import SymbolStats
from graph.types import ExpressionFusionParameters
from quantization.new_qrec import QRec
from quantization.qtype import QType
from quantization.qtype_constraint import MatchAll
from quantization.unified_quantization_handler import (in_qs_constraint,
                                                       out_qs_constraint,
                                                       params_type)

from ..mult_quantization_handler import MultQuantizionHandler

LOG = logging.getLogger('nntool.' + __name__)

@params_type(ExpressionFusionParameters)
@in_qs_constraint(MatchAll({'dtype': np.int8}))
@out_qs_constraint(MatchAll({'dtype': np.int8}))
class ExpressionFusionMult(MultQuantizionHandler):
    @classmethod
    def _quantize(cls, params, in_qs, stats, **kwargs):
        force_out_qs, _ = cls.get_mult_opts(**kwargs)

        if stats is None or 'expression' not in stats:
            raise ValueError(
                f'no valid range information is present for {params.name}')

        # expressions need a symmetric input
        in_qs = cls.force_symmetric(in_qs)

        if in_qs is None:
            LOG.info('expression quantizer for {params.name} was not able to force input symmetric')
예제 #2
0
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import numpy as np
from bfloat16 import bfloat16
from graph.types import OutputParameters
from quantization.float.float_quantization_handler import \
    FloatQuantizionHandler
from quantization.new_qrec import QRec
from quantization.qtype import QType
from quantization.qtype_constraint import MatchAll
from quantization.unified_quantization_handler import (in_qs_constraint,
                                                       out_qs_constraint,
                                                       params_type)


@params_type(OutputParameters)
@in_qs_constraint(MatchAll({'dtype': set([np.float32, np.float16, bfloat16])}))
@out_qs_constraint(MatchAll({'dtype': set([np.float32, np.float16,
                                           bfloat16])}))
class FloatOutput(FloatQuantizionHandler):
    @classmethod
    def _quantize(cls, params, in_qs, stats, **kwargs):
        force_out_qs, dtype = cls.get_float_opts(**kwargs)
        if force_out_qs and any(qtype.dtype != dtype for qtype in force_out_qs
                                if qtype is not None):
            return None
        return QRec.float(in_qs=[QType(dtype=dtype)],
                          out_qs=[QType(dtype=dtype)],
                          float_dtype=dtype)
예제 #3
0
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import logging

import numpy as np
from graph.types import SplitParameters
from quantization.qtype_constraint import MatchAll
from quantization.quantizers.split_mixin import SplitMixin
from quantization.unified_quantization_handler import (in_qs_constraint,
                                                       needs_stats,
                                                       out_qs_constraint,
                                                       params_type)

from ..mult_quantization_handler import MultQuantizionHandler

LOG = logging.getLogger('nntool.' + __name__)


@params_type(SplitParameters)
@in_qs_constraint({'dtype': set([np.int8, np.uint8])})
@out_qs_constraint(MatchAll({'dtype': set([np.int8, np.uint8])}))
@needs_stats(False)
class SplitMult(MultQuantizionHandler, SplitMixin):
    @classmethod
    def _quantize(cls, params, in_qs, stats, **kwargs):
        return cls._handle(params, in_qs, stats, 'scaled', **kwargs)
예제 #4
0
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import numpy as np
from quantization.qtype_constraint import MatchAll
from quantization.quantizers.no_change_mixin import NoChangeMixin
from quantization.unified_quantization_handler import (in_qs_constraint,
                                                       needs_stats,
                                                       out_qs_constraint,
                                                       params_type)

from ..pow2_quantization_handler import Pow2QuantizionHandler


@params_type('__default__')
@in_qs_constraint(MatchAll({'dtype': set([np.int8, np.int16])}))
@out_qs_constraint(MatchAll({'dtype': set([np.int8, np.int16])}))
@needs_stats(False)
class NoChangePow2(Pow2QuantizionHandler, NoChangeMixin):
    @classmethod
    def _quantize(cls, params, in_qs, stats, **kwargs):
        return cls._handle(params, in_qs, stats, 'symmetric', **kwargs)
예제 #5
0
from expressions.symbolic.symbol import SymbolStats
from graph.types import ExpressionFusionParameters
from quantization.new_qrec import QRec
from quantization.qtype import QType
from quantization.qtype_constraint import MatchAll
from quantization.unified_quantization_handler import (in_qs_constraint,
                                                       out_qs_constraint,
                                                       params_type)

from ..mult_quantization_handler import MultQuantizionHandler

LOG = logging.getLogger('nntool.' + __name__)


@params_type(ExpressionFusionParameters)
@in_qs_constraint(MatchAll({'dtype': {np.int8, np.uint8, np.int16,
                                      np.uint16}}))
@out_qs_constraint(
    MatchAll({'dtype': {np.int8, np.uint8, np.int16, np.uint16}}))
class ExpressionFusionMult(MultQuantizionHandler):
    @classmethod
    def _quantize(cls, params, in_qs, stats, **kwargs):
        force_out_qs, _ = cls.get_mult_opts(**kwargs)

        if stats is None or 'expression' not in stats:
            raise ValueError(
                f'no valid range information is present for {params.name}')

        # # expressions need a symmetric input
        # in_qs = cls.force_symmetric(in_qs)

        # if in_qs is None: