示例#1
0
def is_same_graph(var1, var2, givens=None, debug=False):
    """
    Return True iff Variables `var1` and `var2` perform the same computation.

    By 'performing the same computation', we mean that they must share the same
    graph, so that for instance this function will return False when comparing
    (x * (y * z)) with ((x * y) * z).

    The current implementation is not efficient since, when possible, it
    verifies equality by calling two different functions that are expected to
    return the same output. The goal is to verify this assumption, to
    eventually get rid of one of them in the future.

    Parameters
    ----------
    var1
        The first Variable to compare.
    var2
        The second Variable to compare.
    givens
        Similar to the `givens` argument of `theano.function`, it can be used
        to perform substitutions in the computational graph of `var1` and
        `var2`. This argument is associated to neither `var1` nor `var2`:
        substitutions may affect both graphs if the substituted variable
        is present in both.
    debug : bool
        If True, then an exception is raised when we are in a situation where
        the `equal_computations` implementation cannot be called.
        This parameter is intended to be used in tests only, to make sure we
        properly test both implementations.

    Examples
    --------

        ======  ======  ======  ======
        var1    var2    givens  output
        ======  ======  ======  ======
        x + 1   x + 1   {}      True
        x + 1   y + 1   {}      False
        x + 1   y + 1   {x: y}  True
        ======  ======  ======  ======

    """
    # Lazy import.
    if givens is None:
        givens = {}
    global equal_computations, is_same_graph_with_merge
    if equal_computations is None:
        from theano.gof.opt import is_same_graph_with_merge
        from theano.scan_module.scan_utils import equal_computations
    # Convert `givens` to dictionary.
    if not isinstance(givens, dict):
        givens = dict(givens)
    # Get result from the merge-based function.
    rval1 = is_same_graph_with_merge(var1=var1, var2=var2, givens=givens)
    # Get result from the function `equal_computations` from scan_utils.

    use_equal_computations = True
    if givens:
        # We need to build the `in_xs` and `in_ys` lists. To do this, we need
        # to be able to tell whether a variable belongs to the computational
        # graph of `var1` or `var2`.
        # The typical case we want to handle is when `to_replace` belongs to
        # one of these graphs, and `replace_by` belongs to the other one. In
        # other situations, the current implementation of `equal_computations`
        # is probably not appropriate, so we do not call it.
        ok = True
        in_xs = []
        in_ys = []
        # Compute the sets of all variables found in each computational graph.
        inputs_var = list(map(inputs, ([var1], [var2])))
        all_vars = [
            set(variables(v_i, v_o))
            for v_i, v_o in ((inputs_var[0], [var1]), (inputs_var[1], [var2]))
        ]

        def in_var(x, k):
            # Return True iff `x` is in computation graph of variable `vark`.
            return x in all_vars[k - 1]

        for to_replace, replace_by in iteritems(givens):
            # Map a substitution variable to the computational graphs it
            # belongs to.
            inside = dict((v, [in_var(v, k) for k in (1, 2)])
                          for v in (to_replace, replace_by))
            if (inside[to_replace][0] and not inside[to_replace][1]
                    and inside[replace_by][1] and not inside[replace_by][0]):
                # Substitute variable in `var1` by one from `var2`.
                in_xs.append(to_replace)
                in_ys.append(replace_by)
            elif (inside[to_replace][1] and not inside[to_replace][0]
                  and inside[replace_by][0] and not inside[replace_by][1]):
                # Substitute variable in `var2` by one from `var1`.
                in_xs.append(replace_by)
                in_ys.append(to_replace)
            else:
                ok = False
                break
        if not ok:
            # We cannot directly use `equal_computations`.
            if debug:
                raise AssertionError(
                    'When `debug` is True we want to make sure we are also '
                    'using the `equal_computations` implementation')
            use_equal_computations = False
    else:
        in_xs = None
        in_ys = None
    if use_equal_computations:
        rval2 = equal_computations(xs=[var1],
                                   ys=[var2],
                                   in_xs=in_xs,
                                   in_ys=in_ys)
        assert rval2 == rval1
    return rval1
示例#2
0
文件: graph.py 项目: gyenney/Tools
def is_same_graph(var1, var2, givens=None, debug=False):
    """
    Return True iff Variables `var1` and `var2` perform the same computation.

    By 'performing the same computation', we mean that they must share the same
    graph, so that for instance this function will return False when comparing
    (x * (y * z)) with ((x * y) * z).

    The current implementation is not efficient since, when possible, it
    verifies equality by calling two different functions that are expected to
    return the same output. The goal is to verify this assumption, to
    eventually get rid of one of them in the future.

    :param var1: The first Variable to compare.

    :param var2: The second Variable to compare.

    :param givens: Similar to the `givens` argument of `theano.function`, it
    can be used to perform substitutions in the computational graph of `var1`
    and `var2`. This argument is associated to neither `var1` nor `var2`:
    substitutions may affect both graphs if the substituted variable is present
    in both.

    :param debug: If True, then an exception is raised when we are in a
    situation where the `equal_computations` implementation cannot be called.
    This parameter is intended to be used in tests only, to make sure we
    properly test both implementations.

    Examples:

        ======  ======  ======  ======
        var1    var2    givens  output
        ======  ======  ======  ======
        x + 1   x + 1   {}      True
        x + 1   y + 1   {}      False
        x + 1   y + 1   {x: y}  True
        ======  ======  ======  ======
    """
    # Lazy import.
    if givens is None:
        givens = {}
    global equal_computations, is_same_graph_with_merge
    if equal_computations is None:
        from theano.gof.opt import is_same_graph_with_merge
        from theano.scan_module.scan_utils import equal_computations
    # Convert `givens` to dictionary.
    if not isinstance(givens, dict):
        givens = dict(givens)
    # Get result from the merge-based function.
    rval1 = is_same_graph_with_merge(var1=var1, var2=var2, givens=givens)
    # Get result from the function `equal_computations` from scan_utils.

    use_equal_computations = True
    if givens:
        # We need to build the `in_xs` and `in_ys` lists. To do this, we need
        # to be able to tell whether a variable belongs to the computational
        # graph of `var1` or `var2`.
        # The typical case we want to handle is when `to_replace` belongs to
        # one of these graphs, and `replace_by` belongs to the other one. In
        # other situations, the current implementation of `equal_computations`
        # is probably not appropriate, so we do not call it.
        ok = True
        in_xs = []
        in_ys = []
        # Compute the sets of all variables found in each computational graph.
        inputs_var = map(inputs, ([var1], [var2]))
        all_vars = [set(variables(v_i, v_o))
                    for v_i, v_o in ((inputs_var[0], [var1]),
                                     (inputs_var[1], [var2]))]

        def in_var(x, k):
            # Return True iff `x` is in computation graph of variable `vark`.
            return x in all_vars[k - 1]

        for to_replace, replace_by in givens.iteritems():
            # Map a substitution variable to the computational graphs it
            # belongs to.
            inside = dict((v, [in_var(v, k) for k in (1, 2)])
                          for v in (to_replace, replace_by))
            if (inside[to_replace][0] and not inside[to_replace][1] and
                inside[replace_by][1] and not inside[replace_by][0]):
                # Substitute variable in `var1` by one from `var2`.
                in_xs.append(to_replace)
                in_ys.append(replace_by)
            elif (inside[to_replace][1] and not inside[to_replace][0] and
                  inside[replace_by][0] and not inside[replace_by][1]):
                # Substitute variable in `var2` by one from `var1`.
                in_xs.append(replace_by)
                in_ys.append(to_replace)
            else:
                ok = False
                break
        if not ok:
            # We cannot directly use `equal_computations`.
            if debug:
                raise AssertionError(
                    'When `debug` is True we want to make sure we are also '
                    'using the `equal_computations` implementation')
            use_equal_computations = False
    else:
        in_xs = None
        in_ys = None
    if use_equal_computations:
        rval2 = equal_computations(xs=[var1], ys=[var2],
                                   in_xs=in_xs, in_ys=in_ys)
        assert rval2 == rval1
    return rval1