def is_neg(var): """ Match a variable with the `-x` pattern. :param var: The Variable to analyze. :return: `x` if `var` is of the form `-x`, or None otherwise. """ apply = var.owner if not apply: return None # First match against `tensor.neg`. if apply.op == tensor.neg: return apply.inputs[0] # Then match against a multiplication by -1. if apply.op == tensor.mul and len(apply.inputs) >= 2: for idx, mul_input in enumerate(apply.inputs): try: constant = opt.get_constant_value(mul_input) is_minus_1 = numpy.allclose(constant, -1) except TypeError: is_minus_1 = False if is_minus_1: # Found a multiplication by -1. if len(apply.inputs) == 2: # Only return the other input. return apply.inputs[1 - idx] else: # Return the multiplication of all other inputs. return tensor.mul(*(apply.inputs[0:idx] + apply.inputs[idx + 1:])) # No match. return None
def _is_1(expr): """rtype bool. True iff expr is a constant close to 1 """ try: v = opt.get_constant_value(expr) return numpy.allclose(v, 1) except TypeError: return False
def local_1msigmoid(node): """ 1-sigm(x) -> sigm(-x) """ if node.op == tensor.sub: sub_l, sub_r = node.inputs if len(sub_r.clients) > 1: return # graph is using both sigm and 1-sigm if sub_r.owner and sub_r.owner.op == sigmoid: try: val_l = opt.get_constant_value(sub_l) except Exception, e: return if numpy.allclose(numpy.sum(val_l), 1): return [sigmoid(-sub_r.owner.inputs[0])]
def scan(fn, sequences=None, outputs_info=None, non_sequences=None, n_steps=None, truncate_gradient=-1, go_backwards=False, mode=None, name=None, profile=False): """ This function constructs and applies a Scan op to the provided arguments. :param fn: ``fn`` is a function that describes the operations involved in one step of ``scan``. ``fn`` should construct variables describing the output of one iteration step. It should expect as input theano variables representing all the slices of the input sequences and previous values of the outputs, as well as all other arguments given to scan as ``non_sequences``. The order in which scan passes these variables to ``fn`` is the following : * all time slices of the first sequence * all time slices of the second sequence * ... * all time slices of the last sequence * all past slices of the first output * all past slices of the second otuput * ... * all past slices of the last output * all other arguments (the list given as `non_sequences` to scan) The order of the sequences is the same as the one in the list `sequences` given to scan. The order of the outputs is the same as the order of ``output_info``. For any sequence or output the order of the time slices is the same as the one in which they have been given as taps. For example if one writes the following : .. code-block:: python scan(fn, sequences = [ dict(input= Sequence1, taps = [-3,2,-1]) , Sequence2 , dict(input = Sequence3, taps = 3) ] , outputs_info = [ dict(initial = Output1, taps = [-3,-5]) , dict(initial = Output2, taps = None) , Output3 ] , non_sequences = [ Argument1, Argument 2]) ``fn`` should expect the following arguments in this given order: #. ``Sequence1[t-3]`` #. ``Sequence1[t+2]`` #. ``Sequence1[t-1]`` #. ``Sequence2[t]`` #. ``Sequence3[t+3]`` #. ``Output1[t-3]`` #. ``Output1[t-5]`` #. ``Output3[t-1]`` #. ``Argument1`` #. ``Argument2`` The list of ``non_sequences`` can also contain shared variables used in the function, though ``scan`` is able to figure those out on its own so they can be skipped. For the clarity of the code we recommand though to provide them to scan. To some extend ``scan`` can also figure out other ``non sequences`` (not shared) even if not passed to scan (but used by `fn`). A simple example of this would be : .. code-block:: python import theano.tensor as TT W = TT.matrix() W_2 = W**2 def f(x): return TT.dot(x,W_2) The function is expected to return two things. One is a list of outputs ordered in the same order as ``outputs_info``, with the difference that there should be only one output variable per output initial state (even if no tap value is used). Secondly `fn` should return an update dictionary (that tells how to update any shared variable after each iteration step). The dictionary can optionally be given as a list of tuples. There is no constraint on the order of these two list, ``fn`` can return either ``(outputs_list, update_dictionary)`` or ``(update_dictionary, outputs_list)`` or just one of the two (in case the other is empty). To use ``scan`` as a while loop, the user needs to change the function ``fn`` such that also a stopping condition is returned. To do so, he/she needs to wrap the condition in an ``until`` class. The condition should be returned as a third element, for example: .. code-block:: python ... return [y1_t, y2_t], {x:x+1}, theano.scan_module.until(x < 50) Note that a number of steps (considered in here as the maximum number of steps ) is still required even though a condition is passed (and it is used to allocate memory if needed). = {}): :param sequences: ``sequences`` is the list of Theano variables or dictionaries describing the sequences ``scan`` has to iterate over. If a sequence is given as wrapped in a dictionary, then a set of optional information can be provided about the sequence. The dictionary should have the following keys: * ``input`` (*mandatory*) -- Theano variable representing the sequence. * ``taps`` -- Temporal taps of the sequence required by ``fn``. They are provided as a list of integers, where a value ``k`` impiles that at iteration step ``t`` scan will pass to ``fn`` the slice ``t+k``. Default value is ``[0]`` Any Theano variable in the list ``sequences`` is automatically wrapped into a dictionary where ``taps`` is set to ``[0]`` :param outputs_info: ``outputs_info`` is the list of Theano variables or dictionaries describing the initial state of the outputs computed recurrently. When this initial states are given as dictionary optional information can be provided about the output corresponding to these initial states. The dictionary should have the following keys: * ``initial`` -- Theano variable that represents the initial state of a given output. In case the output is not computed recursively (think of a map) and does not require a initial state this field can be skiped. Given that only the previous time step of the output is used by ``fn`` the initial state should have the same shape as the output. If multiple time taps are used, the initial state should have one extra dimension that should cover all the possible taps. For example if we use ``-5``, ``-2`` and ``-1`` as past taps, at step 0, ``fn`` will require (by an abuse of notation) ``output[-5]``, ``output[-2]`` and ``output[-1]``. This will be given by the initial state, which in this case should have the shape (5,)+output.shape. If this variable containing the initial state is called ``init_y`` then ``init_y[0]`` *corresponds to* ``output[-5]``. ``init_y[1]`` *correponds to* ``output[-4]``, ``init_y[2]`` corresponds to ``output[-3]``, ``init_y[3]`` coresponds to ``output[-2]``, ``init_y[4]`` corresponds to ``output[-1]``. While this order might seem strange, it comes natural from splitting an array at a given point. Assume that we have a array ``x``, and we choose ``k`` to be time step ``0``. Then our initial state would be ``x[:k]``, while the output will be ``x[k:]``. Looking at this split, elements in ``x[:k]`` are ordered exactly like those in ``init_y``. * ``taps`` -- Temporal taps of the output that will be pass to ``fn``. They are provided as a list of *negative* integers, where a value ``k`` implies that at iteration step ``t`` scan will pass to ``fn`` the slice ``t+k``. ``scan`` will follow this logic if partial information is given: * If an output is not wrapped in a dictionary, ``scan`` will wrap it in one assuming that you use only the last step of the output (i.e. it makes your tap value list equal to [-1]). * If you wrap an output in a dictionary and you do not provide any taps but you provide an initial state it will assume that you are using only a tap value of -1. * If you wrap an output in a dictionary but you do not provide any initial state, it assumes that you are not using any form of taps. * If you provide a ``None`` instead of a variable or a empty dictionary ``scan`` assumes that you will not use any taps for this output (like for example in case of a map) If ``outputs_info`` is an empty list or None, ``scan`` assumes that no tap is used for any of the outputs. If information is provided just for a subset of the outputs an exception is raised (because there is no convention on how scan should map the provided information to the outputs of ``fn``) :param non_sequences: ``non_sequences`` is the list of arguments that are passed to ``fn`` at each steps. One can opt to exclude variable used in ``fn`` from this list as long as they are part of the computational graph, though for clarity we encourage not to do so. :param n_steps: ``n_steps`` is the number of steps to iterate given as an int or Theano scalar. If any of the input sequences do not have enough elements, scan will raise an error. If the *value is 0* the outputs will have *0 rows*. If the value is negative, ``scan`` will run backwards in time. If the ``go_backwards`` flag is already set and also ``n_steps`` is negative, ``scan`` will run forward in time. If n stpes is not provided, ``scan`` will figure out the amount of steps it should run given its input sequences. :param truncate_gradient: ``truncate_gradient`` is the number of steps to use in truncated BPTT. If you compute gradients through a scan op, they are computed using backpropagation through time. By providing a different value then -1, you choose to use truncated BPTT instead of classical BPTT, where you go for only ``truncate_gradient`` number of steps back in time. :param go_backwards: ``go_backwards`` is a flag indicating if ``scan`` should go backwards through the sequences. If you think of each sequence as indexed by time, making this flag True would mean that ``scan`` goes back in time, namely that for any sequence it starts from the end and goes towards 0. :param name: When profiling ``scan``, it is crucial to provide a name for any instance of ``scan``. The profiler will produce an overall profile of your code as well as profiles for the computation of one step of each instance of ``scan``. The ``name`` of the instance appears in those profiles and can greatly help to disambiguate information. :param mode: It is recommended to leave this argument to None, especially when profiling ``scan`` (otherwise the results are not going to be accurate). If you prefer the computations of one step of ``scan`` to be done differently then the entire function, you can use this parameter to describe how the computations in this loop are done (see ``theano.function`` for details about possible values and their meaning). :param profile: Flag or string. If true, or different from the empty string, a profile object will be created and attached to the inner graph of scan. In case ``profile`` is True, the profile object will have the name of the scan instance, otherwise it will have the passed string. Profile object collect (and print) information only when running the inner graph with the new cvm linker ( with default modes, other linkers this argument is useless) :rtype: tuple :return: tuple of the form (outputs, updates); ``outputs`` is either a Theano variable or a list of Theano variables representing the outputs of ``scan`` (in the same order as in ``outputs_info``). ``updates`` is a subclass of dictionary specifying the update rules for all shared variables used in scan This dictionary should be passed to ``theano.function`` when you compile your function. The change compared to a normal dictionary is that we validate that keys are SharedVariable and addition of those dictionary are validated to be consistent. """ # General observation : this code is executed only once, at creation # of the computational graph, so we don't yet need to be smart about # anything (to speed things up) ## ### Step 1. Wrap all inputs in dictionaries and add default values ## # check if inputs are just single variables instead of lists def wrap_into_list(x): ''' Wrap the input into a list if it is not already a list ''' if x is None: return [] elif not isinstance(x, (list, tuple)): return [x] else: return list(x) seqs = wrap_into_list(sequences) outs_info = wrap_into_list(outputs_info) # Make sure we get rid of numpy arrays or ints or anything like that # passed as inputs to scan non_seqs = [] for elem in wrap_into_list(non_sequences): if not isinstance(elem, gof.Variable): non_seqs.append(tensor.as_tensor_variable(elem)) else: non_seqs.append(elem) # If we provided a known number of steps ( before compilation) # and if that number is 1 or -1, then we can skip the Scan Op, # and just apply the inner function once # To do that we check here to see the nature of n_steps n_fixed_steps = None if isinstance(n_steps, (float, int)): n_fixed_steps = int(n_steps) else: try: n_fixed_steps = opt.get_constant_value(n_steps) except (TypeError, AttributeError): n_fixed_steps = None # Check n_steps is an int if (hasattr(n_steps, 'dtype') and str(n_steps.dtype)[:3] not in ('uin', 'int')): raise ValueError(' n_steps must be an int. dtype provided ' 'is %s' % n_steps.dtype) # compute number of sequences and number of outputs n_seqs = len(seqs) n_outs = len(outs_info) return_steps = {} # wrap sequences in a dictionary if they are not already dictionaries for i in xrange(n_seqs): if not isinstance(seqs[i], dict): seqs[i] = dict(input=seqs[i], taps=[0]) elif seqs[i].get('taps', None): seqs[i]['taps'] = wrap_into_list(seqs[i]['taps']) elif seqs[i].get('taps', True) is None: # seqs dictionary does not have the ``taps`` key seqs[i]['taps'] = [0] # wrap outputs info in a dictionary if they are not already in one for i in xrange(n_outs): if outs_info[i] is not None: if isinstance(outs_info[i], dict): # DEPRECATED : if outs_info[i].get('return_steps', None): raise ValueError( "Using `return_steps` has been deprecated. " "Simply select the entries you need using a " "subtensor. Scan will optimize memory " "consumption, so do not worry about that.") # END if not isinstance(outs_info[i], dict): # by default any output has a tap value of -1 outs_info[i] = dict(initial=outs_info[i], taps=[-1]) elif (not outs_info[i].get('initial', None) and outs_info[i].get('taps', None)): # ^ no initial state but taps provided raise ValueError(('If you are using slices of an output ' 'you need to provide a initial state ' 'for it'), outs_info[i]) elif (outs_info[i].get('initial', None) and not outs_info[i].get('taps', None)): # ^ initial state but taps not provided if 'taps' in outs_info[i]: # ^ explicitly provided a None for taps _logger.warning('Output %s ( index %d) has a initial ' 'state but taps is explicitly set to None ', getattr(outs_info[i]['initial'], 'name', 'None'), i) outs_info[i]['taps'] = [-1] else: # if a None is provided as the output info we replace it # with an empty dict() to simplify handling outs_info[i] = dict() ## ### Step 2. Generate inputs and outputs of the inner functions ### for compiling a dummy function (Iteration #1) ## # create theano inputs for the recursive function # note : this is a first batch of possible inputs that will # be compiled in a dummy function; we used this dummy # function to detect shared variables and their updates # and to construct a new and complete list of inputs and # outputs n_seqs = 0 scan_seqs = [] # Variables passed as inputs to the scan op inner_seqs = [] # Variables passed as inputs to the inner function inner_slices = [] # Actual slices if scan is removed from the picture # go through sequences picking up time slices as needed for i, seq in enumerate(seqs): # Note that you can have something like no taps for # a sequence, though is highly unlikely in practice if 'taps' in seq: # go through the indicated slice mintap = numpy.min(seq['taps']) maxtap = numpy.max(seq['taps']) for k in seq['taps']: # create one slice of the input # Later on, if we decide not to use scan because we are # going for just one step, it makes things easier if we # compute the correct outputs here. This way we can use # the output of the lambda expression directly to replace # the output of scan. # If not we need to use copies, that will be replaced at # each frame by the corresponding slice actual_slice = seq['input'][k - mintap] _seq_val = tensor.as_tensor_variable(seq['input']) _seq_val_slice = _seq_val[k - mintap] nw_slice = _seq_val_slice.type() # Try to transfer test_value to the new variable if config.compute_test_value != 'off': try: nw_slice.tag.test_value = gof.Op._get_test_value( _seq_val_slice) except AttributeError, e: if config.compute_test_value != 'ignore': # No need to print a warning or raise an error now, # it will be done when fn will be called. _logger.info(('Cannot compute test value for ' 'the inner function of scan, input value ' 'missing %s'), e) # Add names to slices for debugging and pretty printing .. # that is if the input already has a name if getattr(seq['input'], 'name', None) is not None: if k > 0: nw_name = seq['input'].name + '[t+%d]' % k elif k == 0: nw_name = seq['input'].name + '[t]' else: nw_name = seq['input'].name + '[t%d]' % k nw_slice.name = nw_name # We cut the sequence such that seq[i] to correspond to # seq[i-k] if maxtap < 0: offset = abs(maxtap) else: offset = 0 if maxtap == mintap and maxtap != 0: nw_seq = seq['input'][:abs(maxtap)] elif maxtap - k != 0: nw_seq = seq['input'][offset + k - mintap: -(maxtap - k)] else: nw_seq = seq['input'][offset + k - mintap:] if go_backwards: nw_seq = nw_seq[::-1] scan_seqs.append(nw_seq) inner_seqs.append(nw_slice) inner_slices.append(actual_slice) n_seqs += 1
def scan(fn, sequences=None, states=None, params=None, n_steps=None, mode=None, name=None, profile=False): """ Similar to Theano's official scan, this function gives the user more control over the scan op, avoiding certain difficulties that arose from missing optimizations. :param fn: lambda function that describes one step of scan (see the official Theano scan function) :param sequences: similar to the official Theano's scan. This version of scan does not support taps for the sequences (it can only be a list of tensor). Scan assumes that sequences have the right length and it does not check for this. :param states: similar to outputs_info of the official scan function. There is one crucial difference though, namely that the `initial` key in the dictionary has been replace by 'membuf' key. This reflects the change of meaning. Instead of passing to scan just the initial steps misisng, one has now to pass a memory buffer in which scan will try to store its output. In this memory buffer the first entries should be set to the initial states of the corresponding states. Providing a memory buffer that has less entries then the number of steps, mneans scan will only use that amount of memory. The user has to match the memory buffer size with the number of steps, otherwise scan will produce wrong results. Also if gradients are to be computed through the scan, the memory buffer should have the same length as the number of steps. For states that do not require a initial state, one has to provide a dictionary with a single key 'steps' that says how many intermediate results to store. See examples below for more insight. :param n_steps: This parameter is mandatory and it will represent the number of steps scan will do (scan will not check sequences or any other source of information to figure out how many steps it needs to do). :param mode: Same as for the official scan :param name: Same as for the official scan :param profile: Same as for the official scan Note: - there is no truncate / go_backwards anymore ! - the outputs returned by scan contain the initial states as well (i.e. if I loop over k steps, with my smallest tap for an output -3 and keep al intermediate results, my output will be of length k+3 Examples: (a) if you do not want to store any intermediate results (just the last one) # The memory buffer can be the initial state, just that we need to # add one extra dimension in front of it state = TT.unbroadcast(TT.shape_padleft(x0),0) out,_ = scan(lambda x:x+1, states = state, n_steps = 5) # Once we got our result we need to remove the extra dimension out = out[0] (b) if you want to keep every intermediate results state = TT.alloc(TT.constant(0), 6, x0.shape[0]) state = TT.set_subtensor(state[0], x0) out,_ = scan(lambda x:x+1, states = state, n_steps = 5) out = out[1:] """ def wrap_into_list(x): ''' Wrap the input into a list if it is not already a list ''' if x is None: return [] elif not isinstance(x, (list, tuple)): return [x] else: return list(x) seqs = wrap_into_list(sequences) outs_info = wrap_into_list(states) # Make sure we get rid of numpy arrays or ints or anything like that # passed as inputs to scan non_seqs = [] for elem in wrap_into_list(params): if not isinstance(elem, gof.Variable): non_seqs.append(tensor.as_tensor_variable(elem)) else: non_seqs.append(elem) # If we provided a known number of steps ( before compilation) # and if that number is 1 or -1, then we can skip the Scan Op, # and just apply the inner function once # To do that we check here to see the nature of n_steps n_fixed_steps = None if isinstance(n_steps, (float, int)): n_fixed_steps = int(n_steps) else: try: n_fixed_steps = opt.get_constant_value(n_steps) except (TypeError, AttributeError): n_fixed_steps = None # Check n_steps is an int if (hasattr(n_steps, 'dtype') and str(n_steps.dtype)[:3] not in ('uin', 'int')): raise ValueError(' n_steps must be an int. dtype provided ' 'is %s' % n_steps.dtype) # compute number of sequences and number of outputs n_seqs = len(seqs) n_outs = len(outs_info) return_steps = OrderedDict() # wrap outputs info in a dictionary if they are not already in one for i in xrange(n_outs): if outs_info[i] is not None: if not isinstance(outs_info[i], dict): # by default any output has a tap value of -1 outs_info[i] = dict(membuf=outs_info[i], taps=[-1]) elif (not outs_info[i].get('membuf', None) and outs_info[i].get('taps', None)): # ^ no initial state but taps provided raise ValueError(('If you are using slices of an output ' 'you need to provide a memory buffer for ' 'the state '), outs_info[i]) elif (outs_info[i].get('membuf', None) and not outs_info[i].get('taps', None)): # ^ initial state but taps not provided if 'taps' in outs_info[i]: # ^ explicitly provided a None for taps _logger.warning( 'Output %s (index %d) has a memory ' 'buffer but taps is explicitly set to None ', getattr(outs_info[i]['membuf'], 'name', 'None'), i) outs_info[i]['taps'] = [-1] else: # if a None is provided as the output info we replace it # with an dict(steps=n_steps) to simplify handling outs_info[i] = dict(steps=n_steps) ## ### Step 2. Generate inputs and outputs of the inner functions ### for compiling a dummy function (Iteration #1) ## # create theano inputs for the recursive function # note : this is a first batch of possible inputs that will # be compiled in a dummy function; we used this dummy # function to detect shared variables and their updates # and to construct a new and complete list of inputs and # outputs n_seqs = 0 scan_seqs = [] # Variables passed as inputs to the scan op inner_seqs = [] # Variables passed as inputs to the inner function inner_slices = [] # Actual slices if scan is removed from the picture # go through sequences picking up time slices as needed for i, seq in enumerate(seqs): actual_slice = seq[0] _seq_val = tensor.as_tensor_variable(seq) _seq_val_slice = _seq_val[0] # Try to transfer test_value to the new variable if config.compute_test_value != 'off': try: nw_slice.tag.test_value = gof.Op._get_test_value( _seq_val_slice) except AttributeError, e: if config.compute_test_value != 'ignore': # No need to print a warning or raise an error now, # it will be done when fn will be called. _logger.info(('Cannot compute test value for ' 'the inner function of scan, input value ' 'missing %s'), e) nw_slice = _seq_val_slice.type() if seq.name: nw_slice.name = seq.name + '[t]' scan_seqs.append(_seq_val) inner_seqs.append(nw_slice) inner_slices.append(actual_slice) n_seqs += 1
def scan(fn, sequences=None, outputs_info=None, non_sequences=None, n_steps=None, truncate_gradient=-1, go_backwards=False, mode=None, name=None, options=None, profile=False): """ This function constructs and applies a Scan op to the provided arguments. :param fn: ``fn`` is a function that describes the operations involved in one step of ``scan``. ``fn`` should construct variables describing the output of one iteration step. It should expect as input theano variables representing all the slices of the input sequences and previous values of the outputs, as well as all other arguments given to scan as ``non_sequences``. The order in which scan passes these variables to ``fn`` is the following : * all time slices of the first sequence * all time slices of the second sequence * ... * all time slices of the last sequence * all past slices of the first output * all past slices of the second otuput * ... * all past slices of the last output * all other arguments (the list given as `non_sequences` to scan) The order of the sequences is the same as the one in the list `sequences` given to scan. The order of the outputs is the same as the order of ``output_info``. For any sequence or output the order of the time slices is the same as the one in which they have been given as taps. For example if one writes the following : .. code-block:: python scan(fn, sequences = [ dict(input= Sequence1, taps = [-3,2,-1]) , Sequence2 , dict(input = Sequence3, taps = 3) ] , outputs_info = [ dict(initial = Output1, taps = [-3,-5]) , dict(initial = Output2, taps = None) , Output3 ] , non_sequences = [ Argument1, Argument 2]) ``fn`` should expect the following arguments in this given order: #. ``Sequence1[t-3]`` #. ``Sequence1[t+2]`` #. ``Sequence1[t-1]`` #. ``Sequence2[t]`` #. ``Sequence3[t+3]`` #. ``Output1[t-3]`` #. ``Output1[t-5]`` #. ``Output3[t-1]`` #. ``Argument1`` #. ``Argument2`` The list of ``non_sequences`` can also contain shared variables used in the function, though ``scan`` is able to figure those out on its own so they can be skipped. For the clarity of the code we recommand though to provide them to scan. To some extend ``scan`` can also figure out other ``non sequences`` (not shared) even if not passed to scan (but used by `fn`). A simple example of this would be : .. code-block:: python import theano.tensor as TT W = TT.matrix() W_2 = W**2 def f(x): return TT.dot(x,W_2) The function is expected to return two things. One is a list of outputs ordered in the same order as ``outputs_info``, with the difference that there should be only one output variable per output initial state (even if no tap value is used). Secondly `fn` should return an update dictionary (that tells how to update any shared variable after each iteration step). The dictionary can optionally be given as a list of tuples. There is no constraint on the order of these two list, ``fn`` can return either ``(outputs_list, update_dictionary)`` or ``(update_dictionary, outputs_list)`` or just one of the two (in case the other is empty). To use ``scan`` as a while loop, the user needs to change the function ``fn`` such that also a stopping condition is returned. To do so, he/she needs to wrap the condition in an ``until`` class. The condition should be returned as a third element, for example: .. code-block:: python ... return [y1_t, y2_t], {x:x+1}, theano.scan_module.until(x < 50) Note that a number of steps (considered in here as the maximum number of steps ) is still required even though a condition is passed (and it is used to allocate memory if needed). = {}): :param sequences: ``sequences`` is the list of Theano variables or dictionaries describing the sequences ``scan`` has to iterate over. If a sequence is given as wrapped in a dictionary, then a set of optional information can be provided about the sequence. The dictionary should have the following keys: * ``input`` (*mandatory*) -- Theano variable representing the sequence. * ``taps`` -- Temporal taps of the sequence required by ``fn``. They are provided as a list of integers, where a value ``k`` impiles that at iteration step ``t`` scan will pass to ``fn`` the slice ``t+k``. Default value is ``[0]`` Any Theano variable in the list ``sequences`` is automatically wrapped into a dictionary where ``taps`` is set to ``[0]`` :param outputs_info: ``outputs_info`` is the list of Theano variables or dictionaries describing the initial state of the outputs computed recurrently. When this initial states are given as dictionary optional information can be provided about the output corresponding to these initial states. The dictionary should have the following keys: * ``initial`` -- Theano variable that represents the initial state of a given output. In case the output is not computed recursively (think of a map) and does not require a initial state this field can be skiped. Given that only the previous time step of the output is used by ``fn`` the initial state should have the same shape as the output. If multiple time taps are used, the initial state should have one extra dimension that should cover all the possible taps. For example if we use ``-5``, ``-2`` and ``-1`` as past taps, at step 0, ``fn`` will require (by an abuse of notation) ``output[-5]``, ``output[-2]`` and ``output[-1]``. This will be given by the initial state, which in this case should have the shape (5,)+output.shape. If this variable containing the initial state is called ``init_y`` then ``init_y[0]`` *corresponds to* ``output[-5]``. ``init_y[1]`` *correponds to* ``output[-4]``, ``init_y[2]`` corresponds to ``output[-3]``, ``init_y[3]`` coresponds to ``output[-2]``, ``init_y[4]`` corresponds to ``output[-1]``. While this order might seem strange, it comes natural from splitting an array at a given point. Assume that we have a array ``x``, and we choose ``k`` to be time step ``0``. Then our initial state would be ``x[:k]``, while the output will be ``x[k:]``. Looking at this split, elements in ``x[:k]`` are ordered exactly like those in ``init_y``. * ``taps`` -- Temporal taps of the output that will be pass to ``fn``. They are provided as a list of *negative* integers, where a value ``k`` implies that at iteration step ``t`` scan will pass to ``fn`` the slice ``t+k``. ``scan`` will follow this logic if partial information is given: * If an output is not wrapped in a dictionary, ``scan`` will wrap it in one assuming that you use only the last step of the output (i.e. it makes your tap value list equal to [-1]). * If you wrap an output in a dictionary and you do not provide any taps but you provide an initial state it will assume that you are using only a tap value of -1. * If you wrap an output in a dictionary but you do not provide any initial state, it assumes that you are not using any form of taps. * If you provide a ``None`` instead of a variable or a empty dictionary ``scan`` assumes that you will not use any taps for this output (like for example in case of a map) If ``outputs_info`` is an empty list or None, ``scan`` assumes that no tap is used for any of the outputs. If information is provided just for a subset of the outputs an exception is raised (because there is no convention on how scan should map the provided information to the outputs of ``fn``) :param non_sequences: ``non_sequences`` is the list of arguments that are passed to ``fn`` at each steps. One can opt to exclude variable used in ``fn`` from this list as long as they are part of the computational graph, though for clarity we encourage not to do so. :param n_steps: ``n_steps`` is the number of steps to iterate given as an int or Theano scalar. If any of the input sequences do not have enough elements, scan will raise an error. If the *value is 0* the outputs will have *0 rows*. If the value is negative, ``scan`` will run backwards in time. If the ``go_backwards`` flag is already set and also ``n_steps`` is negative, ``scan`` will run forward in time. If n stpes is not provided, ``scan`` will figure out the amount of steps it should run given its input sequences. :param truncate_gradient: ``truncate_gradient`` is the number of steps to use in truncated BPTT. If you compute gradients through a scan op, they are computed using backpropagation through time. By providing a different value then -1, you choose to use truncated BPTT instead of classical BPTT, where you go for only ``truncate_gradient`` number of steps back in time. :param go_backwards: ``go_backwards`` is a flag indicating if ``scan`` should go backwards through the sequences. If you think of each sequence as indexed by time, making this flag True would mean that ``scan`` goes back in time, namely that for any sequence it starts from the end and goes towards 0. :param name: When profiling ``scan``, it is crucial to provide a name for any instance of ``scan``. The profiler will produce an overall profile of your code as well as profiles for the computation of one step of each instance of ``scan``. The ``name`` of the instance appears in those profiles and can greatly help to disambiguate information. :param mode: It is recommended to leave this argument to None, especially when profiling ``scan`` (otherwise the results are not going to be accurate). If you prefer the computations of one step of ``scan`` to be done differently then the entire function, you can use this parameter to describe how the computations in this loop are done (see ``theano.function`` for details about possible values and their meaning). :param profile: Flag or string. If true, or different from the empty string, a profile object will be created and attached to the inner graph of scan. In case ``profile`` is True, the profile object will have the name of the scan instance, otherwise it will have the passed string. Profile object collect (and print) information only when running the inner graph with the new cvm linker ( with default modes, other linkers this argument is useless) :rtype: tuple :return: tuple of the form (outputs, updates); ``outputs`` is either a Theano variable or a list of Theano variables representing the outputs of ``scan`` (in the same order as in ``outputs_info``). ``updates`` is a subclass of dictionary specifying the update rules for all shared variables used in scan This dictionary should be passed to ``theano.function`` when you compile your function. The change compared to a normal dictionary is that we validate that keys are SharedVariable and addition of those dictionary are validated to be consistent. """ # Note : see the internal documentation of the scan op for naming # conventions and all other details if options is None: options = {} rvals = scan_utils.canonical_arguments(sequences, outputs_info, non_sequences, go_backwards, n_steps) inputs, states_and_outputs_info, parameters, T = rvals # If we provided a known number of steps ( before compilation) # and if that number is 1 or -1, then we can skip the Scan Op, # and just apply the inner function once # To do that we check here to see the nature of n_steps T_value = None if isinstance(n_steps, (float, int)): T_value = int(n_steps) else: try: T_value = opt.get_constant_value(n_steps) except (TypeError, AttributeError): T_value = None if T_value in (1, -1): return one_step_scan(fn, inputs, states_and_outputs_info, parameters, truncate_gradient) # 1. Variable representing the current time step t = scalar_shared(numpy.int64(0), name='t') # 2. Allocate memory for the states of scan. mintaps = [] lengths = [] for pos, arg_info in enumerate(states_and_outputs_info): if arg_info.get('taps', None) == [-1]: mintaps.append(1) lengths.append(scalar_shared(numpy.int64(0), name='l%d' % pos)) arg_info['initial'] = scan_utils.expand(tensor.unbroadcast( tensor.shape_padleft(arg_info['initial']), 0), T) elif arg_info.get('taps', None): if numpy.any(numpy.array(arg_info.get('taps', [])) > 0): # Make sure we do not have requests for future values of a # sequence we can not provide such values raise ValueError('Can not use future taps of outputs', arg_info) mintap = abs(numpy.min(arg_info['taps'])) lengths.append(scalar_shared(numpy.int64(0), name='l%d' % pos)) mintaps.append(mintap) arg_info['initial'] = scan_utils.expand( arg_info['initial'][:mintap], T) else: mintaps.append(0) lengths.append(scalar_shared(numpy.int64(0), name='l%d' % pos)) # 3. Generate arguments for the function passed to scan. This will # function will return the outputs that need to be computed at every # timesteps inputs_slices = [input[t] for input in inputs] states_slices = [] for n, state in enumerate(states_and_outputs_info): # Check if it is actually a state and not an output if mintaps[n] != 0: for k in state['taps']: states_slices.append( state['initial'][(t + mintaps[n] + k) % lengths[n]]) # 4. Construct outputs that are to be computed by the inner # function of scan args = inputs_slices + states_slices + parameters cond, states_and_outputs, updates = \ scan_utils.get_updates_and_outputs(fn(*args)) # User is allowed to provide no information if it only behaves like a # map if (len(states_and_outputs) != len(states_and_outputs_info) and len(states_and_outputs_info) == 0): mintaps = [0] * len(states_and_outputs) # 5. Construct the scan op # 5.1 Construct list of shared variables with updates (those that # can be treated as states (i.e. of TensorType) and those that can not # (like Random States) if cond is not None: _cond = [cond] else: _cond = [] rvals = rebuild_collect_shared( states_and_outputs + _cond, updates=updates, rebuild_strict=True, copy_inputs_over=True, no_default_updates=False) # extracting the arguments input_variables, cloned_outputs, other_rval = rvals clone_d, update_d, update_expr, shared_inputs = other_rval additional_input_states = [] additional_output_states = [] additional_lengths = [] additional_mintaps = [] original_numeric_shared_variables = [] non_numeric_input_states = [] non_numeric_output_states = [] original_non_numeric_shared_variables = [] pos = len(lengths) for sv in shared_inputs: if sv in update_d: if isinstance(sv, (TensorVariable, TensorSharedVariable)): # We can treat it as a sit sot nw_state = scan_utils.expand( tensor.unbroadcast(tensor.shape_padleft(sv), 0), T) additional_lengths.append(scalar_shared(numpy.int64(0), name='l%d' % pos)) pos = pos + 1 additional_mintaps.append(1) additional_input_states.append(nw_state) additional_output_states.append( scan_utils.clone(tensor.set_subtensor( nw_state[(t + 1) % additional_lengths[-1]], update_d[sv]))) original_numeric_shared_variables.append(sv) else: non_numeric_input_states.append(sv) non_numeric_output_states.append(update_d[sv]) original_non_numeric_shared_variables.append(sv) # Replace shared variables in the update _additional_output_states = [] replace = {} for sv, buf in zip(original_numeric_shared_variables, additional_input_states): replace[sv] = buf[t] for out in additional_output_states: _additional_output_states.append( scan_utils.clone(out, replace=replace)) additional_output_states = _additional_output_states # 5.2 Collect inputs/outputs of the inner function inputs = [] outputs = [] for n, mintap in enumerate(mintaps): if mintap != 0: input_state = states_and_outputs_info[n]['initial'] inputs.append(input_state) outputs.append( tensor.set_subtensor( input_state[(t + mintap) % lengths[n]], states_and_outputs[n])) else: mem_buffer = scan_utils.allocate_memory( T, states_and_outputs_info[n], states_and_outputs[n]) inputs.append(output) outputs.append( tensor.set_subtensor(output[t % lengths[n]], states_and_outputs[n])) inputs.extend(additional_input_states) outputs.extend(additional_output_states) lengths.extend(additional_lengths) mintaps.extend(additional_mintaps) inputs.extend(non_numeric_input_states) outputs.extend(non_numeric_output_states) all_other_inputs = gof.graph.inputs(outputs) parameters = [x for x in all_other_inputs if (x not in inputs and x not in lengths and x is not t and isinstance(x, gof.Variable) and not isinstance(x, gof.Constant))] inputs.extend(parameters) # 5.3 Construct the the options dictionary options['name'] = name options['profile'] = profile options['mode'] = mode options['inplace'] = False options['gpu'] = False options['truncate_gradient'] = truncate_gradient options['hash_inner_graph'] = 0 # 5.4 Construct the ScanOp instance local_op = scan_op.ScanOp(inputs=inputs, outputs=outputs, lengths=lengths, switches=[], mintaps=mintaps, index=t, options=options, as_repeatUntil=cond) # Note that we get here all the outputs followed by the update rules to # the shared variables we had in our scan # we know that we have (in this given order): # * len(states_and_outputs) real outputs # * len(additional_input_states) updates for numeric shared variable # * len(non_numeric_input_states) updates for non numeric shared # variables scan_inputs = [T] + inputs scan_outputs_update_rules = scan_utils.to_list(local_op(*scan_inputs)) # 5.5 Collect outputs and add permutation object scan_outputs = [] for pos in xrange(len(states_and_outputs)): out = scan_utils.ScanPermutation(mintaps[pos])( scan_outputs_update_rules[pos], t) scan_outputs.append(out[mintaps[pos]:]) # 5.6 Construct updates dictionary update_rules = scan_outputs_update_rules[len(states_and_outputs):] updates = {} for v, u in izip(original_numeric_shared_variables, update_rules[:len(additional_input_states)]): updates[v] = u[-1] for v, u in izip(original_non_numeric_shared_variables, update_rules[len(additional_input_states):]): updates[v] = u # Step 5.7 We are done and can return everything back to the user return scan_outputs, updates
def scan(fn, sequences=None, outputs_info=None, non_sequences=None, n_steps=None, truncate_gradient=-1, go_backwards=False, mode=None, name=None, profile=False): """ This function constructs and applies a Scan op to the provided arguments. :param fn: ``fn`` is a function that describes the operations involved in one step of ``scan``. ``fn`` should construct variables describing the output of one iteration step. It should expect as input theano variables representing all the slices of the input sequences and previous values of the outputs, as well as all other arguments given to scan as ``non_sequences``. The order in which scan passes these variables to ``fn`` is the following : * all time slices of the first sequence * all time slices of the second sequence * ... * all time slices of the last sequence * all past slices of the first output * all past slices of the second otuput * ... * all past slices of the last output * all other arguments (the list given as `non_sequences` to scan) The order of the sequences is the same as the one in the list `sequences` given to scan. The order of the outputs is the same as the order of ``output_info``. For any sequence or output the order of the time slices is the same as the one in which they have been given as taps. For example if one writes the following : .. code-block:: python scan(fn, sequences = [ dict(input= Sequence1, taps = [-3,2,-1]) , Sequence2 , dict(input = Sequence3, taps = 3) ] , outputs_info = [ dict(initial = Output1, taps = [-3,-5]) , dict(initial = Output2, taps = None) , Output3 ] , non_sequences = [ Argument1, Argument 2]) ``fn`` should expect the following arguments in this given order: #. ``Sequence1[t-3]`` #. ``Sequence1[t+2]`` #. ``Sequence1[t-1]`` #. ``Sequence2[t]`` #. ``Sequence3[t+3]`` #. ``Output1[t-3]`` #. ``Output1[t-5]`` #. ``Output3[t-1]`` #. ``Argument1`` #. ``Argument2`` The list of ``non_sequences`` can also contain shared variables used in the function, though ``scan`` is able to figure those out on its own so they can be skipped. For the clarity of the code we recommand though to provide them to scan. To some extend ``scan`` can also figure out other ``non sequences`` (not shared) even if not passed to scan (but used by `fn`). A simple example of this would be : .. code-block:: python import theano.tensor as TT W = TT.matrix() W_2 = W**2 def f(x): return TT.dot(x,W_2) The function is expected to return two things. One is a list of outputs ordered in the same order as ``outputs_info``, with the difference that there should be only one output variable per output initial state (even if no tap value is used). Secondly `fn` should return an update dictionary (that tells how to update any shared variable after each iteration step). The dictionary can optionally be given as a list of tuples. There is no constraint on the order of these two list, ``fn`` can return either ``(outputs_list, update_dictionary)`` or ``(update_dictionary, outputs_list)`` or just one of the two (in case the other is empty). To use ``scan`` as a while loop, the user needs to change the function ``fn`` such that also a stopping condition is returned. To do so, he/she needs to wrap the condition in an ``until`` class. The condition should be returned as a third element, for example: .. code-block:: python ... return [y1_t, y2_t], {x:x+1}, theano.scan_module.until(x < 50) Note that a number of steps (considered in here as the maximum number of steps ) is still required even though a condition is passed (and it is used to allocate memory if needed). = {}): :param sequences: ``sequences`` is the list of Theano variables or dictionaries describing the sequences ``scan`` has to iterate over. If a sequence is given as wrapped in a dictionary, then a set of optional information can be provided about the sequence. The dictionary should have the following keys: * ``input`` (*mandatory*) -- Theano variable representing the sequence. * ``taps`` -- Temporal taps of the sequence required by ``fn``. They are provided as a list of integers, where a value ``k`` impiles that at iteration step ``t`` scan will pass to ``fn`` the slice ``t+k``. Default value is ``[0]`` Any Theano variable in the list ``sequences`` is automatically wrapped into a dictionary where ``taps`` is set to ``[0]`` :param outputs_info: ``outputs_info`` is the list of Theano variables or dictionaries describing the initial state of the outputs computed recurrently. When this initial states are given as dictionary optional information can be provided about the output corresponding to these initial states. The dictionary should have the following keys: * ``initial`` -- Theano variable that represents the initial state of a given output. In case the output is not computed recursively (think of a map) and does not require a initial state this field can be skiped. Given that only the previous time step of the output is used by ``fn`` the initial state should have the same shape as the output. If multiple time taps are used, the initial state should have one extra dimension that should cover all the possible taps. For example if we use ``-5``, ``-2`` and ``-1`` as past taps, at step 0, ``fn`` will require (by an abuse of notation) ``output[-5]``, ``output[-2]`` and ``output[-1]``. This will be given by the initial state, which in this case should have the shape (5,)+output.shape. If this variable containing the initial state is called ``init_y`` then ``init_y[0]`` *corresponds to* ``output[-5]``. ``init_y[1]`` *correponds to* ``output[-4]``, ``init_y[2]`` corresponds to ``output[-3]``, ``init_y[3]`` coresponds to ``output[-2]``, ``init_y[4]`` corresponds to ``output[-1]``. While this order might seem strange, it comes natural from splitting an array at a given point. Assume that we have a array ``x``, and we choose ``k`` to be time step ``0``. Then our initial state would be ``x[:k]``, while the output will be ``x[k:]``. Looking at this split, elements in ``x[:k]`` are ordered exactly like those in ``init_y``. * ``taps`` -- Temporal taps of the output that will be pass to ``fn``. They are provided as a list of *negative* integers, where a value ``k`` implies that at iteration step ``t`` scan will pass to ``fn`` the slice ``t+k``. ``scan`` will follow this logic if partial information is given: * If an output is not wrapped in a dictionary, ``scan`` will wrap it in one assuming that you use only the last step of the output (i.e. it makes your tap value list equal to [-1]). * If you wrap an output in a dictionary and you do not provide any taps but you provide an initial state it will assume that you are using only a tap value of -1. * If you wrap an output in a dictionary but you do not provide any initial state, it assumes that you are not using any form of taps. * If you provide a ``None`` instead of a variable or a empty dictionary ``scan`` assumes that you will not use any taps for this output (like for example in case of a map) If ``outputs_info`` is an empty list or None, ``scan`` assumes that no tap is used for any of the outputs. If information is provided just for a subset of the outputs an exception is raised (because there is no convention on how scan should map the provided information to the outputs of ``fn``) :param non_sequences: ``non_sequences`` is the list of arguments that are passed to ``fn`` at each steps. One can opt to exclude variable used in ``fn`` from this list as long as they are part of the computational graph, though for clarity we encourage not to do so. :param n_steps: ``n_steps`` is the number of steps to iterate given as an int or Theano scalar. If any of the input sequences do not have enough elements, scan will raise an error. If the *value is 0* the outputs will have *0 rows*. If the value is negative, ``scan`` will run backwards in time. If the ``go_backwards`` flag is already set and also ``n_steps`` is negative, ``scan`` will run forward in time. If n stpes is not provided, ``scan`` will figure out the amount of steps it should run given its input sequences. :param truncate_gradient: ``truncate_gradient`` is the number of steps to use in truncated BPTT. If you compute gradients through a scan op, they are computed using backpropagation through time. By providing a different value then -1, you choose to use truncated BPTT instead of classical BPTT, where you go for only ``truncate_gradient`` number of steps back in time. :param go_backwards: ``go_backwards`` is a flag indicating if ``scan`` should go backwards through the sequences. If you think of each sequence as indexed by time, making this flag True would mean that ``scan`` goes back in time, namely that for any sequence it starts from the end and goes towards 0. :param name: When profiling ``scan``, it is crucial to provide a name for any instance of ``scan``. The profiler will produce an overall profile of your code as well as profiles for the computation of one step of each instance of ``scan``. The ``name`` of the instance appears in those profiles and can greatly help to disambiguate information. :param mode: It is recommended to leave this argument to None, especially when profiling ``scan`` (otherwise the results are not going to be accurate). If you prefer the computations of one step of ``scan`` to be done differently then the entire function, you can use this parameter to describe how the computations in this loop are done (see ``theano.function`` for details about possible values and their meaning). :param profile: Flag or string. If true, or different from the empty string, a profile object will be created and attached to the inner graph of scan. In case ``profile`` is True, the profile object will have the name of the scan instance, otherwise it will have the passed string. Profile object collect (and print) information only when running the inner graph with the new cvm linker ( with default modes, other linkers this argument is useless) :rtype: tuple :return: tuple of the form (outputs, updates); ``outputs`` is either a Theano variable or a list of Theano variables representing the outputs of ``scan`` (in the same order as in ``outputs_info``). ``updates`` is a subclass of dictionary specifying the update rules for all shared variables used in scan This dictionary should be passed to ``theano.function`` when you compile your function. The change compared to a normal dictionary is that we validate that keys are SharedVariable and addition of those dictionary are validated to be consistent. """ # General observation : this code is executed only once, at creation # of the computational graph, so we don't yet need to be smart about # anything (to speed things up) ## ### Step 1. Wrap all inputs in dictionaries and add default values ## # check if inputs are just single variables instead of lists def wrap_into_list(x): ''' Wrap the input into a list if it is not already a list ''' if x is None: return [] elif not isinstance(x, (list, tuple)): return [x] else: return list(x) seqs = wrap_into_list(sequences) outs_info = wrap_into_list(outputs_info) # Make sure we get rid of numpy arrays or ints or anything like that # passed as inputs to scan non_seqs = [] for elem in wrap_into_list(non_sequences): if not isinstance(elem, gof.Variable): non_seqs.append(tensor.as_tensor_variable(elem)) else: non_seqs.append(elem) # If we provided a known number of steps ( before compilation) # and if that number is 1 or -1, then we can skip the Scan Op, # and just apply the inner function once # To do that we check here to see the nature of n_steps n_fixed_steps = None if isinstance(n_steps, (float, int)): n_fixed_steps = int(n_steps) else: try: n_fixed_steps = opt.get_constant_value(n_steps) except (TypeError, AttributeError): n_fixed_steps = None # Check n_steps is an int if (hasattr(n_steps, 'dtype') and str(n_steps.dtype)[:3] not in ('uin', 'int')): raise ValueError(' n_steps must be an int. dtype provided ' 'is %s' % n_steps.dtype) # compute number of sequences and number of outputs n_seqs = len(seqs) n_outs = len(outs_info) return_steps = {} # wrap sequences in a dictionary if they are not already dictionaries for i in xrange(n_seqs): if not isinstance(seqs[i], dict): seqs[i] = dict(input=seqs[i], taps=[0]) elif seqs[i].get('taps', None): seqs[i]['taps'] = wrap_into_list(seqs[i]['taps']) elif seqs[i].get('taps', True) is None: # seqs dictionary does not have the ``taps`` key seqs[i]['taps'] = [0] # wrap outputs info in a dictionary if they are not already in one for i in xrange(n_outs): if outs_info[i] is not None: if isinstance(outs_info[i], dict): # DEPRECATED : if outs_info[i].get('return_steps', None): raise ValueError( "Using `return_steps` has been deprecated. " "Simply select the entries you need using a " "subtensor. Scan will optimize memory " "consumption, so do not worry about that.") # END if not isinstance(outs_info[i], dict): # by default any output has a tap value of -1 outs_info[i] = dict(initial=outs_info[i], taps=[-1]) elif (not outs_info[i].get('initial', None) and outs_info[i].get('taps', None)): # ^ no initial state but taps provided raise ValueError(('If you are using slices of an output ' 'you need to provide a initial state ' 'for it'), outs_info[i]) elif (outs_info[i].get('initial', None) and not outs_info[i].get('taps', None)): # ^ initial state but taps not provided if 'taps' in outs_info[i]: # ^ explicitly provided a None for taps _logger.warning( 'Output %s ( index %d) has a initial ' 'state but taps is explicitly set to None ', getattr(outs_info[i]['initial'], 'name', 'None'), i) outs_info[i]['taps'] = [-1] else: # if a None is provided as the output info we replace it # with an empty dict() to simplify handling outs_info[i] = dict() ## ### Step 2. Generate inputs and outputs of the inner functions ### for compiling a dummy function (Iteration #1) ## # create theano inputs for the recursive function # note : this is a first batch of possible inputs that will # be compiled in a dummy function; we used this dummy # function to detect shared variables and their updates # and to construct a new and complete list of inputs and # outputs n_seqs = 0 scan_seqs = [] # Variables passed as inputs to the scan op inner_seqs = [] # Variables passed as inputs to the inner function inner_slices = [] # Actual slices if scan is removed from the picture # go through sequences picking up time slices as needed for i, seq in enumerate(seqs): # Note that you can have something like no taps for # a sequence, though is highly unlikely in practice if 'taps' in seq: # go through the indicated slice mintap = numpy.min(seq['taps']) maxtap = numpy.max(seq['taps']) for k in seq['taps']: # create one slice of the input # Later on, if we decide not to use scan because we are # going for just one step, it makes things easier if we # compute the correct outputs here. This way we can use # the output of the lambda expression directly to replace # the output of scan. # If not we need to use copies, that will be replaced at # each frame by the corresponding slice actual_slice = seq['input'][k - mintap] _seq_val = tensor.as_tensor_variable(seq['input']) _seq_val_slice = _seq_val[k - mintap] nw_slice = _seq_val_slice.type() # Try to transfer test_value to the new variable if config.compute_test_value != 'off': try: nw_slice.tag.test_value = gof.Op._get_test_value( _seq_val_slice) except AttributeError, e: if config.compute_test_value != 'ignore': # No need to print a warning or raise an error now, # it will be done when fn will be called. _logger.info( ('Cannot compute test value for ' 'the inner function of scan, input value ' 'missing %s'), e) # Add names to slices for debugging and pretty printing .. # that is if the input already has a name if getattr(seq['input'], 'name', None) is not None: if k > 0: nw_name = seq['input'].name + '[t+%d]' % k elif k == 0: nw_name = seq['input'].name + '[t]' else: nw_name = seq['input'].name + '[t%d]' % k nw_slice.name = nw_name # We cut the sequence such that seq[i] to correspond to # seq[i-k] if maxtap < 0: offset = abs(maxtap) else: offset = 0 if maxtap == mintap and maxtap != 0: nw_seq = seq['input'][:abs(maxtap)] elif maxtap - k != 0: nw_seq = seq['input'][offset + k - mintap:-(maxtap - k)] else: nw_seq = seq['input'][offset + k - mintap:] if go_backwards: nw_seq = nw_seq[::-1] scan_seqs.append(nw_seq) inner_seqs.append(nw_slice) inner_slices.append(actual_slice) n_seqs += 1
def scan(fn, sequences=None, outputs_info=None, non_sequences=None, n_steps=None, truncate_gradient=-1, go_backwards=False, mode=None, name=None, options=None, profile=False): """ This function constructs and applies a Scan op to the provided arguments. :param fn: ``fn`` is a function that describes the operations involved in one step of ``scan``. ``fn`` should construct variables describing the output of one iteration step. It should expect as input theano variables representing all the slices of the input sequences and previous values of the outputs, as well as all other arguments given to scan as ``non_sequences``. The order in which scan passes these variables to ``fn`` is the following : * all time slices of the first sequence * all time slices of the second sequence * ... * all time slices of the last sequence * all past slices of the first output * all past slices of the second otuput * ... * all past slices of the last output * all other arguments (the list given as `non_sequences` to scan) The order of the sequences is the same as the one in the list `sequences` given to scan. The order of the outputs is the same as the order of ``output_info``. For any sequence or output the order of the time slices is the same as the one in which they have been given as taps. For example if one writes the following : .. code-block:: python scan(fn, sequences = [ dict(input= Sequence1, taps = [-3,2,-1]) , Sequence2 , dict(input = Sequence3, taps = 3) ] , outputs_info = [ dict(initial = Output1, taps = [-3,-5]) , dict(initial = Output2, taps = None) , Output3 ] , non_sequences = [ Argument1, Argument 2]) ``fn`` should expect the following arguments in this given order: #. ``Sequence1[t-3]`` #. ``Sequence1[t+2]`` #. ``Sequence1[t-1]`` #. ``Sequence2[t]`` #. ``Sequence3[t+3]`` #. ``Output1[t-3]`` #. ``Output1[t-5]`` #. ``Output3[t-1]`` #. ``Argument1`` #. ``Argument2`` The list of ``non_sequences`` can also contain shared variables used in the function, though ``scan`` is able to figure those out on its own so they can be skipped. For the clarity of the code we recommand though to provide them to scan. To some extend ``scan`` can also figure out other ``non sequences`` (not shared) even if not passed to scan (but used by `fn`). A simple example of this would be : .. code-block:: python import theano.tensor as TT W = TT.matrix() W_2 = W**2 def f(x): return TT.dot(x,W_2) The function is expected to return two things. One is a list of outputs ordered in the same order as ``outputs_info``, with the difference that there should be only one output variable per output initial state (even if no tap value is used). Secondly `fn` should return an update dictionary (that tells how to update any shared variable after each iteration step). The dictionary can optionally be given as a list of tuples. There is no constraint on the order of these two list, ``fn`` can return either ``(outputs_list, update_dictionary)`` or ``(update_dictionary, outputs_list)`` or just one of the two (in case the other is empty). To use ``scan`` as a while loop, the user needs to change the function ``fn`` such that also a stopping condition is returned. To do so, he/she needs to wrap the condition in an ``until`` class. The condition should be returned as a third element, for example: .. code-block:: python ... return [y1_t, y2_t], {x:x+1}, theano.scan_module.until(x < 50) Note that a number of steps (considered in here as the maximum number of steps ) is still required even though a condition is passed (and it is used to allocate memory if needed). = {}): :param sequences: ``sequences`` is the list of Theano variables or dictionaries describing the sequences ``scan`` has to iterate over. If a sequence is given as wrapped in a dictionary, then a set of optional information can be provided about the sequence. The dictionary should have the following keys: * ``input`` (*mandatory*) -- Theano variable representing the sequence. * ``taps`` -- Temporal taps of the sequence required by ``fn``. They are provided as a list of integers, where a value ``k`` impiles that at iteration step ``t`` scan will pass to ``fn`` the slice ``t+k``. Default value is ``[0]`` Any Theano variable in the list ``sequences`` is automatically wrapped into a dictionary where ``taps`` is set to ``[0]`` :param outputs_info: ``outputs_info`` is the list of Theano variables or dictionaries describing the initial state of the outputs computed recurrently. When this initial states are given as dictionary optional information can be provided about the output corresponding to these initial states. The dictionary should have the following keys: * ``initial`` -- Theano variable that represents the initial state of a given output. In case the output is not computed recursively (think of a map) and does not require a initial state this field can be skiped. Given that only the previous time step of the output is used by ``fn`` the initial state should have the same shape as the output. If multiple time taps are used, the initial state should have one extra dimension that should cover all the possible taps. For example if we use ``-5``, ``-2`` and ``-1`` as past taps, at step 0, ``fn`` will require (by an abuse of notation) ``output[-5]``, ``output[-2]`` and ``output[-1]``. This will be given by the initial state, which in this case should have the shape (5,)+output.shape. If this variable containing the initial state is called ``init_y`` then ``init_y[0]`` *corresponds to* ``output[-5]``. ``init_y[1]`` *correponds to* ``output[-4]``, ``init_y[2]`` corresponds to ``output[-3]``, ``init_y[3]`` coresponds to ``output[-2]``, ``init_y[4]`` corresponds to ``output[-1]``. While this order might seem strange, it comes natural from splitting an array at a given point. Assume that we have a array ``x``, and we choose ``k`` to be time step ``0``. Then our initial state would be ``x[:k]``, while the output will be ``x[k:]``. Looking at this split, elements in ``x[:k]`` are ordered exactly like those in ``init_y``. * ``taps`` -- Temporal taps of the output that will be pass to ``fn``. They are provided as a list of *negative* integers, where a value ``k`` implies that at iteration step ``t`` scan will pass to ``fn`` the slice ``t+k``. ``scan`` will follow this logic if partial information is given: * If an output is not wrapped in a dictionary, ``scan`` will wrap it in one assuming that you use only the last step of the output (i.e. it makes your tap value list equal to [-1]). * If you wrap an output in a dictionary and you do not provide any taps but you provide an initial state it will assume that you are using only a tap value of -1. * If you wrap an output in a dictionary but you do not provide any initial state, it assumes that you are not using any form of taps. * If you provide a ``None`` instead of a variable or a empty dictionary ``scan`` assumes that you will not use any taps for this output (like for example in case of a map) If ``outputs_info`` is an empty list or None, ``scan`` assumes that no tap is used for any of the outputs. If information is provided just for a subset of the outputs an exception is raised (because there is no convention on how scan should map the provided information to the outputs of ``fn``) :param non_sequences: ``non_sequences`` is the list of arguments that are passed to ``fn`` at each steps. One can opt to exclude variable used in ``fn`` from this list as long as they are part of the computational graph, though for clarity we encourage not to do so. :param n_steps: ``n_steps`` is the number of steps to iterate given as an int or Theano scalar. If any of the input sequences do not have enough elements, scan will raise an error. If the *value is 0* the outputs will have *0 rows*. If the value is negative, ``scan`` will run backwards in time. If the ``go_backwards`` flag is already set and also ``n_steps`` is negative, ``scan`` will run forward in time. If n stpes is not provided, ``scan`` will figure out the amount of steps it should run given its input sequences. :param truncate_gradient: ``truncate_gradient`` is the number of steps to use in truncated BPTT. If you compute gradients through a scan op, they are computed using backpropagation through time. By providing a different value then -1, you choose to use truncated BPTT instead of classical BPTT, where you go for only ``truncate_gradient`` number of steps back in time. :param go_backwards: ``go_backwards`` is a flag indicating if ``scan`` should go backwards through the sequences. If you think of each sequence as indexed by time, making this flag True would mean that ``scan`` goes back in time, namely that for any sequence it starts from the end and goes towards 0. :param name: When profiling ``scan``, it is crucial to provide a name for any instance of ``scan``. The profiler will produce an overall profile of your code as well as profiles for the computation of one step of each instance of ``scan``. The ``name`` of the instance appears in those profiles and can greatly help to disambiguate information. :param mode: It is recommended to leave this argument to None, especially when profiling ``scan`` (otherwise the results are not going to be accurate). If you prefer the computations of one step of ``scan`` to be done differently then the entire function, you can use this parameter to describe how the computations in this loop are done (see ``theano.function`` for details about possible values and their meaning). :param profile: Flag or string. If true, or different from the empty string, a profile object will be created and attached to the inner graph of scan. In case ``profile`` is True, the profile object will have the name of the scan instance, otherwise it will have the passed string. Profile object collect (and print) information only when running the inner graph with the new cvm linker ( with default modes, other linkers this argument is useless) :rtype: tuple :return: tuple of the form (outputs, updates); ``outputs`` is either a Theano variable or a list of Theano variables representing the outputs of ``scan`` (in the same order as in ``outputs_info``). ``updates`` is a subclass of dictionary specifying the update rules for all shared variables used in scan This dictionary should be passed to ``theano.function`` when you compile your function. The change compared to a normal dictionary is that we validate that keys are SharedVariable and addition of those dictionary are validated to be consistent. """ # Note : see the internal documentation of the scan op for naming # conventions and all other details if options is None: options = {} rvals = scan_utils.canonical_arguments(sequences, outputs_info, non_sequences, go_backwards, n_steps) inputs, states_and_outputs_info, parameters, T = rvals # If we provided a known number of steps ( before compilation) # and if that number is 1 or -1, then we can skip the Scan Op, # and just apply the inner function once # To do that we check here to see the nature of n_steps T_value = None if isinstance(n_steps, (float, int)): T_value = int(n_steps) else: try: T_value = opt.get_constant_value(n_steps) except (TypeError, AttributeError): T_value = None if T_value in (1, -1): return one_step_scan(fn, inputs, states_and_outputs_info, parameters, truncate_gradient) # 1. Variable representing the current time step t = scalar_shared(numpy.int64(0), name='t') # 2. Allocate memory for the states of scan. mintaps = [] lengths = [] for pos, arg_info in enumerate(states_and_outputs_info): if arg_info.get('taps', None) == [-1]: mintaps.append(1) lengths.append(scalar_shared(numpy.int64(0), name='l%d' % pos)) arg_info['initial'] = scan_utils.expand( tensor.unbroadcast(tensor.shape_padleft(arg_info['initial']), 0), T) elif arg_info.get('taps', None): if numpy.any(numpy.array(arg_info.get('taps', [])) > 0): # Make sure we do not have requests for future values of a # sequence we can not provide such values raise ValueError('Can not use future taps of outputs', arg_info) mintap = abs(numpy.min(arg_info['taps'])) lengths.append(scalar_shared(numpy.int64(0), name='l%d' % pos)) mintaps.append(mintap) arg_info['initial'] = scan_utils.expand( arg_info['initial'][:mintap], T) else: mintaps.append(0) lengths.append(scalar_shared(numpy.int64(0), name='l%d' % pos)) # 3. Generate arguments for the function passed to scan. This will # function will return the outputs that need to be computed at every # timesteps inputs_slices = [input[t] for input in inputs] states_slices = [] for n, state in enumerate(states_and_outputs_info): # Check if it is actually a state and not an output if mintaps[n] != 0: for k in state['taps']: states_slices.append(state['initial'][(t + mintaps[n] + k) % lengths[n]]) # 4. Construct outputs that are to be computed by the inner # function of scan args = inputs_slices + states_slices + parameters cond, states_and_outputs, updates = \ scan_utils.get_updates_and_outputs(fn(*args)) # User is allowed to provide no information if it only behaves like a # map if (len(states_and_outputs) != len(states_and_outputs_info) and len(states_and_outputs_info) == 0): mintaps = [0] * len(states_and_outputs) # 5. Construct the scan op # 5.1 Construct list of shared variables with updates (those that # can be treated as states (i.e. of TensorType) and those that can not # (like Random States) if cond is not None: _cond = [cond] else: _cond = [] rvals = rebuild_collect_shared(states_and_outputs + _cond, updates=updates, rebuild_strict=True, copy_inputs_over=True, no_default_updates=False) # extracting the arguments input_variables, cloned_outputs, other_rval = rvals clone_d, update_d, update_expr, shared_inputs = other_rval additional_input_states = [] additional_output_states = [] additional_lengths = [] additional_mintaps = [] original_numeric_shared_variables = [] non_numeric_input_states = [] non_numeric_output_states = [] original_non_numeric_shared_variables = [] pos = len(lengths) for sv in shared_inputs: if sv in update_d: if isinstance(sv, (TensorVariable, TensorSharedVariable)): # We can treat it as a sit sot nw_state = scan_utils.expand( tensor.unbroadcast(tensor.shape_padleft(sv), 0), T) additional_lengths.append( scalar_shared(numpy.int64(0), name='l%d' % pos)) pos = pos + 1 additional_mintaps.append(1) additional_input_states.append(nw_state) additional_output_states.append( scan_utils.clone( tensor.set_subtensor( nw_state[(t + 1) % additional_lengths[-1]], update_d[sv]))) original_numeric_shared_variables.append(sv) else: non_numeric_input_states.append(sv) non_numeric_output_states.append(update_d[sv]) original_non_numeric_shared_variables.append(sv) # Replace shared variables in the update _additional_output_states = [] replace = {} for sv, buf in zip(original_numeric_shared_variables, additional_input_states): replace[sv] = buf[t] for out in additional_output_states: _additional_output_states.append(scan_utils.clone(out, replace=replace)) additional_output_states = _additional_output_states # 5.2 Collect inputs/outputs of the inner function inputs = [] outputs = [] for n, mintap in enumerate(mintaps): if mintap != 0: input_state = states_and_outputs_info[n]['initial'] inputs.append(input_state) outputs.append( tensor.set_subtensor(input_state[(t + mintap) % lengths[n]], states_and_outputs[n])) else: mem_buffer = scan_utils.allocate_memory(T, states_and_outputs_info[n], states_and_outputs[n]) inputs.append(output) outputs.append( tensor.set_subtensor(output[t % lengths[n]], states_and_outputs[n])) inputs.extend(additional_input_states) outputs.extend(additional_output_states) lengths.extend(additional_lengths) mintaps.extend(additional_mintaps) inputs.extend(non_numeric_input_states) outputs.extend(non_numeric_output_states) all_other_inputs = gof.graph.inputs(outputs) parameters = [ x for x in all_other_inputs if (x not in inputs and x not in lengths and x is not t and isinstance( x, gof.Variable) and not isinstance(x, gof.Constant)) ] inputs.extend(parameters) # 5.3 Construct the the options dictionary options['name'] = name options['profile'] = profile options['mode'] = mode options['inplace'] = False options['gpu'] = False options['truncate_gradient'] = truncate_gradient options['hash_inner_graph'] = 0 # 5.4 Construct the ScanOp instance local_op = scan_op.ScanOp(inputs=inputs, outputs=outputs, lengths=lengths, switches=[], mintaps=mintaps, index=t, options=options, as_repeatUntil=cond) # Note that we get here all the outputs followed by the update rules to # the shared variables we had in our scan # we know that we have (in this given order): # * len(states_and_outputs) real outputs # * len(additional_input_states) updates for numeric shared variable # * len(non_numeric_input_states) updates for non numeric shared # variables scan_inputs = [T] + inputs scan_outputs_update_rules = scan_utils.to_list(local_op(*scan_inputs)) # 5.5 Collect outputs and add permutation object scan_outputs = [] for pos in xrange(len(states_and_outputs)): out = scan_utils.ScanPermutation(mintaps[pos])( scan_outputs_update_rules[pos], t) scan_outputs.append(out[mintaps[pos]:]) # 5.6 Construct updates dictionary update_rules = scan_outputs_update_rules[len(states_and_outputs):] updates = {} for v, u in izip(original_numeric_shared_variables, update_rules[:len(additional_input_states)]): updates[v] = u[-1] for v, u in izip(original_non_numeric_shared_variables, update_rules[len(additional_input_states):]): updates[v] = u # Step 5.7 We are done and can return everything back to the user return scan_outputs, updates
def scan(fn, sequences=None, states=None, params=None, n_steps=None, mode=None, name=None, profile=False): """ Similar to Theano's official scan, this function gives the user more control over the scan op, avoiding certain difficulties that arose from missing optimizations. :param fn: lambda function that describes one step of scan (see the official Theano scan function) :param sequences: similar to the official Theano's scan. This version of scan does not support taps for the sequences (it can only be a list of tensor). Scan assumes that sequences have the right length and it does not check for this. :param states: similar to outputs_info of the official scan function. There is one crucial difference though, namely that the `initial` key in the dictionary has been replace by 'membuf' key. This reflects the change of meaning. Instead of passing to scan just the initial steps misisng, one has now to pass a memory buffer in which scan will try to store its output. In this memory buffer the first entries should be set to the initial states of the corresponding states. Providing a memory buffer that has less entries then the number of steps, mneans scan will only use that amount of memory. The user has to match the memory buffer size with the number of steps, otherwise scan will produce wrong results. Also if gradients are to be computed through the scan, the memory buffer should have the same length as the number of steps. For states that do not require a initial state, one has to provide a dictionary with a single key 'steps' that says how many intermediate results to store. See examples below for more insight. :param n_steps: This parameter is mandatory and it will represent the number of steps scan will do (scan will not check sequences or any other source of information to figure out how many steps it needs to do). :param mode: Same as for the official scan :param name: Same as for the official scan :param profile: Same as for the official scan Note: - there is no truncate / go_backwards anymore ! - the outputs returned by scan contain the initial states as well (i.e. if I loop over k steps, with my smallest tap for an output -3 and keep al intermediate results, my output will be of length k+3 Examples: (a) if you do not want to store any intermediate results (just the last one) # The memory buffer can be the initial state, just that we need to # add one extra dimension in front of it state = TT.unbroadcast(TT.shape_padleft(x0),0) out,_ = scan(lambda x:x+1, states = state, n_steps = 5) # Once we got our result we need to remove the extra dimension out = out[0] (b) if you want to keep every intermediate results state = TT.alloc(TT.constant(0), 6, x0.shape[0]) state = TT.set_subtensor(state[0], x0) out,_ = scan(lambda x:x+1, states = state, n_steps = 5) out = out[1:] """ def wrap_into_list(x): ''' Wrap the input into a list if it is not already a list ''' if x is None: return [] elif not isinstance(x, (list, tuple)): return [x] else: return list(x) seqs = wrap_into_list(sequences) outs_info = wrap_into_list(states) # Make sure we get rid of numpy arrays or ints or anything like that # passed as inputs to scan non_seqs = [] for elem in wrap_into_list(params): if not isinstance(elem, gof.Variable): non_seqs.append(tensor.as_tensor_variable(elem)) else: non_seqs.append(elem) # If we provided a known number of steps ( before compilation) # and if that number is 1 or -1, then we can skip the Scan Op, # and just apply the inner function once # To do that we check here to see the nature of n_steps n_fixed_steps = None if isinstance(n_steps, (float, int)): n_fixed_steps = int(n_steps) else: try: n_fixed_steps = opt.get_constant_value(n_steps) except (TypeError, AttributeError): n_fixed_steps = None # Check n_steps is an int if (hasattr(n_steps, 'dtype') and str(n_steps.dtype)[:3] not in ('uin', 'int')): raise ValueError(' n_steps must be an int. dtype provided ' 'is %s' % n_steps.dtype) # compute number of sequences and number of outputs n_seqs = len(seqs) n_outs = len(outs_info) return_steps = {} # wrap outputs info in a dictionary if they are not already in one for i in xrange(n_outs): if outs_info[i] is not None: if not isinstance(outs_info[i], dict): # by default any output has a tap value of -1 outs_info[i] = dict(membuf=outs_info[i], taps=[-1]) elif (not outs_info[i].get('membuf', None) and outs_info[i].get('taps', None)): # ^ no initial state but taps provided raise ValueError(('If you are using slices of an output ' 'you need to provide a memory buffer for ' 'the state '), outs_info[i]) elif (outs_info[i].get('membuf', None) and not outs_info[i].get('taps', None)): # ^ initial state but taps not provided if 'taps' in outs_info[i]: # ^ explicitly provided a None for taps _logger.warning( 'Output %s ( index %d) has a memory ' 'buffer but taps is explicitly set to None ', getattr(outs_info[i]['membuf'], 'name', 'None'), i) outs_info[i]['taps'] = [-1] else: # if a None is provided as the output info we replace it # with an dict(steps=n_steps) to simplify handling outs_info[i] = dict(steps=n_steps) ## ### Step 2. Generate inputs and outputs of the inner functions ### for compiling a dummy function (Iteration #1) ## # create theano inputs for the recursive function # note : this is a first batch of possible inputs that will # be compiled in a dummy function; we used this dummy # function to detect shared variables and their updates # and to construct a new and complete list of inputs and # outputs n_seqs = 0 scan_seqs = [] # Variables passed as inputs to the scan op inner_seqs = [] # Variables passed as inputs to the inner function inner_slices = [] # Actual slices if scan is removed from the picture # go through sequences picking up time slices as needed for i, seq in enumerate(seqs): actual_slice = seq[0] _seq_val = tensor.as_tensor_variable(seq) _seq_val_slice = _seq_val[0] # Try to transfer test_value to the new variable if config.compute_test_value != 'off': try: nw_slice.tag.test_value = gof.Op._get_test_value( _seq_val_slice) except AttributeError, e: if config.compute_test_value != 'ignore': # No need to print a warning or raise an error now, # it will be done when fn will be called. _logger.info(('Cannot compute test value for ' 'the inner function of scan, input value ' 'missing %s'), e) nw_slice = _seq_val_slice.type() if seq.name: nw_slice.name = seq.name + '[t]' scan_seqs.append(_seq_val) inner_seqs.append(nw_slice) inner_slices.append(actual_slice) n_seqs += 1