def _handle(cls, params, in_qs, _, ktype, **kwargs): force_out_qs = kwargs.get('force_out_qs') force_out_q = force_out_qs and force_out_qs[0] forced_in_qs = [in_q for in_q in in_qs if in_q.forced] # two inputs cannot be forced to different values if forced_in_qs and not QType.forced_equal(*forced_in_qs): LOG.debug( 'two input qtypes of %s are forced to different qtypes - rejecting', params.name) return None # input cannot be forced to different value than output if force_out_q and not force_out_q.can_force(force_out_q, *in_qs): LOG.debug( 'output and input of %s are forced to different qtypes - rejecting', params.name) return None backwards = kwargs.get('backwards') if backwards: if force_out_q: in_qs = [deepcopy(force_out_q) for _ in in_qs] return QRec(in_qs=in_qs, out_qs=[deepcopy(force_out_q)], ktype=ktype) elif force_out_q and not all(in_q == force_out_q for in_q in in_qs): # if going forwards and our output is forced and does not match input then # we cannot satisfy LOG.debug( "output of %s is forced and inputs don't match - rejecting", params.name) return None return QRec(in_qs=in_qs, out_qs=[deepcopy(in_qs[0])], ktype=ktype)
def _handle(cls, params, in_qs, stats, ktype, **kwargs): forced_out_qs = kwargs.get('force_out_qs') if forced_out_qs: # some could not be forced forced_out_qs = [ qtype for qtype in forced_out_qs if qtype is not None] forced_in_qs = [in_q for in_q in in_qs if in_q.forced] forced_in_q = forced_in_qs[0] if forced_in_qs else None # two inputs cannot be forced to different values if forced_out_qs and not QType.forced_equal(*forced_out_qs): LOG.info( 'two output qtypes of split %s are forced to different qtypes', params.name) return None # input cannot be forced to different value than output if forced_in_q and forced_out_qs and not forced_in_q.can_force(*forced_out_qs): LOG.error( 'output and input of split %s are forced to different qtypes', params.name) return None # now if forced we are compatible with the split forced_out_q = forced_out_qs[0] if forced_out_qs else None if forced_in_q: out_qs = [deepcopy(forced_in_q) for _ in range(params.num_splits)] return QRec(ktype=ktype, in_qs=[deepcopy(forced_in_q)], out_qs=out_qs) if forced_out_q: out_qs = [deepcopy(forced_out_q) for _ in range(params.num_splits)] return QRec(ktype=ktype, in_qs=[deepcopy(forced_out_q)], out_qs=out_qs) out_qs = [deepcopy(in_qs[0]) for _ in range(params.num_splits)] return QRec(ktype=ktype, in_qs=[deepcopy(in_qs[0])], out_qs=out_qs)
def _handle(cls, params, in_qs, _, **kwargs): force_out_qs = kwargs['force_out_qs'] force_out_q = force_out_qs[0] if force_out_qs else None forced_in_qs = [in_q for in_q in in_qs if in_q.forced] # two inputs cannot be forced to different values if forced_in_qs and not QType.forced_equal(*forced_in_qs): LOG.info( 'two input qtypes of concat %s are forced to different qtypes', params.name) return None # input cannot be forced to different value than output if force_out_q and not force_out_q.can_force(*forced_in_qs): LOG.info( 'output and input of concat %s are forced to different qtypes', params.name) return None backwards = kwargs.get('backwards') # if we are going backwards or are forced if backwards: if force_out_q: ok = True if force_out_q.forced_dtype and any( in_q.dtype != force_out_q.dtype for in_q in in_qs): ok = False if force_out_q.forced_zero_point or force_out_q.forced_scale or force_out_q.forced_q: ok = False # if output must be forced if not ok: in_qs = [deepcopy(force_out_q) for _ in in_qs] return QRec(ktype=cls.KTYPE, in_qs=in_qs, out_qs=[deepcopy(force_out_q)]) # if all the inputs are the same qtype then we output that qtype if all(in_qs[0] == in_q for in_q in in_qs[1::]): return QRec(ktype=cls.KTYPE, in_qs=in_qs, out_qs=[deepcopy(in_qs[0])]) # our output cannot be forced at this point # if an input has scale or q forced then all forced inputs must be the same here # TODO - have a general function for this problem - should pick with force constraints respecting dtype if forced_in_qs and any(fin_qs.forced_scale or fin_qs.forced_q for fin_qs in forced_in_qs): in_qs = [deepcopy(forced_in_qs[0]) for _ in in_qs] return QRec(ktype=cls.KTYPE, in_qs=in_qs, out_qs=[deepcopy(forced_in_qs[0])]) # if all the inputs are not the same then force all of them to the maximum input size with a Q that # fits the most int bits common_q = cls._get_common_q(in_qs) in_qs = [deepcopy(common_q) for _ in in_qs] return QRec(ktype=cls.KTYPE, in_qs=in_qs, out_qs=[deepcopy(common_q)])