Ejemplo n.º 1
0
    def process_node(self, env, node):

        # helpful functions
        def select_min(x, y):
            if x is None:
                return y
            if y is None:
                return x
            return tensor.minimum(x, y)

        def select_max(x, y):
            if x is None:
                return y
            if y is None:
                return x
            return tensor.maximum(x, y)

        def sanitize(x):
            if x is None:
                return None
            else:
                return tensor.as_tensor_variable(x)

        shape_of = node.env.shape_feature.shape_of
        # 1. Initialization of variables
        # Note 1) We do not actually care about outputs representing shared
        # variables (those have no intermediate values) so it is safer to
        # ignore them and not change them in any way. To simplify the
        # optimizations I construct the variable ``c_outs`` ( that counts
        # outputs up to those we care) and the list ``init_l`` which for any
        # output we care says the length of its initial state. Note that
        # defining ``init_l`` for mit_mot sequences is a bit trickier but
        # it is safe to set it to 0
        op = node.op
        c_outs = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot

        init_l = [0 for x in xrange(op.n_mit_mot)]
        init_l += [abs(numpy.min(v)) for v in op.tap_array[op.n_mit_mot :]]
        init_l += [0 for x in xrange(op.n_nit_sot)]
        # 2. Check the clients of each output and see for how many steps
        # does scan need to run

        # This comparison checks if there is any uncounted output, which
        # can only be an output corresponding to a shared variable

        # 2.1 Initialize
        # global_nsteps is a dictionary having two fields ( 'real' deals
        # with int values, 'sym' with symbolic ones) or None
        # given that a scan op has k outputs o_1, .. o_k and each
        # output has n_j clients c_1^1, c_1^2, .. c_1^{n_1}, c_2^1, ..,
        # global_nsteps is None if any of the clients is different
        # from a subtensor or its real and sym field equal to
        # max(c_i_j.idx_list[0].stop), meaning store up to which maximal
        # index(step) for any output scan actually needs to compute
        # In other words n_steps should be equal to this maximal !
        # Note: if we have a shared variable that gets updated at every step
        # of the loop, reducing the number of steps will affect the the
        # value of the shared variable after the loop so we need not to
        # change the number of steps in that case. To do this we set
        # global_nsteps to None which is seen as a flag that nothing needs
        # to be done
        if len(node.outputs) <= c_outs:
            global_nsteps = {"real": -1, "sym": []}
        else:
            global_nsteps = None

        # Keeps track of the original slices that each client represent
        slices = [None for o in node.outputs]

        # A list for each output indicating how many intermediate values
        # should be stored. If negative it means none of the intermediate
        # values (i.e. the output can be removed since it is not used
        # afterwards in the computations), if 0 it means that all
        # intermediate values are required, otherwise is up to that number
        # of intermediate values
        # Note that for mit_mot outputs and shared outputs we can not change
        # the number of intermediate steps stored without affecting the
        # result of the op
        store_steps = [0 for o in xrange(op.n_mit_mot)]
        store_steps += [-1 for o in node.outputs[op.n_mit_mot : c_outs]]
        # Flag that says if an input has changed and we need to do something
        # or not
        flag_store = False

        # 2.2 Loop over the clients
        for i, out in enumerate(node.outputs[:c_outs]):
            # look at all its clients
            slices[i] = []
            for cl, _ in out.clients:

                # 2.1 outputs of the function
                # => output needs all its intermediate values
                if type(cl) == str:
                    # if the node is actually an output, then
                    # we need to store the entire thing
                    global_nsteps = None
                    slices[i] = None
                    break
                # 2.2 non-subtensor nodes
                # => output needs all its intermediate values
                elif not isinstance(cl.op, tensor.basic.Subtensor):
                    global_nsteps = None
                    slices[i] = None
                    break
                # 2.3 subtensor nodes
                # => output might need to store just a subset of its values
                else:
                    # 2.3.1 extract idx list of subtensor
                    this_slice = tensor.basic.get_idx_list(cl.inputs, cl.op.idx_list)
                    if this_slice == None:
                        # if unable to extract idx_list
                        # => outputs needs all its intermediate values
                        global_nsteps = None
                        slices[i] = None
                        break

                    # 2.3.2 extract the begin/end of the first dimension
                    if i > op.n_mit_mot:
                        try:
                            length = shape_of[out][0]
                        except Exception:
                            length = node.inputs[0] + init_l[i]
                    else:
                        try:
                            length = shape_of[out][0]
                        except Exception:
                            length = out.shape[0]
                    cf_slice = tensor.basic.get_canonical_form_slice(this_slice[0], length)
                    slices[i] += [(cf_slice, this_slice)]

                    if isinstance(this_slice[0], slice) and this_slice[0].stop is None:
                        global_nsteps = None
                        break
                    if isinstance(cf_slice[0], slice):
                        stop = tensor.basic.extract_constant(cf_slice[0].stop)
                    else:
                        stop = tensor.basic.extract_constant(cf_slice[0]) + 1
                    if stop == sys.maxint or stop == length:
                        stop = None
                    else:
                        # there is a **gotcha** here ! Namely, scan returns an
                        # array that contains the initial state of the output
                        # as well. Which means that if have a initial state of
                        # length 3, and you look for 5 steps you get an output
                        # y of length 8. If you only use y[:5], this does not
                        # mean that you only need to loop for 5 steps but
                        # actually only for 2 steps ( the first 3 are the
                        # initial state)
                        stop = stop - init_l[i]

                    # 2.3.3 we might get away with less number of steps
                    if stop is not None and global_nsteps is not None:
                        # yes if it is a tensor
                        if isinstance(stop, tensor.Variable):
                            global_nsteps["sym"] += [stop]
                        # not if it is maxint
                        elif type(stop) is int and stop == sys.maxint:
                            global_nsteps = None
                        # yes if it is a int k, 0 < k < maxint
                        elif type(stop) is int and global_nsteps["real"] < stop:
                            global_nsteps["real"] = stop
                        # yes if it is a int k, 0 < k < maxint
                        elif type(stop) is int and stop > 0:
                            pass
                        # not otherwise
                        else:
                            global_nsteps = None

        # 2.3. Analyze global_nsteps to figure out for how many steps scan
        # needs to iterate
        if global_nsteps is not None:
            nw_steps = node.inputs[0]

            # there are some symbolic tensors that limit the number of
            # steps
            if len(global_nsteps["sym"]) == 0:
                sym_steps = None
            else:
                sym_steps = global_nsteps["sym"][0]
                for c in global_nsteps["sym"][1:]:
                    sym_steps = tensor.maximum(sym_steps, c)

            if global_nsteps["real"] >= 0:
                real_steps = global_nsteps["real"]
            else:
                real_steps = None
            nw_steps = select_min(select_max(sym_steps, real_steps), node.inputs[0])
        else:
            nw_steps = node.inputs[0]
            global_nsteps = None

        # 2.4 Loop over the clients again now looking just to see how many
        # intermediate steps to store
        for i, out in enumerate(node.outputs[:c_outs]):
            # look at all its clients
            for cl, _ in out.clients:
                if type(cl) == str:
                    store_steps[i] = 0
                    break
                elif not isinstance(cl.op, tensor.basic.Subtensor):
                    store_steps[i] = 0
                    break
                else:
                    this_slice = tensor.basic.get_idx_list(cl.inputs, cl.op.idx_list)
                    if this_slice == None:
                        store_steps[i] = 0
                        break

                    if isinstance(this_slice[0], slice) and this_slice[0].start is None:
                        store_steps[i] = 0
                        break

                    if i > op.n_mit_mot:
                        length = node.inputs[0] + init_l[i]
                    else:
                        try:
                            length = shape_of[out][0]
                        except Exception:
                            length = out.shape[0]
                    cf_slice = tensor.basic.get_canonical_form_slice(this_slice[0], length)

                    if isinstance(cf_slice[0], slice):
                        start = tensor.basic.extract_constant(cf_slice[0].start)
                    else:
                        start = tensor.basic.extract_constant(cf_slice[0])
                    if start == 0 or store_steps[i] == 0:
                        store_steps[i] = 0
                    else:
                        pval = select_max(nw_steps - start + init_l[i], init_l[i])
                        if store_steps[i] != -1:
                            pval = select_max(pval, store_steps[i])

                        store_steps[i] = pval
                        flag_store = True

        orphane_outs = [i for i, x in enumerate(store_steps) if (type(x) is int) and (x < 0)]
        flag_store = flag_store or (len(orphane_outs) > 0)
        # 3. is there anything to change ?
        if flag_store or global_nsteps is not None:
            # 3.1 initialize inputs for the new scan
            old_outputs = []
            nw_inputs = list(node.inputs)
            nw_inputs[0] = nw_steps

            # 3.2 check orphane outputs to see if we can eliminate any
            required, not_required = scan_utils.scan_can_remove_outs(node.op, orphane_outs)
            # 3.3. compose replace pairs for those nodes that need not
            # to store everything in memory ( or ar orphane and required
            # by the inner function .. )
            replaced_outs = []
            offset = 1 + op.n_seqs + op.n_mit_mot
            for idx, _val in enumerate(store_steps[op.n_mit_mot :]):
                i = idx + op.n_mit_mot
                if not (type(_val) is int and _val <= 0 and i not in required):

                    if idx + op.n_mit_mot in required:
                        val = 1
                    else:
                        val = _val
                    # If the memory for this output has been pre-allocated
                    # before going into the scan op (by an alloc node)
                    if idx < op.n_mit_sot + op.n_sit_sot:
                        # In case the input is still an alloc node, we
                        # actually have two options:
                        #   a) the input is a set_subtensor, in that case we
                        #      can replace the initial tensor by a slice,
                        #   b) it is not, and we simply take a slice of it.
                        if nw_inputs[offset + idx].owner and isinstance(
                            nw_inputs[offset + idx].owner.op, tensor.IncSubtensor
                        ):
                            _nw_input = nw_inputs[offset + idx].owner.inputs[1]
                            tmp = pre_greedy_local_optimizer(list_opt_slice, tensor.as_tensor_variable(val - init_l[i]))
                            tmp = pre_constant_merge([tmp])[0]
                            nw_input = scan_utils.expand(_nw_input, tmp)
                        else:
                            tmp = pre_greedy_local_optimizer(list_opt_slice, tensor.as_tensor_variable(val))
                            tmp = pre_constant_merge([tmp])[0]
                            nw_input = nw_inputs[offset + idx][:tmp]

                        nw_inputs[offset + idx] = nw_input
                        replaced_outs.append(op.n_mit_mot + idx)
                        odx = op.n_mit_mot + idx
                        old_outputs += [(odx, [x[0].outputs[0] for x in node.outputs[odx].clients])]
                    # If there is no memory pre-allocated for this output
                    elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot:

                        pos = op.n_mit_mot + idx + op.n_seqs + 1 + op.n_shared_outs
                        if nw_inputs[pos] == node.inputs[0]:
                            nw_inputs[pos] = val
                        odx = op.n_mit_mot + idx
                        replaced_outs.append(odx)
                        old_outputs += [(odx, [x[0].outputs[0] for x in node.outputs[odx].clients])]
            # 3.4. Recompute inputs for everything else based on the new
            # number of steps
            if global_nsteps is not None:
                for idx, val in enumerate(store_steps[op.n_mit_mot :]):
                    if val == 0:
                        if idx < op.n_mit_sot + op.n_sit_sot:
                            _nw_input = nw_inputs[offset + idx].owner.inputs[1]
                            odx = op.n_mit_mot + idx
                            nw_input = scan_utils.expand(_nw_input, nw_steps)
                            nw_inputs[offset + idx] = nw_input
                        elif idx < (op.n_mit_sot + op.n_sit_sot + op.n_nit_sot):
                            in_idx = offset + idx + op.n_shared_outs
                            if nw_inputs[in_idx] == node.inputs[0]:
                                nw_inputs[in_idx] = nw_steps
                            odx = op.n_mit_mot + idx

            # 3.5 Remove unwanted orphane outputs
            (inps, outs, info, node_ins, compress_map) = scan_utils.compress_outs(op, not_required, nw_inputs)
            inv_compress_map = {}
            for k, v in compress_map.items():
                inv_compress_map[v] = k

            node_ins = [pre_greedy_local_optimizer(list_opt_slice, x) for x in node_ins]
            node_ins = pre_constant_merge(node_ins)
            # 3.6 Compose the new scan
            # I need to make sure I'm not reapplying the same optimization
            # twice since bad things usually happen if I do that
            info["_scan_merge_visited"] = True
            new_outs = scan_op.Scan(inps, outs, info).make_node(*node_ins).outputs

            old_new = []
            # 3.7 Get replace pairs for those outputs that do not change
            # the number of intermediate steps stored
            for idx, sl in enumerate(slices):
                if global_nsteps and sl is not None and store_steps[idx] == 0:
                    for hdx, cl in enumerate(node.outputs[idx].clients):
                        cnf_slice, old_slices = sl[hdx]
                        # Sanitize the nw_slice by converting ints back into
                        # constants :) I only need to do this for the first
                        # slice since that is the only slice

                        if isinstance(cnf_slice[0], slice):
                            fslice = slice(
                                sanitize(cnf_slice[0].start), sanitize(cnf_slice[0].stop), sanitize(cnf_slice[0].step)
                            )
                        else:
                            fslice = sanitize(cnf_slice[0])

                        nw_slice = (fslice,) + tuple(old_slices[1:])
                        nw_pos = inv_compress_map[idx]
                        nw_out = new_outs[nw_pos]

                        subtens = tensor.basic.Subtensor(nw_slice)
                        # slice inputs
                        sl_ins = tensor.basic.Subtensor.collapse(
                            nw_slice, lambda entry: isinstance(entry, tensor.Variable)
                        )
                        new_o = subtens.make_node(new_outs[nw_pos], *sl_ins).outputs[0]
                        if new_o.ndim > 0:
                            new_o = new_o[:: cnf_slice[1]]
                        replaced_outs.append(idx)
                        old_new += [(cl[0].outputs[0], new_o)]
            # 3.8. Get replace pairs for those outputs that change
            # the number of stored intermediate steps
            for pos, old_outs in old_outputs:
                if len(old_outs) > 0:
                    nw_pos = compress_map[pos]
                    nw_out = new_outs[nw_pos]
                    for k, old in enumerate(old_outs):
                        # Get the correct slice
                        cnf_slice, old_slices = slices[pos][k]
                        if type(cnf_slice[0]) is slice:
                            start = cnf_slice[0].start - nw_steps - init_l[pos] + store_steps[pos]
                            if cnf_slice[0].stop is not None and cnf_slice[0].stop != sys.maxint:
                                stop = cnf_slice[0].stop - nw_steps - init_l[pos] + store_steps[pos]
                            else:
                                stop = None
                            nw_slice = (slice(sanitize(start), sanitize(stop), sanitize(cnf_slice[0].step)),) + tuple(
                                old_slices[1:]
                            )

                        else:
                            position = cnf_slice[0] - nw_steps - init_l[pos] + store_steps[pos]

                            nw_slice = (sanitize(position),) + tuple(old_slices[1:])

                        subtens = tensor.basic.Subtensor(nw_slice)
                        sl_ins = tensor.basic.Subtensor.collapse(
                            nw_slice, lambda entry: isinstance(entry, tensor.Variable)
                        )
                        new_o = subtens.make_node(new_outs[nw_pos], *sl_ins).outputs[0]
                        if new_o.ndim > 0:
                            new_o = new_o[:: cnf_slice[1]]
                        old_new += [(old, new_o)]

            # 3.9. Get replace pairs for all other nodes
            if flag_store or global_nsteps is not None:
                for idx, o in enumerate(node.outputs):
                    if not (idx in replaced_outs) and not idx in not_required:
                        nw_pos = compress_map[idx]
                        old_new += [(o, new_outs[nw_pos])]

                env.replace_all_validate(old_new, reason="scan_save_mem")
Ejemplo n.º 2
0
                        # No need to print a warning or raise an error now,
                        # it will be done when fn will be called.
                        _logger.info(('Cannot compute test value for the '
                            'inner function of scan, input value missing %s'),
                                     e)

            if getattr(init_out['initial'], 'name', None) is not None:
                arg.name = init_out['initial'].name + '[t-1]'

            # We need now to allocate space for storing the output and copy
            # the initial state over. We do this using the expand function
            # defined in scan utils
            sit_sot_scan_inputs.append(
                scan_utils.expand(
                    tensor.unbroadcast(
                        tensor.shape_padleft(actual_arg), 0),
                    actual_n_steps
                ))

            sit_sot_inner_slices.append(actual_arg)
            if i in return_steps:
                sit_sot_return_steps[n_sit_sot] = return_steps[i]
            sit_sot_inner_inputs.append(arg)
            sit_sot_rightOrder.append(i)
            n_sit_sot += 1

        elif init_out.get('taps', None):

            if numpy.any(numpy.array(init_out.get('taps', [])) > 0):
                # Make sure we do not have requests for future values of a
                # sequence we can not provide such values
Ejemplo n.º 3
0
    def process_node(self, fgraph, node):

        # helpful functions
        def select_min(x, y):
            if x is None:
                return y
            if y is None:
                return x
            return tensor.minimum(x, y)

        def select_max(x, y):
            if x is None:
                return y
            if y is None:
                return x
            return tensor.maximum(x, y)

        def sanitize(x):
            if x is None:
                return None
            else:
                return tensor.as_tensor_variable(x)

        if hasattr(fgraph, 'shape_feature'):
            shape_of = node.fgraph.shape_feature.shape_of
        else:
            # Each access to shape_of is in a try..except block in order to
            # use a default version when the variable is not in the shape_of
            # dictionary.
            shape_of = {}
        # 1. Initialization of variables
        # Note 1) We do not actually care about outputs representing shared
        # variables (those have no intermediate values) so it is safer to
        # ignore them and not change them in any way. To simplify the
        # optimizations I construct the variable ``c_outs`` ( that counts
        # outputs up to those we care) and the list ``init_l`` which for any
        # output we care says the length of its initial state. Note that
        # defining ``init_l`` for mit_mot sequences is a bit trickier but
        # it is safe to set it to 0
        op = node.op
        c_outs = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot

        init_l = [0 for x in xrange(op.n_mit_mot)]
        init_l += [abs(numpy.min(v)) for v in op.tap_array[op.n_mit_mot:]]
        init_l += [0 for x in xrange(op.n_nit_sot)]
        # 2. Check the clients of each output and see for how many steps
        # does scan need to run

        # This comparison checks if there is any uncounted output, which
        # can only be an output corresponding to a shared variable

        # 2.1 Initialize
        # global_nsteps is a dictionary having two fields ( 'real' deals
        # with int values, 'sym' with symbolic ones) or None
        # given that a scan op has k outputs o_1, .. o_k and each
        # output has n_j clients c_1^1, c_1^2, .. c_1^{n_1}, c_2^1, ..,
        # global_nsteps is None if any of the clients is different
        # from a subtensor or its real and sym field equal to
        # max(c_i_j.idx_list[0].stop), meaning store up to which maximal
        # index(step) for any output scan actually needs to compute
        # In other words n_steps should be equal to this maximal !
        # Note: if we have a shared variable that gets updated at every step
        # of the loop, reducing the number of steps will affect the the
        # value of the shared variable after the loop so we need not to
        # change the number of steps in that case. To do this we set
        # global_nsteps to None which is seen as a flag that nothing needs
        # to be done
        assert len(node.outputs) >= c_outs
        if len(node.outputs) == c_outs:
            global_nsteps = {'real': -1, 'sym': []}
        else:
            global_nsteps = None

        # Keeps track of the original slices that each client represent
        slices = [None for o in node.outputs]

        # A list for each output indicating how many intermediate values
        # should be stored. If negative it means none of the intermediate
        # values (i.e. the output can be removed since it is not used
        # afterwards in the computations), if 0 it means that all
        # intermediate values are required, otherwise is up to that number
        # of intermediate values
        # Note that for mit_mot outputs and shared outputs we can not change
        # the number of intermediate steps stored without affecting the
        # result of the op
        store_steps = [0 for o in xrange(op.n_mit_mot)]
        store_steps += [-1 for o in node.outputs[op.n_mit_mot:c_outs]]
        # Flag that says if an input has changed and we need to do something
        # or not
        flag_store = False

        # 2.2 Loop over the clients
        for i, out in enumerate(node.outputs[:c_outs]):
            # look at all its clients
            slices[i] = []
            for cl, _ in out.clients:

                # 2.1 outputs of the function
                #=> output needs all its intermediate values
                if type(cl) == str:
                    # if the node is actually an output, then
                    # we need to store the entire thing
                    global_nsteps = None
                    slices[i] = None
                    break
                # 2.2 non-subtensor nodes
                #=> output needs all its intermediate values
                elif not isinstance(cl.op, tensor.basic.Subtensor):
                    global_nsteps = None
                    slices[i] = None
                    break
                # 2.3 subtensor nodes
                #=> output might need to store just a subset of its values
                else:
                    # 2.3.1 extract idx list of subtensor
                    this_slice = tensor.basic.get_idx_list(cl.inputs,
                                                     cl.op.idx_list)
                    if this_slice is None:
                        # if unable to extract idx_list
                        #=> outputs needs all its intermediate values
                        global_nsteps = None
                        slices[i] = None
                        break

                    # 2.3.2 extract the begin/end of the first dimension
                    if i >= op.n_mit_mot:
                        try:
                            length = shape_of[out][0]
                        except KeyError:
                            length = node.inputs[0] + init_l[i]
                    else:
                        try:
                            length = shape_of[out][0]
                        except KeyError:
                            length = out.shape[0]
                    cf_slice = tensor.basic.get_canonical_form_slice(
                                                    this_slice[0], length)
                    slices[i] += [(cf_slice, this_slice)]

                    if (isinstance(this_slice[0], slice) and
                        this_slice[0].stop is None):
                        global_nsteps = None
                    if isinstance(cf_slice[0], slice):
                        stop = tensor.basic.extract_constant(cf_slice[0].stop)
                    else:
                        stop = tensor.basic.extract_constant(cf_slice[0]) + 1
                    if stop == maxsize or stop == length:
                        stop = None
                    else:
                        # there is a **gotcha** here ! Namely, scan returns an
                        # array that contains the initial state of the output
                        # as well. Which means that if have a initial state of
                        # length 3, and you look for 5 steps you get an output
                        # y of length 8. If you only use y[:5], this does not
                        # mean that you only need to loop for 5 steps but
                        # actually only for 2 steps ( the first 3 are the
                        # initial state)
                        stop = stop - init_l[i]

                    # 2.3.3 we might get away with less number of steps
                    if stop is not None and global_nsteps is not None:
                        # yes if it is a tensor
                        if isinstance(stop, tensor.Variable):
                            global_nsteps['sym'] += [stop]
                        # not if it is maxsize
                        elif (type(stop) in (int, long) and
                              stop == maxsize):
                            global_nsteps = None
                        # yes if it is a int k, 0 < k < maxsize
                        elif (type(stop) in (int, long) and
                              global_nsteps['real'] < stop):
                            global_nsteps['real'] = stop
                        # yes if it is a int k, 0 < k < maxsize
                        elif (type(stop) in (int, long) and stop > 0):
                            pass
                        # not otherwise
                        else:
                            global_nsteps = None

        # 2.3. Analyze global_nsteps to figure out for how many steps scan
        # needs to iterate
        if global_nsteps is not None:
            nw_steps = node.inputs[0]

            # there are some symbolic tensors that limit the number of
            # steps
            if len(global_nsteps['sym']) == 0:
                sym_steps = None
            else:
                sym_steps = global_nsteps['sym'][0]
                for c in global_nsteps['sym'][1:]:
                    sym_steps = tensor.maximum(sym_steps, c)

            if global_nsteps['real'] >= 0:
                real_steps = global_nsteps['real']
            else:
                real_steps = None
            nw_steps = select_min(select_max(sym_steps, real_steps),
                                  node.inputs[0])
        else:
            nw_steps = node.inputs[0]
            global_nsteps = None

        # 2.4 Loop over the clients again now looking just to see how many
        # intermediate steps to store
        for i, out in enumerate(node.outputs[:c_outs]):
            # look at all its clients
            for cl, _ in out.clients:
                if type(cl) == str:
                    store_steps[i] = 0
                    break
                elif not isinstance(cl.op, tensor.basic.Subtensor):
                    store_steps[i] = 0
                    break
                else:
                    this_slice = tensor.basic.get_idx_list(cl.inputs,
                                                         cl.op.idx_list)
                    if this_slice is None:
                        store_steps[i] = 0
                        break

                    if (isinstance(this_slice[0], slice) and
                        this_slice[0].start is None):
                        store_steps[i] = 0
                        break

                    if i > op.n_mit_mot:
                        length = node.inputs[0] + init_l[i]
                    else:
                        try:
                            length = shape_of[out][0]
                        except KeyError:
                            length = out.shape[0]
                    cf_slice = tensor.basic.get_canonical_form_slice(
                                                    this_slice[0], length)

                    if isinstance(cf_slice[0], slice):
                        start = tensor.basic.extract_constant(
                            cf_slice[0].start)
                    else:
                        start = tensor.basic.extract_constant(cf_slice[0])
                    if start == 0 or store_steps[i] == 0:
                        store_steps[i] = 0
                    else:
                        pval = select_max(nw_steps - start + init_l[i],
                                          init_l[i])
                        if store_steps[i] != -1:
                            pval = select_max(pval, store_steps[i])

                        store_steps[i] = pval
                        flag_store = True

        orphane_outs = [i for i, x in enumerate(store_steps)
                        if (type(x) is int) and (x < 0)]
        flag_store = flag_store or (len(orphane_outs) > 0)
        # 3. is there anything to change ?
        if (flag_store or global_nsteps is not None):
            # 3.1 initialize inputs for the new scan
            old_outputs = []
            nw_inputs = list(node.inputs)
            nw_inputs[0] = nw_steps

            # 3.2 check orphane outputs to see if we can eliminate any
            required, not_required = \
                    scan_utils.scan_can_remove_outs(node.op,
                                                    orphane_outs)
            # 3.3. compose replace pairs for those nodes that need not
            # to store everything in memory ( or ar orphane and required
            # by the inner function .. )
            replaced_outs = []
            offset = 1 + op.n_seqs + op.n_mit_mot
            for idx, _val in enumerate(store_steps[op.n_mit_mot:]):
                i = idx + op.n_mit_mot
                if not(type(_val) is int and _val <= 0 and i not in required):

                    if idx + op.n_mit_mot in required:
                        val = 1
                    else:
                        val = _val
                    # If the memory for this output has been pre-allocated
                    # before going into the scan op (by an alloc node)
                    if idx < op.n_mit_sot + op.n_sit_sot:
                        # In case the input is still an alloc node, we
                        # actually have two options:
                        #   a) the input is a set_subtensor, in that case we
                        #      can replace the initial tensor by a slice,
                        #   b) it is not, and we simply take a slice of it.

                        #TODO: commit change below with Razvan
                        if (nw_inputs[offset + idx].owner and
                            isinstance(nw_inputs[offset + idx].owner.op,
                                       tensor.IncSubtensor) and
                            isinstance(nw_inputs[offset+idx].owner.op.idx_list[0], slice)):

                            _nw_input = nw_inputs[offset + idx].owner.inputs[1]
                            val = tensor.as_tensor_variable(val)
                            initl = tensor.as_tensor_variable(init_l[i])
                            tmp = pre_greedy_local_optimizer(list_opt_slice,
                                    tensor.maximum(val - initl, 0))
                            tmp = pre_constant_merge([tmp])[0]
                            nw_input = scan_utils.expand(_nw_input, tmp)
                        else:
                            tmp = tensor.as_tensor_variable(val)
                            initl = tensor.as_tensor_variable(init_l[i])
                            tmp = tensor.maximum(tmp, initl)
                            tmp = pre_greedy_local_optimizer(list_opt_slice,
                                                             tmp)
                            tmp = pre_constant_merge([tmp])[0]
                            nw_input = nw_inputs[offset + idx][:tmp]

                        nw_inputs[offset + idx] = nw_input
                        replaced_outs.append(op.n_mit_mot + idx)
                        odx = op.n_mit_mot + idx
                        old_outputs += [(odx, [x[0].outputs[0] for x in
                                        node.outputs[odx].clients])]
                    # If there is no memory pre-allocated for this output
                    elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot:

                        pos = (op.n_mit_mot + idx + op.n_seqs +
                               1 + op.n_shared_outs)
                        if nw_inputs[pos] == node.inputs[0]:
                            nw_inputs[pos] = val
                        odx = op.n_mit_mot + idx
                        replaced_outs.append(odx)
                        old_outputs += [(odx, [x[0].outputs[0] for x in
                                        node.outputs[odx].clients])]
            # 3.4. Recompute inputs for everything else based on the new
            # number of steps
            if global_nsteps is not None:
                for idx, val in enumerate(store_steps[op.n_mit_mot:]):
                    if val == 0:
                        if idx < op.n_mit_sot + op.n_sit_sot:
                            _nw_input = nw_inputs[offset + idx].owner.inputs[1]
                            odx = op.n_mit_mot + idx
                            nw_input = scan_utils.expand(_nw_input, nw_steps)
                            nw_inputs[offset + idx] = nw_input
                        elif idx < (op.n_mit_sot + op.n_sit_sot +
                                    op.n_nit_sot):
                            in_idx = offset + idx + op.n_shared_outs
                            if nw_inputs[in_idx] == node.inputs[0]:
                                nw_inputs[in_idx] = nw_steps
                            odx = op.n_mit_mot + idx

            # 3.5 Remove unwanted orphane outputs
            (inps, outs, info, node_ins, compress_map) = \
                    scan_utils.compress_outs(op, not_required, nw_inputs)
            inv_compress_map = {}
            for k, v in compress_map.items():
                inv_compress_map[v] = k

            node_ins = [pre_greedy_local_optimizer(list_opt_slice, x) for x in
                        node_ins]
            node_ins = pre_constant_merge(node_ins)
            # 3.6 Compose the new scan
            # I need to make sure I'm not reapplying the same optimization
            # twice since bad things usually happen if I do that
            info['_scan_savemem_visited'] = True
            new_outs = scan_op.Scan(inps,
                                    outs,
                                    info).make_node(*node_ins).outputs

            old_new = []
            # 3.7 Get replace pairs for those outputs that do not change
            # the number of intermediate steps stored
            for idx, sl in enumerate(slices):
                if global_nsteps and sl is not None and store_steps[idx] == 0:
                    for hdx, cl in enumerate(node.outputs[idx].clients):
                        cnf_slice, old_slices = sl[hdx]
                        # Sanitize the nw_slice by converting ints back into
                        # constants :) I only need to do this for the first
                        # slice since that is the only slice

                        if isinstance(cnf_slice[0], slice):
                            fslice = slice(
                                sanitize(cnf_slice[0].start),
                                sanitize(cnf_slice[0].stop),
                                sanitize(cnf_slice[0].step))
                        else:
                            fslice = sanitize(cnf_slice[0])

                        nw_slice = (fslice,) + tuple(old_slices[1:])
                        nw_pos = inv_compress_map[idx]

                        subtens = tensor.basic.Subtensor(nw_slice)
                        # slice inputs
                        sl_ins = tensor.basic.Subtensor.collapse(
                            nw_slice,
                            lambda entry: isinstance(entry,
                                                    tensor.Variable))
                        new_o = subtens.make_node(new_outs[nw_pos],
                                                  *sl_ins).outputs[0]
                        if new_o.ndim > 0:
                            new_o = new_o[::cnf_slice[1]]
                        replaced_outs.append(idx)
                        old_new += [(cl[0].outputs[0], new_o)]
            # 3.8. Get replace pairs for those outputs that change
            # the number of stored intermediate steps
            for pos, old_outs in old_outputs:
                if len(old_outs) > 0:
                    nw_pos = compress_map[pos]
                    for k, old in enumerate(old_outs):
                        # Get the correct slice
                        cnf_slice, old_slices = slices[pos][k]
                        if type(cnf_slice[0]) is slice:
                            start = (cnf_slice[0].start - nw_steps -
                                     init_l[pos] + store_steps[pos])
                            if (cnf_slice[0].stop is not None and
                                cnf_slice[0].stop != maxsize):
                                stop = (cnf_slice[0].stop - nw_steps -
                                        init_l[pos] + store_steps[pos])
                            else:
                                stop = None
                            nw_slice = ((slice(sanitize(start),
                                               sanitize(stop),
                                               sanitize(cnf_slice[0].step)),)
                                        + tuple(old_slices[1:]))

                        else:
                            position = (cnf_slice[0] - nw_steps -
                                         init_l[pos] + store_steps[pos])

                            nw_slice = (sanitize(position),) + \
                                    tuple(old_slices[1:])

                        subtens = tensor.basic.Subtensor(nw_slice)
                        sl_ins = tensor.basic.Subtensor.collapse(
                            nw_slice,
                            lambda entry: isinstance(entry,
                                                     tensor.Variable))
                        new_o = subtens.make_node(new_outs[nw_pos],
                                                  *sl_ins).outputs[0]
                        if new_o.ndim > 0:
                            new_o = new_o[::cnf_slice[1]]
                        old_new += [(old, new_o)]

            # 3.9. Get replace pairs for all other nodes
            if flag_store or global_nsteps is not None:
                for idx, o in enumerate(node.outputs):
                    if not (idx in replaced_outs) and not idx in not_required:
                        nw_pos = compress_map[idx]
                        old_new += [(o, new_outs[nw_pos])]

                fgraph.replace_all_validate_remove(old_new,
                                                   remove=[node],
                                                   reason='scan_save_mem')
Ejemplo n.º 4
0
def scan(fn,
         sequences=None,
         outputs_info=None,
         non_sequences=None,
         n_steps=None,
         truncate_gradient=-1,
         go_backwards=False,
         mode=None,
         name=None,
         options=None,
         profile=False):
    """
    This function constructs and applies a Scan op to the provided
    arguments.

    :param fn:
        ``fn`` is a function that describes the operations involved in one
        step of ``scan``. ``fn`` should construct variables describing the
        output of one iteration step. It should expect as input theano
        variables representing all the slices of the input sequences
        and previous values of the outputs, as well as all other arguments
        given to scan as ``non_sequences``. The order in which scan passes
        these variables to ``fn``  is the following :

        * all time slices of the first sequence
        * all time slices of the second sequence
        * ...
        * all time slices of the last sequence
        * all past slices of the first output
        * all past slices of the second otuput
        * ...
        * all past slices of the last output
        * all other arguments (the list given as `non_sequences` to
            scan)

        The order of the sequences is the same as the one in the list
        `sequences` given to scan. The order of the outputs is the same
        as the order of ``output_info``. For any sequence or output the
        order of the time slices is the same as the one in which they have
        been given as taps. For example if one writes the following :

        .. code-block:: python

            scan(fn, sequences = [ dict(input= Sequence1, taps = [-3,2,-1])
                                 , Sequence2
                                 , dict(input =  Sequence3, taps = 3) ]
                   , outputs_info = [ dict(initial =  Output1, taps = [-3,-5])
                                    , dict(initial = Output2, taps = None)
                                    , Output3 ]
                   , non_sequences = [ Argument1, Argument 2])

        ``fn`` should expect the following arguments in this given order:

        #. ``Sequence1[t-3]``
        #. ``Sequence1[t+2]``
        #. ``Sequence1[t-1]``
        #. ``Sequence2[t]``
        #. ``Sequence3[t+3]``
        #. ``Output1[t-3]``
        #. ``Output1[t-5]``
        #. ``Output3[t-1]``
        #. ``Argument1``
        #. ``Argument2``

        The list of ``non_sequences`` can also contain shared variables
        used in the function, though ``scan`` is able to figure those
        out on its own so they can be skipped. For the clarity of the
        code we recommand though to provide them to scan. To some extend
        ``scan`` can also figure out other ``non sequences`` (not shared)
        even if not passed to scan (but used by `fn`). A simple example of
        this would be :

        .. code-block:: python

            import theano.tensor as TT
            W   = TT.matrix()
            W_2 = W**2
            def f(x):
                return TT.dot(x,W_2)

        The function is expected to return two things. One is a list of
        outputs ordered in the same order as ``outputs_info``, with the
        difference that there should be only one output variable per
        output initial state (even if no tap value is used). Secondly
        `fn` should return an update dictionary (that tells how to
        update any shared variable after each iteration step). The
        dictionary can optionally be given as a list of tuples. There is
        no constraint on the order of these two list, ``fn`` can return
        either ``(outputs_list, update_dictionary)`` or
        ``(update_dictionary, outputs_list)`` or just one of the two (in
        case the other is empty).

        To use ``scan`` as a while loop, the user needs to change the
        function ``fn`` such that also a stopping condition is returned.
        To do so, he/she needs to wrap the condition in an ``until`` class.
        The condition should be returned as a third element, for example:

        .. code-block:: python

            ...
            return [y1_t, y2_t], {x:x+1}, theano.scan_module.until(x < 50)

        Note that a number of steps (considered in here as the maximum
        number of steps ) is still required even though a condition is
        passed (and it is used to allocate memory if needed). = {}):

    :param sequences:
        ``sequences`` is the list of Theano variables or dictionaries
        describing the sequences ``scan`` has to iterate over. If a
        sequence is given as wrapped in a dictionary, then a set of optional
        information can be provided about the sequence. The dictionary
        should have the following keys:

        * ``input`` (*mandatory*) -- Theano variable representing the
          sequence.

        * ``taps`` -- Temporal taps of the sequence required by ``fn``.
          They are provided as a list of integers, where a value ``k``
          impiles that at iteration step ``t`` scan will pass to ``fn``
          the slice ``t+k``. Default value is ``[0]``

        Any Theano variable in the list ``sequences`` is automatically
        wrapped into a dictionary where ``taps`` is set to ``[0]``


    :param outputs_info:
        ``outputs_info`` is the list of Theano variables or dictionaries
        describing the initial state of the outputs computed
        recurrently. When this initial states are given as dictionary
        optional information can be provided about the output corresponding
        to these initial states. The dictionary should have the following
        keys:

        * ``initial`` -- Theano variable that represents the initial
          state of a given output. In case the output is not computed
          recursively (think of a map) and does not require a initial
          state this field can be skiped. Given that only the previous
          time step of the output is used by ``fn`` the initial state
          should have the same shape as the output. If multiple time
          taps are used, the initial state should have one extra
          dimension that should cover all the possible taps. For example
          if we use ``-5``, ``-2`` and ``-1`` as past taps, at step 0,
          ``fn`` will require (by an abuse of notation) ``output[-5]``,
          ``output[-2]`` and ``output[-1]``. This will be given by
          the initial state, which in this case should have the shape
          (5,)+output.shape. If this variable containing the initial
          state is called ``init_y`` then ``init_y[0]`` *corresponds to*
          ``output[-5]``. ``init_y[1]`` *correponds to* ``output[-4]``,
          ``init_y[2]`` corresponds to ``output[-3]``, ``init_y[3]``
          coresponds to ``output[-2]``, ``init_y[4]`` corresponds to
          ``output[-1]``. While this order might seem strange, it comes
          natural from splitting an array at a given point. Assume that
          we have a array ``x``, and we choose ``k`` to be time step
          ``0``. Then our initial state would be ``x[:k]``, while the
          output will be ``x[k:]``. Looking at this split, elements in
          ``x[:k]`` are ordered exactly like those in ``init_y``.
        * ``taps`` -- Temporal taps of the output that will be pass to
          ``fn``. They are provided as a list of *negative* integers,
          where a value ``k`` implies that at iteration step ``t`` scan
          will pass to ``fn`` the slice ``t+k``.

        ``scan`` will follow this logic if partial information is given:

        * If an output is not wrapped in a dictionary, ``scan`` will wrap
          it in one assuming that you use only the last step of the output
          (i.e. it makes your tap value list equal to [-1]).
        * If you wrap an output in a dictionary and you do not provide any
          taps but you provide an initial state it will assume that you are
          using only a tap value of -1.
        * If you wrap an output in a dictionary but you do not provide any
          initial state, it assumes that you are not using any form of
          taps.
        * If you provide a ``None`` instead of a variable or a empty
          dictionary ``scan`` assumes that you will not use any taps for
          this output (like for example in case of a map)

        If ``outputs_info`` is an empty list or None, ``scan`` assumes
        that no tap is used for any of the outputs. If information is
        provided just for a subset of the outputs an exception is
        raised (because there is no convention on how scan should map
        the provided information to the outputs of ``fn``)


    :param non_sequences:
        ``non_sequences`` is the list of arguments that are passed to
        ``fn`` at each steps. One can opt to exclude variable
        used in ``fn`` from this list as long as they are part of the
        computational graph, though for clarity we encourage not to do so.


    :param n_steps:
        ``n_steps`` is the number of steps to iterate given as an int
        or Theano scalar. If any of the input sequences do not have
        enough elements, scan will raise an error. If the *value is 0* the
        outputs will have *0 rows*. If the value is negative, ``scan``
        will run backwards in time. If the ``go_backwards`` flag is already
        set and also ``n_steps`` is negative, ``scan`` will run forward
        in time. If n stpes is not provided, ``scan`` will figure
        out the amount of steps it should run given its input sequences.


    :param truncate_gradient:
        ``truncate_gradient`` is the number of steps to use in truncated
        BPTT.  If you compute gradients through a scan op, they are
        computed using backpropagation through time. By providing a
        different value then -1, you choose to use truncated BPTT instead
        of classical BPTT, where you go for only ``truncate_gradient``
        number of steps back in time.


    :param go_backwards:
        ``go_backwards`` is a flag indicating if ``scan`` should go
        backwards through the sequences. If you think of each sequence
        as indexed by time, making this flag True would mean that
        ``scan`` goes back in time, namely that for any sequence it
        starts from the end and goes towards 0.


    :param name:
        When profiling ``scan``, it is crucial to provide a name for any
        instance of ``scan``. The profiler will produce an overall
        profile of your code as well as profiles for the computation of
        one step of each instance of ``scan``. The ``name`` of the instance
        appears in those profiles and can greatly help to disambiguate
        information.

    :param mode:
        It is recommended to leave this argument to None, especially
        when profiling ``scan`` (otherwise the results are not going to
        be accurate). If you prefer the computations of one step of
        ``scan`` to be done differently then the entire function, you
        can use this parameter to describe how the computations in this
        loop are done (see ``theano.function`` for details about
        possible values and their meaning).

    :param profile:
        Flag or string. If true, or different from the empty string, a
        profile object will be created and attached to the inner graph of
        scan. In case ``profile`` is True, the profile object will have the
        name of the scan instance, otherwise it will have the passed string.
        Profile object collect (and print) information only when running the
        inner graph with the new cvm linker ( with default modes,
        other linkers this argument is useless)

    :rtype: tuple
    :return: tuple of the form (outputs, updates); ``outputs`` is either a
             Theano variable or a list of Theano variables representing the
             outputs of ``scan`` (in the same order as in
             ``outputs_info``). ``updates`` is a subclass of dictionary
             specifying the
             update rules for all shared variables used in scan
             This dictionary should be passed to ``theano.function`` when
             you compile your function. The change compared to a normal
             dictionary is that we validate that keys are SharedVariable
             and addition of those dictionary are validated to be consistent.
    """
    # Note : see the internal documentation of the scan op for naming
    # conventions and all other details
    if options is None:
        options = {}
    rvals = scan_utils.canonical_arguments(sequences,
                                           outputs_info,
                                           non_sequences,
                                           go_backwards,
                                           n_steps)
    inputs, states_and_outputs_info, parameters, T = rvals
    # If we provided a known number of steps ( before compilation)
    # and if that number is 1 or -1, then we can skip the Scan Op,
    # and just apply the inner function once
    # To do that we check here to see the nature of n_steps
    T_value = None
    if isinstance(n_steps, (float, int)):
        T_value = int(n_steps)
    else:
        try:
            T_value = opt.get_constant_value(n_steps)
        except (TypeError, AttributeError):
            T_value = None

    if T_value in (1, -1):
        return one_step_scan(fn,
                             inputs,
                             states_and_outputs_info,
                             parameters,
                             truncate_gradient)

    # 1. Variable representing the current time step
    t = scalar_shared(numpy.int64(0), name='t')

    # 2. Allocate memory for the states of scan.
    mintaps = []
    lengths = []
    for pos, arg_info in enumerate(states_and_outputs_info):
        if arg_info.get('taps', None) == [-1]:
            mintaps.append(1)
            lengths.append(scalar_shared(numpy.int64(0),
                                         name='l%d' % pos))
            arg_info['initial'] = scan_utils.expand(tensor.unbroadcast(
                    tensor.shape_padleft(arg_info['initial']), 0), T)
        elif arg_info.get('taps', None):
            if numpy.any(numpy.array(arg_info.get('taps', [])) > 0):
                # Make sure we do not have requests for future values of a
                # sequence we can not provide such values
                raise ValueError('Can not use future taps of outputs',
                                 arg_info)
            mintap = abs(numpy.min(arg_info['taps']))
            lengths.append(scalar_shared(numpy.int64(0),
                                         name='l%d' % pos))
            mintaps.append(mintap)
            arg_info['initial'] = scan_utils.expand(
                arg_info['initial'][:mintap], T)
        else:
            mintaps.append(0)
            lengths.append(scalar_shared(numpy.int64(0),
                                         name='l%d' % pos))

    # 3. Generate arguments for the function passed to scan. This will
    # function will return the outputs that need to be computed at every
    # timesteps
    inputs_slices = [input[t] for input in inputs]
    states_slices = []
    for n, state in enumerate(states_and_outputs_info):
        # Check if it is actually a state and not an output
        if mintaps[n] != 0:
            for k in state['taps']:
                states_slices.append(
                    state['initial'][(t + mintaps[n] + k) % lengths[n]])

    # 4. Construct outputs that are to be computed by the inner
    # function of scan
    args = inputs_slices + states_slices + parameters
    cond, states_and_outputs, updates = \
            scan_utils.get_updates_and_outputs(fn(*args))

    # User is allowed to provide no information if it only behaves like a
    # map
    if (len(states_and_outputs) != len(states_and_outputs_info) and
        len(states_and_outputs_info) == 0):
        mintaps = [0] * len(states_and_outputs)

    # 5. Construct the scan op
    # 5.1 Construct list of shared variables with updates (those that
    # can be treated as states (i.e. of TensorType) and those that can not
    # (like Random States)

    if cond is not None:
        _cond = [cond]
    else:
        _cond = []
    rvals = rebuild_collect_shared(
        states_and_outputs + _cond,
        updates=updates,
        rebuild_strict=True,
        copy_inputs_over=True,
        no_default_updates=False)

    # extracting the arguments
    input_variables, cloned_outputs, other_rval = rvals
    clone_d, update_d, update_expr, shared_inputs = other_rval
    additional_input_states = []
    additional_output_states = []
    additional_lengths = []
    additional_mintaps = []
    original_numeric_shared_variables = []

    non_numeric_input_states = []
    non_numeric_output_states = []
    original_non_numeric_shared_variables = []
    pos = len(lengths)
    for sv in shared_inputs:
        if sv in update_d:
            if isinstance(sv, (TensorVariable, TensorSharedVariable)):
                # We can treat it as a sit sot
                nw_state = scan_utils.expand(
                    tensor.unbroadcast(tensor.shape_padleft(sv), 0), T)
                additional_lengths.append(scalar_shared(numpy.int64(0),
                                                       name='l%d' % pos))
                pos = pos + 1
                additional_mintaps.append(1)
                additional_input_states.append(nw_state)
                additional_output_states.append(
                    scan_utils.clone(tensor.set_subtensor(
                        nw_state[(t + 1) % additional_lengths[-1]],
                        update_d[sv])))
                original_numeric_shared_variables.append(sv)
            else:
                non_numeric_input_states.append(sv)
                non_numeric_output_states.append(update_d[sv])
                original_non_numeric_shared_variables.append(sv)

    # Replace shared variables in the update
    _additional_output_states = []
    replace = {}
    for sv, buf in zip(original_numeric_shared_variables,
                       additional_input_states):
        replace[sv] = buf[t]
    for out in additional_output_states:
        _additional_output_states.append(
            scan_utils.clone(out, replace=replace))
    additional_output_states = _additional_output_states

    # 5.2 Collect inputs/outputs of the inner function
    inputs = []
    outputs = []
    for n, mintap in enumerate(mintaps):
        if mintap != 0:
            input_state = states_and_outputs_info[n]['initial']
            inputs.append(input_state)
            outputs.append(
                tensor.set_subtensor(
                    input_state[(t + mintap) % lengths[n]],
                    states_and_outputs[n]))
        else:
            mem_buffer = scan_utils.allocate_memory(
                T, states_and_outputs_info[n], states_and_outputs[n])
            inputs.append(output)
            outputs.append(
                tensor.set_subtensor(output[t % lengths[n]],
                                     states_and_outputs[n]))
    inputs.extend(additional_input_states)
    outputs.extend(additional_output_states)
    lengths.extend(additional_lengths)
    mintaps.extend(additional_mintaps)
    inputs.extend(non_numeric_input_states)
    outputs.extend(non_numeric_output_states)
    all_other_inputs = gof.graph.inputs(outputs)
    parameters = [x for x in all_other_inputs
                  if (x not in inputs and x not in lengths and x is not t
                      and isinstance(x, gof.Variable) and
                      not isinstance(x, gof.Constant))]
    inputs.extend(parameters)
    # 5.3 Construct the the options dictionary
    options['name'] = name
    options['profile'] = profile
    options['mode'] = mode
    options['inplace'] = False
    options['gpu'] = False
    options['truncate_gradient'] = truncate_gradient
    options['hash_inner_graph'] = 0
    # 5.4 Construct the ScanOp instance
    local_op = scan_op.ScanOp(inputs=inputs,
                              outputs=outputs,
                              lengths=lengths,
                              switches=[],
                              mintaps=mintaps,
                              index=t,
                              options=options,
                              as_repeatUntil=cond)
    # Note that we get here all the outputs followed by the update rules to
    # the shared variables we had in our scan
    # we know that we have (in this given order):
    #   * len(states_and_outputs) real outputs
    #   * len(additional_input_states) updates for numeric shared variable
    #   * len(non_numeric_input_states) updates for non numeric shared
    #   variables
    scan_inputs = [T] + inputs
    scan_outputs_update_rules = scan_utils.to_list(local_op(*scan_inputs))
    # 5.5 Collect outputs and add permutation object
    scan_outputs = []
    for pos in xrange(len(states_and_outputs)):
        out = scan_utils.ScanPermutation(mintaps[pos])(
            scan_outputs_update_rules[pos], t)
        scan_outputs.append(out[mintaps[pos]:])
    # 5.6 Construct updates dictionary
    update_rules = scan_outputs_update_rules[len(states_and_outputs):]
    updates = {}
    for v, u in izip(original_numeric_shared_variables,
                     update_rules[:len(additional_input_states)]):
        updates[v] = u[-1]
    for v, u in izip(original_non_numeric_shared_variables,
                     update_rules[len(additional_input_states):]):
        updates[v] = u
    # Step 5.7 We are done and can return everything back to the user
    return scan_outputs, updates
Ejemplo n.º 5
0
def scan(fn,
         sequences=None,
         outputs_info=None,
         non_sequences=None,
         n_steps=None,
         truncate_gradient=-1,
         go_backwards=False,
         mode=None,
         name=None,
         options=None,
         profile=False):
    """
    This function constructs and applies a Scan op to the provided
    arguments.

    :param fn:
        ``fn`` is a function that describes the operations involved in one
        step of ``scan``. ``fn`` should construct variables describing the
        output of one iteration step. It should expect as input theano
        variables representing all the slices of the input sequences
        and previous values of the outputs, as well as all other arguments
        given to scan as ``non_sequences``. The order in which scan passes
        these variables to ``fn``  is the following :

        * all time slices of the first sequence
        * all time slices of the second sequence
        * ...
        * all time slices of the last sequence
        * all past slices of the first output
        * all past slices of the second otuput
        * ...
        * all past slices of the last output
        * all other arguments (the list given as `non_sequences` to
            scan)

        The order of the sequences is the same as the one in the list
        `sequences` given to scan. The order of the outputs is the same
        as the order of ``outputs_info``. For any sequence or output the
        order of the time slices is the same as the one in which they have
        been given as taps. For example if one writes the following :

        .. code-block:: python

            scan(fn, sequences = [ dict(input= Sequence1, taps = [-3,2,-1])
                                 , Sequence2
                                 , dict(input =  Sequence3, taps = 3) ]
                   , outputs_info = [ dict(initial =  Output1, taps = [-3,-5])
                                    , dict(initial = Output2, taps = None)
                                    , Output3 ]
                   , non_sequences = [ Argument1, Argument 2])

        ``fn`` should expect the following arguments in this given order:

        #. ``Sequence1[t-3]``
        #. ``Sequence1[t+2]``
        #. ``Sequence1[t-1]``
        #. ``Sequence2[t]``
        #. ``Sequence3[t+3]``
        #. ``Output1[t-3]``
        #. ``Output1[t-5]``
        #. ``Output3[t-1]``
        #. ``Argument1``
        #. ``Argument2``

        The list of ``non_sequences`` can also contain shared variables
        used in the function, though ``scan`` is able to figure those
        out on its own so they can be skipped. For the clarity of the
        code we recommend though to provide them to scan. To some extend
        ``scan`` can also figure out other ``non sequences`` (not shared)
        even if not passed to scan (but used by `fn`). A simple example of
        this would be :

        .. code-block:: python

            import theano.tensor as TT
            W   = TT.matrix()
            W_2 = W**2
            def f(x):
                return TT.dot(x,W_2)

        The function is expected to return two things. One is a list of
        outputs ordered in the same order as ``outputs_info``, with the
        difference that there should be only one output variable per
        output initial state (even if no tap value is used). Secondly
        `fn` should return an update dictionary (that tells how to
        update any shared variable after each iteration step). The
        dictionary can optionally be given as a list of tuples. There is
        no constraint on the order of these two list, ``fn`` can return
        either ``(outputs_list, update_dictionary)`` or
        ``(update_dictionary, outputs_list)`` or just one of the two (in
        case the other is empty).

        To use ``scan`` as a while loop, the user needs to change the
        function ``fn`` such that also a stopping condition is returned.
        To do so, he/she needs to wrap the condition in an ``until`` class.
        The condition should be returned as a third element, for example:

        .. code-block:: python

            ...
            return [y1_t, y2_t], {x:x+1}, theano.scan_module.until(x < 50)

        Note that a number of steps (considered in here as the maximum
        number of steps ) is still required even though a condition is
        passed (and it is used to allocate memory if needed). = {}):

    :param sequences:
        ``sequences`` is the list of Theano variables or dictionaries
        describing the sequences ``scan`` has to iterate over. If a
        sequence is given as wrapped in a dictionary, then a set of optional
        information can be provided about the sequence. The dictionary
        should have the following keys:

        * ``input`` (*mandatory*) -- Theano variable representing the
          sequence.

        * ``taps`` -- Temporal taps of the sequence required by ``fn``.
          They are provided as a list of integers, where a value ``k``
          impiles that at iteration step ``t`` scan will pass to ``fn``
          the slice ``t+k``. Default value is ``[0]``

        Any Theano variable in the list ``sequences`` is automatically
        wrapped into a dictionary where ``taps`` is set to ``[0]``


    :param outputs_info:
        ``outputs_info`` is the list of Theano variables or dictionaries
        describing the initial state of the outputs computed
        recurrently. When this initial states are given as dictionary
        optional information can be provided about the output corresponding
        to these initial states. The dictionary should have the following
        keys:

        * ``initial`` -- Theano variable that represents the initial
          state of a given output. In case the output is not computed
          recursively (think of a map) and does not require a initial
          state this field can be skiped. Given that only the previous
          time step of the output is used by ``fn`` the initial state
          should have the same shape as the output. If multiple time
          taps are used, the initial state should have one extra
          dimension that should cover all the possible taps. For example
          if we use ``-5``, ``-2`` and ``-1`` as past taps, at step 0,
          ``fn`` will require (by an abuse of notation) ``output[-5]``,
          ``output[-2]`` and ``output[-1]``. This will be given by
          the initial state, which in this case should have the shape
          (5,)+output.shape. If this variable containing the initial
          state is called ``init_y`` then ``init_y[0]`` *corresponds to*
          ``output[-5]``. ``init_y[1]`` *correponds to* ``output[-4]``,
          ``init_y[2]`` corresponds to ``output[-3]``, ``init_y[3]``
          coresponds to ``output[-2]``, ``init_y[4]`` corresponds to
          ``output[-1]``. While this order might seem strange, it comes
          natural from splitting an array at a given point. Assume that
          we have a array ``x``, and we choose ``k`` to be time step
          ``0``. Then our initial state would be ``x[:k]``, while the
          output will be ``x[k:]``. Looking at this split, elements in
          ``x[:k]`` are ordered exactly like those in ``init_y``.
        * ``taps`` -- Temporal taps of the output that will be pass to
          ``fn``. They are provided as a list of *negative* integers,
          where a value ``k`` implies that at iteration step ``t`` scan
          will pass to ``fn`` the slice ``t+k``.

        ``scan`` will follow this logic if partial information is given:

        * If an output is not wrapped in a dictionary, ``scan`` will wrap
          it in one assuming that you use only the last step of the output
          (i.e. it makes your tap value list equal to [-1]).
        * If you wrap an output in a dictionary and you do not provide any
          taps but you provide an initial state it will assume that you are
          using only a tap value of -1.
        * If you wrap an output in a dictionary but you do not provide any
          initial state, it assumes that you are not using any form of
          taps.
        * If you provide a ``None`` instead of a variable or a empty
          dictionary ``scan`` assumes that you will not use any taps for
          this output (like for example in case of a map)

        If ``outputs_info`` is an empty list or None, ``scan`` assumes
        that no tap is used for any of the outputs. If information is
        provided just for a subset of the outputs an exception is
        raised (because there is no convention on how scan should map
        the provided information to the outputs of ``fn``)


    :param non_sequences:
        ``non_sequences`` is the list of arguments that are passed to
        ``fn`` at each steps. One can opt to exclude variable
        used in ``fn`` from this list as long as they are part of the
        computational graph, though for clarity we encourage not to do so.


    :param n_steps:
        ``n_steps`` is the number of steps to iterate given as an int
        or Theano scalar. If any of the input sequences do not have
        enough elements, scan will raise an error. If the *value is 0* the
        outputs will have *0 rows*. If the value is negative, ``scan``
        will run backwards in time. If the ``go_backwards`` flag is already
        set and also ``n_steps`` is negative, ``scan`` will run forward
        in time. If n stpes is not provided, ``scan`` will figure
        out the amount of steps it should run given its input sequences.


    :param truncate_gradient:
        ``truncate_gradient`` is the number of steps to use in truncated
        BPTT.  If you compute gradients through a scan op, they are
        computed using backpropagation through time. By providing a
        different value then -1, you choose to use truncated BPTT instead
        of classical BPTT, where you go for only ``truncate_gradient``
        number of steps back in time.


    :param go_backwards:
        ``go_backwards`` is a flag indicating if ``scan`` should go
        backwards through the sequences. If you think of each sequence
        as indexed by time, making this flag True would mean that
        ``scan`` goes back in time, namely that for any sequence it
        starts from the end and goes towards 0.


    :param name:
        When profiling ``scan``, it is crucial to provide a name for any
        instance of ``scan``. The profiler will produce an overall
        profile of your code as well as profiles for the computation of
        one step of each instance of ``scan``. The ``name`` of the instance
        appears in those profiles and can greatly help to disambiguate
        information.

    :param mode:
        It is recommended to leave this argument to None, especially
        when profiling ``scan`` (otherwise the results are not going to
        be accurate). If you prefer the computations of one step of
        ``scan`` to be done differently then the entire function, you
        can use this parameter to describe how the computations in this
        loop are done (see ``theano.function`` for details about
        possible values and their meaning).

    :param profile:
        Flag or string. If true, or different from the empty string, a
        profile object will be created and attached to the inner graph of
        scan. In case ``profile`` is True, the profile object will have the
        name of the scan instance, otherwise it will have the passed string.
        Profile object collect (and print) information only when running the
        inner graph with the new cvm linker ( with default modes,
        other linkers this argument is useless)

    :rtype: tuple
    :return: tuple of the form (outputs, updates); ``outputs`` is either a
             Theano variable or a list of Theano variables representing the
             outputs of ``scan`` (in the same order as in
             ``outputs_info``). ``updates`` is a subclass of dictionary
             specifying the
             update rules for all shared variables used in scan
             This dictionary should be passed to ``theano.function`` when
             you compile your function. The change compared to a normal
             dictionary is that we validate that keys are SharedVariable
             and addition of those dictionary are validated to be consistent.
    """
    # Note : see the internal documentation of the scan op for naming
    # conventions and all other details
    if options is None:
        options = {}
    rvals = scan_utils.canonical_arguments(sequences,
                                           outputs_info,
                                           non_sequences,
                                           go_backwards,
                                           n_steps)
    inputs, states_and_outputs_info, parameters, T = rvals
    # If we provided a known number of steps ( before compilation)
    # and if that number is 1 or -1, then we can skip the Scan Op,
    # and just apply the inner function once
    # To do that we check here to see the nature of n_steps
    T_value = None
    if isinstance(n_steps, (float, int)):
        T_value = int(n_steps)
    else:
        try:
            T_value = opt.get_scalar_constant_value(n_steps)
        except (TypeError, AttributeError):
            T_value = None

    if T_value in (1, -1):
        return one_step_scan(fn,
                             inputs,
                             states_and_outputs_info,
                             parameters,
                             truncate_gradient)

    # 1. Variable representing the current time step
    t = scalar_shared(numpy.int64(0), name='t')

    # 2. Allocate memory for the states of scan.
    mintaps = []
    lengths = []
    for pos, arg_info in enumerate(states_and_outputs_info):
        if arg_info.get('taps', None) == [-1]:
            mintaps.append(1)
            lengths.append(scalar_shared(numpy.int64(0),
                                         name='l%d' % pos))
            arg_info['initial'] = scan_utils.expand(tensor.unbroadcast(
                    tensor.shape_padleft(arg_info['initial']), 0), T)
        elif arg_info.get('taps', None):
            if numpy.any(numpy.array(arg_info.get('taps', [])) > 0):
                # Make sure we do not have requests for future values of a
                # sequence we can not provide such values
                raise ValueError('Can not use future taps of outputs',
                                 arg_info)
            mintap = abs(numpy.min(arg_info['taps']))
            lengths.append(scalar_shared(numpy.int64(0),
                                         name='l%d' % pos))
            mintaps.append(mintap)
            arg_info['initial'] = scan_utils.expand(
                arg_info['initial'][:mintap], T)
        else:
            mintaps.append(0)
            lengths.append(scalar_shared(numpy.int64(0),
                                         name='l%d' % pos))

    # 3. Generate arguments for the function passed to scan. This will
    # function will return the outputs that need to be computed at every
    # timesteps
    inputs_slices = [input[t] for input in inputs]
    states_slices = []
    for n, state in enumerate(states_and_outputs_info):
        # Check if it is actually a state and not an output
        if mintaps[n] != 0:
            for k in state['taps']:
                states_slices.append(
                    state['initial'][(t + mintaps[n] + k) % lengths[n]])

    # 4. Construct outputs that are to be computed by the inner
    # function of scan
    args = inputs_slices + states_slices + parameters
    cond, states_and_outputs, updates = \
            scan_utils.get_updates_and_outputs(fn(*args))

    # User is allowed to provide no information if it only behaves like a
    # map
    if (len(states_and_outputs) != len(states_and_outputs_info) and
        len(states_and_outputs_info) == 0):
        mintaps = [0] * len(states_and_outputs)

    # 5. Construct the scan op
    # 5.1 Construct list of shared variables with updates (those that
    # can be treated as states (i.e. of TensorType) and those that can not
    # (like Random States)

    if cond is not None:
        _cond = [cond]
    else:
        _cond = []
    rvals = rebuild_collect_shared(
        states_and_outputs + _cond,
        updates=updates,
        rebuild_strict=True,
        copy_inputs_over=True,
        no_default_updates=False)

    # extracting the arguments
    input_variables, cloned_outputs, other_rval = rvals
    clone_d, update_d, update_expr, shared_inputs = other_rval
    additional_input_states = []
    additional_output_states = []
    additional_lengths = []
    additional_mintaps = []
    original_numeric_shared_variables = []

    non_numeric_input_states = []
    non_numeric_output_states = []
    original_non_numeric_shared_variables = []
    pos = len(lengths)
    for sv in shared_inputs:
        if sv in update_d:
            if isinstance(sv, (TensorVariable, TensorSharedVariable)):
                # We can treat it as a sit sot
                nw_state = scan_utils.expand(
                    tensor.unbroadcast(tensor.shape_padleft(sv), 0), T)
                additional_lengths.append(scalar_shared(numpy.int64(0),
                                                       name='l%d' % pos))
                pos = pos + 1
                additional_mintaps.append(1)
                additional_input_states.append(nw_state)
                additional_output_states.append(
                    scan_utils.clone(tensor.set_subtensor(
                        nw_state[(t + 1) % additional_lengths[-1]],
                        update_d[sv])))
                original_numeric_shared_variables.append(sv)
            else:
                non_numeric_input_states.append(sv)
                non_numeric_output_states.append(update_d[sv])
                original_non_numeric_shared_variables.append(sv)

    # Replace shared variables in the update
    _additional_output_states = []
    replace = {}
    for sv, buf in zip(original_numeric_shared_variables,
                       additional_input_states):
        replace[sv] = buf[t]
    for out in additional_output_states:
        _additional_output_states.append(
            scan_utils.clone(out, replace=replace))
    additional_output_states = _additional_output_states

    # 5.2 Collect inputs/outputs of the inner function
    inputs = []
    outputs = []
    for n, mintap in enumerate(mintaps):
        if mintap != 0:
            input_state = states_and_outputs_info[n]['initial']
            inputs.append(input_state)
            outputs.append(
                tensor.set_subtensor(
                    input_state[(t + mintap) % lengths[n]],
                    states_and_outputs[n]))
        else:
            mem_buffer = scan_utils.allocate_memory(
                T, states_and_outputs_info[n], states_and_outputs[n])
            inputs.append(output)
            outputs.append(
                tensor.set_subtensor(output[t % lengths[n]],
                                     states_and_outputs[n]))
    inputs.extend(additional_input_states)
    outputs.extend(additional_output_states)
    lengths.extend(additional_lengths)
    mintaps.extend(additional_mintaps)
    inputs.extend(non_numeric_input_states)
    outputs.extend(non_numeric_output_states)
    all_other_inputs = gof.graph.inputs(outputs)
    parameters = [x for x in all_other_inputs
                  if (x not in inputs and x not in lengths and x is not t
                      and isinstance(x, gof.Variable) and
                      not isinstance(x, gof.Constant))]
    inputs.extend(parameters)
    # 5.3 Construct the the options dictionary
    options['name'] = name
    options['profile'] = profile
    options['mode'] = mode
    options['inplace'] = False
    options['gpu'] = False
    options['truncate_gradient'] = truncate_gradient
    options['hash_inner_graph'] = 0
    # 5.4 Construct the ScanOp instance
    local_op = scan_op.ScanOp(inputs=inputs,
                              outputs=outputs,
                              lengths=lengths,
                              switches=[],
                              mintaps=mintaps,
                              index=t,
                              options=options,
                              as_repeatUntil=cond)
    # Note that we get here all the outputs followed by the update rules to
    # the shared variables we had in our scan
    # we know that we have (in this given order):
    #   * len(states_and_outputs) real outputs
    #   * len(additional_input_states) updates for numeric shared variable
    #   * len(non_numeric_input_states) updates for non numeric shared
    #   variables
    scan_inputs = [T] + inputs
    scan_outputs_update_rules = scan_utils.to_list(local_op(*scan_inputs))
    # 5.5 Collect outputs and add permutation object
    scan_outputs = []
    for pos in xrange(len(states_and_outputs)):
        out = scan_utils.ScanPermutation(mintaps[pos])(
            scan_outputs_update_rules[pos], t)
        scan_outputs.append(out[mintaps[pos]:])
    # 5.6 Construct updates dictionary
    update_rules = scan_outputs_update_rules[len(states_and_outputs):]
    updates = {}
    for v, u in izip(original_numeric_shared_variables,
                     update_rules[:len(additional_input_states)]):
        updates[v] = u[-1]
    for v, u in izip(original_non_numeric_shared_variables,
                     update_rules[len(additional_input_states):]):
        updates[v] = u
    # Step 5.7 We are done and can return everything back to the user
    return scan_outputs, updates
Ejemplo n.º 6
0
                        # No need to print a warning or raise an error now,
                        # it will be done when fn will be called.
                        _logger.info(
                            ('Cannot compute test value for the '
                             'inner function of scan, input value missing %s'),
                            e)

            if getattr(init_out['initial'], 'name', None) is not None:
                arg.name = init_out['initial'].name + '[t-1]'

            # We need now to allocate space for storing the output and copy
            # the initial state over. We do this using the expand function
            # defined in scan utils
            sit_sot_scan_inputs.append(
                scan_utils.expand(
                    tensor.unbroadcast(tensor.shape_padleft(actual_arg), 0),
                    actual_n_steps))

            sit_sot_inner_slices.append(actual_arg)
            if i in return_steps:
                sit_sot_return_steps[n_sit_sot] = return_steps[i]
            sit_sot_inner_inputs.append(arg)
            sit_sot_rightOrder.append(i)
            n_sit_sot += 1

        elif init_out.get('taps', None):

            if numpy.any(numpy.array(init_out.get('taps', [])) > 0):
                # Make sure we do not have requests for future values of a
                # sequence we can not provide such values
                raise ValueError('Can not use future taps of outputs',
Ejemplo n.º 7
0
                except AttributeError, e:
                    if config.compute_test_value != "ignore":
                        # No need to print a warning or raise an error now,
                        # it will be done when fn will be called.
                        _logger.info(
                            ("Cannot compute test value for the " "inner function of scan, input value missing %s"), e
                        )

            if getattr(init_out["initial"], "name", None) is not None:
                arg.name = init_out["initial"].name + "[t-1]"

            # We need now to allocate space for storing the output and copy
            # the initial state over. We do this using the expand function
            # defined in scan utils
            sit_sot_scan_inputs.append(
                scan_utils.expand(tensor.unbroadcast(tensor.shape_padleft(actual_arg), 0), actual_n_steps)
            )

            sit_sot_inner_slices.append(actual_arg)
            if i in return_steps:
                sit_sot_return_steps[n_sit_sot] = return_steps[i]
            sit_sot_inner_inputs.append(arg)
            sit_sot_rightOrder.append(i)
            n_sit_sot += 1

        elif init_out.get("taps", None):

            if numpy.any(numpy.array(init_out.get("taps", [])) > 0):
                # Make sure we do not have requests for future values of a
                # sequence we can not provide such values
                raise ValueError("Can not use future taps of outputs", init_out)