def _get_variables(self): """Collect variables, updates and auxiliary variables. In addition collects all :class:`.Scan` ops and recurses in the respective inner Theano graphs. """ updates = OrderedDict() shared_outputs = [o for o in self.outputs if is_shared_variable(o)] usual_outputs = [o for o in self.outputs if not is_shared_variable(o)] variables = shared_outputs if usual_outputs: # Sort apply nodes topologically, get variables and remove # duplicates inputs = graph.inputs(self.outputs) self.sorted_apply_nodes = graph.io_toposort(inputs, usual_outputs) self.scans = list( unique([ node.op for node in self.sorted_apply_nodes if isinstance(node.op, Scan) ])) self.sorted_scan_nodes = [ node for node in self.sorted_apply_nodes if isinstance(node.op, Scan) ] self._scan_graphs = [ ComputationGraph(scan.outputs) for scan in self.scans ] seen = set() main_vars = ([ var for var in list( chain(*[ apply_node.inputs for apply_node in self.sorted_apply_nodes ])) if not (var in seen or seen.add(var)) ] + [var for var in self.outputs if var not in seen]) # While preserving order add auxiliary variables, and collect # updates seen = set() # Intermediate variables could be auxiliary seen_avs = set(main_vars) variables = [] for var in main_vars: variables.append(var) for annotation in getattr(var.tag, 'annotations', []): if annotation not in seen: seen.add(annotation) new_avs = [ av for av in annotation.auxiliary_variables if not (av in seen_avs or seen_avs.add(av)) ] variables.extend(new_avs) updates = dict_union(updates, annotation.updates) self.variables = variables self.updates = updates
def _get_variables(self): """Collect variables, updates and auxiliary variables. In addition collects all :class:`.Scan` ops and recurses in the respective inner Theano graphs. """ updates = OrderedDict() shared_outputs = [o for o in self.outputs if is_shared_variable(o)] usual_outputs = [o for o in self.outputs if not is_shared_variable(o)] variables = shared_outputs if usual_outputs: # Sort apply nodes topologically, get variables and remove # duplicates inputs = graph.inputs(self.outputs) self.sorted_apply_nodes = graph.io_toposort(inputs, usual_outputs) self.scans = list(unique([node.op for node in self.sorted_apply_nodes if isinstance(node.op, Scan)])) self.sorted_scan_nodes = [node for node in self.sorted_apply_nodes if isinstance(node.op, Scan)] self._scan_graphs = [ComputationGraph(scan.outputs) for scan in self.scans] seen = set() main_vars = ( [var for var in list(chain( *[apply_node.inputs for apply_node in self.sorted_apply_nodes])) if not (var in seen or seen.add(var))] + [var for var in self.outputs if var not in seen]) # While preserving order add auxiliary variables, and collect # updates seen = set() # Intermediate variables could be auxiliary seen_avs = set(main_vars) variables = [] for var in main_vars: variables.append(var) for annotation in getattr(var.tag, 'annotations', []): if annotation not in seen: seen.add(annotation) new_avs = [ av for av in annotation.auxiliary_variables if not (av in seen_avs or seen_avs.add(av))] variables.extend(new_avs) updates = dict_union(updates, annotation.updates) self.variables = variables self.updates = updates
def test_many_steps(self): x = tensor.tensor3('x') mask = tensor.matrix('mask') h = self.simple.apply(x, mask=mask, iterate=True) calc_h = theano.function(inputs=[x, mask], outputs=[h]) x_val = 0.1 * numpy.asarray(list(itertools.permutations(range(4))), dtype=theano.config.floatX) x_val = numpy.ones((24, 4, 3), dtype=theano.config.floatX) * x_val[..., None] mask_val = numpy.ones((24, 4), dtype=theano.config.floatX) mask_val[12:24, 3] = 0 h_val = numpy.zeros((25, 4, 3), dtype=theano.config.floatX) for i in range(1, 25): h_val[i] = numpy.tanh(h_val[i - 1].dot( 2 * numpy.ones((3, 3))) + x_val[i - 1]) h_val[i] = (mask_val[i - 1, :, None] * h_val[i] + (1 - mask_val[i - 1, :, None]) * h_val[i - 1]) h_val = h_val[1:] assert_allclose(h_val, calc_h(x_val, mask_val)[0], rtol=1e-04) # Also test that initial state is a parameter initial_state, = VariableFilter(roles=[INITIAL_STATE])( ComputationGraph(h)) assert is_shared_variable(initial_state) assert initial_state.name == 'initial_state'
def get_snapshot(self, data): """Evaluate all role-carrying Theano variables on given data. Parameters ---------- data : dict of (data source, data) pairs Data for input variables. The sources should match with the names of the input variables. Returns ------- Dictionary of (variable, variable value on given data) pairs. """ role_variables = [ var for var in self.variables if hasattr(var.tag, "roles") and not is_shared_variable(var) ] value_holders = [shared_like(var) for var in role_variables] function = self.get_theano_function( equizip(value_holders, role_variables)) function(*(data[input_.name] for input_ in self.inputs)) return OrderedDict([ (var, value_holder.get_value(borrow=True)) for var, value_holder in equizip(role_variables, value_holders) ])
def test_many_steps(self): x = tensor.tensor3('x') mask = tensor.matrix('mask') h, c = self.lstm.apply(x, mask=mask, iterate=True) calc_h = theano.function(inputs=[x, mask], outputs=[h]) x_val = (0.1 * numpy.asarray( list(itertools.islice(itertools.permutations(range(12)), 0, 24)), dtype=theano.config.floatX)) x_val = numpy.ones((24, 4, 12), dtype=theano.config.floatX) * x_val[:, None, :] mask_val = numpy.ones((24, 4), dtype=theano.config.floatX) mask_val[12:24, 3] = 0 h_val = numpy.zeros((25, 4, 3), dtype=theano.config.floatX) c_val = numpy.zeros((25, 4, 3), dtype=theano.config.floatX) W_state_val = 2 * numpy.ones((3, 12), dtype=theano.config.floatX) W_cell_to_in = 2 * numpy.ones((3,), dtype=theano.config.floatX) W_cell_to_out = 2 * numpy.ones((3,), dtype=theano.config.floatX) W_cell_to_forget = 2 * numpy.ones((3,), dtype=theano.config.floatX) def sigmoid(x): return 1. / (1. + numpy.exp(-x)) for i in range(1, 25): activation = numpy.dot(h_val[i-1], W_state_val) + x_val[i-1] i_t = sigmoid(activation[:, :3] + c_val[i-1] * W_cell_to_in) f_t = sigmoid(activation[:, 3:6] + c_val[i-1] * W_cell_to_forget) c_val[i] = f_t * c_val[i-1] + i_t * numpy.tanh(activation[:, 6:9]) o_t = sigmoid(activation[:, 9:12] + c_val[i] * W_cell_to_out) h_val[i] = o_t * numpy.tanh(c_val[i]) h_val[i] = (mask_val[i - 1, :, None] * h_val[i] + (1 - mask_val[i - 1, :, None]) * h_val[i - 1]) c_val[i] = (mask_val[i - 1, :, None] * c_val[i] + (1 - mask_val[i - 1, :, None]) * c_val[i - 1]) h_val = h_val[1:] assert_allclose(h_val, calc_h(x_val, mask_val)[0], rtol=1e-04) # Also test that initial state is a parameter initial1, initial2 = VariableFilter(roles=[INITIAL_STATE])( ComputationGraph(h)) assert is_shared_variable(initial1) assert is_shared_variable(initial2) assert {initial1.name, initial2.name} == { 'initial_state', 'initial_cells'}
def _get_variables(self): """Collect variables, updates and auxiliary variables.""" updates = OrderedDict() shared_outputs = [o for o in self.outputs if is_shared_variable(o)] usual_outputs = [o for o in self.outputs if not is_shared_variable(o)] variables = shared_outputs if usual_outputs: # Sort apply nodes topologically, get variables and remove # duplicates inputs = graph.inputs(self.outputs) sorted_apply_nodes = graph.io_toposort(inputs, usual_outputs) seen = set() main_vars = [ var for var in list( chain(*[ apply_node.inputs for apply_node in sorted_apply_nodes ])) if not (var in seen or seen.add(var)) ] + self.outputs # While preserving order add auxiliary variables, and collect # updates seen = set() # Intermediate variables could be auxiliary seen_avs = set(main_vars) variables = [] for var in main_vars: variables.append(var) for annotation in getattr(var.tag, 'annotations', []): if annotation not in seen: seen.add(annotation) new_avs = [ av for av in annotation.auxiliary_variables if not (av in seen_avs or seen_avs.add(av)) ] variables.extend(new_avs) updates = dict_union(updates, annotation.updates) self.variables = variables self.updates = updates
def simple_assertions(self, updates, num_bricks=2, num_updates=4): """Shared assertions for simple tests.""" assert len(updates) == num_updates assert all(is_shared_variable(u[0]) for u in updates) # This order is somewhat arbitrary and implementation_dependent means = set(u[0] for u in updates if has_roles(u[0], [BATCH_NORM_POPULATION_MEAN])) stdevs = set(u[0] for u in updates if has_roles(u[0], [BATCH_NORM_POPULATION_STDEV])) assert means.isdisjoint(stdevs) assert len(set(get_brick(v) for v in means)) == num_bricks assert len(set(get_brick(v) for v in stdevs)) == num_bricks
def test_many_steps(self): x = tensor.tensor3('x') gi = tensor.tensor3('gi') mask = tensor.matrix('mask') h = self.reset_only.apply(x, gi, mask=mask) calc_h = theano.function(inputs=[x, gi, mask], outputs=[h]) x_val = 0.1 * numpy.asarray(list(itertools.permutations(range(4))), dtype=theano.config.floatX) x_val = numpy.ones((24, 4, 3), dtype=theano.config.floatX) * x_val[..., None] ri_val = 0.3 - x_val zi_val = 2 * ri_val mask_val = numpy.ones((24, 4), dtype=theano.config.floatX) mask_val[12:24, 3] = 0 h_val = numpy.zeros((25, 4, 3), dtype=theano.config.floatX) W = self.reset_only.state_to_state.get_value() Wz = self.reset_only.state_to_gates.get_value()[:, :3] Wr = self.reset_only.state_to_gates.get_value()[:, 3:] for i in range(1, 25): z_val = numpy.tanh(h_val[i - 1].dot(Wz) + zi_val[i - 1]) r_val = numpy.tanh(h_val[i - 1].dot(Wr) + ri_val[i - 1]) h_val[i] = numpy.tanh((r_val * h_val[i - 1]).dot(W) + x_val[i - 1]) h_val[i] = z_val * h_val[i] + (1 - z_val) * h_val[i - 1] h_val[i] = (mask_val[i - 1, :, None] * h_val[i] + (1 - mask_val[i - 1, :, None]) * h_val[i - 1]) h_val = h_val[1:] # TODO Figure out why this tolerance needs to be so big assert_allclose( h_val, calc_h(x_val, numpy.concatenate( [zi_val, ri_val], axis=2), mask_val)[0], 1e-04) # Also test that initial state is a parameter initial_state, = VariableFilter(roles=[INITIAL_STATE])( ComputationGraph(h)) assert is_shared_variable(initial_state) assert initial_state.name == 'initial_state'
def get_snapshot(self, data): """Evaluate all role-carrying Theano variables on given data. Parameters ---------- data : dict of (data source, data) pairs Data for input variables. The sources should match with the names of the input variables. Returns ------- Dictionary of (variable, variable value on given data) pairs. """ role_variables = [var for var in self.variables if hasattr(var.tag, "roles") and not is_shared_variable(var)] value_holders = [shared_like(var) for var in role_variables] function = self.get_theano_function(zip(value_holders, role_variables)) function(*(data[input_.name] for input_ in self.inputs)) return OrderedDict([(var, value_holder.get_value(borrow=True)) for var, value_holder in zip(role_variables, value_holders)])
def do_many_steps(self, stack, skip_connections=False, low_memory=False): depth = self.depth # 24 steps # 4 batch examples # 12 dimensions per step x_val = (0.1 * numpy.asarray( list(itertools.islice(itertools.permutations(range(12)), 0, 24)), dtype=theano.config.floatX)) x_val = numpy.ones((24, 4, 12), dtype=theano.config.floatX) * x_val[:, None, :] # mask the last third of steps mask_val = numpy.ones((24, 4), dtype=theano.config.floatX) mask_val[12:24, 3] = 0 # unroll all states and cells for all steps and also initial value h_val = numpy.zeros((depth, 25, 4, 3), dtype=theano.config.floatX) c_val = numpy.zeros((depth, 25, 4, 3), dtype=theano.config.floatX) # we will use same weights on all layers W_state2x_val = 2 * numpy.ones((3, 12), dtype=theano.config.floatX) W_state_val = 2 * numpy.ones((3, 12), dtype=theano.config.floatX) W_cell_to_in = 2 * numpy.ones((3,), dtype=theano.config.floatX) W_cell_to_out = 2 * numpy.ones((3,), dtype=theano.config.floatX) W_cell_to_forget = 2 * numpy.ones((3,), dtype=theano.config.floatX) kwargs = OrderedDict() for d in range(depth): if d > 0: suffix = RECURRENTSTACK_SEPARATOR + str(d) else: suffix = '' if d == 0 or skip_connections: kwargs['inputs' + suffix] = tensor.tensor3('inputs' + suffix) kwargs['inputs' + suffix].tag.test_value = x_val kwargs['mask'] = tensor.matrix('mask') kwargs['mask'].tag.test_value = mask_val results = stack.apply(iterate=True, low_memory=low_memory, **kwargs) calc_h = theano.function(inputs=list(kwargs.values()), outputs=results) def sigmoid(x): return 1. / (1. + numpy.exp(-x)) for i in range(1, 25): x_v = x_val[i - 1] h_vs = [] c_vs = [] for d in range(depth): h_v = h_val[d][i - 1, :, :] c_v = c_val[d][i - 1, :, :] activation = numpy.dot(h_v, W_state_val) + x_v if skip_connections and d > 0: activation += x_val[i - 1] i_t = sigmoid(activation[:, :3] + c_v * W_cell_to_in) f_t = sigmoid(activation[:, 3:6] + c_v * W_cell_to_forget) c_v1 = f_t * c_v + i_t * numpy.tanh(activation[:, 6:9]) o_t = sigmoid(activation[:, 9:12] + c_v1 * W_cell_to_out) h_v1 = o_t * numpy.tanh(c_v1) h_v = (mask_val[i - 1, :, None] * h_v1 + (1 - mask_val[i - 1, :, None]) * h_v) c_v = (mask_val[i - 1, :, None] * c_v1 + (1 - mask_val[i - 1, :, None]) * c_v) # current layer output state transformed to input of next x_v = numpy.dot(h_v, W_state2x_val) h_vs.append(h_v) c_vs.append(c_v) for d in range(depth): h_val[d][i, :, :] = h_vs[d] c_val[d][i, :, :] = c_vs[d] args_val = [x_val]*(depth if skip_connections else 1) + [mask_val] res = calc_h(*args_val) for d in range(depth): assert_allclose(h_val[d][1:], res[d * 2], rtol=1e-4) assert_allclose(c_val[d][1:], res[d * 2 + 1], rtol=1e-4) # Also test that initial state is a parameter for h in results: initial_states = VariableFilter(roles=[INITIAL_STATE])( ComputationGraph(h)) assert all(is_shared_variable(initial_state) for initial_state in initial_states)
def recurrent_apply(brick, application, application_call, *args, **kwargs): """Iterates a transition function. Parameters ---------- iterate : bool If ``True`` iteration is made. By default ``True``. reverse : bool If ``True``, the sequences are processed in backward direction. ``False`` by default. return_initial_states : bool If ``True``, initial states are included in the returned state tensors. ``False`` by default. .. todo:: * Handle `updates` returned by the :func:`theano.scan` routine. * ``kwargs`` has a random order; check if this is a problem. """ # Extract arguments related to iteration and immediately relay the # call to the wrapped function if `iterate=False` iterate = kwargs.pop('iterate', True) if not iterate: return application_function(brick, *args, **kwargs) reverse = kwargs.pop('reverse', False) return_initial_states = kwargs.pop('return_initial_states', False) # Push everything to kwargs for arg, arg_name in zip(args, arg_names): kwargs[arg_name] = arg # Make sure that all arguments for scan are tensor variables scan_arguments = (application.sequences + application.states + application.contexts) for arg in scan_arguments: if arg in kwargs: if kwargs[arg] is None: del kwargs[arg] else: kwargs[arg] = tensor.as_tensor_variable(kwargs[arg]) # Check which sequence and contexts were provided sequences_given = dict_subset(kwargs, application.sequences, must_have=False) contexts_given = dict_subset(kwargs, application.contexts, must_have=False) # Determine number of steps and batch size. if len(sequences_given): # TODO Assumes 1 time dim! shape = list(sequences_given.values())[0].shape if not iterate: batch_size = shape[0] else: n_steps = shape[0] batch_size = shape[1] else: # TODO Raise error if n_steps and batch_size not found? n_steps = kwargs.pop('n_steps') batch_size = kwargs.pop('batch_size') # Handle the rest kwargs rest_kwargs = {key: value for key, value in kwargs.items() if key not in scan_arguments} for value in rest_kwargs.values(): if (isinstance(value, Variable) and not is_shared_variable(value)): logger.warning("unknown input {}".format(value) + unknown_scan_input) # Ensure that all initial states are available. for state_name in application.states: dim = brick.get_dim(state_name) if state_name in kwargs: if isinstance(kwargs[state_name], NdarrayInitialization): kwargs[state_name] = tensor.alloc( kwargs[state_name].generate(brick.rng, (1, dim)), batch_size, dim) elif isinstance(kwargs[state_name], Application): kwargs[state_name] = ( kwargs[state_name](state_name, batch_size, *args, **kwargs)) else: # TODO init_func returns 2D-tensor, fails for iterate=False kwargs[state_name] = ( brick.initial_state(state_name, batch_size, *args, **kwargs)) assert kwargs[state_name] states_given = dict_subset(kwargs, application.states) # Theano issue 1772 for name, state in states_given.items(): states_given[name] = tensor.unbroadcast(state, *range(state.ndim)) def scan_function(*args): args = list(args) arg_names = (list(sequences_given) + [output for output in application.outputs if output in application.states] + list(contexts_given)) kwargs = dict(equizip(arg_names, args)) kwargs.update(rest_kwargs) outputs = application(iterate=False, **kwargs) # We want to save the computation graph returned by the # `application_function` when it is called inside the # `theano.scan`. application_call.inner_inputs = args application_call.inner_outputs = pack(outputs) return outputs outputs_info = [ states_given[name] if name in application.states else None for name in application.outputs] result, updates = theano.scan( scan_function, sequences=list(sequences_given.values()), outputs_info=outputs_info, non_sequences=list(contexts_given.values()), n_steps=n_steps, go_backwards=reverse) result = pack(result) if return_initial_states: # Undo Subtensor for i in range(len(states_given)): assert isinstance(result[i].owner.op, tensor.subtensor.Subtensor) result[i] = result[i].owner.inputs[0] if updates: application_call.updates = dict_union(application_call.updates, updates) return result
def recurrent_apply(brick, application, application_call, *args, **kwargs): """Iterates a transition function. Parameters ---------- iterate : bool If ``True`` iteration is made. By default ``True``. reverse : bool If ``True``, the sequences are processed in backward direction. ``False`` by default. return_initial_states : bool If ``True``, initial states are included in the returned state tensors. ``False`` by default. """ # Extract arguments related to iteration and immediately relay the # call to the wrapped function if `iterate=False` iterate = kwargs.pop('iterate', True) if not iterate: return application_function(brick, *args, **kwargs) reverse = kwargs.pop('reverse', False) return_initial_states = kwargs.pop('return_initial_states', False) # Push everything to kwargs for arg, arg_name in zip(args, arg_names): kwargs[arg_name] = arg # Make sure that all arguments for scan are tensor variables scan_arguments = (application.sequences + application.states + application.contexts) for arg in scan_arguments: if arg in kwargs: if kwargs[arg] is None: del kwargs[arg] else: kwargs[arg] = tensor.as_tensor_variable(kwargs[arg]) # Check which sequence and contexts were provided sequences_given = dict_subset(kwargs, application.sequences, must_have=False) contexts_given = dict_subset(kwargs, application.contexts, must_have=False) # Determine number of steps and batch size. if len(sequences_given): # TODO Assumes 1 time dim! shape = list(sequences_given.values())[0].shape n_steps = shape[0] batch_size = shape[1] else: # TODO Raise error if n_steps and batch_size not found? n_steps = kwargs.pop('n_steps') batch_size = kwargs.pop('batch_size') # Handle the rest kwargs rest_kwargs = { key: value for key, value in kwargs.items() if key not in scan_arguments } for value in rest_kwargs.values(): if (isinstance(value, Variable) and not is_shared_variable(value)): logger.warning("unknown input {}".format(value) + unknown_scan_input) # Ensure that all initial states are available. initial_states = brick.initial_states(batch_size, as_dict=True, *args, **kwargs) for state_name in application.states: dim = brick.get_dim(state_name) if state_name in kwargs: if isinstance(kwargs[state_name], NdarrayInitialization): kwargs[state_name] = tensor.alloc( kwargs[state_name].generate(brick.rng, (1, dim)), batch_size, dim) elif isinstance(kwargs[state_name], Application): kwargs[state_name] = (kwargs[state_name](state_name, batch_size, *args, **kwargs)) else: try: kwargs[state_name] = initial_states[state_name] except KeyError: raise KeyError( "no initial state for '{}' of the brick {}".format( state_name, brick.name)) states_given = dict_subset(kwargs, application.states) # Theano issue 1772 for name, state in states_given.items(): states_given[name] = tensor.unbroadcast( state, *range(state.ndim)) def scan_function(*args): args = list(args) arg_names = (list(sequences_given) + [ output for output in application.outputs if output in application.states ] + list(contexts_given)) kwargs = dict(equizip(arg_names, args)) kwargs.update(rest_kwargs) outputs = application(iterate=False, **kwargs) # We want to save the computation graph returned by the # `application_function` when it is called inside the # `theano.scan`. application_call.inner_inputs = args application_call.inner_outputs = pack(outputs) return outputs outputs_info = [ states_given[name] if name in application.states else None for name in application.outputs ] result, updates = theano.scan( scan_function, sequences=list(sequences_given.values()), outputs_info=outputs_info, non_sequences=list(contexts_given.values()), n_steps=n_steps, go_backwards=reverse, name='{}_{}_scan'.format(brick.name, application.application_name)) result = pack(result) if return_initial_states: # Undo Subtensor for i in range(len(states_given)): assert isinstance(result[i].owner.op, tensor.subtensor.Subtensor) result[i] = result[i].owner.inputs[0] if updates: application_call.updates = dict_union(application_call.updates, updates) return result
def shared_variables(self): return [var for var in self.variables if is_shared_variable(var)]
def recurrent_apply(brick, application, application_call, *args, **kwargs): """Iterates a transition function. Parameters ---------- iterate : bool If ``True`` iteration is made. By default ``True``. reverse : bool If ``True``, the sequences are processed in backward direction. ``False`` by default. return_initial_states : bool If ``True``, initial states are included in the returned state tensors. ``False`` by default. .. todo:: * Handle `updates` returned by the :func:`theano.scan` routine. * ``kwargs`` has a random order; check if this is a problem. """ # Extract arguments related to iteration. iterate = kwargs.pop('iterate', True) reverse = kwargs.pop('reverse', False) return_initial_states = kwargs.pop('return_initial_states', False) # Push everything to kwargs for arg, arg_name in zip(args, arg_names): kwargs[arg_name] = arg # Separate sequences, states and contexts scan_arguments = (application.sequences + application.states + application.contexts) # Check what is given and what is not def only_given(arg_names): return OrderedDict((arg_name, kwargs[arg_name]) for arg_name in arg_names if kwargs.get(arg_name)) sequences_given = only_given(application.sequences) contexts_given = only_given(application.contexts) # TODO Assumes 1 time dim! if len(sequences_given): shape = list(sequences_given.values())[0].shape if not iterate: batch_size = shape[0] else: n_steps = shape[0] batch_size = shape[1] else: # TODO Raise error if n_steps and batch_size not found? n_steps = kwargs.pop('n_steps') batch_size = kwargs.pop('batch_size') # Handle the rest kwargs rest_kwargs = {key: value for key, value in kwargs.items() if key not in scan_arguments} for value in rest_kwargs.values(): if (isinstance(value, Variable) and not is_shared_variable(value)): warnings.warn( 'Your function uses a non-shared variable other than' ' those given by scan explicitly. That can' ' significantly slow down `tensor.grad` call.' ' Did you forget to declare it in `contexts`?') # Ensure that all initial states are available. for state_name in application.states: dim = brick.get_dim(state_name) if state_name in kwargs: if isinstance(kwargs[state_name], NdarrayInitialization): kwargs[state_name] = tensor.alloc( kwargs[state_name].generate(brick.rng, (1, dim)), batch_size, dim) elif isinstance(kwargs[state_name], Application): kwargs[state_name] = \ kwargs[state_name](state_name, batch_size, *args, **kwargs) else: # TODO init_func returns 2D-tensor, fails for iterate=False kwargs[state_name] = \ brick.initial_state(state_name, batch_size, *args, **kwargs) assert kwargs[state_name] states_given = only_given(application.states) assert len(states_given) == len(application.states) # Theano issue 1772 for name, state in states_given.items(): states_given[name] = tensor.unbroadcast(state, *range(state.ndim)) # Apply methods if not iterate: return application_function(brick, **kwargs) def scan_function(*args): args = list(args) arg_names = (list(sequences_given) + list(states_given) + list(contexts_given)) kwargs = dict(zip(arg_names, args)) kwargs.update(rest_kwargs) return application_function(brick, **kwargs) outputs_info = (list(states_given.values()) + [None] * (len(application.outputs) - len(application.states))) result, updates = theano.scan( scan_function, sequences=list(sequences_given.values()), outputs_info=outputs_info, non_sequences=list(contexts_given.values()), n_steps=n_steps, go_backwards=reverse) result = pack(result) if return_initial_states: # Undo Subtensor for i in range(len(states_given)): assert isinstance(result[i].owner.op, tensor.subtensor.Subtensor) result[i] = result[i].owner.inputs[0] if updates: application_call.updates = dict_union(application_call.updates, updates) return result