Example #1
0
 def _simplify_symargs(self, symargs, symvals):
     ''' Simplify symargs and symvals in-place iteratively. '''
     while not self._simplify_symargs_one_pass(symargs, symvals):
         pass
     used_syms = symtuple(*[
         symtuple(*a.values())
         for a in itertools.chain.from_iterable(symargs)
     ]).free_symbols
     assert set(used_syms) == set(symvals.keys())
     assert all(val for val in symvals.values())
Example #2
0
    def _simplify_symargs_one_pass(symargs, symvals):
        '''
        Simplify symargs and symvals in-place:
        - If fbi/ofm is False, then remove it.
        - If fbi/ofm is True, then remove topi/ofm.
        - If a symbol can take only one value, then substitute it.
        - If a symbol only occurs once, then remove its constraint.

        Return whether the symargs and symvals are already simplified.
        '''
        for a in itertools.chain.from_iterable(symargs):
            is_fbifm = a.get('fbifm')
            is_fbofm = a.get('fbofm')
            # pylint: disable=singleton-comparison
            # lhs may be symbolic, see
            # docs.sympy.org/latest/modules/logic.html#sympy.logic.boolalg.BooleanTrue
            if is_fbifm == True:
                a.pop('topifm', 0)
            if is_fbifm == False:
                a.pop('fbifm', False)
            if is_fbofm == True:
                a.pop('topofm', 0)
            if is_fbofm == False:
                a.pop('fbofm', False)

        subs_dict = {}

        # Possible values for symbols.
        subs_dict.update(
            (s, symvals[s][0]) for s in symvals if len(symvals[s]) == 1)

        # Count the occurrence of symbols in all args (values).
        symcnts = Counter(s for a in itertools.chain.from_iterable(symargs)
                          for val in a.values()
                          for s in symtuple(val).free_symbols)
        assert set(symcnts.keys()).issubset(symvals.keys())
        subs_dict.update(
            (s, None) for s in set(symvals.keys()) - set(symcnts.keys()))
        subs_dict.update((s, 0 if str(s).startswith('top') else False)
                         for s in symcnts if symcnts[s] <= 1)

        # Substitute symbols and remove from symbol dict.
        for a in itertools.chain.from_iterable(symargs):
            for k in a:
                a[k] = symtuple(a[k]).subs(subs_dict)[0]
        for s in subs_dict:
            del symvals[s]

        return not subs_dict
Example #3
0
    def _lazify_topofm_symargs(self, symargs, symvals):
        '''
        Turn qualified topofm constraints into lazily updated rules.

        If a symbol is only used as the topofm constraint by a single CONV
        layer and some local-region layers, we can turn it into a lazily update
        rule.
        '''
        sym2conv = {}  # symbol --> the only CONV layer using it.
        sym2lrs = {}  # symbol --> list of local-region layer using it.
        unqual_syms = set()  # symbols used by two or more CONV layers.
        for l, a in zip(itertools.chain.from_iterable(self.seg),
                        itertools.chain.from_iterable(symargs)):
            layer = self.network[l]
            if isinstance(layer, ConvLayer):
                topofm = a.get('topofm', 0)
                topifm = a.get('topifm', 0)
                for s in symtuple(topofm, topifm).free_symbols:
                    if s not in unqual_syms:
                        if s in sym2conv:
                            # If a symbol is used in two CONV layers, it cannot
                            # be lazily updated.
                            del sym2conv[s]
                            sym2lrs.pop(s, [])
                            unqual_syms.add(s)
                        elif topofm == s:
                            assert s not in sym2lrs
                            sym2conv[s] = l
            else:
                topofm = a.get('topofm', 0)
                if topofm in sym2conv:
                    sym2lrs.setdefault(topofm, []).append(l)
        assert 0 not in sym2conv and 0 not in sym2lrs

        syms = sym2conv.keys()  # symbols to be lazily updated.
        lr2conv = {}  # local-region layer to the CONV layer constraining it.
        for s in syms:
            for lr in sym2lrs.get(s, []):
                lr2conv[lr] = sym2conv[s]
        lconvs = set(
            lr2conv.values())  # CONV layer whose topofm to be removed.

        for l, a in zip(itertools.chain.from_iterable(self.seg),
                        itertools.chain.from_iterable(symargs)):
            if l in lconvs:
                # Remove CONV topofm.
                assert sym2conv[a['topofm']] == l
                del a['topofm']
            elif l in lr2conv:
                # Link local-region layer to the CONV layer.
                lconv = lr2conv[l]
                assert sym2conv[a['topofm']] == lconv
                del a['topofm']
                a['update_dict'] = {
                    lconv: PipelineSegment.TopOfmUpdateLambda()
                }

        for s in syms:
            del symvals[s]
Example #4
0
    def _subs_symargs(symargs, *subs_args):
        '''
        Substitute symbols. The additional arguments are passed to subs().

        Return a new substituted copy without modifying the original one.
        '''
        # sympify=False is necessary because there may be str in the values.
        return [[
            dict((k, symtuple(a[k], sympify=False).subs(*subs_args)[0])
                 for k in a) for a in atpl
        ] for atpl in symargs]