Beispiel #1
0
    def process_node(self, env, node):

        # helpful functions
        def select_min(x, y):
            if x is None:
                return y
            if y is None:
                return x
            return tensor.minimum(x, y)

        def select_max(x, y):
            if x is None:
                return y
            if y is None:
                return x
            return tensor.maximum(x, y)

        def sanitize(x):
            if x is None:
                return None
            else:
                return tensor.as_tensor_variable(x)

        shape_of = node.env.shape_feature.shape_of
        # 1. Initialization of variables
        # Note 1) We do not actually care about outputs representing shared
        # variables (those have no intermediate values) so it is safer to
        # ignore them and not change them in any way. To simplify the
        # optimizations I construct the variable ``c_outs`` ( that counts
        # outputs up to those we care) and the list ``init_l`` which for any
        # output we care says the length of its initial state. Note that
        # defining ``init_l`` for mit_mot sequences is a bit trickier but
        # it is safe to set it to 0
        op = node.op
        c_outs = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot

        init_l = [0 for x in xrange(op.n_mit_mot)]
        init_l += [abs(numpy.min(v)) for v in op.tap_array[op.n_mit_mot :]]
        init_l += [0 for x in xrange(op.n_nit_sot)]
        # 2. Check the clients of each output and see for how many steps
        # does scan need to run

        # This comparison checks if there is any uncounted output, which
        # can only be an output corresponding to a shared variable

        # 2.1 Initialize
        # global_nsteps is a dictionary having two fields ( 'real' deals
        # with int values, 'sym' with symbolic ones) or None
        # given that a scan op has k outputs o_1, .. o_k and each
        # output has n_j clients c_1^1, c_1^2, .. c_1^{n_1}, c_2^1, ..,
        # global_nsteps is None if any of the clients is different
        # from a subtensor or its real and sym field equal to
        # max(c_i_j.idx_list[0].stop), meaning store up to which maximal
        # index(step) for any output scan actually needs to compute
        # In other words n_steps should be equal to this maximal !
        # Note: if we have a shared variable that gets updated at every step
        # of the loop, reducing the number of steps will affect the the
        # value of the shared variable after the loop so we need not to
        # change the number of steps in that case. To do this we set
        # global_nsteps to None which is seen as a flag that nothing needs
        # to be done
        if len(node.outputs) <= c_outs:
            global_nsteps = {"real": -1, "sym": []}
        else:
            global_nsteps = None

        # Keeps track of the original slices that each client represent
        slices = [None for o in node.outputs]

        # A list for each output indicating how many intermediate values
        # should be stored. If negative it means none of the intermediate
        # values (i.e. the output can be removed since it is not used
        # afterwards in the computations), if 0 it means that all
        # intermediate values are required, otherwise is up to that number
        # of intermediate values
        # Note that for mit_mot outputs and shared outputs we can not change
        # the number of intermediate steps stored without affecting the
        # result of the op
        store_steps = [0 for o in xrange(op.n_mit_mot)]
        store_steps += [-1 for o in node.outputs[op.n_mit_mot : c_outs]]
        # Flag that says if an input has changed and we need to do something
        # or not
        flag_store = False

        # 2.2 Loop over the clients
        for i, out in enumerate(node.outputs[:c_outs]):
            # look at all its clients
            slices[i] = []
            for cl, _ in out.clients:

                # 2.1 outputs of the function
                # => output needs all its intermediate values
                if type(cl) == str:
                    # if the node is actually an output, then
                    # we need to store the entire thing
                    global_nsteps = None
                    slices[i] = None
                    break
                # 2.2 non-subtensor nodes
                # => output needs all its intermediate values
                elif not isinstance(cl.op, tensor.basic.Subtensor):
                    global_nsteps = None
                    slices[i] = None
                    break
                # 2.3 subtensor nodes
                # => output might need to store just a subset of its values
                else:
                    # 2.3.1 extract idx list of subtensor
                    this_slice = tensor.basic.get_idx_list(cl.inputs, cl.op.idx_list)
                    if this_slice == None:
                        # if unable to extract idx_list
                        # => outputs needs all its intermediate values
                        global_nsteps = None
                        slices[i] = None
                        break

                    # 2.3.2 extract the begin/end of the first dimension
                    if i > op.n_mit_mot:
                        try:
                            length = shape_of[out][0]
                        except Exception:
                            length = node.inputs[0] + init_l[i]
                    else:
                        try:
                            length = shape_of[out][0]
                        except Exception:
                            length = out.shape[0]
                    cf_slice = tensor.basic.get_canonical_form_slice(this_slice[0], length)
                    slices[i] += [(cf_slice, this_slice)]

                    if isinstance(this_slice[0], slice) and this_slice[0].stop is None:
                        global_nsteps = None
                        break
                    if isinstance(cf_slice[0], slice):
                        stop = tensor.basic.extract_constant(cf_slice[0].stop)
                    else:
                        stop = tensor.basic.extract_constant(cf_slice[0]) + 1
                    if stop == sys.maxint or stop == length:
                        stop = None
                    else:
                        # there is a **gotcha** here ! Namely, scan returns an
                        # array that contains the initial state of the output
                        # as well. Which means that if have a initial state of
                        # length 3, and you look for 5 steps you get an output
                        # y of length 8. If you only use y[:5], this does not
                        # mean that you only need to loop for 5 steps but
                        # actually only for 2 steps ( the first 3 are the
                        # initial state)
                        stop = stop - init_l[i]

                    # 2.3.3 we might get away with less number of steps
                    if stop is not None and global_nsteps is not None:
                        # yes if it is a tensor
                        if isinstance(stop, tensor.Variable):
                            global_nsteps["sym"] += [stop]
                        # not if it is maxint
                        elif type(stop) is int and stop == sys.maxint:
                            global_nsteps = None
                        # yes if it is a int k, 0 < k < maxint
                        elif type(stop) is int and global_nsteps["real"] < stop:
                            global_nsteps["real"] = stop
                        # yes if it is a int k, 0 < k < maxint
                        elif type(stop) is int and stop > 0:
                            pass
                        # not otherwise
                        else:
                            global_nsteps = None

        # 2.3. Analyze global_nsteps to figure out for how many steps scan
        # needs to iterate
        if global_nsteps is not None:
            nw_steps = node.inputs[0]

            # there are some symbolic tensors that limit the number of
            # steps
            if len(global_nsteps["sym"]) == 0:
                sym_steps = None
            else:
                sym_steps = global_nsteps["sym"][0]
                for c in global_nsteps["sym"][1:]:
                    sym_steps = tensor.maximum(sym_steps, c)

            if global_nsteps["real"] >= 0:
                real_steps = global_nsteps["real"]
            else:
                real_steps = None
            nw_steps = select_min(select_max(sym_steps, real_steps), node.inputs[0])
        else:
            nw_steps = node.inputs[0]
            global_nsteps = None

        # 2.4 Loop over the clients again now looking just to see how many
        # intermediate steps to store
        for i, out in enumerate(node.outputs[:c_outs]):
            # look at all its clients
            for cl, _ in out.clients:
                if type(cl) == str:
                    store_steps[i] = 0
                    break
                elif not isinstance(cl.op, tensor.basic.Subtensor):
                    store_steps[i] = 0
                    break
                else:
                    this_slice = tensor.basic.get_idx_list(cl.inputs, cl.op.idx_list)
                    if this_slice == None:
                        store_steps[i] = 0
                        break

                    if isinstance(this_slice[0], slice) and this_slice[0].start is None:
                        store_steps[i] = 0
                        break

                    if i > op.n_mit_mot:
                        length = node.inputs[0] + init_l[i]
                    else:
                        try:
                            length = shape_of[out][0]
                        except Exception:
                            length = out.shape[0]
                    cf_slice = tensor.basic.get_canonical_form_slice(this_slice[0], length)

                    if isinstance(cf_slice[0], slice):
                        start = tensor.basic.extract_constant(cf_slice[0].start)
                    else:
                        start = tensor.basic.extract_constant(cf_slice[0])
                    if start == 0 or store_steps[i] == 0:
                        store_steps[i] = 0
                    else:
                        pval = select_max(nw_steps - start + init_l[i], init_l[i])
                        if store_steps[i] != -1:
                            pval = select_max(pval, store_steps[i])

                        store_steps[i] = pval
                        flag_store = True

        orphane_outs = [i for i, x in enumerate(store_steps) if (type(x) is int) and (x < 0)]
        flag_store = flag_store or (len(orphane_outs) > 0)
        # 3. is there anything to change ?
        if flag_store or global_nsteps is not None:
            # 3.1 initialize inputs for the new scan
            old_outputs = []
            nw_inputs = list(node.inputs)
            nw_inputs[0] = nw_steps

            # 3.2 check orphane outputs to see if we can eliminate any
            required, not_required = scan_utils.scan_can_remove_outs(node.op, orphane_outs)
            # 3.3. compose replace pairs for those nodes that need not
            # to store everything in memory ( or ar orphane and required
            # by the inner function .. )
            replaced_outs = []
            offset = 1 + op.n_seqs + op.n_mit_mot
            for idx, _val in enumerate(store_steps[op.n_mit_mot :]):
                i = idx + op.n_mit_mot
                if not (type(_val) is int and _val <= 0 and i not in required):

                    if idx + op.n_mit_mot in required:
                        val = 1
                    else:
                        val = _val
                    # If the memory for this output has been pre-allocated
                    # before going into the scan op (by an alloc node)
                    if idx < op.n_mit_sot + op.n_sit_sot:
                        # In case the input is still an alloc node, we
                        # actually have two options:
                        #   a) the input is a set_subtensor, in that case we
                        #      can replace the initial tensor by a slice,
                        #   b) it is not, and we simply take a slice of it.
                        if nw_inputs[offset + idx].owner and isinstance(
                            nw_inputs[offset + idx].owner.op, tensor.IncSubtensor
                        ):
                            _nw_input = nw_inputs[offset + idx].owner.inputs[1]
                            tmp = pre_greedy_local_optimizer(list_opt_slice, tensor.as_tensor_variable(val - init_l[i]))
                            tmp = pre_constant_merge([tmp])[0]
                            nw_input = scan_utils.expand(_nw_input, tmp)
                        else:
                            tmp = pre_greedy_local_optimizer(list_opt_slice, tensor.as_tensor_variable(val))
                            tmp = pre_constant_merge([tmp])[0]
                            nw_input = nw_inputs[offset + idx][:tmp]

                        nw_inputs[offset + idx] = nw_input
                        replaced_outs.append(op.n_mit_mot + idx)
                        odx = op.n_mit_mot + idx
                        old_outputs += [(odx, [x[0].outputs[0] for x in node.outputs[odx].clients])]
                    # If there is no memory pre-allocated for this output
                    elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot:

                        pos = op.n_mit_mot + idx + op.n_seqs + 1 + op.n_shared_outs
                        if nw_inputs[pos] == node.inputs[0]:
                            nw_inputs[pos] = val
                        odx = op.n_mit_mot + idx
                        replaced_outs.append(odx)
                        old_outputs += [(odx, [x[0].outputs[0] for x in node.outputs[odx].clients])]
            # 3.4. Recompute inputs for everything else based on the new
            # number of steps
            if global_nsteps is not None:
                for idx, val in enumerate(store_steps[op.n_mit_mot :]):
                    if val == 0:
                        if idx < op.n_mit_sot + op.n_sit_sot:
                            _nw_input = nw_inputs[offset + idx].owner.inputs[1]
                            odx = op.n_mit_mot + idx
                            nw_input = scan_utils.expand(_nw_input, nw_steps)
                            nw_inputs[offset + idx] = nw_input
                        elif idx < (op.n_mit_sot + op.n_sit_sot + op.n_nit_sot):
                            in_idx = offset + idx + op.n_shared_outs
                            if nw_inputs[in_idx] == node.inputs[0]:
                                nw_inputs[in_idx] = nw_steps
                            odx = op.n_mit_mot + idx

            # 3.5 Remove unwanted orphane outputs
            (inps, outs, info, node_ins, compress_map) = scan_utils.compress_outs(op, not_required, nw_inputs)
            inv_compress_map = {}
            for k, v in compress_map.items():
                inv_compress_map[v] = k

            node_ins = [pre_greedy_local_optimizer(list_opt_slice, x) for x in node_ins]
            node_ins = pre_constant_merge(node_ins)
            # 3.6 Compose the new scan
            # I need to make sure I'm not reapplying the same optimization
            # twice since bad things usually happen if I do that
            info["_scan_merge_visited"] = True
            new_outs = scan_op.Scan(inps, outs, info).make_node(*node_ins).outputs

            old_new = []
            # 3.7 Get replace pairs for those outputs that do not change
            # the number of intermediate steps stored
            for idx, sl in enumerate(slices):
                if global_nsteps and sl is not None and store_steps[idx] == 0:
                    for hdx, cl in enumerate(node.outputs[idx].clients):
                        cnf_slice, old_slices = sl[hdx]
                        # Sanitize the nw_slice by converting ints back into
                        # constants :) I only need to do this for the first
                        # slice since that is the only slice

                        if isinstance(cnf_slice[0], slice):
                            fslice = slice(
                                sanitize(cnf_slice[0].start), sanitize(cnf_slice[0].stop), sanitize(cnf_slice[0].step)
                            )
                        else:
                            fslice = sanitize(cnf_slice[0])

                        nw_slice = (fslice,) + tuple(old_slices[1:])
                        nw_pos = inv_compress_map[idx]
                        nw_out = new_outs[nw_pos]

                        subtens = tensor.basic.Subtensor(nw_slice)
                        # slice inputs
                        sl_ins = tensor.basic.Subtensor.collapse(
                            nw_slice, lambda entry: isinstance(entry, tensor.Variable)
                        )
                        new_o = subtens.make_node(new_outs[nw_pos], *sl_ins).outputs[0]
                        if new_o.ndim > 0:
                            new_o = new_o[:: cnf_slice[1]]
                        replaced_outs.append(idx)
                        old_new += [(cl[0].outputs[0], new_o)]
            # 3.8. Get replace pairs for those outputs that change
            # the number of stored intermediate steps
            for pos, old_outs in old_outputs:
                if len(old_outs) > 0:
                    nw_pos = compress_map[pos]
                    nw_out = new_outs[nw_pos]
                    for k, old in enumerate(old_outs):
                        # Get the correct slice
                        cnf_slice, old_slices = slices[pos][k]
                        if type(cnf_slice[0]) is slice:
                            start = cnf_slice[0].start - nw_steps - init_l[pos] + store_steps[pos]
                            if cnf_slice[0].stop is not None and cnf_slice[0].stop != sys.maxint:
                                stop = cnf_slice[0].stop - nw_steps - init_l[pos] + store_steps[pos]
                            else:
                                stop = None
                            nw_slice = (slice(sanitize(start), sanitize(stop), sanitize(cnf_slice[0].step)),) + tuple(
                                old_slices[1:]
                            )

                        else:
                            position = cnf_slice[0] - nw_steps - init_l[pos] + store_steps[pos]

                            nw_slice = (sanitize(position),) + tuple(old_slices[1:])

                        subtens = tensor.basic.Subtensor(nw_slice)
                        sl_ins = tensor.basic.Subtensor.collapse(
                            nw_slice, lambda entry: isinstance(entry, tensor.Variable)
                        )
                        new_o = subtens.make_node(new_outs[nw_pos], *sl_ins).outputs[0]
                        if new_o.ndim > 0:
                            new_o = new_o[:: cnf_slice[1]]
                        old_new += [(old, new_o)]

            # 3.9. Get replace pairs for all other nodes
            if flag_store or global_nsteps is not None:
                for idx, o in enumerate(node.outputs):
                    if not (idx in replaced_outs) and not idx in not_required:
                        nw_pos = compress_map[idx]
                        old_new += [(o, new_outs[nw_pos])]

                env.replace_all_validate(old_new, reason="scan_save_mem")
Beispiel #2
0
    def process_node(self, fgraph, node):

        # helpful functions
        def select_min(x, y):
            if x is None:
                return y
            if y is None:
                return x
            return tensor.minimum(x, y)

        def select_max(x, y):
            if x is None:
                return y
            if y is None:
                return x
            return tensor.maximum(x, y)

        def sanitize(x):
            if x is None:
                return None
            else:
                return tensor.as_tensor_variable(x)

        if hasattr(fgraph, 'shape_feature'):
            shape_of = node.fgraph.shape_feature.shape_of
        else:
            # Each access to shape_of is in a try..except block in order to
            # use a default version when the variable is not in the shape_of
            # dictionary.
            shape_of = {}
        # 1. Initialization of variables
        # Note 1) We do not actually care about outputs representing shared
        # variables (those have no intermediate values) so it is safer to
        # ignore them and not change them in any way. To simplify the
        # optimizations I construct the variable ``c_outs`` ( that counts
        # outputs up to those we care) and the list ``init_l`` which for any
        # output we care says the length of its initial state. Note that
        # defining ``init_l`` for mit_mot sequences is a bit trickier but
        # it is safe to set it to 0
        op = node.op
        c_outs = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot

        init_l = [0 for x in xrange(op.n_mit_mot)]
        init_l += [abs(numpy.min(v)) for v in op.tap_array[op.n_mit_mot:]]
        init_l += [0 for x in xrange(op.n_nit_sot)]
        # 2. Check the clients of each output and see for how many steps
        # does scan need to run

        # This comparison checks if there is any uncounted output, which
        # can only be an output corresponding to a shared variable

        # 2.1 Initialize
        # global_nsteps is a dictionary having two fields ( 'real' deals
        # with int values, 'sym' with symbolic ones) or None
        # given that a scan op has k outputs o_1, .. o_k and each
        # output has n_j clients c_1^1, c_1^2, .. c_1^{n_1}, c_2^1, ..,
        # global_nsteps is None if any of the clients is different
        # from a subtensor or its real and sym field equal to
        # max(c_i_j.idx_list[0].stop), meaning store up to which maximal
        # index(step) for any output scan actually needs to compute
        # In other words n_steps should be equal to this maximal !
        # Note: if we have a shared variable that gets updated at every step
        # of the loop, reducing the number of steps will affect the the
        # value of the shared variable after the loop so we need not to
        # change the number of steps in that case. To do this we set
        # global_nsteps to None which is seen as a flag that nothing needs
        # to be done
        assert len(node.outputs) >= c_outs
        if len(node.outputs) == c_outs:
            global_nsteps = {'real': -1, 'sym': []}
        else:
            global_nsteps = None

        # Keeps track of the original slices that each client represent
        slices = [None for o in node.outputs]

        # A list for each output indicating how many intermediate values
        # should be stored. If negative it means none of the intermediate
        # values (i.e. the output can be removed since it is not used
        # afterwards in the computations), if 0 it means that all
        # intermediate values are required, otherwise is up to that number
        # of intermediate values
        # Note that for mit_mot outputs and shared outputs we can not change
        # the number of intermediate steps stored without affecting the
        # result of the op
        store_steps = [0 for o in xrange(op.n_mit_mot)]
        store_steps += [-1 for o in node.outputs[op.n_mit_mot:c_outs]]
        # Flag that says if an input has changed and we need to do something
        # or not
        flag_store = False

        # 2.2 Loop over the clients
        for i, out in enumerate(node.outputs[:c_outs]):
            # look at all its clients
            slices[i] = []
            for cl, _ in out.clients:

                # 2.1 outputs of the function
                #=> output needs all its intermediate values
                if type(cl) == str:
                    # if the node is actually an output, then
                    # we need to store the entire thing
                    global_nsteps = None
                    slices[i] = None
                    break
                # 2.2 non-subtensor nodes
                #=> output needs all its intermediate values
                elif not isinstance(cl.op, tensor.basic.Subtensor):
                    global_nsteps = None
                    slices[i] = None
                    break
                # 2.3 subtensor nodes
                #=> output might need to store just a subset of its values
                else:
                    # 2.3.1 extract idx list of subtensor
                    this_slice = tensor.basic.get_idx_list(cl.inputs,
                                                     cl.op.idx_list)
                    if this_slice is None:
                        # if unable to extract idx_list
                        #=> outputs needs all its intermediate values
                        global_nsteps = None
                        slices[i] = None
                        break

                    # 2.3.2 extract the begin/end of the first dimension
                    if i >= op.n_mit_mot:
                        try:
                            length = shape_of[out][0]
                        except KeyError:
                            length = node.inputs[0] + init_l[i]
                    else:
                        try:
                            length = shape_of[out][0]
                        except KeyError:
                            length = out.shape[0]
                    cf_slice = tensor.basic.get_canonical_form_slice(
                                                    this_slice[0], length)
                    slices[i] += [(cf_slice, this_slice)]

                    if (isinstance(this_slice[0], slice) and
                        this_slice[0].stop is None):
                        global_nsteps = None
                    if isinstance(cf_slice[0], slice):
                        stop = tensor.basic.extract_constant(cf_slice[0].stop)
                    else:
                        stop = tensor.basic.extract_constant(cf_slice[0]) + 1
                    if stop == maxsize or stop == length:
                        stop = None
                    else:
                        # there is a **gotcha** here ! Namely, scan returns an
                        # array that contains the initial state of the output
                        # as well. Which means that if have a initial state of
                        # length 3, and you look for 5 steps you get an output
                        # y of length 8. If you only use y[:5], this does not
                        # mean that you only need to loop for 5 steps but
                        # actually only for 2 steps ( the first 3 are the
                        # initial state)
                        stop = stop - init_l[i]

                    # 2.3.3 we might get away with less number of steps
                    if stop is not None and global_nsteps is not None:
                        # yes if it is a tensor
                        if isinstance(stop, tensor.Variable):
                            global_nsteps['sym'] += [stop]
                        # not if it is maxsize
                        elif (type(stop) in (int, long) and
                              stop == maxsize):
                            global_nsteps = None
                        # yes if it is a int k, 0 < k < maxsize
                        elif (type(stop) in (int, long) and
                              global_nsteps['real'] < stop):
                            global_nsteps['real'] = stop
                        # yes if it is a int k, 0 < k < maxsize
                        elif (type(stop) in (int, long) and stop > 0):
                            pass
                        # not otherwise
                        else:
                            global_nsteps = None

        # 2.3. Analyze global_nsteps to figure out for how many steps scan
        # needs to iterate
        if global_nsteps is not None:
            nw_steps = node.inputs[0]

            # there are some symbolic tensors that limit the number of
            # steps
            if len(global_nsteps['sym']) == 0:
                sym_steps = None
            else:
                sym_steps = global_nsteps['sym'][0]
                for c in global_nsteps['sym'][1:]:
                    sym_steps = tensor.maximum(sym_steps, c)

            if global_nsteps['real'] >= 0:
                real_steps = global_nsteps['real']
            else:
                real_steps = None
            nw_steps = select_min(select_max(sym_steps, real_steps),
                                  node.inputs[0])
        else:
            nw_steps = node.inputs[0]
            global_nsteps = None

        # 2.4 Loop over the clients again now looking just to see how many
        # intermediate steps to store
        for i, out in enumerate(node.outputs[:c_outs]):
            # look at all its clients
            for cl, _ in out.clients:
                if type(cl) == str:
                    store_steps[i] = 0
                    break
                elif not isinstance(cl.op, tensor.basic.Subtensor):
                    store_steps[i] = 0
                    break
                else:
                    this_slice = tensor.basic.get_idx_list(cl.inputs,
                                                         cl.op.idx_list)
                    if this_slice is None:
                        store_steps[i] = 0
                        break

                    if (isinstance(this_slice[0], slice) and
                        this_slice[0].start is None):
                        store_steps[i] = 0
                        break

                    if i > op.n_mit_mot:
                        length = node.inputs[0] + init_l[i]
                    else:
                        try:
                            length = shape_of[out][0]
                        except KeyError:
                            length = out.shape[0]
                    cf_slice = tensor.basic.get_canonical_form_slice(
                                                    this_slice[0], length)

                    if isinstance(cf_slice[0], slice):
                        start = tensor.basic.extract_constant(
                            cf_slice[0].start)
                    else:
                        start = tensor.basic.extract_constant(cf_slice[0])
                    if start == 0 or store_steps[i] == 0:
                        store_steps[i] = 0
                    else:
                        pval = select_max(nw_steps - start + init_l[i],
                                          init_l[i])
                        if store_steps[i] != -1:
                            pval = select_max(pval, store_steps[i])

                        store_steps[i] = pval
                        flag_store = True

        orphane_outs = [i for i, x in enumerate(store_steps)
                        if (type(x) is int) and (x < 0)]
        flag_store = flag_store or (len(orphane_outs) > 0)
        # 3. is there anything to change ?
        if (flag_store or global_nsteps is not None):
            # 3.1 initialize inputs for the new scan
            old_outputs = []
            nw_inputs = list(node.inputs)
            nw_inputs[0] = nw_steps

            # 3.2 check orphane outputs to see if we can eliminate any
            required, not_required = \
                    scan_utils.scan_can_remove_outs(node.op,
                                                    orphane_outs)
            # 3.3. compose replace pairs for those nodes that need not
            # to store everything in memory ( or ar orphane and required
            # by the inner function .. )
            replaced_outs = []
            offset = 1 + op.n_seqs + op.n_mit_mot
            for idx, _val in enumerate(store_steps[op.n_mit_mot:]):
                i = idx + op.n_mit_mot
                if not(type(_val) is int and _val <= 0 and i not in required):

                    if idx + op.n_mit_mot in required:
                        val = 1
                    else:
                        val = _val
                    # If the memory for this output has been pre-allocated
                    # before going into the scan op (by an alloc node)
                    if idx < op.n_mit_sot + op.n_sit_sot:
                        # In case the input is still an alloc node, we
                        # actually have two options:
                        #   a) the input is a set_subtensor, in that case we
                        #      can replace the initial tensor by a slice,
                        #   b) it is not, and we simply take a slice of it.

                        #TODO: commit change below with Razvan
                        if (nw_inputs[offset + idx].owner and
                            isinstance(nw_inputs[offset + idx].owner.op,
                                       tensor.IncSubtensor) and
                            isinstance(nw_inputs[offset+idx].owner.op.idx_list[0], slice)):

                            _nw_input = nw_inputs[offset + idx].owner.inputs[1]
                            val = tensor.as_tensor_variable(val)
                            initl = tensor.as_tensor_variable(init_l[i])
                            tmp = pre_greedy_local_optimizer(list_opt_slice,
                                    tensor.maximum(val - initl, 0))
                            tmp = pre_constant_merge([tmp])[0]
                            nw_input = scan_utils.expand(_nw_input, tmp)
                        else:
                            tmp = tensor.as_tensor_variable(val)
                            initl = tensor.as_tensor_variable(init_l[i])
                            tmp = tensor.maximum(tmp, initl)
                            tmp = pre_greedy_local_optimizer(list_opt_slice,
                                                             tmp)
                            tmp = pre_constant_merge([tmp])[0]
                            nw_input = nw_inputs[offset + idx][:tmp]

                        nw_inputs[offset + idx] = nw_input
                        replaced_outs.append(op.n_mit_mot + idx)
                        odx = op.n_mit_mot + idx
                        old_outputs += [(odx, [x[0].outputs[0] for x in
                                        node.outputs[odx].clients])]
                    # If there is no memory pre-allocated for this output
                    elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot:

                        pos = (op.n_mit_mot + idx + op.n_seqs +
                               1 + op.n_shared_outs)
                        if nw_inputs[pos] == node.inputs[0]:
                            nw_inputs[pos] = val
                        odx = op.n_mit_mot + idx
                        replaced_outs.append(odx)
                        old_outputs += [(odx, [x[0].outputs[0] for x in
                                        node.outputs[odx].clients])]
            # 3.4. Recompute inputs for everything else based on the new
            # number of steps
            if global_nsteps is not None:
                for idx, val in enumerate(store_steps[op.n_mit_mot:]):
                    if val == 0:
                        if idx < op.n_mit_sot + op.n_sit_sot:
                            _nw_input = nw_inputs[offset + idx].owner.inputs[1]
                            odx = op.n_mit_mot + idx
                            nw_input = scan_utils.expand(_nw_input, nw_steps)
                            nw_inputs[offset + idx] = nw_input
                        elif idx < (op.n_mit_sot + op.n_sit_sot +
                                    op.n_nit_sot):
                            in_idx = offset + idx + op.n_shared_outs
                            if nw_inputs[in_idx] == node.inputs[0]:
                                nw_inputs[in_idx] = nw_steps
                            odx = op.n_mit_mot + idx

            # 3.5 Remove unwanted orphane outputs
            (inps, outs, info, node_ins, compress_map) = \
                    scan_utils.compress_outs(op, not_required, nw_inputs)
            inv_compress_map = {}
            for k, v in compress_map.items():
                inv_compress_map[v] = k

            node_ins = [pre_greedy_local_optimizer(list_opt_slice, x) for x in
                        node_ins]
            node_ins = pre_constant_merge(node_ins)
            # 3.6 Compose the new scan
            # I need to make sure I'm not reapplying the same optimization
            # twice since bad things usually happen if I do that
            info['_scan_savemem_visited'] = True
            new_outs = scan_op.Scan(inps,
                                    outs,
                                    info).make_node(*node_ins).outputs

            old_new = []
            # 3.7 Get replace pairs for those outputs that do not change
            # the number of intermediate steps stored
            for idx, sl in enumerate(slices):
                if global_nsteps and sl is not None and store_steps[idx] == 0:
                    for hdx, cl in enumerate(node.outputs[idx].clients):
                        cnf_slice, old_slices = sl[hdx]
                        # Sanitize the nw_slice by converting ints back into
                        # constants :) I only need to do this for the first
                        # slice since that is the only slice

                        if isinstance(cnf_slice[0], slice):
                            fslice = slice(
                                sanitize(cnf_slice[0].start),
                                sanitize(cnf_slice[0].stop),
                                sanitize(cnf_slice[0].step))
                        else:
                            fslice = sanitize(cnf_slice[0])

                        nw_slice = (fslice,) + tuple(old_slices[1:])
                        nw_pos = inv_compress_map[idx]

                        subtens = tensor.basic.Subtensor(nw_slice)
                        # slice inputs
                        sl_ins = tensor.basic.Subtensor.collapse(
                            nw_slice,
                            lambda entry: isinstance(entry,
                                                    tensor.Variable))
                        new_o = subtens.make_node(new_outs[nw_pos],
                                                  *sl_ins).outputs[0]
                        if new_o.ndim > 0:
                            new_o = new_o[::cnf_slice[1]]
                        replaced_outs.append(idx)
                        old_new += [(cl[0].outputs[0], new_o)]
            # 3.8. Get replace pairs for those outputs that change
            # the number of stored intermediate steps
            for pos, old_outs in old_outputs:
                if len(old_outs) > 0:
                    nw_pos = compress_map[pos]
                    for k, old in enumerate(old_outs):
                        # Get the correct slice
                        cnf_slice, old_slices = slices[pos][k]
                        if type(cnf_slice[0]) is slice:
                            start = (cnf_slice[0].start - nw_steps -
                                     init_l[pos] + store_steps[pos])
                            if (cnf_slice[0].stop is not None and
                                cnf_slice[0].stop != maxsize):
                                stop = (cnf_slice[0].stop - nw_steps -
                                        init_l[pos] + store_steps[pos])
                            else:
                                stop = None
                            nw_slice = ((slice(sanitize(start),
                                               sanitize(stop),
                                               sanitize(cnf_slice[0].step)),)
                                        + tuple(old_slices[1:]))

                        else:
                            position = (cnf_slice[0] - nw_steps -
                                         init_l[pos] + store_steps[pos])

                            nw_slice = (sanitize(position),) + \
                                    tuple(old_slices[1:])

                        subtens = tensor.basic.Subtensor(nw_slice)
                        sl_ins = tensor.basic.Subtensor.collapse(
                            nw_slice,
                            lambda entry: isinstance(entry,
                                                     tensor.Variable))
                        new_o = subtens.make_node(new_outs[nw_pos],
                                                  *sl_ins).outputs[0]
                        if new_o.ndim > 0:
                            new_o = new_o[::cnf_slice[1]]
                        old_new += [(old, new_o)]

            # 3.9. Get replace pairs for all other nodes
            if flag_store or global_nsteps is not None:
                for idx, o in enumerate(node.outputs):
                    if not (idx in replaced_outs) and not idx in not_required:
                        nw_pos = compress_map[idx]
                        old_new += [(o, new_outs[nw_pos])]

                fgraph.replace_all_validate_remove(old_new,
                                                   remove=[node],
                                                   reason='scan_save_mem')