def _simplify_symargs(self, symargs, symvals): ''' Simplify symargs and symvals in-place iteratively. ''' while not self._simplify_symargs_one_pass(symargs, symvals): pass used_syms = symtuple(*[ symtuple(*a.values()) for a in itertools.chain.from_iterable(symargs) ]).free_symbols assert set(used_syms) == set(symvals.keys()) assert all(val for val in symvals.values())
def _simplify_symargs_one_pass(symargs, symvals): ''' Simplify symargs and symvals in-place: - If fbi/ofm is False, then remove it. - If fbi/ofm is True, then remove topi/ofm. - If a symbol can take only one value, then substitute it. - If a symbol only occurs once, then remove its constraint. Return whether the symargs and symvals are already simplified. ''' for a in itertools.chain.from_iterable(symargs): is_fbifm = a.get('fbifm') is_fbofm = a.get('fbofm') # pylint: disable=singleton-comparison # lhs may be symbolic, see # docs.sympy.org/latest/modules/logic.html#sympy.logic.boolalg.BooleanTrue if is_fbifm == True: a.pop('topifm', 0) if is_fbifm == False: a.pop('fbifm', False) if is_fbofm == True: a.pop('topofm', 0) if is_fbofm == False: a.pop('fbofm', False) subs_dict = {} # Possible values for symbols. subs_dict.update( (s, symvals[s][0]) for s in symvals if len(symvals[s]) == 1) # Count the occurrence of symbols in all args (values). symcnts = Counter(s for a in itertools.chain.from_iterable(symargs) for val in a.values() for s in symtuple(val).free_symbols) assert set(symcnts.keys()).issubset(symvals.keys()) subs_dict.update( (s, None) for s in set(symvals.keys()) - set(symcnts.keys())) subs_dict.update((s, 0 if str(s).startswith('top') else False) for s in symcnts if symcnts[s] <= 1) # Substitute symbols and remove from symbol dict. for a in itertools.chain.from_iterable(symargs): for k in a: a[k] = symtuple(a[k]).subs(subs_dict)[0] for s in subs_dict: del symvals[s] return not subs_dict
def _lazify_topofm_symargs(self, symargs, symvals): ''' Turn qualified topofm constraints into lazily updated rules. If a symbol is only used as the topofm constraint by a single CONV layer and some local-region layers, we can turn it into a lazily update rule. ''' sym2conv = {} # symbol --> the only CONV layer using it. sym2lrs = {} # symbol --> list of local-region layer using it. unqual_syms = set() # symbols used by two or more CONV layers. for l, a in zip(itertools.chain.from_iterable(self.seg), itertools.chain.from_iterable(symargs)): layer = self.network[l] if isinstance(layer, ConvLayer): topofm = a.get('topofm', 0) topifm = a.get('topifm', 0) for s in symtuple(topofm, topifm).free_symbols: if s not in unqual_syms: if s in sym2conv: # If a symbol is used in two CONV layers, it cannot # be lazily updated. del sym2conv[s] sym2lrs.pop(s, []) unqual_syms.add(s) elif topofm == s: assert s not in sym2lrs sym2conv[s] = l else: topofm = a.get('topofm', 0) if topofm in sym2conv: sym2lrs.setdefault(topofm, []).append(l) assert 0 not in sym2conv and 0 not in sym2lrs syms = sym2conv.keys() # symbols to be lazily updated. lr2conv = {} # local-region layer to the CONV layer constraining it. for s in syms: for lr in sym2lrs.get(s, []): lr2conv[lr] = sym2conv[s] lconvs = set( lr2conv.values()) # CONV layer whose topofm to be removed. for l, a in zip(itertools.chain.from_iterable(self.seg), itertools.chain.from_iterable(symargs)): if l in lconvs: # Remove CONV topofm. assert sym2conv[a['topofm']] == l del a['topofm'] elif l in lr2conv: # Link local-region layer to the CONV layer. lconv = lr2conv[l] assert sym2conv[a['topofm']] == lconv del a['topofm'] a['update_dict'] = { lconv: PipelineSegment.TopOfmUpdateLambda() } for s in syms: del symvals[s]
def _subs_symargs(symargs, *subs_args): ''' Substitute symbols. The additional arguments are passed to subs(). Return a new substituted copy without modifying the original one. ''' # sympify=False is necessary because there may be str in the values. return [[ dict((k, symtuple(a[k], sympify=False).subs(*subs_args)[0]) for k in a) for a in atpl ] for atpl in symargs]