Beispiel #1
0
    def register(
        self,
        name: str,
        optimizer: Union["OptimizationDatabase", OptimizersType],
        *tags: str,
        use_db_name_as_tag=True,
        **kwargs,
    ):
        """Register a new optimizer to the database.

        Parameters
        ----------
        name:
            Name of the optimizer.
        opt:
            The optimizer to register.
        tags:
            Tag name that allow to select the optimizer.
        use_db_name_as_tag:
            Add the database's name as a tag, so that its name can be used in a
            query.
            By default, all optimizations registered in ``EquilibriumDB`` are
            selected when the ``"EquilibriumDB"`` name is used as a tag. We do
            not want this behavior for some optimizers like
            ``local_remove_all_assert``. Setting `use_db_name_as_tag` to
            ``False`` removes that behavior. This mean only the optimizer name
            and the tags specified will enable that optimization.

        """
        if not isinstance(
                optimizer,
            (
                OptimizationDatabase,
                aesara_opt.GlobalOptimizer,
                aesara_opt.LocalOptimizer,
            ),
        ):
            raise TypeError(f"{optimizer} is not a valid optimizer type.")

        if name in self.__db__:
            raise ValueError(
                f"The tag '{name}' is already present in the database.")

        if use_db_name_as_tag:
            if self.name is not None:
                tags = tags + (self.name, )

        optimizer.name = name
        # This restriction is there because in many place we suppose that
        # something in the OptimizationDatabase is there only once.
        if optimizer.name in self.__db__:
            raise ValueError(
                f"Tried to register {optimizer.name} again under the new name {name}. "
                "The same optimization cannot be registered multiple times in"
                " an ``OptimizationDatabase``; use ProxyDB instead.")
        self.__db__[name] = OrderedSet([optimizer])
        self._names.add(name)
        self.__db__[optimizer.__class__.__name__].add(optimizer)
        self.add_tags(name, *tags)
Beispiel #2
0
def _build_droot_impact(destroy_handler):
    droot = {}  # destroyed view + nonview variables -> foundation
    impact = {}  # destroyed nonview variable -> it + all views of it
    root_destroyer = {}  # root -> destroyer apply

    for app in destroy_handler.destroyers:
        for output_idx, input_idx_list in app.op.destroy_map.items():
            if len(input_idx_list) != 1:
                raise NotImplementedError()
            input_idx = input_idx_list[0]
            input = app.inputs[input_idx]

            # Find non-view variable which is ultimatly viewed by input.
            view_i = destroy_handler.view_i
            _r = input
            while _r is not None:
                r = _r
                _r = view_i.get(r)
            input_root = r

            if input_root in droot:
                raise InconsistencyError(
                    f"Multiple destroyers of {input_root}")
            droot[input_root] = input_root
            root_destroyer[input_root] = app

            # The code here add all the variables that are views of r into
            # an OrderedSet input_impact
            input_impact = OrderedSet()

            q = deque()
            q.append(input_root)
            while len(q) > 0:
                v = q.popleft()
                for n in destroy_handler.view_o.get(v, []):
                    input_impact.add(n)
                    q.append(n)

            for v in input_impact:
                assert v not in droot
                droot[v] = input_root

            impact[input_root] = input_impact
            impact[input_root].add(input_root)

    return droot, impact, root_destroyer
Beispiel #3
0
    def on_attach(self, fgraph):
        """
        When attaching to a new fgraph, check that
            1) This DestroyHandler wasn't already attached to some fgraph
               (its data structures are only set up to serve one).
            2) The FunctionGraph doesn't already have a DestroyHandler.
               This would result in it validating everything twice, causing
               compilation to be slower.

        Give the FunctionGraph instance:
            1) A new method "destroyers(var)"
               TODO: what does this do exactly?
            2) A new attribute, "destroy_handler"
        TODO: WRITEME: what does this do besides the checks?

        """

        # Do the checking #
        already_there = False
        if self.fgraph is fgraph:
            already_there = True
        if self.fgraph is not None:
            raise Exception(
                "A DestroyHandler instance can only serve one"
                " FunctionGraph. (Matthew 6:24)"
            )
        for attr in ("destroyers", "destroy_handler"):
            if hasattr(fgraph, attr):
                already_there = True

        if already_there:
            # FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment
            raise AlreadyThere(
                "DestroyHandler feature is already present"
                " or in conflict with another plugin."
            )

        # Annotate the FunctionGraph #
        self.unpickle(fgraph)
        fgraph.destroy_handler = self

        self.fgraph = fgraph
        self.destroyers = (
            OrderedSet()
        )  # set of Apply instances with non-null destroy_map
        self.view_i = {}  # variable -> variable used in calculation
        self.view_o = (
            {}
        )  # variable -> set of variables that use this one as a direct input
        # clients: how many times does an apply use a given variable
        self.clients = OrderedDict()  # variable -> apply -> ninputs
        self.stale_droot = True

        self.debug_all_apps = set()
        if self.do_imports_on_attach:
            Bookkeeper.on_attach(self, fgraph)
Beispiel #4
0
    def register(self, name, obj, *tags, **kwargs):
        """

        Parameters
        ----------
        name : str
            Name of the optimizer.
        obj
            The optimizer to register.
        tags
            Tag name that allow to select the optimizer.
        kwargs
            If non empty, should contain only use_db_name_as_tag=False.
            By default, all optimizations registered in EquilibriumDB
            are selected when the EquilibriumDB name is used as a
            tag. We do not want this behavior for some optimizer like
            local_remove_all_assert. use_db_name_as_tag=False remove
            that behavior. This mean only the optimizer name and the
            tags specified will enable that optimization.

        """
        # N.B. obj is not an instance of class Optimizer.
        # It is an instance of a DB.In the tests for example,
        # this is not always the case.
        if not isinstance(obj, (DB, opt.Optimizer, opt.LocalOptimizer)):
            raise TypeError("Object cannot be registered in OptDB", obj)
        if name in self.__db__:
            raise ValueError(
                "The name of the object cannot be an existing"
                " tag or the name of another existing object.",
                obj,
                name,
            )
        if kwargs:
            assert "use_db_name_as_tag" in kwargs
            assert kwargs["use_db_name_as_tag"] is False
        else:
            if self.name is not None:
                tags = tags + (self.name,)
        obj.name = name
        # This restriction is there because in many place we suppose that
        # something in the DB is there only once.
        if obj.name in self.__db__:
            raise ValueError(
                """You can\'t register the same optimization
multiple time in a DB. Tryed to register "%s" again under the new name "%s".
 Use aesara.gof.ProxyDB to work around that"""
                % (obj.name, name)
            )
        self.__db__[name] = OrderedSet([obj])
        self._names.add(name)
        self.__db__[obj.__class__.__name__].add(obj)
        self.add_tags(name, *tags)
Beispiel #5
0
    def __init__(
        self,
        include: Sequence[str],
        require: Optional[Sequence[str]] = None,
        exclude: Optional[Sequence[str]] = None,
        subquery: Optional[Dict[str, "OptimizationQuery"]] = None,
        position_cutoff: float = math.inf,
        extra_optimizations: Optional[Sequence[OptimizersType]] = None,
    ):
        """

        Parameters
        ==========
        include:
            A set of tags such that every optimization obtained through this
            ``OptimizationQuery`` must have **one** of the tags listed. This
            field is required and basically acts as a starting point for the
            search.
        require:
            A set of tags such that every optimization obtained through this
            ``OptimizationQuery`` must have **all** of these tags.
        exclude:
            A set of tags such that every optimization obtained through this
            ``OptimizationQuery`` must have **none** of these tags.
        subquery:
            A dictionary mapping the name of a sub-database to a special
            ``OptimizationQuery``.  If no subquery is given for a sub-database,
            the original ``OptimizationQuery`` will be used again.
        position_cutoff:
            Only optimizations with position less than the cutoff are returned.
        extra_optimizations:
            Extra optimizations to be added.

        """
        self.include = OrderedSet(include)
        self.require = require or OrderedSet()
        self.exclude = exclude or OrderedSet()
        self.subquery = subquery or {}
        self.position_cutoff = position_cutoff
        self.name: Optional[str] = None
        if extra_optimizations is None:
            extra_optimizations = []
        self.extra_optimizations = extra_optimizations
        if isinstance(self.require, (list, tuple)):
            self.require = OrderedSet(self.require)
        if isinstance(self.exclude, (list, tuple)):
            self.exclude = OrderedSet(self.exclude)
Beispiel #6
0
    def on_change_input(self, fgraph, app, i, old_r, new_r, reason):
        """
        app.inputs[i] changed from old_r to new_r.

        """
        if app == "output":
            # app == 'output' is special key that means FunctionGraph is redefining which nodes are being
            # considered 'outputs' of the graph.
            pass
        else:
            if app not in self.debug_all_apps:
                raise ProtocolError("change without import")

            # UPDATE self.clients
            self.clients[old_r][app] -= 1
            if self.clients[old_r][app] == 0:
                del self.clients[old_r][app]

            self.clients.setdefault(new_r, OrderedDict()).setdefault(app, 0)
            self.clients[new_r][app] += 1

            # UPDATE self.view_i, self.view_o
            for o_idx, i_idx_list in getattr(app.op, "view_map",
                                             OrderedDict()).items():
                if len(i_idx_list) > 1:
                    # destroying this output invalidates multiple inputs
                    raise NotImplementedError()
                i_idx = i_idx_list[0]
                output = app.outputs[o_idx]
                if i_idx == i:
                    if app.inputs[i_idx] is not new_r:
                        raise ProtocolError("wrong new_r on change")

                    self.view_i[output] = new_r

                    self.view_o[old_r].remove(output)
                    if not self.view_o[old_r]:
                        del self.view_o[old_r]

                    self.view_o.setdefault(new_r, OrderedSet()).add(output)

            if self.algo == "fast":
                if app in self.fail_validate:
                    del self.fail_validate[app]
                self.fast_destroy(fgraph, app, reason)
        self.stale_droot = True
Beispiel #7
0
 def __init__(
     self,
     include,
     require=None,
     exclude=None,
     subquery=None,
     position_cutoff=math.inf,
     extra_optimizations=None,
 ):
     self.include = OrderedSet(include)
     self.require = require or OrderedSet()
     self.exclude = exclude or OrderedSet()
     self.subquery = subquery or {}
     self.position_cutoff = position_cutoff
     if extra_optimizations is None:
         extra_optimizations = []
     self.extra_optimizations = extra_optimizations
     if isinstance(self.require, (list, tuple)):
         self.require = OrderedSet(self.require)
     if isinstance(self.exclude, (list, tuple)):
         self.exclude = OrderedSet(self.exclude)
Beispiel #8
0
    def on_import(self, fgraph, app, reason):
        """
        Add Apply instance to set which must be computed.

        """
        if app in self.debug_all_apps:
            raise ProtocolError("double import")
        self.debug_all_apps.add(app)
        # print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)

        # If it's a destructive op, add it to our watch list
        dmap = getattr(app.op, "destroy_map", None)
        vmap = getattr(app.op, "view_map", {})
        if dmap:
            self.destroyers.add(app)
            if self.algo == "fast":
                self.fast_destroy(fgraph, app, reason)

        # add this symbol to the forward and backward maps
        for o_idx, i_idx_list in vmap.items():
            if len(i_idx_list) > 1:
                raise NotImplementedError(
                    "destroying this output invalidates multiple inputs",
                    (app.op))
            o = app.outputs[o_idx]
            i = app.inputs[i_idx_list[0]]
            self.view_i[o] = i
            self.view_o.setdefault(i, OrderedSet()).add(o)

        # update self.clients
        for i, input in enumerate(app.inputs):
            self.clients.setdefault(input, OrderedDict()).setdefault(app, 0)
            self.clients[input][app] += 1

        for i, output in enumerate(app.outputs):
            self.clients.setdefault(output, OrderedDict())

        self.stale_droot = True
Beispiel #9
0
class OptimizationQuery:
    """An object that specifies a set of optimizations by tag/name."""
    def __init__(
        self,
        include: Iterable[str],
        require: Optional[Union[OrderedSet, Sequence[str]]] = None,
        exclude: Optional[Union[OrderedSet, Sequence[str]]] = None,
        subquery: Optional[Dict[str, "OptimizationQuery"]] = None,
        position_cutoff: float = math.inf,
        extra_optimizations: Optional[Sequence[Tuple[Union["OptimizationQuery",
                                                           OptimizersType],
                                                     Union[int,
                                                           float]]]] = None,
    ):
        """

        Parameters
        ==========
        include:
            A set of tags such that every optimization obtained through this
            `OptimizationQuery` must have **one** of the tags listed. This
            field is required and basically acts as a starting point for the
            search.
        require:
            A set of tags such that every optimization obtained through this
            `OptimizationQuery` must have **all** of these tags.
        exclude:
            A set of tags such that every optimization obtained through this
            ``OptimizationQuery` must have **none** of these tags.
        subquery:
            A dictionary mapping the name of a sub-database to a special
            `OptimizationQuery`.  If no subquery is given for a sub-database,
            the original `OptimizationQuery` will be used again.
        position_cutoff:
            Only optimizations with position less than the cutoff are returned.
        extra_optimizations:
            Extra optimizations to be added.

        """
        self.include = OrderedSet(include)
        self.require = OrderedSet(require) if require else OrderedSet()
        self.exclude = OrderedSet(exclude) if exclude else OrderedSet()
        self.subquery = subquery or {}
        self.position_cutoff = position_cutoff
        self.name: Optional[str] = None
        if extra_optimizations is None:
            extra_optimizations = []
        self.extra_optimizations = list(extra_optimizations)

    def __str__(self):
        return ("OptimizationQuery(" +
                f"inc={self.include},ex={self.exclude}," +
                f"require={self.require},subquery={self.subquery}," +
                f"position_cutoff={self.position_cutoff}," +
                f"extra_opts={self.extra_optimizations})")

    def __setstate__(self, state):
        self.__dict__.update(state)
        if not hasattr(self, "extra_optimizations"):
            self.extra_optimizations = []

    def including(self, *tags: str) -> "OptimizationQuery":
        """Add rewrites with the given tags."""
        return OptimizationQuery(
            self.include.union(tags),
            self.require,
            self.exclude,
            self.subquery,
            self.position_cutoff,
            self.extra_optimizations,
        )

    def excluding(self, *tags: str) -> "OptimizationQuery":
        """Remove rewrites with the given tags."""
        return OptimizationQuery(
            self.include,
            self.require,
            self.exclude.union(tags),
            self.subquery,
            self.position_cutoff,
            self.extra_optimizations,
        )

    def requiring(self, *tags: str) -> "OptimizationQuery":
        """Filter for rewrites with the given tags."""
        return OptimizationQuery(
            self.include,
            self.require.union(tags),
            self.exclude,
            self.subquery,
            self.position_cutoff,
            self.extra_optimizations,
        )

    def register(
        self, *optimizations: Tuple["OptimizationQuery", Union[int, float]]
    ) -> "OptimizationQuery":
        """Include the given optimizations."""
        return OptimizationQuery(
            self.include,
            self.require,
            self.exclude,
            self.subquery,
            self.position_cutoff,
            self.extra_optimizations + list(optimizations),
        )
Beispiel #10
0
    def __query__(self, q):
        # The ordered set is needed for deterministic optimization.
        variables = OrderedSet()
        for tag in q.include:
            variables.update(self.__db__[tag])
        for tag in q.require:
            variables.intersection_update(self.__db__[tag])
        for tag in q.exclude:
            variables.difference_update(self.__db__[tag])
        remove = OrderedSet()
        add = OrderedSet()
        for obj in variables:
            if isinstance(obj, OptimizationDatabase):
                def_sub_query = q
                if q.extra_optimizations:
                    def_sub_query = copy.copy(q)
                    def_sub_query.extra_optimizations = []
                sq = q.subquery.get(obj.name, def_sub_query)

                replacement = obj.query(sq)
                replacement.name = obj.name
                remove.add(obj)
                add.add(replacement)
        variables.difference_update(remove)
        variables.update(add)
        return variables
Beispiel #11
0
class Query:
    """

    Parameters
    ----------
    position_cutoff : float
        Used by SequenceDB to keep only optimizer that are positioned before
        the cut_off point.

    """

    def __init__(
        self,
        include,
        require=None,
        exclude=None,
        subquery=None,
        position_cutoff=math.inf,
        extra_optimizations=None,
    ):
        self.include = OrderedSet(include)
        self.require = require or OrderedSet()
        self.exclude = exclude or OrderedSet()
        self.subquery = subquery or {}
        self.position_cutoff = position_cutoff
        if extra_optimizations is None:
            extra_optimizations = []
        self.extra_optimizations = extra_optimizations
        if isinstance(self.require, (list, tuple)):
            self.require = OrderedSet(self.require)
        if isinstance(self.exclude, (list, tuple)):
            self.exclude = OrderedSet(self.exclude)

    def __str__(self):
        return (
            "Query{inc=%s,ex=%s,require=%s,subquery=%s,"
            "position_cutoff=%f,extra_opts=%s}"
            % (
                self.include,
                self.exclude,
                self.require,
                self.subquery,
                self.position_cutoff,
                self.extra_optimizations,
            )
        )

    def __setstate__(self, state):
        self.__dict__.update(state)
        if not hasattr(self, "extra_optimizations"):
            self.extra_optimizations = []

    # add all opt with this tag
    def including(self, *tags):
        return Query(
            self.include.union(tags),
            self.require,
            self.exclude,
            self.subquery,
            self.position_cutoff,
            self.extra_optimizations,
        )

    # remove all opt with this tag
    def excluding(self, *tags):
        return Query(
            self.include,
            self.require,
            self.exclude.union(tags),
            self.subquery,
            self.position_cutoff,
            self.extra_optimizations,
        )

    # keep only opt with this tag.
    def requiring(self, *tags):
        return Query(
            self.include,
            self.require.union(tags),
            self.exclude,
            self.subquery,
            self.position_cutoff,
            self.extra_optimizations,
        )

    def register(self, *optimizations):
        return Query(
            self.include,
            self.require,
            self.exclude,
            self.subquery,
            self.position_cutoff,
            self.extra_optimizations + list(optimizations),
        )
Beispiel #12
0
class OptimizationQuery:
    """An object that specifies a set of optimizations by tag/name."""
    def __init__(
        self,
        include: Sequence[str],
        require: Optional[Sequence[str]] = None,
        exclude: Optional[Sequence[str]] = None,
        subquery: Optional[Dict[str, "OptimizationQuery"]] = None,
        position_cutoff: float = math.inf,
        extra_optimizations: Optional[Sequence[OptimizersType]] = None,
    ):
        """

        Parameters
        ==========
        include:
            A set of tags such that every optimization obtained through this
            ``OptimizationQuery`` must have **one** of the tags listed. This
            field is required and basically acts as a starting point for the
            search.
        require:
            A set of tags such that every optimization obtained through this
            ``OptimizationQuery`` must have **all** of these tags.
        exclude:
            A set of tags such that every optimization obtained through this
            ``OptimizationQuery`` must have **none** of these tags.
        subquery:
            A dictionary mapping the name of a sub-database to a special
            ``OptimizationQuery``.  If no subquery is given for a sub-database,
            the original ``OptimizationQuery`` will be used again.
        position_cutoff:
            Only optimizations with position less than the cutoff are returned.
        extra_optimizations:
            Extra optimizations to be added.

        """
        self.include = OrderedSet(include)
        self.require = require or OrderedSet()
        self.exclude = exclude or OrderedSet()
        self.subquery = subquery or {}
        self.position_cutoff = position_cutoff
        self.name: Optional[str] = None
        if extra_optimizations is None:
            extra_optimizations = []
        self.extra_optimizations = extra_optimizations
        if isinstance(self.require, (list, tuple)):
            self.require = OrderedSet(self.require)
        if isinstance(self.exclude, (list, tuple)):
            self.exclude = OrderedSet(self.exclude)

    def __str__(self):
        return ("OptimizationQuery(" +
                f"inc={self.include},ex={self.exclude}," +
                f"require={self.require},subquery={self.subquery}," +
                f"position_cutoff={self.position_cutoff}," +
                f"extra_opts={self.extra_optimizations})")

    def __setstate__(self, state):
        self.__dict__.update(state)
        if not hasattr(self, "extra_optimizations"):
            self.extra_optimizations = []

    # add all opt with this tag
    def including(self, *tags):
        return OptimizationQuery(
            self.include.union(tags),
            self.require,
            self.exclude,
            self.subquery,
            self.position_cutoff,
            self.extra_optimizations,
        )

    # remove all opt with this tag
    def excluding(self, *tags):
        return OptimizationQuery(
            self.include,
            self.require,
            self.exclude.union(tags),
            self.subquery,
            self.position_cutoff,
            self.extra_optimizations,
        )

    # keep only opt with this tag.
    def requiring(self, *tags):
        return OptimizationQuery(
            self.include,
            self.require.union(tags),
            self.exclude,
            self.subquery,
            self.position_cutoff,
            self.extra_optimizations,
        )

    def register(self, *optimizations):
        return OptimizationQuery(
            self.include,
            self.require,
            self.exclude,
            self.subquery,
            self.position_cutoff,
            self.extra_optimizations + list(optimizations),
        )