Beispiel #1
0
    def refresh_droot_impact(self):
        """
        Makes sure self.droot, self.impact, and self.root_destroyer are
        up to date, and returns them.
        (see docstrings for these properties above)
        """
        if self.stale_droot:
            droot = OrderedDict(
            )  # destroyed view + nonview variables -> foundation
            impact = OrderedDict(
            )  # destroyed nonview variable -> it + all views of it
            root_destroyer = OrderedDict()  # root -> destroyer apply

            for app in self.destroyers:
                for output_idx, input_idx_list in app.op.destroy_map.items():
                    if len(input_idx_list) != 1:
                        raise NotImplementedError()
                    input_idx = input_idx_list[0]
                    input = app.inputs[input_idx]
                    input_root = getroot(input, self.view_i)
                    if input_root in droot:
                        raise InconsistencyError("Multiple destroyers of %s" %
                                                 input_root)
                    droot[input_root] = input_root
                    root_destroyer[input_root] = app
                    input_impact = get_impact(input_root, self.view_o)
                    for v in input_impact:
                        assert v not in droot
                        droot[v] = input_root

                    impact[input_root] = input_impact
                    impact[input_root].add(input_root)
            self.droot, self.impact, self.root_destroyer = droot, impact, root_destroyer
            self.stale_droot = False
        return self.droot, self.impact, self.root_destroyer
Beispiel #2
0
    def __init__(self, valid=None, invalid=None, valid_equivalent=None):
        '''
        Check if variables can be expressed without using variables in invalid.

        init_valid_equivalent provides a dictionary mapping some invalid
        variables to valid ones that can be used instead.
        '''

        if valid is None:
            valid = []
        if invalid is None:
            invalid = []
        if valid_equivalent is None:
            valid_equivalent = OrderedDict()

        # Nodes that are valid to have in the graph computing outputs
        self.valid = set(valid)

        # Nodes that are NOT valid to have in the graph computing outputs
        self.invalid = set(invalid)

        # Mapping from invalid variables to equivalent valid ones.
        self.valid_equivalent = valid_equivalent.copy()
        self.valid.update(valid_equivalent.values())
        self.invalid.update(valid_equivalent.keys())
Beispiel #3
0
    def on_prune(self, fgraph, app, reason):
        """Remove Apply instance from set which must be computed"""
        if app not in self.debug_all_apps:
            raise ProtocolError("prune without import")
        self.debug_all_apps.remove(app)

        #UPDATE self.clients
        for i, input in enumerate(OrderedSet(app.inputs)):
            del self.clients[input][app]

        if getattr(app.op, 'destroy_map', OrderedDict()):
            self.destroyers.remove(app)

        # Note: leaving empty client dictionaries in the struct.
        # Why? It's a pain to remove them. I think they aren't doing any harm, they will be
        # deleted on_detach().

        #UPDATE self.view_i, self.view_o
        for o_idx, i_idx_list in getattr(app.op, 'view_map',
                                         OrderedDict()).items():
            if len(i_idx_list) > 1:
                #destroying this output invalidates multiple inputs
                raise NotImplementedError()
            o = app.outputs[o_idx]
            i = app.inputs[i_idx_list[0]]

            del self.view_i[o]

            self.view_o[i].remove(o)
            if not self.view_o[i]:
                del self.view_o[i]

        self.stale_droot = True
Beispiel #4
0
    def on_import(self, fgraph, app, reason):
        """Add Apply instance to set which must be computed"""

        if app in self.debug_all_apps:
            raise ProtocolError("double import")
        self.debug_all_apps.add(app)
        #print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)

        # If it's a destructive op, add it to our watch list
        if getattr(app.op, 'destroy_map', OrderedDict()):
            self.destroyers.add(app)

        # add this symbol to the forward and backward maps
        for o_idx, i_idx_list in getattr(app.op, 'view_map',
                                         OrderedDict()).items():
            if len(i_idx_list) > 1:
                raise NotImplementedError(
                    'destroying this output invalidates multiple inputs',
                    (app.op))
            o = app.outputs[o_idx]
            i = app.inputs[i_idx_list[0]]
            self.view_i[o] = i
            self.view_o.setdefault(i, OrderedSet()).add(o)

        # update self.clients
        for i, input in enumerate(app.inputs):
            self.clients.setdefault(input, OrderedDict()).setdefault(app, 0)
            self.clients[input][app] += 1

        for i, output in enumerate(app.outputs):
            self.clients.setdefault(output, OrderedDict())

        self.stale_droot = True
Beispiel #5
0
    def __init__(self, valid=None, invalid=None, valid_equivalent=None):
        '''
        Check if variables can be expressed without using variables in invalid.

        init_valid_equivalent provides a dictionary mapping some invalid
        variables to valid ones that can be used instead.
        '''

        if valid is None:
            valid = []
        if invalid is None:
            invalid = []
        if valid_equivalent is None:
            valid_equivalent = OrderedDict()

        # Nodes that are valid to have in the graph computing outputs
        self.valid = set(valid)

        # Nodes that are NOT valid to have in the graph computing outputs
        self.invalid = set(invalid)

        # Mapping from invalid variables to equivalent valid ones.
        self.valid_equivalent = valid_equivalent.copy()
        self.valid.update(valid_equivalent.values())
        self.invalid.update(valid_equivalent.keys())
    def orderings(self):
        """
        Return dict d s.t. d[node] is a list of nodes that must be evaluated
        before node itself can be evaluated.

        This is used primarily by the destroy_handler feature to ensure that
        all clients of any destroyed inputs have already computed their
        outputs.

        :note: This only calls the orderings() fct on all features. It does not
               take care of computing dependencies by itself.

        """
        ords = OrderedDict()
        assert isinstance(self._features, list)
        for feature in self._features:
            if hasattr(feature, 'orderings'):
                orderings = feature.orderings(self)
                if not isinstance(orderings, OrderedDict):
                    raise TypeError("Non-deterministic return value from " +
                                    str(feature.orderings) +
                                    ". Nondeterministic object is " +
                                    str(orderings))
                for node, prereqs in orderings.items():
                    if not isinstance(prereqs, (list, OrderedSet)):
                        raise TypeError(
                            "prereqs must be a type with a "
                            "deterministic iteration order, or toposort "
                            " will be non-deterministic.")
                    ords.setdefault(node, []).extend(prereqs)
        # eliminate duplicate prereqs
        for (node, prereqs) in ords.items():
            ords[node] = list(OrderedSet(prereqs))
        return ords
    def on_attach(self, fgraph):
        """
        When attaching to a new fgraph, check that
            1) This DestroyHandler wasn't already attached to some fgraph
               (its data structures are only set up to serve one)
            2) The FunctionGraph doesn't already have a DestroyHandler.
               This would result in it validating everything twice, causing
               compilation to be slower.

        Give the FunctionGraph instance:
            1) A new method "destroyers(var)"
                TODO: what does this do exactly?
            2) A new attribute, "destroy_handler"
        TODO: WRITEME: what does this do besides the checks?
        """

        ####### Do the checking ###########
        already_there = False
        if self.fgraph is fgraph:
            already_there = True
        if self.fgraph is not None:
            raise Exception(
                "A DestroyHandler instance can only serve one FunctionGraph. (Matthew 6:24)"
            )
        for attr in ('destroyers', 'destroy_handler'):
            if hasattr(fgraph, attr):
                already_there = True

        if already_there:
            # FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment
            raise toolbox.AlreadyThere(
                "DestroyHandler feature is already present or in conflict with another plugin."
            )

        ####### Annotate the FunctionGraph ############

        def get_destroyers_of(r):
            droot, impact, root_destroyer = self.refresh_droot_impact()
            try:
                return [root_destroyer[droot[r]]]
            except Exception:
                return []

        fgraph.destroyers = get_destroyers_of
        fgraph.destroy_handler = self

        self.fgraph = fgraph
        self.destroyers = OrderedSet(
        )  #set of Apply instances with non-null destroy_map
        self.view_i = OrderedDict()  # variable -> variable used in calculation
        self.view_o = OrderedDict(
        )  # variable -> set of variables that use this one as a direct input
        #clients: how many times does an apply use a given variable
        self.clients = OrderedDict()  # variable -> apply -> ninputs
        self.stale_droot = True

        self.debug_all_apps = OrderedSet()
        if self.do_imports_on_attach:
            toolbox.Bookkeeper.on_attach(self, fgraph)
Beispiel #8
0
    class OrderedSet(object):
        """
        An implementation of OrderedSet based on the keys of
        an OrderedDict.
        """
        def __init__(self, iterable=None):
            self.data = OrderedDict()
            if iterable is not None:
                self.update(iterable)

        def update(self, container):
            check_deterministic(container)
            for elem in container:
                self.add(elem)

        def add(self, key):
            self.data[key] = None

        def __len__(self):
            return len(self.data)

        def __contains__(self, key):
            return key in self.data

        def discard(self, key):
            if key in self.data:
                del self.data[key]

        def remove(self, key):
            if key in self.data:
                del self.data[key]
            else:
                raise KeyError(key)

        def __iter__(self):
            return self.data.__iter__()

        def __reversed__(self):
            return self.data.__reversed__()

        def pop(self, last=True):
            raise NotImplementedError()

        def __eq__(self, other):
            # Note that we implement only the comparison to another
            # `OrderedSet`, and not to a regular `set`, because otherwise we
            # could have a non-symmetric equality relation like:
            #       my_ordered_set == my_set and my_set != my_ordered_set
            if isinstance(other, OrderedSet):
                return len(self) == len(other) and list(self) == list(other)
            elif isinstance(other, set):
                # Raise exception to avoid confusion.
                raise TypeError(
                        'Cannot compare an `OrderedSet` to a `set` because '
                        'this comparison cannot be made symmetric: please '
                        'manually cast your `OrderedSet` into `set` before '
                        'performing this comparison.')
            else:
                return NotImplemented
Beispiel #9
0
    class OrderedSet(object):
        """
        An implementation of OrderedSet based on the keys of
        an OrderedDict.
        """
        def __init__(self, iterable=None):
            self.data = OrderedDict()
            if iterable is not None:
                self.update(iterable)

        def update(self, container):
            check_deterministic(container)
            for elem in container:
                self.add(elem)

        def add(self, key):
            self.data[key] = None

        def __len__(self):
            return len(self.data)

        def __contains__(self, key):
            return key in self.data

        def discard(self, key):
            if key in self.data:
                del self.data[key]

        def remove(self, key):
            if key in self.data:
                del self.data[key]
            else:
                raise KeyError(key)

        def __iter__(self):
            return self.data.__iter__()

        def __reversed__(self):
            return self.data.__reversed__()

        def pop(self, last=True):
            raise NotImplementedError()

        def __eq__(self, other):
            # Note that we implement only the comparison to another
            # `OrderedSet`, and not to a regular `set`, because otherwise we
            # could have a non-symmetric equality relation like:
            #       my_ordered_set == my_set and my_set != my_ordered_set
            if isinstance(other, OrderedSet):
                return len(self) == len(other) and list(self) == list(other)
            elif isinstance(other, set):
                # Raise exception to avoid confusion.
                raise TypeError(
                    'Cannot compare an `OrderedSet` to a `set` because '
                    'this comparison cannot be made symmetric: please '
                    'manually cast your `OrderedSet` into `set` before '
                    'performing this comparison.')
            else:
                return NotImplemented
Beispiel #10
0
    def on_attach(self, fgraph):
        """
        When attaching to a new fgraph, check that
            1) This DestroyHandler wasn't already attached to some fgraph
               (its data structures are only set up to serve one)
            2) The FunctionGraph doesn't already have a DestroyHandler.
               This would result in it validating everything twice, causing
               compilation to be slower.

        Give the FunctionGraph instance:
            1) A new method "destroyers(var)"
                TODO: what does this do exactly?
            2) A new attribute, "destroy_handler"
        TODO: WRITEME: what does this do besides the checks?
        """

        ####### Do the checking ###########
        already_there = False
        if self.fgraph is fgraph:
            already_there = True
        if self.fgraph is not None:
            raise Exception("A DestroyHandler instance can only serve one" " FunctionGraph. (Matthew 6:24)")
        for attr in ("destroyers", "destroy_handler"):
            if hasattr(fgraph, attr):
                already_there = True

        if already_there:
            # FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment
            raise toolbox.AlreadyThere(
                "DestroyHandler feature is already present" " or in conflict with another plugin."
            )

        ####### Annotate the FunctionGraph ############

        def get_destroyers_of(r):
            droot, impact, root_destroyer = self.refresh_droot_impact()
            try:
                return [root_destroyer[droot[r]]]
            except Exception:
                return []

        fgraph.destroyers = get_destroyers_of
        fgraph.destroy_handler = self

        self.fgraph = fgraph
        self.destroyers = OrderedSet()  # set of Apply instances with non-null destroy_map
        self.view_i = OrderedDict()  # variable -> variable used in calculation
        self.view_o = OrderedDict()  # variable -> set of variables that use this one as a direct input
        # clients: how many times does an apply use a given variable
        self.clients = OrderedDict()  # variable -> apply -> ninputs
        self.stale_droot = True

        self.debug_all_apps = OrderedSet()
        if self.do_imports_on_attach:
            toolbox.Bookkeeper.on_attach(self, fgraph)
Beispiel #11
0
def forced_replace(out, x, y):
    """
    :param out: Theano Variable
    :param x: Theano Variable
    :param y: Theano Variable

    This function checks all internal values of the graph that computes the
    variable ``out`` for occurances of values identical with ``x``. If such
    occurances are encountered then they are replaced with variable ``y``.
    For example:
        out := sigmoid(wu)*(1-sigmoid(wu))
        x := sigmoid(wu)
        forced_replace(out, x, y) := y*(1-y)
    """
    if out is None:
        return None

    # ``visited`` is a set of nodes that are already known and don't need to be
    # checked again, speeding up the traversal of multiply-connected graphs.
    visited = set()
    def local_traverse(graph, x):
        if graph in visited:
            return []
        visited.add(graph)
        if equal_computations([graph], [x]):
            return [graph]
        elif not graph.owner:
            return []
        else:
            rval = []
            for inp in graph.owner.inputs:
                rval += local_traverse(inp, x)
            return rval
    to_replace = local_traverse(out, x)
    return clone(out, replace=OrderedDict((v, y) for v in to_replace))
Beispiel #12
0
def forced_replace(out, x, y):
    """
    :param out: Theano Variable
    :param x: Theano Variable
    :param y: Theano Variable

    This function checks all internal values of the graph that computes the
    variable ``out`` for occurances of values identical with ``x``. If such
    occurances are encountered then they are replaced with variable ``y``.
    For example:
        out := sigmoid(wu)*(1-sigmoid(wu))
        x := sigmoid(wu)
        forced_replace(out, x, y) := y*(1-y)
    """
    if out is None:
        return None

    def traverse(graph, x):
        if equal_computations([graph], [x]):
            return [graph]
        elif not graph.owner:
            return []
        else:
            rval = []
            for inp in graph.owner.inputs:
                rval += traverse(inp, x)
            return rval

    to_replace = traverse(out, x)
    return clone(out, replace=OrderedDict((v, y) for v in to_replace))
Beispiel #13
0
 def __init__(self, do_imports_on_attach=True):
     self.fgraph = None
     self.do_imports_on_attach = do_imports_on_attach
     """maps every variable in the graph to its "foundation" (deepest
     ancestor in view chain)
     TODO: change name to var_to_vroot"""
     self.droot = OrderedDict()
     """maps a variable to all variables that are indirect or direct views of it
      (including itself)
      essentially the inverse of droot
     TODO: do all variables appear in this dict, or only those that are foundations?
     TODO: do only destroyed variables go in here? one old docstring said so
     TODO: rename to x_to_views after reverse engineering what x is"""
     self.impact = OrderedDict()
     """if a var is destroyed, then this dict will map
     droot[var] to the apply node that destroyed var
     TODO: rename to vroot_to_destroyer"""
     self.root_destroyer = OrderedDict()
Beispiel #14
0
    def orderings(self, function_graph):
        """
        Called by toposort. It should return a dictionary of
        {node: predecessors} where predecessors is a list of
        nodes that should be computed before the key node.

        If you raise an exception in this function, the state of the graph
        might be broken for all intents and purposes.
        """
        return OrderedDict()
Beispiel #15
0
    def on_change_input(self, fgraph, app, i, old_r, new_r, reason):
        """app.inputs[i] changed from old_r to new_r """
        if app == 'output':
            # app == 'output' is special key that means FunctionGraph is redefining which nodes are being
            # considered 'outputs' of the graph.
            pass
        else:
            if app not in self.debug_all_apps:
                raise ProtocolError("change without import")

            #UPDATE self.clients
            self.clients[old_r][app] -= 1
            if self.clients[old_r][app] == 0:
                del self.clients[old_r][app]

            self.clients.setdefault(new_r, OrderedDict()).setdefault(app, 0)
            self.clients[new_r][app] += 1

            #UPDATE self.view_i, self.view_o
            for o_idx, i_idx_list in getattr(app.op, 'view_map',
                                             OrderedDict()).items():
                if len(i_idx_list) > 1:
                    #destroying this output invalidates multiple inputs
                    raise NotImplementedError()
                i_idx = i_idx_list[0]
                output = app.outputs[o_idx]
                if i_idx == i:
                    if app.inputs[i_idx] is not new_r:
                        raise ProtocolError("wrong new_r on change")

                    self.view_i[output] = new_r

                    self.view_o[old_r].remove(output)
                    if not self.view_o[old_r]:
                        del self.view_o[old_r]

                    self.view_o.setdefault(new_r, OrderedSet()).add(output)

        self.stale_droot = True
Beispiel #16
0
    def run(replay, log=None):

        if not replay:
            log = StringIO()
        else:
            log = StringIO(log)
        record = Record(replay=replay, file_object=log)

        disturb_mem.disturb_mem()

        mode = RecordMode(record=record)

        b = sharedX(np.zeros((2,)), name='b')
        channels = OrderedDict()

        disturb_mem.disturb_mem()

        v_max = b.max(axis=0)
        v_min = b.min(axis=0)
        v_range = v_max - v_min

        updates = []
        for i, val in enumerate([
                v_max.max(),
                v_max.min(),
                v_range.max(),
                ]):
            disturb_mem.disturb_mem()
            s = sharedX(0., name='s_' + str(i))
            updates.append((s, val))

        for var in theano.gof.graph.ancestors(update for _, update in updates):
            if var.name is not None and var.name is not 'b':
                if var.name[0] != 's' or len(var.name) != 2:
                    var.name = None

        for key in channels:
            updates.append((s, channels[key]))
        f = theano.function([], mode=mode, updates=updates,
                            on_unused_input='ignore', name='f')
        for output in f.maker.fgraph.outputs:
            mode.record.handle_line(var_descriptor(output) + '\n')
        disturb_mem.disturb_mem()
        f()

        mode.record.f.flush()

        if not replay:
            return log.getvalue()
Beispiel #17
0
 def update(self, other=None):
     if other is None:
         return
     if (isinstance(other, dict) and len(other) > 1
             and not isinstance(other, OrderedDict)):
         # Warn about non-determinism.
         warnings.warn(
             'Updating an `OrderedUpdates` with a '
             'non-ordered dictionary with 2+ elements could '
             'make your code non-deterministic',
             stacklevel=2)
     for key, val in OrderedDict(other).iteritems():
         if key in self:
             if self[key] == val:
                 continue
             raise KeyError('Collision', key)
         self[key] = val  # __setitem__ does type-checking
Beispiel #18
0
def test_subgraph_grad():

    # Tests that the grad method with no known_grads
    # matches what happens if you use successive subgraph_grads

    x = theano.tensor.fvector('x')
    t = theano.tensor.fvector('t')
    w1 = theano.shared(np.random.randn(3, 4))
    w2 = theano.shared(np.random.randn(4, 2))
    a1 = theano.tensor.tanh(theano.tensor.dot(x, w1))
    a2 = theano.tensor.tanh(theano.tensor.dot(a1, w2))
    cost2 = theano.tensor.sqr(a2 - t).sum()
    cost2 += theano.tensor.sqr(w2.sum())
    cost1 = theano.tensor.sqr(w1.sum())

    params = [[w2], [w1]]
    costs = [cost2, cost1]
    grad_ends = [[a1], [x]]

    inputs = [t, x]
    rng = np.random.RandomState([2012, 11, 15])
    values = [rng.randn(2), rng.randn(3)]
    values = [np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values)]

    wrt = [w2, w1]
    cost = cost2 + cost1
    true_grads = theano.grad(cost, wrt)
    true_grads = theano.function(inputs, true_grads)
    true_grads = true_grads(*values)
    from theano.gof.python25 import OrderedDict
    next_grad = None
    param_grads = []
    for i in xrange(2):
        param_grad, next_grad = theano.subgraph_grad(wrt=params[i],
                                                     end=grad_ends[i],
                                                     start=next_grad,
                                                     cost=costs[i])
        next_grad = OrderedDict(zip(grad_ends[i], next_grad))
        param_grads.extend(param_grad)

    pgrads = theano.function(inputs, param_grads)
    pgrads = pgrads(*values)

    for true_grad, pgrad in zip(true_grads, pgrads):
        assert (np.sum(np.abs(true_grad - pgrad)) < 0.00001)
Beispiel #19
0
def reconstruct_graph(inputs, outputs, tag=None):
    """
    Different interface to clone, that allows you to pass inputs.
    Compared to clone, this method always replaces the inputs with
    new variables of the same type, and returns those ( in the same
    order as the original inputs).
    """
    if tag is None:
        tag = ''
    nw_inputs = [safe_new(x, tag) for x in inputs]
    givens = OrderedDict()
    for nw_x, x in izip(nw_inputs, inputs):
        givens[x] = nw_x
    allinputs = theano.gof.graph.inputs(outputs)
    for inp in allinputs:
        if isinstance(inp, theano.Constant):
            givens[inp] = inp.clone()

    nw_outputs = clone(outputs, replace=givens)
    return (nw_inputs, nw_outputs)
Beispiel #20
0
    def __init__(self, outer_inputs, outer_outputs, _inner_inputs,
                 _inner_outputs, info):
        self.n_steps = outer_inputs[0]
        rval = reconstruct_graph(_inner_inputs, _inner_outputs, '_merge')
        if info['as_while']:
            self.cond = [rval[1][-1]]
            inner_outputs = rval[1][:-1]
        else:
            inner_outputs = rval[1]
        inner_inputs = rval[0]

        p = 1
        q = 0

        n_seqs = info['n_seqs']
        self.outer_in_seqs = outer_inputs[p:p + n_seqs]
        self.inner_in_seqs = inner_inputs[q:q + n_seqs]
        p += n_seqs
        q += n_seqs

        n_mit_mot = info['n_mit_mot']
        n_mit_sot = info['n_mit_sot']

        self.mit_mot_in_slices = info['tap_array'][:n_mit_mot]
        self.mit_sot_in_slices = info['tap_array'][n_mit_mot:n_mit_mot +
                                                   n_mit_sot]

        n_mit_mot_ins = sum(len(s) for s in self.mit_mot_in_slices)
        n_mit_sot_ins = sum(len(s) for s in self.mit_sot_in_slices)

        iimm = inner_inputs[q:q + n_mit_mot_ins]
        self.inner_in_mit_mot = []
        qq = 0
        for sl in self.mit_mot_in_slices:
            self.inner_in_mit_mot.append(iimm[qq:qq + len(sl)])
            qq += len(sl)
        q += n_mit_mot_ins

        iims = inner_inputs[q:q + n_mit_sot_ins]
        self.inner_in_mit_sot = []
        qq = 0
        for sl in self.mit_sot_in_slices:
            self.inner_in_mit_sot.append(iims[qq:qq + len(sl)])
            qq += len(sl)
        q += n_mit_sot_ins

        self.outer_in_mit_mot = outer_inputs[p:p + n_mit_mot]
        p += n_mit_mot
        self.outer_in_mit_sot = outer_inputs[p:p + n_mit_sot]
        p += n_mit_sot

        n_sit_sot = info['n_sit_sot']
        self.outer_in_sit_sot = outer_inputs[p:p + n_sit_sot]
        self.inner_in_sit_sot = inner_inputs[q:q + n_sit_sot]
        p += n_sit_sot
        q += n_sit_sot

        n_shared_outs = info['n_shared_outs']
        self.outer_in_shared = outer_inputs[p:p + n_shared_outs]
        self.inner_in_shared = inner_inputs[q:q + n_shared_outs]
        p += n_shared_outs
        q += n_shared_outs

        n_nit_sot = info['n_nit_sot']
        self.outer_in_nit_sot = outer_inputs[p:p + n_nit_sot]
        p += n_nit_sot

        self.outer_in_non_seqs = outer_inputs[p:]
        self.inner_in_non_seqs = inner_inputs[q:]

        # now for the outputs
        p = 0
        q = 0

        self.mit_mot_out_slices = info['mit_mot_out_slices']
        n_mit_mot_outs = info['n_mit_mot_outs']
        self.outer_out_mit_mot = outer_outputs[p:p + n_mit_mot]
        iomm = inner_outputs[q:q + n_mit_mot_outs]
        self.inner_out_mit_mot = []
        qq = 0
        for sl in self.mit_mot_out_slices:
            self.inner_out_mit_mot.append(iomm[qq:qq + len(sl)])
            qq += len(sl)
        p += n_mit_mot
        q += n_mit_mot_outs

        self.outer_out_mit_sot = outer_outputs[p:p + n_mit_sot]
        self.inner_out_mit_sot = inner_outputs[q:q + n_mit_sot]
        p += n_mit_sot
        q += n_mit_sot

        self.outer_out_sit_sot = outer_outputs[p:p + n_sit_sot]
        self.inner_out_sit_sot = inner_outputs[q:q + n_sit_sot]
        p += n_sit_sot
        q += n_sit_sot

        self.outer_out_nit_sot = outer_outputs[p:p + n_nit_sot]
        self.inner_out_nit_sot = inner_outputs[q:q + n_nit_sot]
        p += n_nit_sot
        q += n_nit_sot

        self.outer_out_shared = outer_outputs[p:p + n_shared_outs]
        self.inner_out_shared = inner_outputs[q:q + n_shared_outs]
        p += n_shared_outs
        q += n_shared_outs

        self.other_info = OrderedDict()
        for k in ('truncate_gradient', 'name', 'mode', 'destroy_map', 'gpu',
                  'as_while', 'profile'):
            if k in info:
                self.other_info[k] = info[k]
Beispiel #21
0
class scan_args(object):
    """Parses the inputs and outputs of scan in an easy to manipulate format"""
    def __init__(self, outer_inputs, outer_outputs, _inner_inputs,
                 _inner_outputs, info):
        self.n_steps = outer_inputs[0]
        rval = reconstruct_graph(_inner_inputs, _inner_outputs, '_merge')
        if info['as_while']:
            self.cond = [rval[1][-1]]
            inner_outputs = rval[1][:-1]
        else:
            inner_outputs = rval[1]
        inner_inputs = rval[0]

        p = 1
        q = 0

        n_seqs = info['n_seqs']
        self.outer_in_seqs = outer_inputs[p:p + n_seqs]
        self.inner_in_seqs = inner_inputs[q:q + n_seqs]
        p += n_seqs
        q += n_seqs

        n_mit_mot = info['n_mit_mot']
        n_mit_sot = info['n_mit_sot']

        self.mit_mot_in_slices = info['tap_array'][:n_mit_mot]
        self.mit_sot_in_slices = info['tap_array'][n_mit_mot:n_mit_mot +
                                                   n_mit_sot]

        n_mit_mot_ins = sum(len(s) for s in self.mit_mot_in_slices)
        n_mit_sot_ins = sum(len(s) for s in self.mit_sot_in_slices)

        iimm = inner_inputs[q:q + n_mit_mot_ins]
        self.inner_in_mit_mot = []
        qq = 0
        for sl in self.mit_mot_in_slices:
            self.inner_in_mit_mot.append(iimm[qq:qq + len(sl)])
            qq += len(sl)
        q += n_mit_mot_ins

        iims = inner_inputs[q:q + n_mit_sot_ins]
        self.inner_in_mit_sot = []
        qq = 0
        for sl in self.mit_sot_in_slices:
            self.inner_in_mit_sot.append(iims[qq:qq + len(sl)])
            qq += len(sl)
        q += n_mit_sot_ins

        self.outer_in_mit_mot = outer_inputs[p:p + n_mit_mot]
        p += n_mit_mot
        self.outer_in_mit_sot = outer_inputs[p:p + n_mit_sot]
        p += n_mit_sot

        n_sit_sot = info['n_sit_sot']
        self.outer_in_sit_sot = outer_inputs[p:p + n_sit_sot]
        self.inner_in_sit_sot = inner_inputs[q:q + n_sit_sot]
        p += n_sit_sot
        q += n_sit_sot

        n_shared_outs = info['n_shared_outs']
        self.outer_in_shared = outer_inputs[p:p + n_shared_outs]
        self.inner_in_shared = inner_inputs[q:q + n_shared_outs]
        p += n_shared_outs
        q += n_shared_outs

        n_nit_sot = info['n_nit_sot']
        self.outer_in_nit_sot = outer_inputs[p:p + n_nit_sot]
        p += n_nit_sot

        self.outer_in_non_seqs = outer_inputs[p:]
        self.inner_in_non_seqs = inner_inputs[q:]

        # now for the outputs
        p = 0
        q = 0

        self.mit_mot_out_slices = info['mit_mot_out_slices']
        n_mit_mot_outs = info['n_mit_mot_outs']
        self.outer_out_mit_mot = outer_outputs[p:p + n_mit_mot]
        iomm = inner_outputs[q:q + n_mit_mot_outs]
        self.inner_out_mit_mot = []
        qq = 0
        for sl in self.mit_mot_out_slices:
            self.inner_out_mit_mot.append(iomm[qq:qq + len(sl)])
            qq += len(sl)
        p += n_mit_mot
        q += n_mit_mot_outs

        self.outer_out_mit_sot = outer_outputs[p:p + n_mit_sot]
        self.inner_out_mit_sot = inner_outputs[q:q + n_mit_sot]
        p += n_mit_sot
        q += n_mit_sot

        self.outer_out_sit_sot = outer_outputs[p:p + n_sit_sot]
        self.inner_out_sit_sot = inner_outputs[q:q + n_sit_sot]
        p += n_sit_sot
        q += n_sit_sot

        self.outer_out_nit_sot = outer_outputs[p:p + n_nit_sot]
        self.inner_out_nit_sot = inner_outputs[q:q + n_nit_sot]
        p += n_nit_sot
        q += n_nit_sot

        self.outer_out_shared = outer_outputs[p:p + n_shared_outs]
        self.inner_out_shared = inner_outputs[q:q + n_shared_outs]
        p += n_shared_outs
        q += n_shared_outs

        self.other_info = OrderedDict()
        for k in ('truncate_gradient', 'name', 'mode', 'destroy_map', 'gpu',
                  'as_while', 'profile'):
            if k in info:
                self.other_info[k] = info[k]

    inner_inputs = property(
        lambda self: (self.inner_in_seqs + sum(self.inner_in_mit_mot, [
        ]) + sum(self.inner_in_mit_sot, []) + self.inner_in_sit_sot + self.
                      inner_in_shared + self.inner_in_non_seqs))

    outer_inputs = property(lambda self: (
        [self.n_steps] + self.outer_in_seqs + self.outer_in_mit_mot + self.
        outer_in_mit_sot + self.outer_in_sit_sot + self.outer_in_shared + self.
        outer_in_nit_sot + self.outer_in_non_seqs))

    inner_outputs = property(lambda self: (
        sum(self.inner_out_mit_mot, []) + self.inner_out_mit_sot + self.
        inner_out_sit_sot + self.inner_out_nit_sot + self.inner_out_shared))

    outer_outputs = property(lambda self: (
        self.outer_out_mit_mot + self.outer_out_mit_sot + self.
        outer_out_sit_sot + self.outer_out_nit_sot + self.outer_out_shared))

    info = property(lambda self: OrderedDict(
        n_seqs=len(self.outer_in_seqs),
        n_mit_mot=len(self.outer_in_mit_mot),
        n_mit_sot=len(self.outer_in_mit_sot),
        tap_array=(self.mit_mot_in_slices + self.mit_sot_in_slices + [[-1]] *
                   len(self.inner_in_sit_sot)),
        n_sit_sot=len(self.outer_in_sit_sot),
        n_nit_sot=len(self.outer_in_nit_sot),
        n_shared_outs=len(self.outer_in_shared),
        n_mit_mot_outs=sum(len(s) for s in self.mit_mot_out_slices),
        mit_mot_out_slices=self.mit_mot_out_slices,
        **self.other_info))

    def __copy__(self):
        res = object.__new__(type(self))
        res.__dict__.update(self.__dict__)
        # also copy mutable attrs
        for attr in self.__dict__:
            if (attr.startswith('inner_in') or attr.startswith('inner_out')
                    or attr.startswith('outer_in')
                    or attr.startswith('outer_out')
                    or attr in ('mit_mot_out_slices', 'mit_mot_in_slices',
                                'mit_sot_in_slices', 'other_info')):
                setattr(res, attr, copy.copy(getattr(self, attr)))
        return res

    def merge(self, other):
        res = copy.copy(self)
        for attr in self.__dict__:
            if (attr.startswith('inner_in') or attr.startswith('inner_out')
                    or attr.startswith('outer_in')
                    or attr.startswith('outer_out')
                    or attr in ('mit_mot_out_slices', 'mit_mot_in_slices',
                                'mit_sot_in_slices')):
                getattr(res, attr).extend(getattr(other, attr))
        return res
Beispiel #22
0
def compress_outs(op, not_required, inputs):
    '''
    Helpful function that gets a Scan op, a list of indices indicating
    which outputs are not required anymore and should be removed, and
    a list of inputs to the apply node corresponding to the scan op and
    produces the list of inputs and outputs and the info dictionary where
    the indicated outputs are eliminated. Note that eliminating an output
    means removing its inputs from the inner funciton and from the
    node inputs, and changing the dictionary.
    '''
    info = OrderedDict()
    info['tap_array'] = []
    info['n_seqs'] = op.info['n_seqs']
    info['n_mit_mot'] = 0
    info['n_mit_mot_outs'] = 0
    info['mit_mot_out_slices'] = []
    info['n_mit_sot'] = 0
    info['n_sit_sot'] = 0
    info['n_shared_outs'] = 0
    info['n_nit_sot'] = 0
    info['truncate_gradient'] = op.info['truncate_gradient']
    info['name'] = op.info['name']
    info['gpu'] = op.info['gpu']
    info['mode'] = op.info['mode']
    info['as_while'] = op.info['as_while']
    info['profile'] = op.info['profile']

    op_inputs = op.inputs[:op.n_seqs]
    op_outputs = []
    node_inputs = inputs[:op.n_seqs + 1]
    map_old_new = OrderedDict()

    offset = 0
    ni_offset = op.n_seqs + 1
    i_offset = op.n_seqs
    o_offset = 0
    curr_pos = 0
    for idx in xrange(op.info['n_mit_mot']):
        if offset + idx not in not_required:
            map_old_new[offset + idx] = curr_pos
            curr_pos += 1
            info['n_mit_mot'] += 1
            info['tap_array'] += [op.tap_array[offset + idx]]
            info['mit_mot_out_slices'] += [op.mit_mot_out_slices[offset + idx]]
            # input taps
            for jdx in op.tap_array[offset + idx]:
                op_inputs += [op.inputs[i_offset]]
                i_offset += 1
            # output taps
            for jdx in op.mit_mot_out_slices[offset + idx]:
                op_outputs += [op.outputs[o_offset]]
                o_offset += 1
            # node inputs
            node_inputs += [inputs[ni_offset + idx]]
        else:
            o_offset += len(op.mit_mot_out_slices[offset + idx])
            i_offset += len(op.tap_array[offset + idx])
    info['n_mit_mot_outs'] = len(op_outputs)
    offset += op.n_mit_mot
    ni_offset += op.n_mit_mot

    for idx in xrange(op.info['n_mit_sot']):
        if offset + idx not in not_required:
            map_old_new[offset + idx] = curr_pos
            curr_pos += 1
            info['n_mit_sot'] += 1
            info['tap_array'] += [op.tap_array[offset + idx]]
            #input taps
            for jdx in op.tap_array[offset + idx]:
                op_inputs += [op.inputs[i_offset]]
                i_offset += 1
            #output taps
            op_outputs += [op.outputs[o_offset]]
            o_offset += 1
            #node inputs
            node_inputs += [inputs[ni_offset + idx]]
        else:
            o_offset += 1
            i_offset += len(op.tap_array[offset + idx])

    offset += op.n_mit_sot
    ni_offset += op.n_mit_sot
    for idx in xrange(op.info['n_sit_sot']):
        if offset + idx not in not_required:
            map_old_new[offset + idx] = curr_pos
            curr_pos += 1
            info['n_sit_sot'] += 1
            info['tap_array'] += [op.tap_array[offset + idx]]
            #input taps
            op_inputs += [op.inputs[i_offset]]
            i_offset += 1
            #output taps
            op_outputs += [op.outputs[o_offset]]
            o_offset += 1
            #node inputs
            node_inputs += [inputs[ni_offset + idx]]
        else:
            o_offset += 1
            i_offset += 1

    offset += op.n_sit_sot
    ni_offset += op.n_sit_sot
    nit_sot_ins = []
    for idx in xrange(op.info['n_nit_sot']):
        if offset + idx not in not_required:
            map_old_new[offset + idx] = curr_pos
            curr_pos += 1
            info['n_nit_sot'] += 1
            op_outputs += [op.outputs[o_offset]]
            o_offset += 1
            nit_sot_ins += [inputs[ni_offset + idx + op.n_shared_outs]]
        else:
            o_offset += 1

    offset += op.n_nit_sot
    shared_ins = []
    for idx in xrange(op.info['n_shared_outs']):
        if offset + idx not in not_required:
            map_old_new[offset + idx] = curr_pos
            curr_pos += 1
            info['n_shared_outs'] += 1
            op_outputs += [op.outputs[o_offset]]
            o_offset += 1
            op_inputs += [op.inputs[i_offset]]
            i_offset += 1
            shared_ins += [inputs[ni_offset + idx]]
        else:
            o_offset += 1
            i_offset += 1
    node_inputs += shared_ins
    node_inputs += nit_sot_ins
    # other stuff
    op_inputs += op.inputs[i_offset:]
    node_inputs += inputs[ni_offset + op.n_shared_outs + op.n_nit_sot:]
    if op.as_while:
        op_outputs += [op.outputs[o_offset]]
        map_old_new[o_offset] = len(op_outputs) - 1
        #map_old_new[len(op_outputs)-1] = o_offset

    return (op_inputs, op_outputs, info, node_inputs, map_old_new)
Beispiel #23
0
def scan(fn,
         sequences=None,
         states=None,
         params=None,
         n_steps=None,
         mode=None,
         name=None,
         profile=False,
         allow_gc=None):
    """
    Similar to Theano's official scan, this function gives the user more
    control over the scan op, avoiding certain difficulties that arose from
    missing optimizations.

    :param fn: lambda function that describes one step of scan (see the
        official Theano scan function)
    :param sequences: similar to the official Theano's scan. This version
        of scan does not support taps for the sequences (it can only be a
        list of tensor). Scan assumes that sequences have the right length
        and it does not check for this.
    :param states: similar to outputs_info of the official scan function.
        There is one crucial difference though, namely that the `initial`
        key in the dictionary has been replace by 'membuf' key. This
        reflects the change of meaning. Instead of passing to scan just
        the initial steps misisng, one has now to pass a memory buffer in
        which scan will try to store its output. In this memory buffer the
        first entries should be set to the initial states of the
        corresponding states.
        Providing a memory buffer that has less entries then the number of
        steps, mneans scan will only use that amount of memory. The user has
        to match the memory buffer size with the number of steps, otherwise
        scan will produce wrong results. Also if gradients are to be
        computed through the scan, the memory buffer should have the same
        length as the number of steps.
        For states that do not require a initial state, one has to provide a
        dictionary with a single key 'steps' that says how many intermediate
        results to store. See examples below for more insight.
    :param n_steps: This parameter is mandatory and it will represent the
        number of steps scan will do (scan will not check sequences or any
        other source of information to figure out how many steps it needs
        to do).
    :param mode: Same as for the official scan
    :param name: Same as for the official scan
    :param profile: Same as for the official scan

    Note:
     - there is no truncate / go_backwards anymore !
     - the outputs returned by scan contain the initial states as well (i.e.
     if I loop over k steps, with my smallest tap for an output -3 and keep
     al intermediate results, my output will be of length k+3

     Examples:
         (a) if you do not want to store any intermediate results (just the
         last one)

         # The memory buffer can be the initial state, just that we need to
         # add one extra dimension in front of it
         state = TT.unbroadcast(TT.shape_padleft(x0),0)
         out,_ = scan(lambda x:x+1, states = state, n_steps = 5)
         # Once we got our result we need to remove the extra dimension
         out = out[0]

        (b) if you want to keep every intermediate results

        state = TT.alloc(TT.constant(0), 6, x0.shape[0])
        state = TT.set_subtensor(state[0], x0)
        out,_ = scan(lambda x:x+1, states = state, n_steps = 5)
        out = out[1:]

    """
    def wrap_into_list(x):
        '''
        Wrap the input into a list if it is not already a list
        '''
        if x is None:
            return []
        elif not isinstance(x, (list, tuple)):
            return [x]
        else:
            return list(x)

    seqs = wrap_into_list(sequences)
    outs_info = wrap_into_list(states)
    if allow_gc is None:
        allow_gc = config.scan.allow_gc

    # Make sure we get rid of numpy arrays or ints or anything like that
    # passed as inputs to scan
    non_seqs = []
    for elem in wrap_into_list(params):
        if not isinstance(elem, gof.Variable):
            non_seqs.append(tensor.as_tensor_variable(elem))
        else:
            non_seqs.append(elem)

    # If we provided a known number of steps ( before compilation)
    # and if that number is 1 or -1, then we can skip the Scan Op,
    # and just apply the inner function once
    # To do that we check here to see the nature of n_steps
    n_fixed_steps = None

    if isinstance(n_steps, (float, int)):
        n_fixed_steps = int(n_steps)
    else:
        try:
            n_fixed_steps = opt.get_scalar_constant_value(n_steps)
        except tensor.basic.NotScalarConstantError:
            n_fixed_steps = None

    # Check n_steps is an int
    if (hasattr(n_steps, 'dtype')
            and str(n_steps.dtype)[:3] not in ('uin', 'int')):
        raise ValueError(' n_steps must be an int. dtype provided '
                         'is %s' % n_steps.dtype)

    # compute number of sequences and number of outputs
    n_seqs = len(seqs)
    n_outs = len(outs_info)

    return_steps = OrderedDict()
    # wrap outputs info in a dictionary if they are not already in one
    for i in xrange(n_outs):
        if outs_info[i] is not None:
            if not isinstance(outs_info[i], dict):
                # by default any output has a tap value of -1
                outs_info[i] = dict(membuf=outs_info[i], taps=[-1])
            elif (not outs_info[i].get('membuf', None)
                  and outs_info[i].get('taps', None)):
                # ^ no initial state but taps provided
                raise ValueError(('If you are using slices of an output '
                                  'you need to provide a memory buffer for '
                                  'the state '), outs_info[i])
            elif (outs_info[i].get('membuf', None)
                  and not outs_info[i].get('taps', None)):
                # ^ initial state but taps not provided
                if 'taps' in outs_info[i]:
                    # ^ explicitly provided a None for taps
                    _logger.warning(
                        'Output %s (index %d) has a memory '
                        'buffer but taps is explicitly set to None ',
                        getattr(outs_info[i]['membuf'], 'name', 'None'), i)
                outs_info[i]['taps'] = [-1]
        else:
            # if a None is provided as the output info we replace it
            # with an dict(steps=n_steps) to simplify handling
            outs_info[i] = dict(steps=n_steps)

    ##
    ###   Step 2. Generate inputs and outputs of the inner functions
    ###           for compiling a dummy function (Iteration #1)
    ##

    # create theano inputs for the recursive function
    # note : this is a first batch of possible inputs that will
    #        be compiled in a dummy function; we used this dummy
    #        function to detect shared variables and their updates
    #        and to construct a new and complete list of inputs and
    #        outputs

    n_seqs = 0
    scan_seqs = []  # Variables passed as inputs to the scan op
    inner_seqs = []  # Variables passed as inputs to the inner function
    inner_slices = []  # Actual slices if scan is removed from the picture
    # go through sequences picking up time slices as needed
    for i, seq in enumerate(seqs):
        if isinstance(seq, dict):
            seq = seq['input']
        actual_slice = seq[0]
        _seq_val = tensor.as_tensor_variable(seq)
        _seq_val_slice = _seq_val[0]

        nw_slice = _seq_val_slice.type()
        # Try to transfer test_value to the new variable
        if config.compute_test_value != 'off':
            try:
                nw_slice.tag.test_value = gof.Op._get_test_value(
                    _seq_val_slice)
            except AttributeError, e:
                if config.compute_test_value != 'ignore':
                    # No need to print a warning or raise an error now,
                    # it will be done when fn will be called.
                    _logger.info(('Cannot compute test value for '
                                  'the inner function of scan, input value '
                                  'missing %s'), e)

        if seq.name:
            nw_slice.name = seq.name + '[t]'
        scan_seqs.append(_seq_val)
        inner_seqs.append(nw_slice)
        inner_slices.append(actual_slice)

        n_seqs += 1
Beispiel #24
0
def get_updates_and_outputs(ls):
    """
    This function tries to recognize the updates OrderedDict, the
    list of outputs and the stopping condition returned by the
    lambda expression and arrange them in a predefined order

    WRITEME: what is the type of ls? how is it formatted?
            if it's not in the predefined order already, how does
            this function know how to put it in that order?

    """
    def is_outputs(elem):
        if (isinstance(elem, (list, tuple))
                and all([isinstance(x, theano.Variable) for x in elem])):
            return True
        if isinstance(elem, theano.Variable):
            return True
        return False

    def is_updates(elem):
        if isinstance(elem, dict):
            # Make sure the updates will be applied in a deterministic order
            if (not isinstance(elem, gof.python25.OrderedDict)
                    and len(elem) > 1):
                warnings.warn("Expected OrderedDict or OrderedUpdates, got "\
                        + str(type(elem)) + ". This can make your script non-"
                        "deterministic.")
            return True
        # Dictionaries can be given as lists of tuples
        if (isinstance(elem, (list, tuple)) and all(
            [isinstance(x, (list, tuple)) and len(x) == 2 for x in elem])):
            return True
        return False

    def is_condition(elem):
        return isinstance(elem, theano.scan_module.until)

    def _list(x):
        if isinstance(x, (list, tuple)):
            return list(x)
        else:
            return [x]

    def filter(x):
        """
        Ensure `x` is made only of allowed data types.

        Return True iff `x` is made only of lists, tuples, dictionaries, Theano
        variables or `theano.scan_module.until` objects.
        """
        # Is `x` a container we can iterate on?
        iter_on = None
        if isinstance(x, list) or isinstance(x, tuple):
            iter_on = x
        elif isinstance(x, dict):
            iter_on = x.iteritems()
        if iter_on is not None:
            return all(filter(y) for y in iter_on)
        else:
            return (isinstance(x, theano.Variable)
                    or isinstance(x, theano.scan_module.until))

    if not filter(ls):
        raise ValueError(
            'The return value of your scan lambda expression may only be '
            'made of lists, tuples, or dictionaries containing Theano '
            'variables (or `theano.scan_module.until` objects for '
            'conditions). In particular if you need to use constant '
            'values, you can use `tensor.constant` to turn them into '
            'Theano variables.')

    if is_outputs(ls):
        return None, _list(ls), OrderedDict()
    if is_updates(ls):
        return None, [], OrderedDict(ls)
    error_msg = ('Scan cannot parse the return value of your lambda '
                 'expression, which is: %s' % (ls, ))
    if not isinstance(ls, (list, tuple)):
        raise ValueError(error_msg)
    ls = list(ls)
    deprecation_msg = (
        'The return value of the lambda function'
        ' has been restricted. you have to always return first the'
        ' outputs (if any), afterwards the updates (if any) and'
        ' at the end the conclusion')
    if len(ls) == 2:
        if is_outputs(ls[0]):
            if is_updates(ls[1]):
                return (None, _list(ls[0]), OrderedDict(ls[1]))
            elif is_condition(ls[1]):
                return (ls[1].condition, _list(ls[0]), OrderedDict())
            else:
                raise ValueError(error_msg)
        elif is_updates(ls[0]):
            if is_outputs(ls[1]):
                raise ValueError(deprecation_msg)
            elif is_condition(ls[1]):
                return (ls[1].condition, [], OrderedDict(ls[0]))
            else:
                raise ValueError(error_msg)
        else:
            raise ValueError(error_msg)
    elif len(ls) == 3:
        if is_outputs(ls[0]):
            if is_updates(ls[1]):
                if is_condition(ls[2]):
                    return (ls[2].condition, _list(ls[0]), OrderedDict(ls[1]))
                else:
                    raise ValueError(error_msg)
            else:
                raise ValueError(error_msg)
        else:
            raise ValueError(error_msg)
    else:
        raise ValueError(error_msg)
Beispiel #25
0
class DestroyHandler(toolbox.Bookkeeper):
    """
    The DestroyHandler class detects when a graph is impossible to evaluate
    because of aliasing and destructive operations.

    Several data structures are used to do this.

    An Op can use its view_map property to declare that an output may be
    aliased to an input. If that output is destroyed, the input is also
    considered to be destroyed. The view_maps of several Ops can feed into
    one another and form a directed graph. The consequence of destroying any
    variable in such a graph is that all variables in the graph must be
    considered to be destroyed, because they could all be refering to the
    same underlying storage.

    In the current implementation, that graph is a tree, and the root of that
    tree is called the foundation.

    TODO: why "in the current implementation" ? is there another implementation
          planned?
    TODO: why is the graph a tree? isn't it possible that one variable could
          be aliased to many variables? for example, don't switch and ifelse
          have to do this?

    The original DestroyHandler (if 0'ed out above) computed several data
    structures from scratch each time it was asked to validate the graph.
    Because this happens potentially thousands of times and each graph to
    validate is extremely similar to the previous one, computing the
    data structures from scratch repeatedly was wasteful and resulted in
    high compile times for large graphs.

    This implementation computes the data structures once at initialization
    and then incrementally updates them.

    It is a work in progress. The following data structures have been
    converted to use the incremental strategy:
        <none>

    The following data structures remain to be converted:
        <unknown>
    """


    def __init__(self, do_imports_on_attach=True):
        self.fgraph = None
        self.do_imports_on_attach = do_imports_on_attach

        """maps every variable in the graph to its "foundation" (deepest
        ancestor in view chain)
        TODO: change name to var_to_vroot"""
        self.droot = OrderedDict()

        """maps a variable to all variables that are indirect or direct views of it
         (including itself)
         essentially the inverse of droot
        TODO: do all variables appear in this dict, or only those that are foundations?
        TODO: do only destroyed variables go in here? one old docstring said so
        TODO: rename to x_to_views after reverse engineering what x is"""
        self.impact = OrderedDict()

        """if a var is destroyed, then this dict will map
        droot[var] to the apply node that destroyed var
        TODO: rename to vroot_to_destroyer"""
        self.root_destroyer = OrderedDict()

    def on_attach(self, fgraph):
        """
        When attaching to a new fgraph, check that
            1) This DestroyHandler wasn't already attached to some fgraph
               (its data structures are only set up to serve one)
            2) The FunctionGraph doesn't already have a DestroyHandler.
               This would result in it validating everything twice, causing
               compilation to be slower.

        Give the FunctionGraph instance:
            1) A new method "destroyers(var)"
                TODO: what does this do exactly?
            2) A new attribute, "destroy_handler"
        TODO: WRITEME: what does this do besides the checks?
        """

        ####### Do the checking ###########
        already_there = False
        if self.fgraph is fgraph:
            already_there = True
        if self.fgraph is not None:
            raise Exception("A DestroyHandler instance can only serve one FunctionGraph. (Matthew 6:24)")
        for attr in ('destroyers', 'destroy_handler'):
            if hasattr(fgraph, attr):
                already_there = True

        if already_there:
            # FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment
            raise toolbox.AlreadyThere("DestroyHandler feature is already present or in conflict with another plugin.")

        ####### Annotate the FunctionGraph ############

        def get_destroyers_of(r):
            droot, impact, root_destroyer = self.refresh_droot_impact()
            try:
                return [root_destroyer[droot[r]]]
            except Exception:
                return []

        fgraph.destroyers = get_destroyers_of
        fgraph.destroy_handler = self

        self.fgraph = fgraph
        self.destroyers = OrderedSet() #set of Apply instances with non-null destroy_map
        self.view_i = OrderedDict()  # variable -> variable used in calculation
        self.view_o = OrderedDict()  # variable -> set of variables that use this one as a direct input
        #clients: how many times does an apply use a given variable
        self.clients = OrderedDict() # variable -> apply -> ninputs
        self.stale_droot = True

        self.debug_all_apps = OrderedSet()
        if self.do_imports_on_attach:
            toolbox.Bookkeeper.on_attach(self, fgraph)

    def refresh_droot_impact(self):
        """
        Makes sure self.droot, self.impact, and self.root_destroyer are
        up to date, and returns them.
        (see docstrings for these properties above)
        """
        if self.stale_droot:
            droot = OrderedDict()   # destroyed view + nonview variables -> foundation
            impact = OrderedDict()  # destroyed nonview variable -> it + all views of it
            root_destroyer = OrderedDict() # root -> destroyer apply

            for app in self.destroyers:
                for output_idx, input_idx_list in app.op.destroy_map.items():
                    if len(input_idx_list) != 1:
                        raise NotImplementedError()
                    input_idx = input_idx_list[0]
                    input = app.inputs[input_idx]
                    input_root = getroot(input, self.view_i)
                    if input_root in droot:
                        raise InconsistencyError("Multiple destroyers of %s" % input_root)
                    droot[input_root] = input_root
                    root_destroyer[input_root] = app
                    input_impact = get_impact(input_root, self.view_o)
                    for v in input_impact:
                        assert v not in droot
                        droot[v] = input_root

                    impact[input_root] = input_impact
                    impact[input_root].add(input_root)
            self.droot, self.impact, self.root_destroyer = droot, impact, root_destroyer
            self.stale_droot = False
        return self.droot, self.impact, self.root_destroyer

    def on_detach(self, fgraph):
        if fgraph is not self.fgraph:
            raise Exception("detaching wrong fgraph", fgraph)
        del self.destroyers
        del self.view_i
        del self.view_o
        del self.clients
        del self.stale_droot
        assert self.fgraph.destroyer_handler is self
        delattr(self.fgraph, 'destroyers')
        delattr(self.fgraph, 'destroy_handler')
        self.fgraph = None

    def on_import(self, fgraph, app):
        """Add Apply instance to set which must be computed"""

        if app in self.debug_all_apps: raise ProtocolError("double import")
        self.debug_all_apps.add(app)
        #print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)

        # If it's a destructive op, add it to our watch list
        if getattr(app.op, 'destroy_map', OrderedDict()):
            self.destroyers.add(app)

        # add this symbol to the forward and backward maps
        for o_idx, i_idx_list in getattr(app.op, 'view_map', OrderedDict()).items():
            if len(i_idx_list) > 1:
                raise NotImplementedError('destroying this output invalidates multiple inputs', (app.op))
            o = app.outputs[o_idx]
            i = app.inputs[i_idx_list[0]]
            self.view_i[o] = i
            self.view_o.setdefault(i, OrderedSet()).add(o)

        # update self.clients
        for i, input in enumerate(app.inputs):
            self.clients.setdefault(input, OrderedDict()).setdefault(app,0)
            self.clients[input][app] += 1

        for i, output in enumerate(app.outputs):
            self.clients.setdefault(output, OrderedDict())

        self.stale_droot = True

    def on_prune(self, fgraph, app):
        """Remove Apply instance from set which must be computed"""
        if app not in self.debug_all_apps: raise ProtocolError("prune without import")
        self.debug_all_apps.remove(app)

        #UPDATE self.clients
        for i, input in enumerate(OrderedSet(app.inputs)):
            del self.clients[input][app]

        if getattr(app.op, 'destroy_map', OrderedDict()):
            self.destroyers.remove(app)

        # Note: leaving empty client dictionaries in the struct.
        # Why? It's a pain to remove them. I think they aren't doing any harm, they will be
        # deleted on_detach().

        #UPDATE self.view_i, self.view_o
        for o_idx, i_idx_list in getattr(app.op, 'view_map', OrderedDict()).items():
            if len(i_idx_list) > 1:
                #destroying this output invalidates multiple inputs
                raise NotImplementedError()
            o = app.outputs[o_idx]
            i = app.inputs[i_idx_list[0]]

            del self.view_i[o]

            self.view_o[i].remove(o)
            if not self.view_o[i]:
                del self.view_o[i]

        self.stale_droot = True

    def on_change_input(self, fgraph, app, i, old_r, new_r):
        """app.inputs[i] changed from old_r to new_r """
        if app == 'output':
            # app == 'output' is special key that means FunctionGraph is redefining which nodes are being
            # considered 'outputs' of the graph.
            pass
        else:
            if app not in self.debug_all_apps: raise ProtocolError("change without import")

            #UPDATE self.clients
            self.clients[old_r][app] -= 1
            if self.clients[old_r][app] == 0:
                del self.clients[old_r][app]

            self.clients.setdefault(new_r, OrderedDict()).setdefault(app,0)
            self.clients[new_r][app] += 1

            #UPDATE self.view_i, self.view_o
            for o_idx, i_idx_list in getattr(app.op, 'view_map', OrderedDict()).items():
                if len(i_idx_list) > 1:
                    #destroying this output invalidates multiple inputs
                    raise NotImplementedError()
                i_idx = i_idx_list[0]
                output = app.outputs[o_idx]
                if i_idx == i:
                    if app.inputs[i_idx] is not new_r:
                        raise ProtocolError("wrong new_r on change")

                    self.view_i[output] = new_r

                    self.view_o[old_r].remove(output)
                    if not self.view_o[old_r]:
                        del self.view_o[old_r]

                    self.view_o.setdefault(new_r, OrderedSet()).add(output)

        self.stale_droot = True

    def validate(self, fgraph):
        """Return None

        Raise InconsistencyError when
        a) orderings() raises an error
        b) orderings cannot be topologically sorted.

        """

        if self.destroyers:
            ords = self.orderings(fgraph)

            if _contains_cycle(fgraph, ords):
                raise InconsistencyError("Dependency graph contains cycles")
        else:
            #James's Conjecture:
            #If there are no destructive ops, then there can be no cycles.
            pass
        return True

    def orderings(self, fgraph):
        """Return orderings induced by destructive operations.

        Raise InconsistencyError when
        a) attempting to destroy indestructable variable, or
        b) attempting to destroy a value multiple times, or
        c) an Apply destroys (illegally) one of its own inputs by aliasing

        """
        rval = OrderedDict()

        if self.destroyers:
            # BUILD DATA STRUCTURES
            # CHECK for multiple destructions during construction of variables

            droot, impact, __ignore = self.refresh_droot_impact()

            # check for destruction of constants
            illegal_destroy = [r for r in droot if \
                    getattr(r.tag,'indestructible', False) or \
                    isinstance(r, graph.Constant)]
            if illegal_destroy:
                raise InconsistencyError("Attempting to destroy indestructible variables: %s" %
                        illegal_destroy)

            # add destroyed variable clients as computational dependencies
            for app in self.destroyers:
                # for each destroyed input...
                for output_idx, input_idx_list in app.op.destroy_map.items():
                    destroyed_idx = input_idx_list[0]
                    destroyed_variable = app.inputs[destroyed_idx]
                    root = droot[destroyed_variable]
                    root_impact = impact[root]
                    # we generally want to put all clients of things which depend on root
                    # as pre-requisites of app.
                    # But, app is itself one such client!
                    # App will always be a client of the node we're destroying
                    # (destroyed_variable, but the tricky thing is when it is also a client of
                    # *another variable* viewing on the root.  Generally this is illegal, (e.g.,
                    # add_inplace(x, x.T).  In some special cases though, the in-place op will
                    # actually be able to work properly with multiple destroyed inputs (e.g,
                    # add_inplace(x, x).  An Op that can still work in this case should declare
                    # so via the 'destroyhandler_tolerate_same' attribute or
                    # 'destroyhandler_tolerate_aliased' attribute.
                    #
                    # destroyhandler_tolerate_same should be a list of pairs of the form
                    # [(idx0, idx1), (idx0, idx2), ...]
                    # The first element of each pair is the input index of a destroyed
                    # variable.
                    # The second element of each pair is the index of a different input where
                    # we will permit exactly the same variable to appear.
                    # For example, add_inplace.tolerate_same might be [(0,1)] if the destroyed
                    # input is also allowed to appear as the second argument.
                    #
                    # destroyhandler_tolerate_aliased is the same sort of list of
                    # pairs.
                    # op.destroyhandler_tolerate_aliased = [(idx0, idx1)] tells the
                    # destroyhandler to IGNORE an aliasing between a destroyed
                    # input idx0 and another input idx1.
                    # This is generally a bad idea, but it is safe in some
                    # cases, such as
                    # - the op reads from the aliased idx1 before modifying idx0
                    # - the idx0 and idx1 are guaranteed not to overlap (e.g.
                    #   they are pointed at different rows of a matrix).
                    #

                    #CHECK FOR INPUT ALIASING
                    # OPT: pre-compute this on import
                    tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', [])
                    assert isinstance(tolerate_same, list)
                    tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same
                            if idx0 == destroyed_idx)
                    tolerated.add(destroyed_idx)
                    tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', [])
                    assert isinstance(tolerate_aliased, list)
                    ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased
                            if idx0 == destroyed_idx)
                    #print 'tolerated', tolerated
                    #print 'ignored', ignored
                    for i, input in enumerate(app.inputs):
                        if i in ignored:
                            continue
                        if input in root_impact \
                                and (i not in tolerated or input is not destroyed_variable):
                            raise InconsistencyError("Input aliasing: %s (%i, %i)"
                                    % (app, destroyed_idx, i))

                    # add the rule: app must be preceded by all other Apply instances that
                    # depend on destroyed_input
                    root_clients = OrderedSet()
                    for r in root_impact:
                        assert not [a for a, c in self.clients[r].items() if not c]
                        root_clients.update([a for a, c in self.clients[r].items() if c])
                    root_clients.remove(app)
                    if root_clients:
                        rval[app] = root_clients

        return rval
Beispiel #26
0
 def __init__(self, iterable=None):
     self.data = OrderedDict()
     if iterable is not None:
         self.update(iterable)
Beispiel #27
0
    # MIT_MOT -- not provided by the user only by the grad function
    n_mit_mot = 0
    n_mit_mot_outs = 0
    mit_mot_scan_inputs = []
    mit_mot_inner_inputs = []
    mit_mot_inner_outputs = []
    mit_mot_out_slices = []
    mit_mot_rightOrder = []

    # SIT_SOT -- provided by the user
    n_mit_sot = 0
    mit_sot_scan_inputs = []
    mit_sot_inner_inputs = []
    mit_sot_inner_slices = []
    mit_sot_inner_outputs = []
    mit_sot_return_steps = OrderedDict()
    mit_sot_tap_array = []
    mit_sot_rightOrder = []

    n_sit_sot = 0
    sit_sot_scan_inputs = []
    sit_sot_inner_inputs = []
    sit_sot_inner_slices = []
    sit_sot_inner_outputs = []
    sit_sot_return_steps = OrderedDict()
    sit_sot_rightOrder = []
    nit_sot_steps = []
    # go through outputs picking up time slices as needed
    for i, init_out in enumerate(outs_info):
        # Note that our convention dictates that if an output uses
        # just the previous time step, as a initial state we will only
def test_mlp():
    """
    Demonstrate stochastic gradient descent optimization for a multilayer
    perceptron

    This is demonstrated on MNIST.

    :type learning_rate: float
    :param learning_rate: learning rate used (factor for the stochastic
    gradient

    :type n_epochs: int
    :param n_epochs: maximal number of epochs to run the optimizer

    :type dataset: string
    :param dataset: the path of the MNIST dataset file from
                         http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz


   """
    datasets = gen_data()

    train_set_x, train_set_y = datasets[0]
    valid_set_x, valid_set_y = datasets[1]
    test_set_x, test_set_y = datasets[2]

    batch_size = 100  # size of the minibatch

    # compute number of minibatches for training, validation and testing
    n_train_batches = train_set_x.get_value(borrow=True).shape[0] / batch_size
    n_valid_batches = valid_set_x.get_value(borrow=True).shape[0] / batch_size
    n_test_batches = test_set_x.get_value(borrow=True).shape[0] / batch_size

    ######################
    # BUILD ACTUAL MODEL #
    ######################
    #print '... building the model'

    # allocate symbolic variables for the data
    index = T.lscalar()  # index to a [mini]batch
    x = T.matrix('x')  # the data is presented as rasterized images
    y = T.ivector('y')  # the labels are presented as 1D vector of
    # [int] labels

    rng = numpy.random.RandomState(1234)

    # construct the MLP class
    classifier = MLP(rng=rng, input=x, n_in=28 * 28, n_hidden=500, n_out=10)

    # the cost we minimize during training is the negative log likelihood of
    # the model.
    # We take the mean of the cost over each minibatch.
    cost = classifier.negative_log_likelihood(y).mean()

    # compute the gradient of cost with respect to theta (stored in params)
    # the resulting gradients will be stored in a list gparams
    gparams = []
    for param in classifier.params:
        gparam = T.grad(cost, param)
        gparams.append(gparam)

    # Some optimizations needed are tagged with 'fast_run'
    # TODO: refine that and include only those
    mode = theano.compile.get_default_mode().including('fast_run')

    updates2 = OrderedDict()

    updates2[classifier.hiddenLayer.params[0]] = T.grad(
        cost, classifier.hiddenLayer.params[0])
    train_model = theano.function(
        inputs=[index],
        updates=updates2,
        givens={
            x: train_set_x[index * batch_size:(index + 1) * batch_size],
            y: train_set_y[index * batch_size:(index + 1) * batch_size]
        },
        mode=mode)
    #print 'MODEL 1'
    #theano.printing.debugprint(train_model, print_type=True)
    assert any([
        isinstance(i.op, T.nnet.CrossentropySoftmax1HotWithBiasDx)
        for i in train_model.maker.fgraph.toposort()
    ])

    # Even without FeatureShape
    train_model = theano.function(
        inputs=[index],
        updates=updates2,
        mode=mode.excluding('ShapeOpt'),
        givens={
            x: train_set_x[index * batch_size:(index + 1) * batch_size],
            y: train_set_y[index * batch_size:(index + 1) * batch_size]
        })
    #print
    #print 'MODEL 2'
    #theano.printing.debugprint(train_model, print_type=True)
    assert any([
        isinstance(i.op, T.nnet.CrossentropySoftmax1HotWithBiasDx)
        for i in train_model.maker.fgraph.toposort()
    ])
Beispiel #29
0
def scan(fn,
         sequences=None,
         outputs_info=None,
         non_sequences=None,
         n_steps=None,
         truncate_gradient=-1,
         go_backwards=False,
         mode=None,
         name=None,
         profile=False):
    """
    This function constructs and applies a Scan op to the provided
    arguments.

    :param fn:
        ``fn`` is a function that describes the operations involved in one
        step of ``scan``. ``fn`` should construct variables describing the
        output of one iteration step. It should expect as input theano
        variables representing all the slices of the input sequences
        and previous values of the outputs, as well as all other arguments
        given to scan as ``non_sequences``. The order in which scan passes
        these variables to ``fn``  is the following :

        * all time slices of the first sequence
        * all time slices of the second sequence
        * ...
        * all time slices of the last sequence
        * all past slices of the first output
        * all past slices of the second otuput
        * ...
        * all past slices of the last output
        * all other arguments (the list given as `non_sequences` to
            scan)

        The order of the sequences is the same as the one in the list
        `sequences` given to scan. The order of the outputs is the same
        as the order of ``output_info``. For any sequence or output the
        order of the time slices is the same as the one in which they have
        been given as taps. For example if one writes the following :

        .. code-block:: python

            scan(fn, sequences = [ dict(input= Sequence1, taps = [-3,2,-1])
                                 , Sequence2
                                 , dict(input =  Sequence3, taps = 3) ]
                   , outputs_info = [ dict(initial =  Output1, taps = [-3,-5])
                                    , dict(initial = Output2, taps = None)
                                    , Output3 ]
                   , non_sequences = [ Argument1, Argument2])

        ``fn`` should expect the following arguments in this given order:

        #. ``Sequence1[t-3]``
        #. ``Sequence1[t+2]``
        #. ``Sequence1[t-1]``
        #. ``Sequence2[t]``
        #. ``Sequence3[t+3]``
        #. ``Output1[t-3]``
        #. ``Output1[t-5]``
        #. ``Output3[t-1]``
        #. ``Argument1``
        #. ``Argument2``

        The list of ``non_sequences`` can also contain shared variables
        used in the function, though ``scan`` is able to figure those
        out on its own so they can be skipped. For the clarity of the
        code we recommend though to provide them to scan. To some extend
        ``scan`` can also figure out other ``non sequences`` (not shared)
        even if not passed to scan (but used by `fn`). A simple example of
        this would be :

        .. code-block:: python

            import theano.tensor as TT
            W   = TT.matrix()
            W_2 = W**2
            def f(x):
                return TT.dot(x,W_2)

        The function is expected to return two things. One is a list of
        outputs ordered in the same order as ``outputs_info``, with the
        difference that there should be only one output variable per
        output initial state (even if no tap value is used). Secondly
        `fn` should return an update dictionary (that tells how to
        update any shared variable after each iteration step). The
        dictionary can optionally be given as a list of tuples. There is
        no constraint on the order of these two list, ``fn`` can return
        either ``(outputs_list, update_dictionary)`` or
        ``(update_dictionary, outputs_list)`` or just one of the two (in
        case the other is empty).

        To use ``scan`` as a while loop, the user needs to change the
        function ``fn`` such that also a stopping condition is returned.
        To do so, he/she needs to wrap the condition in an ``until`` class.
        The condition should be returned as a third element, for example:

        .. code-block:: python

            ...
            return [y1_t, y2_t], {x:x+1}, theano.scan_module.until(x < 50)

        Note that a number of steps (considered in here as the maximum
        number of steps ) is still required even though a condition is
        passed (and it is used to allocate memory if needed). = {}):

    :param sequences:
        ``sequences`` is the list of Theano variables or dictionaries
        describing the sequences ``scan`` has to iterate over. If a
        sequence is given as wrapped in a dictionary, then a set of optional
        information can be provided about the sequence. The dictionary
        should have the following keys:

        * ``input`` (*mandatory*) -- Theano variable representing the
          sequence.

        * ``taps`` -- Temporal taps of the sequence required by ``fn``.
          They are provided as a list of integers, where a value ``k``
          impiles that at iteration step ``t`` scan will pass to ``fn``
          the slice ``t+k``. Default value is ``[0]``

        Any Theano variable in the list ``sequences`` is automatically
        wrapped into a dictionary where ``taps`` is set to ``[0]``


    :param outputs_info:
        ``outputs_info`` is the list of Theano variables or dictionaries
        describing the initial state of the outputs computed
        recurrently. When this initial states are given as dictionary
        optional information can be provided about the output corresponding
        to these initial states. The dictionary should have the following
        keys:

        * ``initial`` -- Theano variable that represents the initial
          state of a given output. In case the output is not computed
          recursively (think of a map) and does not require an initial
          state this field can be skipped. Given that (only) the previous
          time step of the output is used by ``fn``, the initial state
          **should have the same shape** as the output and **should not
          involve a downcast** of the data type of the output. If multiple
          time taps are used, the initial state should have one extra
          dimension that should cover all the possible taps. For example
          if we use ``-5``, ``-2`` and ``-1`` as past taps, at step 0,
          ``fn`` will require (by an abuse of notation) ``output[-5]``,
          ``output[-2]`` and ``output[-1]``. This will be given by
          the initial state, which in this case should have the shape
          (5,)+output.shape. If this variable containing the initial
          state is called ``init_y`` then ``init_y[0]`` *corresponds to*
          ``output[-5]``. ``init_y[1]`` *correponds to* ``output[-4]``,
          ``init_y[2]`` corresponds to ``output[-3]``, ``init_y[3]``
          coresponds to ``output[-2]``, ``init_y[4]`` corresponds to
          ``output[-1]``. While this order might seem strange, it comes
          natural from splitting an array at a given point. Assume that
          we have a array ``x``, and we choose ``k`` to be time step
          ``0``. Then our initial state would be ``x[:k]``, while the
          output will be ``x[k:]``. Looking at this split, elements in
          ``x[:k]`` are ordered exactly like those in ``init_y``.
        * ``taps`` -- Temporal taps of the output that will be pass to
          ``fn``. They are provided as a list of *negative* integers,
          where a value ``k`` implies that at iteration step ``t`` scan
          will pass to ``fn`` the slice ``t+k``.

        ``scan`` will follow this logic if partial information is given:

        * If an output is not wrapped in a dictionary, ``scan`` will wrap
          it in one assuming that you use only the last step of the output
          (i.e. it makes your tap value list equal to [-1]).
        * If you wrap an output in a dictionary and you do not provide any
          taps but you provide an initial state it will assume that you are
          using only a tap value of -1.
        * If you wrap an output in a dictionary but you do not provide any
          initial state, it assumes that you are not using any form of
          taps.
        * If you provide a ``None`` instead of a variable or a empty
          dictionary ``scan`` assumes that you will not use any taps for
          this output (like for example in case of a map)

        If ``outputs_info`` is an empty list or None, ``scan`` assumes
        that no tap is used for any of the outputs. If information is
        provided just for a subset of the outputs an exception is
        raised (because there is no convention on how scan should map
        the provided information to the outputs of ``fn``)


    :param non_sequences:
        ``non_sequences`` is the list of arguments that are passed to
        ``fn`` at each steps. One can opt to exclude variable
        used in ``fn`` from this list as long as they are part of the
        computational graph, though for clarity we encourage not to do so.


    :param n_steps:
        ``n_steps`` is the number of steps to iterate given as an int
        or Theano scalar. If any of the input sequences do not have
        enough elements, scan will raise an error. If the *value is 0* the
        outputs will have *0 rows*. If the value is negative, ``scan``
        will run backwards in time. If the ``go_backwards`` flag is already
        set and also ``n_steps`` is negative, ``scan`` will run forward
        in time. If n stpes is not provided, ``scan`` will figure
        out the amount of steps it should run given its input sequences.


    :param truncate_gradient:
        ``truncate_gradient`` is the number of steps to use in truncated
        BPTT.  If you compute gradients through a scan op, they are
        computed using backpropagation through time. By providing a
        different value then -1, you choose to use truncated BPTT instead
        of classical BPTT, where you go for only ``truncate_gradient``
        number of steps back in time.


    :param go_backwards:
        ``go_backwards`` is a flag indicating if ``scan`` should go
        backwards through the sequences. If you think of each sequence
        as indexed by time, making this flag True would mean that
        ``scan`` goes back in time, namely that for any sequence it
        starts from the end and goes towards 0.


    :param name:
        When profiling ``scan``, it is crucial to provide a name for any
        instance of ``scan``. The profiler will produce an overall
        profile of your code as well as profiles for the computation of
        one step of each instance of ``scan``. The ``name`` of the instance
        appears in those profiles and can greatly help to disambiguate
        information.

    :param mode:
        It is recommended to leave this argument to None, especially
        when profiling ``scan`` (otherwise the results are not going to
        be accurate). If you prefer the computations of one step of
        ``scan`` to be done differently then the entire function, you
        can use this parameter to describe how the computations in this
        loop are done (see ``theano.function`` for details about
        possible values and their meaning).

    :param profile:
        Flag or string. If true, or different from the empty string, a
        profile object will be created and attached to the inner graph of
        scan. In case ``profile`` is True, the profile object will have the
        name of the scan instance, otherwise it will have the passed string.
        Profile object collect (and print) information only when running the
        inner graph with the new cvm linker ( with default modes,
        other linkers this argument is useless)

    :rtype: tuple
    :return: tuple of the form (outputs, updates); ``outputs`` is either a
             Theano variable or a list of Theano variables representing the
             outputs of ``scan`` (in the same order as in
             ``outputs_info``). ``updates`` is a subclass of dictionary
             specifying the
             update rules for all shared variables used in scan
             This dictionary should be passed to ``theano.function`` when
             you compile your function. The change compared to a normal
             dictionary is that we validate that keys are SharedVariable
             and addition of those dictionary are validated to be consistent.
    """
    # General observation : this code is executed only once, at creation
    # of the computational graph, so we don't yet need to be smart about
    # anything (to speed things up)

    ##
    ###   Step 1. Wrap all inputs in dictionaries and add default values
    ##

    # check if inputs are just single variables instead of lists
    def wrap_into_list(x):
        '''
        Wrap the input into a list if it is not already a list
        '''
        if x is None:
            return []
        elif not isinstance(x, (list, tuple)):
            return [x]
        else:
            return list(x)

    seqs = wrap_into_list(sequences)
    outs_info = wrap_into_list(outputs_info)

    # Make sure we get rid of numpy arrays or ints or anything like that
    # passed as inputs to scan
    non_seqs = []
    for elem in wrap_into_list(non_sequences):
        if not isinstance(elem, gof.Variable):
            non_seqs.append(tensor.as_tensor_variable(elem))
        else:
            non_seqs.append(elem)

    # If we provided a known number of steps ( before compilation)
    # and if that number is 1 or -1, then we can skip the Scan Op,
    # and just apply the inner function once
    # To do that we check here to see the nature of n_steps
    n_fixed_steps = None

    if isinstance(n_steps, (float, int)):
        n_fixed_steps = int(n_steps)
    else:
        try:
            n_fixed_steps = opt.get_scalar_constant_value(n_steps)
        except tensor.basic.NotScalarConstantError:
            n_fixed_steps = None

    # Check n_steps is an int
    if (hasattr(n_steps, 'dtype') and
        str(n_steps.dtype)[:3] not in ('uin', 'int')):
        raise ValueError(' n_steps must be an int. dtype provided '
                         'is %s' % n_steps.dtype)

    # compute number of sequences and number of outputs
    n_seqs = len(seqs)
    n_outs = len(outs_info)

    return_steps = OrderedDict()
    # wrap sequences in a dictionary if they are not already dictionaries
    for i in xrange(n_seqs):
        if not isinstance(seqs[i], dict):
            seqs[i] = OrderedDict([('input', seqs[i]), ('taps', [0])])
        elif seqs[i].get('taps', None):
            seqs[i]['taps'] = wrap_into_list(seqs[i]['taps'])
        elif seqs[i].get('taps', True) is None:
            # seqs dictionary does not have the ``taps`` key
            seqs[i]['taps'] = [0]

    # wrap outputs info in a dictionary if they are not already in one
    for i in xrange(n_outs):
        if outs_info[i] is not None:
            if isinstance(outs_info[i], dict):
                # DEPRECATED :
                if outs_info[i].get('return_steps', None):
                    raise ValueError(
                            "Using `return_steps` has been deprecated. "
                            "Simply select the entries you need using a "
                            "subtensor. Scan will optimize memory "
                            "consumption, so do not worry about that.")
                # END

            if not isinstance(outs_info[i], dict):
                # by default any output has a tap value of -1
                outs_info[i] = OrderedDict([('initial', outs_info[i]), ('taps', [-1])])
            elif (not outs_info[i].get('initial', None) and
                    outs_info[i].get('taps', None)):
                # ^ no initial state but taps provided
                raise ValueError(('If you are using slices of an output '
                                  'you need to provide a initial state '
                                  'for it'), outs_info[i])
            elif (outs_info[i].get('initial', None) and
                  not outs_info[i].get('taps', None)):
                # ^ initial state but taps not provided
                if 'taps' in outs_info[i]:
                    # ^ explicitly provided a None for taps
                    _logger.warning('Output %s ( index %d) has a initial '
                            'state but taps is explicitly set to None ',
                             getattr(outs_info[i]['initial'], 'name', 'None'),
                             i)
                outs_info[i]['taps'] = [-1]
        else:
            # if a None is provided as the output info we replace it
            # with an empty OrdereDict() to simplify handling
            outs_info[i] = OrderedDict()

    ##
    ###   Step 2. Generate inputs and outputs of the inner functions
    ###           for compiling a dummy function (Iteration #1)
    ##

    # create theano inputs for the recursive function
    # note : this is a first batch of possible inputs that will
    #        be compiled in a dummy function; we used this dummy
    #        function to detect shared variables and their updates
    #        and to construct a new and complete list of inputs and
    #        outputs

    n_seqs = 0
    scan_seqs = []     # Variables passed as inputs to the scan op
    inner_seqs = []    # Variables passed as inputs to the inner function
    inner_slices = []  # Actual slices if scan is removed from the picture
    # go through sequences picking up time slices as needed
    for i, seq in enumerate(seqs):
        # Note that you can have something like no taps for
        # a sequence, though is highly unlikely in practice
        if 'taps' in seq:
            # go through the indicated slice
            mintap = numpy.min(seq['taps'])
            maxtap = numpy.max(seq['taps'])
            for k in seq['taps']:
                # create one slice of the input
                # Later on, if we decide not to use scan because we are
                # going for just one step, it makes things easier if we
                # compute the correct outputs here. This way we can use
                # the output of the lambda expression directly to replace
                # the output of scan.

                # If not we need to use copies, that will be replaced at
                # each frame by the corresponding slice
                actual_slice = seq['input'][k - mintap]
                _seq_val = tensor.as_tensor_variable(seq['input'])
                _seq_val_slice = _seq_val[k - mintap]
                nw_slice = _seq_val_slice.type()

                # Try to transfer test_value to the new variable
                if config.compute_test_value != 'off':
                    try:
                        nw_slice.tag.test_value = gof.Op._get_test_value(
                            _seq_val_slice)
                    except AttributeError, e:
                        if config.compute_test_value != 'ignore':
                            # No need to print a warning or raise an error now,
                            # it will be done when fn will be called.
                            _logger.info(('Cannot compute test value for '
                                'the inner function of scan, input value '
                                'missing %s'), e)

                # Add names to slices for debugging and pretty printing ..
                # that is if the input already has a name
                if getattr(seq['input'], 'name', None) is not None:
                    if k > 0:
                        nw_name = seq['input'].name + '[t+%d]' % k
                    elif k == 0:
                        nw_name = seq['input'].name + '[t]'
                    else:
                        nw_name = seq['input'].name + '[t%d]' % k
                    nw_slice.name = nw_name

                # We cut the sequence such that seq[i] to correspond to
                # seq[i-k]
                if maxtap < 0:
                    offset = abs(maxtap)
                else:
                    offset = 0
                if maxtap == mintap and maxtap != 0:
                    if maxtap < 0:
                        nw_seq = seq['input'][:maxtap]
                    else:
                        nw_seq = seq['input'][maxtap:]
                elif maxtap - k != 0:
                    nw_seq = seq['input'][offset + k - mintap: -(maxtap - k)]
                else:
                    nw_seq = seq['input'][offset + k - mintap:]
                if go_backwards:
                    nw_seq = nw_seq[::-1]

                scan_seqs.append(nw_seq)
                inner_seqs.append(nw_slice)
                inner_slices.append(actual_slice)
                n_seqs += 1
Beispiel #30
0
    # MIT_MOT -- not provided by the user only by the grad function
    n_mit_mot = 0
    n_mit_mot_outs = 0
    mit_mot_scan_inputs = []
    mit_mot_inner_inputs = []
    mit_mot_inner_outputs = []
    mit_mot_out_slices = []
    mit_mot_rightOrder = []

    # SIT_SOT -- provided by the user
    n_mit_sot = 0
    mit_sot_scan_inputs = []
    mit_sot_inner_inputs = []
    mit_sot_inner_slices = []
    mit_sot_inner_outputs = []
    mit_sot_return_steps = OrderedDict()
    mit_sot_tap_array = []
    mit_sot_rightOrder = []

    n_sit_sot = 0
    sit_sot_scan_inputs = []
    sit_sot_inner_inputs = []
    sit_sot_inner_slices = []
    sit_sot_inner_outputs = []
    sit_sot_return_steps = OrderedDict()
    sit_sot_rightOrder = []

    # go through outputs picking up time slices as needed
    for i, init_out in enumerate(outs_info):
        # Note that our convention dictates that if an output uses
        # just the previous time step, as a initial state we will only
Beispiel #31
0
class DestroyHandler(toolbox.Bookkeeper):
    """
    The DestroyHandler class detects when a graph is impossible to evaluate
    because of aliasing and destructive operations.

    Several data structures are used to do this.

    An Op can use its view_map property to declare that an output may be
    aliased to an input. If that output is destroyed, the input is also
    considered to be destroyed. The view_maps of several Ops can feed into
    one another and form a directed graph. The consequence of destroying any
    variable in such a graph is that all variables in the graph must be
    considered to be destroyed, because they could all be refering to the
    same underlying storage.

    In the current implementation, that graph is a tree, and the root of that
    tree is called the foundation.

    TODO: why "in the current implementation" ? is there another implementation
          planned?
    TODO: why is the graph a tree? isn't it possible that one variable could
          be aliased to many variables? for example, don't switch and ifelse
          have to do this?

    The original DestroyHandler (if 0'ed out above) computed several data
    structures from scratch each time it was asked to validate the graph.
    Because this happens potentially thousands of times and each graph to
    validate is extremely similar to the previous one, computing the
    data structures from scratch repeatedly was wasteful and resulted in
    high compile times for large graphs.

    This implementation computes the data structures once at initialization
    and then incrementally updates them.

    It is a work in progress. The following data structures have been
    converted to use the incremental strategy:
        <none>

    The following data structures remain to be converted:
        <unknown>
    """
    pickle_rm_attr = ["destroyers"]

    def __init__(self, do_imports_on_attach=True):
        self.fgraph = None
        self.do_imports_on_attach = do_imports_on_attach
        """maps every variable in the graph to its "foundation" (deepest
        ancestor in view chain)
        TODO: change name to var_to_vroot"""
        self.droot = OrderedDict()
        """maps a variable to all variables that are indirect or direct views of it
         (including itself)
         essentially the inverse of droot
        TODO: do all variables appear in this dict, or only those that are foundations?
        TODO: do only destroyed variables go in here? one old docstring said so
        TODO: rename to x_to_views after reverse engineering what x is"""
        self.impact = OrderedDict()
        """if a var is destroyed, then this dict will map
        droot[var] to the apply node that destroyed var
        TODO: rename to vroot_to_destroyer"""
        self.root_destroyer = OrderedDict()

    def on_attach(self, fgraph):
        """
        When attaching to a new fgraph, check that
            1) This DestroyHandler wasn't already attached to some fgraph
               (its data structures are only set up to serve one)
            2) The FunctionGraph doesn't already have a DestroyHandler.
               This would result in it validating everything twice, causing
               compilation to be slower.

        Give the FunctionGraph instance:
            1) A new method "destroyers(var)"
                TODO: what does this do exactly?
            2) A new attribute, "destroy_handler"
        TODO: WRITEME: what does this do besides the checks?
        """

        ####### Do the checking ###########
        already_there = False
        if self.fgraph is fgraph:
            already_there = True
        if self.fgraph is not None:
            raise Exception("A DestroyHandler instance can only serve one"
                            " FunctionGraph. (Matthew 6:24)")
        for attr in ('destroyers', 'destroy_handler'):
            if hasattr(fgraph, attr):
                already_there = True

        if already_there:
            # FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment
            raise toolbox.AlreadyThere(
                "DestroyHandler feature is already present"
                " or in conflict with another plugin.")

        ####### Annotate the FunctionGraph ############
        self.unpickle(fgraph)
        fgraph.destroy_handler = self

        self.fgraph = fgraph
        self.destroyers = OrderedSet(
        )  #set of Apply instances with non-null destroy_map
        self.view_i = OrderedDict()  # variable -> variable used in calculation
        self.view_o = OrderedDict(
        )  # variable -> set of variables that use this one as a direct input
        #clients: how many times does an apply use a given variable
        self.clients = OrderedDict()  # variable -> apply -> ninputs
        self.stale_droot = True

        self.debug_all_apps = OrderedSet()
        if self.do_imports_on_attach:
            toolbox.Bookkeeper.on_attach(self, fgraph)

    def unpickle(self, fgraph):
        def get_destroyers_of(r):
            droot, impact, root_destroyer = self.refresh_droot_impact()
            try:
                return [root_destroyer[droot[r]]]
            except Exception:
                return []

        fgraph.destroyers = get_destroyers_of

    def refresh_droot_impact(self):
        """
        Makes sure self.droot, self.impact, and self.root_destroyer are
        up to date, and returns them.
        (see docstrings for these properties above)
        """
        if self.stale_droot:
            droot = OrderedDict(
            )  # destroyed view + nonview variables -> foundation
            impact = OrderedDict(
            )  # destroyed nonview variable -> it + all views of it
            root_destroyer = OrderedDict()  # root -> destroyer apply

            for app in self.destroyers:
                for output_idx, input_idx_list in app.op.destroy_map.items():
                    if len(input_idx_list) != 1:
                        raise NotImplementedError()
                    input_idx = input_idx_list[0]
                    input = app.inputs[input_idx]
                    input_root = getroot(input, self.view_i)
                    if input_root in droot:
                        raise InconsistencyError("Multiple destroyers of %s" %
                                                 input_root)
                    droot[input_root] = input_root
                    root_destroyer[input_root] = app
                    input_impact = get_impact(input_root, self.view_o)
                    for v in input_impact:
                        assert v not in droot
                        droot[v] = input_root

                    impact[input_root] = input_impact
                    impact[input_root].add(input_root)
            self.droot, self.impact, self.root_destroyer = droot, impact, root_destroyer
            self.stale_droot = False
        return self.droot, self.impact, self.root_destroyer

    def on_detach(self, fgraph):
        if fgraph is not self.fgraph:
            raise Exception("detaching wrong fgraph", fgraph)
        del self.destroyers
        del self.view_i
        del self.view_o
        del self.clients
        del self.stale_droot
        assert self.fgraph.destroyer_handler is self
        delattr(self.fgraph, 'destroyers')
        delattr(self.fgraph, 'destroy_handler')
        self.fgraph = None

    def on_import(self, fgraph, app, reason):
        """Add Apply instance to set which must be computed"""

        if app in self.debug_all_apps:
            raise ProtocolError("double import")
        self.debug_all_apps.add(app)
        #print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)

        # If it's a destructive op, add it to our watch list
        if getattr(app.op, 'destroy_map', OrderedDict()):
            self.destroyers.add(app)

        # add this symbol to the forward and backward maps
        for o_idx, i_idx_list in getattr(app.op, 'view_map',
                                         OrderedDict()).items():
            if len(i_idx_list) > 1:
                raise NotImplementedError(
                    'destroying this output invalidates multiple inputs',
                    (app.op))
            o = app.outputs[o_idx]
            i = app.inputs[i_idx_list[0]]
            self.view_i[o] = i
            self.view_o.setdefault(i, OrderedSet()).add(o)

        # update self.clients
        for i, input in enumerate(app.inputs):
            self.clients.setdefault(input, OrderedDict()).setdefault(app, 0)
            self.clients[input][app] += 1

        for i, output in enumerate(app.outputs):
            self.clients.setdefault(output, OrderedDict())

        self.stale_droot = True

    def on_prune(self, fgraph, app, reason):
        """Remove Apply instance from set which must be computed"""
        if app not in self.debug_all_apps:
            raise ProtocolError("prune without import")
        self.debug_all_apps.remove(app)

        #UPDATE self.clients
        for i, input in enumerate(OrderedSet(app.inputs)):
            del self.clients[input][app]

        if getattr(app.op, 'destroy_map', OrderedDict()):
            self.destroyers.remove(app)

        # Note: leaving empty client dictionaries in the struct.
        # Why? It's a pain to remove them. I think they aren't doing any harm, they will be
        # deleted on_detach().

        #UPDATE self.view_i, self.view_o
        for o_idx, i_idx_list in getattr(app.op, 'view_map',
                                         OrderedDict()).items():
            if len(i_idx_list) > 1:
                #destroying this output invalidates multiple inputs
                raise NotImplementedError()
            o = app.outputs[o_idx]
            i = app.inputs[i_idx_list[0]]

            del self.view_i[o]

            self.view_o[i].remove(o)
            if not self.view_o[i]:
                del self.view_o[i]

        self.stale_droot = True

    def on_change_input(self, fgraph, app, i, old_r, new_r, reason):
        """app.inputs[i] changed from old_r to new_r """
        if app == 'output':
            # app == 'output' is special key that means FunctionGraph is redefining which nodes are being
            # considered 'outputs' of the graph.
            pass
        else:
            if app not in self.debug_all_apps:
                raise ProtocolError("change without import")

            #UPDATE self.clients
            self.clients[old_r][app] -= 1
            if self.clients[old_r][app] == 0:
                del self.clients[old_r][app]

            self.clients.setdefault(new_r, OrderedDict()).setdefault(app, 0)
            self.clients[new_r][app] += 1

            #UPDATE self.view_i, self.view_o
            for o_idx, i_idx_list in getattr(app.op, 'view_map',
                                             OrderedDict()).items():
                if len(i_idx_list) > 1:
                    #destroying this output invalidates multiple inputs
                    raise NotImplementedError()
                i_idx = i_idx_list[0]
                output = app.outputs[o_idx]
                if i_idx == i:
                    if app.inputs[i_idx] is not new_r:
                        raise ProtocolError("wrong new_r on change")

                    self.view_i[output] = new_r

                    self.view_o[old_r].remove(output)
                    if not self.view_o[old_r]:
                        del self.view_o[old_r]

                    self.view_o.setdefault(new_r, OrderedSet()).add(output)

        self.stale_droot = True

    def validate(self, fgraph):
        """Return None

        Raise InconsistencyError when
        a) orderings() raises an error
        b) orderings cannot be topologically sorted.

        """

        if self.destroyers:
            ords = self.orderings(fgraph)

            if _contains_cycle(fgraph, ords):
                raise InconsistencyError("Dependency graph contains cycles")
        else:
            #James's Conjecture:
            #If there are no destructive ops, then there can be no cycles.

            #FB: This isn't always True. It can happend that
            #optimization introduce node that depend on itself. This
            #is very rare and should not happen in general. It will be
            #caught later. The error will be far from the source. But
            #doing this conjecture should speed up compilation most of
            #the time. The user should create such dependency except
            #if he mess too much with the internal.
            pass
        return True

    def orderings(self, fgraph):
        """Return orderings induced by destructive operations.

        Raise InconsistencyError when
        a) attempting to destroy indestructable variable, or
        b) attempting to destroy a value multiple times, or
        c) an Apply destroys (illegally) one of its own inputs by aliasing

        """
        rval = OrderedDict()

        if self.destroyers:
            # BUILD DATA STRUCTURES
            # CHECK for multiple destructions during construction of variables

            droot, impact, __ignore = self.refresh_droot_impact()

            # check for destruction of constants
            illegal_destroy = [r for r in droot if \
                    getattr(r.tag,'indestructible', False) or \
                    isinstance(r, graph.Constant)]
            if illegal_destroy:
                raise InconsistencyError(
                    "Attempting to destroy indestructible variables: %s" %
                    illegal_destroy)

            # add destroyed variable clients as computational dependencies
            for app in self.destroyers:
                # for each destroyed input...
                for output_idx, input_idx_list in app.op.destroy_map.items():
                    destroyed_idx = input_idx_list[0]
                    destroyed_variable = app.inputs[destroyed_idx]
                    root = droot[destroyed_variable]
                    root_impact = impact[root]
                    # we generally want to put all clients of things which depend on root
                    # as pre-requisites of app.
                    # But, app is itself one such client!
                    # App will always be a client of the node we're destroying
                    # (destroyed_variable, but the tricky thing is when it is also a client of
                    # *another variable* viewing on the root.  Generally this is illegal, (e.g.,
                    # add_inplace(x, x.T).  In some special cases though, the in-place op will
                    # actually be able to work properly with multiple destroyed inputs (e.g,
                    # add_inplace(x, x).  An Op that can still work in this case should declare
                    # so via the 'destroyhandler_tolerate_same' attribute or
                    # 'destroyhandler_tolerate_aliased' attribute.
                    #
                    # destroyhandler_tolerate_same should be a list of pairs of the form
                    # [(idx0, idx1), (idx0, idx2), ...]
                    # The first element of each pair is the input index of a destroyed
                    # variable.
                    # The second element of each pair is the index of a different input where
                    # we will permit exactly the same variable to appear.
                    # For example, add_inplace.tolerate_same might be [(0,1)] if the destroyed
                    # input is also allowed to appear as the second argument.
                    #
                    # destroyhandler_tolerate_aliased is the same sort of list of
                    # pairs.
                    # op.destroyhandler_tolerate_aliased = [(idx0, idx1)] tells the
                    # destroyhandler to IGNORE an aliasing between a destroyed
                    # input idx0 and another input idx1.
                    # This is generally a bad idea, but it is safe in some
                    # cases, such as
                    # - the op reads from the aliased idx1 before modifying idx0
                    # - the idx0 and idx1 are guaranteed not to overlap (e.g.
                    #   they are pointed at different rows of a matrix).
                    #

                    #CHECK FOR INPUT ALIASING
                    # OPT: pre-compute this on import
                    tolerate_same = getattr(app.op,
                                            'destroyhandler_tolerate_same', [])
                    assert isinstance(tolerate_same, list)
                    tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same
                                           if idx0 == destroyed_idx)
                    tolerated.add(destroyed_idx)
                    tolerate_aliased = getattr(
                        app.op, 'destroyhandler_tolerate_aliased', [])
                    assert isinstance(tolerate_aliased, list)
                    ignored = OrderedSet(idx1
                                         for idx0, idx1 in tolerate_aliased
                                         if idx0 == destroyed_idx)
                    #print 'tolerated', tolerated
                    #print 'ignored', ignored
                    for i, input in enumerate(app.inputs):
                        if i in ignored:
                            continue
                        if input in root_impact \
                                and (i not in tolerated or input is not destroyed_variable):
                            raise InconsistencyError(
                                "Input aliasing: %s (%i, %i)" %
                                (app, destroyed_idx, i))

                    # add the rule: app must be preceded by all other Apply instances that
                    # depend on destroyed_input
                    root_clients = OrderedSet()
                    for r in root_impact:
                        assert not [
                            a for a, c in self.clients[r].items() if not c
                        ]
                        root_clients.update(
                            [a for a, c in self.clients[r].items() if c])
                    root_clients.remove(app)
                    if root_clients:
                        rval[app] = root_clients

        return rval
Beispiel #32
0
 def __init__(self, iterable=None):
     self.data = OrderedDict()
     if iterable is not None:
         self.update(iterable)
Beispiel #33
0
    def orderings(self, fgraph):
        """Return orderings induced by destructive operations.

        Raise InconsistencyError when
        a) attempting to destroy indestructable variable, or
        b) attempting to destroy a value multiple times, or
        c) an Apply destroys (illegally) one of its own inputs by aliasing

        """
        rval = OrderedDict()

        if self.destroyers:
            # BUILD DATA STRUCTURES
            # CHECK for multiple destructions during construction of variables

            droot, impact, __ignore = self.refresh_droot_impact()

            # check for destruction of constants
            illegal_destroy = [r for r in droot if \
                    getattr(r.tag,'indestructible', False) or \
                    isinstance(r, graph.Constant)]
            if illegal_destroy:
                raise InconsistencyError(
                    "Attempting to destroy indestructible variables: %s" %
                    illegal_destroy)

            # add destroyed variable clients as computational dependencies
            for app in self.destroyers:
                # for each destroyed input...
                for output_idx, input_idx_list in app.op.destroy_map.items():
                    destroyed_idx = input_idx_list[0]
                    destroyed_variable = app.inputs[destroyed_idx]
                    root = droot[destroyed_variable]
                    root_impact = impact[root]
                    # we generally want to put all clients of things which depend on root
                    # as pre-requisites of app.
                    # But, app is itself one such client!
                    # App will always be a client of the node we're destroying
                    # (destroyed_variable, but the tricky thing is when it is also a client of
                    # *another variable* viewing on the root.  Generally this is illegal, (e.g.,
                    # add_inplace(x, x.T).  In some special cases though, the in-place op will
                    # actually be able to work properly with multiple destroyed inputs (e.g,
                    # add_inplace(x, x).  An Op that can still work in this case should declare
                    # so via the 'destroyhandler_tolerate_same' attribute or
                    # 'destroyhandler_tolerate_aliased' attribute.
                    #
                    # destroyhandler_tolerate_same should be a list of pairs of the form
                    # [(idx0, idx1), (idx0, idx2), ...]
                    # The first element of each pair is the input index of a destroyed
                    # variable.
                    # The second element of each pair is the index of a different input where
                    # we will permit exactly the same variable to appear.
                    # For example, add_inplace.tolerate_same might be [(0,1)] if the destroyed
                    # input is also allowed to appear as the second argument.
                    #
                    # destroyhandler_tolerate_aliased is the same sort of list of
                    # pairs.
                    # op.destroyhandler_tolerate_aliased = [(idx0, idx1)] tells the
                    # destroyhandler to IGNORE an aliasing between a destroyed
                    # input idx0 and another input idx1.
                    # This is generally a bad idea, but it is safe in some
                    # cases, such as
                    # - the op reads from the aliased idx1 before modifying idx0
                    # - the idx0 and idx1 are guaranteed not to overlap (e.g.
                    #   they are pointed at different rows of a matrix).
                    #

                    #CHECK FOR INPUT ALIASING
                    # OPT: pre-compute this on import
                    tolerate_same = getattr(app.op,
                                            'destroyhandler_tolerate_same', [])
                    assert isinstance(tolerate_same, list)
                    tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same
                                           if idx0 == destroyed_idx)
                    tolerated.add(destroyed_idx)
                    tolerate_aliased = getattr(
                        app.op, 'destroyhandler_tolerate_aliased', [])
                    assert isinstance(tolerate_aliased, list)
                    ignored = OrderedSet(idx1
                                         for idx0, idx1 in tolerate_aliased
                                         if idx0 == destroyed_idx)
                    #print 'tolerated', tolerated
                    #print 'ignored', ignored
                    for i, input in enumerate(app.inputs):
                        if i in ignored:
                            continue
                        if input in root_impact \
                                and (i not in tolerated or input is not destroyed_variable):
                            raise InconsistencyError(
                                "Input aliasing: %s (%i, %i)" %
                                (app, destroyed_idx, i))

                    # add the rule: app must be preceded by all other Apply instances that
                    # depend on destroyed_input
                    root_clients = OrderedSet()
                    for r in root_impact:
                        assert not [
                            a for a, c in self.clients[r].items() if not c
                        ]
                        root_clients.update(
                            [a for a, c in self.clients[r].items() if c])
                    root_clients.remove(app)
                    if root_clients:
                        rval[app] = root_clients

        return rval