예제 #1
0
파일: pfunc.py 프로젝트: wjbianjason/Theano
def rebuild_collect_shared(outputs,
                           inputs=None,
                           replace=None,
                           updates=None,
                           rebuild_strict=True,
                           copy_inputs_over=True,
                           no_default_updates=False,
                           ):
    """
    Function that allows replacing subgraphs of a computational graph.

    It returns a set of dictionaries and lists which collect (partial?)
    different information about shared variables. This info is required by
    `pfunc`.

    Parameters
    ----------
    outputs : list of Theano Variables (or Theano expressions)
        List of Theano variables or expressions representing the outputs of the
        computational graph.
    inputs : list of Theano Variables (or Theano expressions)
        List of Theano variables or expressions representing the inputs of the
        computational graph (or None).
    replace : dict
        Dictionary describing which subgraphs should be replaced by what.
        orig_value => new_value
    updates : dict
        Dictionary describing updates expressions for shared variables.
    rebuild_strict : bool
        Flag, if true the type of all inputs should be the same as the one for
        the current node.
    copy_inputs_over : bool
        Flag; if False it will clone inputs.
    no_default_updates : either bool or list of Variables
        If True, do not perform any automatic update on Variables.
        If False (default), perform them all.
        Else, perform automatic updates on all Variables that are neither in
        "updates" nor in "no_default_updates".

    """

    if isinstance(outputs, tuple):
        outputs = list(outputs)

    # This function implements similar functionality as graph.clone
    # and it should be merged with that
    clone_d = {}
    update_d = {}
    update_expr = []
    # list of shared inputs that are used as inputs of the graph
    shared_inputs = []

    def clone_v_get_shared_updates(v, copy_inputs_over):
        """
        Clones a variable and its inputs recursively until all are in clone_d.
        Also appends all shared variables met along the way to shared inputs,
        and their default_update (if applicable) to update_d and update_expr.

        v can have an fgraph attached to it, case in which we want to clone
        constants (to avoid having a constant belonging to two fgraphs).

        """
        # this co-recurses with clone_a
        assert v is not None
        if v in clone_d:
            return clone_d[v]
        if v.owner:
            clone_a(v.owner, copy_inputs_over)
            return clone_d.setdefault(v, v)
        elif isinstance(v, SharedVariable):
            if v not in shared_inputs:
                shared_inputs.append(v)
            if hasattr(v, 'default_update'):
                # Check that v should not be excluded from the default
                # updates list
                if (no_default_updates is False or
                    (isinstance(no_default_updates, list) and
                     v not in no_default_updates)):
                    # Do not use default_update if a "real" update was
                    # provided
                    if v not in update_d:
                        v_update = v.type.filter_variable(v.default_update,
                                                          allow_convert=False)
                        if v_update.type != v.type:
                            raise TypeError(
                                'an update must have the same type as '
                                'the original shared variable',
                                (v, v.type, v_update, v_update.type))
                        update_d[v] = v_update
                        update_expr.append((v, v_update))
        if not copy_inputs_over or (isinstance(v, Constant) and
                                    hasattr(v, 'fgraph')):
            # Cloning shared variables implies copying their underlying
            # memory buffer ?? No.
            return clone_d.setdefault(v, v.clone())
        else:
            return clone_d.setdefault(v, v)

    def clone_a(a, copy_inputs_over):
        """
        Clones a variable and its inputs recursively until all are in
        clone_d. It occures with clone_v_get_shared_updates.

        """
        if a is None:
            return None
        if a not in clone_d:
            for i in a.inputs:
                clone_v_get_shared_updates(i, copy_inputs_over)

            clone_d[a] = a.clone_with_new_inputs([clone_d[i] for i in
                                                  a.inputs],
                                                 strict=rebuild_strict)
            for old_o, new_o in zip(a.outputs, clone_d[a].outputs):
                clone_d.setdefault(old_o, new_o)
        return clone_d[a]

    # intialize the clone_d mapping with the replace dictionary
    if replace is None:
        replace = []
    try:
        replace_pairs = list(replace.items())
    except Exception:
        replace_pairs = replace

    for v_orig, v_repl in replace_pairs:
        if not isinstance(v_orig, Variable):
            raise TypeError('given keys must be Variable', v_orig)
        if not isinstance(v_repl, Variable):
            v_repl = shared(v_repl)

        if v_orig in clone_d:
            raise AssertionError(
                "When using 'givens' or 'replace' with several "
                "(old_v, new_v) replacement pairs, you can not have a "
                "new_v variable depend on an old_v one. For instance, "
                "givens = {a:b, b:(a+1)} is not allowed. Here, the old_v "
                "%s is used to compute other new_v's, but it is scheduled "
                "to be replaced by %s." % (v_orig, v_repl))

        clone_d[v_orig] = clone_v_get_shared_updates(v_repl,
                                                     copy_inputs_over)

    if inputs is None:
        inputs = []

    def clone_inputs(i):
        if not copy_inputs_over:
            return clone_d.setdefault(i, i.clone())
        else:
            return clone_d.setdefault(i, i)

    input_variables = [clone_inputs(i) for i in inputs]

    # It was decided, as a first step, to prevent shared variables from
    # being used as function inputs. Although it is technically possible,
    # it is also not clear when/how to use the value of that shared
    # variable (is it a default? ignored?, if the shared variable changes,
    # does that function default also change?).
    for v in input_variables:
        if isinstance(v, SharedVariable):
            raise TypeError(('Cannot use a shared variable (%s) as explicit '
                             'input. Consider substituting a non-shared'
                             ' variable via the `givens` parameter') % v)

    # Fill update_d and update_expr with provided updates
    if updates is None:
        updates = []
    for (store_into, update_val) in iter_over_pairs(updates):
        if not isinstance(store_into, SharedVariable):
            raise TypeError('update target must be a SharedVariable',
                            store_into)
        if store_into in update_d:
            raise ValueError('this shared variable already has an update '
                             'expression',
                             (store_into, update_d[store_into]))

        # filter_variable ensure smooth conversion of cpu/gpu Types
        try:
            update_val = store_into.type.filter_variable(update_val,
                                                         allow_convert=False)
        except TypeError:
            err_msg = ('An update must have the same type as the'
                       ' original shared variable (shared_var=%s,'
                       ' shared_var.type=%s,'
                       ' update_val=%s, update_val.type=%s).' % (
                           store_into,
                           store_into.type,
                           update_val,
                           update_val.type))
            err_sug = ('If the difference is related to the broadcast pattern,'
                       ' you can call the'
                       ' tensor.unbroadcast(var, axis_to_unbroadcast[, ...])'
                       ' function to remove broadcastable dimensions.')

            raise TypeError(err_msg, err_sug)
        assert update_val.type == store_into.type

        update_d[store_into] = update_val
        update_expr.append((store_into, update_val))

    # Elements of "outputs" are here cloned to "cloned_outputs"
    if isinstance(outputs, list):
        cloned_outputs = []
        for v in outputs:
            if isinstance(v, Variable):
                cloned_v = clone_v_get_shared_updates(v, copy_inputs_over)
                cloned_outputs.append(cloned_v)
            elif isinstance(v, Out):
                cloned_v = clone_v_get_shared_updates(v.variable,
                                                      copy_inputs_over)
                cloned_outputs.append(Out(cloned_v, borrow=v.borrow))
            else:
                raise TypeError('Outputs must be theano Variable or '
                                'Out instances. Received ' + str(v) +
                                ' of type ' + str(type(v)))
            # computed_list.append(cloned_v)
    else:
        if isinstance(outputs, Variable):
            cloned_v = clone_v_get_shared_updates(outputs, copy_inputs_over)
            cloned_outputs = cloned_v
            # computed_list.append(cloned_v)
        elif isinstance(outputs, Out):
            cloned_v = clone_v_get_shared_updates(outputs.variable,
                                                  copy_inputs_over)
            cloned_outputs = Out(cloned_v, borrow=outputs.borrow)
            # computed_list.append(cloned_v)
        elif outputs is None:
            cloned_outputs = []  # TODO: get Function.__call__ to return None
        else:
            raise TypeError('output must be a theano Variable or Out '
                            'instance (or list of them)',
                            outputs)

    # Iterate over update_expr, cloning its elements, and updating
    # shared_inputs, update_d and update_expr from the SharedVariables
    # we discover.
    # If the variable to be updated is a shared variable not already
    # in shared_inputs, add it.
    # Note: we extend update_expr while iterating over it.

    i = 0
    while i < len(update_expr):
        v, v_update = update_expr[i]
        cloned_v_update = clone_v_get_shared_updates(v_update,
                                                     copy_inputs_over)
        update_d[v] = cloned_v_update
        if isinstance(v, SharedVariable) and v not in shared_inputs:
            shared_inputs.append(v)
        i += 1

    return (input_variables, cloned_outputs,
            [clone_d, update_d, update_expr, shared_inputs])
예제 #2
0
        assert update_val.type == store_into.type

        update_d[store_into] = update_val
        update_expr.append((store_into, update_val))

    # Elements of "outputs" are here cloned to "cloned_outputs"
    if isinstance(outputs, list):
        cloned_outputs = []
        for v in outputs:
            if isinstance(v, Variable):
                cloned_v = clone_v_get_shared_updates(v, copy_inputs_over)
                cloned_outputs.append(cloned_v)
            elif isinstance(v, Out):
                cloned_v = clone_v_get_shared_updates(v.variable,
                                                      copy_inputs_over)
                cloned_outputs.append(Out(cloned_v, borrow=v.borrow))
            else:
                raise TypeError('Outputs must be theano Variable or '
                                'Out instances. Received ' + str(v) +
                                ' of type ' + str(type(v)))
            #computed_list.append(cloned_v)
    else:
        if isinstance(outputs, Variable):
            cloned_v = clone_v_get_shared_updates(outputs, copy_inputs_over)
            cloned_outputs = cloned_v
            #computed_list.append(cloned_v)
        elif isinstance(outputs, Out):
            cloned_v = clone_v_get_shared_updates(outputs.variable,
                                                  copy_inputs_over)
            cloned_outputs = Out(cloned_v, borrow=outputs.borrow)
            #computed_list.append(cloned_v)
예제 #3
0
파일: pfunc.py 프로젝트: thiboeri/Theano
def rebuild_collect_shared( outputs
                           , inputs             = None
                           , replace            = None
                           , updates            = None
                           , rebuild_strict     = True
                           , copy_inputs_over   = True
                           , no_default_updates = False
                          ):
    """
    Function that allows replacing subgraphs of a computational
    graph.

    It returns a set of dictionaries and lists which collect (partial?)
    different information about shared variables. This info is required by
    `pfunc`.


    :type outputs: list of Theano Variables ( or Theano expressions)
    :param outputs: list of Theano variables or expressions representing the
                    outputs of the computational graph

    :type inputs: list of Theano Variables ( or Theano expressions)
    :param inputs: list of Theano variables or expressions representing the
                    inputs of the computational graph (or None)
    :type replace: dict
    :param replace: dictionary describing which subgraphs should be
                    replaced by what

    :type updates: dict
    :param updates: dictionary describing updates expressions for shared
                    variables

    :type rebuild_strict: bool
    :param rebuild_strict: flag, if true the type of all inputs should be
                            the same as the for the current node

    :type copy_inputs_over: bool
    :param copy_inputs_over: flag; if False it will clone inputs

    :type no_default_updates: either bool or list of Variables
    :param no_default_updates: if True, do not perform any automatic update
                               on Variables. If False (default), perform
                               them all. Else, perform automatic updates
                               on all Variables that are neither in
                               "updates" nor in "no_default_updates".

    """

    if isinstance(outputs,tuple):
        outputs = list(outputs)

    ## This function implements similar functionality as graph.clone
    ## and it should be merged with that
    clone_d = {}
    update_d = {}
    update_expr = []
    # list of shared inputs that are used as inputs of the graph
    shared_inputs = []


    def clone_v_get_shared_updates(v, copy_inputs_over):
        '''
        Clones a variable and its inputs recursively until all are in
        clone_d. Also appends all shared variables met along the way to
        shared inputs, and their default_update (if applicable) to update_d
        and update_expr.

        v can have an env attached to it, case in which we want to clone
        constants ( to avoid having a constant belonging to two envs)
        '''
        # this co-recurses with clone_a
        assert v is not None
        if v in clone_d:
            return clone_d[v]
        if v.owner:
            clone_a(v.owner, copy_inputs_over)
            return clone_d.setdefault(v,v)
        elif isinstance(v, SharedVariable):
            if v not in shared_inputs:
                shared_inputs.append(v)
            if hasattr(v, 'default_update'):
                # Check that v should not be excluded from the default
                # updates list
                if    ( no_default_updates is False or
                        ( isinstance(no_default_updates, list) and
                          v not in no_default_updates
                        )
                      ):
                    # Do not use default_update if a "real" update was
                    # provided
                    if v not in update_d:
                        v_update = v.filter_update(v.default_update)
                        if v_update.type != v.type:
                            raise TypeError(
                                ( 'an update must have the same type as '
                                  'the original shared variable'  )
                                , (v, v.type, v_update, v_update.type))
                        update_d[v] = v_update
                        update_expr.append((v, v_update))
        if not copy_inputs_over or (isinstance(v, Constant) and
                                    hasattr(v,'env')):
            ### Cloning shared variables implies copying their underlying
            ### memory buffer ?? No.
            return clone_d.setdefault(v,v.clone())
        else:
            return clone_d.setdefault(v,v)

    def clone_a(a, copy_inputs_over):
        '''
        Clones a variable and its inputs recursively until all are in
        clone_d. It occures with clone_v_get_shared_updates
        '''
        if a is None:
            return None
        if a not in clone_d:
            for i in a.inputs:
                clone_v_get_shared_updates(i, copy_inputs_over)

            clone_d[a] = a.clone_with_new_inputs([clone_d[i] for i in
                                                  a.inputs],
                                                 strict = rebuild_strict)
            for old_o, new_o in zip(a.outputs, clone_d[a].outputs):
                clone_d.setdefault(old_o,new_o)
        return clone_d[a]


    # intialize the clone_d mapping with the replace dictionary
    if replace is None:
        replace = []
    try:
        replace_pairs = replace.items()
    except Exception:
        replace_pairs = replace

    for v_orig, v_repl in replace_pairs:
        if not isinstance(v_orig,Variable):
            raise TypeError('given keys must be Variable', v_orig)
        if not isinstance(v_repl,Variable):
            v_repl = shared(v_repl)
        assert v_orig not in clone_d
        clone_d[v_orig] = clone_v_get_shared_updates(v_repl,
                                                     copy_inputs_over)

    if inputs is None:
        inputs = []

    def clone_inputs(i):
        if not copy_inputs_over:
            return clone_d.setdefault(i,i.clone())
        else:
            return clone_d.setdefault(i,i)

    input_variables = [clone_inputs(i) for i in inputs]

    # It was decided, as a first step, to prevent shared variables from
    # being used as function inputs. Although it is technically possible,
    # it is also not clear when/how to use the value of that shared
    # variable (is it a default? ignored?, if the shared variable changes,
    # does that function default also change?).
    if numpy.any([isinstance(v, SharedVariable) for v in input_variables]):
        raise TypeError(('Cannot use a shared variable (%s) as explicit '
                         'input. Consider substituting a non-shared'
                         ' variable via the `givens` parameter') % v)

    # Fill update_d and update_expr with provided updates
    if updates is None:
        updates = []
    for (store_into, update_val) in iter_over_pairs(updates):
        if not isinstance(store_into, SharedVariable):
            raise TypeError('update target must be a SharedVariable'
                            , store_into)
        if store_into in update_d:
            raise ValueError(('this shared variable already has an update '
                              'expression'),
                              (store_into, update_d[store_into]))

        update_val = store_into.filter_update(update_val)
                                        # typically this might be a cast()
        if update_val.type != store_into.type:
            err_msg  = ( 'an update must have the same type as the '
                        'original shared variable(dest, dest.type, '
                        'update_val, update_val.type)')
            err_arg = ( store_into
                       , store_into.type
                       , update_val
                       , update_val.type)

            raise TypeError(err_msg, err_arg )
        update_d[store_into] = update_val
        update_expr.append((store_into, update_val))

    # Elements of "outputs" are here cloned to "cloned_outputs"
    if isinstance(outputs, list):
        cloned_outputs = []
        for v in outputs:
            if isinstance(v, Variable):
                cloned_v = clone_v_get_shared_updates(v, copy_inputs_over)
                cloned_outputs.append(cloned_v)
            elif isinstance(v, Out):
                cloned_v = clone_v_get_shared_updates(v.variable,
                                                      copy_inputs_over)
                cloned_outputs.append(Out(cloned_v, borrow=v.borrow))
            else:
                raise TypeError( ( 'outputs must be theano Variable or '
                                  'Out instances'), v)
            #computed_list.append(cloned_v)
    else:
        if isinstance(outputs, Variable):
            cloned_v = clone_v_get_shared_updates(outputs, copy_inputs_over)
            cloned_outputs = cloned_v
            #computed_list.append(cloned_v)
        elif isinstance(outputs, Out):
            cloned_v = clone_v_get_shared_updates(outputs.variable,
                                                  copy_inputs_over)
            cloned_outputs = Out(cloned_v, borrow=outputs.borrow)
            #computed_list.append(cloned_v)
        elif outputs is None:
            cloned_outputs = [] # TODO: get Function.__call__ to return None
        else:
            raise TypeError( ('output must be a theano Variable or Out '
                              'instance (or list of them)')
                            , outputs)


    # Iterate over update_expr, cloning its elements, and updating
    # shared_inputs, update_d and update_expr from the SharedVariables
    # we discover.
    # If the variable to be updated is a shared variable not already
    # in shared_inputs, add it.
    # Note: we extend update_expr while iterating over it.

    i = 0
    while i<len(update_expr):
        v, v_update = update_expr[i]
        cloned_v_update = clone_v_get_shared_updates(v_update,
                                                     copy_inputs_over)
        update_d[v] = cloned_v_update
        if isinstance(v, SharedVariable) and v not in shared_inputs:
            shared_inputs.append(v)
        i += 1

    return ( input_variables, cloned_outputs
            , [clone_d, update_d, update_expr, shared_inputs] )