示例#1
0
    def get_dependent_nodes(self,
                            i: Variable,
                            seen: Optional[Set[int]] = None) -> Set[Variable]:
        if seen is None:
            seen = {i}
        else:
            seen.add(i)

        var_mappings = self.var_mappings

        field_info = self.find_among_fields(i)

        if field_info is None:
            raise ValueError("{} not found among fields.".format(i))

        # Find the `var_mappings` key suffix that matches the field/set of
        # arguments containing our source node
        if field_info.name[:8].endswith("_in"):
            map_key_suffix = "{}p".format(field_info.name[:8])
        else:
            map_key_suffix = field_info.name[:9]

        dependent_nodes = set()
        for k, v in var_mappings.items():

            if not k.endswith(map_key_suffix):
                continue

            dependent_idx = v[field_info.agg_index]
            dependent_idx = (dependent_idx if isinstance(dependent_idx, list)
                             else [dependent_idx])

            # Get the `ScanArgs` field name for the aggregate list property
            # corresponding to these dependent argument types (i.e. either
            # "outer_inputs", "inner_inputs", "inner_outputs", or
            # "outer_outputs").
            # To do this, we need to parse the "shared" prefix of the
            # current `var_mappings` key and append the missing parts so that
            # it either forms `"*_inputs"` or `"*_outputs"`.
            to_agg_field_prefix = k[:9]
            if to_agg_field_prefix.endswith("p"):
                to_agg_field_name = "{}uts".format(to_agg_field_prefix)
            else:
                to_agg_field_name = "{}puts".format(to_agg_field_prefix)

            to_agg_field = getattr(self, to_agg_field_name)

            for d_id in dependent_idx:
                if d_id < 0:
                    continue

                dependent_var = to_agg_field[d_id]

                if dependent_var not in seen:
                    dependent_nodes.add(dependent_var)

        if field_info.name.startswith("inner_in"):
            # If starting from an inner-input, then we need to find any
            # inner-outputs that depend on it.
            for out_n in self.inner_outputs:
                if i in graph_inputs([out_n]):
                    if out_n not in seen:
                        dependent_nodes.add(out_n)

        for n in tuple(dependent_nodes):
            if n in seen:
                continue
            sub_dependent_nodes = self.get_dependent_nodes(n, seen=seen)
            dependent_nodes |= sub_dependent_nodes
            seen |= sub_dependent_nodes

        return dependent_nodes
示例#2
0
    def __init__(
        self,
        inputs: Optional[List[Variable]] = None,
        outputs: Optional[List[Variable]] = None,
        features: Optional[List[Feature]] = None,
        clone: bool = True,
        update_mapping: Optional[Dict[Variable, Variable]] = None,
        memo: Optional[Dict[Variable, Variable]] = None,
        copy_inputs: bool = True,
        copy_orphans: bool = True,
    ):
        """
        Create a `FunctionGraph` which operates on the subgraph between the
        `inputs` and `outputs`.

        Parameters
        ----------
        inputs
            Input variables of the graph.
        outputs
            Output variables of the graph.
        clone
            If ``True``, the graph will be cloned.
        features
            A list of features to be added to the `FunctionGraph`.
        update_mapping
            Mapping between the `inputs` with updates and the `outputs`
            corresponding to their updates.
        memo
            See ``clone_get_equiv``.
        copy_inputs
            See ``clone_get_equiv``.
        copy_orphans
            See ``clone_get_equiv``.
        """
        if outputs is None:
            raise ValueError("No outputs specified")

        if inputs is None:
            inputs = [
                i for i in graph_inputs(outputs)
                if not isinstance(i, Constant)
            ]

        if clone:
            memo = clone_get_equiv(
                inputs,
                outputs,
                copy_inputs=copy_inputs,
                copy_orphans=copy_orphans,
                memo=memo,
            )
            outputs = [memo[o] for o in outputs]
            inputs = [memo[i] for i in inputs]

        self.execute_callbacks_time = 0
        self.execute_callbacks_times = {}

        if features is None:
            features = []

        self._features = []

        # All apply nodes in the subgraph defined by inputs and
        # outputs are cached in this field
        self.apply_nodes = set()

        # Ditto for variable nodes.
        # It must contain all fgraph.inputs and all apply_nodes
        # outputs even if they aren't used in the graph.
        self.variables = set()

        self.inputs = []
        self.outputs = list(outputs)
        self.clients = {}

        for f in features:
            self.attach_feature(f)

        self.attach_feature(ReplaceValidate())

        for in_var in inputs:
            if in_var.owner is not None:
                raise ValueError("One of the provided inputs is the output of "
                                 "an already existing node. "
                                 "If that is okay, either discard that "
                                 "input's owner or use graph.clone.")

            self.add_input(in_var, check=False)

        for output in outputs:
            self.import_var(output, reason="init")
        for i, output in enumerate(outputs):
            self.clients[output].append(("output", i))

        self.profile = None
        self.update_mapping = update_mapping
示例#3
0
文件: aesaraf.py 项目: bwengals/pymc3
def rvs_to_value_vars(
    graphs: Iterable[TensorVariable],
    apply_transforms: bool = False,
    initial_replacements: Optional[Dict[TensorVariable,
                                        TensorVariable]] = None,
    **kwargs,
) -> Tuple[TensorVariable, Dict[TensorVariable, TensorVariable]]:
    """Clone and replace random variables in graphs with their value variables.

    This will *not* recompute test values in the resulting graphs.

    Parameters
    ==========
    graphs
        The graphs in which to perform the replacements.
    apply_transforms
        If ``True``, apply each value variable's transform.
    initial_replacements
        A ``dict`` containing the initial replacements to be made.

    """

    # Avoid circular dependency
    from pymc.distributions import NoDistribution

    def transform_replacements(var, replacements):
        rv_var, rv_value_var = extract_rv_and_value_vars(var)

        if rv_value_var is None:
            # If RandomVariable does not have a value_var and corresponds to
            # a NoDistribution, we allow further replacements in upstream graph
            if isinstance(rv_var.owner.op, NoDistribution):
                return rv_var.owner.inputs

            else:
                warnings.warn(f"No value variable found for {rv_var}; "
                              "the random variable will not be replaced.")
                return []

        transform = getattr(rv_value_var.tag, "transform", None)

        if transform is None or not apply_transforms:
            replacements[var] = rv_value_var
            # In case the value variable is itself a graph, we walk it for
            # potential replacements
            return [rv_value_var]

        trans_rv_value = transform.backward(rv_value_var, *rv_var.owner.inputs)
        replacements[var] = trans_rv_value

        # Walk the transformed variable and make replacements
        return [trans_rv_value]

    # Clone original graphs
    inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)]
    equiv = clone_get_equiv(inputs, graphs, False, False, {})
    graphs = [equiv[n] for n in graphs]

    if initial_replacements:
        initial_replacements = {
            equiv.get(k, k): equiv.get(v, v)
            for k, v in initial_replacements.items()
        }

    return replace_rvs_in_graphs(graphs, transform_replacements,
                                 initial_replacements, **kwargs)
示例#4
0
def scan(
    fn,
    sequences=None,
    outputs_info=None,
    non_sequences=None,
    n_steps=None,
    truncate_gradient=-1,
    go_backwards=False,
    mode=None,
    name=None,
    profile=False,
    allow_gc=None,
    strict=False,
    return_list=False,
):
    r"""This function constructs and applies a `Scan` `Op` to the provided arguments.

    Parameters
    ----------
    fn
        `fn` is a function that describes the operations involved in one
        step of `scan`. `fn` should construct variables describing the
        output of one iteration step. It should expect as input
        `Variable`\s representing all the slices of the input sequences
        and previous values of the outputs, as well as all other arguments
        given to scan as `non_sequences`. The order in which scan passes
        these variables to `fn`  is the following :

        * all time slices of the first sequence
        * all time slices of the second sequence
        * ...
        * all time slices of the last sequence
        * all past slices of the first output
        * all past slices of the second output
        * ...
        * all past slices of the last output
        * all other arguments (the list given as `non_sequences` to
            `scan`)

        The order of the sequences is the same as the one in the list
        `sequences` given to `scan`. The order of the outputs is the same
        as the order of `outputs_info`. For any sequence or output the
        order of the time slices is the same as the one in which they have
        been given as taps. For example if one writes the following :

        .. code-block:: python

            scan(fn, sequences = [ dict(input= Sequence1, taps = [-3,2,-1])
                                 , Sequence2
                                 , dict(input =  Sequence3, taps = 3) ]
                   , outputs_info = [ dict(initial =  Output1, taps = [-3,-5])
                                    , dict(initial = Output2, taps = None)
                                    , Output3 ]
                   , non_sequences = [ Argument1, Argument2])

        `fn` should expect the following arguments in this given order:

        #. ``sequence1[t-3]``
        #. ``sequence1[t+2]``
        #. ``sequence1[t-1]``
        #. ``sequence2[t]``
        #. ``sequence3[t+3]``
        #. ``output1[t-3]``
        #. ``output1[t-5]``
        #. ``output3[t-1]``
        #. ``argument1``
        #. ``argument2``

        The list of `non_sequences` can also contain shared variables
        used in the function, though `scan` is able to figure those
        out on its own so they can be skipped. For the clarity of the
        code we recommend though to provide them to `scan`. To some extend
        `scan` can also figure out other `non sequences` (not shared)
        even if not passed to `scan` (but used by `fn`). A simple example of
        this would be :

        .. code-block:: python

            import aesara.tensor as at

            W   = at.matrix()
            W_2 = W**2

            def f(x):
                return at.dot(x,W_2)

        The function `fn` is expected to return two things. One is a list of
        outputs ordered in the same order as `outputs_info`, with the
        difference that there should be only one output variable per
        output initial state (even if no tap value is used). Secondly
        `fn` should return an update dictionary (that tells how to
        update any shared variable after each iteration step). The
        dictionary can optionally be given as a list of tuples. There is
        no constraint on the order of these two list, `fn` can return
        either ``(outputs_list, update_dictionary)`` or
        ``(update_dictionary, outputs_list)`` or just one of the two (in
        case the other is empty).

        To use `scan` as a ``while`` loop, the user needs to change the
        function `fn` such that also a stopping condition is returned.
        To do so, one needs to wrap the condition in an `until` class.
        The condition should be returned as a third element, for example:

        .. code-block:: python

            ...
            return [y1_t, y2_t], {x:x+1}, until(x < 50)

        Note that a number of steps--considered in here as the maximum
        number of steps--is still required even though a condition is
        passed.  It is used to allocate memory if needed.

    sequences
        `sequences` is the list of `Variable`\s or ``dict``\s
        describing the sequences `scan` has to iterate over. If a
        sequence is given as wrapped in a ``dict``, then a set of optional
        information can be provided about the sequence. The ``dict``
        should have the following keys:

        * ``input`` (*mandatory*) -- `Variable` representing the
          sequence.

        * ``taps`` -- Temporal taps of the sequence required by `fn`.
          They are provided as a list of integers, where a value ``k``
          impiles that at iteration step ``t`` scan will pass to `fn`
          the slice ``t+k``. Default value is ``[0]``

        All `Variable`\s in the list `sequences` are automatically
        wrapped into a ``dict`` where ``taps`` is set to ``[0]``

    outputs_info
        `outputs_info` is the list of `Variable`\s or ``dict``\s
        describing the initial state of the outputs computed
        recurrently. When the initial states are given as ``dict``\s,
        optional information can be provided about the output corresponding
        to those initial states. The ``dict`` should have the following
        keys:

        * ``initial`` -- A `Variable` that represents the initial
          state of a given output. In case the output is not computed
          recursively (e.g. a ``map``-like function) and does not require an initial
          state, this field can be skipped. Given that only the previous
          time step of the output is used by `fn`, the initial state
          **should have the same shape** as the output and **should not
          involve a downcast** of the data type of the output. If multiple
          time taps are used, the initial state should have one extra
          dimension that covers all the possible taps. For example
          if we use ``-5``, ``-2`` and ``-1`` as past taps, at step ``0``,
          `fn` will require (by an abuse of notation) ``output[-5]``,
          ``output[-2]`` and ``output[-1]``. This will be given by
          the initial state, which in this case should have the shape
          ``(5,) + output.shape``. If this `Variable` containing the initial
          state is called ``init_y`` then ``init_y[0]`` corresponds to
          ``output[-5]``. ``init_y[1]`` corresponds to ``output[-4]``,
          ``init_y[2]`` corresponds to ``output[-3]``, ``init_y[3]``
          corresponds to ``output[-2]``, ``init_y[4]`` corresponds to
          ``output[-1]``.
          While this order might seem strange, it comes natural from splitting
          an array at a given point. assume that we have a array ``x``, and we
          choose ``k`` to be time step ``0``. Then our initial state would be
          ``x[:k]``, while the output will be ``x[k:]``. Looking at this split,
          elements in ``x[:k]`` are ordered exactly like those in ``init_y``.
        * ``taps`` -- Temporal taps of the output that will be passed to
          `fn`. They are provided as a list of *negative* integers,
          where a value ``k`` implies that at iteration step ``t`` scan
          will pass to `fn` the slice ``t+k``.

        `scan` will follow this logic if partial information is given:

        * If an output is not wrapped in a ``dict``, `scan` will wrap
          it in one assuming that you use only the last step of the output
          (i.e. it makes your tap value list equal to ``[-1]``).
        * If you wrap an output in a ``dict`` and you do not provide any
          taps but you provide an initial state it will assume that you are
          using only a tap value of ``-1``.
        * If you wrap an output in a ``dict`` but you do not provide any
          initial state, it assumes that you are not using any form of
          taps.
        * If you provide a ``None`` instead of a `Variable` or a empty
          ``dict`` `scan` assumes that you will not use any taps for
          this output (like for example in case of a ``map``)

        If `outputs_info` is an empty ``list`` or ``None``, `scan` assumes
        that no tap is used for any of the outputs. If information is
        provided just for a subset of the outputs, an exception is
        raised, because there is no convention on how scan should map
        the provided information to the outputs of `fn`.

    non_sequences
        `non_sequences` is the list of arguments that are passed to
        `fn` at each steps. One can choose to exclude variables
        used in `fn` from this list, as long as they are part of the
        computational graph, although--for clarity--this is *not* encouraged.

    n_steps
        `n_steps` is the number of steps to iterate given as an ``int``
        or a scalar `Variable`. If any of the input sequences do not have
        enough elements, `scan` will raise an error. If the value is ``0``, the
        outputs will have ``0`` rows. If `n_steps` is not provided, `scan` will
        figure out the amount of steps it should run given its input
        sequences. ``n_steps < 0`` is not supported anymore.

    truncate_gradient
        `truncate_gradient` is the number of steps to use in truncated
        back-propagation through time (BPTT).  If you compute gradients through
        a `Scan` `Op`, they are computed using BPTT. By providing a different
        value then ``-1``, you choose to use truncated BPTT instead of classical
        BPTT, where you go for only `truncate_gradient` number of steps back in
        time.

    go_backwards
        `go_backwards` is a flag indicating if `scan` should go
        backwards through the sequences. If you think of each sequence
        as indexed by time, making this flag ``True`` would mean that
        `scan` goes back in time, namely that for any sequence it
        starts from the end and goes towards ``0``.

    name
        When profiling `scan`, it is helpful to provide a name for any
        instance of `scan`.
        For example, the profiler will produce an overall profile of your code
        as well as profiles for the computation of one step of each instance of
        `Scan`. The `name` of the instance appears in those profiles and can
        greatly help to disambiguate information.

    mode
        The mode used to compile the inner-graph.
        If you prefer the computations of one step of `scan` to be done
        differently then the entire function, you can use this parameter to
        describe how the computations in this loop are done (see
        `aesara.function` for details about possible values and their meaning).

    profile
        If ``True`` or a non-empty string, a profile object will be created and
        attached to the inner graph of `Scan`. When `profile` is ``True``, the
        profiler results will use the name of the `Scan` instance, otherwise it
        will use the passed string.  The profiler only collects and prints
        information when running the inner graph with the `CVM` `Linker`.

    allow_gc
        Set the value of `allow_gc` for the internal graph of the `Scan`.  If
        set to ``None``, this will use the value of
        `aesara.config.scan__allow_gc`.

        The full `Scan` behavior related to allocation is determined by this
        value and the flag `aesara.config.allow_gc`. If the flag
        `allow_gc` is ``True`` (default) and this `allow_gc` is ``False``
        (default), then we let `Scan` allocate all intermediate memory
        on the first iteration, and they are not garbage collected
        after that first iteration; this is determined by `allow_gc`. This can
        speed up allocation of the subsequent iterations. All those temporary
        allocations are freed at the end of all iterations; this is what the
        flag `aesara.config.allow_gc` means.

        If you use pre-allocation and this `Scan` is on GPU, the speed up from
        `allow_gc` is small. If you are missing memory, disabling `allow_gc`
        could help you run graph that request much memory.

    strict
        If ``True``, all the shared variables used in `fn` must be provided as a
        part of `non_sequences` or `sequences`.

    return_list
        If ``True``, will always return a ``list``, even if there is only one output.

    Returns
    -------
    tuple
        ``tuple`` of the form ``(outputs, updates)``.
        ``outputs`` is either a `Variable` or a ``list`` of `Variable`\s
        representing the outputs in the same order as in `outputs_info`.
        ``updates`` is a subclass of ``dict`` specifying the update rules for
        all shared variables used in `Scan`.
        This ``dict`` should be passed to `aesara.function` when you compile
        your function.

    """

    # General observation : this code is executed only once, at creation
    # of the computational graph, so we don't yet need to be smart about
    # anything (to speed things up)

    ##
    # Step 1. Wrap all inputs in dictionaries and add default values
    ##

    # check if inputs are just single variables instead of lists
    def wrap_into_list(x):
        """
        Wrap the input into a list if it is not already a list.

        """
        if x is None:
            return []
        elif not isinstance(x, (list, tuple)):
            return [x]
        else:
            return list(x)

    seqs = wrap_into_list(sequences)
    outs_info = wrap_into_list(outputs_info)

    # Make sure we get rid of numpy arrays or ints or anything like that
    # passed as inputs to scan
    non_seqs = []
    for elem in wrap_into_list(non_sequences):
        if not isinstance(elem, Variable):
            non_seqs.append(at.as_tensor_variable(elem))
        else:
            non_seqs.append(elem)

    # If we provided a known number of steps ( before compilation)
    # and if that number is 1 or -1, then we can skip the Scan Op,
    # and just apply the inner function once
    # To do that we check here to see the nature of n_steps
    n_fixed_steps = None

    if isinstance(n_steps, (float, int)):
        n_fixed_steps = int(n_steps)
    else:
        try:
            n_fixed_steps = at.get_scalar_constant_value(n_steps)
        except NotScalarConstantError:
            n_fixed_steps = None

    # Check n_steps is an int
    if hasattr(n_steps, "dtype") and str(n_steps.dtype) not in integer_dtypes:
        raise ValueError(
            f" n_steps must be an int. dtype provided is {n_steps.dtype}")

    # compute number of sequences and number of outputs
    n_seqs = len(seqs)
    n_outs = len(outs_info)

    return_steps = OrderedDict()
    # wrap sequences in a dictionary if they are not already dictionaries
    for i in range(n_seqs):
        if not isinstance(seqs[i], dict):
            seqs[i] = OrderedDict([("input", seqs[i]), ("taps", [0])])
        elif seqs[i].get("taps", None) is not None:
            seqs[i]["taps"] = wrap_into_list(seqs[i]["taps"])
        elif seqs[i].get("taps", None) is None:
            # seqs dictionary does not have the ``taps`` key
            seqs[i]["taps"] = [0]

    # wrap outputs info in a dictionary if they are not already in one
    for i in range(n_outs):
        if outs_info[i] is not None:
            if isinstance(outs_info[i], dict):
                if outs_info[i].get("return_steps", None) is not None:
                    raise DeprecationWarning(
                        "Using `return_steps` has been deprecated. "
                        "Simply select the entries you need using a "
                        "subtensor. Scan will optimize memory "
                        "consumption, so do not worry about that.")
                # END

            if not isinstance(outs_info[i], dict):
                # by default any output has a tap value of -1
                outs_info[i] = OrderedDict([("initial", outs_info[i]),
                                            ("taps", [-1])])
            elif (outs_info[i].get("initial", None) is None
                  and outs_info[i].get("taps", None) is not None):
                # ^ no initial state but taps provided
                raise ValueError("If you are using slices of an output "
                                 "you need to provide a initial state "
                                 f"for it: {outs_info[i]}")
            elif (outs_info[i].get("initial", None) is not None
                  and outs_info[i].get("taps", None) is None):
                # ^ initial state but taps not provided
                if "taps" in outs_info[i]:
                    # ^ explicitly provided a None for taps
                    _logger.warning(
                        f"Output {getattr(outs_info[i]['initial'], 'name', 'None')} (index {i}) has a initial "
                        "state but taps is explicitly set to None ", )
                outs_info[i]["taps"] = [-1]
            elif outs_info[i].get("taps", None) is not None:
                # Check that taps are valid (< 0 and all dfferent)
                taps = outs_info[i]["taps"]
                if len(taps) > len(set(taps)):
                    raise ValueError(
                        ("All the taps must be different in "
                         " `outputs_info`"),
                        outs_info[i],
                    )
                for t in taps:
                    if t >= 0:
                        raise ValueError(
                            ("All the tap values must be "
                             "smaller than 0."),
                            outs_info[i],
                        )
        else:
            # if a None is provided as the output info we replace it
            # with an empty OrdereDict() to simplify handling
            outs_info[i] = OrderedDict()

    ##
    # Step 2. Generate inputs and outputs of the inner functions
    # for compiling a dummy function (Iteration #1)
    ##

    # create aesara inputs for the recursive function
    # note : this is a first batch of possible inputs that will
    #        be compiled in a dummy function; we used this dummy
    #        function to detect shared variables and their updates
    #        and to construct a new and complete list of inputs and
    #        outputs

    n_seqs = 0
    scan_seqs = []  # Variables passed as inputs to the scan op
    inner_seqs = []  # Variables passed as inputs to the inner function
    inner_slices = []  # Actual slices if scan is removed from the picture
    # go through sequences picking up time slices as needed
    for i, seq in enumerate(seqs):
        # Note that you can have something like no taps for
        # a sequence, though is highly unlikely in practice
        if "taps" in seq:
            # go through the indicated slice
            mintap = np.min(seq["taps"])
            maxtap = np.max(seq["taps"])
            # We cut the sequence such that seq[i] to correspond to
            # seq[i-k]. For the purposes of cutting the sequences, we
            # need to pretend tap 0 is used to avoid cutting the sequences
            # too long if the taps are all lower or all higher than 0.
            maxtap_proxy = max(maxtap, 0)
            mintap_proxy = min(mintap, 0)
            for k in seq["taps"]:
                # create one slice of the input
                # Later on, if we decide not to use scan because we are
                # going for just one step, it makes things easier if we
                # compute the correct outputs here. This way we can use
                # the output of the lambda expression directly to replace
                # the output of scan.

                # If not we need to use copies, that will be replaced at
                # each frame by the corresponding slice
                actual_slice = seq["input"][k - mintap_proxy]
                _seq_val = at.as_tensor_variable(seq["input"])
                _seq_val_slice = _seq_val[k - mintap_proxy]
                nw_slice = _seq_val_slice.type()

                # Try to transfer test_value to the new variable
                if config.compute_test_value != "off":
                    try:
                        nw_slice.tag.test_value = get_test_value(
                            _seq_val_slice)
                    except TestValueError:
                        if config.compute_test_value != "ignore":
                            # No need to print a warning or raise an error now,
                            # it will be done when fn will be called.
                            _logger.warning(
                                ("Cannot compute test value for "
                                 "the inner function of scan, input value "
                                 "missing {}").format(_seq_val_slice))

                # Add names to slices for debugging and pretty printing ..
                # that is if the input already has a name
                if getattr(seq["input"], "name", None) is not None:
                    if k > 0:
                        nw_name = seq["input"].name + f"[t+{int(k)}]"
                    elif k == 0:
                        nw_name = seq["input"].name + "[t]"
                    else:
                        nw_name = seq["input"].name + f"[t{int(k)}]"
                    nw_slice.name = nw_name

                start = k - mintap_proxy
                nw_name = None
                if k == maxtap_proxy:
                    nw_seq = seq["input"][start:]
                    if getattr(seq["input"], "name", None) is not None:
                        nw_name = seq["input"].name + f"[{int(start)}:]"
                else:
                    end = -(maxtap_proxy - k)
                    nw_seq = seq["input"][start:end]
                    if getattr(seq["input"], "name", None) is not None:
                        nw_name = seq[
                            "input"].name + f"[{int(start)}:{int(end)}]"

                if go_backwards:
                    nw_seq = nw_seq[::-1]

                scan_seqs.append(nw_seq)
                inner_seqs.append(nw_slice)
                inner_slices.append(actual_slice)
                n_seqs += 1
                # Add names -- it helps a lot when debugging
                if nw_name is not None:
                    nw_seq.name = nw_name

    # Since we've added all sequences now we need to level them up based on
    # n_steps or their different shapes
    lengths_vec = []
    for seq in scan_seqs:
        lengths_vec.append(seq.shape[0])

    if not utils.isNaN_or_Inf_or_None(n_steps):
        # ^ N_steps should also be considered
        lengths_vec.append(at.as_tensor(n_steps))

    if len(lengths_vec) == 0:
        # ^ No information about the number of steps
        raise ValueError("No information about the number of steps "
                         "provided. Either provide a value for "
                         "n_steps argument of scan or provide an input "
                         "sequence")

    # If the user has provided the number of steps, do that regardless ( and
    # raise an error if the sequences are not long enough )
    if utils.isNaN_or_Inf_or_None(n_steps):
        actual_n_steps = lengths_vec[0]
        for contestant in lengths_vec[1:]:
            actual_n_steps = minimum(actual_n_steps, contestant)
    else:
        actual_n_steps = at.as_tensor(n_steps)

    scan_seqs = [seq[:actual_n_steps] for seq in scan_seqs]
    # Conventions :
    #   mit_mot = multiple input taps, multiple output taps ( only provided
    #             by the gradient function )
    #   mit_sot = multiple input taps, single output tap (t + 0)
    #   sit_sot = single input tap, single output tap (t + 0)
    #   nit_sot = no input tap, single output tap (t + 0)

    # MIT_MOT -- not provided by the user only by the grad function
    n_mit_mot = 0
    n_mit_mot_outs = 0
    mit_mot_scan_inputs = []
    mit_mot_inner_inputs = []
    mit_mot_inner_outputs = []
    mit_mot_out_slices = []

    # SIT_SOT -- provided by the user
    n_mit_sot = 0
    mit_sot_scan_inputs = []
    mit_sot_inner_inputs = []
    mit_sot_inner_slices = []
    mit_sot_inner_outputs = []
    mit_sot_return_steps = OrderedDict()
    mit_sot_tap_array = []
    mit_sot_rightOrder = []

    n_sit_sot = 0
    sit_sot_scan_inputs = []
    sit_sot_inner_inputs = []
    sit_sot_inner_slices = []
    sit_sot_inner_outputs = []
    sit_sot_return_steps = OrderedDict()
    sit_sot_rightOrder = []

    # go through outputs picking up time slices as needed
    for i, init_out in enumerate(outs_info):
        # Note that our convention dictates that if an output uses
        # just the previous time step, as a initial state we will only
        # provide a tensor of the same dimension as one time step; This
        # makes code much cleaner for those who do not use taps. Otherwise
        # they would always had to shape_padleft the initial state ..
        # which is ugly
        if init_out.get("taps", None) == [-1]:

            actual_arg = init_out["initial"]
            if not isinstance(actual_arg, Variable):
                actual_arg = at.as_tensor_variable(actual_arg)
            arg = safe_new(actual_arg)
            if isinstance(arg, Constant):
                # safe new returns a clone of the constants, but that is not
                # what we need for initial states
                arg = arg.type()

            # Try to transfer test_value to the new variable
            if config.compute_test_value != "off":
                try:
                    arg.tag.test_value = get_test_value(actual_arg)
                except TestValueError:
                    if config.compute_test_value != "ignore":
                        _logger.warning(
                            ("Cannot compute test value for the "
                             "inner function of scan, test value missing: {}"
                             ).format(actual_arg))

            if getattr(init_out["initial"], "name", None) is not None:
                arg.name = init_out["initial"].name + "[t-1]"

            # We need now to allocate space for storing the output and copy
            # the initial state over. We do this using the expand function
            # defined in scan utils
            sit_sot_scan_inputs.append(
                utils.expand_empty(
                    at.unbroadcast(shape_padleft(actual_arg), 0),
                    actual_n_steps,
                ))

            sit_sot_inner_slices.append(actual_arg)
            if i in return_steps:
                sit_sot_return_steps[n_sit_sot] = return_steps[i]
            sit_sot_inner_inputs.append(arg)
            sit_sot_rightOrder.append(i)
            n_sit_sot += 1

        elif init_out.get("taps", None):

            if np.any(np.array(init_out.get("taps", [])) > 0):
                # Make sure we do not have requests for future values of a
                # sequence we can not provide such values
                raise ValueError("Can not use future taps of outputs",
                                 init_out)
            # go through the taps
            mintap = abs(np.min(init_out["taps"]))
            mit_sot_tap_array.append(init_out["taps"])
            # Sequence
            mit_sot_scan_inputs.append(
                utils.expand_empty(init_out["initial"][:mintap],
                                   actual_n_steps))

            if i in return_steps:
                mit_sot_return_steps[n_mit_sot] = return_steps[i]
            mit_sot_rightOrder.append(i)
            n_mit_sot += 1
            for k in init_out["taps"]:
                # create a new slice
                actual_nw_slice = init_out["initial"][k + mintap]
                _init_out_var = at.as_tensor_variable(init_out["initial"])
                _init_out_var_slice = _init_out_var[k + mintap]
                nw_slice = _init_out_var_slice.type()

                # Try to transfer test_value to the new variable
                if config.compute_test_value != "off":
                    try:
                        nw_slice.tag.test_value = get_test_value(
                            _init_out_var_slice)
                    except TestValueError:
                        if config.compute_test_value != "ignore":
                            _logger.warning(
                                ("Cannot compute test value for "
                                 "the inner function of scan, test value "
                                 "missing: {}").format(_init_out_var_slice))

                # give it a name or debugging and pretty printing
                if getattr(init_out["initial"], "name", None) is not None:
                    if k > 0:
                        nw_slice.name = init_out[
                            "initial"].name + f"[t+{int(k)}]"
                    elif k == 0:
                        nw_slice.name = init_out["initial"].name + "[t]"
                    else:
                        nw_slice.name = init_out[
                            "initial"].name + f"[t{int(k)}]"
                mit_sot_inner_inputs.append(nw_slice)
                mit_sot_inner_slices.append(actual_nw_slice)
        # NOTE: there is another case, in which we do not want to provide
        #      any previous value of the output to the inner function (i.e.
        #      a map); in that case we do not have to do anything ..

    # Re-order args
    max_mit_sot = np.max([-1] + mit_sot_rightOrder) + 1
    max_sit_sot = np.max([-1] + sit_sot_rightOrder) + 1
    n_elems = np.max([max_mit_sot, max_sit_sot])
    _ordered_args = [[] for x in range(n_elems)]
    offset = 0
    for idx in range(n_mit_sot):
        n_inputs = len(mit_sot_tap_array[idx])
        if n_fixed_steps in (1, -1):
            _ordered_args[
                mit_sot_rightOrder[idx]] = mit_sot_inner_slices[offset:offset +
                                                                n_inputs]
        else:
            _ordered_args[
                mit_sot_rightOrder[idx]] = mit_sot_inner_inputs[offset:offset +
                                                                n_inputs]
        offset += n_inputs

    for idx in range(n_sit_sot):
        if n_fixed_steps in (1, -1):
            _ordered_args[sit_sot_rightOrder[idx]] = [
                sit_sot_inner_slices[idx]
            ]
        else:
            _ordered_args[sit_sot_rightOrder[idx]] = [
                sit_sot_inner_inputs[idx]
            ]

    ordered_args = []
    for ls in _ordered_args:
        ordered_args += ls
    if n_fixed_steps in (1, -1):
        args = inner_slices + ordered_args + non_seqs

    else:
        args = inner_seqs + ordered_args + non_seqs

    # add only the non-shared variables and non-constants to the arguments of
    # the dummy function [ a function should not get shared variables or
    # constants as input ]
    dummy_args = [
        arg for arg in args if (not isinstance(arg, SharedVariable)
                                and not isinstance(arg, Constant))
    ]
    # when we apply the lambda expression we get a mixture of update rules
    # and outputs that needs to be separated

    condition, outputs, updates = utils.get_updates_and_outputs(fn(*args))
    if condition is not None:
        as_while = True
    else:
        as_while = False
    ##
    # Step 3. Check if we actually need scan and remove it if we don't
    ##

    if n_fixed_steps in (1, -1):
        for pos, inner_out in enumerate(outputs):
            # we need to see if we need to pad our sequences with an
            # unbroadcastable dimension; case example : we return an
            # output for which we want all intermediate. If n_steps is 1
            # then, if we return the output as given by the innner function
            # this will represent only a slice and it will have one
            # dimension less.
            if isinstance(inner_out.type,
                          TensorType) and return_steps.get(pos, 0) != 1:
                outputs[pos] = at.unbroadcast(shape_padleft(inner_out), 0)

        if not return_list and len(outputs) == 1:
            outputs = outputs[0]

        return (outputs, updates)

    ##
    # Step 4. Compile the dummy function
    ##

    # We can now compile a dummy function just to see what shared variable
    # we have and what are their update rules (note that the user has
    # the option not to pass the shared variable to scan, so we need to
    # pick them manually and add them to scan)

    # extract still missing inputs (there still might be so) and add them
    # as non sequences at the end of our args
    if condition is not None:
        outputs.append(condition)
    fake_nonseqs = [x.type() for x in non_seqs]
    fake_outputs = clone_replace(outputs,
                                 replace=OrderedDict(
                                     zip(non_seqs, fake_nonseqs)))
    all_inputs = filter(
        lambda x: (isinstance(x, Variable) and not isinstance(
            x, SharedVariable) and not isinstance(x, Constant)),
        graph_inputs(fake_outputs),
    )
    extra_inputs = [x for x in all_inputs if x not in args + fake_nonseqs]
    non_seqs += extra_inputs
    # Note we do not use all_inputs directly since the order of variables
    # in args is quite important
    dummy_args += extra_inputs

    dummy_outs = outputs
    # Perform a try-except to provide a meaningful error message to the
    # user if inputs of the inner function are missing.
    try:
        dummy_inputs, dummy_outputs = construct_pfunc_ins_and_outs(
            dummy_args, dummy_outs, updates=updates)
    except MissingInputError as err:
        msg = ("\nPlease pass this variable to the scan's inner function. Do "
               "not forget to also pass it to the `non_sequences` attribute "
               "of scan.")
        raise MissingInputError(err.args[0] + msg)
    ##
    # Step 5. Re-arange inputs of scan into a more strict order
    ##

    # Step 5.0 Check the outputs of the dummy function to see if they
    # match with user provided data

    # if the number of outputs to the function does not match the number of
    # assumed outputs until now (provided by the user) there can be
    # only one explanation: No information is provided for any of the
    # outputs (i.e. we are dealing with a map)
    tmp_dummy_f_outs = len(dummy_outputs)
    if as_while:
        tmp_dummy_f_outs -= 1
    if not (tmp_dummy_f_outs == n_outs or outs_info == []):
        raise ValueError("Please provide None as outputs_info for "
                         "any output that does not feed back into "
                         "scan (i.e. it behaves like a map) ")

    if outs_info == []:
        n_outs = len(dummy_outputs)
        if as_while:
            n_outs = n_outs - 1
        outs_info = [OrderedDict() for x in range(n_outs)]

    # Step 5.1 Outputs with taps different then -1

    for i, out in enumerate(outs_info):
        if "taps" in out and out["taps"] != [-1]:
            mit_sot_inner_outputs.append(outputs[i])

    # Step 5.2 Outputs with tap equal to -1
    for i, out in enumerate(outs_info):
        if "taps" in out and out["taps"] == [-1]:
            sit_sot_inner_outputs.append(outputs[i])

    # Step 5.3 Outputs that correspond to update rules of shared variables
    givens = OrderedDict()
    n_shared_outs = 0
    shared_scan_inputs = []
    shared_inner_inputs = []
    shared_inner_outputs = []
    sit_sot_shared = []
    for input in dummy_inputs:
        if isinstance(input.variable, SharedVariable) and input.update:
            new_var = safe_new(input.variable)
            if getattr(input.variable, "name", None) is not None:
                new_var.name = input.variable.name + "_copy"
            if isinstance(new_var.type, TensorType):
                sit_sot_inner_inputs.append(new_var)
                sit_sot_scan_inputs.append(
                    utils.expand_empty(
                        at.unbroadcast(shape_padleft(input.variable), 0),
                        actual_n_steps,
                    ))
                tensor_update = at.as_tensor_variable(input.update)
                sit_sot_inner_outputs.append(tensor_update)
                # Note that `pos` is not a negative index. The sign of `pos` is used
                # as a flag to indicate if this output should be part of the
                # update rules or part of the standard outputs of `scan`.
                # If `pos` is positive then it corresponds to the standard
                # outputs of `scan` and it refers to output of index `pos`. If `pos`
                # is negative that it corresponds to update rules of `scan` and it
                # refers to the update rule with index `-1 - pos`.
                sit_sot_rightOrder.append(-1 - len(sit_sot_shared))
                sit_sot_shared.append(input.variable)
                givens[input.variable] = new_var

            else:
                shared_inner_inputs.append(new_var)
                shared_scan_inputs.append(input.variable)
                shared_inner_outputs.append(input.update)
                givens[input.variable] = new_var
                n_shared_outs += 1

    n_sit_sot = len(sit_sot_inner_inputs)

    # Step 5.4 Outputs with no taps used in the input
    n_nit_sot = 0
    nit_sot_inner_outputs = []
    nit_sot_return_steps = OrderedDict()
    nit_sot_rightOrder = []
    for i, out in enumerate(outs_info):
        if "taps" not in out:
            nit_sot_inner_outputs.append(outputs[i])
            if i in return_steps:
                nit_sot_return_steps[n_nit_sot] = return_steps[i]
            nit_sot_rightOrder.append(i)
            n_nit_sot += 1

    # Step 5.5 all other arguments including extra inputs
    other_scan_args = []
    other_inner_args = []

    other_scan_args += [
        arg for arg in non_seqs if (not isinstance(arg, SharedVariable)
                                    and not isinstance(arg, Constant))
    ]

    # Step 5.6 all shared variables with no update rules
    other_inner_args += [
        safe_new(arg, "_copy") for arg in non_seqs if
        (not isinstance(arg, SharedVariable) and not isinstance(arg, Constant))
    ]

    givens.update(OrderedDict(zip(other_scan_args, other_inner_args)))

    if strict:
        non_seqs_set = set(non_sequences if non_sequences is not None else [])

        other_shared_scan_args = [
            arg.variable for arg in dummy_inputs
            if (isinstance(arg.variable, SharedVariable) and not arg.update
                and arg.variable in non_seqs_set)
        ]
        other_shared_inner_args = [
            safe_new(arg.variable, "_copy") for arg in dummy_inputs
            if (isinstance(arg.variable, SharedVariable) and not arg.update
                and arg.variable in non_seqs_set)
        ]
    else:
        other_shared_scan_args = [
            arg.variable for arg in dummy_inputs
            if (isinstance(arg.variable, SharedVariable) and not arg.update)
        ]
        other_shared_inner_args = [
            safe_new(arg.variable, "_copy") for arg in dummy_inputs
            if (isinstance(arg.variable, SharedVariable) and not arg.update)
        ]
    givens.update(
        OrderedDict(zip(other_shared_scan_args, other_shared_inner_args)))

    ##
    # Step 6. Re-order the outputs and clone them replacing things
    # using the givens
    ##
    inner_inputs = (inner_seqs + mit_mot_inner_inputs + mit_sot_inner_inputs +
                    sit_sot_inner_inputs + shared_inner_inputs +
                    other_shared_inner_args + other_inner_args)

    inner_outs = (mit_mot_inner_outputs + mit_sot_inner_outputs +
                  sit_sot_inner_outputs + nit_sot_inner_outputs +
                  shared_inner_outputs)
    if condition is not None:
        inner_outs.append(condition)
    # gpuarray is imported here, instead of being imported on top of
    # the file because that would force on the user some dependencies that we
    # might do not want to. Currently we are working on removing the
    # dependencies on sandbox code completely.
    from aesara import gpuarray

    if gpuarray.pygpu_activated:
        # very often we end up in this situation when we want to
        # replace w with w_copy, where w is a GPU variable
        # and w_copy is TensorType. This is caused because shared
        # variables are put on GPU right away >:| ,
        new_givens = OrderedDict()

        for w, w_copy in givens.items():
            if isinstance(w.type, gpuarray.GpuArrayType) and isinstance(
                    w_copy.type, TensorType):
                for o in inner_outs:
                    new_givens = traverse(o, w, w_copy, new_givens)
            else:
                new_givens[w] = w_copy
    else:
        new_givens = givens

    new_outs = clone_replace(inner_outs, replace=new_givens)

    ##
    # Step 7. Create the Scan Op
    ##

    tap_array = tuple(tuple(v) for v in mit_sot_tap_array) + tuple(
        (-1, ) for x in range(n_sit_sot))
    if allow_gc is None:
        allow_gc = config.scan__allow_gc

    info = ScanInfo(
        tap_array=tap_array,
        n_seqs=n_seqs,
        n_mit_mot=n_mit_mot,
        n_mit_mot_outs=n_mit_mot_outs,
        mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices),
        n_mit_sot=n_mit_sot,
        n_sit_sot=n_sit_sot,
        n_shared_outs=n_shared_outs,
        n_nit_sot=n_nit_sot,
    )

    local_op = Scan(
        inner_inputs,
        new_outs,
        info,
        mode=mode,
        truncate_gradient=truncate_gradient,
        name=name,
        gpua=False,
        as_while=as_while,
        profile=profile,
        allow_gc=allow_gc,
        strict=strict,
    )

    ##
    # Step 8. Compute the outputs using the scan op
    ##
    _scan_inputs = (scan_seqs + mit_mot_scan_inputs + mit_sot_scan_inputs +
                    sit_sot_scan_inputs + shared_scan_inputs +
                    [actual_n_steps for x in range(n_nit_sot)] +
                    other_shared_scan_args + other_scan_args)

    scan_inputs = []
    for arg in [actual_n_steps] + _scan_inputs:
        try:
            arg = at.as_tensor_variable(arg)
        except TypeError:
            # This happens for Random States for e.g. but it is a good way
            # to make sure all inputs are tensors.
            pass
        scan_inputs += [arg]
    scan_outs = local_op(*scan_inputs)
    if not isinstance(scan_outs, (list, tuple)):
        scan_outs = [scan_outs]
    ##
    # Step 9. Figure out which outs are update rules for shared variables
    # and so on ...
    ##

    update_map = OrderedUpdates()

    def remove_dimensions(outs, steps_return, offsets=None):
        out_ls = []
        for idx, out in enumerate(outs):
            if idx in steps_return:
                if steps_return[idx] > 1:
                    out_ls.append(out[-steps_return[idx]:])
                else:
                    out_ls.append(out[-1])
            else:
                if offsets is None:
                    out_ls.append(out)
                else:
                    out_ls.append(out[offsets[idx]:])
        return out_ls

    offset = n_mit_mot
    offsets = [abs(np.min(x)) for x in mit_sot_tap_array]
    mit_sot_outs = remove_dimensions(scan_outs[offset:offset + n_mit_sot],
                                     mit_sot_return_steps, offsets)

    offset += n_mit_sot
    offsets = [1 for x in range(n_sit_sot)]
    sit_sot_outs = remove_dimensions(scan_outs[offset:offset + n_sit_sot],
                                     sit_sot_return_steps, offsets)

    offset += n_sit_sot
    nit_sot_outs = remove_dimensions(scan_outs[offset:offset + n_nit_sot],
                                     nit_sot_return_steps)

    offset += n_nit_sot
    for idx, update_rule in enumerate(scan_outs[offset:offset +
                                                n_shared_outs]):
        update_map[shared_scan_inputs[idx]] = update_rule

    _scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs
    # Step 10. I need to reorder the outputs to be in the order expected by
    # the user
    rightOrder = mit_sot_rightOrder + sit_sot_rightOrder + nit_sot_rightOrder
    scan_out_list = [None] * len(rightOrder)
    for idx, pos in enumerate(rightOrder):
        if pos >= 0:
            scan_out_list[pos] = _scan_out_list[idx]
        else:
            # Not that pos is not a negative index. The sign of pos is used
            # as a flag to indicate if this output should be part of the
            # update rules or part of the standard outputs of scan.
            # If `pos` is positive than it corresponds to the standard
            # outputs of scan and it refers to output of index `pos`. If `pos`
            # is negative that it corresponds to update rules of scan and it
            # refers to update rule of index -1 - `pos`.
            update_map[sit_sot_shared[abs(pos) - 1]] = _scan_out_list[idx][-1]
    scan_out_list = [x for x in scan_out_list if x is not None]
    if not return_list and len(scan_out_list) == 1:
        scan_out_list = scan_out_list[0]
    elif len(scan_out_list) == 0:
        scan_out_list = None

    return (scan_out_list, update_map)
示例#5
0
    def __init__(
        self,
        inputs: List[Variable],
        outputs: List[Variable],
        inline: bool = False,
        lop_overrides: str = "default",
        grad_overrides: str = "default",
        rop_overrides: str = "default",
        connection_pattern: Optional[List[List[bool]]] = None,
        name: Optional[str] = None,
        **kwargs,
    ):
        """
        Parameters
        ----------
        inputs
            The inputs to the graph.
        outputs
            The outputs to the graph.
        inline
            Defaults to ``False``

            ``True`` : Cause the :class:`Op`'s original graph being used during
            compilation, the :class:`Op` will not be visible in the compiled
            graph but rather its internal graph.

            ``False`` : will use a pre-compiled function inside.
        grad_overrides
            Defaults to ``'default'``.
            This argument is mutually exclusive with ``lop_overrides``.

            ``'default'`` : Do not override, use default grad() result

            `OpFromGraph`: Override with another `OpFromGraph`, should
            accept inputs as the same order and types of ``inputs`` and ``output_grads``
            arguments as one would specify in :meth:`Op.grad`() method.

            `callable`: Should take two args: ``inputs`` and ``output_grads``.
            Each argument is expected to be a list of :class:`Variable `.
            Must return list of :class:`Variable `.
        lop_overrides
            Defaults to ``'default'``.

            This argument is mutually exclusive with ``grad_overrides``.

            These options are similar to the ``grad_overrides`` above, but for
            the :meth:`Op.L_op` method.

            ``'default'``: Do not override, use the default :meth:`Op.L_op` result

            `OpFromGraph`: Override with another `OpFromGraph`, should
            accept inputs as the same order and types of ``inputs``,
            ``outputs`` and ``output_grads`` arguments as one would specify in
            :meth:`Op.grad` method.

            `callable`: Should take three args: ``inputs``, ``outputs`` and ``output_grads``.
            Each argument is expected to be a list of :class:`Variable`.
            Must return list of :class:`Variable`.

            `NullType` instance: Treat as non-differentiable
            `DisconnectedType` instance: Treat as disconnected gradient,
            numerically gives zero

            ``list``: Each `OpFromGraph`/callable must return a single
            :class:`Variable`. Each list element corresponds to gradient of
            a specific input, length of list must be equal to number of inputs.

        rop_overrides
            One of ``{'default', OpFromGraph, callable, Variable}``.

            Defaults to ``'default'``.

            ``'default'``: Do not override, use the default :meth:`Op.R_op` result

            `OpFromGraph`: Override with another `OpFromGraph`, should
            accept inputs as the same order and types of ``inputs`` and ``eval_points``
            arguments as one would specify in :meth:`Op.R_op` method.

            `callable`: Should take two args: ``inputs`` and ``eval_points``.
            Each argument is expected to be a list of :class:`Variable`.  Must
            return list of :class:`Variable`.

            `NullType` instance: Treat as non-differentiable `DisconnectedType`
            instance: Treat as zero since `DisconnectedType` is not yet supported
            in :meth:`Op.R_op`.

            ``list``:
            Each :class:`OpFromGraph`/callable must return a single
            :class:`Variable <aesara.graph.basic.Variable>`. Each list element
            corresponds to a specific output of :meth:`Op.R_op`, length of list
            must be equal to number of outputs.  connection_pattern If not
            ``None``, this will be used as the connection_pattern for this
            :class:`Op`.
        name
            A name for debugging purposes.
        kwargs
            Check :func:`orig_function` for more arguments, only works when not
            inline.
        """

        if not (isinstance(inputs, list) and isinstance(outputs, list)):
            raise TypeError("Inputs and outputs must be lists")

        for i in inputs + outputs:
            if not isinstance(i, Variable):
                raise TypeError(
                    f"Inputs and outputs must be Variable instances; got {i}")
            if i in inputs and isinstance(i, Constant):
                raise TypeError(f"Constants not allowed as inputs; {i}")

        if "updates" in kwargs or "givens" in kwargs:
            raise NotImplementedError(
                "Updates and givens are not allowed here")

        self.is_inline = inline
        # To correctly support shared variables the inner fct should
        # not see them. Otherwise there is a problem with the gradient.
        self.shared_inputs = [
            var for var in graph_inputs(outputs)
            if isinstance(var, SharedVariable)
        ]
        shared_vars = [var.type() for var in self.shared_inputs]

        new = rebuild_collect_shared(
            outputs,
            inputs=inputs + shared_vars,
            replace=dict(zip(self.shared_inputs, shared_vars)),
            copy_inputs_over=False,
        )
        (
            local_inputs,
            local_outputs,
            [clone_d, update_d, update_expr, shared_inputs],
        ) = new
        assert len(local_inputs) == len(inputs) + len(self.shared_inputs)
        assert len(local_outputs) == len(outputs)
        assert not update_d
        assert not update_expr
        assert not shared_inputs

        self._inner_inputs = local_inputs
        self._inner_outputs = local_outputs
        self.inputs = inputs
        self.outputs = outputs
        self.kwargs = kwargs
        self.input_types = [inp.type for inp in inputs]
        self.output_types = [out.type for out in outputs]
        if lop_overrides != "default":
            if grad_overrides != "default":
                raise ValueError(
                    "lop_overrides and grad_overrides are mutually exclusive")
            else:
                self.set_lop_overrides(lop_overrides)
                self._lop_type = "lop"
        elif grad_overrides != "default":
            self.set_lop_overrides(grad_overrides)
            self._lop_type = "grad"
        else:
            self.set_lop_overrides("default")
            self._lop_type = "lop"
        self.set_rop_overrides(rop_overrides)

        self._connection_pattern = connection_pattern

        if name is not None:
            assert isinstance(name, str), "name must be None or string object"
        self.name = name
示例#6
0
    def _set_row_mappings(self, Gamma, dir_priors, model):
        """Create maps from Dirichlet priors parameters to rows and slices in the transition matrix.

        These maps are needed when a transition matrix isn't simply comprised
        of Dirichlet prior rows, but--instead--slices of Dirichlet priors.

        Consider the following:

        .. code-block:: python

            with pm.Model():
                d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1])
                d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1])

                p_0_rv = tt.as_tensor([0, 0, 1])
                p_1_rv = tt.zeros(3)
                p_1_rv = tt.set_subtensor(p_0_rv[[0, 2]], d_0_rv)
                p_2_rv = tt.zeros(3)
                p_2_rv = tt.set_subtensor(p_1_rv[[1, 2]], d_1_rv)

                P_tt = tt.stack([p_0_rv, p_1_rv, p_2_rv])

        The transition matrix `P_tt` has Dirichlet priors in only two of its
        three rows, and--even then--they're only present in parts of two rows.

        In this example, we need to know that Dirichlet prior 0, i.e. `d_0_rv`,
        is mapped to row 1, and prior 1 is mapped to row 2.  Furthermore, we
        need to know that prior 0 fills columns 0 and 2 in row 1, and prior 1
        fills columns 1 and 2 in row 2.

        These mappings allow one to embed Dirichlet priors in larger transition
        matrices with--for instance--fixed transition behavior.

        """  # noqa: E501

        # Remove unimportant `Op`s from the transition matrix graph
        Gamma = pre_greedy_local_optimizer(
            FunctionGraph([], []),
            [
                OpRemove(Elemwise(aes.Cast(aes.float32))),
                OpRemove(Elemwise(aes.Cast(aes.float64))),
                OpRemove(Elemwise(aes.identity)),
            ],
            Gamma,
        )

        # Canonicalize the transition matrix graph
        fg = FunctionGraph(
            list(graph_inputs([Gamma] + self.dir_priors_untrans)),
            [Gamma] + self.dir_priors_untrans,
            clone=True,
        )
        canonicalize_opt = optdb.query(Query(include=["canonicalize"]))
        canonicalize_opt.optimize(fg)
        Gamma = fg.outputs[0]
        dir_priors_untrans = fg.outputs[1:]
        fg.disown()

        Gamma_DimShuffle = Gamma.owner

        if not (isinstance(Gamma_DimShuffle.op, DimShuffle)):
            raise TypeError("The transition matrix should be non-time-varying")

        Gamma_Join = Gamma_DimShuffle.inputs[0].owner

        if not (isinstance(Gamma_Join.op, at.basic.Join)):
            raise TypeError(
                "The transition matrix should be comprised of stacked row vectors"
            )

        Gamma_rows = Gamma_Join.inputs[1:]

        self.n_rows = len(Gamma_rows)

        # Loop through the rows in the transition matrix's graph and determine
        # how our transformed Dirichlet RVs map to this transition matrix.
        self.row_remaps = {}
        self.row_slices = {}
        for i, dim_row in enumerate(Gamma_rows):
            if not dim_row.owner:
                continue

            # By-pass the `DimShuffle`s applied to the `AdvancedIncSubtensor1`
            # `Op`s in which we're actually interested
            gamma_row = dim_row.owner.inputs[0]

            if gamma_row in dir_priors_untrans:
                # This is a row that's simply a `Dirichlet`
                j = dir_priors_untrans.index(gamma_row)
                self.row_remaps[j] = i
                self.row_slices[j] = slice(None)

            if gamma_row.owner.inputs[1] not in dir_priors_untrans:
                continue

            # Parts of a row set by a `*Subtensor*` `Op` using a full
            # `Dirichlet` e.g. `P_row[idx] = dir_rv`
            j = dir_priors_untrans.index(gamma_row.owner.inputs[1])
            untrans_dirich = dir_priors_untrans[j]

            if (gamma_row.owner
                    and isinstance(gamma_row.owner.op, AdvancedIncSubtensor1)
                    and gamma_row.owner.inputs[1] == untrans_dirich):
                self.row_remaps[j] = i

                rhand_val = gamma_row.owner.inputs[2]
                if not isinstance(rhand_val, TensorConstant):
                    # TODO: We could allow more types of `idx` (e.g. slices)
                    # Currently, `idx` can't be something like `2:5`
                    raise TypeError("Only array indexing allowed for mixed"
                                    " Dirichlet/non-Dirichlet rows")
                self.row_slices[j] = rhand_val.data
示例#7
0
    def __call__(self, fct, graph=None):
        """Create pydot graph from function.

        Parameters
        ----------
        fct : aesara.compile.function.types.Function
            A compiled Aesara function, variable, apply or a list of variables.
        graph: pydot.Dot
            `pydot` graph to which nodes are added. Creates new one if
            undefined.

        Returns
        -------
        pydot.Dot
            Pydot graph of `fct`
        """
        if graph is None:
            graph = pd.Dot()

        self.__nodes = {}

        profile = None

        if isinstance(fct, Function):
            profile = getattr(fct, "profile", None)
            fgraph = fct.maker.fgraph
        elif isinstance(fct, FunctionGraph):
            fgraph = fct
        else:
            if isinstance(fct, Variable):
                fct = [fct]
            elif isinstance(fct, Apply):
                fct = fct.outputs
            assert isinstance(fct, (list, tuple))
            assert all(isinstance(v, Variable) for v in fct)
            fgraph = FunctionGraph(inputs=graph_inputs(fct), outputs=fct)

        outputs = fgraph.outputs
        topo = fgraph.toposort()
        outputs = list(outputs)

        # Loop over apply nodes
        for node in topo:
            nparams = {}
            __node_id = self.__node_id(node)
            nparams["name"] = __node_id
            nparams["label"] = apply_label(node)
            nparams["profile"] = apply_profile(fgraph, node, profile)
            nparams["node_type"] = "apply"
            nparams["apply_op"] = nparams["label"]
            nparams["shape"] = self.shapes["apply"]

            use_color = None
            for opName, color in self.apply_colors.items():
                if opName in node.op.__class__.__name__:
                    use_color = color
            if use_color:
                nparams["style"] = "filled"
                nparams["fillcolor"] = use_color
                nparams["type"] = "colored"

            pd_node = dict_to_pdnode(nparams)
            graph.add_node(pd_node)

            # Loop over input nodes
            for id, var in enumerate(node.inputs):
                var_id = self.__node_id(var.owner if var.owner else var)
                if var.owner is None:
                    vparams = {
                        "name": var_id,
                        "label": var_label(var),
                        "node_type": "input",
                    }
                    if isinstance(var, Constant):
                        vparams["node_type"] = "constant_input"
                    elif isinstance(
                            var, aesara.tensor.sharedvar.TensorSharedVariable):
                        vparams["node_type"] = "shared_input"
                    vparams["dtype"] = type_to_str(var.type)
                    vparams["tag"] = var_tag(var)
                    vparams["style"] = "filled"
                    vparams["fillcolor"] = self.node_colors[
                        vparams["node_type"]]
                    vparams["shape"] = self.shapes["input"]
                    pd_var = dict_to_pdnode(vparams)
                    graph.add_node(pd_var)

                edge_params = {}
                if node.op.view_map and id in reduce(
                        list.__add__, node.op.view_map.values(), []):
                    edge_params["color"] = self.node_colors["output"]
                elif node.op.destroy_map and id in reduce(
                        list.__add__, node.op.destroy_map.values(), []):
                    edge_params["color"] = "red"

                edge_label = vparams["dtype"]
                if len(node.inputs) > 1:
                    edge_label = str(id) + " " + edge_label
                pdedge = pd.Edge(var_id,
                                 __node_id,
                                 label=edge_label,
                                 **edge_params)
                graph.add_edge(pdedge)

            # Loop over output nodes
            for id, var in enumerate(node.outputs):
                var_id = self.__node_id(var)

                if var in outputs or len(fgraph.clients[var]) == 0:
                    vparams = {
                        "name": var_id,
                        "label": var_label(var),
                        "node_type": "output",
                        "dtype": type_to_str(var.type),
                        "tag": var_tag(var),
                        "style": "filled",
                    }
                    if len(fgraph.clients[var]) == 0:
                        vparams["fillcolor"] = self.node_colors["unused"]
                    else:
                        vparams["fillcolor"] = self.node_colors["output"]
                    vparams["shape"] = self.shapes["output"]
                    pd_var = dict_to_pdnode(vparams)
                    graph.add_node(pd_var)

                    graph.add_edge(
                        pd.Edge(__node_id, var_id, label=vparams["dtype"]))
                elif var.name or not self.compact:
                    graph.add_edge(
                        pd.Edge(__node_id, var_id, label=vparams["dtype"]))

            # Create sub-graph for OpFromGraph nodes
            if isinstance(node.op, builders.OpFromGraph):
                subgraph = pd.Cluster(__node_id)
                gf = PyDotFormatter()
                # Use different node prefix for sub-graphs
                gf.__node_prefix = __node_id
                gf(node.op.fn, subgraph)
                graph.add_subgraph(subgraph)
                pd_node.get_attributes()["subg"] = subgraph.get_name()

                def format_map(m):
                    return str([list(x) for x in m])

                # Inputs mapping
                ext_inputs = [self.__node_id(x) for x in node.inputs]
                int_inputs = [gf.__node_id(x) for x in node.op.inner_inputs]
                assert len(ext_inputs) == len(int_inputs)
                h = format_map(zip(ext_inputs, int_inputs))
                pd_node.get_attributes()["subg_map_inputs"] = h

                # Outputs mapping
                ext_outputs = [self.__node_id(x) for x in node.outputs]
                int_outputs = [gf.__node_id(x) for x in node.op.inner_outputs]
                assert len(ext_outputs) == len(int_outputs)
                h = format_map(zip(int_outputs, ext_outputs))
                pd_node.get_attributes()["subg_map_outputs"] = h

        return graph
示例#8
0
文件: gibbs.py 项目: YRApril/LiJia
def elemwise_logp(model, var):
    terms = [
        v.logp_elemwiset for v in model.basic_RVs
        if var in graph_inputs([v.logpt])
    ]
    return model.fn(add(*terms))
示例#9
0
def test_logpt_incsubtensor(indices, size):
    """Make sure we can compute a log-likelihood for ``Y[idx] = data`` where ``Y`` is univariate."""

    mu = floatX(np.power(10, np.arange(np.prod(size)))).reshape(size)
    data = mu[indices]
    sigma = 0.001
    rng = aesara.shared(np.random.RandomState(232), borrow=True)

    a = Normal.dist(mu, sigma, size=size, rng=rng)
    a.name = "a"

    a_idx = at.set_subtensor(a[indices], data)

    assert isinstance(
        a_idx.owner.op,
        (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1))

    a_idx_value_var = a_idx.type()
    a_idx_value_var.name = "a_idx_value"

    a_idx_logp = logpt(a_idx, a_idx_value_var)

    logp_vals = a_idx_logp.eval()

    # The indices that were set should all have the same log-likelihood values,
    # because the values they were set to correspond to the unique means along
    # that dimension.  This helps us confirm that the log-likelihood is
    # associating the assigned values with their correct parameters.
    exp_obs_logps = sp.norm.logpdf(mu, mu, sigma)[indices]
    np.testing.assert_almost_equal(logp_vals[indices], exp_obs_logps)

    # Next, we need to confirm that the unset indices are being sampled
    # from the original random variable in the correct locations.
    # rng.get_value(borrow=True).seed(232)

    res_ancestors = list(walk_model((a_idx_logp, ), walk_past_rvs=True))
    res_rv_ancestors = tuple(
        v for v in res_ancestors
        if v.owner and isinstance(v.owner.op, RandomVariable))

    # The imputed missing values are drawn from the original distribution
    (a_new, ) = res_rv_ancestors
    assert a_new is not a
    assert a_new.owner.op == a.owner.op

    fg = FunctionGraph(
        [
            v for v in graph_inputs((a_idx_logp, ))
            if not isinstance(v, Constant)
        ],
        [a_idx_logp],
        clone=False,
    )

    ((a_client, _), ) = fg.clients[a_new]
    # The imputed values should be treated as constants when gradients are
    # taken
    assert isinstance(a_client.op, DisconnectedGrad)

    ((a_client, _), ) = fg.clients[a_client.outputs[0]]
    assert isinstance(
        a_client.op,
        (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1))
    indices = tuple(i.eval() for i in a_client.inputs[2:])
    np.testing.assert_almost_equal(indices, indices)
示例#10
0
def logpt(
    var: Union[TensorVariable, List[TensorVariable]],
    rv_values: Optional[Union[TensorVariable, Dict[TensorVariable,
                                                   TensorVariable]]] = None,
    *,
    jacobian: bool = True,
    scaling: bool = True,
    transformed: bool = True,
    sum: bool = True,
    **kwargs,
) -> Union[TensorVariable, List[TensorVariable]]:
    """Create a measure-space (i.e. log-likelihood) graph for a random variable
    or a list of random variables at a given point.

    The input `var` determines which log-likelihood graph is used and
    `rv_value` is that graph's input parameter.  For example, if `var` is
    the output of a ``NormalRV`` ``Op``, then the output is a graph of the
    density function for `var` set to the value `rv_value`.

    Parameters
    ==========
    var
        The `RandomVariable` output that determines the log-likelihood graph.
        Can also be a list of variables. The final log-likelihood graph will
        be the sum total of all individual log-likelihood graphs of variables
        in the list.
    rv_values
        A variable, or ``dict`` of variables, that represents the value of
        `var` in its log-likelihood.  If no `rv_value` is provided,
        ``var.tag.value_var`` will be checked and, when available, used.
    jacobian
        Whether or not to include the Jacobian term.
    scaling
        A scaling term to apply to the generated log-likelihood graph.
    transformed
        Apply transforms.
    sum
        Sum the log-likelihood or return each term as a separate list item.

    """
    # TODO: In future when we drop support for tag.value_var most of the following
    # logic can be removed and logpt can just be a wrapper function that calls aeppl's
    # joint_logprob directly.

    # If var is not a list make it one.
    if not isinstance(var, (list, tuple)):
        var = [var]

    # If logpt isn't provided values it is assumed that the tagged value var or
    # observation is the value variable for that particular RV.
    if rv_values is None:
        rv_values = {}
        for rv in var:
            value_var = getattr(rv.tag, "observations",
                                getattr(rv.tag, "value_var", None))
            if value_var is None:
                raise ValueError(f"No value variable found for var {rv}")
            rv_values[rv] = value_var
    # Else we assume we were given a single rv and respective value
    elif not isinstance(rv_values, Mapping):
        if len(var) == 1:
            rv_values = {
                var[0]: at.as_tensor_variable(rv_values).astype(var[0].type)
            }
        else:
            raise ValueError(
                "rv_values must be a dict if more than one var is requested")

    if scaling:
        rv_scalings = {}
        for rv, value_var in rv_values.items():
            rv_scalings[value_var] = _get_scaling(
                getattr(rv.tag, "total_size", None), value_var.shape,
                value_var.ndim)

    # Aeppl needs all rv-values pairs, not just that of the requested var.
    # Hence we iterate through the graph to collect them.
    tmp_rvs_to_values = rv_values.copy()
    for node in io_toposort(graph_inputs(var), var):
        try:
            curr_vars = [node.default_output()]
        except ValueError:
            curr_vars = node.outputs
        for curr_var in curr_vars:
            if curr_var in tmp_rvs_to_values:
                continue
            # Check if variable has a value variable
            value_var = getattr(curr_var.tag, "observations",
                                getattr(curr_var.tag, "value_var", None))
            if value_var is not None:
                tmp_rvs_to_values[curr_var] = value_var

    # After collecting all necessary rvs and values, we check for any value transforms
    transform_map = {}
    if transformed:
        for rv, value_var in tmp_rvs_to_values.items():
            if hasattr(value_var.tag, "transform"):
                transform_map[value_var] = value_var.tag.transform
            # If the provided value_variable does not have transform information, we
            # check if the original `rv.tag.value_var` does.
            # TODO: This logic should be replaced by an explicit dict of
            #  `{value_var: transform}` similar to `rv_values`.
            else:
                original_value_var = getattr(rv.tag, "value_var", None)
                if original_value_var is not None and hasattr(
                        original_value_var.tag, "transform"):
                    transform_map[value_var] = original_value_var.tag.transform

    transform_opt = TransformValuesOpt(transform_map)
    temp_logp_var_dict = factorized_joint_logprob(tmp_rvs_to_values,
                                                  extra_rewrites=transform_opt,
                                                  use_jacobian=jacobian,
                                                  **kwargs)

    # aeppl returns the logpt for every single value term we provided to it. This includes
    # the extra values we plugged in above, so we filter those we actually wanted in the
    # same order they were given in.
    logp_var_dict = {}
    for value_var in rv_values.values():
        logp_var_dict[value_var] = temp_logp_var_dict[value_var]

    if scaling:
        for value_var in logp_var_dict.keys():
            if value_var in rv_scalings:
                logp_var_dict[value_var] *= rv_scalings[value_var]

    if sum:
        logp_var = at.sum(
            [at.sum(factor) for factor in logp_var_dict.values()])
    else:
        logp_var = list(logp_var_dict.values())
        # TODO: deprecate special behavior when only one variable is requested and
        #  always return a list. This is here for backwards compatibility as logpt
        #  started as a replacement to factor.logpt, but it should now be considered an
        #  internal function reached only via model.logp* methods.
        if len(logp_var) == 1:
            logp_var = logp_var[0]

    return logp_var
示例#11
0
def sample_numpyro_nuts(
    draws=1000,
    tune=1000,
    chains=4,
    target_accept=0.8,
    random_seed=10,
    model=None,
    progress_bar=True,
    keep_untransformed=False,
):
    model = modelcontext(model)

    seed = jax.random.PRNGKey(random_seed)

    rv_names = [rv.name for rv in model.value_vars]
    init_state = [model.initial_point[rv_name] for rv_name in rv_names]
    init_state_batched = jax.tree_map(
        lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)
    init_state_batched_at = [at.as_tensor(v) for v in init_state_batched]

    nuts_inputs = sorted(
        [
            v for v in graph_inputs([model.logpt])
            if not isinstance(v, Constant)
        ],
        key=lambda x: isinstance(x, SharedVariable),
    )
    map_seed = jax.random.split(seed, chains)
    numpyro_samples = NumPyroNUTS(
        nuts_inputs,
        [model.logpt],
        target_accept=target_accept,
        draws=draws,
        tune=tune,
        chains=chains,
        seed=map_seed,
        progress_bar=progress_bar,
    )(*init_state_batched_at)

    # Un-transform the transformed variables in JAX
    sample_outputs = []
    for i, (value_var, rv_samples) in enumerate(
            zip(model.value_vars, numpyro_samples[:-1])):
        rv = model.values_to_rvs[value_var]
        transform = getattr(value_var.tag, "transform", None)
        if transform is not None:
            untrans_value_var = transform.backward(rv, rv_samples)
            untrans_value_var.name = rv.name
            sample_outputs.append(untrans_value_var)

            if keep_untransformed:
                rv_samples.name = value_var.name
                sample_outputs.append(rv_samples)
        else:
            rv_samples.name = rv.name
            sample_outputs.append(rv_samples)

    print("Compiling...")

    tic1 = pd.Timestamp.now()
    _sample = compile_rv_inplace(
        [],
        sample_outputs + [numpyro_samples[-1]],
        allow_input_downcast=True,
        on_unused_input="ignore",
        accept_inplace=True,
        mode="JAX",
    )
    tic2 = pd.Timestamp.now()

    print("Compilation time = ", tic2 - tic1)

    print("Sampling...")

    *mcmc_samples, leapfrogs_taken = _sample()
    tic3 = pd.Timestamp.now()

    print("Sampling time = ", tic3 - tic2)

    posterior = {k.name: v for k, v in zip(sample_outputs, mcmc_samples)}

    az_trace = az.from_dict(posterior=posterior)

    return az_trace
示例#12
0
def logcdfpt(
    var: TensorVariable,
    rv_values: Optional[Union[TensorVariable, Dict[TensorVariable, TensorVariable]]] = None,
    *,
    scaling: bool = True,
    sum: bool = True,
    **kwargs,
) -> TensorVariable:
    """Create a measure-space (i.e. log-cdf) graph for a random variable at a given point.

    Parameters
    ==========
    var
        The `RandomVariable` output that determines the log-likelihood graph.
    rv_values
        A variable, or ``dict`` of variables, that represents the value of
        `var` in its log-likelihood.  If no `rv_value` is provided,
        ``var.tag.value_var`` will be checked and, when available, used.
    jacobian
        Whether or not to include the Jacobian term.
    scaling
        A scaling term to apply to the generated log-likelihood graph.
    transformed
        Apply transforms.
    sum
        Sum the log-likelihood.

    """
    if not isinstance(rv_values, Mapping):
        rv_values = {var: rv_values} if rv_values is not None else {}

    rv_var, rv_value_var = extract_rv_and_value_vars(var)

    rv_value = rv_values.get(rv_var, rv_value_var)

    if rv_var is not None and rv_value is None:
        raise ValueError(f"No value variable specified or associated with {rv_var}")

    if rv_value is not None:
        rv_value = at.as_tensor(rv_value)

        if rv_var is not None:
            # Make sure that the value is compatible with the random variable
            rv_value = rv_var.type.filter_variable(rv_value.astype(rv_var.dtype))

        if rv_value_var is None:
            rv_value_var = rv_value

    rv_node = rv_var.owner

    rng, size, dtype, *dist_params = rv_node.inputs

    # Here, we plug the actual random variable into the log-likelihood graph,
    # because we want a log-likelihood graph that only contains
    # random variables.  This is important, because a random variable's
    # parameters can contain random variables themselves.
    # Ultimately, with a graph containing only random variables and
    # "deterministics", we can simply replace all the random variables with
    # their value variables and be done.
    tmp_rv_values = rv_values.copy()
    tmp_rv_values[rv_var] = rv_var

    logp_var = _logcdf(rv_node.op, rv_var, tmp_rv_values, *dist_params, **kwargs)

    transform = getattr(rv_value_var.tag, "transform", None) if rv_value_var else None

    # Replace random variables with their value variables
    replacements = rv_values.copy()
    replacements.update({rv_var: rv_value, rv_value_var: rv_value})

    (logp_var,), _ = rvs_to_value_vars(
        (logp_var,),
        apply_transforms=False,
        initial_replacements=replacements,
    )

    if sum:
        logp_var = at.sum(logp_var)

    if scaling:
        logp_var *= _get_scaling(
            getattr(rv_var.tag, "total_size", None), rv_value.shape, rv_value.ndim
        )

    # Recompute test values for the changes introduced by the replacements
    # above.
    if config.compute_test_value != "off":
        for node in io_toposort(graph_inputs((logp_var,)), (logp_var,)):
            compute_test_value(node)

    if rv_var.name is not None:
        logp_var.name = f"__logp_{rv_var.name}"

    return logp_var
示例#13
0
def logpt(
    var: TensorVariable,
    rv_values: Optional[Union[TensorVariable, Dict[TensorVariable, TensorVariable]]] = None,
    *,
    jacobian: bool = True,
    scaling: bool = True,
    transformed: bool = True,
    sum: bool = True,
    **kwargs,
) -> TensorVariable:
    """Create a measure-space (i.e. log-likelihood) graph for a random variable
    or a list of random variables at a given point.

    The input `var` determines which log-likelihood graph is used and
    `rv_value` is that graph's input parameter.  For example, if `var` is
    the output of a ``NormalRV`` ``Op``, then the output is a graph of the
    density function for `var` set to the value `rv_value`.

    Parameters
    ==========
    var
        The `RandomVariable` output that determines the log-likelihood graph.
        Can also be a list of variables. The final log-likelihood graph will
        be the sum total of all individual log-likelihood graphs of variables
        in the list.
    rv_values
        A variable, or ``dict`` of variables, that represents the value of
        `var` in its log-likelihood.  If no `rv_value` is provided,
        ``var.tag.value_var`` will be checked and, when available, used.
    jacobian
        Whether or not to include the Jacobian term.
    scaling
        A scaling term to apply to the generated log-likelihood graph.
    transformed
        Apply transforms.
    sum
        Sum the log-likelihood.

    """
    # TODO: In future when we drop support for tag.value_var most of the following
    # logic can be removed and logpt can just be a wrapper function that calls aeppl's
    # joint_logprob directly.

    # If var is not a list make it one.
    if not isinstance(var, list):
        var = [var]

    # If logpt isn't provided values and the variable (provided in var)
    # is an RV, it is assumed that the tagged value var or observation is
    # the value variable for that particular RV.
    if rv_values is None:
        rv_values = {}
        for _var in var:
            if isinstance(_var.owner.op, RandomVariable):
                rv_value_var = getattr(
                    _var.tag, "observations", getattr(_var.tag, "value_var", _var)
                )
                rv_values = {_var: rv_value_var}
    elif not isinstance(rv_values, Mapping):
        # Else if we're given a single value and a single variable we assume a mapping among them.
        rv_values = (
            {var[0]: at.as_tensor_variable(rv_values).astype(var[0].type)} if len(var) == 1 else {}
        )

    # Since the filtering of logp graph is based on value variables
    # provided to this function
    if not rv_values:
        warnings.warn("No value variables provided the logp will be an empty graph")

    if scaling:
        rv_scalings = {}
        for _var in var:
            rv_value_var = getattr(_var.tag, "observations", getattr(_var.tag, "value_var", _var))
            rv_scalings[rv_value_var] = _get_scaling(
                getattr(_var.tag, "total_size", None), rv_value_var.shape, rv_value_var.ndim
            )

    # Unlike aeppl, PyMC's logpt is expected to plug in the values variables to corresponding
    # RVs automatically unless the values are explicity set to None. Hence we iterate through
    # the graph to find RVs and construct a new RVs to values dictionary.
    tmp_rvs_to_values = rv_values.copy()
    transform_map = {}
    for node in io_toposort(graph_inputs(var), var):
        if isinstance(node.op, RandomVariable):
            curr_var = node.out
            rv_value_var = getattr(
                curr_var.tag, "observations", getattr(curr_var.tag, "value_var", curr_var)
            )
            rv_value = rv_values.get(curr_var, rv_value_var)
            tmp_rvs_to_values[curr_var] = rv_value
            # Along with value variables we also check for transforms if any.
            if hasattr(rv_value_var.tag, "transform") and transformed:
                transform_map[rv_value] = rv_value_var.tag.transform
        # The condition below is a hackish way of excluding the value variable for the
        # RV being indexed in case of Advanced Indexing of RVs. It gets added by the
        # logic above but aeppl does not expect us to include it in the dictionary of
        # {RV:values} given to it.
        if isinstance(node.op, subtensor_types):
            curr_var = node.out
            if (
                curr_var in tmp_rvs_to_values.keys()
                and curr_var.owner.inputs[0] in tmp_rvs_to_values.keys()
            ):
                tmp_rvs_to_values.pop(curr_var.owner.inputs[0])

    transform_opt = TransformValuesOpt(transform_map)
    temp_logp_var_dict = factorized_joint_logprob(
        tmp_rvs_to_values, extra_rewrites=transform_opt, use_jacobian=jacobian, **kwargs
    )

    # aeppl returns the logpt for every single value term we provided to it. This includes
    # the extra values we plugged in above so we need to filter those out.
    logp_var_dict = {}
    for value_var, _logp in temp_logp_var_dict.items():
        if value_var in rv_values.values():
            logp_var_dict[value_var] = _logp

    # If it's an empty dictionary the logp is None
    if not logp_var_dict:
        logp_var = None
    else:
        # Otherwise apply appropriate scalings and at.add and/or at.sum the
        # graphs accordingly.
        if scaling:
            for _value in logp_var_dict.keys():
                if _value in rv_scalings:
                    logp_var_dict[_value] *= rv_scalings[_value]

        if len(logp_var_dict) == 1:
            logp_var_dict = tuple(logp_var_dict.values())[0]
            if sum:
                logp_var = at.sum(logp_var_dict)
            else:
                logp_var = logp_var_dict
        else:
            if sum:
                logp_var = at.sum([at.sum(factor) for factor in logp_var_dict.values()])
            else:
                logp_var = at.add(*logp_var_dict.values())

        # Recompute test values for the changes introduced by the replacements
        # above.
        if config.compute_test_value != "off":
            for node in io_toposort(graph_inputs((logp_var,)), (logp_var,)):
                compute_test_value(node)

    return logp_var
示例#14
0
def pydotprint(
    fct,
    outfile=None,
    compact=True,
    format="png",
    with_ids=False,
    high_contrast=True,
    cond_highlight=None,
    colorCodes=None,
    max_label_size=70,
    scan_graphs=False,
    var_with_name_simple=False,
    print_output_file=True,
    return_image=False,
):
    """Print to a file the graph of a compiled aesara function's ops. Supports
    all pydot output formats, including png and svg.

    :param fct: a compiled Aesara function, a Variable, an Apply or
                a list of Variable.
    :param outfile: the output file where to put the graph.
    :param compact: if True, will remove intermediate var that don't have name.
    :param format: the file format of the output.
    :param with_ids: Print the toposort index of the node in the node name.
                     and an index number in the variable ellipse.
    :param high_contrast: if true, the color that describes the respective
            node is filled with its corresponding color, instead of coloring
            the border
    :param colorCodes: dictionary with names of ops as keys and colors as
            values
    :param cond_highlight: Highlights a lazy if by surrounding each of the 3
                possible categories of ops with a border. The categories
                are: ops that are on the left branch, ops that are on the
                right branch, ops that are on both branches
                As an alternative you can provide the node that represents
                the lazy if
    :param scan_graphs: if true it will plot the inner graph of each scan op
                in files with the same name as the name given for the main
                file to which the name of the scan op is concatenated and
                the index in the toposort of the scan.
                This index can be printed with the option with_ids.
    :param var_with_name_simple: If true and a variable have a name,
                we will print only the variable name.
                Otherwise, we concatenate the type to the var name.
    :param return_image: If True, it will create the image and return it.
        Useful to display the image in ipython notebook.

        .. code-block:: python

            import aesara
            v = aesara.tensor.vector()
            from IPython.display import SVG
            SVG(aesara.printing.pydotprint(v*2, return_image=True,
                                           format='svg'))

    In the graph, ellipses are Apply Nodes (the execution of an op)
    and boxes are variables.  If variables have names they are used as
    text (if multiple vars have the same name, they will be merged in
    the graph).  Otherwise, if the variable is constant, we print its
    value and finally we print the type + a unique number to prevent
    multiple vars from being merged.  We print the op of the apply in
    the Apply box with a number that represents the toposort order of
    application of those Apply.  If an Apply has more than 1 input, we
    label each edge between an input and the Apply node with the
    input's index.

    Variable color code::
        - Cyan boxes are SharedVariable, inputs and/or outputs) of the graph,
        - Green boxes are inputs variables to the graph,
        - Blue boxes are outputs variables of the graph,
        - Grey boxes are variables that are not outputs and are not used,

    Default apply node code::
        - Red ellipses are transfers from/to the gpu
        - Yellow are scan node
        - Brown are shape node
        - Magenta are IfElse node
        - Dark pink are elemwise node
        - Purple are subtensor
        - Orange are alloc node

    For edges, they are black by default. If a node returns a view
    of an input, we put the corresponding input edge in blue. If it
    returns a destroyed input, we put the corresponding edge in red.

    .. note::

        Since October 20th, 2014, this print the inner function of all
        scan separately after the top level debugprint output.

    """
    from aesara.scan.op import Scan

    if colorCodes is None:
        colorCodes = default_colorCodes

    if outfile is None:
        outfile = os.path.join(
            config.compiledir,
            "aesara.pydotprint." + config.device + "." + format)

    if isinstance(fct, Function):
        profile = getattr(fct, "profile", None)
        fgraph = fct.maker.fgraph
        outputs = fgraph.outputs
        topo = fgraph.toposort()
    elif isinstance(fct, FunctionGraph):
        profile = None
        outputs = fct.outputs
        topo = fct.toposort()
        fgraph = fct
    else:
        if isinstance(fct, Variable):
            fct = [fct]
        elif isinstance(fct, Apply):
            fct = fct.outputs
        assert isinstance(fct, (list, tuple))
        assert all(isinstance(v, Variable) for v in fct)
        fct = FunctionGraph(inputs=list(graph_inputs(fct)), outputs=fct)
        profile = None
        outputs = fct.outputs
        topo = fct.toposort()
        fgraph = fct
    if not pydot_imported:
        raise RuntimeError(
            "Failed to import pydot. You must install graphviz"
            " and either pydot or pydot-ng for "
            "`pydotprint` to work.",
            pydot_imported_msg,
        )

    g = pd.Dot()

    if cond_highlight is not None:
        c1 = pd.Cluster("Left")
        c2 = pd.Cluster("Right")
        c3 = pd.Cluster("Middle")
        cond = None
        for node in topo:
            if (node.op.__class__.__name__ == "IfElse"
                    and node.op.name == cond_highlight):
                cond = node
        if cond is None:
            _logger.warning("pydotprint: cond_highlight is set but there is no"
                            " IfElse node in the graph")
            cond_highlight = None

    if cond_highlight is not None:

        def recursive_pass(x, ls):
            if not x.owner:
                return ls
            else:
                ls += [x.owner]
                for inp in x.inputs:
                    ls += recursive_pass(inp, ls)
                return ls

        left = set(recursive_pass(cond.inputs[1], []))
        right = set(recursive_pass(cond.inputs[2], []))
        middle = left.intersection(right)
        left = left.difference(middle)
        right = right.difference(middle)
        middle = list(middle)
        left = list(left)
        right = list(right)

    var_str = {}
    var_id = {}
    all_strings = set()

    def var_name(var):
        if var in var_str:
            return var_str[var], var_id[var]

        if var.name is not None:
            if var_with_name_simple:
                varstr = var.name
            else:
                varstr = "name=" + var.name + " " + str(var.type)
        elif isinstance(var, Constant):
            dstr = "val=" + str(np.asarray(var.data))
            if "\n" in dstr:
                dstr = dstr[:dstr.index("\n")]
            varstr = f"{dstr} {var.type}"
        elif var in input_update and input_update[var].name is not None:
            varstr = input_update[var].name
            if not var_with_name_simple:
                varstr += str(var.type)
        else:
            # a var id is needed as otherwise var with the same type will be
            # merged in the graph.
            varstr = str(var.type)
        if len(varstr) > max_label_size:
            varstr = varstr[:max_label_size - 3] + "..."
        var_str[var] = varstr
        var_id[var] = str(id(var))

        all_strings.add(varstr)

        return varstr, var_id[var]

    apply_name_cache = {}
    apply_name_id = {}

    def apply_name(node):
        if node in apply_name_cache:
            return apply_name_cache[node], apply_name_id[node]
        prof_str = ""
        if profile:
            time = profile.apply_time.get((fgraph, node), 0)
            # second, %fct time in profiler
            if profile.fct_callcount == 0 or profile.fct_call_time == 0:
                pf = 0
            else:
                pf = time * 100 / profile.fct_call_time
            prof_str = f"   ({time:.3f}s,{pf:.3f}%)"
        applystr = str(node.op).replace(":", "_")
        applystr += prof_str
        if (applystr in all_strings) or with_ids:
            idx = " id=" + str(topo.index(node))
            if len(applystr) + len(idx) > max_label_size:
                applystr = applystr[:max_label_size - 3 -
                                    len(idx)] + idx + "..."
            else:
                applystr = applystr + idx
        elif len(applystr) > max_label_size:
            applystr = applystr[:max_label_size - 3] + "..."
            idx = 1
            while applystr in all_strings:
                idx += 1
                suffix = " id=" + str(idx)
                applystr = applystr[:max_label_size - 3 -
                                    len(suffix)] + "..." + suffix

        all_strings.add(applystr)
        apply_name_cache[node] = applystr
        apply_name_id[node] = str(id(node))

        return applystr, apply_name_id[node]

    # Update the inputs that have an update function
    input_update = {}
    reverse_input_update = {}
    # Here outputs can be the original list, as we should not change
    # it, we must copy it.
    outputs = list(outputs)
    if isinstance(fct, Function):
        function_inputs = zip(fct.maker.expanded_inputs, fgraph.inputs)
        for i, fg_ii in reversed(list(function_inputs)):
            if i.update is not None:
                k = outputs.pop()
                # Use the fgaph.inputs as it isn't the same as maker.inputs
                input_update[k] = fg_ii
                reverse_input_update[fg_ii] = k

    apply_shape = "ellipse"
    var_shape = "box"
    for node_idx, node in enumerate(topo):
        astr, aid = apply_name(node)

        use_color = None
        for opName, color in colorCodes.items():
            if opName in node.op.__class__.__name__:
                use_color = color

        if use_color is None:
            nw_node = Node(aid, label=astr, shape=apply_shape)
        elif high_contrast:
            nw_node = Node(aid,
                           label=astr,
                           style="filled",
                           fillcolor=use_color,
                           shape=apply_shape)
        else:
            nw_node = Node(aid, label=astr, color=use_color, shape=apply_shape)
        g.add_node(nw_node)
        if cond_highlight:
            if node in middle:
                c3.add_node(nw_node)
            elif node in left:
                c1.add_node(nw_node)
            elif node in right:
                c2.add_node(nw_node)

        for idx, var in enumerate(node.inputs):
            varstr, varid = var_name(var)
            label = ""
            if len(node.inputs) > 1:
                label = str(idx)
            param = {}
            if label:
                param["label"] = label
            if hasattr(node.op, "view_map") and idx in reduce(
                    list.__add__, node.op.view_map.values(), []):
                param["color"] = colorCodes["Output"]
            elif hasattr(node.op, "destroy_map") and idx in reduce(
                    list.__add__, node.op.destroy_map.values(), []):
                param["color"] = "red"
            if var.owner is None:
                color = "green"
                if isinstance(var, SharedVariable):
                    # Input are green, output blue
                    # Mixing blue and green give cyan! (input and output var)
                    color = "cyan"
                if high_contrast:
                    g.add_node(
                        Node(
                            varid,
                            style="filled",
                            fillcolor=color,
                            label=varstr,
                            shape=var_shape,
                        ))
                else:
                    g.add_node(
                        Node(varid, color=color, label=varstr,
                             shape=var_shape))
                g.add_edge(pd.Edge(varid, aid, **param))
            elif var.name or not compact or var in outputs:
                g.add_edge(pd.Edge(varid, aid, **param))
            else:
                # no name, so we don't make a var ellipse
                if label:
                    label += " "
                label += str(var.type)
                if len(label) > max_label_size:
                    label = label[:max_label_size - 3] + "..."
                param["label"] = label
                g.add_edge(pd.Edge(apply_name(var.owner)[1], aid, **param))

        for idx, var in enumerate(node.outputs):
            varstr, varid = var_name(var)
            out = var in outputs
            label = ""
            if len(node.outputs) > 1:
                label = str(idx)
            if len(label) > max_label_size:
                label = label[:max_label_size - 3] + "..."
            param = {}
            if label:
                param["label"] = label
            if out or var in input_update:
                g.add_edge(pd.Edge(aid, varid, **param))
                if high_contrast:
                    g.add_node(
                        Node(
                            varid,
                            style="filled",
                            label=varstr,
                            fillcolor=colorCodes["Output"],
                            shape=var_shape,
                        ))
                else:
                    g.add_node(
                        Node(
                            varid,
                            color=colorCodes["Output"],
                            label=varstr,
                            shape=var_shape,
                        ))
            elif len(fgraph.clients[var]) == 0:
                g.add_edge(pd.Edge(aid, varid, **param))
                # grey mean that output var isn't used
                if high_contrast:
                    g.add_node(
                        Node(
                            varid,
                            style="filled",
                            label=varstr,
                            fillcolor="grey",
                            shape=var_shape,
                        ))
                else:
                    g.add_node(
                        Node(varid,
                             label=varstr,
                             color="grey",
                             shape=var_shape))
            elif var.name or not compact:
                if not (not compact):
                    if label:
                        label += " "
                    label += str(var.type)
                    if len(label) > max_label_size:
                        label = label[:max_label_size - 3] + "..."
                    param["label"] = label
                g.add_edge(pd.Edge(aid, varid, **param))
                g.add_node(Node(varid, shape=var_shape, label=varstr))
    #            else:
    # don't add egde here as it is already added from the inputs.

    # The var that represent updates, must be linked to the input var.
    for sha, up in input_update.items():
        _, shaid = var_name(sha)
        _, upid = var_name(up)
        g.add_edge(
            pd.Edge(shaid, upid, label="UPDATE", color=colorCodes["Output"]))

    if cond_highlight:
        g.add_subgraph(c1)
        g.add_subgraph(c2)
        g.add_subgraph(c3)

    if not outfile.endswith("." + format):
        outfile += "." + format

    if scan_graphs:
        scan_ops = [(idx, x) for idx, x in enumerate(topo)
                    if isinstance(x.op, Scan)]
        path, fn = os.path.split(outfile)
        basename = ".".join(fn.split(".")[:-1])
        # Safe way of doing things .. a file name may contain multiple .
        ext = fn[len(basename):]

        for idx, scan_op in scan_ops:
            # is there a chance that name is not defined?
            if hasattr(scan_op.op, "name"):
                new_name = basename + "_" + scan_op.op.name + "_" + str(idx)
            else:
                new_name = basename + "_" + str(idx)
            new_name = os.path.join(path, new_name + ext)
            if hasattr(scan_op.op, "fn"):
                to_print = scan_op.op.fn
            else:
                to_print = scan_op.op.outputs
            pydotprint(
                to_print,
                new_name,
                compact,
                format,
                with_ids,
                high_contrast,
                cond_highlight,
                colorCodes,
                max_label_size,
                scan_graphs,
            )

    if return_image:
        return g.create(prog="dot", format=format)
    else:
        try:
            g.write(outfile, prog="dot", format=format)
        except pd.InvocationException:
            # based on https://github.com/Theano/Theano/issues/2988
            version = getattr(pd, "__version__", "")
            if version and [int(n) for n in version.split(".")] < [1, 0, 28]:
                raise Exception("Old version of pydot detected, which can "
                                "cause issues with pydot printing. Try "
                                "upgrading pydot version to a newer one")
            raise

        if print_output_file:
            print("The output file is available at", outfile)
示例#15
0
def map_variables(replacer, graphs, additional_inputs=None):
    """Construct new graphs based on 'graphs' with some variables replaced
    according to 'replacer'.

    :param replacer: function that takes a variable and returns its
         replacement.
    :param graphs: an iterable of graphs in which to replace variables
    :param additional_inputs: an iterable of graph inputs not used in any
         of 'graphs' but possibly used in the graphs returned by `replacer`
    :return: the new graphs, in the same order as 'graphs'

    Example:

    .. code-block:: python

        tag = "replaceme"

        a = aesara.tensor.type.scalar("a")
        b = aesara.tensor.type.scalar("b")
        c = aesara.tensor.type.scalar("c")

        ab = a + b
        ab.tag.replacement = a * b

        u = ab + c
        v, = map_variables(lambda graph:
            return getattr(graph.tag, "replacement", graph),
            [u])

        # v is now equal to a * b + c
    """
    if additional_inputs is None:
        additional_inputs = []

    # wrap replacer to avoid replacing things we just put there.
    graphs_seen = set()

    def wrapped_replacer(graph):
        if graph in graphs_seen:
            return graph
        else:
            new_graph = replacer(graph)
            graphs_seen.add(new_graph)
            return new_graph

    graphs = list(graphs)
    inputs_ = list(set(list(graph_inputs(graphs)) + list(additional_inputs)))

    # perform any desired replacement of input variables.  these
    # aren't replaced by the local optimizer approach because they are
    # not outputs of any Apply node.
    new_inputs = [wrapped_replacer(i) for i in inputs_]
    replacements = [(input_, new_input)
                    for input_, new_input in zip(inputs_, new_inputs)
                    if new_input is not input_]
    graphs = clone_replace(graphs, share_inputs=True, replace=replacements)
    inputs_ = list(set(list(graph_inputs(graphs)) + list(additional_inputs)))

    fg = FunctionGraph(inputs_, graphs, clone=False)

    nodes_seen = set()

    @local_optimizer(None)
    def local_transform(fgraph, node):
        if node in nodes_seen:
            return False

        # importing Scan into module scope would be circular
        from aesara.compile.builders import OpFromGraph
        from aesara.scan.op import Scan

        if isinstance(node.op, (Scan, OpFromGraph)):
            # recurse on the inner graph
            (
                new_inner_inputs,
                new_outer_inputs,
                new_inner_outputs,
            ) = _map_variables_inner(
                wrapped_replacer,
                inner_inputs=node.op.inputs,
                outer_inputs=node.inputs,
                inner_outputs=node.op.outputs,
                containing_op=node.op,
            )
            # reinstantiate the op
            if isinstance(node.op, Scan):
                new_op = Scan(
                    new_inner_inputs,
                    new_inner_outputs,
                    node.op.info,
                    node.op.mode,
                    # FIXME: infer this someday?
                    typeConstructor=None,
                )
            elif isinstance(node.op, OpFromGraph):
                new_op = OpFromGraph(new_inner_inputs, new_inner_outputs,
                                     **node.op.kwargs)
            # make a new node to replace the old one
            new_node = new_op.make_node(*new_outer_inputs)
            nodes_seen.add(new_node)
            return new_node.outputs
        else:
            nodes_seen.add(node)
            replacements = [wrapped_replacer(o) for o in node.outputs]

            # Add inputs to replacement graphs as inputs to this `fgraph`
            for i in graph_inputs(replacements):
                fgraph.add_input(i)

            return replacements

    topo_transform = TopoOptimizer(local_transform, "out_to_in")
    topo_transform.optimize(fg)

    new_graphs = fg.outputs
    fg.disown()
    return new_graphs
示例#16
0
文件: fg.py 项目: mgorny/aesara
    def __init__(
        self,
        inputs: Optional[Sequence[Variable]] = None,
        outputs: Optional[Sequence[Variable]] = None,
        features: Optional[Sequence[Feature]] = None,
        clone: bool = True,
        update_mapping: Optional[Dict[Variable, Variable]] = None,
        **clone_kwds,
    ):
        """
        Create a `FunctionGraph` which operates on the subgraph between the
        `inputs` and `outputs`.

        Parameters
        ----------
        inputs
            Input variables of the graph.
        outputs
            Output variables of the graph.
        features
            A list of features to be added to the `FunctionGraph`.
        clone
            If ``True``, the graph will be cloned.
        update_mapping
            Mapping between the `inputs` with updates and the `outputs`
            corresponding to their updates.
        clone_kwds
            Keywords passed to `clone_get_equiv` when `clone` is ``True``.
        """
        if outputs is None:
            raise ValueError("No outputs specified")

        if inputs is None:
            inputs = [
                i for i in graph_inputs(outputs)
                if not isinstance(i, AtomicVariable)
            ]

        if clone:
            _memo = clone_get_equiv(
                inputs,
                outputs,
                **clone_kwds,
            )
            outputs = [cast(Variable, _memo[o]) for o in outputs]
            inputs = [cast(Variable, _memo[i]) for i in inputs]

        self.execute_callbacks_time: float = 0.0
        self.execute_callbacks_times: Dict[Feature, float] = {}

        if features is None:
            features = []

        self._features: List[Feature] = []

        # All apply nodes in the subgraph defined by inputs and
        # outputs are cached in this field
        self.apply_nodes: Set[Apply] = set()

        # It includes inputs, outputs, and all intermediate variables
        # connecting the inputs and outputs.  It also contains irrelevant
        # outputs the nodes in `self.apply_nodes`.
        self.variables: Set[Variable] = set()

        self.inputs: List[Variable] = []
        self.outputs: List[Variable] = []
        self.clients: Dict[Variable, List[ClientType]] = {}

        for f in features:
            self.attach_feature(f)

        self.attach_feature(ReplaceValidate())

        for in_var in inputs:
            if in_var.owner is not None:
                raise ValueError("One of the provided inputs is the output of "
                                 "an already existing node. "
                                 "If that is okay, either discard that "
                                 "input's owner or use graph.clone.")

            self.add_input(in_var, check=False)

        for output in outputs:
            self.add_output(output, reason="init")

        self.profile = None
        self.update_mapping = update_mapping
示例#17
0
    def __init__(
        self,
        inputs,
        outputs,
        inline=False,
        lop_overrides="default",
        grad_overrides="default",
        rop_overrides="default",
        connection_pattern=None,
        name=None,
        **kwargs,
    ):
        if not isinstance(outputs, list):
            raise TypeError(f"outputs must be list, got {type(outputs)}")
        for i in inputs + outputs:
            if not isinstance(i, Variable):
                raise TypeError(
                    "inputs and outputs must be Variable instances", i)
        if "updates" in kwargs or "givens" in kwargs:
            raise TypeError("updates and givens are not allowed here")
        self.is_inline = inline
        # To correctly support shared variables the inner fct should
        # not see them. Otherwise there is a problem with the gradient.
        self.shared_inputs = [
            var for var in graph_inputs(outputs)
            if isinstance(var, SharedVariable)
        ]
        shared_vars = [var.type() for var in self.shared_inputs]

        new = rebuild_collect_shared(
            outputs,
            inputs=inputs + shared_vars,
            replace=dict(zip(self.shared_inputs, shared_vars)),
            copy_inputs_over=False,
        )
        (
            local_inputs,
            local_outputs,
            [clone_d, update_d, update_expr, shared_inputs],
        ) = new
        assert len(local_inputs) == len(inputs) + len(self.shared_inputs)
        assert len(local_outputs) == len(outputs)
        assert not update_d
        assert not update_expr
        assert not shared_inputs

        self.local_inputs = local_inputs
        self.local_outputs = local_outputs
        self.inputs = inputs
        self.outputs = outputs
        self.kwargs = kwargs
        self.input_types = [inp.type for inp in inputs]
        self.output_types = [out.type for out in outputs]
        if lop_overrides != "default":
            if grad_overrides != "default":
                raise ValueError(
                    "lop_overrides and grad_overrides are mutually exclusive")
            else:
                self.set_lop_overrides(lop_overrides)
                self._lop_type = "lop"
        elif grad_overrides != "default":
            self.set_lop_overrides(grad_overrides)
            self._lop_type = "grad"
        else:
            self.set_lop_overrides("default")
            self._lop_type = "lop"
        self.set_rop_overrides(rop_overrides)

        self._connection_pattern = connection_pattern

        if name is not None:
            assert isinstance(name, str), "name must be None or string object"
        self.name = name
示例#18
0
def logpt(
    var: TensorVariable,
    rv_values: Optional[Union[TensorVariable, Dict[TensorVariable,
                                                   TensorVariable]]] = None,
    *,
    jacobian: bool = True,
    scaling: bool = True,
    transformed: bool = True,
    cdf: bool = False,
    sum: bool = False,
    **kwargs,
) -> TensorVariable:
    """Create a measure-space (i.e. log-likelihood) graph for a random variable at a given point.

    The input `var` determines which log-likelihood graph is used and
    `rv_value` is that graph's input parameter.  For example, if `var` is
    the output of a ``NormalRV`` ``Op``, then the output is a graph of the
    density function for `var` set to the value `rv_value`.

    Parameters
    ==========
    var
        The `RandomVariable` output that determines the log-likelihood graph.
    rv_values
        A variable, or ``dict`` of variables, that represents the value of
        `var` in its log-likelihood.  If no `rv_value` is provided,
        ``var.tag.value_var`` will be checked and, when available, used.
    jacobian
        Whether or not to include the Jacobian term.
    scaling
        A scaling term to apply to the generated log-likelihood graph.
    transformed
        Apply transforms.
    cdf
        Return the log cumulative distribution.
    sum
        Sum the log-likelihood.

    """
    if not isinstance(rv_values, Mapping):
        rv_values = {var: rv_values} if rv_values is not None else {}

    rv_var, rv_value_var = extract_rv_and_value_vars(var)

    rv_value = rv_values.get(rv_var, rv_value_var)

    if rv_var is not None and rv_value is None:
        raise ValueError(
            f"No value variable specified or associated with {rv_var}")

    if rv_value is not None:
        rv_value = at.as_tensor(rv_value)

        if rv_var is not None:
            # Make sure that the value is compatible with the random variable
            rv_value = rv_var.type.filter_variable(
                rv_value.astype(rv_var.dtype))

        if rv_value_var is None:
            rv_value_var = rv_value

    if rv_var is None:
        if var.owner is not None:
            return _logp(
                var.owner.op,
                var,
                rv_values,
                *var.owner.inputs,
                jacobian=jacobian,
                scaling=scaling,
                transformed=transformed,
                cdf=cdf,
                sum=sum,
            )

        return at.zeros_like(var)

    rv_node = rv_var.owner

    rng, size, dtype, *dist_params = rv_node.inputs

    # Here, we plug the actual random variable into the log-likelihood graph,
    # because we want a log-likelihood graph that only contains
    # random variables.  This is important, because a random variable's
    # parameters can contain random variables themselves.
    # Ultimately, with a graph containing only random variables and
    # "deterministics", we can simply replace all the random variables with
    # their value variables and be done.
    tmp_rv_values = rv_values.copy()
    tmp_rv_values[rv_var] = rv_var

    if not cdf:
        logp_var = _logp(rv_node.op, rv_var, tmp_rv_values, *dist_params,
                         **kwargs)
    else:
        logp_var = _logcdf(rv_node.op, rv_var, tmp_rv_values, *dist_params,
                           **kwargs)

    transform = getattr(rv_value_var.tag, "transform",
                        None) if rv_value_var else None

    if transform and transformed and not cdf and jacobian:
        transformed_jacobian = transform.jacobian_det(rv_var, rv_value)
        if transformed_jacobian:
            if logp_var.ndim > transformed_jacobian.ndim:
                logp_var = logp_var.sum(axis=-1)
            logp_var += transformed_jacobian

    # Replace random variables with their value variables
    replacements = rv_values.copy()
    replacements.update({rv_var: rv_value, rv_value_var: rv_value})

    (logp_var, ), _ = rvs_to_value_vars(
        (logp_var, ),
        apply_transforms=transformed and not cdf,
        initial_replacements=replacements,
    )

    if sum:
        logp_var = at.sum(logp_var)

    if scaling:
        logp_var *= _get_scaling(getattr(rv_var.tag, "total_size", None),
                                 rv_value.shape, rv_value.ndim)

    # Recompute test values for the changes introduced by the replacements
    # above.
    if config.compute_test_value != "off":
        for node in io_toposort(graph_inputs((logp_var, )), (logp_var, )):
            compute_test_value(node)

    if rv_var.name is not None:
        logp_var.name = "__logp_%s" % rv_var.name

    return logp_var