def get_monitoring_channels(self, data): rval = OrderedDict() g_ch = self.generator.get_monitoring_channels(data) d_ch = self.discriminator.get_monitoring_channels((data, None)) samples, _, conditional_data, _ = self.generator.sample_and_noise(100) d_samp_ch = self.discriminator.get_monitoring_channels(((samples, conditional_data), None)) i_ch = OrderedDict() if self.inferer is not None: batch_size = self.inference_monitoring_batch_size sample, noise, conditional_data, _ = self.generator.sample_and_noise(batch_size) i_ch.update(self.inferer.get_monitoring_channels(((sample, conditional_data), noise))) if self.monitor_generator: for key in g_ch: rval["gen_" + key] = g_ch[key] if self.monitor_discriminator: for key in d_ch: rval["dis_on_data_" + key] = d_samp_ch[key] for key in d_ch: rval["dis_on_samp_" + key] = d_ch[key] if self.monitor_inference: for key in i_ch: rval["inf_" + key] = i_ch[key] return rval
def __init__(self, valid=None, invalid=None, valid_equivalent=None): ''' Check if variables can be expressed without using variables in invalid. init_valid_equivalent provides a dictionary mapping some invalid variables to valid ones that can be used instead. ''' if valid is None: valid = [] if invalid is None: invalid = [] if valid_equivalent is None: valid_equivalent = OrderedDict() # Nodes that are valid to have in the graph computing outputs self.valid = set(valid) # Nodes that are NOT valid to have in the graph computing outputs self.invalid = set(invalid) # Mapping from invalid variables to equivalent valid ones. self.valid_equivalent = valid_equivalent.copy() self.valid.update(valid_equivalent.values()) self.invalid.update(valid_equivalent.keys())
def get_monitoring_channels(self, data): rval = OrderedDict() try: rval.update(self.mlp.get_monitoring_channels(data)) except Exception: warnings.warn("something went wrong with compressor.mlp's monitoring channels") return rval
def orderings(self): """ Return dict d s.t. d[node] is a list of nodes that must be evaluated before node itself can be evaluated. This is used primarily by the destroy_handler feature to ensure that all clients of any destroyed inputs have already computed their outputs. Notes ----- This only calls the orderings() fct on all features. It does not take care of computing dependencies by itself. """ ords = OrderedDict() assert isinstance(self._features, list) for feature in self._features: if hasattr(feature, 'orderings'): orderings = feature.orderings(self) if not isinstance(orderings, OrderedDict): raise TypeError("Non-deterministic return value from " + str(feature.orderings) + ". Nondeterministic object is " + str(orderings)) for node, prereqs in iteritems(orderings): if not isinstance(prereqs, (list, OrderedSet)): raise TypeError( "prereqs must be a type with a " "deterministic iteration order, or toposort " " will be non-deterministic.") ords.setdefault(node, []).extend(prereqs) # eliminate duplicate prereqs for (node, prereqs) in iteritems(ords): ords[node] = list(OrderedSet(prereqs)) return ords
def get_layer_monitoring_channels(self, state_below=None, state=None, targets=None): W, = self.transformer.get_params() assert W.ndim == 4 sq_W = T.sqr(W) row_norms = T.sqrt(sq_W.sum(axis=(0, 1, 2))) P = state rval = OrderedDict() vars_and_prefixes = [(P, '')] for var, prefix in vars_and_prefixes: if not hasattr(var, 'ndim') or var.ndim != 4: print "expected 4D tensor, got " print var print type(var) if isinstance(var, tuple): print "tuple length: ", len(var) assert False v_max = var.max(axis=3) v_min = var.min(axis=3) v_mean = var.mean(axis=3) v_range = v_max - v_min v_max = v_max.max(axis=(1,2)) v_min = v_min.min(axis=(1,2)) # max_x.mean_u is "the mean over *u*nits of the max over # e*x*amples" The x and u are included in the name because # otherwise its hard to remember which axis is which when reading # the monitor I use inner.outer rather than outer_of_inner or # something like that because I want mean_x.* to appear next to # each other in the alphabetical list, as these are commonly # plotted together for key, val in [('max_x.max_u', v_max.max()), ('max_x.mean_u', v_max.mean()), ('max_x.min_u', v_max.min()), ('min_x.max_u', v_min.max()), ('min_x.mean_u', v_min.mean()), ('min_x.min_u', v_min.min()), ('range_x.max_u', v_range.max()), ('range_x.mean_u', v_range.mean()), ('range_x.min_u', v_range.min()), ('mean_x.max_u', v_mean.max()), ('mean_x.mean_u', v_mean.mean()), ('mean_x.min_u', v_mean.min())]: rval[prefix+key] = val rval.update(OrderedDict([('kernel_norms_min', row_norms.min()), ('kernel_norms_mean', row_norms.mean()), ('kernel_norms_max', row_norms.max()), ])) return rval
def on_attach(self, fgraph): """ When attaching to a new fgraph, check that 1) This DestroyHandler wasn't already attached to some fgraph (its data structures are only set up to serve one). 2) The FunctionGraph doesn't already have a DestroyHandler. This would result in it validating everything twice, causing compilation to be slower. Give the FunctionGraph instance: 1) A new method "destroyers(var)" TODO: what does this do exactly? 2) A new attribute, "destroy_handler" TODO: WRITEME: what does this do besides the checks? """ # Do the checking # already_there = False if self.fgraph is fgraph: already_there = True if self.fgraph is not None: raise Exception( "A DestroyHandler instance can only serve one" " FunctionGraph. (Matthew 6:24)") for attr in ('destroyers', 'destroy_handler'): if hasattr(fgraph, attr): already_there = True if already_there: # FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment raise toolbox.AlreadyThere( "DestroyHandler feature is already present" " or in conflict with another plugin.") # Annotate the FunctionGraph # self.unpickle(fgraph) fgraph.destroy_handler = self self.fgraph = fgraph self.destroyers = OrderedSet() # set of Apply instances with non-null destroy_map self.view_i = OrderedDict() # variable -> variable used in calculation self.view_o = OrderedDict() # variable -> set of variables that use this one as a direct input # clients: how many times does an apply use a given variable self.clients = OrderedDict() # variable -> apply -> ninputs self.stale_droot = True self.debug_all_apps = OrderedSet() if self.do_imports_on_attach: toolbox.Bookkeeper.on_attach(self, fgraph)
def __init__(self, *axis): # Sort them to make sure we merge all possible case. items = sorted(axis) self.axis = OrderedDict(items) for axis, broad in iteritems(self.axis): assert isinstance(axis, (numpy.integer, int)), ("Rebroadcast needs integer axes. Got ", axis) assert isinstance(broad, bool), ("Rebroadcast needs bool for new broadcast pattern. Got ", broad)
def get_layer_monitoring_channels(self, state_below=None, state=None, targets=None): W, = self.transformer.get_params() assert W.ndim == 5 sq_W = T.sqr(W) row_norms = T.sqrt(sq_W.sum(axis=(1, 2, 3, 4))) rval = OrderedDict([ ('kernel_norms_min', row_norms.min()), ('kernel_norms_mean', row_norms.mean()), ('kernel_norms_max', row_norms.max()), ]) cost = self.cost orval = self.nonlin.get_monitoring_channels_from_state(state, targets, cost_fn=cost) rval.update(orval) return rval
def __init__(self, valid=None, invalid=None, valid_equivalent=None): if valid is None: valid = [] if invalid is None: invalid = [] if valid_equivalent is None: valid_equivalent = OrderedDict() # Nodes that are valid to have in the graph computing outputs self.valid = set(valid) # Nodes that are NOT valid to have in the graph computing outputs self.invalid = set(invalid) # Mapping from invalid variables to equivalent valid ones. self.valid_equivalent = valid_equivalent.copy() self.valid.update(list(valid_equivalent.values())) self.invalid.update(list(valid_equivalent.keys()))
def get_gradients(self, model, data, **kwargs): space, sources = self.get_data_specs(model) space.validate(data) assert isinstance(model, CompressAdversaryPair) g = model.compressor d = model.discriminator #get raw gradients for d and g objectives... d_obj, g_obj = self.get_objectives(model, data) g_params = g.get_params() d_params = d.get_params() for param in g_params: assert param not in d_params for param in d_params: assert param not in g_params d_grads = T.grad(d_obj, d_params) g_grads = T.grad(g_obj, g_params) # if self.scale_grads: # S_grad = T.grad(g_obj, S) # scale = T.maximum(1., self.target_scale / T.sqrt(T.sqr(S_grad).sum())) # g_grads = [g_grad * scale for g_grad in g_grads] #adjust raw gradients with control signals rval = OrderedDict() zeros = itertools.repeat(theano.tensor.constant(0., dtype='float32')) if self.ever_train_discriminator: rval.update(OrderedDict(safe_zip(d_params, [self.now_train_discriminator * dg for dg in d_grads]))) else: rval.update(OrderedDict(zip(d_params, zeros))) if self.ever_train_compressor: rval.update(OrderedDict(safe_zip(g_params, [self.now_train_compressor * gg for gg in g_grads]))) else: rval.update(OrderedDict(zip(g_params, zeros))) #update control signals using the updates return functionality updates = OrderedDict() #first, the clock self.future_train_clock = T.switch(T.ge(self.train_clock,self.discriminator_steps+self.joint_steps+self.compressor_steps),1.,self.train_clock+1.) updates[self.train_clock] = self.future_train_clock #then the control signals updates[self.now_train_discriminator] = T.switch(T.le(self.future_train_clock,self.discriminator_steps+self.joint_steps),1.,0.) updates[self.now_train_compressor] = T.switch(T.gt(self.future_train_clock,self.discriminator_steps),1.,0.) return rval, updates
def get_monitoring_channels(self, data): if data is None: m = 100 else: m = data.shape[0] n = self.mlp.get_input_space().get_total_dimension() noise = self.get_noise((m, n)) rval = OrderedDict() try: rval.update(self.mlp.get_monitoring_channels((noise, None))) except Exception: warnings.warn("something went wrong with generator.mlp's monitoring channels") if self.monitor_ll: rval['ll'] = T.cast(self.ll(data, self.ll_n_samples, self.ll_sigma), theano.config.floatX).mean() rval['nll'] = -rval['ll'] return rval
def orderings(self, function_graph): """ Called by toposort. It should return a dictionary of {node: predecessors} where predecessors is a list of nodes that should be computed before the key node. If you raise an exception in this function, the state of the graph might be broken for all intents and purposes. """ return OrderedDict()
def __init__(self, do_imports_on_attach=True): self.fgraph = None self.do_imports_on_attach = do_imports_on_attach """maps every variable in the graph to its "foundation" (deepest ancestor in view chain) TODO: change name to var_to_vroot""" self.droot = OrderedDict() """maps a variable to all variables that are indirect or direct views of it (including itself) essentially the inverse of droot TODO: do all variables appear in this dict, or only those that are foundations? TODO: do only destroyed variables go in here? one old docstring said so TODO: rename to x_to_views after reverse engineering what x is""" self.impact = OrderedDict() """if a var is destroyed, then this dict will map droot[var] to the apply node that destroyed var TODO: rename to vroot_to_destroyer""" self.root_destroyer = OrderedDict()
def get_monitoring_channels(self, data): if data is None: m = 100 conditional_data = self.condition_distribution.sample(m) else: _, conditional_data = data m = conditional_data.shape[0] noise = self.get_noise((m, self.noise_dim)) rval = OrderedDict() sampled_data = (noise, conditional_data) try: rval.update(self.mlp.get_monitoring_channels((sampled_data, None))) except Exception: warnings.warn("something went wrong with generator.mlp's monitoring channels") if self.monitor_ll: rval["ll"] = T.cast(self.ll(data, self.ll_n_samples, self.ll_sigma), theano.config.floatX).mean() rval["nll"] = -rval["ll"] return rval
def __init__(self, *axis): # Sort them to make sure we merge all possible case. items = sorted(axis) self.axis = OrderedDict(items) for axis, broad in iteritems(self.axis): if not isinstance(axis, (numpy.integer, integer_types)): raise TypeError("Rebroadcast needs integer axes. " "Got {}".format(axis)) if not isinstance(broad, (numpy.bool_, bool)): raise TypeError("Rebroadcast needs bool for new broadcast " "pattern. Got {}".format(broad))
def test_known_grads(): # Tests that the grad method with no known_grads # matches what happens if you put its own known_grads # in for each variable full_range = theano.tensor.arange(10) x = theano.tensor.scalar('x') t = theano.tensor.iscalar('t') ft = full_range[t] ft.name = 'ft' coeffs = theano.tensor.vector('c') ct = coeffs[t] ct.name = 'ct' p = x**ft p.name = 'p' y = ct * p y.name = 'y' cost = theano.tensor.sqr(y) cost.name = 'cost' layers = [[cost], [y], [ct, p], [ct, x, ft], [coeffs, t, full_range, x]] inputs = [coeffs, t, x] rng = np.random.RandomState([2012, 11, 15]) values = [rng.randn(10), rng.randint(10), rng.randn()] values = [np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values)] true_grads = theano.tensor.grad(cost, inputs, disconnected_inputs='ignore') true_grads = theano.function(inputs, true_grads) true_grads = true_grads(*values) for layer in layers: first = theano.tensor.grad(cost, layer, disconnected_inputs='ignore') known = OrderedDict(izip(layer, first)) full = theano.tensor.grad(cost=None, known_grads=known, wrt=inputs, disconnected_inputs='ignore') full = theano.function(inputs, full) full = full(*values) assert len(true_grads) == len(full) for a, b, var in zip(true_grads, full, inputs): if not np.allclose(a, b): print('Failure') print(a) print(b) print(var) print(layer) for v in known: print(v, ':', theano.function(inputs, known[v])(*values)) assert False
def run(replay, log=None): if not replay: log = StringIO() else: log = StringIO(log) record = Record(replay=replay, file_object=log) disturb_mem.disturb_mem() mode = RecordMode(record=record) b = sharedX(np.zeros((2, )), name='b') channels = OrderedDict() disturb_mem.disturb_mem() v_max = b.max(axis=0) v_min = b.min(axis=0) v_range = v_max - v_min updates = [] for i, val in enumerate([ v_max.max(), v_max.min(), v_range.max(), ]): disturb_mem.disturb_mem() s = sharedX(0., name='s_' + str(i)) updates.append((s, val)) for var in theano.gof.graph.ancestors(update for _, update in updates): if var.name is not None and var.name is not 'b': if var.name[0] != 's' or len(var.name) != 2: var.name = None for key in channels: updates.append((s, channels[key])) f = theano.function([], mode=mode, updates=updates, on_unused_input='ignore', name='f') for output in f.maker.fgraph.outputs: mode.record.handle_line(var_descriptor(output) + '\n') disturb_mem.disturb_mem() f() mode.record.f.flush() if not replay: return log.getvalue()
def get_monitoring_channels(self, data): if data is None: m = 100 else: m = data.shape[0] n = self.mlp.get_input_space().get_total_dimension() noise = self.get_noise((m, n)) rval = OrderedDict() try: rval.update(self.mlp.get_monitoring_channels((noise, None))) except Exception: warnings.warn( "something went wrong with generator.mlp's monitoring channels" ) if self.monitor_ll: rval['ll'] = T.cast( self.ll(data, self.ll_n_samples, self.ll_sigma), theano.config.floatX).mean() rval['nll'] = -rval['ll'] return rval
def on_change_input(self, fgraph, app, i, old_r, new_r, reason): """app.inputs[i] changed from old_r to new_r """ if app == 'output': # app == 'output' is special key that means FunctionGraph is redefining which nodes are being # considered 'outputs' of the graph. pass else: if app not in self.debug_all_apps: raise ProtocolError("change without import") # UPDATE self.clients self.clients[old_r][app] -= 1 if self.clients[old_r][app] == 0: del self.clients[old_r][app] self.clients.setdefault(new_r, OrderedDict()).setdefault(app, 0) self.clients[new_r][app] += 1 # UPDATE self.view_i, self.view_o for o_idx, i_idx_list in iteritems( getattr(app.op, 'view_map', OrderedDict())): if len(i_idx_list) > 1: # destroying this output invalidates multiple inputs raise NotImplementedError() i_idx = i_idx_list[0] output = app.outputs[o_idx] if i_idx == i: if app.inputs[i_idx] is not new_r: raise ProtocolError("wrong new_r on change") self.view_i[output] = new_r self.view_o[old_r].remove(output) if not self.view_o[old_r]: del self.view_o[old_r] self.view_o.setdefault(new_r, OrderedSet()).add(output) self.stale_droot = True
def get_lr_scalers(self): """ .. todo:: WRITEME """ rval = OrderedDict() params = self.get_params() for layer in self.hidden_layers + [ self.visible_layer ]: contrib = layer.get_lr_scalers() # No two layers can contend to scale a parameter assert not any([key in rval for key in contrib]) # Don't try to scale anything that's not a parameter assert all([key in params for key in contrib]) rval.update(contrib) assert all([isinstance(val, float) for val in rval.values()]) return rval
def test_hash_from_dict(): dicts = [{}, {0: 0}, {0: 1}, {1: 0}, {1: 1}, {0: (0,)}, {0: [1]}, {0: (0, 1)}, {0: [1, 0]}] for elem in dicts[:]: dicts.append(OrderedDict(elem)) hashs = [] for idx, d in enumerate(dicts): h = hash_from_dict(d) assert h not in hashs hashs.append(h) # List are not hashable. So they are transformed into tuple. assert hash_from_dict({0: (0,)}) == hash_from_dict({0: [0]})
def on_prune(self, fgraph, app, reason): """ Remove Apply instance from set which must be computed. """ if app not in self.debug_all_apps: raise ProtocolError("prune without import") self.debug_all_apps.remove(app) # UPDATE self.clients for i, input in enumerate(OrderedSet(app.inputs)): del self.clients[input][app] if getattr(app.op, 'destroy_map', OrderedDict()): self.destroyers.remove(app) # Note: leaving empty client dictionaries in the struct. # Why? It's a pain to remove them. I think they aren't doing any harm, they will be # deleted on_detach(). # UPDATE self.view_i, self.view_o for o_idx, i_idx_list in iteritems(getattr(app.op, 'view_map', OrderedDict())): if len(i_idx_list) > 1: # destroying this output invalidates multiple inputs raise NotImplementedError() o = app.outputs[o_idx] i = app.inputs[i_idx_list[0]] del self.view_i[o] self.view_o[i].remove(o) if not self.view_o[i]: del self.view_o[i] self.stale_droot = True
def refresh_droot_impact(self): """ Makes sure self.droot, self.impact, and self.root_destroyer are up to date, and returns them. (see docstrings for these properties above) """ if self.stale_droot: droot = OrderedDict( ) # destroyed view + nonview variables -> foundation impact = OrderedDict( ) # destroyed nonview variable -> it + all views of it root_destroyer = OrderedDict() # root -> destroyer apply for app in self.destroyers: for output_idx, input_idx_list in iteritems( app.op.destroy_map): if len(input_idx_list) != 1: raise NotImplementedError() input_idx = input_idx_list[0] input = app.inputs[input_idx] input_root = getroot(input, self.view_i) if input_root in droot: raise InconsistencyError("Multiple destroyers of %s" % input_root) droot[input_root] = input_root root_destroyer[input_root] = app input_impact = get_impact(input_root, self.view_o) for v in input_impact: assert v not in droot droot[v] = input_root impact[input_root] = input_impact impact[input_root].add(input_root) self.droot, self.impact, self.root_destroyer = droot, impact, root_destroyer self.stale_droot = False return self.droot, self.impact, self.root_destroyer
def make_layer_to_state(self, num_examples, rng=None): """ Makes and returns a dictionary mapping layers to states. By states, we mean here a real assignment, not a mean field state. For example, for a layer containing binary random variables, the state will be a shared variable containing values in {0,1}, not [0,1]. The visible layer will be included. Uses a dictionary so it is easy to unambiguously index a layer without needing to remember rules like vis layer = 0, hiddens start at 1, etc. Parameters ---------- num_examples : int WRITEME rng : WRITEME """ # Make a list of all layers layers = [self.visible_layer] + self.hidden_layers if rng is None: rng = self.rng states = [layer.make_state(num_examples, rng) for layer in layers] zipped = safe_zip(layers, states) def recurse_check(layer, state): if isinstance(state, (list, tuple)): for elem in state: recurse_check(layer, elem) else: val = state.get_value() m = val.shape[0] if m != num_examples: raise ValueError(layer.layer_name + " gave state with " + str(m) + " examples in some component." "We requested " + str(num_examples)) for layer, state in zipped: recurse_check(layer, state) rval = OrderedDict(zipped) return rval
def gradient_descent(self, loss): """Momentum GD with gradient clipping.""" grad = T.grad(loss, self.params) self.momentum_velocity_ = [0.] * len(grad) grad_norm = T.sqrt(sum(map(lambda x: T.sqr(x).sum(), grad))) updates = OrderedDict() not_finite = T.or_(T.isnan(grad_norm), T.isinf(grad_norm)) scaling_den = T.maximum(5.0, grad_norm) for n, (param, grad) in enumerate(zip(self.params, grad)): grad = T.switch(not_finite, 0.1 * param, grad * (5.0 / scaling_den)) velocity = self.momentum_velocity_[n] update_step = self.momentum * velocity - self.learning_rate * grad self.momentum_velocity_[n] = update_step updates[param] = param + update_step return updates
def get_layer_monitoring_channels(self, state_below=None, state=None, targets=None): rval = OrderedDict() if state is None: state = self.fprop(state_below) vars_and_prefixes = [(state, '')] for var, prefix in vars_and_prefixes: # print "average output: ", var.ndim, type(var) # if not hasattr(var, 'ndim') or var.ndim != 4: # print "expected 4D tensor, got " # print var # print type(var) # if isinstance(var, tuple): # print "tuple length: ", len(var) # assert False v_max = var.max(axis=1) v_min = var.min(axis=1) v_mean = var.mean(axis=1) v_range = v_max - v_min # max_x.mean_u is "the mean over *u*nits of the max over # e*x*amples" The x and u are included in the name because # otherwise its hard to remember which axis is which when reading # the monitor I use inner.outer rather than outer_of_inner or # something like that because I want mean_x.* to appear next to # each other in the alphabetical list, as these are commonly # plotted together for key, val in [('max_x.max_u', v_max.max()), ('max_x.mean_u', v_max.mean()), ('max_x.min_u', v_max.min()), ('min_x.max_u', v_min.max()), ('min_x.mean_u', v_min.mean()), ('min_x.min_u', v_min.min()), ('range_x.max_u', v_range.max()), ('range_x.mean_u', v_range.mean()), ('range_x.min_u', v_range.min()), ('mean_x.max_u', v_mean.max()), ('mean_x.mean_u', v_mean.mean()), ('mean_x.min_u', v_mean.min())]: rval[prefix + key] = val return rval
def get_lr_scalers(self): if not hasattr(self, 'W_lr_scale'): self.W_lr_scale = None if not hasattr(self, 'b_lr_scale'): self.b_lr_scale = None rval = OrderedDict() if self.W_lr_scale is not None: W, = self.transformer.get_params() rval[W] = self.W_lr_scale if self.b_lr_scale is not None: rval[self.b] = self.b_lr_scale return rval
def update(self, other=None): if other is None: return if (isinstance(other, dict) and len(other) > 1 and not isinstance(other, OrderedDict)): # Warn about non-determinism. warnings.warn('Updating an `OrderedUpdates` with a ' 'non-ordered dictionary with 2+ elements could ' 'make your code non-deterministic', stacklevel=2) for key, val in iteritems(OrderedDict(other)): if key in self: if self[key] == val: continue raise KeyError('Collision', key) self[key] = val # __setitem__ does type-checking
def get_monitoring_channels(self, model, data, **kwargs): X_pure, Y_pure = data X_pure.tag.test_value = numpy.random.random( size=[5, 784]).astype('float32') Y_pure.tag.test_value = numpy.random.randint(10, size=[5, 1]).astype('int64') rval = OrderedDict() g = model.compressor d = model.discriminator yhat_pure = T.argmax(d.fprop(X_pure), axis=1).dimshuffle(0, 'x') yhat_reconstructed = T.argmax(d.fprop(g.reconstruct(X_pure)), axis=1).dimshuffle(0, 'x') rval['conviction_pure'] = T.cast(T.eq(yhat_pure, 10).mean(), 'float32') rval['accuracy_pure'] = T.cast( T.eq(yhat_pure, Y_pure).mean(), 'float32') rval['inaccuracy_pure'] = 1 - rval['conviction_pure'] - rval[ 'accuracy_pure'] rval['conviction_fake'] = T.cast( T.eq(yhat_reconstructed, 10).mean(), 'float32') rval['accuracy_fake'] = T.cast( T.eq(yhat_reconstructed, Y_pure).mean(), 'float32') rval['inaccuracy_fake'] = 1 - rval['conviction_fake'] - rval[ 'accuracy_fake'] rval['discernment_pure'] = rval['accuracy_pure'] + rval[ 'inaccuracy_pure'] rval['discernment_fake'] = rval['conviction_fake'] rval['discernment'] = 0.5 * (rval['discernment_pure'] + rval['discernment_fake']) # y = T.alloc(0., m, 1) d_obj, g_obj = self.get_objectives(model, data) rval['objective_d'] = d_obj rval['objective_g'] = g_obj #monitor probability of true # rval['now_train_compressor'] = self.now_train_compressor return rval
def test_subgraph_grad(): # Tests that the grad method with no known_grads # matches what happens if you use successive subgraph_grads x = theano.tensor.fvector('x') t = theano.tensor.fvector('t') w1 = theano.shared(np.random.randn(3, 4)) w2 = theano.shared(np.random.randn(4, 2)) a1 = theano.tensor.tanh(theano.tensor.dot(x, w1)) a2 = theano.tensor.tanh(theano.tensor.dot(a1, w2)) cost2 = theano.tensor.sqr(a2 - t).sum() cost2 += theano.tensor.sqr(w2.sum()) cost1 = theano.tensor.sqr(w1.sum()) params = [[w2], [w1]] costs = [cost2, cost1] grad_ends = [[a1], [x]] inputs = [t, x] rng = np.random.RandomState([2012, 11, 15]) values = [rng.randn(2), rng.randn(3)] values = [np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values)] wrt = [w2, w1] cost = cost2 + cost1 true_grads = theano.grad(cost, wrt) true_grads = theano.function(inputs, true_grads) true_grads = true_grads(*values) next_grad = None param_grads = [] for i in xrange(2): param_grad, next_grad = theano.subgraph_grad(wrt=params[i], end=grad_ends[i], start=next_grad, cost=costs[i]) next_grad = OrderedDict(zip(grad_ends[i], next_grad)) param_grads.extend(param_grad) pgrads = theano.function(inputs, param_grads) pgrads = pgrads(*values) for true_grad, pgrad in zip(true_grads, pgrads): assert (np.sum(np.abs(true_grad - pgrad)) < 0.00001)
def get_monitoring_channels(self, data): rval = OrderedDict() X, Y = data Xhat = self.compressor.reconstruct(X) c_ch = self.compressor.get_monitoring_channels(X) d_ch = self.discriminator.get_monitoring_channels((X, Y)) d_distorted_ch = self.discriminator.get_monitoring_channels((Xhat, Y)) if self.monitor_compressor: for key in c_ch: rval['compress_' + key] = c_ch[key] if self.monitor_discriminator: for key in d_ch: rval['dis_on_data_' + key] = d_ch[key] for key in d_ch: rval['dis_on_distorted_' + key] = d_distorted_ch[key] return rval
def forced_replace(out, x, y): """ Check all internal values of the graph that compute the variable ``out`` for occurrences of values identical with ``x``. If such occurrences are encountered then they are replaced with variable ``y``. Parameters ---------- out : Theano Variable x : Theano Variable y : Theano Variable Examples -------- out := sigmoid(wu)*(1-sigmoid(wu)) x := sigmoid(wu) forced_replace(out, x, y) := y*(1-y) """ if out is None: return None # ``visited`` is a set of nodes that are already known and don't need to be # checked again, speeding up the traversal of multiply-connected graphs. visited = set() def local_traverse(graph, x): if graph in visited: return [] visited.add(graph) if equal_computations([graph], [x]): return [graph] elif not graph.owner: return [] else: rval = [] for inp in graph.owner.inputs: rval += local_traverse(inp, x) return rval to_replace = local_traverse(out, x) return clone(out, replace=OrderedDict((v, y) for v in to_replace))
def reconstruct_graph(inputs, outputs, tag=None): """ Different interface to clone, that allows you to pass inputs. Compared to clone, this method always replaces the inputs with new variables of the same type, and returns those ( in the same order as the original inputs). """ if tag is None: tag = '' nw_inputs = [safe_new(x, tag) for x in inputs] givens = OrderedDict() for nw_x, x in izip(nw_inputs, inputs): givens[x] = nw_x allinputs = theano.gof.graph.inputs(outputs) for inp in allinputs: if isinstance(inp, theano.Constant): givens[inp] = inp.clone() nw_outputs = clone(outputs, replace=givens) return (nw_inputs, nw_outputs)
def get_data_subsets(self): """ Partition the dataset according to cross-validation subsets and return the raw data in each subset. """ for subsets in self.subset_iterator: labels = None if len(subsets) == 3: labels = ['train', 'valid', 'test'] elif len(subsets) == 2: labels = ['train', 'test'] # data_subsets is an OrderedDict to maintain label order data_subsets = OrderedDict() for i, subset in enumerate(subsets): subset_data = tuple(data[subset] for data in self._data) if len(subset_data) == 2: X, y = subset_data else: X, = subset_data y = None data_subsets[labels[i]] = (X, y) yield data_subsets
def make_layer_to_symbolic_state(self, num_examples, rng=None): """ .. todo:: Explain the difference with `make_layer_to_state` Makes and returns a dictionary mapping layers to states. By states, we mean here a real assignment, not a mean field state. For example, for a layer containing binary random variables, the state will be a shared variable containing values in {0,1}, not [0,1]. The visible layer will be included. Uses a dictionary so it is easy to unambiguously index a layer without needing to remember rules like vis layer = 0, hiddens start at 1, etc. Parameters ---------- num_examples : int WRITEME rng : WRITEME """ # Make a list of all layers layers = [self.visible_layer] + self.hidden_layers assert rng is not None states = [ layer.make_symbolic_state(num_examples, rng) for layer in layers ] zipped = safe_zip(layers, states) rval = OrderedDict(zipped) return rval
def construct(self, X_shared, y_shared, A, b, params, gparams, learning_rate, batch_size, loss): self.learning_rate = learning_rate index = T.lscalar() lr = T.dscalar() learning_rate = T.dscalar('learning_rate') # Create update rule updates = OrderedDict() for param, gparam in zip(params, gparams): updates[param] = param - learning_rate * gparam # Construct update function self.train_model = theano.function( inputs=[index, lr], updates=updates, outputs=loss, givens={ A: X_shared[index * batch_size:(index + 1) * batch_size], b: y_shared[index * batch_size:(index + 1) * batch_size], learning_rate: lr })
def get_cost_updates(self, k=1, MSEWeight=1): """This functions implements one step of CD-k :param k: number of Gibbs steps to do in CD-k Returns a proxy for the cost and the updates dictionary. The dictionary contains the update rules for weights and biases but also an update of the shared variable used to store the persistent chain, if one is used. """ # compute positive phase pre_sigmoid_ph, ph_mean, ph_sample = self.sample_h_given_v(self.input) # decide how to initialize persistent chain: # for CD, we use the newly generate hidden sample chain_start = ph_sample # end-snippet-2 # perform actual negative phase # in order to implement CD-k/PCD-k we need to scan over the # function that implements one gibbs step k times. # Read Theano tutorial on scan for more information : # http://deeplearning.net/software/theano/library/scan.html # the scan will return the entire Gibbs chain if k>1: ( [ pre_sigmoid_nvs, nv_means, nv_samples, pre_sigmoid_nhs, nh_means, nh_samples ], updates ) = theano.scan( self.gibbs_hvh, # the None are place holders, saying that # chain_start is the initial state corresponding to the # 6th output outputs_info=[None, None, None, None, None, chain_start], n_steps=k-1 ) n_visMean = self.propdown(nh_samples[-1])[1] n_hidMean = self.propup(n_visMean)[1] elif k==1: n_visMean = self.propdown(chain_start)[1] n_hidMean = self.propup(n_visMean)[1] #pre_sigmoid_v, n_visMean, v_sample = self.sample_v_given_h(chain_start) #n_hidMean = self.propup(v_sample)[1] updates=OrderedDict() else: print 'cd_steps wrong' exit() # start-snippet-3 # determine gradients on RBM parameters # note that we only need the sample at the end of the chain #chain_end = nv_samples[-1] #cost = T.mean(self.free_energy(self.input)) - T.mean( # self.free_energy(chain_end)) # We must not compute the gradient through the gibbs sampling #gparams = T.grad(cost, self.params, consider_constant=[chain_end]) # end-snippet-3 start-snippet-4 # constructs the update dictionary #pdb.set_trace() gparams = self.gradient( ph_mean, n_visMean, n_hidMean) # mode reconstruction cost gparams_Rec = T.grad(self.MSECost, self.params) #pdb.set_trace() grad = gparams[:] for id in range(len(gparams)): grad[id] = gparams[id] + MSEWeight*gparams_Rec[id] # set the computational graph, real training config parameters will be imported in the training function updates[self.params_inc[0]] = self.params_inc[0]*self.momentum - (grad[0] + self.W*self.weightCost)* self.lr updates[self.params_inc[1]] = self.params_inc[1]*self.momentum - grad[1] * self.lr updates[self.params_inc[2]] = self.params_inc[2]*self.momentum - grad[2] * self.lr for inc, param in zip(self.params_inc, self.params): updates[param] = param + inc # reconstruction cross-entropy is a better proxy for CD monitoring_cost = self.get_reconstruction_cost(n_visMean) if self.MSEType=='mode': MSE_cost = self.get_reconstruction_cost(self.mode_recover) else: MSE_cost = self.get_reconstruction_cost(self.recover) return monitoring_cost, MSE_cost, updates
def scan(fn, sequences=None, states=None, params=None, n_steps=None, mode=None, name=None, profile=False, allow_gc=None): """ Similar to Theano's official scan, this function gives the user more control over the scan op, avoiding certain difficulties that arose from missing optimizations. :param fn: lambda function that describes one step of scan (see the official Theano scan function) :param sequences: similar to the official Theano's scan. This version of scan does not support taps for the sequences (it can only be a list of tensor). Scan assumes that sequences have the right length and it does not check for this. :param states: similar to outputs_info of the official scan function. There is one crucial difference though, namely that the `initial` key in the dictionary has been replace by 'membuf' key. This reflects the change of meaning. Instead of passing to scan just the initial steps misisng, one has now to pass a memory buffer in which scan will try to store its output. In this memory buffer the first entries should be set to the initial states of the corresponding states. Providing a memory buffer that has less entries then the number of steps, mneans scan will only use that amount of memory. The user has to match the memory buffer size with the number of steps, otherwise scan will produce wrong results. Also if gradients are to be computed through the scan, the memory buffer should have the same length as the number of steps. For states that do not require a initial state, one has to provide a dictionary with a single key 'steps' that says how many intermediate results to store. See examples below for more insight. :param n_steps: This parameter is mandatory and it will represent the number of steps scan will do (scan will not check sequences or any other source of information to figure out how many steps it needs to do). :param mode: Same as for the official scan :param name: Same as for the official scan :param profile: Same as for the official scan Note: - there is no truncate / go_backwards anymore ! - the outputs returned by scan contain the initial states as well (i.e. if I loop over k steps, with my smallest tap for an output -3 and keep al intermediate results, my output will be of length k+3 Examples: (a) if you do not want to store any intermediate results (just the last one) # The memory buffer can be the initial state, just that we need to # add one extra dimension in front of it state = TT.unbroadcast(TT.shape_padleft(x0),0) out,_ = scan(lambda x:x+1, states = state, n_steps = 5) # Once we got our result we need to remove the extra dimension out = out[0] (b) if you want to keep every intermediate results state = TT.alloc(TT.constant(0), 6, x0.shape[0]) state = TT.set_subtensor(state[0], x0) out,_ = scan(lambda x:x+1, states = state, n_steps = 5) out = out[1:] """ def wrap_into_list(x): ''' Wrap the input into a list if it is not already a list ''' if x is None: return [] elif not isinstance(x, (list, tuple)): return [x] else: return list(x) seqs = wrap_into_list(sequences) outs_info = wrap_into_list(states) if allow_gc is None: allow_gc = config.scan.allow_gc # Make sure we get rid of numpy arrays or ints or anything like that # passed as inputs to scan non_seqs = [] for elem in wrap_into_list(params): if not isinstance(elem, gof.Variable): non_seqs.append(tensor.as_tensor_variable(elem)) else: non_seqs.append(elem) # If we provided a known number of steps ( before compilation) # and if that number is 1 or -1, then we can skip the Scan Op, # and just apply the inner function once # To do that we check here to see the nature of n_steps n_fixed_steps = None if isinstance(n_steps, (float, int)): n_fixed_steps = int(n_steps) else: try: n_fixed_steps = opt.get_scalar_constant_value(n_steps) except tensor.basic.NotScalarConstantError: n_fixed_steps = None # Check n_steps is an int if (hasattr(n_steps, 'dtype') and str(n_steps.dtype)[:3] not in ('uin', 'int')): raise ValueError(' n_steps must be an int. dtype provided ' 'is %s' % n_steps.dtype) # compute number of sequences and number of outputs n_seqs = len(seqs) n_outs = len(outs_info) return_steps = OrderedDict() # wrap outputs info in a dictionary if they are not already in one for i in xrange(n_outs): if outs_info[i] is not None: if not isinstance(outs_info[i], dict): # by default any output has a tap value of -1 outs_info[i] = dict(membuf=outs_info[i], taps=[-1]) elif (not outs_info[i].get('membuf', None) and outs_info[i].get('taps', None)): # ^ no initial state but taps provided raise ValueError(('If you are using slices of an output ' 'you need to provide a memory buffer for ' 'the state '), outs_info[i]) elif (outs_info[i].get('membuf', None) and not outs_info[i].get('taps', None)): # ^ initial state but taps not provided if 'taps' in outs_info[i]: # ^ explicitly provided a None for taps _logger.warning( 'Output %s (index %d) has a memory ' 'buffer but taps is explicitly set to None ', getattr(outs_info[i]['membuf'], 'name', 'None'), i) outs_info[i]['taps'] = [-1] else: # if a None is provided as the output info we replace it # with an dict(steps=n_steps) to simplify handling outs_info[i] = dict(steps=n_steps) ## # Step 2. Generate inputs and outputs of the inner functions # for compiling a dummy function (Iteration #1) ## # create theano inputs for the recursive function # note : this is a first batch of possible inputs that will # be compiled in a dummy function; we used this dummy # function to detect shared variables and their updates # and to construct a new and complete list of inputs and # outputs n_seqs = 0 scan_seqs = [] # Variables passed as inputs to the scan op inner_seqs = [] # Variables passed as inputs to the inner function inner_slices = [] # Actual slices if scan is removed from the picture # go through sequences picking up time slices as needed for i, seq in enumerate(seqs): if isinstance(seq, dict): seq = seq['input'] actual_slice = seq[0] _seq_val = tensor.as_tensor_variable(seq) _seq_val_slice = _seq_val[0] nw_slice = _seq_val_slice.type() # Try to transfer test_value to the new variable if config.compute_test_value != 'off': try: nw_slice.tag.test_value = gof.Op._get_test_value( _seq_val_slice) except AttributeError as e: if config.compute_test_value != 'ignore': # No need to print a warning or raise an error now, # it will be done when fn will be called. _logger.info(('Cannot compute test value for ' 'the inner function of scan, input value ' 'missing %s'), e) if seq.name: nw_slice.name = seq.name + '[t]' scan_seqs.append(_seq_val) inner_seqs.append(nw_slice) inner_slices.append(actual_slice) n_seqs += 1 actual_n_steps = tensor.as_tensor(n_steps) # Conventions : # mit_mot = multiple input taps, multiple output taps ( only provided # by the gradient function ) # mit_sot = multiple input taps, single output tap (t + 0) # sit_sot = single input tap, single output tap (t + 0) # nit_sot = no input tap, single output tap (t + 0) # MIT_MOT -- not provided by the user only by the grad function n_mit_mot = 0 n_mit_mot_outs = 0 mit_mot_scan_inputs = [] mit_mot_inner_inputs = [] mit_mot_inner_outputs = [] mit_mot_out_slices = [] mit_mot_rightOrder = [] # SIT_SOT -- provided by the user n_mit_sot = 0 mit_sot_scan_inputs = [] mit_sot_inner_inputs = [] mit_sot_inner_slices = [] mit_sot_inner_outputs = [] mit_sot_return_steps = OrderedDict() mit_sot_tap_array = [] mit_sot_rightOrder = [] n_sit_sot = 0 sit_sot_scan_inputs = [] sit_sot_inner_inputs = [] sit_sot_inner_slices = [] sit_sot_inner_outputs = [] sit_sot_return_steps = OrderedDict() sit_sot_rightOrder = [] nit_sot_steps = [] # go through outputs picking up time slices as needed for i, init_out in enumerate(outs_info): # Note that our convention dictates that if an output uses # just the previous time step, as a initial state we will only # provide a tensor of the same dimension as one time step; This # makes code much cleaner for those who do not use taps. Otherwise # they would always had to shape_padleft the initial state .. # which is ugly # Note, 'taps' might not be in the dictionary if 'taps' in init_out and init_out['taps'] == [-1]: actual_arg = init_out['membuf'] arg = safe_new(init_out['membuf'][0]) if isinstance(arg, tensor.Constant): # safe new returns a clone of the constants, but that is not # what we need for initial states arg = arg.type() # Try to transfer test_value to the new variable if config.compute_test_value != 'off': try: arg.tag.test_value = gof.Op._get_test_value(actual_arg) except AttributeError as e: if config.compute_test_value != 'ignore': # No need to print a warning or raise an error now, # it will be done when fn will be called. _logger.info(('Cannot compute test value for the ' 'inner function of scan, input value missing %s'), e) if getattr(init_out['membuf'], 'name', None) is not None: arg.name = init_out['membuf'].name + '[t-1]' # We need now to allocate space for storing the output and copy # the initial state over. We do this using the expand function # defined in scan utils sit_sot_scan_inputs.append(actual_arg) sit_sot_inner_slices.append(actual_arg[0]) if i in return_steps: sit_sot_return_steps[n_sit_sot] = return_steps[i] sit_sot_inner_inputs.append(arg) sit_sot_rightOrder.append(i) n_sit_sot += 1 elif init_out.get('taps', None): if numpy.any(numpy.array(init_out.get('taps', [])) > 0): # Make sure we do not have requests for future values of a # sequence we can not provide such values raise ValueError('Can not use future taps of outputs', init_out) # go through the taps mintap = abs(numpy.min(init_out['taps'])) mit_sot_tap_array.append(init_out['taps']) idx_offset = abs(numpy.min(init_out['taps'])) # Sequence mit_sot_scan_inputs.append(init_out['membuf']) if i in return_steps: mit_sot_return_steps[n_mit_sot] = return_steps[i] mit_sot_rightOrder.append(i) n_mit_sot += 1 for k in init_out['taps']: # create a new slice actual_nw_slice = init_out['membuf'][k + mintap] _init_out_var = tensor.as_tensor_variable(init_out['membuf']) _init_out_var_slice = _init_out_var[k + mintap] nw_slice = _init_out_var_slice.type() # Try to transfer test_value to the new variable if config.compute_test_value != 'off': try: nw_slice.tag.test_value = gof.Op._get_test_value( _init_out_var_slice) except AttributeError as e: if config.compute_test_value != 'ignore': # No need to print a warning or raise an error now, # it will be done when fn will be called. _logger.info(('Cannot compute test value for ' 'the inner function of scan, input value ' 'missing. %s'), e) # give it a name or debugging and pretty printing if getattr(init_out['membuf'], 'name', None) is not None: if k > 0: nw_slice.name = (init_out['membuf'].name + '[t+%d]' % k) elif k == 0: nw_slice.name = init_out['membuf'].name + '[t]' else: nw_slice.name = (init_out['membuf'].name + '[t%d]' % k) mit_sot_inner_inputs.append(nw_slice) mit_sot_inner_slices.append(actual_nw_slice) else: pass # Re-order args max_mit_sot = numpy.max([-1] + mit_sot_rightOrder) + 1 max_sit_sot = numpy.max([-1] + sit_sot_rightOrder) + 1 n_elems = numpy.max([max_mit_sot, max_sit_sot]) _ordered_args = [[] for x in xrange(n_elems)] offset = 0 for idx in xrange(n_mit_sot): n_inputs = len(mit_sot_tap_array[idx]) if n_fixed_steps == 1: _ordered_args[mit_sot_rightOrder[idx]] = \ mit_sot_inner_slices[offset:offset + n_inputs] else: _ordered_args[mit_sot_rightOrder[idx]] = \ mit_sot_inner_inputs[offset:offset + n_inputs] offset += n_inputs for idx in xrange(n_sit_sot): if n_fixed_steps == 1: _ordered_args[sit_sot_rightOrder[idx]] = \ [sit_sot_inner_slices[idx]] else: _ordered_args[sit_sot_rightOrder[idx]] = \ [sit_sot_inner_inputs[idx]] ordered_args = [] for ls in _ordered_args: ordered_args += ls if n_fixed_steps == 1: args = (inner_slices + ordered_args + non_seqs) else: args = (inner_seqs + ordered_args + non_seqs) # add only the non-shared variables and non-constants to the arguments of # the dummy function [ a function should not get shared variables or # constants as input ] dummy_args = [arg for arg in args if (not isinstance(arg, SharedVariable) and not isinstance(arg, tensor.Constant))] # when we apply the lambda expression we get a mixture of update rules # and outputs that needs to be separated lambda_result = fn(*args) condition, outputs, updates = scan_utils.get_updates_and_outputs( lambda_result) if condition is not None: as_while = True else: as_while = False ## # Step 3. Check if we actually need scan and remove it if we don't ## if n_fixed_steps == 1: # We do not need to use the scan op anymore, so we can just return # the outputs and updates we have if condition is not None: _logger.warning(('When the number of steps is fixed and equal ' 'to 1, the provided stopping condition, ', str(condition), ' is ignored')) for pos, inner_out in enumerate(outputs): # we need to see if we need to pad our sequences with an # unbroadcastable dimension; case example : we return an # output for which we want all intermediate. If n_steps is 1 # then, if we return the output as given by the innner function # this will represent only a slice and it will have one # dimension less. if (isinstance(inner_out.type, tensor.TensorType) and return_steps.get(pos, 0) != 1): outputs[pos] = tensor.unbroadcast( tensor.shape_padleft(inner_out), 0) if len(outputs) == 1: outputs = outputs[0] return (outputs, updates) ## # Step 4. Compile the dummy function ## # We can now compile a dummy function just to see what shared variable # we have and what are their update rules (note that the user has # the option not to pass the shared variable to scan, so we need to # pick them manually and add them to scan) # make the compilation as fast as possible by not applying any # optimization or conversion to C [ note this region is not important # for performance so we can do stuff as unoptimal as we wish ] # extract still missing inputs (there still might be so) and add them # as non sequences at the end of our args fake_nonseqs = [x.type() for x in non_seqs] fake_outputs = scan_utils.clone(outputs + updates.values(), replace=dict(zip(non_seqs, fake_nonseqs))) all_inputs = itertools.ifilter( lambda x: (isinstance(x, gof.Variable) and not isinstance(x, SharedVariable) and not isinstance(x, gof.Constant)), gof.graph.inputs(fake_outputs)) extra_inputs = filter(lambda x: x not in args + fake_nonseqs, all_inputs) non_seqs += extra_inputs # Note we do not use all_inputs directly since the order of variables # in args is quite important dummy_args += extra_inputs dummy_outs = outputs if condition is not None: dummy_outs.append(condition) # If we use a regular dict here, the results are non-deterministic if not isinstance(updates, (list, tuple)): if isinstance(updates, dict) and \ not isinstance(updates, OrderedDict): warnings.warn("Using non-deterministic dictionary.") dummy_f = function(dummy_args, dummy_outs, updates=updates, mode=compile.mode.Mode(linker='py', optimizer=None), on_unused_input='ignore') ## # Step 5. Re-arange inputs of scan into a more strict order ## # Step 5.0 Check the outputs of the dummy function to see if they # match with user provided data # if the number of outputs to the function does not match the number of # assumed outputs until now (provided by the user) there can be # only one explanation: No information is provided for any of the # outputs (i.e. we are dealing with a map) tmp_dummy_f_outs = len(dummy_f.maker.outputs) if as_while: tmp_dummy_f_outs -= 1 if not (tmp_dummy_f_outs == n_outs or outs_info == []): raise ValueError('Please provide None as output_info for ' 'any output that does not feed back into ' 'scan (i.e. it behaves like a map) ') if outs_info == []: n_outs = len(dummy_f.maker.outputs) if as_while: n_outs = n_outs - 1 outs_info = [dict(steps=n_steps) for x in xrange(n_outs)] # Step 5.1 Outputs with taps different then -1 for i, out in enumerate(outs_info): if 'taps' in out and out['taps'] != [-1]: mit_sot_inner_outputs.append(outputs[i]) # Step 5.2 Outputs with tap equal to -1 for i, out in enumerate(outs_info): if 'taps' in out and out['taps'] == [-1]: sit_sot_inner_outputs.append(outputs[i]) # Step 5.3 Outputs that correspond to update rules of shared variables givens = OrderedDict() n_shared_outs = 0 shared_scan_inputs = [] shared_inner_inputs = [] shared_inner_outputs = [] for input in dummy_f.maker.expanded_inputs: if isinstance(input.variable, SharedVariable) and input.update: new_var = safe_new(input.variable) if getattr(input.variable, 'name', None) is not None: new_var.name = input.variable.name + '_copy' shared_inner_inputs.append(new_var) shared_scan_inputs.append(input.variable) shared_inner_outputs.append(input.update) givens[input.variable] = new_var n_shared_outs += 1 # Step 5.4 Outputs with no taps used in the input n_nit_sot = 0 nit_sot_inner_outputs = [] nit_sot_return_steps = OrderedDict() nit_sot_rightOrder = [] for i, out in enumerate(outs_info): if not 'taps' in out: nit_sot_inner_outputs.append(outputs[i]) if i in return_steps: nit_sot_return_steps[n_nit_sot] = return_steps[i] nit_sot_rightOrder.append(i) nit_sot_steps.append(out['steps']) n_nit_sot += 1 # Step 5.5 all other arguments including extra inputs other_scan_args = [] other_inner_args = [] other_scan_args += [arg for arg in non_seqs if (not isinstance(arg, SharedVariable) and not isinstance(arg, tensor.Constant))] # Step 5.6 all shared variables with no update rules other_inner_args += [safe_new(arg, '_copy') for arg in non_seqs if (not isinstance(arg, SharedVariable) and not isinstance(arg, tensor.Constant))] givens.update(dict(zip(other_scan_args, other_inner_args))) other_shared_scan_args = [arg.variable for arg in dummy_f.maker.expanded_inputs if (isinstance(arg.variable, SharedVariable) and not arg.update)] other_shared_inner_args = [safe_new(arg.variable, '_copy') for arg in dummy_f.maker.expanded_inputs if (isinstance(arg.variable, SharedVariable) and not arg.update)] givens.update(dict(zip(other_shared_scan_args, other_shared_inner_args))) ## # Step 6. Re-order the outputs and clone them replacing things # using the givens ## inner_inputs = (inner_seqs + mit_mot_inner_inputs + mit_sot_inner_inputs + sit_sot_inner_inputs + shared_inner_inputs + other_shared_inner_args + other_inner_args) inner_outs = (mit_mot_inner_outputs + mit_sot_inner_outputs + sit_sot_inner_outputs + nit_sot_inner_outputs + shared_inner_outputs) if condition is not None: inner_outs.append(condition) new_givens = OrderedDict() for w, w_copy in givens.iteritems(): new_givens[w] = w.type.filter_variable(w_copy) new_outs = scan_utils.clone(inner_outs, replace=new_givens) ## # Step 7. Create the Scan Op ## tap_array = mit_sot_tap_array + [[-1] for x in xrange(n_sit_sot)] info = OrderedDict() info['tap_array'] = tap_array info['n_seqs'] = n_seqs info['n_mit_mot'] = n_mit_mot info['n_mit_mot_outs'] = n_mit_mot_outs info['mit_mot_out_slices'] = mit_mot_out_slices info['n_mit_sot'] = n_mit_sot info['n_sit_sot'] = n_sit_sot info['n_shared_outs'] = n_shared_outs info['n_nit_sot'] = n_nit_sot info['truncate_gradient'] = -1 info['name'] = name info['mode'] = mode info['destroy_map'] = OrderedDict() info['inplace'] = False info['gpu'] = False info['as_while'] = as_while info['profile'] = profile info['_scan_savemem_visited'] = True info['allow_gc'] = allow_gc local_op = scan_op.Scan(inner_inputs, new_outs, info) ## # Step 8. Compute the outputs using the scan op ## _scan_inputs = (scan_seqs + mit_mot_scan_inputs + mit_sot_scan_inputs + sit_sot_scan_inputs + shared_scan_inputs + nit_sot_steps + other_shared_scan_args + other_scan_args) scan_inputs = [] for arg in [actual_n_steps] + _scan_inputs: if not isinstance(arg, gof.Variable): arg = tensor.as_tensor_variable(arg) scan_inputs += [arg] scan_outs = local_op(*scan_inputs) if type(scan_outs) not in (list, tuple): scan_outs = [scan_outs] ## # Step 9. Figure out which outs are update rules for shared variables # and so on ... ## update_map = OrderedUpdates() offset = n_mit_mot offsets = [abs(numpy.min(x)) for x in mit_sot_tap_array] mit_sot_outs = scan_outs[offset:offset + n_mit_sot] offset += n_mit_sot offsets = [1 for x in xrange(n_sit_sot)] sit_sot_outs = scan_outs[offset:offset + n_sit_sot] offset += n_sit_sot nit_sot_outs = scan_outs[offset:offset + n_nit_sot] offset += n_nit_sot for idx, update_rule in enumerate( scan_outs[offset:offset + n_shared_outs]): update_map[shared_scan_inputs[idx]] = update_rule _scan_out_list = (mit_sot_outs + sit_sot_outs + nit_sot_outs) # Step 10. I need to reorder the outputs to be in the order expected by # the user rightOrder = (mit_sot_rightOrder + sit_sot_rightOrder + nit_sot_rightOrder) scan_out_list = [None] * len(rightOrder) for idx, pos in enumerate(rightOrder): scan_out_list[pos] = _scan_out_list[idx] if len(scan_out_list) == 1: scan_out_list = scan_out_list[0] elif len(scan_out_list) == 0: scan_out_list = None assert isinstance(update_map, OrderedDict) return (scan_out_list, update_map)
def scan(fn, sequences=None, outputs_info=None, non_sequences=None, n_steps=None, truncate_gradient=-1, go_backwards=False, mode=None, name=None, profile=False, allow_gc=None, strict=False): """ This function constructs and applies a Scan op to the provided arguments. :param fn: ``fn`` is a function that describes the operations involved in one step of ``scan``. ``fn`` should construct variables describing the output of one iteration step. It should expect as input theano variables representing all the slices of the input sequences and previous values of the outputs, as well as all other arguments given to scan as ``non_sequences``. The order in which scan passes these variables to ``fn`` is the following : * all time slices of the first sequence * all time slices of the second sequence * ... * all time slices of the last sequence * all past slices of the first output * all past slices of the second otuput * ... * all past slices of the last output * all other arguments (the list given as `non_sequences` to scan) The order of the sequences is the same as the one in the list `sequences` given to scan. The order of the outputs is the same as the order of ``outputs_info``. For any sequence or output the order of the time slices is the same as the one in which they have been given as taps. For example if one writes the following : .. code-block:: python scan(fn, sequences = [ dict(input= Sequence1, taps = [-3,2,-1]) , Sequence2 , dict(input = Sequence3, taps = 3) ] , outputs_info = [ dict(initial = Output1, taps = [-3,-5]) , dict(initial = Output2, taps = None) , Output3 ] , non_sequences = [ Argument1, Argument2]) ``fn`` should expect the following arguments in this given order: #. ``Sequence1[t-3]`` #. ``Sequence1[t+2]`` #. ``Sequence1[t-1]`` #. ``Sequence2[t]`` #. ``Sequence3[t+3]`` #. ``Output1[t-3]`` #. ``Output1[t-5]`` #. ``Output3[t-1]`` #. ``Argument1`` #. ``Argument2`` The list of ``non_sequences`` can also contain shared variables used in the function, though ``scan`` is able to figure those out on its own so they can be skipped. For the clarity of the code we recommend though to provide them to scan. To some extend ``scan`` can also figure out other ``non sequences`` (not shared) even if not passed to scan (but used by `fn`). A simple example of this would be : .. code-block:: python import theano.tensor as TT W = TT.matrix() W_2 = W**2 def f(x): return TT.dot(x,W_2) The function is expected to return two things. One is a list of outputs ordered in the same order as ``outputs_info``, with the difference that there should be only one output variable per output initial state (even if no tap value is used). Secondly `fn` should return an update dictionary (that tells how to update any shared variable after each iteration step). The dictionary can optionally be given as a list of tuples. There is no constraint on the order of these two list, ``fn`` can return either ``(outputs_list, update_dictionary)`` or ``(update_dictionary, outputs_list)`` or just one of the two (in case the other is empty). To use ``scan`` as a while loop, the user needs to change the function ``fn`` such that also a stopping condition is returned. To do so, he/she needs to wrap the condition in an ``until`` class. The condition should be returned as a third element, for example: .. code-block:: python ... return [y1_t, y2_t], {x:x+1}, theano.scan_module.until(x < 50) Note that a number of steps (considered in here as the maximum number of steps ) is still required even though a condition is passed (and it is used to allocate memory if needed). = {}): :param sequences: ``sequences`` is the list of Theano variables or dictionaries describing the sequences ``scan`` has to iterate over. If a sequence is given as wrapped in a dictionary, then a set of optional information can be provided about the sequence. The dictionary should have the following keys: * ``input`` (*mandatory*) -- Theano variable representing the sequence. * ``taps`` -- Temporal taps of the sequence required by ``fn``. They are provided as a list of integers, where a value ``k`` impiles that at iteration step ``t`` scan will pass to ``fn`` the slice ``t+k``. Default value is ``[0]`` Any Theano variable in the list ``sequences`` is automatically wrapped into a dictionary where ``taps`` is set to ``[0]`` :param outputs_info: ``outputs_info`` is the list of Theano variables or dictionaries describing the initial state of the outputs computed recurrently. When this initial states are given as dictionary optional information can be provided about the output corresponding to these initial states. The dictionary should have the following keys: * ``initial`` -- Theano variable that represents the initial state of a given output. In case the output is not computed recursively (think of a map) and does not require an initial state this field can be skipped. Given that (only) the previous time step of the output is used by ``fn``, the initial state **should have the same shape** as the output and **should not involve a downcast** of the data type of the output. If multiple time taps are used, the initial state should have one extra dimension that should cover all the possible taps. For example if we use ``-5``, ``-2`` and ``-1`` as past taps, at step 0, ``fn`` will require (by an abuse of notation) ``output[-5]``, ``output[-2]`` and ``output[-1]``. This will be given by the initial state, which in this case should have the shape (5,)+output.shape. If this variable containing the initial state is called ``init_y`` then ``init_y[0]`` *corresponds to* ``output[-5]``. ``init_y[1]`` *correponds to* ``output[-4]``, ``init_y[2]`` corresponds to ``output[-3]``, ``init_y[3]`` coresponds to ``output[-2]``, ``init_y[4]`` corresponds to ``output[-1]``. While this order might seem strange, it comes natural from splitting an array at a given point. Assume that we have a array ``x``, and we choose ``k`` to be time step ``0``. Then our initial state would be ``x[:k]``, while the output will be ``x[k:]``. Looking at this split, elements in ``x[:k]`` are ordered exactly like those in ``init_y``. * ``taps`` -- Temporal taps of the output that will be pass to ``fn``. They are provided as a list of *negative* integers, where a value ``k`` implies that at iteration step ``t`` scan will pass to ``fn`` the slice ``t+k``. ``scan`` will follow this logic if partial information is given: * If an output is not wrapped in a dictionary, ``scan`` will wrap it in one assuming that you use only the last step of the output (i.e. it makes your tap value list equal to [-1]). * If you wrap an output in a dictionary and you do not provide any taps but you provide an initial state it will assume that you are using only a tap value of -1. * If you wrap an output in a dictionary but you do not provide any initial state, it assumes that you are not using any form of taps. * If you provide a ``None`` instead of a variable or a empty dictionary ``scan`` assumes that you will not use any taps for this output (like for example in case of a map) If ``outputs_info`` is an empty list or None, ``scan`` assumes that no tap is used for any of the outputs. If information is provided just for a subset of the outputs an exception is raised (because there is no convention on how scan should map the provided information to the outputs of ``fn``) :param non_sequences: ``non_sequences`` is the list of arguments that are passed to ``fn`` at each steps. One can opt to exclude variable used in ``fn`` from this list as long as they are part of the computational graph, though for clarity we encourage not to do so. :param n_steps: ``n_steps`` is the number of steps to iterate given as an int or Theano scalar. If any of the input sequences do not have enough elements, scan will raise an error. If the *value is 0* the outputs will have *0 rows*. If the value is negative, ``scan`` will run backwards in time. If the ``go_backwards`` flag is already set and also ``n_steps`` is negative, ``scan`` will run forward in time. If n_steps is not provided, ``scan`` will figure out the amount of steps it should run given its input sequences. :param truncate_gradient: ``truncate_gradient`` is the number of steps to use in truncated BPTT. If you compute gradients through a scan op, they are computed using backpropagation through time. By providing a different value then -1, you choose to use truncated BPTT instead of classical BPTT, where you go for only ``truncate_gradient`` number of steps back in time. :param go_backwards: ``go_backwards`` is a flag indicating if ``scan`` should go backwards through the sequences. If you think of each sequence as indexed by time, making this flag True would mean that ``scan`` goes back in time, namely that for any sequence it starts from the end and goes towards 0. :param name: When profiling ``scan``, it is crucial to provide a name for any instance of ``scan``. The profiler will produce an overall profile of your code as well as profiles for the computation of one step of each instance of ``scan``. The ``name`` of the instance appears in those profiles and can greatly help to disambiguate information. :param mode: It is recommended to leave this argument to None, especially when profiling ``scan`` (otherwise the results are not going to be accurate). If you prefer the computations of one step of ``scan`` to be done differently then the entire function, you can use this parameter to describe how the computations in this loop are done (see ``theano.function`` for details about possible values and their meaning). :param profile: Flag or string. If true, or different from the empty string, a profile object will be created and attached to the inner graph of scan. In case ``profile`` is True, the profile object will have the name of the scan instance, otherwise it will have the passed string. Profile object collect (and print) information only when running the inner graph with the new cvm linker ( with default modes, other linkers this argument is useless) :param allow_gc: Set the value of allow gc for the internal graph of scan. If set to None, this will use the value of config.scan.allow_gc. :param strict: If true, all the shared variables used in ``fn`` must be provided as a part of ``non_sequences`` or ``sequences``. :rtype: tuple :return: tuple of the form (outputs, updates); ``outputs`` is either a Theano variable or a list of Theano variables representing the outputs of ``scan`` (in the same order as in ``outputs_info``). ``updates`` is a subclass of dictionary specifying the update rules for all shared variables used in scan This dictionary should be passed to ``theano.function`` when you compile your function. The change compared to a normal dictionary is that we validate that keys are SharedVariable and addition of those dictionary are validated to be consistent. """ # General observation : this code is executed only once, at creation # of the computational graph, so we don't yet need to be smart about # anything (to speed things up) ## # Step 1. Wrap all inputs in dictionaries and add default values ## # check if inputs are just single variables instead of lists def wrap_into_list(x): ''' Wrap the input into a list if it is not already a list ''' if x is None: return [] elif not isinstance(x, (list, tuple)): return [x] else: return list(x) seqs = wrap_into_list(sequences) outs_info = wrap_into_list(outputs_info) # Make sure we get rid of numpy arrays or ints or anything like that # passed as inputs to scan non_seqs = [] for elem in wrap_into_list(non_sequences): if not isinstance(elem, gof.Variable): non_seqs.append(tensor.as_tensor_variable(elem)) else: non_seqs.append(elem) # If we provided a known number of steps ( before compilation) # and if that number is 1 or -1, then we can skip the Scan Op, # and just apply the inner function once # To do that we check here to see the nature of n_steps n_fixed_steps = None if isinstance(n_steps, (float, int)): n_fixed_steps = int(n_steps) else: try: n_fixed_steps = opt.get_scalar_constant_value(n_steps) except tensor.basic.NotScalarConstantError: n_fixed_steps = None # Check n_steps is an int if (hasattr(n_steps, 'dtype') and str(n_steps.dtype)[:3] not in ('uin', 'int')): raise ValueError(' n_steps must be an int. dtype provided ' 'is %s' % n_steps.dtype) # compute number of sequences and number of outputs n_seqs = len(seqs) n_outs = len(outs_info) return_steps = OrderedDict() # wrap sequences in a dictionary if they are not already dictionaries for i in xrange(n_seqs): if not isinstance(seqs[i], dict): seqs[i] = OrderedDict([('input', seqs[i]), ('taps', [0])]) elif seqs[i].get('taps', None) is not None: seqs[i]['taps'] = wrap_into_list(seqs[i]['taps']) elif seqs[i].get('taps', None) is None: # seqs dictionary does not have the ``taps`` key seqs[i]['taps'] = [0] # wrap outputs info in a dictionary if they are not already in one for i in xrange(n_outs): if outs_info[i] is not None: if isinstance(outs_info[i], dict): # DEPRECATED : if outs_info[i].get('return_steps', None) is not None: raise ValueError( "Using `return_steps` has been deprecated. " "Simply select the entries you need using a " "subtensor. Scan will optimize memory " "consumption, so do not worry about that.") # END if not isinstance(outs_info[i], dict): # by default any output has a tap value of -1 outs_info[i] = OrderedDict([('initial', outs_info[i]), ('taps', [-1])]) elif (outs_info[i].get('initial', None) is None and outs_info[i].get('taps', None) is not None): # ^ no initial state but taps provided raise ValueError(('If you are using slices of an output ' 'you need to provide a initial state ' 'for it'), outs_info[i]) elif (outs_info[i].get('initial', None) is not None and outs_info[i].get('taps', None) is None): # ^ initial state but taps not provided if 'taps' in outs_info[i]: # ^ explicitly provided a None for taps _logger.warning('Output %s ( index %d) has a initial ' 'state but taps is explicitly set to None ', getattr(outs_info[i]['initial'], 'name', 'None'), i) outs_info[i]['taps'] = [-1] else: # if a None is provided as the output info we replace it # with an empty OrdereDict() to simplify handling outs_info[i] = OrderedDict() ## # Step 2. Generate inputs and outputs of the inner functions # for compiling a dummy function (Iteration #1) ## # create theano inputs for the recursive function # note : this is a first batch of possible inputs that will # be compiled in a dummy function; we used this dummy # function to detect shared variables and their updates # and to construct a new and complete list of inputs and # outputs n_seqs = 0 scan_seqs = [] # Variables passed as inputs to the scan op inner_seqs = [] # Variables passed as inputs to the inner function inner_slices = [] # Actual slices if scan is removed from the picture # go through sequences picking up time slices as needed for i, seq in enumerate(seqs): # Note that you can have something like no taps for # a sequence, though is highly unlikely in practice if 'taps' in seq: # go through the indicated slice mintap = numpy.min(seq['taps']) maxtap = numpy.max(seq['taps']) for k in seq['taps']: # create one slice of the input # Later on, if we decide not to use scan because we are # going for just one step, it makes things easier if we # compute the correct outputs here. This way we can use # the output of the lambda expression directly to replace # the output of scan. # If not we need to use copies, that will be replaced at # each frame by the corresponding slice actual_slice = seq['input'][k - mintap] _seq_val = tensor.as_tensor_variable(seq['input']) _seq_val_slice = _seq_val[k - mintap] nw_slice = _seq_val_slice.type() # Try to transfer test_value to the new variable if config.compute_test_value != 'off': try: nw_slice.tag.test_value = gof.Op._get_test_value( _seq_val_slice) except AttributeError as e: if config.compute_test_value != 'ignore': # No need to print a warning or raise an error now, # it will be done when fn will be called. _logger.info(('Cannot compute test value for ' 'the inner function of scan, input value ' 'missing %s'), e) # Add names to slices for debugging and pretty printing .. # that is if the input already has a name if getattr(seq['input'], 'name', None) is not None: if k > 0: nw_name = seq['input'].name + '[t+%d]' % k elif k == 0: nw_name = seq['input'].name + '[t]' else: nw_name = seq['input'].name + '[t%d]' % k nw_slice.name = nw_name # We cut the sequence such that seq[i] to correspond to # seq[i-k] if maxtap < 0: offset = abs(maxtap) else: offset = 0 if maxtap == mintap and maxtap != 0: if maxtap < 0: nw_seq = seq['input'][:maxtap] else: nw_seq = seq['input'][maxtap:] elif maxtap - k != 0: nw_seq = seq['input'][offset + k - mintap: -(maxtap - k)] else: nw_seq = seq['input'][offset + k - mintap:] if go_backwards: nw_seq = nw_seq[::-1] scan_seqs.append(nw_seq) inner_seqs.append(nw_slice) inner_slices.append(actual_slice) n_seqs += 1 # Since we've added all sequences now we need to level them up based on # n_steps or their different shapes lengths_vec = [] for seq in scan_seqs: lengths_vec.append(seq.shape[0]) if not scan_utils.isNaN_or_Inf_or_None(n_steps): # ^ N_steps should also be considered lengths_vec.append(tensor.as_tensor(n_steps)) if len(lengths_vec) == 0: # ^ No information about the number of steps raise ValueError(' No information about the number of steps ' 'provided. Either provide a value for ' 'n_steps argument of scan or provide an input ' 'sequence') # If the user has provided the number of steps, do that regardless ( and # raise an error if the sequences are not long enough ) if scan_utils.isNaN_or_Inf_or_None(n_steps): actual_n_steps = lengths_vec[0] for contestant in lengths_vec[1:]: actual_n_steps = tensor.minimum(actual_n_steps, contestant) else: actual_n_steps = tensor.as_tensor(n_steps) # Add names -- it helps a lot when debugging for (nw_seq, seq) in zip(scan_seqs, seqs): if getattr(seq['input'], 'name', None) is not None: nw_seq.name = seq['input'].name + '[%d:]' % k scan_seqs = [seq[:actual_n_steps] for seq in scan_seqs] # Conventions : # mit_mot = multiple input taps, multiple output taps ( only provided # by the gradient function ) # mit_sot = multiple input taps, single output tap (t + 0) # sit_sot = single input tap, single output tap (t + 0) # nit_sot = no input tap, single output tap (t + 0) # MIT_MOT -- not provided by the user only by the grad function n_mit_mot = 0 n_mit_mot_outs = 0 mit_mot_scan_inputs = [] mit_mot_inner_inputs = [] mit_mot_inner_outputs = [] mit_mot_out_slices = [] mit_mot_rightOrder = [] # SIT_SOT -- provided by the user n_mit_sot = 0 mit_sot_scan_inputs = [] mit_sot_inner_inputs = [] mit_sot_inner_slices = [] mit_sot_inner_outputs = [] mit_sot_return_steps = OrderedDict() mit_sot_tap_array = [] mit_sot_rightOrder = [] n_sit_sot = 0 sit_sot_scan_inputs = [] sit_sot_inner_inputs = [] sit_sot_inner_slices = [] sit_sot_inner_outputs = [] sit_sot_return_steps = OrderedDict() sit_sot_rightOrder = [] # go through outputs picking up time slices as needed for i, init_out in enumerate(outs_info): # Note that our convention dictates that if an output uses # just the previous time step, as a initial state we will only # provide a tensor of the same dimension as one time step; This # makes code much cleaner for those who do not use taps. Otherwise # they would always had to shape_padleft the initial state .. # which is ugly if init_out.get('taps', None) == [-1]: actual_arg = init_out['initial'] if not isinstance(actual_arg, tensor.Variable): actual_arg = tensor.as_tensor_variable(actual_arg) arg = safe_new(actual_arg) if isinstance(arg, tensor.Constant): # safe new returns a clone of the constants, but that is not # what we need for initial states arg = arg.type() # Try to transfer test_value to the new variable if config.compute_test_value != 'off': try: arg.tag.test_value = gof.Op._get_test_value(actual_arg) except AttributeError as e: if config.compute_test_value != 'ignore': # No need to print a warning or raise an error now, # it will be done when fn will be called. _logger.info(('Cannot compute test value for the ' 'inner function of scan, input value missing %s'), e) if getattr(init_out['initial'], 'name', None) is not None: arg.name = init_out['initial'].name + '[t-1]' # We need now to allocate space for storing the output and copy # the initial state over. We do this using the expand function # defined in scan utils sit_sot_scan_inputs.append( scan_utils.expand( tensor.unbroadcast( tensor.shape_padleft(actual_arg), 0), actual_n_steps )) sit_sot_inner_slices.append(actual_arg) if i in return_steps: sit_sot_return_steps[n_sit_sot] = return_steps[i] sit_sot_inner_inputs.append(arg) sit_sot_rightOrder.append(i) n_sit_sot += 1 elif init_out.get('taps', None): if numpy.any(numpy.array(init_out.get('taps', [])) > 0): # Make sure we do not have requests for future values of a # sequence we can not provide such values raise ValueError('Can not use future taps of outputs', init_out) # go through the taps mintap = abs(numpy.min(init_out['taps'])) mit_sot_tap_array.append(init_out['taps']) idx_offset = abs(numpy.min(init_out['taps'])) # Sequence mit_sot_scan_inputs.append( scan_utils.expand(init_out['initial'][:mintap], actual_n_steps)) if i in return_steps: mit_sot_return_steps[n_mit_sot] = return_steps[i] mit_sot_rightOrder.append(i) n_mit_sot += 1 for k in init_out['taps']: # create a new slice actual_nw_slice = init_out['initial'][k + mintap] _init_out_var = tensor.as_tensor_variable(init_out['initial']) _init_out_var_slice = _init_out_var[k + mintap] nw_slice = _init_out_var_slice.type() # Try to transfer test_value to the new variable if config.compute_test_value != 'off': try: nw_slice.tag.test_value = gof.Op._get_test_value( _init_out_var_slice) except AttributeError as e: if config.compute_test_value != 'ignore': # No need to print a warning or raise an error now, # it will be done when fn will be called. _logger.info(('Cannot compute test value for ' 'the inner function of scan, input value ' 'missing. %s'), e) # give it a name or debugging and pretty printing if getattr(init_out['initial'], 'name', None) is not None: if k > 0: nw_slice.name = (init_out['initial'].name + '[t+%d]' % k) elif k == 0: nw_slice.name = init_out['initial'].name + '[t]' else: nw_slice.name = (init_out['initial'].name + '[t%d]' % k) mit_sot_inner_inputs.append(nw_slice) mit_sot_inner_slices.append(actual_nw_slice) # NOTE: there is another case, in which we do not want to provide # any previous value of the output to the inner function (i.e. # a map); in that case we do not have to do anything .. # Re-order args max_mit_sot = numpy.max([-1] + mit_sot_rightOrder) + 1 max_sit_sot = numpy.max([-1] + sit_sot_rightOrder) + 1 n_elems = numpy.max([max_mit_sot, max_sit_sot]) _ordered_args = [[] for x in xrange(n_elems)] offset = 0 for idx in xrange(n_mit_sot): n_inputs = len(mit_sot_tap_array[idx]) if n_fixed_steps in [1, -1]: _ordered_args[mit_sot_rightOrder[idx]] = \ mit_sot_inner_slices[offset:offset + n_inputs] else: _ordered_args[mit_sot_rightOrder[idx]] = \ mit_sot_inner_inputs[offset:offset + n_inputs] offset += n_inputs for idx in xrange(n_sit_sot): if n_fixed_steps in [1, -1]: _ordered_args[sit_sot_rightOrder[idx]] = \ [sit_sot_inner_slices[idx]] else: _ordered_args[sit_sot_rightOrder[idx]] = \ [sit_sot_inner_inputs[idx]] ordered_args = [] for ls in _ordered_args: ordered_args += ls if n_fixed_steps in [1, -1]: args = (inner_slices + ordered_args + non_seqs) else: args = (inner_seqs + ordered_args + non_seqs) # add only the non-shared variables and non-constants to the arguments of # the dummy function [ a function should not get shared variables or # constants as input ] dummy_args = [arg for arg in args if (not isinstance(arg, SharedVariable) and not isinstance(arg, tensor.Constant))] # when we apply the lambda expression we get a mixture of update rules # and outputs that needs to be separated condition, outputs, updates = scan_utils.get_updates_and_outputs(fn(*args)) if condition is not None: as_while = True else: as_while = False ## # Step 3. Check if we actually need scan and remove it if we don't ## if n_fixed_steps in [1, -1]: # We do not need to use the scan op anymore, so we can just return # the outputs and updates we have if condition is not None: _logger.warning(('When the number of steps is fixed and equal ' 'to 1, the provided stopping condition, ', str(condition), ' is ignored')) for pos, inner_out in enumerate(outputs): # we need to see if we need to pad our sequences with an # unbroadcastable dimension; case example : we return an # output for which we want all intermediate. If n_steps is 1 # then, if we return the output as given by the innner function # this will represent only a slice and it will have one # dimension less. if (isinstance(inner_out.type, tensor.TensorType) and return_steps.get(pos, 0) != 1): outputs[pos] = tensor.unbroadcast( tensor.shape_padleft(inner_out), 0) if len(outputs) == 1: outputs = outputs[0] return (outputs, updates) ## # Step 4. Compile the dummy function ## # We can now compile a dummy function just to see what shared variable # we have and what are their update rules (note that the user has # the option not to pass the shared variable to scan, so we need to # pick them manually and add them to scan) # make the compilation as fast as possible by not applying any # optimization or conversion to C [ note this region is not important # for performance so we can do stuff as unoptimal as we wish ] # extract still missing inputs (there still might be so) and add them # as non sequences at the end of our args fake_nonseqs = [x.type() for x in non_seqs] fake_outputs = scan_utils.clone(outputs, replace=OrderedDict(zip(non_seqs, fake_nonseqs))) all_inputs = itertools.ifilter( lambda x: (isinstance(x, gof.Variable) and not isinstance(x, SharedVariable) and not isinstance(x, gof.Constant)), gof.graph.inputs(fake_outputs)) extra_inputs = [x for x in all_inputs if x not in args + fake_nonseqs] non_seqs += extra_inputs # Note we do not use all_inputs directly since the order of variables # in args is quite important dummy_args += extra_inputs dummy_outs = outputs if condition is not None: dummy_outs.append(condition) dummy_f = function(dummy_args, dummy_outs, updates=updates, mode=compile.mode.Mode(linker='py', optimizer=None), on_unused_input='ignore', profile=False) ## # Step 5. Re-arange inputs of scan into a more strict order ## # Step 5.0 Check the outputs of the dummy function to see if they # match with user provided data # if the number of outputs to the function does not match the number of # assumed outputs until now (provided by the user) there can be # only one explanation: No information is provided for any of the # outputs (i.e. we are dealing with a map) tmp_dummy_f_outs = len(dummy_f.maker.outputs) if as_while: tmp_dummy_f_outs -= 1 if not (tmp_dummy_f_outs == n_outs or outs_info == []): raise ValueError('Please provide None as outputs_info for ' 'any output that does not feed back into ' 'scan (i.e. it behaves like a map) ') if outs_info == []: n_outs = len(dummy_f.maker.outputs) if as_while: n_outs = n_outs - 1 outs_info = [OrderedDict() for x in xrange(n_outs)] # Step 5.1 Outputs with taps different then -1 for i, out in enumerate(outs_info): if 'taps' in out and out['taps'] != [-1]: mit_sot_inner_outputs.append(outputs[i]) # Step 5.2 Outputs with tap equal to -1 for i, out in enumerate(outs_info): if 'taps' in out and out['taps'] == [-1]: sit_sot_inner_outputs.append(outputs[i]) # Step 5.3 Outputs that correspond to update rules of shared variables givens = OrderedDict() n_shared_outs = 0 shared_scan_inputs = [] shared_inner_inputs = [] shared_inner_outputs = [] sit_sot_shared = [] for input in dummy_f.maker.expanded_inputs: if isinstance(input.variable, SharedVariable) and input.update: new_var = safe_new(input.variable) if getattr(input.variable, 'name', None) is not None: new_var.name = input.variable.name + '_copy' if isinstance(new_var.type, ops.expandable_types): sit_sot_inner_inputs.append(new_var) sit_sot_scan_inputs.append( scan_utils.expand( tensor.unbroadcast( tensor.shape_padleft(input.variable), 0), actual_n_steps)) tensor_update = tensor.as_tensor_variable(input.update) sit_sot_inner_outputs.append(tensor_update) # Not that pos is not a negative index. The sign of pos is used # as a flag to indicate if this output should be part of the # update rules or part of the standard outputs of scan. # If `pos` is positive than it corresponds to the standard # outputs of scan and it refers to output of index `pos`. If `pos` # is negative that it corresponds to update rules of scan and it # refers to update rule of index -1 - `pos`. sit_sot_rightOrder.append(-1 - len(sit_sot_shared)) sit_sot_shared.append(input.variable) givens[input.variable] = new_var else: shared_inner_inputs.append(new_var) shared_scan_inputs.append(input.variable) shared_inner_outputs.append(input.update) givens[input.variable] = new_var n_shared_outs += 1 n_sit_sot = len(sit_sot_inner_inputs) # Step 5.4 Outputs with no taps used in the input n_nit_sot = 0 nit_sot_inner_outputs = [] nit_sot_return_steps = OrderedDict() nit_sot_rightOrder = [] for i, out in enumerate(outs_info): if not 'taps' in out: nit_sot_inner_outputs.append(outputs[i]) if i in return_steps: nit_sot_return_steps[n_nit_sot] = return_steps[i] nit_sot_rightOrder.append(i) n_nit_sot += 1 # Step 5.5 all other arguments including extra inputs other_scan_args = [] other_inner_args = [] other_scan_args += [arg for arg in non_seqs if (not isinstance(arg, SharedVariable) and not isinstance(arg, tensor.Constant))] # Step 5.6 all shared variables with no update rules other_inner_args += [safe_new(arg, '_copy') for arg in non_seqs if (not isinstance(arg, SharedVariable) and not isinstance(arg, tensor.Constant))] givens.update(OrderedDict(zip(other_scan_args, other_inner_args))) if strict: non_seqs_set = set(non_sequences if non_sequences != None else []) other_shared_scan_args = [arg.variable for arg in dummy_f.maker.expanded_inputs if (isinstance(arg.variable, SharedVariable) and not arg.update and arg.variable in non_seqs_set)] other_shared_inner_args = [safe_new(arg.variable, '_copy') for arg in dummy_f.maker.expanded_inputs if (isinstance(arg.variable, SharedVariable) and not arg.update and arg.variable in non_seqs_set)] else: other_shared_scan_args = [arg.variable for arg in dummy_f.maker.expanded_inputs if (isinstance(arg.variable, SharedVariable) and not arg.update)] other_shared_inner_args = [safe_new(arg.variable, '_copy') for arg in dummy_f.maker.expanded_inputs if (isinstance(arg.variable, SharedVariable) and not arg.update)] givens.update(OrderedDict(zip(other_shared_scan_args, other_shared_inner_args))) ## # Step 6. Re-order the outputs and clone them replacing things # using the givens ## inner_inputs = (inner_seqs + mit_mot_inner_inputs + mit_sot_inner_inputs + sit_sot_inner_inputs + shared_inner_inputs + other_shared_inner_args + other_inner_args) inner_outs = (mit_mot_inner_outputs + mit_sot_inner_outputs + sit_sot_inner_outputs + nit_sot_inner_outputs + shared_inner_outputs) if condition is not None: inner_outs.append(condition) # Cuda is imported here, instead of being imported on top of the file # because forces on the user some dependencies that we might do not want # to. Currently we are working on removing the dependencies on sandbox # code completeley. from theano.sandbox import cuda if cuda.cuda_available: # very often we end up in this situation when we want to # replace w with w_copy, where w is CudaNdarray # and w_copy is TensorType. This is caused because shared # variables are put on GPU right aways >:| , new_givens = OrderedDict() for w, w_copy in givens.iteritems(): if (isinstance(w.type, cuda.CudaNdarrayType) and isinstance(w_copy.type, tensor.TensorType)): for o in inner_outs: new_givens = traverse(o, w, w_copy, new_givens) else: new_givens[w] = w_copy else: new_givens = givens new_outs = scan_utils.clone(inner_outs, replace=new_givens) ## # Step 7. Create the Scan Op ## tap_array = mit_sot_tap_array + [[-1] for x in xrange(n_sit_sot)] if allow_gc is None: allow_gc = config.scan.allow_gc info = OrderedDict() info['tap_array'] = tap_array info['n_seqs'] = n_seqs info['n_mit_mot'] = n_mit_mot info['n_mit_mot_outs'] = n_mit_mot_outs info['mit_mot_out_slices'] = mit_mot_out_slices info['n_mit_sot'] = n_mit_sot info['n_sit_sot'] = n_sit_sot info['n_shared_outs'] = n_shared_outs info['n_nit_sot'] = n_nit_sot info['truncate_gradient'] = truncate_gradient info['name'] = name info['mode'] = mode info['destroy_map'] = OrderedDict() info['gpu'] = False info['as_while'] = as_while info['profile'] = profile info['allow_gc'] = allow_gc info['strict'] = strict if strict: warnings.warn('In the strict mode, all neccessary shared variables ' 'must be passed as a part of non_sequences', Warning) local_op = scan_op.Scan(inner_inputs, new_outs, info) ## # Step 8. Compute the outputs using the scan op ## _scan_inputs = (scan_seqs + mit_mot_scan_inputs + mit_sot_scan_inputs + sit_sot_scan_inputs + shared_scan_inputs + [actual_n_steps for x in xrange(n_nit_sot)] + other_shared_scan_args + other_scan_args) scan_inputs = [] for arg in [actual_n_steps] + _scan_inputs: try: arg = tensor.as_tensor_variable(arg) except TypeError: # This happens for Random States for e.g. but it is a good way # to make sure no input is a cuda ndarrays pass scan_inputs += [arg] scan_outs = local_op(*scan_inputs) if type(scan_outs) not in (list, tuple): scan_outs = [scan_outs] ## # Step 9. Figure out which outs are update rules for shared variables # and so on ... ## update_map = OrderedUpdates() def remove_dimensions(outs, steps_return, offsets=None): out_ls = [] for idx, out in enumerate(outs): if idx in steps_return: if steps_return[idx] > 1: out_ls.append(out[-steps_return[idx]:]) else: out_ls.append(out[-1]) else: if offsets is None: out_ls.append(out) else: out_ls.append(out[offsets[idx]:]) return out_ls offset = n_mit_mot offsets = [abs(numpy.min(x)) for x in mit_sot_tap_array] mit_sot_outs = remove_dimensions( scan_outs[offset:offset + n_mit_sot], mit_sot_return_steps, offsets) offset += n_mit_sot offsets = [1 for x in xrange(n_sit_sot)] sit_sot_outs = remove_dimensions( scan_outs[offset:offset + n_sit_sot], sit_sot_return_steps, offsets) offset += n_sit_sot nit_sot_outs = remove_dimensions( scan_outs[offset:offset + n_nit_sot], nit_sot_return_steps) offset += n_nit_sot for idx, update_rule in enumerate( scan_outs[offset:offset + n_shared_outs]): update_map[shared_scan_inputs[idx]] = update_rule _scan_out_list = (mit_sot_outs + sit_sot_outs + nit_sot_outs) # Step 10. I need to reorder the outputs to be in the order expected by # the user rightOrder = (mit_sot_rightOrder + sit_sot_rightOrder + nit_sot_rightOrder) scan_out_list = [None] * len(rightOrder) for idx, pos in enumerate(rightOrder): if pos >= 0: scan_out_list[pos] = _scan_out_list[idx] else: # Not that pos is not a negative index. The sign of pos is used # as a flag to indicate if this output should be part of the # update rules or part of the standard outputs of scan. # If `pos` is positive than it corresponds to the standard # outputs of scan and it refers to output of index `pos`. If `pos` # is negative that it corresponds to update rules of scan and it # refers to update rule of index -1 - `pos`. update_map[sit_sot_shared[abs(pos) - 1]] = _scan_out_list[idx][-1] scan_out_list = [x for x in scan_out_list if x is not None] if len(scan_out_list) == 1: scan_out_list = scan_out_list[0] elif len(scan_out_list) == 0: scan_out_list = None return (scan_out_list, update_map)
def orderings(self, fgraph): """ Return orderings induced by destructive operations. Raise InconsistencyError when a) attempting to destroy indestructable variable, or b) attempting to destroy a value multiple times, or c) an Apply destroys (illegally) one of its own inputs by aliasing """ rval = OrderedDict() if self.destroyers: # BUILD DATA STRUCTURES # CHECK for multiple destructions during construction of variables droot, impact, __ignore = self.refresh_droot_impact() # check for destruction of constants illegal_destroy = [r for r in droot if getattr(r.tag, 'indestructible', False) or isinstance(r, graph.Constant)] if illegal_destroy: raise InconsistencyError( "Attempting to destroy indestructible variables: %s" % illegal_destroy) # add destroyed variable clients as computational dependencies for app in self.destroyers: # for each destroyed input... for output_idx, input_idx_list in iteritems(app.op.destroy_map): destroyed_idx = input_idx_list[0] destroyed_variable = app.inputs[destroyed_idx] root = droot[destroyed_variable] root_impact = impact[root] # we generally want to put all clients of things which depend on root # as pre-requisites of app. # But, app is itself one such client! # App will always be a client of the node we're destroying # (destroyed_variable, but the tricky thing is when it is also a client of # *another variable* viewing on the root. Generally this is illegal, (e.g., # add_inplace(x, x.T). In some special cases though, the in-place op will # actually be able to work properly with multiple destroyed inputs (e.g, # add_inplace(x, x). An Op that can still work in this case should declare # so via the 'destroyhandler_tolerate_same' attribute or # 'destroyhandler_tolerate_aliased' attribute. # # destroyhandler_tolerate_same should be a list of pairs of the form # [(idx0, idx1), (idx0, idx2), ...] # The first element of each pair is the input index of a destroyed # variable. # The second element of each pair is the index of a different input where # we will permit exactly the same variable to appear. # For example, add_inplace.tolerate_same might be [(0,1)] if the destroyed # input is also allowed to appear as the second argument. # # destroyhandler_tolerate_aliased is the same sort of list of # pairs. # op.destroyhandler_tolerate_aliased = [(idx0, idx1)] tells the # destroyhandler to IGNORE an aliasing between a destroyed # input idx0 and another input idx1. # This is generally a bad idea, but it is safe in some # cases, such as # - the op reads from the aliased idx1 before modifying idx0 # - the idx0 and idx1 are guaranteed not to overlap (e.g. # they are pointed at different rows of a matrix). # # CHECK FOR INPUT ALIASING # OPT: pre-compute this on import tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', []) assert isinstance(tolerate_same, list) tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same if idx0 == destroyed_idx) tolerated.add(destroyed_idx) tolerate_aliased = getattr( app.op, 'destroyhandler_tolerate_aliased', []) assert isinstance(tolerate_aliased, list) ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx) # print 'tolerated', tolerated # print 'ignored', ignored for i, input in enumerate(app.inputs): if i in ignored: continue if input in root_impact \ and (i not in tolerated or input is not destroyed_variable): raise InconsistencyError("Input aliasing: %s (%i, %i)" % (app, destroyed_idx, i)) # add the rule: app must be preceded by all other Apply instances that # depend on destroyed_input root_clients = OrderedSet() for r in root_impact: assert not [a for a, c in self.clients[r].items() if not c] root_clients.update([a for a, c in self.clients[r].items() if c]) root_clients.remove(app) if root_clients: rval[app] = root_clients return rval
class DestroyHandler(toolbox.Bookkeeper): # noqa """ The DestroyHandler class detects when a graph is impossible to evaluate because of aliasing and destructive operations. Several data structures are used to do this. An Op can use its view_map property to declare that an output may be aliased to an input. If that output is destroyed, the input is also considered to be destroyed. The view_maps of several Ops can feed into one another and form a directed graph. The consequence of destroying any variable in such a graph is that all variables in the graph must be considered to be destroyed, because they could all be refering to the same underlying storage. In the current implementation, that graph is a tree, and the root of that tree is called the foundation. TODO: why "in the current implementation" ? is there another implementation planned? TODO: why is the graph a tree? isn't it possible that one variable could be aliased to many variables? for example, don't switch and ifelse have to do this? The original DestroyHandler (if 0'ed out above) computed several data structures from scratch each time it was asked to validate the graph. Because this happens potentially thousands of times and each graph to validate is extremely similar to the previous one, computing the data structures from scratch repeatedly was wasteful and resulted in high compile times for large graphs. This implementation computes the data structures once at initialization and then incrementally updates them. It is a work in progress. The following data structures have been converted to use the incremental strategy: <none> The following data structures remain to be converted: <unknown> """ pickle_rm_attr = ["destroyers"] def __init__(self, do_imports_on_attach=True): self.fgraph = None self.do_imports_on_attach = do_imports_on_attach """ Maps every variable in the graph to its "foundation" (deepest ancestor in view chain). TODO: change name to var_to_vroot. """ self.droot = OrderedDict() """ Maps a variable to all variables that are indirect or direct views of it (including itself) essentially the inverse of droot. TODO: do all variables appear in this dict, or only those that are foundations? TODO: do only destroyed variables go in here? one old docstring said so. TODO: rename to x_to_views after reverse engineering what x is """ self.impact = OrderedDict() """ If a var is destroyed, then this dict will map droot[var] to the apply node that destroyed var TODO: rename to vroot_to_destroyer """ self.root_destroyer = OrderedDict() def on_attach(self, fgraph): """ When attaching to a new fgraph, check that 1) This DestroyHandler wasn't already attached to some fgraph (its data structures are only set up to serve one). 2) The FunctionGraph doesn't already have a DestroyHandler. This would result in it validating everything twice, causing compilation to be slower. Give the FunctionGraph instance: 1) A new method "destroyers(var)" TODO: what does this do exactly? 2) A new attribute, "destroy_handler" TODO: WRITEME: what does this do besides the checks? """ # Do the checking # already_there = False if self.fgraph is fgraph: already_there = True if self.fgraph is not None: raise Exception( "A DestroyHandler instance can only serve one" " FunctionGraph. (Matthew 6:24)") for attr in ('destroyers', 'destroy_handler'): if hasattr(fgraph, attr): already_there = True if already_there: # FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment raise toolbox.AlreadyThere( "DestroyHandler feature is already present" " or in conflict with another plugin.") # Annotate the FunctionGraph # self.unpickle(fgraph) fgraph.destroy_handler = self self.fgraph = fgraph self.destroyers = OrderedSet() # set of Apply instances with non-null destroy_map self.view_i = OrderedDict() # variable -> variable used in calculation self.view_o = OrderedDict() # variable -> set of variables that use this one as a direct input # clients: how many times does an apply use a given variable self.clients = OrderedDict() # variable -> apply -> ninputs self.stale_droot = True self.debug_all_apps = OrderedSet() if self.do_imports_on_attach: toolbox.Bookkeeper.on_attach(self, fgraph) def unpickle(self, fgraph): def get_destroyers_of(r): droot, impact, root_destroyer = self.refresh_droot_impact() try: return [root_destroyer[droot[r]]] except Exception: return [] fgraph.destroyers = get_destroyers_of def refresh_droot_impact(self): """ Makes sure self.droot, self.impact, and self.root_destroyer are up to date, and returns them (see docstrings for these properties above). """ if self.stale_droot: self.droot, self.impact, self.root_destroyer =\ _build_droot_impact(self) self.stale_droot = False return self.droot, self.impact, self.root_destroyer def on_detach(self, fgraph): if fgraph is not self.fgraph: raise Exception("detaching wrong fgraph", fgraph) del self.destroyers del self.view_i del self.view_o del self.clients del self.stale_droot assert self.fgraph.destroyer_handler is self delattr(self.fgraph, 'destroyers') delattr(self.fgraph, 'destroy_handler') self.fgraph = None def on_import(self, fgraph, app, reason): """ Add Apply instance to set which must be computed. """ if app in self.debug_all_apps: raise ProtocolError("double import") self.debug_all_apps.add(app) # print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps) # If it's a destructive op, add it to our watch list if getattr(app.op, 'destroy_map', {}): self.destroyers.add(app) # add this symbol to the forward and backward maps for o_idx, i_idx_list in iteritems(getattr(app.op, 'view_map', {})): if len(i_idx_list) > 1: raise NotImplementedError( 'destroying this output invalidates multiple inputs', (app. op)) o = app.outputs[o_idx] i = app.inputs[i_idx_list[0]] self.view_i[o] = i self.view_o.setdefault(i, OrderedSet()).add(o) # update self.clients for i, input in enumerate(app.inputs): self.clients.setdefault(input, OrderedDict()).setdefault(app, 0) self.clients[input][app] += 1 for i, output in enumerate(app.outputs): self.clients.setdefault(output, OrderedDict()) self.stale_droot = True def on_prune(self, fgraph, app, reason): """ Remove Apply instance from set which must be computed. """ if app not in self.debug_all_apps: raise ProtocolError("prune without import") self.debug_all_apps.remove(app) # UPDATE self.clients for i, input in enumerate(OrderedSet(app.inputs)): del self.clients[input][app] if getattr(app.op, 'destroy_map', OrderedDict()): self.destroyers.remove(app) # Note: leaving empty client dictionaries in the struct. # Why? It's a pain to remove them. I think they aren't doing any harm, they will be # deleted on_detach(). # UPDATE self.view_i, self.view_o for o_idx, i_idx_list in iteritems(getattr(app.op, 'view_map', OrderedDict())): if len(i_idx_list) > 1: # destroying this output invalidates multiple inputs raise NotImplementedError() o = app.outputs[o_idx] i = app.inputs[i_idx_list[0]] del self.view_i[o] self.view_o[i].remove(o) if not self.view_o[i]: del self.view_o[i] self.stale_droot = True def on_change_input(self, fgraph, app, i, old_r, new_r, reason): """ app.inputs[i] changed from old_r to new_r. """ if app == 'output': # app == 'output' is special key that means FunctionGraph is redefining which nodes are being # considered 'outputs' of the graph. pass else: if app not in self.debug_all_apps: raise ProtocolError("change without import") # UPDATE self.clients self.clients[old_r][app] -= 1 if self.clients[old_r][app] == 0: del self.clients[old_r][app] self.clients.setdefault(new_r, OrderedDict()).setdefault(app, 0) self.clients[new_r][app] += 1 # UPDATE self.view_i, self.view_o for o_idx, i_idx_list in iteritems(getattr(app.op, 'view_map', OrderedDict())): if len(i_idx_list) > 1: # destroying this output invalidates multiple inputs raise NotImplementedError() i_idx = i_idx_list[0] output = app.outputs[o_idx] if i_idx == i: if app.inputs[i_idx] is not new_r: raise ProtocolError("wrong new_r on change") self.view_i[output] = new_r self.view_o[old_r].remove(output) if not self.view_o[old_r]: del self.view_o[old_r] self.view_o.setdefault(new_r, OrderedSet()).add(output) self.stale_droot = True def validate(self, fgraph): """ Return None. Raise InconsistencyError when a) orderings() raises an error b) orderings cannot be topologically sorted. """ if self.destroyers: ords = self.orderings(fgraph) if _contains_cycle(fgraph, ords): raise InconsistencyError("Dependency graph contains cycles") else: # James's Conjecture: # If there are no destructive ops, then there can be no cycles. # FB: This isn't always True. It can happend that # optimization introduce node that depend on itself. This # is very rare and should not happen in general. It will be # caught later. The error will be far from the source. But # doing this conjecture should speed up compilation most of # the time. The user should create such dependency except # if he mess too much with the internal. pass return True def orderings(self, fgraph): """ Return orderings induced by destructive operations. Raise InconsistencyError when a) attempting to destroy indestructable variable, or b) attempting to destroy a value multiple times, or c) an Apply destroys (illegally) one of its own inputs by aliasing """ rval = OrderedDict() if self.destroyers: # BUILD DATA STRUCTURES # CHECK for multiple destructions during construction of variables droot, impact, __ignore = self.refresh_droot_impact() # check for destruction of constants illegal_destroy = [r for r in droot if getattr(r.tag, 'indestructible', False) or isinstance(r, graph.Constant)] if illegal_destroy: raise InconsistencyError( "Attempting to destroy indestructible variables: %s" % illegal_destroy) # add destroyed variable clients as computational dependencies for app in self.destroyers: # for each destroyed input... for output_idx, input_idx_list in iteritems(app.op.destroy_map): destroyed_idx = input_idx_list[0] destroyed_variable = app.inputs[destroyed_idx] root = droot[destroyed_variable] root_impact = impact[root] # we generally want to put all clients of things which depend on root # as pre-requisites of app. # But, app is itself one such client! # App will always be a client of the node we're destroying # (destroyed_variable, but the tricky thing is when it is also a client of # *another variable* viewing on the root. Generally this is illegal, (e.g., # add_inplace(x, x.T). In some special cases though, the in-place op will # actually be able to work properly with multiple destroyed inputs (e.g, # add_inplace(x, x). An Op that can still work in this case should declare # so via the 'destroyhandler_tolerate_same' attribute or # 'destroyhandler_tolerate_aliased' attribute. # # destroyhandler_tolerate_same should be a list of pairs of the form # [(idx0, idx1), (idx0, idx2), ...] # The first element of each pair is the input index of a destroyed # variable. # The second element of each pair is the index of a different input where # we will permit exactly the same variable to appear. # For example, add_inplace.tolerate_same might be [(0,1)] if the destroyed # input is also allowed to appear as the second argument. # # destroyhandler_tolerate_aliased is the same sort of list of # pairs. # op.destroyhandler_tolerate_aliased = [(idx0, idx1)] tells the # destroyhandler to IGNORE an aliasing between a destroyed # input idx0 and another input idx1. # This is generally a bad idea, but it is safe in some # cases, such as # - the op reads from the aliased idx1 before modifying idx0 # - the idx0 and idx1 are guaranteed not to overlap (e.g. # they are pointed at different rows of a matrix). # # CHECK FOR INPUT ALIASING # OPT: pre-compute this on import tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', []) assert isinstance(tolerate_same, list) tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same if idx0 == destroyed_idx) tolerated.add(destroyed_idx) tolerate_aliased = getattr( app.op, 'destroyhandler_tolerate_aliased', []) assert isinstance(tolerate_aliased, list) ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx) # print 'tolerated', tolerated # print 'ignored', ignored for i, input in enumerate(app.inputs): if i in ignored: continue if input in root_impact \ and (i not in tolerated or input is not destroyed_variable): raise InconsistencyError("Input aliasing: %s (%i, %i)" % (app, destroyed_idx, i)) # add the rule: app must be preceded by all other Apply instances that # depend on destroyed_input root_clients = OrderedSet() for r in root_impact: assert not [a for a, c in self.clients[r].items() if not c] root_clients.update([a for a, c in self.clients[r].items() if c]) root_clients.remove(app) if root_clients: rval[app] = root_clients return rval
def scan( fn, sequences=None, states=None, params=None, n_steps=None, mode=None, name=None, profile=False, allow_gc=None ): """ Similar to Theano's official scan, this function gives the user more control over the scan op, avoiding certain difficulties that arose from missing optimizations. :param fn: lambda function that describes one step of scan (see the official Theano scan function) :param sequences: similar to the official Theano's scan. This version of scan does not support taps for the sequences (it can only be a list of tensor). Scan assumes that sequences have the right length and it does not check for this. :param states: similar to outputs_info of the official scan function. There is one crucial difference though, namely that the `initial` key in the dictionary has been replace by 'membuf' key. This reflects the change of meaning. Instead of passing to scan just the initial steps misisng, one has now to pass a memory buffer in which scan will try to store its output. In this memory buffer the first entries should be set to the initial states of the corresponding states. Providing a memory buffer that has less entries then the number of steps, mneans scan will only use that amount of memory. The user has to match the memory buffer size with the number of steps, otherwise scan will produce wrong results. Also if gradients are to be computed through the scan, the memory buffer should have the same length as the number of steps. For states that do not require a initial state, one has to provide a dictionary with a single key 'steps' that says how many intermediate results to store. See examples below for more insight. :param n_steps: This parameter is mandatory and it will represent the number of steps scan will do (scan will not check sequences or any other source of information to figure out how many steps it needs to do). :param mode: Same as for the official scan :param name: Same as for the official scan :param profile: Same as for the official scan Note: - there is no truncate / go_backwards anymore ! - the outputs returned by scan contain the initial states as well (i.e. if I loop over k steps, with my smallest tap for an output -3 and keep al intermediate results, my output will be of length k+3 Examples: (a) if you do not want to store any intermediate results (just the last one) # The memory buffer can be the initial state, just that we need to # add one extra dimension in front of it state = TT.unbroadcast(TT.shape_padleft(x0),0) out,_ = scan(lambda x:x+1, states = state, n_steps = 5) # Once we got our result we need to remove the extra dimension out = out[0] (b) if you want to keep every intermediate results state = TT.alloc(TT.constant(0), 6, x0.shape[0]) state = TT.set_subtensor(state[0], x0) out,_ = scan(lambda x:x+1, states = state, n_steps = 5) out = out[1:] """ def wrap_into_list(x): """ Wrap the input into a list if it is not already a list """ if x is None: return [] elif not isinstance(x, (list, tuple)): return [x] else: return list(x) seqs = wrap_into_list(sequences) outs_info = wrap_into_list(states) if allow_gc is None: allow_gc = config.scan.allow_gc # Make sure we get rid of numpy arrays or ints or anything like that # passed as inputs to scan non_seqs = [] for elem in wrap_into_list(params): if not isinstance(elem, gof.Variable): non_seqs.append(tensor.as_tensor_variable(elem)) else: non_seqs.append(elem) # If we provided a known number of steps ( before compilation) # and if that number is 1 or -1, then we can skip the Scan Op, # and just apply the inner function once # To do that we check here to see the nature of n_steps n_fixed_steps = None if isinstance(n_steps, (float, int)): n_fixed_steps = int(n_steps) else: try: n_fixed_steps = opt.get_scalar_constant_value(n_steps) except tensor.basic.NotScalarConstantError: n_fixed_steps = None # Check n_steps is an int if hasattr(n_steps, "dtype") and str(n_steps.dtype)[:3] not in ("uin", "int"): raise ValueError(" n_steps must be an int. dtype provided " "is %s" % n_steps.dtype) # compute number of sequences and number of outputs n_seqs = len(seqs) n_outs = len(outs_info) return_steps = OrderedDict() # wrap outputs info in a dictionary if they are not already in one for i in xrange(n_outs): if outs_info[i] is not None: if not isinstance(outs_info[i], dict): # by default any output has a tap value of -1 outs_info[i] = dict(membuf=outs_info[i], taps=[-1]) elif not outs_info[i].get("membuf", None) and outs_info[i].get("taps", None): # ^ no initial state but taps provided raise ValueError( ("If you are using slices of an output " "you need to provide a memory buffer for " "the state "), outs_info[i], ) elif outs_info[i].get("membuf", None) and not outs_info[i].get("taps", None): # ^ initial state but taps not provided if "taps" in outs_info[i]: # ^ explicitly provided a None for taps _logger.warning( "Output %s (index %d) has a memory " "buffer but taps is explicitly set to None ", getattr(outs_info[i]["membuf"], "name", "None"), i, ) outs_info[i]["taps"] = [-1] else: # if a None is provided as the output info we replace it # with an dict(steps=n_steps) to simplify handling outs_info[i] = dict(steps=n_steps) ## # Step 2. Generate inputs and outputs of the inner functions # for compiling a dummy function (Iteration #1) ## # create theano inputs for the recursive function # note : this is a first batch of possible inputs that will # be compiled in a dummy function; we used this dummy # function to detect shared variables and their updates # and to construct a new and complete list of inputs and # outputs n_seqs = 0 scan_seqs = [] # Variables passed as inputs to the scan op inner_seqs = [] # Variables passed as inputs to the inner function inner_slices = [] # Actual slices if scan is removed from the picture # go through sequences picking up time slices as needed for i, seq in enumerate(seqs): if isinstance(seq, dict): seq = seq["input"] actual_slice = seq[0] _seq_val = tensor.as_tensor_variable(seq) _seq_val_slice = _seq_val[0] nw_slice = _seq_val_slice.type() # Try to transfer test_value to the new variable if config.compute_test_value != "off": try: nw_slice.tag.test_value = gof.Op._get_test_value(_seq_val_slice) except AttributeError as e: if config.compute_test_value != "ignore": # No need to print a warning or raise an error now, # it will be done when fn will be called. _logger.info( ("Cannot compute test value for " "the inner function of scan, input value " "missing %s"), e ) if seq.name: nw_slice.name = seq.name + "[t]" scan_seqs.append(_seq_val) inner_seqs.append(nw_slice) inner_slices.append(actual_slice) n_seqs += 1 actual_n_steps = tensor.as_tensor(n_steps) # Conventions : # mit_mot = multiple input taps, multiple output taps ( only provided # by the gradient function ) # mit_sot = multiple input taps, single output tap (t + 0) # sit_sot = single input tap, single output tap (t + 0) # nit_sot = no input tap, single output tap (t + 0) # MIT_MOT -- not provided by the user only by the grad function n_mit_mot = 0 n_mit_mot_outs = 0 mit_mot_scan_inputs = [] mit_mot_inner_inputs = [] mit_mot_inner_outputs = [] mit_mot_out_slices = [] mit_mot_rightOrder = [] # SIT_SOT -- provided by the user n_mit_sot = 0 mit_sot_scan_inputs = [] mit_sot_inner_inputs = [] mit_sot_inner_slices = [] mit_sot_inner_outputs = [] mit_sot_return_steps = OrderedDict() mit_sot_tap_array = [] mit_sot_rightOrder = [] n_sit_sot = 0 sit_sot_scan_inputs = [] sit_sot_inner_inputs = [] sit_sot_inner_slices = [] sit_sot_inner_outputs = [] sit_sot_return_steps = OrderedDict() sit_sot_rightOrder = [] nit_sot_steps = [] # go through outputs picking up time slices as needed for i, init_out in enumerate(outs_info): # Note that our convention dictates that if an output uses # just the previous time step, as a initial state we will only # provide a tensor of the same dimension as one time step; This # makes code much cleaner for those who do not use taps. Otherwise # they would always had to shape_padleft the initial state .. # which is ugly # Note, 'taps' might not be in the dictionary if "taps" in init_out and init_out["taps"] == [-1]: actual_arg = init_out["membuf"] arg = safe_new(init_out["membuf"][0]) if isinstance(arg, tensor.Constant): # safe new returns a clone of the constants, but that is not # what we need for initial states arg = arg.type() # Try to transfer test_value to the new variable if config.compute_test_value != "off": try: arg.tag.test_value = gof.Op._get_test_value(actual_arg) except AttributeError as e: if config.compute_test_value != "ignore": # No need to print a warning or raise an error now, # it will be done when fn will be called. _logger.info( ("Cannot compute test value for the " "inner function of scan, input value missing %s"), e ) if getattr(init_out["membuf"], "name", None) is not None: arg.name = init_out["membuf"].name + "[t-1]" # We need now to allocate space for storing the output and copy # the initial state over. We do this using the expand function # defined in scan utils sit_sot_scan_inputs.append(actual_arg) sit_sot_inner_slices.append(actual_arg[0]) if i in return_steps: sit_sot_return_steps[n_sit_sot] = return_steps[i] sit_sot_inner_inputs.append(arg) sit_sot_rightOrder.append(i) n_sit_sot += 1 elif init_out.get("taps", None): if numpy.any(numpy.array(init_out.get("taps", [])) > 0): # Make sure we do not have requests for future values of a # sequence we can not provide such values raise ValueError("Can not use future taps of outputs", init_out) # go through the taps mintap = abs(numpy.min(init_out["taps"])) mit_sot_tap_array.append(init_out["taps"]) idx_offset = abs(numpy.min(init_out["taps"])) # Sequence mit_sot_scan_inputs.append(init_out["membuf"]) if i in return_steps: mit_sot_return_steps[n_mit_sot] = return_steps[i] mit_sot_rightOrder.append(i) n_mit_sot += 1 for k in init_out["taps"]: # create a new slice actual_nw_slice = init_out["membuf"][k + mintap] _init_out_var = tensor.as_tensor_variable(init_out["membuf"]) _init_out_var_slice = _init_out_var[k + mintap] nw_slice = _init_out_var_slice.type() # Try to transfer test_value to the new variable if config.compute_test_value != "off": try: nw_slice.tag.test_value = gof.Op._get_test_value(_init_out_var_slice) except AttributeError as e: if config.compute_test_value != "ignore": # No need to print a warning or raise an error now, # it will be done when fn will be called. _logger.info( ( "Cannot compute test value for " "the inner function of scan, input value " "missing. %s" ), e, ) # give it a name or debugging and pretty printing if getattr(init_out["membuf"], "name", None) is not None: if k > 0: nw_slice.name = init_out["membuf"].name + "[t+%d]" % k elif k == 0: nw_slice.name = init_out["membuf"].name + "[t]" else: nw_slice.name = init_out["membuf"].name + "[t%d]" % k mit_sot_inner_inputs.append(nw_slice) mit_sot_inner_slices.append(actual_nw_slice) else: pass # Re-order args max_mit_sot = numpy.max([-1] + mit_sot_rightOrder) + 1 max_sit_sot = numpy.max([-1] + sit_sot_rightOrder) + 1 n_elems = numpy.max([max_mit_sot, max_sit_sot]) _ordered_args = [[] for x in xrange(n_elems)] offset = 0 for idx in xrange(n_mit_sot): n_inputs = len(mit_sot_tap_array[idx]) if n_fixed_steps == 1: _ordered_args[mit_sot_rightOrder[idx]] = mit_sot_inner_slices[offset : offset + n_inputs] else: _ordered_args[mit_sot_rightOrder[idx]] = mit_sot_inner_inputs[offset : offset + n_inputs] offset += n_inputs for idx in xrange(n_sit_sot): if n_fixed_steps == 1: _ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_slices[idx]] else: _ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_inputs[idx]] ordered_args = [] for ls in _ordered_args: ordered_args += ls if n_fixed_steps == 1: args = inner_slices + ordered_args + non_seqs else: args = inner_seqs + ordered_args + non_seqs # add only the non-shared variables and non-constants to the arguments of # the dummy function [ a function should not get shared variables or # constants as input ] dummy_args = [arg for arg in args if (not isinstance(arg, SharedVariable) and not isinstance(arg, tensor.Constant))] # when we apply the lambda expression we get a mixture of update rules # and outputs that needs to be separated lambda_result = fn(*args) condition, outputs, updates = scan_utils.get_updates_and_outputs(lambda_result) if condition is not None: as_while = True else: as_while = False ## # Step 3. Check if we actually need scan and remove it if we don't ## if n_fixed_steps == 1: # We do not need to use the scan op anymore, so we can just return # the outputs and updates we have if condition is not None: _logger.warning( ( "When the number of steps is fixed and equal " "to 1, the provided stopping condition, ", str(condition), " is ignored", ) ) for pos, inner_out in enumerate(outputs): # we need to see if we need to pad our sequences with an # unbroadcastable dimension; case example : we return an # output for which we want all intermediate. If n_steps is 1 # then, if we return the output as given by the innner function # this will represent only a slice and it will have one # dimension less. if isinstance(inner_out.type, tensor.TensorType) and return_steps.get(pos, 0) != 1: outputs[pos] = tensor.unbroadcast(tensor.shape_padleft(inner_out), 0) if len(outputs) == 1: outputs = outputs[0] return (outputs, updates) ## # Step 4. Compile the dummy function ## # We can now compile a dummy function just to see what shared variable # we have and what are their update rules (note that the user has # the option not to pass the shared variable to scan, so we need to # pick them manually and add them to scan) # make the compilation as fast as possible by not applying any # optimization or conversion to C [ note this region is not important # for performance so we can do stuff as unoptimal as we wish ] # extract still missing inputs (there still might be so) and add them # as non sequences at the end of our args fake_nonseqs = [x.type() for x in non_seqs] fake_outputs = scan_utils.clone(outputs + updates.values(), replace=dict(zip(non_seqs, fake_nonseqs))) all_inputs = itertools.ifilter( lambda x: ( isinstance(x, gof.Variable) and not isinstance(x, SharedVariable) and not isinstance(x, gof.Constant) ), gof.graph.inputs(fake_outputs), ) extra_inputs = filter(lambda x: x not in args + fake_nonseqs, all_inputs) non_seqs += extra_inputs # Note we do not use all_inputs directly since the order of variables # in args is quite important dummy_args += extra_inputs dummy_outs = outputs if condition is not None: dummy_outs.append(condition) # If we use a regular dict here, the results are non-deterministic if not isinstance(updates, (list, tuple)): if isinstance(updates, dict) and not isinstance(updates, OrderedDict): warnings.warn("Using non-deterministic dictionary.") dummy_f = function( dummy_args, dummy_outs, updates=updates, mode=compile.mode.Mode(linker="py", optimizer=None), on_unused_input="ignore", ) ## # Step 5. Re-arange inputs of scan into a more strict order ## # Step 5.0 Check the outputs of the dummy function to see if they # match with user provided data # if the number of outputs to the function does not match the number of # assumed outputs until now (provided by the user) there can be # only one explanation: No information is provided for any of the # outputs (i.e. we are dealing with a map) tmp_dummy_f_outs = len(dummy_f.maker.outputs) if as_while: tmp_dummy_f_outs -= 1 if not (tmp_dummy_f_outs == n_outs or outs_info == []): raise ValueError( "Please provide None as output_info for " "any output that does not feed back into " "scan (i.e. it behaves like a map) " ) if outs_info == []: n_outs = len(dummy_f.maker.outputs) if as_while: n_outs = n_outs - 1 outs_info = [dict(steps=n_steps) for x in xrange(n_outs)] # Step 5.1 Outputs with taps different then -1 for i, out in enumerate(outs_info): if "taps" in out and out["taps"] != [-1]: mit_sot_inner_outputs.append(outputs[i]) # Step 5.2 Outputs with tap equal to -1 for i, out in enumerate(outs_info): if "taps" in out and out["taps"] == [-1]: sit_sot_inner_outputs.append(outputs[i]) # Step 5.3 Outputs that correspond to update rules of shared variables givens = OrderedDict() n_shared_outs = 0 shared_scan_inputs = [] shared_inner_inputs = [] shared_inner_outputs = [] for input in dummy_f.maker.expanded_inputs: if isinstance(input.variable, SharedVariable) and input.update: new_var = safe_new(input.variable) if getattr(input.variable, "name", None) is not None: new_var.name = input.variable.name + "_copy" shared_inner_inputs.append(new_var) shared_scan_inputs.append(input.variable) shared_inner_outputs.append(input.update) givens[input.variable] = new_var n_shared_outs += 1 # Step 5.4 Outputs with no taps used in the input n_nit_sot = 0 nit_sot_inner_outputs = [] nit_sot_return_steps = OrderedDict() nit_sot_rightOrder = [] for i, out in enumerate(outs_info): if not "taps" in out: nit_sot_inner_outputs.append(outputs[i]) if i in return_steps: nit_sot_return_steps[n_nit_sot] = return_steps[i] nit_sot_rightOrder.append(i) nit_sot_steps.append(out["steps"]) n_nit_sot += 1 # Step 5.5 all other arguments including extra inputs other_scan_args = [] other_inner_args = [] other_scan_args += [ arg for arg in non_seqs if (not isinstance(arg, SharedVariable) and not isinstance(arg, tensor.Constant)) ] # Step 5.6 all shared variables with no update rules other_inner_args += [ safe_new(arg, "_copy") for arg in non_seqs if (not isinstance(arg, SharedVariable) and not isinstance(arg, tensor.Constant)) ] givens.update(dict(zip(other_scan_args, other_inner_args))) other_shared_scan_args = [ arg.variable for arg in dummy_f.maker.expanded_inputs if (isinstance(arg.variable, SharedVariable) and not arg.update) ] other_shared_inner_args = [ safe_new(arg.variable, "_copy") for arg in dummy_f.maker.expanded_inputs if (isinstance(arg.variable, SharedVariable) and not arg.update) ] givens.update(dict(zip(other_shared_scan_args, other_shared_inner_args))) ## # Step 6. Re-order the outputs and clone them replacing things # using the givens ## inner_inputs = ( inner_seqs + mit_mot_inner_inputs + mit_sot_inner_inputs + sit_sot_inner_inputs + shared_inner_inputs + other_shared_inner_args + other_inner_args ) inner_outs = ( mit_mot_inner_outputs + mit_sot_inner_outputs + sit_sot_inner_outputs + nit_sot_inner_outputs + shared_inner_outputs ) if condition is not None: inner_outs.append(condition) new_givens = OrderedDict() for w, w_copy in givens.iteritems(): new_givens[w] = w.type.filter_variable(w_copy) new_outs = scan_utils.clone(inner_outs, replace=new_givens) ## # Step 7. Create the Scan Op ## tap_array = mit_sot_tap_array + [[-1] for x in xrange(n_sit_sot)] info = OrderedDict() info["tap_array"] = tap_array info["n_seqs"] = n_seqs info["n_mit_mot"] = n_mit_mot info["n_mit_mot_outs"] = n_mit_mot_outs info["mit_mot_out_slices"] = mit_mot_out_slices info["n_mit_sot"] = n_mit_sot info["n_sit_sot"] = n_sit_sot info["n_shared_outs"] = n_shared_outs info["n_nit_sot"] = n_nit_sot info["truncate_gradient"] = -1 info["name"] = name info["mode"] = mode info["destroy_map"] = OrderedDict() info["inplace"] = False info["gpu"] = False info["as_while"] = as_while info["profile"] = profile info["_scan_savemem_visited"] = True info["allow_gc"] = allow_gc local_op = scan_op.Scan(inner_inputs, new_outs, info) ## # Step 8. Compute the outputs using the scan op ## _scan_inputs = ( scan_seqs + mit_mot_scan_inputs + mit_sot_scan_inputs + sit_sot_scan_inputs + shared_scan_inputs + nit_sot_steps + other_shared_scan_args + other_scan_args ) scan_inputs = [] for arg in [actual_n_steps] + _scan_inputs: if not isinstance(arg, gof.Variable): arg = tensor.as_tensor_variable(arg) scan_inputs += [arg] scan_outs = local_op(*scan_inputs) if type(scan_outs) not in (list, tuple): scan_outs = [scan_outs] ## # Step 9. Figure out which outs are update rules for shared variables # and so on ... ## update_map = OrderedUpdates() offset = n_mit_mot offsets = [abs(numpy.min(x)) for x in mit_sot_tap_array] mit_sot_outs = scan_outs[offset : offset + n_mit_sot] offset += n_mit_sot offsets = [1 for x in xrange(n_sit_sot)] sit_sot_outs = scan_outs[offset : offset + n_sit_sot] offset += n_sit_sot nit_sot_outs = scan_outs[offset : offset + n_nit_sot] offset += n_nit_sot for idx, update_rule in enumerate(scan_outs[offset : offset + n_shared_outs]): update_map[shared_scan_inputs[idx]] = update_rule _scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs # Step 10. I need to reorder the outputs to be in the order expected by # the user rightOrder = mit_sot_rightOrder + sit_sot_rightOrder + nit_sot_rightOrder scan_out_list = [None] * len(rightOrder) for idx, pos in enumerate(rightOrder): scan_out_list[pos] = _scan_out_list[idx] if len(scan_out_list) == 1: scan_out_list = scan_out_list[0] elif len(scan_out_list) == 0: scan_out_list = None assert isinstance(update_map, OrderedDict) return (scan_out_list, update_map)
class Rebroadcast(gof.Op): """ Change the input's broadcastable fields in some predetermined way. See Also -------- unbroadcast <theano.tensor.unbroadcast> addbroadcast <theano.tensor.addbroadcast> patternbroadcast <theano.tensor.patternbroadcast> Notes ----- Works inplace and works for CudaNdarrayType. Example ------- `Rebroadcast((0, True), (1, False))(x)` would make `x` broadcastable in axis 0 and not broadcastable in axis 1. """ view_map = {0: [0]} _f16_ok = True # Mapping from Type to C code (and version) to use. # In the C code, the name of the input variable is %(iname)s, # the output variable is %(oname)s. c_code_and_version = {} check_input = False __props__ = ("axis",) def __init__(self, *axis): # Sort them to make sure we merge all possible case. items = sorted(axis) self.axis = OrderedDict(items) for axis, broad in iteritems(self.axis): if not isinstance(axis, (numpy.integer, integer_types)): raise TypeError("Rebroadcast needs integer axes. " "Got {}".format(axis)) if not isinstance(broad, (numpy.bool_, bool)): raise TypeError("Rebroadcast needs bool for new broadcast " "pattern. Got {}".format(broad)) def __hash__(self): # Need special __hash__ as dict aren't hashable. # no ambiguity because each item key is unique items = sorted(iteritems(self.axis)) return hash((type(self), tuple(items))) def __str__(self): if len(self.axis) == 0: broadcast_pattern = [] else: broadcast_pattern = ['?' for i in xrange(1 + max(self.axis.keys()))] for k, v in iteritems(self.axis): broadcast_pattern[k] = str(int(v)) return '%s{%s}' % (self.__class__.__name__, ','.join(broadcast_pattern)) def make_node(self, x): if self.axis.keys() and (x.ndim <= max(self.axis.keys())): raise ValueError('Trying to rebroadcast non-existent dimension') t = x.type.clone( broadcastable=[self.axis.get(i, b) for i, b in enumerate(x.type.broadcastable)]) return gof.Apply(self, [x], [t()]) def perform(self, node, inp, out_): x, = inp out, = out_ for axis, value in iteritems(self.axis): if value and x.shape[axis] != 1: raise ValueError('Dimension %s in Rebroadcast\'s input was' ' supposed to be 1 (got %s instead)' % (axis, x.shape[axis])) out[0] = x def grad(self, inp, grads): x, = inp gz, = grads # restore the broadcasting pattern of the input return Rebroadcast(*[(axis, x.type.broadcastable[axis]) for axis, value in iteritems(self.axis)])(gz), def infer_shape(self, node, ishapes): assert len(ishapes) == 1 l = [] one = theano.tensor.basic.constant(1) for ax in xrange(len(ishapes[0])): if self.axis.get(ax, False): l.append(one) else: l.append(ishapes[0][ax]) return [tuple(l)] def R_op(self, inputs, eval_points): if eval_points[0] is None: return [None] return self(*eval_points, **dict(return_list=True)) def c_code(self, node, nodename, inp, out, sub): iname, = inp oname, = out fail = sub['fail'] itype = node.inputs[0].type.__class__ if itype in self.c_code_and_version: code, version = self.c_code_and_version[itype] final_code = "" for axis, value in iteritems(self.axis): if value: final_code += code % locals() return final_code + """ Py_XDECREF(%(oname)s); %(oname)s = %(iname)s; Py_XINCREF(%(oname)s); """ % locals() return super(Rebroadcast, self).c_code(node, nodename, inp, out, sub) def c_code_cache_version(self): version = [] # If any of the c code is unversionned, we have to return () # Else, we will return a list of (type name, version) pairs. for t, (c, v) in sorted(iteritems(self.c_code_and_version), key=lambda pair: str(pair[0])): if not v: warnings.warn("Type %s has C code for Rebroadcast, but it " "has no version. You should add a 'version' " "keyword arg when calling " "register_rebroadcast_c_code." % t, stacklevel=2) return () version.append((str(t), v)) if version: version.append(1) return tuple(version)
def get_gradients(self, model, data, **kwargs): space, sources = self.get_data_specs(model) space.validate(data) assert isinstance(model, AdversaryPair) g = model.generator d = model.discriminator S, d_obj, g_obj, i_obj = self.get_samples_and_objectives(model, data) g_params = g.get_params() d_params = d.get_params() for param in g_params: assert param not in d_params for param in d_params: assert param not in g_params d_grads = T.grad(d_obj, d_params) g_grads = T.grad(g_obj, g_params) if self.scale_grads: S_grad = T.grad(g_obj, S) scale = T.maximum(1., self.target_scale / T.sqrt(T.sqr(S_grad).sum())) g_grads = [g_grad * scale for g_grad in g_grads] rval = OrderedDict() zeros = itertools.repeat(theano.tensor.constant(0., dtype='float32')) if self.ever_train_discriminator: rval.update(OrderedDict(safe_zip(d_params, [self.now_train_discriminator * dg for dg in d_grads]))) else: rval.update(OrderedDict(zip(d_params, zeros))) if self.ever_train_generator: rval.update(OrderedDict(safe_zip(g_params, [self.now_train_generator * gg for gg in g_grads]))) else: rval.update(OrderedDict(zip(g_params, zeros))) if self.ever_train_inference and model.inferer is not None: i_params = model.inferer.get_params() i_grads = T.grad(i_obj, i_params) rval.update(OrderedDict(safe_zip(i_params, [self.now_train_inference * ig for ig in i_grads]))) elif model.inferer is not None: rval.update(OrderedDict(model.inferer.get_params(), zeros)) updates = OrderedDict() # Two d steps for every g step if self.alternate_g: updates[self.now_train_generator] = 1. - self.now_train_generator return rval, updates
class DestroyHandler(toolbox.Bookkeeper): # noqa """ The DestroyHandler class detects when a graph is impossible to evaluate because of aliasing and destructive operations. Several data structures are used to do this. An Op can use its view_map property to declare that an output may be aliased to an input. If that output is destroyed, the input is also considered to be destroyed. The view_maps of several Ops can feed into one another and form a directed graph. The consequence of destroying any variable in such a graph is that all variables in the graph must be considered to be destroyed, because they could all be refering to the same underlying storage. In the current implementation, that graph is a tree, and the root of that tree is called the foundation. TODO: why "in the current implementation" ? is there another implementation planned? TODO: why is the graph a tree? isn't it possible that one variable could be aliased to many variables? for example, don't switch and ifelse have to do this? The original DestroyHandler (if 0'ed out above) computed several data structures from scratch each time it was asked to validate the graph. Because this happens potentially thousands of times and each graph to validate is extremely similar to the previous one, computing the data structures from scratch repeatedly was wasteful and resulted in high compile times for large graphs. This implementation computes the data structures once at initialization and then incrementally updates them. It is a work in progress. The following data structures have been converted to use the incremental strategy: <none> The following data structures remain to be converted: <unknown> """ pickle_rm_attr = ["destroyers"] def __init__(self, do_imports_on_attach=True): self.fgraph = None self.do_imports_on_attach = do_imports_on_attach """maps every variable in the graph to its "foundation" (deepest ancestor in view chain) TODO: change name to var_to_vroot""" self.droot = OrderedDict() """maps a variable to all variables that are indirect or direct views of it (including itself) essentially the inverse of droot TODO: do all variables appear in this dict, or only those that are foundations? TODO: do only destroyed variables go in here? one old docstring said so TODO: rename to x_to_views after reverse engineering what x is""" self.impact = OrderedDict() """if a var is destroyed, then this dict will map droot[var] to the apply node that destroyed var TODO: rename to vroot_to_destroyer""" self.root_destroyer = OrderedDict() def on_attach(self, fgraph): """ When attaching to a new fgraph, check that 1) This DestroyHandler wasn't already attached to some fgraph (its data structures are only set up to serve one) 2) The FunctionGraph doesn't already have a DestroyHandler. This would result in it validating everything twice, causing compilation to be slower. Give the FunctionGraph instance: 1) A new method "destroyers(var)" TODO: what does this do exactly? 2) A new attribute, "destroy_handler" TODO: WRITEME: what does this do besides the checks? """ # Do the checking # already_there = False if self.fgraph is fgraph: already_there = True if self.fgraph is not None: raise Exception( "A DestroyHandler instance can only serve one" " FunctionGraph. (Matthew 6:24)") for attr in ('destroyers', 'destroy_handler'): if hasattr(fgraph, attr): already_there = True if already_there: # FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment raise toolbox.AlreadyThere( "DestroyHandler feature is already present" " or in conflict with another plugin.") # Annotate the FunctionGraph # self.unpickle(fgraph) fgraph.destroy_handler = self self.fgraph = fgraph self.destroyers = OrderedSet() # set of Apply instances with non-null destroy_map self.view_i = OrderedDict() # variable -> variable used in calculation self.view_o = OrderedDict() # variable -> set of variables that use this one as a direct input # clients: how many times does an apply use a given variable self.clients = OrderedDict() # variable -> apply -> ninputs self.stale_droot = True self.debug_all_apps = OrderedSet() if self.do_imports_on_attach: toolbox.Bookkeeper.on_attach(self, fgraph) def unpickle(self, fgraph): def get_destroyers_of(r): droot, impact, root_destroyer = self.refresh_droot_impact() try: return [root_destroyer[droot[r]]] except Exception: return [] fgraph.destroyers = get_destroyers_of def refresh_droot_impact(self): """ Makes sure self.droot, self.impact, and self.root_destroyer are up to date, and returns them. (see docstrings for these properties above) """ if self.stale_droot: droot = OrderedDict() # destroyed view + nonview variables -> foundation impact = OrderedDict() # destroyed nonview variable -> it + all views of it root_destroyer = OrderedDict() # root -> destroyer apply for app in self.destroyers: for output_idx, input_idx_list in iteritems(app.op.destroy_map): if len(input_idx_list) != 1: raise NotImplementedError() input_idx = input_idx_list[0] input = app.inputs[input_idx] input_root = getroot(input, self.view_i) if input_root in droot: raise InconsistencyError( "Multiple destroyers of %s" % input_root) droot[input_root] = input_root root_destroyer[input_root] = app input_impact = get_impact(input_root, self.view_o) for v in input_impact: assert v not in droot droot[v] = input_root impact[input_root] = input_impact impact[input_root].add(input_root) self.droot, self.impact, self.root_destroyer = droot, impact, root_destroyer self.stale_droot = False return self.droot, self.impact, self.root_destroyer def on_detach(self, fgraph): if fgraph is not self.fgraph: raise Exception("detaching wrong fgraph", fgraph) del self.destroyers del self.view_i del self.view_o del self.clients del self.stale_droot assert self.fgraph.destroyer_handler is self delattr(self.fgraph, 'destroyers') delattr(self.fgraph, 'destroy_handler') self.fgraph = None def on_import(self, fgraph, app, reason): """Add Apply instance to set which must be computed""" if app in self.debug_all_apps: raise ProtocolError("double import") self.debug_all_apps.add(app) # print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps) # If it's a destructive op, add it to our watch list if getattr(app.op, 'destroy_map', {}): self.destroyers.add(app) # add this symbol to the forward and backward maps for o_idx, i_idx_list in iteritems(getattr(app.op, 'view_map', {})): if len(i_idx_list) > 1: raise NotImplementedError( 'destroying this output invalidates multiple inputs', (app. op)) o = app.outputs[o_idx] i = app.inputs[i_idx_list[0]] self.view_i[o] = i self.view_o.setdefault(i, OrderedSet()).add(o) # update self.clients for i, input in enumerate(app.inputs): self.clients.setdefault(input, OrderedDict()).setdefault(app, 0) self.clients[input][app] += 1 for i, output in enumerate(app.outputs): self.clients.setdefault(output, OrderedDict()) self.stale_droot = True def on_prune(self, fgraph, app, reason): """Remove Apply instance from set which must be computed""" if app not in self.debug_all_apps: raise ProtocolError("prune without import") self.debug_all_apps.remove(app) # UPDATE self.clients for i, input in enumerate(OrderedSet(app.inputs)): del self.clients[input][app] if getattr(app.op, 'destroy_map', OrderedDict()): self.destroyers.remove(app) # Note: leaving empty client dictionaries in the struct. # Why? It's a pain to remove them. I think they aren't doing any harm, they will be # deleted on_detach(). # UPDATE self.view_i, self.view_o for o_idx, i_idx_list in iteritems(getattr(app.op, 'view_map', OrderedDict())): if len(i_idx_list) > 1: # destroying this output invalidates multiple inputs raise NotImplementedError() o = app.outputs[o_idx] i = app.inputs[i_idx_list[0]] del self.view_i[o] self.view_o[i].remove(o) if not self.view_o[i]: del self.view_o[i] self.stale_droot = True def on_change_input(self, fgraph, app, i, old_r, new_r, reason): """app.inputs[i] changed from old_r to new_r """ if app == 'output': # app == 'output' is special key that means FunctionGraph is redefining which nodes are being # considered 'outputs' of the graph. pass else: if app not in self.debug_all_apps: raise ProtocolError("change without import") # UPDATE self.clients self.clients[old_r][app] -= 1 if self.clients[old_r][app] == 0: del self.clients[old_r][app] self.clients.setdefault(new_r, OrderedDict()).setdefault(app, 0) self.clients[new_r][app] += 1 # UPDATE self.view_i, self.view_o for o_idx, i_idx_list in iteritems(getattr(app.op, 'view_map', OrderedDict())): if len(i_idx_list) > 1: # destroying this output invalidates multiple inputs raise NotImplementedError() i_idx = i_idx_list[0] output = app.outputs[o_idx] if i_idx == i: if app.inputs[i_idx] is not new_r: raise ProtocolError("wrong new_r on change") self.view_i[output] = new_r self.view_o[old_r].remove(output) if not self.view_o[old_r]: del self.view_o[old_r] self.view_o.setdefault(new_r, OrderedSet()).add(output) self.stale_droot = True def validate(self, fgraph): """Return None Raise InconsistencyError when a) orderings() raises an error b) orderings cannot be topologically sorted. """ if self.destroyers: ords = self.orderings(fgraph) if _contains_cycle(fgraph, ords): raise InconsistencyError("Dependency graph contains cycles") else: # James's Conjecture: # If there are no destructive ops, then there can be no cycles. # FB: This isn't always True. It can happend that # optimization introduce node that depend on itself. This # is very rare and should not happen in general. It will be # caught later. The error will be far from the source. But # doing this conjecture should speed up compilation most of # the time. The user should create such dependency except # if he mess too much with the internal. pass return True def orderings(self, fgraph): """Return orderings induced by destructive operations. Raise InconsistencyError when a) attempting to destroy indestructable variable, or b) attempting to destroy a value multiple times, or c) an Apply destroys (illegally) one of its own inputs by aliasing """ rval = OrderedDict() if self.destroyers: # BUILD DATA STRUCTURES # CHECK for multiple destructions during construction of variables droot, impact, __ignore = self.refresh_droot_impact() # check for destruction of constants illegal_destroy = [r for r in droot if getattr(r.tag, 'indestructible', False) or isinstance(r, graph.Constant)] if illegal_destroy: raise InconsistencyError( "Attempting to destroy indestructible variables: %s" % illegal_destroy) # add destroyed variable clients as computational dependencies for app in self.destroyers: # for each destroyed input... for output_idx, input_idx_list in iteritems(app.op.destroy_map): destroyed_idx = input_idx_list[0] destroyed_variable = app.inputs[destroyed_idx] root = droot[destroyed_variable] root_impact = impact[root] # we generally want to put all clients of things which depend on root # as pre-requisites of app. # But, app is itself one such client! # App will always be a client of the node we're destroying # (destroyed_variable, but the tricky thing is when it is also a client of # *another variable* viewing on the root. Generally this is illegal, (e.g., # add_inplace(x, x.T). In some special cases though, the in-place op will # actually be able to work properly with multiple destroyed inputs (e.g, # add_inplace(x, x). An Op that can still work in this case should declare # so via the 'destroyhandler_tolerate_same' attribute or # 'destroyhandler_tolerate_aliased' attribute. # # destroyhandler_tolerate_same should be a list of pairs of the form # [(idx0, idx1), (idx0, idx2), ...] # The first element of each pair is the input index of a destroyed # variable. # The second element of each pair is the index of a different input where # we will permit exactly the same variable to appear. # For example, add_inplace.tolerate_same might be [(0,1)] if the destroyed # input is also allowed to appear as the second argument. # # destroyhandler_tolerate_aliased is the same sort of list of # pairs. # op.destroyhandler_tolerate_aliased = [(idx0, idx1)] tells the # destroyhandler to IGNORE an aliasing between a destroyed # input idx0 and another input idx1. # This is generally a bad idea, but it is safe in some # cases, such as # - the op reads from the aliased idx1 before modifying idx0 # - the idx0 and idx1 are guaranteed not to overlap (e.g. # they are pointed at different rows of a matrix). # # CHECK FOR INPUT ALIASING # OPT: pre-compute this on import tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', []) assert isinstance(tolerate_same, list) tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same if idx0 == destroyed_idx) tolerated.add(destroyed_idx) tolerate_aliased = getattr( app.op, 'destroyhandler_tolerate_aliased', []) assert isinstance(tolerate_aliased, list) ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx) # print 'tolerated', tolerated # print 'ignored', ignored for i, input in enumerate(app.inputs): if i in ignored: continue if input in root_impact \ and (i not in tolerated or input is not destroyed_variable): raise InconsistencyError("Input aliasing: %s (%i, %i)" % (app, destroyed_idx, i)) # add the rule: app must be preceded by all other Apply instances that # depend on destroyed_input root_clients = OrderedSet() for r in root_impact: assert not [a for a, c in self.clients[r].items() if not c] root_clients.update([a for a, c in self.clients[r].items() if c]) root_clients.remove(app) if root_clients: rval[app] = root_clients return rval
def get_dummy_args(sequences=None, outputs_info=None, non_sequences=None, n_steps=None, truncate_gradient=-1, go_backwards=False, mode=None, name=None, profile=False, allow_gc=None, strict=False): ################################################################## P1> # check if inputs are just single variables instead of lists def wrap_into_list(x): """ Wrap the input into a list if it is not already a list. """ if x is None: return [] elif not isinstance(x, (list, tuple)): return [x] else: return list(x) seqs = wrap_into_list(sequences) outs_info = wrap_into_list(outputs_info) # Make sure we get rid of numpy arrays or ints or anything like that # passed as inputs to scan non_seqs = [] for elem in wrap_into_list(non_sequences): if not isinstance(elem, gof.Variable): non_seqs.append(tensor.as_tensor_variable(elem)) else: non_seqs.append(elem) # If we provided a known number of steps ( before compilation) # and if that number is 1 or -1, then we can skip the Scan Op, # and just apply the inner function once # To do that we check here to see the nature of n_steps n_fixed_steps = None if isinstance(n_steps, (float, integer_types)): n_fixed_steps = int(n_steps) else: try: n_fixed_steps = opt.get_scalar_constant_value(n_steps) except tensor.basic.NotScalarConstantError: n_fixed_steps = None # Check n_steps is an int if (hasattr(n_steps, 'dtype') and str(n_steps.dtype)[:3] not in ('uin', 'int')): raise ValueError(' n_steps must be an int. dtype provided ' 'is %s' % n_steps.dtype) # compute number of sequences and number of outputs n_seqs = len(seqs) n_outs = len(outs_info) return_steps = OrderedDict() # wrap sequences in a dictionary if they are not already dictionaries for i in xrange(n_seqs): if not isinstance(seqs[i], dict): seqs[i] = OrderedDict([('input', seqs[i]), ('taps', [0])]) elif seqs[i].get('taps', None) is not None: seqs[i]['taps'] = wrap_into_list(seqs[i]['taps']) elif seqs[i].get('taps', None) is None: # seqs dictionary does not have the ``taps`` key seqs[i]['taps'] = [0] # wrap outputs info in a dictionary if they are not already in one for i in xrange(n_outs): if outs_info[i] is not None: if isinstance(outs_info[i], dict): # DEPRECATED : if outs_info[i].get('return_steps', None) is not None: raise ValueError( "Using `return_steps` has been deprecated. " "Simply select the entries you need using a " "subtensor. Scan will optimize memory " "consumption, so do not worry about that.") # END if not isinstance(outs_info[i], dict): # by default any output has a tap value of -1 outs_info[i] = OrderedDict([('initial', outs_info[i]), ('taps', [-1])]) elif (outs_info[i].get('initial', None) is None and outs_info[i].get('taps', None) is not None): # ^ no initial state but taps provided raise ValueError(('If you are using slices of an output ' 'you need to provide a initial state ' 'for it'), outs_info[i]) elif (outs_info[i].get('initial', None) is not None and outs_info[i].get('taps', None) is None): # ^ initial state but taps not provided if 'taps' in outs_info[i]: # ^ explicitly provided a None for taps _logger.warning( 'Output %s ( index %d) has a initial ' 'state but taps is explicitly set to None ', getattr(outs_info[i]['initial'], 'name', 'None'), i) outs_info[i]['taps'] = [-1] else: # if a None is provided as the output info we replace it # with an empty OrdereDict() to simplify handling outs_info[i] = OrderedDict() ## # Step 2. Generate inputs and outputs of the inner functions # for compiling a dummy function (Iteration #1) ## # create theano inputs for the recursive function # note : this is a first batch of possible inputs that will # be compiled in a dummy function; we used this dummy # function to detect shared variables and their updates # and to construct a new and complete list of inputs and # outputs n_seqs = 0 scan_seqs = [] # Variables passed as inputs to the scan op inner_seqs = [] # Variables passed as inputs to the inner function inner_slices = [] # Actual slices if scan is removed from the picture # go through sequences picking up time slices as needed for i, seq in enumerate(seqs): # Note that you can have something like no taps for # a sequence, though is highly unlikely in practice if 'taps' in seq: # go through the indicated slice mintap = numpy.min(seq['taps']) maxtap = numpy.max(seq['taps']) for k in seq['taps']: # create one slice of the input # Later on, if we decide not to use scan because we are # going for just one step, it makes things easier if we # compute the correct outputs here. This way we can use # the output of the lambda expression directly to replace # the output of scan. # If not we need to use copies, that will be replaced at # each frame by the corresponding slice actual_slice = seq['input'][k - mintap] _seq_val = tensor.as_tensor_variable(seq['input']) _seq_val_slice = _seq_val[k - mintap] nw_slice = _seq_val_slice.type() # Try to transfer test_value to the new variable if config.compute_test_value != 'off': try: nw_slice.tag.test_value = gof.Op._get_test_value( _seq_val_slice) except AttributeError as e: if config.compute_test_value != 'ignore': # No need to print a warning or raise an error now, # it will be done when fn will be called. _logger.info( ('Cannot compute test value for ' 'the inner function of scan, input value ' 'missing %s'), e) # Add names to slices for debugging and pretty printing .. # that is if the input already has a name if getattr(seq['input'], 'name', None) is not None: if k > 0: nw_name = seq['input'].name + '[t+%d]' % k elif k == 0: nw_name = seq['input'].name + '[t]' else: nw_name = seq['input'].name + '[t%d]' % k nw_slice.name = nw_name # We cut the sequence such that seq[i] to correspond to # seq[i-k]. For the purposes of cutting the sequences, we # need to pretend tap 0 is used to avoid cutting the sequences # too long if the taps are all lower or all higher than 0. maxtap_proxy = max(maxtap, 0) mintap_proxy = min(mintap, 0) start = (k - mintap_proxy) if k == maxtap_proxy: nw_seq = seq['input'][start:] else: end = -(maxtap_proxy - k) nw_seq = seq['input'][start:end] if go_backwards: nw_seq = nw_seq[::-1] scan_seqs.append(nw_seq) inner_seqs.append(nw_slice) inner_slices.append(actual_slice) n_seqs += 1 # Since we've added all sequences now we need to level them up based on # n_steps or their different shapes lengths_vec = [] for seq in scan_seqs: lengths_vec.append(seq.shape[0]) if not scan_utils.isNaN_or_Inf_or_None(n_steps): # ^ N_steps should also be considered lengths_vec.append(tensor.as_tensor(n_steps)) if len(lengths_vec) == 0: # ^ No information about the number of steps raise ValueError('No information about the number of steps ' 'provided. Either provide a value for ' 'n_steps argument of scan or provide an input ' 'sequence') # If the user has provided the number of steps, do that regardless ( and # raise an error if the sequences are not long enough ) if scan_utils.isNaN_or_Inf_or_None(n_steps): actual_n_steps = lengths_vec[0] for contestant in lengths_vec[1:]: actual_n_steps = tensor.minimum(actual_n_steps, contestant) else: actual_n_steps = tensor.as_tensor(n_steps) # Add names -- it helps a lot when debugging for (nw_seq, seq) in zip(scan_seqs, seqs): if getattr(seq['input'], 'name', None) is not None: nw_seq.name = seq['input'].name + '[%d:]' % k scan_seqs = [seq[:actual_n_steps] for seq in scan_seqs] # Conventions : # mit_mot = multiple input taps, multiple output taps ( only provided # by the gradient function ) # mit_sot = multiple input taps, single output tap (t + 0) # sit_sot = single input tap, single output tap (t + 0) # nit_sot = no input tap, single output tap (t + 0) # MIT_MOT -- not provided by the user only by the grad function n_mit_mot = 0 n_mit_mot_outs = 0 mit_mot_scan_inputs = [] mit_mot_inner_inputs = [] mit_mot_inner_outputs = [] mit_mot_out_slices = [] mit_mot_rightOrder = [] # SIT_SOT -- provided by the user n_mit_sot = 0 mit_sot_scan_inputs = [] mit_sot_inner_inputs = [] mit_sot_inner_slices = [] mit_sot_inner_outputs = [] mit_sot_return_steps = OrderedDict() mit_sot_tap_array = [] mit_sot_rightOrder = [] n_sit_sot = 0 sit_sot_scan_inputs = [] sit_sot_inner_inputs = [] sit_sot_inner_slices = [] sit_sot_inner_outputs = [] sit_sot_return_steps = OrderedDict() sit_sot_rightOrder = [] # go through outputs picking up time slices as needed for i, init_out in enumerate(outs_info): # Note that our convention dictates that if an output uses # just the previous time step, as a initial state we will only # provide a tensor of the same dimension as one time step; This # makes code much cleaner for those who do not use taps. Otherwise # they would always had to shape_padleft the initial state .. # which is ugly if init_out.get('taps', None) == [-1]: actual_arg = init_out['initial'] if not isinstance(actual_arg, tensor.Variable): actual_arg = tensor.as_tensor_variable(actual_arg) arg = safe_new(actual_arg) if isinstance(arg, tensor.Constant): # safe new returns a clone of the constants, but that is not # what we need for initial states arg = arg.type() # Try to transfer test_value to the new variable if config.compute_test_value != 'off': try: arg.tag.test_value = gof.Op._get_test_value(actual_arg) except AttributeError as e: if config.compute_test_value != 'ignore': # No need to print a warning or raise an error now, # it will be done when fn will be called. _logger.info( ('Cannot compute test value for the ' 'inner function of scan, input value missing %s'), e) if getattr(init_out['initial'], 'name', None) is not None: arg.name = init_out['initial'].name + '[t-1]' # We need now to allocate space for storing the output and copy # the initial state over. We do this using the expand function # defined in scan utils sit_sot_scan_inputs.append( scan_utils.expand_empty( tensor.unbroadcast(tensor.shape_padleft(actual_arg), 0), actual_n_steps)) sit_sot_inner_slices.append(actual_arg) if i in return_steps: sit_sot_return_steps[n_sit_sot] = return_steps[i] sit_sot_inner_inputs.append(arg) sit_sot_rightOrder.append(i) n_sit_sot += 1 elif init_out.get('taps', None): if numpy.any(numpy.array(init_out.get('taps', [])) > 0): # Make sure we do not have requests for future values of a # sequence we can not provide such values raise ValueError('Can not use future taps of outputs', init_out) # go through the taps mintap = abs(numpy.min(init_out['taps'])) mit_sot_tap_array.append(init_out['taps']) idx_offset = abs(numpy.min(init_out['taps'])) # Sequence mit_sot_scan_inputs.append( scan_utils.expand_empty(init_out['initial'][:mintap], actual_n_steps)) if i in return_steps: mit_sot_return_steps[n_mit_sot] = return_steps[i] mit_sot_rightOrder.append(i) n_mit_sot += 1 for k in init_out['taps']: # create a new slice actual_nw_slice = init_out['initial'][k + mintap] _init_out_var = tensor.as_tensor_variable(init_out['initial']) _init_out_var_slice = _init_out_var[k + mintap] nw_slice = _init_out_var_slice.type() # Try to transfer test_value to the new variable if config.compute_test_value != 'off': try: nw_slice.tag.test_value = gof.Op._get_test_value( _init_out_var_slice) except AttributeError as e: if config.compute_test_value != 'ignore': # No need to print a warning or raise an error now, # it will be done when fn will be called. _logger.info( ('Cannot compute test value for ' 'the inner function of scan, input value ' 'missing. %s'), e) # give it a name or debugging and pretty printing if getattr(init_out['initial'], 'name', None) is not None: if k > 0: nw_slice.name = (init_out['initial'].name + '[t+%d]' % k) elif k == 0: nw_slice.name = init_out['initial'].name + '[t]' else: nw_slice.name = (init_out['initial'].name + '[t%d]' % k) mit_sot_inner_inputs.append(nw_slice) mit_sot_inner_slices.append(actual_nw_slice) # NOTE: there is another case, in which we do not want to provide # any previous value of the output to the inner function (i.e. # a map); in that case we do not have to do anything .. # Re-order args max_mit_sot = numpy.max([-1] + mit_sot_rightOrder) + 1 max_sit_sot = numpy.max([-1] + sit_sot_rightOrder) + 1 n_elems = numpy.max([max_mit_sot, max_sit_sot]) _ordered_args = [[] for x in xrange(n_elems)] offset = 0 for idx in xrange(n_mit_sot): n_inputs = len(mit_sot_tap_array[idx]) if n_fixed_steps in [1, -1]: _ordered_args[mit_sot_rightOrder[idx]] = \ mit_sot_inner_slices[offset:offset + n_inputs] else: _ordered_args[mit_sot_rightOrder[idx]] = \ mit_sot_inner_inputs[offset:offset + n_inputs] offset += n_inputs for idx in xrange(n_sit_sot): if n_fixed_steps in [1, -1]: _ordered_args[sit_sot_rightOrder[idx]] = \ [sit_sot_inner_slices[idx]] else: _ordered_args[sit_sot_rightOrder[idx]] = \ [sit_sot_inner_inputs[idx]] ordered_args = [] for ls in _ordered_args: ordered_args += ls if n_fixed_steps in [1, -1]: args = (inner_slices + ordered_args + non_seqs) else: args = (inner_seqs + ordered_args + non_seqs) # add only the non-shared variables and non-constants to the arguments of # the dummy function [ a function should not get shared variables or # constants as input ] dummy_args = [ arg for arg in args if (not isinstance(arg, SharedVariable) and not isinstance(arg, tensor.Constant)) ] ################################################################## P1< return dummy_args, locals()
def finish_scan(fn_outputs, local_vars): n_fixed_steps = local_vars["n_fixed_steps"] return_steps = local_vars["return_steps"] non_seqs = local_vars["non_seqs"] dummy_args = local_vars["dummy_args"] args = local_vars["args"] outs_info = local_vars["outs_info"] n_outs = local_vars["n_outs"] mit_sot_inner_outputs = local_vars["mit_sot_inner_outputs"] sit_sot_inner_outputs = local_vars["sit_sot_inner_outputs"] sit_sot_scan_inputs = local_vars["sit_sot_scan_inputs"] sit_sot_inner_inputs = local_vars["sit_sot_inner_inputs"] actual_n_steps = local_vars["actual_n_steps"] sit_sot_rightOrder = local_vars["sit_sot_rightOrder"] strict = local_vars["strict"] non_sequences = local_vars["non_sequences"] inner_seqs = local_vars["inner_seqs"] mit_mot_inner_inputs = local_vars["mit_mot_inner_inputs"] mit_sot_inner_inputs = local_vars["mit_sot_inner_inputs"] mit_mot_inner_outputs = local_vars["mit_mot_inner_outputs"] mit_sot_tap_array = local_vars["mit_sot_tap_array"] allow_gc = local_vars["allow_gc"] n_seqs = local_vars["n_seqs"] n_mit_mot_outs = local_vars["n_mit_mot_outs"] mit_mot_out_slices = local_vars["mit_mot_out_slices"] truncate_gradient = local_vars["truncate_gradient"] name = local_vars["name"] mode = local_vars["mode"] profile = local_vars["profile"] scan_seqs = local_vars["scan_seqs"] mit_mot_scan_inputs = local_vars["mit_mot_scan_inputs"] mit_sot_scan_inputs = local_vars["mit_sot_scan_inputs"] n_mit_mot = local_vars["n_mit_mot"] mit_sot_return_steps = local_vars["mit_sot_return_steps"] n_mit_sot = local_vars["n_mit_sot"] sit_sot_return_steps = local_vars["sit_sot_return_steps"] mit_sot_rightOrder = local_vars["mit_sot_rightOrder"] condition, outputs, updates = scan_utils.get_updates_and_outputs( fn_outputs) ################################################################## P2> if condition is not None: as_while = True else: as_while = False ## # Step 3. Check if we actually need scan and remove it if we don't ## if n_fixed_steps in [1, -1]: # We do not need to use the scan op anymore, so we can just return # the outputs and updates we have if condition is not None: _logger.warning(('When the number of steps is fixed and equal ' 'to 1, the provided stopping condition, ', str(condition), ' is ignored')) for pos, inner_out in enumerate(outputs): # we need to see if we need to pad our sequences with an # unbroadcastable dimension; case example : we return an # output for which we want all intermediate. If n_steps is 1 # then, if we return the output as given by the innner function # this will represent only a slice and it will have one # dimension less. if (isinstance(inner_out.type, tensor.TensorType) and return_steps.get(pos, 0) != 1): outputs[pos] = tensor.unbroadcast( tensor.shape_padleft(inner_out), 0) if len(outputs) == 1: outputs = outputs[0] return (outputs, updates) ## # Step 4. Compile the dummy function ## # We can now compile a dummy function just to see what shared variable # we have and what are their update rules (note that the user has # the option not to pass the shared variable to scan, so we need to # pick them manually and add them to scan) # make the compilation as fast as possible by not applying any # optimization or conversion to C [ note this region is not important # for performance so we can do stuff as unoptimal as we wish ] # extract still missing inputs (there still might be so) and add them # as non sequences at the end of our args fake_nonseqs = [x.type() for x in non_seqs] fake_outputs = scan_utils.clone(outputs, replace=OrderedDict( izip(non_seqs, fake_nonseqs))) all_inputs = ifilter( lambda x: (isinstance(x, gof.Variable) and not isinstance( x, SharedVariable) and not isinstance(x, gof.Constant)), gof.graph.inputs(fake_outputs)) extra_inputs = [x for x in all_inputs if x not in args + fake_nonseqs] non_seqs += extra_inputs # Note we do not use all_inputs directly since the order of variables # in args is quite important dummy_args += extra_inputs dummy_outs = outputs if condition is not None: dummy_outs.append(condition) dummy_f = function(dummy_args, dummy_outs, updates=updates, mode=compile.mode.Mode(linker='py', optimizer=None), on_unused_input='ignore', profile=False) ## # Step 5. Re-arange inputs of scan into a more strict order ## # Step 5.0 Check the outputs of the dummy function to see if they # match with user provided data # if the number of outputs to the function does not match the number of # assumed outputs until now (provided by the user) there can be # only one explanation: No information is provided for any of the # outputs (i.e. we are dealing with a map) tmp_dummy_f_outs = len(dummy_f.maker.outputs) if as_while: tmp_dummy_f_outs -= 1 if not (tmp_dummy_f_outs == n_outs or outs_info == []): raise ValueError('Please provide None as outputs_info for ' 'any output that does not feed back into ' 'scan (i.e. it behaves like a map) ') if outs_info == []: n_outs = len(dummy_f.maker.outputs) if as_while: n_outs = n_outs - 1 outs_info = [OrderedDict() for x in xrange(n_outs)] # Step 5.1 Outputs with taps different then -1 for i, out in enumerate(outs_info): if 'taps' in out and out['taps'] != [-1]: mit_sot_inner_outputs.append(outputs[i]) # Step 5.2 Outputs with tap equal to -1 for i, out in enumerate(outs_info): if 'taps' in out and out['taps'] == [-1]: sit_sot_inner_outputs.append(outputs[i]) # Step 5.3 Outputs that correspond to update rules of shared variables givens = OrderedDict() n_shared_outs = 0 shared_scan_inputs = [] shared_inner_inputs = [] shared_inner_outputs = [] sit_sot_shared = [] for input in dummy_f.maker.expanded_inputs: if isinstance(input.variable, SharedVariable) and input.update: new_var = safe_new(input.variable) if getattr(input.variable, 'name', None) is not None: new_var.name = input.variable.name + '_copy' if isinstance(new_var.type, ops.expandable_types): sit_sot_inner_inputs.append(new_var) sit_sot_scan_inputs.append( scan_utils.expand_empty( tensor.unbroadcast( tensor.shape_padleft(input.variable), 0), actual_n_steps)) tensor_update = tensor.as_tensor_variable(input.update) sit_sot_inner_outputs.append(tensor_update) # Not that pos is not a negative index. The sign of pos is used # as a flag to indicate if this output should be part of the # update rules or part of the standard outputs of scan. # If `pos` is positive than it corresponds to the standard # outputs of scan and it refers to output of index `pos`. If `pos` # is negative that it corresponds to update rules of scan and it # refers to update rule of index -1 - `pos`. sit_sot_rightOrder.append(-1 - len(sit_sot_shared)) sit_sot_shared.append(input.variable) givens[input.variable] = new_var else: shared_inner_inputs.append(new_var) shared_scan_inputs.append(input.variable) shared_inner_outputs.append(input.update) givens[input.variable] = new_var n_shared_outs += 1 n_sit_sot = len(sit_sot_inner_inputs) # Step 5.4 Outputs with no taps used in the input n_nit_sot = 0 nit_sot_inner_outputs = [] nit_sot_return_steps = OrderedDict() nit_sot_rightOrder = [] for i, out in enumerate(outs_info): if not 'taps' in out: nit_sot_inner_outputs.append(outputs[i]) if i in return_steps: nit_sot_return_steps[n_nit_sot] = return_steps[i] nit_sot_rightOrder.append(i) n_nit_sot += 1 # Step 5.5 all other arguments including extra inputs other_scan_args = [] other_inner_args = [] other_scan_args += [ arg for arg in non_seqs if (not isinstance(arg, SharedVariable) and not isinstance(arg, tensor.Constant)) ] # Step 5.6 all shared variables with no update rules other_inner_args += [ safe_new(arg, '_copy') for arg in non_seqs if (not isinstance(arg, SharedVariable) and not isinstance(arg, tensor.Constant)) ] givens.update(OrderedDict(izip(other_scan_args, other_inner_args))) if strict: non_seqs_set = set(non_sequences if non_sequences is not None else []) other_shared_scan_args = [ arg.variable for arg in dummy_f.maker.expanded_inputs if (isinstance(arg.variable, SharedVariable) and not arg.update and arg.variable in non_seqs_set) ] other_shared_inner_args = [ safe_new(arg.variable, '_copy') for arg in dummy_f.maker.expanded_inputs if (isinstance(arg.variable, SharedVariable) and not arg.update and arg.variable in non_seqs_set) ] else: other_shared_scan_args = [ arg.variable for arg in dummy_f.maker.expanded_inputs if (isinstance(arg.variable, SharedVariable) and not arg.update) ] other_shared_inner_args = [ safe_new(arg.variable, '_copy') for arg in dummy_f.maker.expanded_inputs if (isinstance(arg.variable, SharedVariable) and not arg.update) ] givens.update( OrderedDict(izip(other_shared_scan_args, other_shared_inner_args))) ## # Step 6. Re-order the outputs and clone them replacing things # using the givens ## inner_inputs = (inner_seqs + mit_mot_inner_inputs + mit_sot_inner_inputs + sit_sot_inner_inputs + shared_inner_inputs + other_shared_inner_args + other_inner_args) inner_outs = (mit_mot_inner_outputs + mit_sot_inner_outputs + sit_sot_inner_outputs + nit_sot_inner_outputs + shared_inner_outputs) if condition is not None: inner_outs.append(condition) # Cuda and Gpuarray are imported here, instead of being imported on top of # the file because that would force on the user some dependencies that we # might do not want to. Currently we are working on removing the # dependencies on sandbox code completeley. from theano.sandbox import cuda, gpuarray if cuda.cuda_available or gpuarray.pygpu_activated: # very often we end up in this situation when we want to # replace w with w_copy, where w is a GPU variable # and w_copy is TensorType. This is caused because shared # variables are put on GPU right aways >:| , new_givens = OrderedDict() for w, w_copy in iteritems(givens): if ((isinstance(w.type, cuda.CudaNdarrayType) or isinstance(w.type, gpuarray.GpuArrayType)) and isinstance(w_copy.type, tensor.TensorType)): for o in inner_outs: new_givens = traverse(o, w, w_copy, new_givens) else: new_givens[w] = w_copy else: new_givens = givens new_outs = scan_utils.clone(inner_outs, replace=new_givens) ## # Step 7. Create the Scan Op ## tap_array = mit_sot_tap_array + [[-1] for x in xrange(n_sit_sot)] if allow_gc is None: allow_gc = config.scan.allow_gc info = OrderedDict() info['tap_array'] = tap_array info['n_seqs'] = n_seqs info['n_mit_mot'] = n_mit_mot info['n_mit_mot_outs'] = n_mit_mot_outs info['mit_mot_out_slices'] = mit_mot_out_slices info['n_mit_sot'] = n_mit_sot info['n_sit_sot'] = n_sit_sot info['n_shared_outs'] = n_shared_outs info['n_nit_sot'] = n_nit_sot info['truncate_gradient'] = truncate_gradient info['name'] = name info['mode'] = mode info['destroy_map'] = OrderedDict() info['gpu'] = False info['as_while'] = as_while info['profile'] = profile info['allow_gc'] = allow_gc info['strict'] = strict local_op = scan_op.Scan(inner_inputs, new_outs, info) ## # Step 8. Compute the outputs using the scan op ## _scan_inputs = (scan_seqs + mit_mot_scan_inputs + mit_sot_scan_inputs + sit_sot_scan_inputs + shared_scan_inputs + [actual_n_steps for x in xrange(n_nit_sot)] + other_shared_scan_args + other_scan_args) scan_inputs = [] for arg in [actual_n_steps] + _scan_inputs: try: arg = tensor.as_tensor_variable(arg) except TypeError: # This happens for Random States for e.g. but it is a good way # to make sure no input is a cuda ndarrays pass scan_inputs += [arg] scan_outs = local_op(*scan_inputs) if type(scan_outs) not in (list, tuple): scan_outs = [scan_outs] ## # Step 9. Figure out which outs are update rules for shared variables # and so on ... ## update_map = OrderedUpdates() def remove_dimensions(outs, steps_return, offsets=None): out_ls = [] for idx, out in enumerate(outs): if idx in steps_return: if steps_return[idx] > 1: out_ls.append(out[-steps_return[idx]:]) else: out_ls.append(out[-1]) else: if offsets is None: out_ls.append(out) else: out_ls.append(out[offsets[idx]:]) return out_ls offset = n_mit_mot offsets = [abs(numpy.min(x)) for x in mit_sot_tap_array] mit_sot_outs = remove_dimensions(scan_outs[offset:offset + n_mit_sot], mit_sot_return_steps, offsets) offset += n_mit_sot offsets = [1 for x in xrange(n_sit_sot)] sit_sot_outs = remove_dimensions(scan_outs[offset:offset + n_sit_sot], sit_sot_return_steps, offsets) offset += n_sit_sot nit_sot_outs = remove_dimensions(scan_outs[offset:offset + n_nit_sot], nit_sot_return_steps) offset += n_nit_sot for idx, update_rule in enumerate(scan_outs[offset:offset + n_shared_outs]): update_map[shared_scan_inputs[idx]] = update_rule _scan_out_list = (mit_sot_outs + sit_sot_outs + nit_sot_outs) # Step 10. I need to reorder the outputs to be in the order expected by # the user rightOrder = (mit_sot_rightOrder + sit_sot_rightOrder + nit_sot_rightOrder) scan_out_list = [None] * len(rightOrder) for idx, pos in enumerate(rightOrder): if pos >= 0: scan_out_list[pos] = _scan_out_list[idx] else: # Not that pos is not a negative index. The sign of pos is used # as a flag to indicate if this output should be part of the # update rules or part of the standard outputs of scan. # If `pos` is positive than it corresponds to the standard # outputs of scan and it refers to output of index `pos`. If `pos` # is negative that it corresponds to update rules of scan and it # refers to update rule of index -1 - `pos`. update_map[sit_sot_shared[abs(pos) - 1]] = _scan_out_list[idx][-1] scan_out_list = [x for x in scan_out_list if x is not None] ################################################################## P2< return (scan_out_list, update_map)
n_outs = n_outs - 1 outs_info = [OrderedDict() for x in xrange(n_outs)] # Step 5.1 Outputs with taps different then -1 for i, out in enumerate(outs_info): if 'taps' in out and out['taps'] != [-1]: mit_sot_inner_outputs.append(outputs[i]) # Step 5.2 Outputs with tap equal to -1 for i, out in enumerate(outs_info): if 'taps' in out and out['taps'] == [-1]: sit_sot_inner_outputs.append(outputs[i]) # Step 5.3 Outputs that correspond to update rules of shared variables givens = OrderedDict() n_shared_outs = 0 shared_scan_inputs = [] shared_inner_inputs = [] shared_inner_outputs = [] sit_sot_shared = [] for input in dummy_f.maker.expanded_inputs: if isinstance(input.variable, SharedVariable) and input.update: new_var = safe_new(input.variable) if getattr(input.variable, 'name', None) is not None: new_var.name = input.variable.name + '_copy' if isinstance(new_var.type, ops.expandable_types): sit_sot_inner_inputs.append(new_var) sit_sot_scan_inputs.append( scan_utils.expand( tensor.unbroadcast(