def __query__(self, q): # The ordered set is needed for deterministic optimization. variables = OrderedSet() for tag in q.include: variables.update(self.__db__[tag]) for tag in q.require: variables.intersection_update(self.__db__[tag]) for tag in q.exclude: variables.difference_update(self.__db__[tag]) remove = OrderedSet() add = OrderedSet() for obj in variables: if isinstance(obj, OptimizationDatabase): def_sub_query = q if q.extra_optimizations: def_sub_query = copy.copy(q) def_sub_query.extra_optimizations = [] sq = q.subquery.get(obj.name, def_sub_query) replacement = obj.query(sq) replacement.name = obj.name remove.add(obj) add.add(replacement) variables.difference_update(remove) variables.update(add) return variables
def __init__( self, include: Sequence[str], require: Optional[Sequence[str]] = None, exclude: Optional[Sequence[str]] = None, subquery: Optional[Dict[str, "OptimizationQuery"]] = None, position_cutoff: float = math.inf, extra_optimizations: Optional[Sequence[OptimizersType]] = None, ): """ Parameters ========== include: A set of tags such that every optimization obtained through this ``OptimizationQuery`` must have **one** of the tags listed. This field is required and basically acts as a starting point for the search. require: A set of tags such that every optimization obtained through this ``OptimizationQuery`` must have **all** of these tags. exclude: A set of tags such that every optimization obtained through this ``OptimizationQuery`` must have **none** of these tags. subquery: A dictionary mapping the name of a sub-database to a special ``OptimizationQuery``. If no subquery is given for a sub-database, the original ``OptimizationQuery`` will be used again. position_cutoff: Only optimizations with position less than the cutoff are returned. extra_optimizations: Extra optimizations to be added. """ self.include = OrderedSet(include) self.require = require or OrderedSet() self.exclude = exclude or OrderedSet() self.subquery = subquery or {} self.position_cutoff = position_cutoff self.name: Optional[str] = None if extra_optimizations is None: extra_optimizations = [] self.extra_optimizations = extra_optimizations if isinstance(self.require, (list, tuple)): self.require = OrderedSet(self.require) if isinstance(self.exclude, (list, tuple)): self.exclude = OrderedSet(self.exclude)
def register( self, name: str, optimizer: Union["OptimizationDatabase", OptimizersType], *tags: str, use_db_name_as_tag=True, **kwargs, ): """Register a new optimizer to the database. Parameters ---------- name: Name of the optimizer. opt: The optimizer to register. tags: Tag name that allow to select the optimizer. use_db_name_as_tag: Add the database's name as a tag, so that its name can be used in a query. By default, all optimizations registered in ``EquilibriumDB`` are selected when the ``"EquilibriumDB"`` name is used as a tag. We do not want this behavior for some optimizers like ``local_remove_all_assert``. Setting `use_db_name_as_tag` to ``False`` removes that behavior. This mean only the optimizer name and the tags specified will enable that optimization. """ if not isinstance( optimizer, ( OptimizationDatabase, aesara_opt.GlobalOptimizer, aesara_opt.LocalOptimizer, ), ): raise TypeError(f"{optimizer} is not a valid optimizer type.") if name in self.__db__: raise ValueError( f"The tag '{name}' is already present in the database.") if use_db_name_as_tag: if self.name is not None: tags = tags + (self.name, ) optimizer.name = name # This restriction is there because in many place we suppose that # something in the OptimizationDatabase is there only once. if optimizer.name in self.__db__: raise ValueError( f"Tried to register {optimizer.name} again under the new name {name}. " "The same optimization cannot be registered multiple times in" " an ``OptimizationDatabase``; use ProxyDB instead.") self.__db__[name] = OrderedSet([optimizer]) self._names.add(name) self.__db__[optimizer.__class__.__name__].add(optimizer) self.add_tags(name, *tags)
def on_attach(self, fgraph): """ When attaching to a new fgraph, check that 1) This DestroyHandler wasn't already attached to some fgraph (its data structures are only set up to serve one). 2) The FunctionGraph doesn't already have a DestroyHandler. This would result in it validating everything twice, causing compilation to be slower. Give the FunctionGraph instance: 1) A new method "destroyers(var)" TODO: what does this do exactly? 2) A new attribute, "destroy_handler" TODO: WRITEME: what does this do besides the checks? """ # Do the checking # already_there = False if self.fgraph is fgraph: already_there = True if self.fgraph is not None: raise Exception( "A DestroyHandler instance can only serve one" " FunctionGraph. (Matthew 6:24)" ) for attr in ("destroyers", "destroy_handler"): if hasattr(fgraph, attr): already_there = True if already_there: # FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment raise AlreadyThere( "DestroyHandler feature is already present" " or in conflict with another plugin." ) # Annotate the FunctionGraph # self.unpickle(fgraph) fgraph.destroy_handler = self self.fgraph = fgraph self.destroyers = ( OrderedSet() ) # set of Apply instances with non-null destroy_map self.view_i = {} # variable -> variable used in calculation self.view_o = ( {} ) # variable -> set of variables that use this one as a direct input # clients: how many times does an apply use a given variable self.clients = OrderedDict() # variable -> apply -> ninputs self.stale_droot = True self.debug_all_apps = set() if self.do_imports_on_attach: Bookkeeper.on_attach(self, fgraph)
def register(self, name, obj, *tags, **kwargs): """ Parameters ---------- name : str Name of the optimizer. obj The optimizer to register. tags Tag name that allow to select the optimizer. kwargs If non empty, should contain only use_db_name_as_tag=False. By default, all optimizations registered in EquilibriumDB are selected when the EquilibriumDB name is used as a tag. We do not want this behavior for some optimizer like local_remove_all_assert. use_db_name_as_tag=False remove that behavior. This mean only the optimizer name and the tags specified will enable that optimization. """ # N.B. obj is not an instance of class Optimizer. # It is an instance of a DB.In the tests for example, # this is not always the case. if not isinstance(obj, (DB, opt.Optimizer, opt.LocalOptimizer)): raise TypeError("Object cannot be registered in OptDB", obj) if name in self.__db__: raise ValueError( "The name of the object cannot be an existing" " tag or the name of another existing object.", obj, name, ) if kwargs: assert "use_db_name_as_tag" in kwargs assert kwargs["use_db_name_as_tag"] is False else: if self.name is not None: tags = tags + (self.name,) obj.name = name # This restriction is there because in many place we suppose that # something in the DB is there only once. if obj.name in self.__db__: raise ValueError( """You can\'t register the same optimization multiple time in a DB. Tryed to register "%s" again under the new name "%s". Use aesara.gof.ProxyDB to work around that""" % (obj.name, name) ) self.__db__[name] = OrderedSet([obj]) self._names.add(name) self.__db__[obj.__class__.__name__].add(obj) self.add_tags(name, *tags)
def __init__( self, include, require=None, exclude=None, subquery=None, position_cutoff=math.inf, extra_optimizations=None, ): self.include = OrderedSet(include) self.require = require or OrderedSet() self.exclude = exclude or OrderedSet() self.subquery = subquery or {} self.position_cutoff = position_cutoff if extra_optimizations is None: extra_optimizations = [] self.extra_optimizations = extra_optimizations if isinstance(self.require, (list, tuple)): self.require = OrderedSet(self.require) if isinstance(self.exclude, (list, tuple)): self.exclude = OrderedSet(self.exclude)
def on_change_input(self, fgraph, app, i, old_r, new_r, reason): """ app.inputs[i] changed from old_r to new_r. """ if app == "output": # app == 'output' is special key that means FunctionGraph is redefining which nodes are being # considered 'outputs' of the graph. pass else: if app not in self.debug_all_apps: raise ProtocolError("change without import") # UPDATE self.clients self.clients[old_r][app] -= 1 if self.clients[old_r][app] == 0: del self.clients[old_r][app] self.clients.setdefault(new_r, OrderedDict()).setdefault(app, 0) self.clients[new_r][app] += 1 # UPDATE self.view_i, self.view_o for o_idx, i_idx_list in getattr(app.op, "view_map", OrderedDict()).items(): if len(i_idx_list) > 1: # destroying this output invalidates multiple inputs raise NotImplementedError() i_idx = i_idx_list[0] output = app.outputs[o_idx] if i_idx == i: if app.inputs[i_idx] is not new_r: raise ProtocolError("wrong new_r on change") self.view_i[output] = new_r self.view_o[old_r].remove(output) if not self.view_o[old_r]: del self.view_o[old_r] self.view_o.setdefault(new_r, OrderedSet()).add(output) if self.algo == "fast": if app in self.fail_validate: del self.fail_validate[app] self.fast_destroy(fgraph, app, reason) self.stale_droot = True
def _build_droot_impact(destroy_handler): droot = {} # destroyed view + nonview variables -> foundation impact = {} # destroyed nonview variable -> it + all views of it root_destroyer = {} # root -> destroyer apply for app in destroy_handler.destroyers: for output_idx, input_idx_list in app.op.destroy_map.items(): if len(input_idx_list) != 1: raise NotImplementedError() input_idx = input_idx_list[0] input = app.inputs[input_idx] # Find non-view variable which is ultimatly viewed by input. view_i = destroy_handler.view_i _r = input while _r is not None: r = _r _r = view_i.get(r) input_root = r if input_root in droot: raise InconsistencyError( f"Multiple destroyers of {input_root}") droot[input_root] = input_root root_destroyer[input_root] = app # The code here add all the variables that are views of r into # an OrderedSet input_impact input_impact = OrderedSet() q = deque() q.append(input_root) while len(q) > 0: v = q.popleft() for n in destroy_handler.view_o.get(v, []): input_impact.add(n) q.append(n) for v in input_impact: assert v not in droot droot[v] = input_root impact[input_root] = input_impact impact[input_root].add(input_root) return droot, impact, root_destroyer
def on_import(self, fgraph, app, reason): """ Add Apply instance to set which must be computed. """ if app in self.debug_all_apps: raise ProtocolError("double import") self.debug_all_apps.add(app) # print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps) # If it's a destructive op, add it to our watch list dmap = getattr(app.op, "destroy_map", None) vmap = getattr(app.op, "view_map", {}) if dmap: self.destroyers.add(app) if self.algo == "fast": self.fast_destroy(fgraph, app, reason) # add this symbol to the forward and backward maps for o_idx, i_idx_list in vmap.items(): if len(i_idx_list) > 1: raise NotImplementedError( "destroying this output invalidates multiple inputs", (app.op)) o = app.outputs[o_idx] i = app.inputs[i_idx_list[0]] self.view_i[o] = i self.view_o.setdefault(i, OrderedSet()).add(o) # update self.clients for i, input in enumerate(app.inputs): self.clients.setdefault(input, OrderedDict()).setdefault(app, 0) self.clients[input][app] += 1 for i, output in enumerate(app.outputs): self.clients.setdefault(output, OrderedDict()) self.stale_droot = True