def local_transform(node): if node in nodes_seen: return False # importing Scan into module scope would be circular from theano.scan_module.scan_op import Scan from theano.compile import OpFromGraph if isinstance(node.op, (Scan, OpFromGraph)): # recurse on the inner graph (new_inner_inputs, new_outer_inputs, new_inner_outputs) = _map_variables_inner( wrapped_replacer, inner_inputs=node.op.inputs, outer_inputs=node.inputs, inner_outputs=node.op.outputs, containing_op=node.op) # reinstantiate the op if isinstance(node.op, Scan): new_op = Scan( new_inner_inputs, new_inner_outputs, node.op.info, # FIXME: infer this someday? typeConstructor=None) elif isinstance(node.op, OpFromGraph): new_op = OpFromGraph(new_inner_inputs, new_inner_outputs, **node.op.kwargs) # make a new node to replace the old one new_node = new_op.make_node(*new_outer_inputs) nodes_seen.add(new_node) return new_node.outputs else: nodes_seen.add(node) return list(map(wrapped_replacer, node.outputs))
def local_transform(node): if node in nodes_seen: return False # importing Scan into module scope would be circular from theano.scan_module.scan_op import Scan from theano.compile import OpFromGraph if isinstance(node.op, (Scan, OpFromGraph)): # recurse on the inner graph (new_inner_inputs, new_outer_inputs, new_inner_outputs) = _map_variables_inner( wrapped_replacer, inner_inputs=node.op.inputs, outer_inputs=node.inputs, inner_outputs=node.op.outputs, containing_op=node.op, ) # reinstantiate the op if isinstance(node.op, Scan): new_op = Scan( new_inner_inputs, new_inner_outputs, node.op.info, # FIXME: infer this someday? typeConstructor=None, ) elif isinstance(node.op, OpFromGraph): new_op = OpFromGraph(new_inner_inputs, new_inner_outputs, **node.op.kwargs) # make a new node to replace the old one new_node = new_op.make_node(*new_outer_inputs) nodes_seen.add(new_node) return new_node.outputs else: nodes_seen.add(node) return list(map(wrapped_replacer, node.outputs))
def construct_scan(scan_args): scan_op = Scan(scan_args.inner_inputs, scan_args.inner_outputs, scan_args.info) scan_out = scan_op(*scan_args.outer_inputs) if not isinstance(scan_out, list): scan_out = [scan_out] return scan_out
def export(node, extra_inner_outputs): assert isinstance(node.op, Scan) # this is ugly but we can't use scan_utils.scan_args because that # clones the inner graph and then extra_inner_outputs aren't in # there anymore old_inner_inputs = node.op.inputs old_inner_outputs = node.op.outputs old_outer_inputs = node.inputs new_inner_inputs = list(old_inner_inputs) new_inner_outputs = list(old_inner_outputs) new_outer_inputs = list(old_outer_inputs) new_info = copy.deepcopy(node.op.info) # put the new inner outputs in the right place in the output list and # update info new_info["n_nit_sot"] += len(extra_inner_outputs) yuck = len(old_inner_outputs) - new_info["n_shared_outs"] new_inner_outputs[yuck:yuck] = extra_inner_outputs # in step 8, theano.scan() adds an outer input (being the actual # number of steps) for each nitsot. we need to do the same thing. # note these don't come with corresponding inner inputs. offset = (1 + node.op.n_seqs + node.op.n_mit_mot + node.op.n_mit_sot + node.op.n_sit_sot + node.op.n_shared_outs) # the outer input is just the actual number of steps, which is # always available as the first outer input. new_outer_inputs[offset:offset] = [new_outer_inputs[0] ] * len(extra_inner_outputs) new_op = Scan(new_inner_inputs, new_inner_outputs, new_info) outer_outputs = new_op(*new_outer_inputs) # grab the outputs we actually care about extra_outer_outputs = outer_outputs[yuck:yuck + len(extra_inner_outputs)] return extra_outer_outputs
def get_population_outputs(batch_outputs, popstats): replacements = [] visited_scan_ops = set() for var in theano.gof.graph.ancestors(batch_outputs): if hasattr(var.tag, "bn_statistic"): # can't rely on object identity because scan_args clones; use original_id popstat = next(popstat for batchstat, popstat in popstats.items() if batchstat.tag.original_id == var.tag.original_id) replacements.append( (var, T.patternbroadcast(popstat, var.broadcastable))) # descend into Scan try: op = var.owner.op except: continue if isinstance(op, Scan): # this would cause multiple replacements for this variable assert not hasattr(var.tag, "bn_statistic") if op in visited_scan_ops: continue visited_scan_ops.add(op) print "descending into", var node = var.owner sa = scan_utils.scan_args(outer_inputs=node.inputs, outer_outputs=node.outputs, _inner_inputs=node.op.inputs, _inner_outputs=node.op.outputs, info=node.op.info) # add subscript as sequence # TODO check if this integer input drops the scan to cpu, if so use float and cast back in subtensor expression indices = T.arange(sa.n_steps) index = scan_utils.safe_new(indices[0]) sa.outer_in_seqs.append(indices) sa.inner_in_seqs.append(index) # add popstats as nonsequences (because they may be shorter than len(indices)) inner_popstats = {} for batchstat, outer_popstat in popstats.items(): # this can't be subscripted hence won't appear in the inner graph if outer_popstat.ndim == 0: continue inner_popstat = scan_utils.safe_new(outer_popstat) sa.outer_in_non_seqs.append(outer_popstat) sa.inner_in_non_seqs.append(inner_popstat) inner_popstats[batchstat] = theano.ifelse.ifelse( index < inner_popstat.shape[0], inner_popstat[index], inner_popstat[-1]) # recurse on inner graph new_inner_outputs = sa.inner_outputs new_inner_outputs = get_population_outputs(new_inner_outputs, inner_popstats) # construct new scan node new_op = Scan(sa.inner_inputs, new_inner_outputs, sa.info) new_outer_outputs = new_op(*sa.outer_inputs) # there is one-to-one correspondence between old outer # inputs and new_outer_inputs; replace one-to-one replacements.extend(equizip(node.outputs, new_outer_outputs)) print "replacements", replacements population_outputs = scan_utils.clone(batch_outputs, replace=replacements) return population_outputs