Q15ScaledQuantization from expressions.symbolic.symbol import SymbolStats from graph.types import ExpressionFusionParameters from quantization.new_qrec import QRec from quantization.qtype import QType from quantization.qtype_constraint import MatchAll from quantization.unified_quantization_handler import (in_qs_constraint, out_qs_constraint, params_type) from ..mult_quantization_handler import MultQuantizionHandler LOG = logging.getLogger('nntool.' + __name__) @params_type(ExpressionFusionParameters) @in_qs_constraint(MatchAll({'dtype': np.int8})) @out_qs_constraint(MatchAll({'dtype': np.int8})) class ExpressionFusionMult(MultQuantizionHandler): @classmethod def _quantize(cls, params, in_qs, stats, **kwargs): force_out_qs, _ = cls.get_mult_opts(**kwargs) if stats is None or 'expression' not in stats: raise ValueError( f'no valid range information is present for {params.name}') # expressions need a symmetric input in_qs = cls.force_symmetric(in_qs) if in_qs is None: LOG.info('expression quantizer for {params.name} was not able to force input symmetric')
# You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <https://www.gnu.org/licenses/>. import numpy as np from bfloat16 import bfloat16 from graph.types import OutputParameters from quantization.float.float_quantization_handler import \ FloatQuantizionHandler from quantization.new_qrec import QRec from quantization.qtype import QType from quantization.qtype_constraint import MatchAll from quantization.unified_quantization_handler import (in_qs_constraint, out_qs_constraint, params_type) @params_type(OutputParameters) @in_qs_constraint(MatchAll({'dtype': set([np.float32, np.float16, bfloat16])})) @out_qs_constraint(MatchAll({'dtype': set([np.float32, np.float16, bfloat16])})) class FloatOutput(FloatQuantizionHandler): @classmethod def _quantize(cls, params, in_qs, stats, **kwargs): force_out_qs, dtype = cls.get_float_opts(**kwargs) if force_out_qs and any(qtype.dtype != dtype for qtype in force_out_qs if qtype is not None): return None return QRec.float(in_qs=[QType(dtype=dtype)], out_qs=[QType(dtype=dtype)], float_dtype=dtype)
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <https://www.gnu.org/licenses/>. import logging import numpy as np from graph.types import SplitParameters from quantization.qtype_constraint import MatchAll from quantization.quantizers.split_mixin import SplitMixin from quantization.unified_quantization_handler import (in_qs_constraint, needs_stats, out_qs_constraint, params_type) from ..mult_quantization_handler import MultQuantizionHandler LOG = logging.getLogger('nntool.' + __name__) @params_type(SplitParameters) @in_qs_constraint({'dtype': set([np.int8, np.uint8])}) @out_qs_constraint(MatchAll({'dtype': set([np.int8, np.uint8])})) @needs_stats(False) class SplitMult(MultQuantizionHandler, SplitMixin): @classmethod def _quantize(cls, params, in_qs, stats, **kwargs): return cls._handle(params, in_qs, stats, 'scaled', **kwargs)
# published by the Free Software Foundation, either version 3 of the # License, or (at your option) any later version. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <https://www.gnu.org/licenses/>. import numpy as np from quantization.qtype_constraint import MatchAll from quantization.quantizers.no_change_mixin import NoChangeMixin from quantization.unified_quantization_handler import (in_qs_constraint, needs_stats, out_qs_constraint, params_type) from ..pow2_quantization_handler import Pow2QuantizionHandler @params_type('__default__') @in_qs_constraint(MatchAll({'dtype': set([np.int8, np.int16])})) @out_qs_constraint(MatchAll({'dtype': set([np.int8, np.int16])})) @needs_stats(False) class NoChangePow2(Pow2QuantizionHandler, NoChangeMixin): @classmethod def _quantize(cls, params, in_qs, stats, **kwargs): return cls._handle(params, in_qs, stats, 'symmetric', **kwargs)
from expressions.symbolic.symbol import SymbolStats from graph.types import ExpressionFusionParameters from quantization.new_qrec import QRec from quantization.qtype import QType from quantization.qtype_constraint import MatchAll from quantization.unified_quantization_handler import (in_qs_constraint, out_qs_constraint, params_type) from ..mult_quantization_handler import MultQuantizionHandler LOG = logging.getLogger('nntool.' + __name__) @params_type(ExpressionFusionParameters) @in_qs_constraint(MatchAll({'dtype': {np.int8, np.uint8, np.int16, np.uint16}})) @out_qs_constraint( MatchAll({'dtype': {np.int8, np.uint8, np.int16, np.uint16}})) class ExpressionFusionMult(MultQuantizionHandler): @classmethod def _quantize(cls, params, in_qs, stats, **kwargs): force_out_qs, _ = cls.get_mult_opts(**kwargs) if stats is None or 'expression' not in stats: raise ValueError( f'no valid range information is present for {params.name}') # # expressions need a symmetric input # in_qs = cls.force_symmetric(in_qs) # if in_qs is None: