def merge(self, nodes): if nodes[0].op.as_while: as_while = True condition = nodes[0].op.outputs[-1] else: as_while = False info = {} info["tap_array"] = [] info["n_seqs"] = sum([nd.op.n_seqs for nd in nodes]) info["n_mit_mot"] = sum([nd.op.n_mit_mot for nd in nodes]) info["n_mit_mot_outs"] = sum([nd.op.n_mit_mot_outs for nd in nodes]) info["mit_mot_out_slices"] = [] info["n_mit_sot"] = sum([nd.op.n_mit_sot for nd in nodes]) info["n_sit_sot"] = sum([nd.op.n_sit_sot for nd in nodes]) info["n_shared_outs"] = sum([nd.op.n_shared_outs for nd in nodes]) info["n_nit_sot"] = sum([nd.op.n_nit_sot for nd in nodes]) info["truncate_gradient"] = nodes[0].op.truncate_gradient info["name"] = "&".join([nd.op.name for nd in nodes]) info["mode"] = nodes[0].op.mode info["inplace"] = False info["gpu"] = False info["as_while"] = as_while info["profile"] = nodes[0].op.profile inner_ins = [] outer_ins = [] inner_outs = [] outer_outs = [] def rename(ls, suffix): for k in ls: if k.name: k.name += str(suffix) return ls for idx, nd in enumerate(nodes): # Seq inner_ins += rename(nd.op.inner_seqs(), idx) outer_ins += rename(nd.op.outer_seqs(nd), idx) for idx, nd in enumerate(nodes): # MitMot inner_ins += rename(nd.op.inner_mitmot(), idx) inner_outs += nd.op.inner_mitmot_outs() info["tap_array"] += nd.op.mitmot_taps() info["mit_mot_out_slices"] += nd.op.mitmot_out_taps() outer_ins += rename(nd.op.outer_mitmot(nd), idx) outer_outs += nd.op.outer_mitmot_outs(nd) for idx, nd in enumerate(nodes): # MitSot inner_ins += rename(nd.op.inner_mitsot(), idx) inner_outs += nd.op.inner_mitsot_outs() info["tap_array"] += nd.op.mitsot_taps() outer_ins += rename(nd.op.outer_mitsot(nd), idx) outer_outs += nd.op.outer_mitsot_outs(nd) for idx, nd in enumerate(nodes): # SitSot inner_ins += rename(nd.op.inner_sitsot(), idx) info["tap_array"] += [[-1] for x in xrange(nd.op.n_sit_sot)] inner_outs += nd.op.inner_sitsot_outs() outer_ins += rename(nd.op.outer_sitsot(nd), idx) outer_outs += nd.op.outer_sitsot_outs(nd) for idx, nd in enumerate(nodes): # Shared inner_ins += rename(nd.op.inner_shared(), idx) outer_ins += rename(nd.op.outer_shared(nd), idx) for idx, nd in enumerate(nodes): # NitSot inner_outs += nd.op.inner_nitsot_outs() outer_ins += rename(nd.op.outer_nitsot(nd), idx) outer_outs += nd.op.outer_nitsot_outs(nd) for idx, nd in enumerate(nodes): # Shared outer_outs += nd.op.outer_shared_outs(nd) inner_outs += nd.op.inner_shared_outs() for idx, nd in enumerate(nodes): # Non Seqs inner_ins += rename(nd.op.inner_non_seqs(), idx) outer_ins += rename(nd.op.outer_non_seqs(nd), idx) # Add back the number of steps outer_ins = [nodes[0].inputs[0]] + outer_ins if as_while: # add the condition inner_outs.append(condition) inner_ins, inner_outs = scan_utils.reconstruct_graph(inner_ins, inner_outs) new_op = scan_op.Scan(inner_ins, inner_outs, info) new_outs = new_op(*outer_ins) if not isinstance(new_outs, (list, tuple)): new_outs = [new_outs] return zip(outer_outs, new_outs)
def remove_constants_and_unused_inputs_scan(node): """ Move constants into the inner graph, and remove unused inputs. Constants that are in the outer graph are represented by a free symbolic variable in the inner graph. If we move them into the inner graph, constant-folding can happen in the inner graph. This is applied only on sequences and non-sequences, not on initial states. """ if not isinstance(node.op, scan_op.Scan): return False op = node.op # We only need to take care of sequences and other arguments st = op.n_seqs st += int(numpy.sum([len(x) for x in op.tap_array[: (op.n_mit_mot + op.n_mit_sot)]])) st += op.n_sit_sot st += op.n_shared_outs op_ins, op_outs = scan_utils.reconstruct_graph(op.inputs, op.outputs) # Corresponds to the initial states, which should stay untouched. # We put those variables aside, and put them back at the end. out_stuff_inner = op_ins[op.n_seqs : st] non_seqs = op_ins[st:] st = op.n_seqs + op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot + op.n_shared_outs + 1 outer_non_seqs = node.inputs[st:] out_stuff_outer = node.inputs[1 + op.n_seqs : st] # To replace constants in the outer graph by clones in the inner graph givens = {} # All the inputs of the inner graph of the new scan nw_inner = [] # Same for the outer graph, initialized w/ number of steps nw_outer = [node.inputs[0]] all_ins = gof.graph.inputs(op_outs) for idx in xrange(op.n_seqs): if ( isinstance(node.inputs[idx + 1], tensor.TensorConstant) and node.inputs[idx + 1].tag.unique_value is not None ): try: # This works if input is a constant that has all entries # equal val = tensor.get_constant_value(node.inputs[idx + 1]) givens[op_ins[idx]] = node.inputs[idx + 1].clone()[0] except TypeError: pass elif op_ins[idx] in all_ins: nw_inner += [op_ins[idx]] nw_outer += [node.inputs[idx + 1]] nw_n_seqs = len(nw_inner) # Add outputs stuff nw_inner += out_stuff_inner nw_outer += out_stuff_outer # Look through non sequences for nw_in, nw_out in zip(non_seqs, outer_non_seqs): if isinstance(nw_out, tensor.Constant): givens[nw_in] = nw_out.clone() elif nw_in in all_ins: nw_inner += [nw_in] nw_outer += [nw_out] if len(nw_inner) != len(op_ins): op_outs = scan_utils.clone(op_outs, replace=givens) nw_info = op.info.copy() nw_info["n_seqs"] = nw_n_seqs # DEBUG CHECK nwScan = scan_op.Scan(nw_inner, op_outs, nw_info) nw_outs = nwScan.make_node(*nw_outer).outputs return nw_outs else: return False
def process_node(self, env, node): # this flag tells if there was any change during the last iterations changed = True clean_inputs, clean_outputs = scan_utils.reconstruct_graph(node.op.inputs, node.op.outputs) local_env = gof.Env(clean_inputs, clean_outputs) max_iterations = 2 * len(local_env.toposort()) + 3 counts = 0 to_remove = [] to_replace = [] replace_with_in = [] replace_with_out = [] op = node.op # Construct the list of non_sequences to simplify a few things st = op.n_seqs st += int(numpy.sum([len(x) for x in op.tap_array[: (op.n_mit_mot + op.n_mit_sot)]])) st += op.n_sit_sot st += op.n_shared_outs non_seqs = clean_inputs[st:] st = op.n_seqs + op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot + op.n_shared_outs + 1 outer_non_seqs = node.inputs[st:] assert len(non_seqs) == len(outer_non_seqs) while changed and counts < max_iterations: counts += 1 changed = False for nd in local_env.toposort(): if ( numpy.all( [(x in non_seqs) or (x.owner in to_remove) or isinstance(x, tensor.Constant) for x in nd.inputs] ) and # we can do this because the assumption is that a # viewOp or deepCopyOp will be just at the end of the # function and not somewhere in the middle .. not isinstance(nd.op, theano.compile.ViewOp) and not isinstance(nd.op, theano.compile.DeepCopyOp) and # and we didn't already looked at this node not nd in to_remove ): # We have a candidate node to removable # Step 1. Reconstruct it on outside to_remove.append(nd) outside_ins = [] for x in nd.inputs: if x in non_seqs: outside_ins += [outer_non_seqs[non_seqs.index(x)]] elif x in to_replace: outside_ins += [replace_with_out[to_replace.index(x)]] elif isinstance(x, theano.Constant): outside_ins += [x.clone()] else: raise Exception( ( "Error in the `scan_pushout_non_seq_" "operations`. The optimization tries " "to move some computation fron scan " "which is not allowed to move. Report " "this on theano-users list" ), x, ) nw_outer_node = nd.op.make_node(*outside_ins) # Step 2. Create variables for replacements for idx, y in enumerate(nd.outputs): y_place_holder = scan_utils.safe_new(y, "_replace") to_replace += [y] replace_with_in += [y_place_holder] assert type(y) == type(nw_outer_node.outputs[idx]) replace_with_out += [nw_outer_node.outputs[idx]] changed = True if counts >= max_iterations: raise Exception( "Error in the `scan_pushout_non_seq_operations`." " The optimization exhausted the maximal number " "of iterations allowed!" ) # We need to check all candidate replacements and choose those that # make sense for us # Step 1. which elements of `to_replace` are used by remaining # components of the inner function clean_to_replace = [] clean_replace_with_in = [] clean_replace_with_out = [] existent_nodes = [nd for nd in local_env.toposort() if nd not in to_remove] to_keep = [] for nd in existent_nodes: to_keep += nd.inputs for idx, out in enumerate(to_replace): if out in to_keep and out.owner not in existent_nodes: clean_to_replace += [out] clean_replace_with_in += [replace_with_in[idx]] clean_replace_with_out += [replace_with_out[idx]] if len(clean_to_replace) > 0: # We can finally put an end to all this madness givens = {} nw_outer = [] nw_inner = [] for to_repl, repl_in, repl_out in zip(clean_to_replace, clean_replace_with_in, clean_replace_with_out): if isinstance(repl_out, theano.Constant): repl_in = repl_out.clone() else: nw_inner += [repl_in] nw_outer += [repl_out] givens[to_repl] = repl_in _op_outs = scan_utils.clone(clean_outputs, replace=givens) _op_ins = clean_inputs + nw_inner op_ins, op_outs = scan_utils.reconstruct_graph(_op_ins, _op_outs) # Reconstruct node nwScan = scan_op.Scan(op_ins, op_outs, op.info) nw_node = nwScan.make_node(*(node.inputs + nw_outer)) env.replace_all_validate(zip(node.outputs, nw_node.outputs), reason="scan_push_computation_out") return True elif to_keep == []: # Nothing in the inner graph should be kept replace_with = {} for idx, out in enumerate(to_replace): if out in local_env.outputs: x = node.outputs[local_env.outputs.index(out)] y = replace_with_out[idx] shape = [y.shape[idx] for idx in xrange(y.ndim)] replace_with[x] = tensor.alloc(y, node.inputs[0], *shape) # We need to add one extra dimension to the outputs env.replace_all_validate(replace_with.items(), reason="scan_push_computation_out") else: return False
def merge(self, nodes): if nodes[0].op.as_while: as_while = True condition = nodes[0].op.outputs[-1] else: as_while = False info = {} info['tap_array'] = [] info['n_seqs'] = sum([nd.op.n_seqs for nd in nodes]) info['n_mit_mot'] = sum([nd.op.n_mit_mot for nd in nodes]) info['n_mit_mot_outs'] = sum([nd.op.n_mit_mot_outs for nd in nodes]) info['mit_mot_out_slices'] = [] info['n_mit_sot'] = sum([nd.op.n_mit_sot for nd in nodes]) info['n_sit_sot'] = sum([nd.op.n_sit_sot for nd in nodes]) info['n_shared_outs'] = sum([nd.op.n_shared_outs for nd in nodes]) info['n_nit_sot'] = sum([nd.op.n_nit_sot for nd in nodes]) info['truncate_gradient'] = nodes[0].op.truncate_gradient info['name'] = '&'.join([nd.op.name for nd in nodes]) info['mode'] = nodes[0].op.mode info['gpu'] = False info['as_while'] = as_while info['profile'] = nodes[0].op.profile inner_ins = [] outer_ins = [] inner_outs = [] outer_outs = [] def rename(ls, suffix): for k in ls: if k.name: k.name += str(suffix) return ls for idx, nd in enumerate(nodes): # Seq inner_ins += rename(nd.op.inner_seqs(nd.op.inputs), idx) outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx) for idx, nd in enumerate(nodes): # MitMot inner_ins += rename(nd.op.inner_mitmot(nd.op.inputs), idx) inner_outs += nd.op.inner_mitmot_outs(nd.op.outputs) info['tap_array'] += nd.op.mitmot_taps() info['mit_mot_out_slices'] += nd.op.mitmot_out_taps() outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx) outer_outs += nd.op.outer_mitmot_outs(nd.outputs) for idx, nd in enumerate(nodes): # MitSot inner_ins += rename(nd.op.inner_mitsot(nd.op.inputs), idx) inner_outs += nd.op.inner_mitsot_outs(nd.op.outputs) info['tap_array'] += nd.op.mitsot_taps() outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx) outer_outs += nd.op.outer_mitsot_outs(nd.outputs) for idx, nd in enumerate(nodes): # SitSot inner_ins += rename(nd.op.inner_sitsot(nd.op.inputs), idx) info['tap_array'] += [[-1] for x in xrange(nd.op.n_sit_sot)] inner_outs += nd.op.inner_sitsot_outs(nd.op.outputs) outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx) outer_outs += nd.op.outer_sitsot_outs(nd.outputs) for idx, nd in enumerate(nodes): # Shared inner_ins += rename(nd.op.inner_shared(nd.op.inputs), idx) outer_ins += rename(nd.op.outer_shared(nd.inputs), idx) for idx, nd in enumerate(nodes): # NitSot inner_outs += nd.op.inner_nitsot_outs(nd.op.outputs) outer_ins += rename(nd.op.outer_nitsot(nd.inputs), idx) outer_outs += nd.op.outer_nitsot_outs(nd.outputs) for idx, nd in enumerate(nodes): # Shared outer_outs += nd.op.outer_shared_outs(nd.outputs) inner_outs += nd.op.inner_shared_outs(nd.op.outputs) for idx, nd in enumerate(nodes): # Non Seqs inner_ins += rename(nd.op.inner_non_seqs(nd.op.inputs), idx) outer_ins += rename(nd.op.outer_non_seqs(nd.inputs), idx) # Add back the number of steps outer_ins = [nodes[0].inputs[0]] + outer_ins if as_while: # add the condition inner_outs.append(condition) inner_ins, inner_outs = scan_utils.reconstruct_graph(inner_ins, inner_outs) new_op = scan_op.Scan(inner_ins, inner_outs, info) new_outs = new_op(*outer_ins) if not isinstance(new_outs, (list, tuple)): new_outs = [new_outs] return zip(outer_outs, new_outs)
def remove_constants_and_unused_inputs_scan(node): ''' Move constants into the inner graph, and remove unused inputs. Constants that are in the outer graph are represented by a free symbolic variable in the inner graph. If we move them into the inner graph, constant-folding can happen in the inner graph. This is applied only on sequences and non-sequences, not on initial states. ''' if not isinstance(node.op, scan_op.Scan): return False op = node.op # We only need to take care of sequences and other arguments st = op.n_seqs st += int(numpy.sum([len(x) for x in op.tap_array[:(op.n_mit_mot + op.n_mit_sot)]])) st += op.n_sit_sot st += op.n_shared_outs op_ins, op_outs = scan_utils.reconstruct_graph(op.inputs, op.outputs) # Corresponds to the initial states, which should stay untouched. # We put those variables aside, and put them back at the end. out_stuff_inner = op_ins[op.n_seqs:st] non_seqs = op_ins[st:] st = (op.n_seqs + op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot + op.n_shared_outs + 1) outer_non_seqs = node.inputs[st:] out_stuff_outer = node.inputs[1 + op.n_seqs:st] # To replace constants in the outer graph by clones in the inner graph givens = {} # All the inputs of the inner graph of the new scan nw_inner = [] # Same for the outer graph, initialized w/ number of steps nw_outer = [node.inputs[0]] all_ins = gof.graph.inputs(op_outs) for idx in xrange(op.n_seqs): if (isinstance(node.inputs[idx + 1], tensor.TensorConstant) and node.inputs[idx + 1].tag.unique_value is not None): try: # This works if input is a constant that has all entries # equal givens[op_ins[idx]] = node.inputs[idx + 1].clone()[0] except TypeError: pass elif op_ins[idx] in all_ins: nw_inner += [op_ins[idx]] nw_outer += [node.inputs[idx + 1]] nw_n_seqs = len(nw_inner) # Add outputs stuff nw_inner += out_stuff_inner nw_outer += out_stuff_outer # Look through non sequences for nw_in, nw_out in zip(non_seqs, outer_non_seqs): if isinstance(nw_out, tensor.Constant): givens[nw_in] = nw_out.clone() elif nw_in in all_ins: nw_inner += [nw_in] nw_outer += [nw_out] if len(nw_inner) != len(op_ins): op_outs = scan_utils.clone(op_outs, replace=givens) nw_info = copy.deepcopy(op.info) nw_info['n_seqs'] = nw_n_seqs # DEBUG CHECK nwScan = scan_op.Scan(nw_inner, op_outs, nw_info) nw_outs = nwScan.make_node(*nw_outer).outputs return nw_outs else: return False
def process_node(self, fgraph, node): # this flag tells if there was any change during the last iterations changed = True clean_inputs, clean_outputs = scan_utils.reconstruct_graph( node.op.inputs, node.op.outputs) local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs) max_iterations = 2 * len(local_fgraph.toposort()) + 3 counts = 0 to_remove = [] to_replace = [] replace_with_in = [] replace_with_out = [] op = node.op # Construct the list of non_sequences to simplify a few things st = op.n_seqs st += int(numpy.sum([len(x) for x in op.tap_array[:(op.n_mit_mot + op.n_mit_sot)]])) st += op.n_sit_sot st += op.n_shared_outs non_seqs = clean_inputs[st:] st = (op.n_seqs + op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot + op.n_shared_outs + 1) outer_non_seqs = node.inputs[st:] assert len(non_seqs) == len(outer_non_seqs) while changed and counts < max_iterations: counts += 1 changed = False for nd in local_fgraph.toposort(): if (numpy.all([(x in non_seqs) or (x.owner in to_remove) or isinstance(x, tensor.Constant) for x in nd.inputs]) and # we can do this because the assumption is that a # viewOp or deepCopyOp will be just at the end of the # function and not somewhere in the middle .. not isinstance(nd.op, theano.compile.ViewOp) and not isinstance(nd.op, theano.compile.DeepCopyOp) and # and we didn't already looked at this node not nd in to_remove): # We have a candidate node to removable # Step 1. Reconstruct it on outside to_remove.append(nd) outside_ins = [] for x in nd.inputs: if x in non_seqs: outside_ins += [outer_non_seqs[non_seqs.index(x)]] elif x in to_replace: outside_ins += [ replace_with_out[to_replace.index(x)]] elif isinstance(x, theano.Constant): outside_ins += [x.clone()] else: raise Exception( ('Error in the `scan_pushout_non_seq_' 'operations`. The optimization tries ' 'to move some computation fron scan ' 'which is not allowed to move. Report ' 'this on theano-users list'), x) outside_ins = [x.type.filter_variable(y) for x,y in zip(nd.inputs, outside_ins)] nw_outer_node = nd.op.make_node(*outside_ins) # Step 2. Create variables for replacements for idx, y in enumerate(nd.outputs): y_place_holder = scan_utils.safe_new(y, '_replace') to_replace += [y] replace_with_in += [y_place_holder] assert type(y) == type(nw_outer_node.outputs[idx]) replace_with_out += [nw_outer_node.outputs[idx]] changed = True if counts >= max_iterations: raise Exception('Error in the `scan_pushout_non_seq_operations`.' ' The optimization exhausted the maximal number ' 'of iterations allowed!') # We need to check all candidate replacements and choose those that # make sense for us # Step 1. which elements of `to_replace` are used by remaining # components of the inner function clean_to_replace = [] clean_replace_with_in = [] clean_replace_with_out = [] existent_nodes = [nd for nd in local_fgraph.toposort() if nd not in to_remove] to_keep = [] for nd in existent_nodes: to_keep += nd.inputs for idx, out in enumerate(to_replace): if out in to_keep and out.owner not in existent_nodes: clean_to_replace += [out] clean_replace_with_in += [replace_with_in[idx]] clean_replace_with_out += [replace_with_out[idx]] if len(clean_to_replace) > 0: # We can finally put an end to all this madness givens = {} nw_outer = [] nw_inner = [] for to_repl, repl_in, repl_out in zip(clean_to_replace, clean_replace_with_in, clean_replace_with_out): if isinstance(repl_out, theano.Constant): repl_in = repl_out.clone() else: nw_inner += [repl_in] nw_outer += [repl_out] givens[to_repl] = repl_in _op_outs = scan_utils.clone(clean_outputs, replace=givens) _op_ins = clean_inputs + nw_inner op_ins, op_outs = scan_utils.reconstruct_graph(_op_ins, _op_outs) # Reconstruct node nwScan = scan_op.Scan(op_ins, op_outs, op.info) nw_node = nwScan.make_node(* (node.inputs + nw_outer)) fgraph.replace_all_validate_remove( zip(node.outputs, nw_node.outputs), remove=[node], reason='scan_push_computation_out') return True elif to_keep == []: # Nothing in the inner graph should be kept replace_with = {} for idx, out in enumerate(to_replace): if out in local_fgraph.outputs: x = node.outputs[local_fgraph.outputs.index(out)] y = replace_with_out[idx] shape = [y.shape[idx] for idx in xrange(y.ndim)] replace_with[x] = tensor.alloc(y, node.inputs[0], *shape) # We need to add one extra dimension to the outputs if replace_with: fgraph.replace_all_validate_remove( replace_with.items(), remove=[node], reason='scan_push_computation_out') else: return False