Exemplo n.º 1
    def _get_variables(self):
        """Collect variables, updates and auxiliary variables.

        In addition collects all :class:`.Scan` ops and recurses in the
        respective inner Theano graphs.

        updates = OrderedDict()

        shared_outputs = [o for o in self.outputs if is_shared_variable(o)]
        usual_outputs = [o for o in self.outputs if not is_shared_variable(o)]
        variables = shared_outputs

        if usual_outputs:
            # Sort apply nodes topologically, get variables and remove
            # duplicates
            inputs = graph.inputs(self.outputs)
            self.sorted_apply_nodes = graph.io_toposort(inputs, usual_outputs)
            self.scans = list(
                    node.op for node in self.sorted_apply_nodes
                    if isinstance(node.op, Scan)
            self.sorted_scan_nodes = [
                node for node in self.sorted_apply_nodes
                if isinstance(node.op, Scan)
            self._scan_graphs = [
                ComputationGraph(scan.outputs) for scan in self.scans

            seen = set()
            main_vars = ([
                var for var in list(
                        for apply_node in self.sorted_apply_nodes
                    ])) if not (var in seen or seen.add(var))
            ] + [var for var in self.outputs if var not in seen])

            # While preserving order add auxiliary variables, and collect
            # updates
            seen = set()
            # Intermediate variables could be auxiliary
            seen_avs = set(main_vars)
            variables = []
            for var in main_vars:
                for annotation in getattr(var.tag, 'annotations', []):
                    if annotation not in seen:
                        new_avs = [
                            av for av in annotation.auxiliary_variables
                            if not (av in seen_avs or seen_avs.add(av))
                        updates = dict_union(updates, annotation.updates)

        self.variables = variables
        self.updates = updates
Exemplo n.º 2
    def _get_variables(self):
        """Collect variables, updates and auxiliary variables.

        In addition collects all :class:`.Scan` ops and recurses in the
        respective inner Theano graphs.

        updates = OrderedDict()

        shared_outputs = [o for o in self.outputs if is_shared_variable(o)]
        usual_outputs = [o for o in self.outputs if not is_shared_variable(o)]
        variables = shared_outputs

        if usual_outputs:
            # Sort apply nodes topologically, get variables and remove
            # duplicates
            inputs = graph.inputs(self.outputs)
            self.sorted_apply_nodes = graph.io_toposort(inputs, usual_outputs)
            self.scans = list(unique([node.op for node in self.sorted_apply_nodes
                                     if isinstance(node.op, Scan)]))
            self.sorted_scan_nodes = [node for node in self.sorted_apply_nodes
                                      if isinstance(node.op, Scan)]
            self._scan_graphs = [ComputationGraph(scan.outputs)
                                 for scan in self.scans]

            seen = set()
            main_vars = (
                [var for var in list(chain(
                    *[apply_node.inputs for apply_node in self.sorted_apply_nodes]))
                 if not (var in seen or seen.add(var))] +
                [var for var in self.outputs if var not in seen])

            # While preserving order add auxiliary variables, and collect
            # updates
            seen = set()
            # Intermediate variables could be auxiliary
            seen_avs = set(main_vars)
            variables = []
            for var in main_vars:
                for annotation in getattr(var.tag, 'annotations', []):
                    if annotation not in seen:
                        new_avs = [
                            av for av in annotation.auxiliary_variables
                            if not (av in seen_avs or seen_avs.add(av))]
                        updates = dict_union(updates, annotation.updates)

        self.variables = variables
        self.updates = updates
Exemplo n.º 3
    def test_many_steps(self):
        x = tensor.tensor3('x')
        mask = tensor.matrix('mask')
        h = self.simple.apply(x, mask=mask, iterate=True)
        calc_h = theano.function(inputs=[x, mask], outputs=[h])

        x_val = 0.1 * numpy.asarray(list(itertools.permutations(range(4))),
        x_val = numpy.ones((24, 4, 3),
                           dtype=theano.config.floatX) * x_val[..., None]
        mask_val = numpy.ones((24, 4), dtype=theano.config.floatX)
        mask_val[12:24, 3] = 0
        h_val = numpy.zeros((25, 4, 3), dtype=theano.config.floatX)
        for i in range(1, 25):
            h_val[i] = numpy.tanh(h_val[i - 1].dot(
                2 * numpy.ones((3, 3))) + x_val[i - 1])
            h_val[i] = (mask_val[i - 1, :, None] * h_val[i] +
                        (1 - mask_val[i - 1, :, None]) * h_val[i - 1])
        h_val = h_val[1:]
        assert_allclose(h_val, calc_h(x_val, mask_val)[0], rtol=1e-04)

        # Also test that initial state is a parameter
        initial_state, = VariableFilter(roles=[INITIAL_STATE])(
        assert is_shared_variable(initial_state)
        assert initial_state.name == 'initial_state'
Exemplo n.º 4
    def get_snapshot(self, data):
        """Evaluate all role-carrying Theano variables on given data.

        data : dict of (data source, data) pairs
            Data for input variables. The sources should match with the
            names of the input variables.

        Dictionary of (variable, variable value on given data) pairs.

        role_variables = [
            var for var in self.variables
            if hasattr(var.tag, "roles") and not is_shared_variable(var)
        value_holders = [shared_like(var) for var in role_variables]
        function = self.get_theano_function(
            equizip(value_holders, role_variables))
        function(*(data[input_.name] for input_ in self.inputs))
        return OrderedDict([
            (var, value_holder.get_value(borrow=True))
            for var, value_holder in equizip(role_variables, value_holders)
Exemplo n.º 5
    def test_many_steps(self):
        x = tensor.tensor3('x')
        mask = tensor.matrix('mask')
        h, c = self.lstm.apply(x, mask=mask, iterate=True)
        calc_h = theano.function(inputs=[x, mask], outputs=[h])

        x_val = (0.1 * numpy.asarray(
            list(itertools.islice(itertools.permutations(range(12)), 0, 24)),
        x_val = numpy.ones((24, 4, 12),
                           dtype=theano.config.floatX) * x_val[:, None, :]
        mask_val = numpy.ones((24, 4), dtype=theano.config.floatX)
        mask_val[12:24, 3] = 0
        h_val = numpy.zeros((25, 4, 3), dtype=theano.config.floatX)
        c_val = numpy.zeros((25, 4, 3), dtype=theano.config.floatX)
        W_state_val = 2 * numpy.ones((3, 12), dtype=theano.config.floatX)
        W_cell_to_in = 2 * numpy.ones((3,), dtype=theano.config.floatX)
        W_cell_to_out = 2 * numpy.ones((3,), dtype=theano.config.floatX)
        W_cell_to_forget = 2 * numpy.ones((3,), dtype=theano.config.floatX)

        def sigmoid(x):
            return 1. / (1. + numpy.exp(-x))

        for i in range(1, 25):
            activation = numpy.dot(h_val[i-1], W_state_val) + x_val[i-1]
            i_t = sigmoid(activation[:, :3] + c_val[i-1] * W_cell_to_in)
            f_t = sigmoid(activation[:, 3:6] + c_val[i-1] * W_cell_to_forget)
            c_val[i] = f_t * c_val[i-1] + i_t * numpy.tanh(activation[:, 6:9])
            o_t = sigmoid(activation[:, 9:12] +
                          c_val[i] * W_cell_to_out)
            h_val[i] = o_t * numpy.tanh(c_val[i])
            h_val[i] = (mask_val[i - 1, :, None] * h_val[i] +
                        (1 - mask_val[i - 1, :, None]) * h_val[i - 1])
            c_val[i] = (mask_val[i - 1, :, None] * c_val[i] +
                        (1 - mask_val[i - 1, :, None]) * c_val[i - 1])

        h_val = h_val[1:]
        assert_allclose(h_val, calc_h(x_val, mask_val)[0], rtol=1e-04)

        # Also test that initial state is a parameter
        initial1, initial2 = VariableFilter(roles=[INITIAL_STATE])(
        assert is_shared_variable(initial1)
        assert is_shared_variable(initial2)
        assert {initial1.name, initial2.name} == {
            'initial_state', 'initial_cells'}
Exemplo n.º 6
    def _get_variables(self):
        """Collect variables, updates and auxiliary variables."""
        updates = OrderedDict()

        shared_outputs = [o for o in self.outputs if is_shared_variable(o)]
        usual_outputs = [o for o in self.outputs if not is_shared_variable(o)]
        variables = shared_outputs

        if usual_outputs:
            # Sort apply nodes topologically, get variables and remove
            # duplicates
            inputs = graph.inputs(self.outputs)
            sorted_apply_nodes = graph.io_toposort(inputs, usual_outputs)

            seen = set()
            main_vars = [
                var for var in list(
                        apply_node.inputs for apply_node in sorted_apply_nodes
                    ])) if not (var in seen or seen.add(var))
            ] + self.outputs

            # While preserving order add auxiliary variables, and collect
            # updates
            seen = set()
            # Intermediate variables could be auxiliary
            seen_avs = set(main_vars)
            variables = []
            for var in main_vars:
                for annotation in getattr(var.tag, 'annotations', []):
                    if annotation not in seen:
                        new_avs = [
                            av for av in annotation.auxiliary_variables
                            if not (av in seen_avs or seen_avs.add(av))
                        updates = dict_union(updates, annotation.updates)

        self.variables = variables
        self.updates = updates
Exemplo n.º 7
 def simple_assertions(self, updates, num_bricks=2, num_updates=4):
     """Shared assertions for simple tests."""
     assert len(updates) == num_updates
     assert all(is_shared_variable(u[0]) for u in updates)
     # This order is somewhat arbitrary and implementation_dependent
     means = set(u[0] for u in updates
                 if has_roles(u[0], [BATCH_NORM_POPULATION_MEAN]))
     stdevs = set(u[0] for u in updates
                  if has_roles(u[0], [BATCH_NORM_POPULATION_STDEV]))
     assert means.isdisjoint(stdevs)
     assert len(set(get_brick(v) for v in means)) == num_bricks
     assert len(set(get_brick(v) for v in stdevs)) == num_bricks
Exemplo n.º 8
 def simple_assertions(self, updates, num_bricks=2, num_updates=4):
     """Shared assertions for simple tests."""
     assert len(updates) == num_updates
     assert all(is_shared_variable(u[0]) for u in updates)
     # This order is somewhat arbitrary and implementation_dependent
     means = set(u[0] for u in updates
                 if has_roles(u[0], [BATCH_NORM_POPULATION_MEAN]))
     stdevs = set(u[0] for u in updates
                  if has_roles(u[0], [BATCH_NORM_POPULATION_STDEV]))
     assert means.isdisjoint(stdevs)
     assert len(set(get_brick(v) for v in means)) == num_bricks
     assert len(set(get_brick(v) for v in stdevs)) == num_bricks
Exemplo n.º 9
    def test_many_steps(self):
        x = tensor.tensor3('x')
        gi = tensor.tensor3('gi')
        mask = tensor.matrix('mask')
        h = self.reset_only.apply(x, gi, mask=mask)
        calc_h = theano.function(inputs=[x, gi, mask], outputs=[h])

        x_val = 0.1 * numpy.asarray(list(itertools.permutations(range(4))),
        x_val = numpy.ones((24, 4, 3),
                           dtype=theano.config.floatX) * x_val[..., None]
        ri_val = 0.3 - x_val
        zi_val = 2 * ri_val
        mask_val = numpy.ones((24, 4), dtype=theano.config.floatX)
        mask_val[12:24, 3] = 0
        h_val = numpy.zeros((25, 4, 3), dtype=theano.config.floatX)
        W = self.reset_only.state_to_state.get_value()
        Wz = self.reset_only.state_to_gates.get_value()[:, :3]
        Wr = self.reset_only.state_to_gates.get_value()[:, 3:]

        for i in range(1, 25):
            z_val = numpy.tanh(h_val[i - 1].dot(Wz) + zi_val[i - 1])
            r_val = numpy.tanh(h_val[i - 1].dot(Wr) + ri_val[i - 1])
            h_val[i] = numpy.tanh((r_val * h_val[i - 1]).dot(W) +
                                  x_val[i - 1])
            h_val[i] = z_val * h_val[i] + (1 - z_val) * h_val[i - 1]
            h_val[i] = (mask_val[i - 1, :, None] * h_val[i] +
                        (1 - mask_val[i - 1, :, None]) * h_val[i - 1])
        h_val = h_val[1:]
        # TODO Figure out why this tolerance needs to be so big
            calc_h(x_val, numpy.concatenate(
                [zi_val, ri_val], axis=2), mask_val)[0],

        # Also test that initial state is a parameter
        initial_state, = VariableFilter(roles=[INITIAL_STATE])(
        assert is_shared_variable(initial_state)
        assert initial_state.name == 'initial_state'
Exemplo n.º 10
    def get_snapshot(self, data):
        """Evaluate all role-carrying Theano variables on given data.

        data : dict of (data source, data) pairs
            Data for input variables. The sources should match with the
            names of the input variables.

        Dictionary of (variable, variable value on given data) pairs.

        role_variables = [var for var in self.variables
                          if hasattr(var.tag, "roles") and
                          not is_shared_variable(var)]
        value_holders = [shared_like(var) for var in role_variables]
        function = self.get_theano_function(zip(value_holders, role_variables))
        function(*(data[input_.name] for input_ in self.inputs))
        return OrderedDict([(var, value_holder.get_value(borrow=True))
                            for var, value_holder in zip(role_variables,
Exemplo n.º 11
    def do_many_steps(self, stack, skip_connections=False, low_memory=False):
        depth = self.depth

        # 24 steps
        #  4 batch examples
        # 12 dimensions per step
        x_val = (0.1 * numpy.asarray(
            list(itertools.islice(itertools.permutations(range(12)), 0, 24)),
        x_val = numpy.ones((24, 4, 12),
                           dtype=theano.config.floatX) * x_val[:, None, :]
        # mask the last third of steps
        mask_val = numpy.ones((24, 4), dtype=theano.config.floatX)
        mask_val[12:24, 3] = 0
        # unroll all states and cells for all steps and also initial value
        h_val = numpy.zeros((depth, 25, 4, 3), dtype=theano.config.floatX)
        c_val = numpy.zeros((depth, 25, 4, 3), dtype=theano.config.floatX)
        # we will use same weights on all layers
        W_state2x_val = 2 * numpy.ones((3, 12), dtype=theano.config.floatX)
        W_state_val = 2 * numpy.ones((3, 12), dtype=theano.config.floatX)
        W_cell_to_in = 2 * numpy.ones((3,), dtype=theano.config.floatX)
        W_cell_to_out = 2 * numpy.ones((3,), dtype=theano.config.floatX)
        W_cell_to_forget = 2 * numpy.ones((3,), dtype=theano.config.floatX)

        kwargs = OrderedDict()

        for d in range(depth):
            if d > 0:
                suffix = RECURRENTSTACK_SEPARATOR + str(d)
                suffix = ''
            if d == 0 or skip_connections:
                kwargs['inputs' + suffix] = tensor.tensor3('inputs' + suffix)
                kwargs['inputs' + suffix].tag.test_value = x_val

        kwargs['mask'] = tensor.matrix('mask')
        kwargs['mask'].tag.test_value = mask_val
        results = stack.apply(iterate=True, low_memory=low_memory, **kwargs)
        calc_h = theano.function(inputs=list(kwargs.values()),

        def sigmoid(x):
            return 1. / (1. + numpy.exp(-x))

        for i in range(1, 25):
            x_v = x_val[i - 1]
            h_vs = []
            c_vs = []
            for d in range(depth):
                h_v = h_val[d][i - 1, :, :]
                c_v = c_val[d][i - 1, :, :]
                activation = numpy.dot(h_v, W_state_val) + x_v
                if skip_connections and d > 0:
                    activation += x_val[i - 1]

                i_t = sigmoid(activation[:, :3] + c_v * W_cell_to_in)
                f_t = sigmoid(activation[:, 3:6] + c_v * W_cell_to_forget)
                c_v1 = f_t * c_v + i_t * numpy.tanh(activation[:, 6:9])
                o_t = sigmoid(activation[:, 9:12] +
                              c_v1 * W_cell_to_out)
                h_v1 = o_t * numpy.tanh(c_v1)
                h_v = (mask_val[i - 1, :, None] * h_v1 +
                       (1 - mask_val[i - 1, :, None]) * h_v)
                c_v = (mask_val[i - 1, :, None] * c_v1 +
                       (1 - mask_val[i - 1, :, None]) * c_v)
                # current layer output state transformed to input of next
                x_v = numpy.dot(h_v, W_state2x_val)


            for d in range(depth):
                h_val[d][i, :, :] = h_vs[d]
                c_val[d][i, :, :] = c_vs[d]

        args_val = [x_val]*(depth if skip_connections else 1) + [mask_val]
        res = calc_h(*args_val)
        for d in range(depth):
            assert_allclose(h_val[d][1:], res[d * 2], rtol=1e-4)
            assert_allclose(c_val[d][1:], res[d * 2 + 1], rtol=1e-4)

        # Also test that initial state is a parameter
        for h in results:
            initial_states = VariableFilter(roles=[INITIAL_STATE])(
            assert all(is_shared_variable(initial_state)
                       for initial_state in initial_states)
Exemplo n.º 12
        def recurrent_apply(brick, application, application_call,
                            *args, **kwargs):
            """Iterates a transition function.

            iterate : bool
                If ``True`` iteration is made. By default ``True``.
            reverse : bool
                If ``True``, the sequences are processed in backward
                direction. ``False`` by default.
            return_initial_states : bool
                If ``True``, initial states are included in the returned
                state tensors. ``False`` by default.

            .. todo::

                * Handle `updates` returned by the :func:`theano.scan`
                * ``kwargs`` has a random order; check if this is a

            # Extract arguments related to iteration and immediately relay the
            # call to the wrapped function if `iterate=False`
            iterate = kwargs.pop('iterate', True)
            if not iterate:
                return application_function(brick, *args, **kwargs)
            reverse = kwargs.pop('reverse', False)
            return_initial_states = kwargs.pop('return_initial_states', False)

            # Push everything to kwargs
            for arg, arg_name in zip(args, arg_names):
                kwargs[arg_name] = arg

            # Make sure that all arguments for scan are tensor variables
            scan_arguments = (application.sequences + application.states +
            for arg in scan_arguments:
                if arg in kwargs:
                    if kwargs[arg] is None:
                        del kwargs[arg]
                        kwargs[arg] = tensor.as_tensor_variable(kwargs[arg])

            # Check which sequence and contexts were provided
            sequences_given = dict_subset(kwargs, application.sequences,
            contexts_given = dict_subset(kwargs, application.contexts,

            # Determine number of steps and batch size.
            if len(sequences_given):
                # TODO Assumes 1 time dim!
                shape = list(sequences_given.values())[0].shape
                if not iterate:
                    batch_size = shape[0]
                    n_steps = shape[0]
                    batch_size = shape[1]
                # TODO Raise error if n_steps and batch_size not found?
                n_steps = kwargs.pop('n_steps')
                batch_size = kwargs.pop('batch_size')

            # Handle the rest kwargs
            rest_kwargs = {key: value for key, value in kwargs.items()
                           if key not in scan_arguments}
            for value in rest_kwargs.values():
                if (isinstance(value, Variable) and not
                    logger.warning("unknown input {}".format(value) +

            # Ensure that all initial states are available.
            for state_name in application.states:
                dim = brick.get_dim(state_name)
                if state_name in kwargs:
                    if isinstance(kwargs[state_name], NdarrayInitialization):
                        kwargs[state_name] = tensor.alloc(
                            kwargs[state_name].generate(brick.rng, (1, dim)),
                            batch_size, dim)
                    elif isinstance(kwargs[state_name], Application):
                        kwargs[state_name] = (
                            kwargs[state_name](state_name, batch_size,
                                               *args, **kwargs))
                    # TODO init_func returns 2D-tensor, fails for iterate=False
                    kwargs[state_name] = (
                        brick.initial_state(state_name, batch_size,
                                            *args, **kwargs))
                    assert kwargs[state_name]
            states_given = dict_subset(kwargs, application.states)

            # Theano issue 1772
            for name, state in states_given.items():
                states_given[name] = tensor.unbroadcast(state,

            def scan_function(*args):
                args = list(args)
                arg_names = (list(sequences_given) +
                             [output for output in application.outputs
                              if output in application.states] +
                kwargs = dict(equizip(arg_names, args))
                outputs = application(iterate=False, **kwargs)
                # We want to save the computation graph returned by the
                # `application_function` when it is called inside the
                # `theano.scan`.
                application_call.inner_inputs = args
                application_call.inner_outputs = pack(outputs)
                return outputs
            outputs_info = [
                states_given[name] if name in application.states
                else None
                for name in application.outputs]
            result, updates = theano.scan(
                scan_function, sequences=list(sequences_given.values()),
            result = pack(result)
            if return_initial_states:
                # Undo Subtensor
                for i in range(len(states_given)):
                    assert isinstance(result[i].owner.op,
                    result[i] = result[i].owner.inputs[0]
            if updates:
                application_call.updates = dict_union(application_call.updates,

            return result
Exemplo n.º 13
        def recurrent_apply(brick, application, application_call, *args,
            """Iterates a transition function.

            iterate : bool
                If ``True`` iteration is made. By default ``True``.
            reverse : bool
                If ``True``, the sequences are processed in backward
                direction. ``False`` by default.
            return_initial_states : bool
                If ``True``, initial states are included in the returned
                state tensors. ``False`` by default.

            # Extract arguments related to iteration and immediately relay the
            # call to the wrapped function if `iterate=False`
            iterate = kwargs.pop('iterate', True)
            if not iterate:
                return application_function(brick, *args, **kwargs)
            reverse = kwargs.pop('reverse', False)
            return_initial_states = kwargs.pop('return_initial_states', False)

            # Push everything to kwargs
            for arg, arg_name in zip(args, arg_names):
                kwargs[arg_name] = arg

            # Make sure that all arguments for scan are tensor variables
            scan_arguments = (application.sequences + application.states +
            for arg in scan_arguments:
                if arg in kwargs:
                    if kwargs[arg] is None:
                        del kwargs[arg]
                        kwargs[arg] = tensor.as_tensor_variable(kwargs[arg])

            # Check which sequence and contexts were provided
            sequences_given = dict_subset(kwargs,
            contexts_given = dict_subset(kwargs,

            # Determine number of steps and batch size.
            if len(sequences_given):
                # TODO Assumes 1 time dim!
                shape = list(sequences_given.values())[0].shape
                n_steps = shape[0]
                batch_size = shape[1]
                # TODO Raise error if n_steps and batch_size not found?
                n_steps = kwargs.pop('n_steps')
                batch_size = kwargs.pop('batch_size')

            # Handle the rest kwargs
            rest_kwargs = {
                key: value
                for key, value in kwargs.items() if key not in scan_arguments
            for value in rest_kwargs.values():
                if (isinstance(value, Variable)
                        and not is_shared_variable(value)):
                    logger.warning("unknown input {}".format(value) +

            # Ensure that all initial states are available.
            initial_states = brick.initial_states(batch_size,
            for state_name in application.states:
                dim = brick.get_dim(state_name)
                if state_name in kwargs:
                    if isinstance(kwargs[state_name], NdarrayInitialization):
                        kwargs[state_name] = tensor.alloc(
                            kwargs[state_name].generate(brick.rng, (1, dim)),
                            batch_size, dim)
                    elif isinstance(kwargs[state_name], Application):
                        kwargs[state_name] = (kwargs[state_name](state_name,
                        kwargs[state_name] = initial_states[state_name]
                    except KeyError:
                        raise KeyError(
                            "no initial state for '{}' of the brick {}".format(
                                state_name, brick.name))
            states_given = dict_subset(kwargs, application.states)

            # Theano issue 1772
            for name, state in states_given.items():
                states_given[name] = tensor.unbroadcast(
                    state, *range(state.ndim))

            def scan_function(*args):
                args = list(args)
                arg_names = (list(sequences_given) + [
                    output for output in application.outputs
                    if output in application.states
                ] + list(contexts_given))
                kwargs = dict(equizip(arg_names, args))
                outputs = application(iterate=False, **kwargs)
                # We want to save the computation graph returned by the
                # `application_function` when it is called inside the
                # `theano.scan`.
                application_call.inner_inputs = args
                application_call.inner_outputs = pack(outputs)
                return outputs

            outputs_info = [
                states_given[name] if name in application.states else None
                for name in application.outputs
            result, updates = theano.scan(
            result = pack(result)
            if return_initial_states:
                # Undo Subtensor
                for i in range(len(states_given)):
                    assert isinstance(result[i].owner.op,
                    result[i] = result[i].owner.inputs[0]
            if updates:
                application_call.updates = dict_union(application_call.updates,

            return result
Exemplo n.º 14
 def shared_variables(self):
     return [var for var in self.variables if is_shared_variable(var)]
Exemplo n.º 15
 def shared_variables(self):
     return [var for var in self.variables if is_shared_variable(var)]
Exemplo n.º 16
        def recurrent_apply(brick, application, application_call,
                            *args, **kwargs):
            """Iterates a transition function.

            iterate : bool
                If ``True`` iteration is made. By default ``True``.
            reverse : bool
                If ``True``, the sequences are processed in backward
                direction. ``False`` by default.
            return_initial_states : bool
                If ``True``, initial states are included in the returned
                state tensors. ``False`` by default.

            .. todo::

                * Handle `updates` returned by the :func:`theano.scan`
                * ``kwargs`` has a random order; check if this is a

            # Extract arguments related to iteration.
            iterate = kwargs.pop('iterate', True)
            reverse = kwargs.pop('reverse', False)
            return_initial_states = kwargs.pop('return_initial_states', False)

            # Push everything to kwargs
            for arg, arg_name in zip(args, arg_names):
                kwargs[arg_name] = arg
            # Separate sequences, states and contexts
            scan_arguments = (application.sequences + application.states +

            # Check what is given and what is not
            def only_given(arg_names):
                return OrderedDict((arg_name, kwargs[arg_name])
                                   for arg_name in arg_names
                                   if kwargs.get(arg_name))
            sequences_given = only_given(application.sequences)
            contexts_given = only_given(application.contexts)

            # TODO Assumes 1 time dim!
            if len(sequences_given):
                shape = list(sequences_given.values())[0].shape
                if not iterate:
                    batch_size = shape[0]
                    n_steps = shape[0]
                    batch_size = shape[1]
                # TODO Raise error if n_steps and batch_size not found?
                n_steps = kwargs.pop('n_steps')
                batch_size = kwargs.pop('batch_size')

            # Handle the rest kwargs
            rest_kwargs = {key: value for key, value in kwargs.items()
                           if key not in scan_arguments}
            for value in rest_kwargs.values():
                if (isinstance(value, Variable) and not
                        'Your function uses a non-shared variable other than'
                        ' those given by scan explicitly. That can'
                        ' significantly slow down `tensor.grad` call.'
                        ' Did you forget to declare it in `contexts`?')

            # Ensure that all initial states are available.
            for state_name in application.states:
                dim = brick.get_dim(state_name)
                if state_name in kwargs:
                    if isinstance(kwargs[state_name], NdarrayInitialization):
                        kwargs[state_name] = tensor.alloc(
                            kwargs[state_name].generate(brick.rng, (1, dim)),
                            batch_size, dim)
                    elif isinstance(kwargs[state_name], Application):
                        kwargs[state_name] = \
                            kwargs[state_name](state_name, batch_size,
                                               *args, **kwargs)
                    # TODO init_func returns 2D-tensor, fails for iterate=False
                    kwargs[state_name] = \
                        brick.initial_state(state_name, batch_size,
                                            *args, **kwargs)
                    assert kwargs[state_name]
            states_given = only_given(application.states)
            assert len(states_given) == len(application.states)

            # Theano issue 1772
            for name, state in states_given.items():
                states_given[name] = tensor.unbroadcast(state,

            # Apply methods
            if not iterate:
                return application_function(brick, **kwargs)

            def scan_function(*args):
                args = list(args)
                arg_names = (list(sequences_given) + list(states_given) +
                kwargs = dict(zip(arg_names, args))
                return application_function(brick, **kwargs)
            outputs_info = (list(states_given.values())
                            + [None] * (len(application.outputs) -
            result, updates = theano.scan(
                scan_function, sequences=list(sequences_given.values()),
            result = pack(result)
            if return_initial_states:
                # Undo Subtensor
                for i in range(len(states_given)):
                    assert isinstance(result[i].owner.op,
                    result[i] = result[i].owner.inputs[0]
            if updates:
                application_call.updates = dict_union(application_call.updates,
            return result