def local_transform(node):
        if node in nodes_seen:
            return False

        # importing Scan into module scope would be circular
        from theano.scan_module.scan_op import Scan
        from theano.compile import OpFromGraph

        if isinstance(node.op, (Scan, OpFromGraph)):
            # recurse on the inner graph
            (new_inner_inputs, new_outer_inputs,
             new_inner_outputs) = _map_variables_inner(
                 wrapped_replacer,
                 inner_inputs=node.op.inputs,
                 outer_inputs=node.inputs,
                 inner_outputs=node.op.outputs,
                 containing_op=node.op)
            # reinstantiate the op
            if isinstance(node.op, Scan):
                new_op = Scan(
                    new_inner_inputs,
                    new_inner_outputs,
                    node.op.info,
                    # FIXME: infer this someday?
                    typeConstructor=None)
            elif isinstance(node.op, OpFromGraph):
                new_op = OpFromGraph(new_inner_inputs, new_inner_outputs,
                                     **node.op.kwargs)
            # make a new node to replace the old one
            new_node = new_op.make_node(*new_outer_inputs)
            nodes_seen.add(new_node)
            return new_node.outputs
        else:
            nodes_seen.add(node)
            return list(map(wrapped_replacer, node.outputs))
Example #2
0
    def local_transform(node):
        if node in nodes_seen:
            return False

        # importing Scan into module scope would be circular
        from theano.scan_module.scan_op import Scan
        from theano.compile import OpFromGraph

        if isinstance(node.op, (Scan, OpFromGraph)):
            # recurse on the inner graph
            (new_inner_inputs, new_outer_inputs, new_inner_outputs) = _map_variables_inner(
                wrapped_replacer,
                inner_inputs=node.op.inputs,
                outer_inputs=node.inputs,
                inner_outputs=node.op.outputs,
                containing_op=node.op,
            )
            # reinstantiate the op
            if isinstance(node.op, Scan):
                new_op = Scan(
                    new_inner_inputs,
                    new_inner_outputs,
                    node.op.info,
                    # FIXME: infer this someday?
                    typeConstructor=None,
                )
            elif isinstance(node.op, OpFromGraph):
                new_op = OpFromGraph(new_inner_inputs, new_inner_outputs, **node.op.kwargs)
            # make a new node to replace the old one
            new_node = new_op.make_node(*new_outer_inputs)
            nodes_seen.add(new_node)
            return new_node.outputs
        else:
            nodes_seen.add(node)
            return list(map(wrapped_replacer, node.outputs))
Example #3
0
def construct_scan(scan_args):
    scan_op = Scan(scan_args.inner_inputs, scan_args.inner_outputs,
                   scan_args.info)
    scan_out = scan_op(*scan_args.outer_inputs)

    if not isinstance(scan_out, list):
        scan_out = [scan_out]

    return scan_out
Example #4
0
def export(node, extra_inner_outputs):
    assert isinstance(node.op, Scan)

    # this is ugly but we can't use scan_utils.scan_args because that
    # clones the inner graph and then extra_inner_outputs aren't in
    # there anymore
    old_inner_inputs = node.op.inputs
    old_inner_outputs = node.op.outputs
    old_outer_inputs = node.inputs

    new_inner_inputs = list(old_inner_inputs)
    new_inner_outputs = list(old_inner_outputs)
    new_outer_inputs = list(old_outer_inputs)
    new_info = copy.deepcopy(node.op.info)

    # put the new inner outputs in the right place in the output list and
    # update info
    new_info["n_nit_sot"] += len(extra_inner_outputs)
    yuck = len(old_inner_outputs) - new_info["n_shared_outs"]
    new_inner_outputs[yuck:yuck] = extra_inner_outputs

    # in step 8, theano.scan() adds an outer input (being the actual
    # number of steps) for each nitsot. we need to do the same thing.
    # note these don't come with corresponding inner inputs.
    offset = (1 + node.op.n_seqs + node.op.n_mit_mot + node.op.n_mit_sot +
              node.op.n_sit_sot + node.op.n_shared_outs)
    # the outer input is just the actual number of steps, which is
    # always available as the first outer input.
    new_outer_inputs[offset:offset] = [new_outer_inputs[0]
                                       ] * len(extra_inner_outputs)

    new_op = Scan(new_inner_inputs, new_inner_outputs, new_info)
    outer_outputs = new_op(*new_outer_inputs)

    # grab the outputs we actually care about
    extra_outer_outputs = outer_outputs[yuck:yuck + len(extra_inner_outputs)]
    return extra_outer_outputs
Example #5
0
def get_population_outputs(batch_outputs, popstats):
    replacements = []
    visited_scan_ops = set()

    for var in theano.gof.graph.ancestors(batch_outputs):
        if hasattr(var.tag, "bn_statistic"):
            # can't rely on object identity because scan_args clones; use original_id
            popstat = next(popstat for batchstat, popstat in popstats.items()
                           if batchstat.tag.original_id == var.tag.original_id)
            replacements.append(
                (var, T.patternbroadcast(popstat, var.broadcastable)))

        # descend into Scan
        try:
            op = var.owner.op
        except:
            continue
        if isinstance(op, Scan):
            # this would cause multiple replacements for this variable
            assert not hasattr(var.tag, "bn_statistic")

            if op in visited_scan_ops:
                continue
            visited_scan_ops.add(op)
            print "descending into", var

            node = var.owner
            sa = scan_utils.scan_args(outer_inputs=node.inputs,
                                      outer_outputs=node.outputs,
                                      _inner_inputs=node.op.inputs,
                                      _inner_outputs=node.op.outputs,
                                      info=node.op.info)

            # add subscript as sequence
            # TODO check if this integer input drops the scan to cpu, if so use float and cast back in subtensor expression
            indices = T.arange(sa.n_steps)
            index = scan_utils.safe_new(indices[0])
            sa.outer_in_seqs.append(indices)
            sa.inner_in_seqs.append(index)

            # add popstats as nonsequences (because they may be shorter than len(indices))
            inner_popstats = {}
            for batchstat, outer_popstat in popstats.items():
                # this can't be subscripted hence won't appear in the inner graph
                if outer_popstat.ndim == 0:
                    continue

                inner_popstat = scan_utils.safe_new(outer_popstat)
                sa.outer_in_non_seqs.append(outer_popstat)
                sa.inner_in_non_seqs.append(inner_popstat)

                inner_popstats[batchstat] = theano.ifelse.ifelse(
                    index < inner_popstat.shape[0], inner_popstat[index],
                    inner_popstat[-1])

            # recurse on inner graph
            new_inner_outputs = sa.inner_outputs
            new_inner_outputs = get_population_outputs(new_inner_outputs,
                                                       inner_popstats)

            # construct new scan node
            new_op = Scan(sa.inner_inputs, new_inner_outputs, sa.info)
            new_outer_outputs = new_op(*sa.outer_inputs)

            # there is one-to-one correspondence between old outer
            # inputs and new_outer_inputs; replace one-to-one
            replacements.extend(equizip(node.outputs, new_outer_outputs))

    print "replacements", replacements
    population_outputs = scan_utils.clone(batch_outputs, replace=replacements)
    return population_outputs