コード例 #1
0
    def _handle(cls, params, in_qs, _, ktype, **kwargs):
        force_out_qs = kwargs.get('force_out_qs')
        force_out_q = force_out_qs and force_out_qs[0]
        forced_in_qs = [in_q for in_q in in_qs if in_q.forced]
        # two inputs cannot be forced to different values
        if forced_in_qs and not QType.forced_equal(*forced_in_qs):
            LOG.debug(
                'two input qtypes of %s are forced to different qtypes - rejecting',
                params.name)
            return None
        # input cannot be forced to different value than output
        if force_out_q and not force_out_q.can_force(force_out_q, *in_qs):
            LOG.debug(
                'output and input of %s are forced to different qtypes - rejecting',
                params.name)
            return None

        backwards = kwargs.get('backwards')
        if backwards:
            if force_out_q:
                in_qs = [deepcopy(force_out_q) for _ in in_qs]
                return QRec(in_qs=in_qs,
                            out_qs=[deepcopy(force_out_q)],
                            ktype=ktype)
        elif force_out_q and not all(in_q == force_out_q for in_q in in_qs):
            # if going forwards and our output is forced and does not match input then
            # we cannot satisfy
            LOG.debug(
                "output of %s is forced and inputs don't match - rejecting",
                params.name)
            return None

        return QRec(in_qs=in_qs, out_qs=[deepcopy(in_qs[0])], ktype=ktype)
コード例 #2
0
ファイル: split_mixin.py プロジェクト: mfkiwl/gap_sdk
    def _handle(cls, params, in_qs, stats, ktype, **kwargs):
        forced_out_qs = kwargs.get('force_out_qs')
        if forced_out_qs:
            # some could not be forced
            forced_out_qs = [
                qtype for qtype in forced_out_qs if qtype is not None]
        forced_in_qs = [in_q for in_q in in_qs if in_q.forced]
        forced_in_q = forced_in_qs[0] if forced_in_qs else None
        # two inputs cannot be forced to different values
        if forced_out_qs and not QType.forced_equal(*forced_out_qs):
            LOG.info(
                'two output qtypes of split %s are forced to different qtypes', params.name)
            return None
        # input cannot be forced to different value than output
        if forced_in_q and forced_out_qs and not forced_in_q.can_force(*forced_out_qs):
            LOG.error(
                'output and input of split %s are forced to different qtypes', params.name)
            return None

        # now if forced we are compatible with the split
        forced_out_q = forced_out_qs[0] if forced_out_qs else None

        if forced_in_q:
            out_qs = [deepcopy(forced_in_q) for _ in range(params.num_splits)]
            return QRec(ktype=ktype, in_qs=[deepcopy(forced_in_q)], out_qs=out_qs)

        if forced_out_q:
            out_qs = [deepcopy(forced_out_q) for _ in range(params.num_splits)]
            return QRec(ktype=ktype, in_qs=[deepcopy(forced_out_q)], out_qs=out_qs)

        out_qs = [deepcopy(in_qs[0]) for _ in range(params.num_splits)]
        return QRec(ktype=ktype, in_qs=[deepcopy(in_qs[0])], out_qs=out_qs)
コード例 #3
0
ファイル: concat_mixin.py プロジェクト: mfkiwl/gap_sdk
    def _handle(cls, params, in_qs, _, **kwargs):
        force_out_qs = kwargs['force_out_qs']
        force_out_q = force_out_qs[0] if force_out_qs else None
        forced_in_qs = [in_q for in_q in in_qs if in_q.forced]
        # two inputs cannot be forced to different values
        if forced_in_qs and not QType.forced_equal(*forced_in_qs):
            LOG.info(
                'two input qtypes of concat %s are forced to different qtypes',
                params.name)
            return None
        # input cannot be forced to different value than output
        if force_out_q and not force_out_q.can_force(*forced_in_qs):
            LOG.info(
                'output and input of concat %s are forced to different qtypes',
                params.name)
            return None

        backwards = kwargs.get('backwards')
        # if we are going backwards or are forced
        if backwards:
            if force_out_q:
                ok = True
                if force_out_q.forced_dtype and any(
                        in_q.dtype != force_out_q.dtype for in_q in in_qs):
                    ok = False
                if force_out_q.forced_zero_point or force_out_q.forced_scale or force_out_q.forced_q:
                    ok = False
                # if output must be forced
                if not ok:
                    in_qs = [deepcopy(force_out_q) for _ in in_qs]
                    return QRec(ktype=cls.KTYPE,
                                in_qs=in_qs,
                                out_qs=[deepcopy(force_out_q)])

        # if all the inputs are the same qtype then we output that qtype
        if all(in_qs[0] == in_q for in_q in in_qs[1::]):
            return QRec(ktype=cls.KTYPE,
                        in_qs=in_qs,
                        out_qs=[deepcopy(in_qs[0])])

        # our output cannot be forced at this point
        # if an input has scale or q forced then all forced inputs must be the same here
        # TODO - have a general function for this problem - should pick with force constraints respecting dtype
        if forced_in_qs and any(fin_qs.forced_scale or fin_qs.forced_q
                                for fin_qs in forced_in_qs):
            in_qs = [deepcopy(forced_in_qs[0]) for _ in in_qs]
            return QRec(ktype=cls.KTYPE,
                        in_qs=in_qs,
                        out_qs=[deepcopy(forced_in_qs[0])])

        # if all the inputs are not the same then force all of them to the maximum input size with a Q that
        # fits the most int bits
        common_q = cls._get_common_q(in_qs)
        in_qs = [deepcopy(common_q) for _ in in_qs]

        return QRec(ktype=cls.KTYPE, in_qs=in_qs, out_qs=[deepcopy(common_q)])