def process_node(self, env, node): # helpful functions def select_min(x, y): if x is None: return y if y is None: return x return tensor.minimum(x, y) def select_max(x, y): if x is None: return y if y is None: return x return tensor.maximum(x, y) def sanitize(x): if x is None: return None else: return tensor.as_tensor_variable(x) shape_of = node.env.shape_feature.shape_of # 1. Initialization of variables # Note 1) We do not actually care about outputs representing shared # variables (those have no intermediate values) so it is safer to # ignore them and not change them in any way. To simplify the # optimizations I construct the variable ``c_outs`` ( that counts # outputs up to those we care) and the list ``init_l`` which for any # output we care says the length of its initial state. Note that # defining ``init_l`` for mit_mot sequences is a bit trickier but # it is safe to set it to 0 op = node.op c_outs = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot init_l = [0 for x in xrange(op.n_mit_mot)] init_l += [abs(numpy.min(v)) for v in op.tap_array[op.n_mit_mot :]] init_l += [0 for x in xrange(op.n_nit_sot)] # 2. Check the clients of each output and see for how many steps # does scan need to run # This comparison checks if there is any uncounted output, which # can only be an output corresponding to a shared variable # 2.1 Initialize # global_nsteps is a dictionary having two fields ( 'real' deals # with int values, 'sym' with symbolic ones) or None # given that a scan op has k outputs o_1, .. o_k and each # output has n_j clients c_1^1, c_1^2, .. c_1^{n_1}, c_2^1, .., # global_nsteps is None if any of the clients is different # from a subtensor or its real and sym field equal to # max(c_i_j.idx_list[0].stop), meaning store up to which maximal # index(step) for any output scan actually needs to compute # In other words n_steps should be equal to this maximal ! # Note: if we have a shared variable that gets updated at every step # of the loop, reducing the number of steps will affect the the # value of the shared variable after the loop so we need not to # change the number of steps in that case. To do this we set # global_nsteps to None which is seen as a flag that nothing needs # to be done if len(node.outputs) <= c_outs: global_nsteps = {"real": -1, "sym": []} else: global_nsteps = None # Keeps track of the original slices that each client represent slices = [None for o in node.outputs] # A list for each output indicating how many intermediate values # should be stored. If negative it means none of the intermediate # values (i.e. the output can be removed since it is not used # afterwards in the computations), if 0 it means that all # intermediate values are required, otherwise is up to that number # of intermediate values # Note that for mit_mot outputs and shared outputs we can not change # the number of intermediate steps stored without affecting the # result of the op store_steps = [0 for o in xrange(op.n_mit_mot)] store_steps += [-1 for o in node.outputs[op.n_mit_mot : c_outs]] # Flag that says if an input has changed and we need to do something # or not flag_store = False # 2.2 Loop over the clients for i, out in enumerate(node.outputs[:c_outs]): # look at all its clients slices[i] = [] for cl, _ in out.clients: # 2.1 outputs of the function # => output needs all its intermediate values if type(cl) == str: # if the node is actually an output, then # we need to store the entire thing global_nsteps = None slices[i] = None break # 2.2 non-subtensor nodes # => output needs all its intermediate values elif not isinstance(cl.op, tensor.basic.Subtensor): global_nsteps = None slices[i] = None break # 2.3 subtensor nodes # => output might need to store just a subset of its values else: # 2.3.1 extract idx list of subtensor this_slice = tensor.basic.get_idx_list(cl.inputs, cl.op.idx_list) if this_slice == None: # if unable to extract idx_list # => outputs needs all its intermediate values global_nsteps = None slices[i] = None break # 2.3.2 extract the begin/end of the first dimension if i > op.n_mit_mot: try: length = shape_of[out][0] except Exception: length = node.inputs[0] + init_l[i] else: try: length = shape_of[out][0] except Exception: length = out.shape[0] cf_slice = tensor.basic.get_canonical_form_slice(this_slice[0], length) slices[i] += [(cf_slice, this_slice)] if isinstance(this_slice[0], slice) and this_slice[0].stop is None: global_nsteps = None break if isinstance(cf_slice[0], slice): stop = tensor.basic.extract_constant(cf_slice[0].stop) else: stop = tensor.basic.extract_constant(cf_slice[0]) + 1 if stop == sys.maxint or stop == length: stop = None else: # there is a **gotcha** here ! Namely, scan returns an # array that contains the initial state of the output # as well. Which means that if have a initial state of # length 3, and you look for 5 steps you get an output # y of length 8. If you only use y[:5], this does not # mean that you only need to loop for 5 steps but # actually only for 2 steps ( the first 3 are the # initial state) stop = stop - init_l[i] # 2.3.3 we might get away with less number of steps if stop is not None and global_nsteps is not None: # yes if it is a tensor if isinstance(stop, tensor.Variable): global_nsteps["sym"] += [stop] # not if it is maxint elif type(stop) is int and stop == sys.maxint: global_nsteps = None # yes if it is a int k, 0 < k < maxint elif type(stop) is int and global_nsteps["real"] < stop: global_nsteps["real"] = stop # yes if it is a int k, 0 < k < maxint elif type(stop) is int and stop > 0: pass # not otherwise else: global_nsteps = None # 2.3. Analyze global_nsteps to figure out for how many steps scan # needs to iterate if global_nsteps is not None: nw_steps = node.inputs[0] # there are some symbolic tensors that limit the number of # steps if len(global_nsteps["sym"]) == 0: sym_steps = None else: sym_steps = global_nsteps["sym"][0] for c in global_nsteps["sym"][1:]: sym_steps = tensor.maximum(sym_steps, c) if global_nsteps["real"] >= 0: real_steps = global_nsteps["real"] else: real_steps = None nw_steps = select_min(select_max(sym_steps, real_steps), node.inputs[0]) else: nw_steps = node.inputs[0] global_nsteps = None # 2.4 Loop over the clients again now looking just to see how many # intermediate steps to store for i, out in enumerate(node.outputs[:c_outs]): # look at all its clients for cl, _ in out.clients: if type(cl) == str: store_steps[i] = 0 break elif not isinstance(cl.op, tensor.basic.Subtensor): store_steps[i] = 0 break else: this_slice = tensor.basic.get_idx_list(cl.inputs, cl.op.idx_list) if this_slice == None: store_steps[i] = 0 break if isinstance(this_slice[0], slice) and this_slice[0].start is None: store_steps[i] = 0 break if i > op.n_mit_mot: length = node.inputs[0] + init_l[i] else: try: length = shape_of[out][0] except Exception: length = out.shape[0] cf_slice = tensor.basic.get_canonical_form_slice(this_slice[0], length) if isinstance(cf_slice[0], slice): start = tensor.basic.extract_constant(cf_slice[0].start) else: start = tensor.basic.extract_constant(cf_slice[0]) if start == 0 or store_steps[i] == 0: store_steps[i] = 0 else: pval = select_max(nw_steps - start + init_l[i], init_l[i]) if store_steps[i] != -1: pval = select_max(pval, store_steps[i]) store_steps[i] = pval flag_store = True orphane_outs = [i for i, x in enumerate(store_steps) if (type(x) is int) and (x < 0)] flag_store = flag_store or (len(orphane_outs) > 0) # 3. is there anything to change ? if flag_store or global_nsteps is not None: # 3.1 initialize inputs for the new scan old_outputs = [] nw_inputs = list(node.inputs) nw_inputs[0] = nw_steps # 3.2 check orphane outputs to see if we can eliminate any required, not_required = scan_utils.scan_can_remove_outs(node.op, orphane_outs) # 3.3. compose replace pairs for those nodes that need not # to store everything in memory ( or ar orphane and required # by the inner function .. ) replaced_outs = [] offset = 1 + op.n_seqs + op.n_mit_mot for idx, _val in enumerate(store_steps[op.n_mit_mot :]): i = idx + op.n_mit_mot if not (type(_val) is int and _val <= 0 and i not in required): if idx + op.n_mit_mot in required: val = 1 else: val = _val # If the memory for this output has been pre-allocated # before going into the scan op (by an alloc node) if idx < op.n_mit_sot + op.n_sit_sot: # In case the input is still an alloc node, we # actually have two options: # a) the input is a set_subtensor, in that case we # can replace the initial tensor by a slice, # b) it is not, and we simply take a slice of it. if nw_inputs[offset + idx].owner and isinstance( nw_inputs[offset + idx].owner.op, tensor.IncSubtensor ): _nw_input = nw_inputs[offset + idx].owner.inputs[1] tmp = pre_greedy_local_optimizer(list_opt_slice, tensor.as_tensor_variable(val - init_l[i])) tmp = pre_constant_merge([tmp])[0] nw_input = scan_utils.expand(_nw_input, tmp) else: tmp = pre_greedy_local_optimizer(list_opt_slice, tensor.as_tensor_variable(val)) tmp = pre_constant_merge([tmp])[0] nw_input = nw_inputs[offset + idx][:tmp] nw_inputs[offset + idx] = nw_input replaced_outs.append(op.n_mit_mot + idx) odx = op.n_mit_mot + idx old_outputs += [(odx, [x[0].outputs[0] for x in node.outputs[odx].clients])] # If there is no memory pre-allocated for this output elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot: pos = op.n_mit_mot + idx + op.n_seqs + 1 + op.n_shared_outs if nw_inputs[pos] == node.inputs[0]: nw_inputs[pos] = val odx = op.n_mit_mot + idx replaced_outs.append(odx) old_outputs += [(odx, [x[0].outputs[0] for x in node.outputs[odx].clients])] # 3.4. Recompute inputs for everything else based on the new # number of steps if global_nsteps is not None: for idx, val in enumerate(store_steps[op.n_mit_mot :]): if val == 0: if idx < op.n_mit_sot + op.n_sit_sot: _nw_input = nw_inputs[offset + idx].owner.inputs[1] odx = op.n_mit_mot + idx nw_input = scan_utils.expand(_nw_input, nw_steps) nw_inputs[offset + idx] = nw_input elif idx < (op.n_mit_sot + op.n_sit_sot + op.n_nit_sot): in_idx = offset + idx + op.n_shared_outs if nw_inputs[in_idx] == node.inputs[0]: nw_inputs[in_idx] = nw_steps odx = op.n_mit_mot + idx # 3.5 Remove unwanted orphane outputs (inps, outs, info, node_ins, compress_map) = scan_utils.compress_outs(op, not_required, nw_inputs) inv_compress_map = {} for k, v in compress_map.items(): inv_compress_map[v] = k node_ins = [pre_greedy_local_optimizer(list_opt_slice, x) for x in node_ins] node_ins = pre_constant_merge(node_ins) # 3.6 Compose the new scan # I need to make sure I'm not reapplying the same optimization # twice since bad things usually happen if I do that info["_scan_merge_visited"] = True new_outs = scan_op.Scan(inps, outs, info).make_node(*node_ins).outputs old_new = [] # 3.7 Get replace pairs for those outputs that do not change # the number of intermediate steps stored for idx, sl in enumerate(slices): if global_nsteps and sl is not None and store_steps[idx] == 0: for hdx, cl in enumerate(node.outputs[idx].clients): cnf_slice, old_slices = sl[hdx] # Sanitize the nw_slice by converting ints back into # constants :) I only need to do this for the first # slice since that is the only slice if isinstance(cnf_slice[0], slice): fslice = slice( sanitize(cnf_slice[0].start), sanitize(cnf_slice[0].stop), sanitize(cnf_slice[0].step) ) else: fslice = sanitize(cnf_slice[0]) nw_slice = (fslice,) + tuple(old_slices[1:]) nw_pos = inv_compress_map[idx] nw_out = new_outs[nw_pos] subtens = tensor.basic.Subtensor(nw_slice) # slice inputs sl_ins = tensor.basic.Subtensor.collapse( nw_slice, lambda entry: isinstance(entry, tensor.Variable) ) new_o = subtens.make_node(new_outs[nw_pos], *sl_ins).outputs[0] if new_o.ndim > 0: new_o = new_o[:: cnf_slice[1]] replaced_outs.append(idx) old_new += [(cl[0].outputs[0], new_o)] # 3.8. Get replace pairs for those outputs that change # the number of stored intermediate steps for pos, old_outs in old_outputs: if len(old_outs) > 0: nw_pos = compress_map[pos] nw_out = new_outs[nw_pos] for k, old in enumerate(old_outs): # Get the correct slice cnf_slice, old_slices = slices[pos][k] if type(cnf_slice[0]) is slice: start = cnf_slice[0].start - nw_steps - init_l[pos] + store_steps[pos] if cnf_slice[0].stop is not None and cnf_slice[0].stop != sys.maxint: stop = cnf_slice[0].stop - nw_steps - init_l[pos] + store_steps[pos] else: stop = None nw_slice = (slice(sanitize(start), sanitize(stop), sanitize(cnf_slice[0].step)),) + tuple( old_slices[1:] ) else: position = cnf_slice[0] - nw_steps - init_l[pos] + store_steps[pos] nw_slice = (sanitize(position),) + tuple(old_slices[1:]) subtens = tensor.basic.Subtensor(nw_slice) sl_ins = tensor.basic.Subtensor.collapse( nw_slice, lambda entry: isinstance(entry, tensor.Variable) ) new_o = subtens.make_node(new_outs[nw_pos], *sl_ins).outputs[0] if new_o.ndim > 0: new_o = new_o[:: cnf_slice[1]] old_new += [(old, new_o)] # 3.9. Get replace pairs for all other nodes if flag_store or global_nsteps is not None: for idx, o in enumerate(node.outputs): if not (idx in replaced_outs) and not idx in not_required: nw_pos = compress_map[idx] old_new += [(o, new_outs[nw_pos])] env.replace_all_validate(old_new, reason="scan_save_mem")
def process_node(self, fgraph, node): # helpful functions def select_min(x, y): if x is None: return y if y is None: return x return tensor.minimum(x, y) def select_max(x, y): if x is None: return y if y is None: return x return tensor.maximum(x, y) def sanitize(x): if x is None: return None else: return tensor.as_tensor_variable(x) if hasattr(fgraph, 'shape_feature'): shape_of = node.fgraph.shape_feature.shape_of else: # Each access to shape_of is in a try..except block in order to # use a default version when the variable is not in the shape_of # dictionary. shape_of = {} # 1. Initialization of variables # Note 1) We do not actually care about outputs representing shared # variables (those have no intermediate values) so it is safer to # ignore them and not change them in any way. To simplify the # optimizations I construct the variable ``c_outs`` ( that counts # outputs up to those we care) and the list ``init_l`` which for any # output we care says the length of its initial state. Note that # defining ``init_l`` for mit_mot sequences is a bit trickier but # it is safe to set it to 0 op = node.op c_outs = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot init_l = [0 for x in xrange(op.n_mit_mot)] init_l += [abs(numpy.min(v)) for v in op.tap_array[op.n_mit_mot:]] init_l += [0 for x in xrange(op.n_nit_sot)] # 2. Check the clients of each output and see for how many steps # does scan need to run # This comparison checks if there is any uncounted output, which # can only be an output corresponding to a shared variable # 2.1 Initialize # global_nsteps is a dictionary having two fields ( 'real' deals # with int values, 'sym' with symbolic ones) or None # given that a scan op has k outputs o_1, .. o_k and each # output has n_j clients c_1^1, c_1^2, .. c_1^{n_1}, c_2^1, .., # global_nsteps is None if any of the clients is different # from a subtensor or its real and sym field equal to # max(c_i_j.idx_list[0].stop), meaning store up to which maximal # index(step) for any output scan actually needs to compute # In other words n_steps should be equal to this maximal ! # Note: if we have a shared variable that gets updated at every step # of the loop, reducing the number of steps will affect the the # value of the shared variable after the loop so we need not to # change the number of steps in that case. To do this we set # global_nsteps to None which is seen as a flag that nothing needs # to be done assert len(node.outputs) >= c_outs if len(node.outputs) == c_outs: global_nsteps = {'real': -1, 'sym': []} else: global_nsteps = None # Keeps track of the original slices that each client represent slices = [None for o in node.outputs] # A list for each output indicating how many intermediate values # should be stored. If negative it means none of the intermediate # values (i.e. the output can be removed since it is not used # afterwards in the computations), if 0 it means that all # intermediate values are required, otherwise is up to that number # of intermediate values # Note that for mit_mot outputs and shared outputs we can not change # the number of intermediate steps stored without affecting the # result of the op store_steps = [0 for o in xrange(op.n_mit_mot)] store_steps += [-1 for o in node.outputs[op.n_mit_mot:c_outs]] # Flag that says if an input has changed and we need to do something # or not flag_store = False # 2.2 Loop over the clients for i, out in enumerate(node.outputs[:c_outs]): # look at all its clients slices[i] = [] for cl, _ in out.clients: # 2.1 outputs of the function #=> output needs all its intermediate values if type(cl) == str: # if the node is actually an output, then # we need to store the entire thing global_nsteps = None slices[i] = None break # 2.2 non-subtensor nodes #=> output needs all its intermediate values elif not isinstance(cl.op, tensor.basic.Subtensor): global_nsteps = None slices[i] = None break # 2.3 subtensor nodes #=> output might need to store just a subset of its values else: # 2.3.1 extract idx list of subtensor this_slice = tensor.basic.get_idx_list(cl.inputs, cl.op.idx_list) if this_slice is None: # if unable to extract idx_list #=> outputs needs all its intermediate values global_nsteps = None slices[i] = None break # 2.3.2 extract the begin/end of the first dimension if i >= op.n_mit_mot: try: length = shape_of[out][0] except KeyError: length = node.inputs[0] + init_l[i] else: try: length = shape_of[out][0] except KeyError: length = out.shape[0] cf_slice = tensor.basic.get_canonical_form_slice( this_slice[0], length) slices[i] += [(cf_slice, this_slice)] if (isinstance(this_slice[0], slice) and this_slice[0].stop is None): global_nsteps = None if isinstance(cf_slice[0], slice): stop = tensor.basic.extract_constant(cf_slice[0].stop) else: stop = tensor.basic.extract_constant(cf_slice[0]) + 1 if stop == maxsize or stop == length: stop = None else: # there is a **gotcha** here ! Namely, scan returns an # array that contains the initial state of the output # as well. Which means that if have a initial state of # length 3, and you look for 5 steps you get an output # y of length 8. If you only use y[:5], this does not # mean that you only need to loop for 5 steps but # actually only for 2 steps ( the first 3 are the # initial state) stop = stop - init_l[i] # 2.3.3 we might get away with less number of steps if stop is not None and global_nsteps is not None: # yes if it is a tensor if isinstance(stop, tensor.Variable): global_nsteps['sym'] += [stop] # not if it is maxsize elif (type(stop) in (int, long) and stop == maxsize): global_nsteps = None # yes if it is a int k, 0 < k < maxsize elif (type(stop) in (int, long) and global_nsteps['real'] < stop): global_nsteps['real'] = stop # yes if it is a int k, 0 < k < maxsize elif (type(stop) in (int, long) and stop > 0): pass # not otherwise else: global_nsteps = None # 2.3. Analyze global_nsteps to figure out for how many steps scan # needs to iterate if global_nsteps is not None: nw_steps = node.inputs[0] # there are some symbolic tensors that limit the number of # steps if len(global_nsteps['sym']) == 0: sym_steps = None else: sym_steps = global_nsteps['sym'][0] for c in global_nsteps['sym'][1:]: sym_steps = tensor.maximum(sym_steps, c) if global_nsteps['real'] >= 0: real_steps = global_nsteps['real'] else: real_steps = None nw_steps = select_min(select_max(sym_steps, real_steps), node.inputs[0]) else: nw_steps = node.inputs[0] global_nsteps = None # 2.4 Loop over the clients again now looking just to see how many # intermediate steps to store for i, out in enumerate(node.outputs[:c_outs]): # look at all its clients for cl, _ in out.clients: if type(cl) == str: store_steps[i] = 0 break elif not isinstance(cl.op, tensor.basic.Subtensor): store_steps[i] = 0 break else: this_slice = tensor.basic.get_idx_list(cl.inputs, cl.op.idx_list) if this_slice is None: store_steps[i] = 0 break if (isinstance(this_slice[0], slice) and this_slice[0].start is None): store_steps[i] = 0 break if i > op.n_mit_mot: length = node.inputs[0] + init_l[i] else: try: length = shape_of[out][0] except KeyError: length = out.shape[0] cf_slice = tensor.basic.get_canonical_form_slice( this_slice[0], length) if isinstance(cf_slice[0], slice): start = tensor.basic.extract_constant( cf_slice[0].start) else: start = tensor.basic.extract_constant(cf_slice[0]) if start == 0 or store_steps[i] == 0: store_steps[i] = 0 else: pval = select_max(nw_steps - start + init_l[i], init_l[i]) if store_steps[i] != -1: pval = select_max(pval, store_steps[i]) store_steps[i] = pval flag_store = True orphane_outs = [i for i, x in enumerate(store_steps) if (type(x) is int) and (x < 0)] flag_store = flag_store or (len(orphane_outs) > 0) # 3. is there anything to change ? if (flag_store or global_nsteps is not None): # 3.1 initialize inputs for the new scan old_outputs = [] nw_inputs = list(node.inputs) nw_inputs[0] = nw_steps # 3.2 check orphane outputs to see if we can eliminate any required, not_required = \ scan_utils.scan_can_remove_outs(node.op, orphane_outs) # 3.3. compose replace pairs for those nodes that need not # to store everything in memory ( or ar orphane and required # by the inner function .. ) replaced_outs = [] offset = 1 + op.n_seqs + op.n_mit_mot for idx, _val in enumerate(store_steps[op.n_mit_mot:]): i = idx + op.n_mit_mot if not(type(_val) is int and _val <= 0 and i not in required): if idx + op.n_mit_mot in required: val = 1 else: val = _val # If the memory for this output has been pre-allocated # before going into the scan op (by an alloc node) if idx < op.n_mit_sot + op.n_sit_sot: # In case the input is still an alloc node, we # actually have two options: # a) the input is a set_subtensor, in that case we # can replace the initial tensor by a slice, # b) it is not, and we simply take a slice of it. #TODO: commit change below with Razvan if (nw_inputs[offset + idx].owner and isinstance(nw_inputs[offset + idx].owner.op, tensor.IncSubtensor) and isinstance(nw_inputs[offset+idx].owner.op.idx_list[0], slice)): _nw_input = nw_inputs[offset + idx].owner.inputs[1] val = tensor.as_tensor_variable(val) initl = tensor.as_tensor_variable(init_l[i]) tmp = pre_greedy_local_optimizer(list_opt_slice, tensor.maximum(val - initl, 0)) tmp = pre_constant_merge([tmp])[0] nw_input = scan_utils.expand(_nw_input, tmp) else: tmp = tensor.as_tensor_variable(val) initl = tensor.as_tensor_variable(init_l[i]) tmp = tensor.maximum(tmp, initl) tmp = pre_greedy_local_optimizer(list_opt_slice, tmp) tmp = pre_constant_merge([tmp])[0] nw_input = nw_inputs[offset + idx][:tmp] nw_inputs[offset + idx] = nw_input replaced_outs.append(op.n_mit_mot + idx) odx = op.n_mit_mot + idx old_outputs += [(odx, [x[0].outputs[0] for x in node.outputs[odx].clients])] # If there is no memory pre-allocated for this output elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot: pos = (op.n_mit_mot + idx + op.n_seqs + 1 + op.n_shared_outs) if nw_inputs[pos] == node.inputs[0]: nw_inputs[pos] = val odx = op.n_mit_mot + idx replaced_outs.append(odx) old_outputs += [(odx, [x[0].outputs[0] for x in node.outputs[odx].clients])] # 3.4. Recompute inputs for everything else based on the new # number of steps if global_nsteps is not None: for idx, val in enumerate(store_steps[op.n_mit_mot:]): if val == 0: if idx < op.n_mit_sot + op.n_sit_sot: _nw_input = nw_inputs[offset + idx].owner.inputs[1] odx = op.n_mit_mot + idx nw_input = scan_utils.expand(_nw_input, nw_steps) nw_inputs[offset + idx] = nw_input elif idx < (op.n_mit_sot + op.n_sit_sot + op.n_nit_sot): in_idx = offset + idx + op.n_shared_outs if nw_inputs[in_idx] == node.inputs[0]: nw_inputs[in_idx] = nw_steps odx = op.n_mit_mot + idx # 3.5 Remove unwanted orphane outputs (inps, outs, info, node_ins, compress_map) = \ scan_utils.compress_outs(op, not_required, nw_inputs) inv_compress_map = {} for k, v in compress_map.items(): inv_compress_map[v] = k node_ins = [pre_greedy_local_optimizer(list_opt_slice, x) for x in node_ins] node_ins = pre_constant_merge(node_ins) # 3.6 Compose the new scan # I need to make sure I'm not reapplying the same optimization # twice since bad things usually happen if I do that info['_scan_savemem_visited'] = True new_outs = scan_op.Scan(inps, outs, info).make_node(*node_ins).outputs old_new = [] # 3.7 Get replace pairs for those outputs that do not change # the number of intermediate steps stored for idx, sl in enumerate(slices): if global_nsteps and sl is not None and store_steps[idx] == 0: for hdx, cl in enumerate(node.outputs[idx].clients): cnf_slice, old_slices = sl[hdx] # Sanitize the nw_slice by converting ints back into # constants :) I only need to do this for the first # slice since that is the only slice if isinstance(cnf_slice[0], slice): fslice = slice( sanitize(cnf_slice[0].start), sanitize(cnf_slice[0].stop), sanitize(cnf_slice[0].step)) else: fslice = sanitize(cnf_slice[0]) nw_slice = (fslice,) + tuple(old_slices[1:]) nw_pos = inv_compress_map[idx] subtens = tensor.basic.Subtensor(nw_slice) # slice inputs sl_ins = tensor.basic.Subtensor.collapse( nw_slice, lambda entry: isinstance(entry, tensor.Variable)) new_o = subtens.make_node(new_outs[nw_pos], *sl_ins).outputs[0] if new_o.ndim > 0: new_o = new_o[::cnf_slice[1]] replaced_outs.append(idx) old_new += [(cl[0].outputs[0], new_o)] # 3.8. Get replace pairs for those outputs that change # the number of stored intermediate steps for pos, old_outs in old_outputs: if len(old_outs) > 0: nw_pos = compress_map[pos] for k, old in enumerate(old_outs): # Get the correct slice cnf_slice, old_slices = slices[pos][k] if type(cnf_slice[0]) is slice: start = (cnf_slice[0].start - nw_steps - init_l[pos] + store_steps[pos]) if (cnf_slice[0].stop is not None and cnf_slice[0].stop != maxsize): stop = (cnf_slice[0].stop - nw_steps - init_l[pos] + store_steps[pos]) else: stop = None nw_slice = ((slice(sanitize(start), sanitize(stop), sanitize(cnf_slice[0].step)),) + tuple(old_slices[1:])) else: position = (cnf_slice[0] - nw_steps - init_l[pos] + store_steps[pos]) nw_slice = (sanitize(position),) + \ tuple(old_slices[1:]) subtens = tensor.basic.Subtensor(nw_slice) sl_ins = tensor.basic.Subtensor.collapse( nw_slice, lambda entry: isinstance(entry, tensor.Variable)) new_o = subtens.make_node(new_outs[nw_pos], *sl_ins).outputs[0] if new_o.ndim > 0: new_o = new_o[::cnf_slice[1]] old_new += [(old, new_o)] # 3.9. Get replace pairs for all other nodes if flag_store or global_nsteps is not None: for idx, o in enumerate(node.outputs): if not (idx in replaced_outs) and not idx in not_required: nw_pos = compress_map[idx] old_new += [(o, new_outs[nw_pos])] fgraph.replace_all_validate_remove(old_new, remove=[node], reason='scan_save_mem')