Example #1
0
    def merge(self, nodes):

        if nodes[0].op.as_while:
            as_while = True
            condition = nodes[0].op.outputs[-1]
        else:
            as_while = False

        info = {}
        info["tap_array"] = []
        info["n_seqs"] = sum([nd.op.n_seqs for nd in nodes])
        info["n_mit_mot"] = sum([nd.op.n_mit_mot for nd in nodes])
        info["n_mit_mot_outs"] = sum([nd.op.n_mit_mot_outs for nd in nodes])
        info["mit_mot_out_slices"] = []
        info["n_mit_sot"] = sum([nd.op.n_mit_sot for nd in nodes])
        info["n_sit_sot"] = sum([nd.op.n_sit_sot for nd in nodes])
        info["n_shared_outs"] = sum([nd.op.n_shared_outs for nd in nodes])
        info["n_nit_sot"] = sum([nd.op.n_nit_sot for nd in nodes])
        info["truncate_gradient"] = nodes[0].op.truncate_gradient
        info["name"] = "&".join([nd.op.name for nd in nodes])
        info["mode"] = nodes[0].op.mode
        info["inplace"] = False
        info["gpu"] = False
        info["as_while"] = as_while
        info["profile"] = nodes[0].op.profile

        inner_ins = []
        outer_ins = []
        inner_outs = []
        outer_outs = []

        def rename(ls, suffix):
            for k in ls:
                if k.name:
                    k.name += str(suffix)
            return ls

        for idx, nd in enumerate(nodes):
            # Seq
            inner_ins += rename(nd.op.inner_seqs(), idx)
            outer_ins += rename(nd.op.outer_seqs(nd), idx)

        for idx, nd in enumerate(nodes):
            # MitMot
            inner_ins += rename(nd.op.inner_mitmot(), idx)
            inner_outs += nd.op.inner_mitmot_outs()
            info["tap_array"] += nd.op.mitmot_taps()
            info["mit_mot_out_slices"] += nd.op.mitmot_out_taps()
            outer_ins += rename(nd.op.outer_mitmot(nd), idx)
            outer_outs += nd.op.outer_mitmot_outs(nd)

        for idx, nd in enumerate(nodes):
            # MitSot
            inner_ins += rename(nd.op.inner_mitsot(), idx)
            inner_outs += nd.op.inner_mitsot_outs()
            info["tap_array"] += nd.op.mitsot_taps()
            outer_ins += rename(nd.op.outer_mitsot(nd), idx)
            outer_outs += nd.op.outer_mitsot_outs(nd)

        for idx, nd in enumerate(nodes):
            # SitSot
            inner_ins += rename(nd.op.inner_sitsot(), idx)
            info["tap_array"] += [[-1] for x in xrange(nd.op.n_sit_sot)]
            inner_outs += nd.op.inner_sitsot_outs()
            outer_ins += rename(nd.op.outer_sitsot(nd), idx)
            outer_outs += nd.op.outer_sitsot_outs(nd)

        for idx, nd in enumerate(nodes):
            # Shared
            inner_ins += rename(nd.op.inner_shared(), idx)
            outer_ins += rename(nd.op.outer_shared(nd), idx)

        for idx, nd in enumerate(nodes):
            # NitSot
            inner_outs += nd.op.inner_nitsot_outs()
            outer_ins += rename(nd.op.outer_nitsot(nd), idx)
            outer_outs += nd.op.outer_nitsot_outs(nd)

        for idx, nd in enumerate(nodes):
            # Shared
            outer_outs += nd.op.outer_shared_outs(nd)
            inner_outs += nd.op.inner_shared_outs()

        for idx, nd in enumerate(nodes):
            # Non Seqs
            inner_ins += rename(nd.op.inner_non_seqs(), idx)
            outer_ins += rename(nd.op.outer_non_seqs(nd), idx)

        # Add back the number of steps
        outer_ins = [nodes[0].inputs[0]] + outer_ins

        if as_while:
            # add the condition
            inner_outs.append(condition)
        inner_ins, inner_outs = scan_utils.reconstruct_graph(inner_ins, inner_outs)

        new_op = scan_op.Scan(inner_ins, inner_outs, info)
        new_outs = new_op(*outer_ins)

        if not isinstance(new_outs, (list, tuple)):
            new_outs = [new_outs]

        return zip(outer_outs, new_outs)
Example #2
0
def remove_constants_and_unused_inputs_scan(node):
    """
    Move constants into the inner graph, and remove unused inputs.

    Constants that are in the outer graph are represented by a free symbolic
    variable in the inner graph. If we move them into the inner graph,
    constant-folding can happen in the inner graph.
    This is applied only on sequences and non-sequences,
    not on initial states.
    """
    if not isinstance(node.op, scan_op.Scan):
        return False
    op = node.op
    # We only need to take care of sequences and other arguments
    st = op.n_seqs
    st += int(numpy.sum([len(x) for x in op.tap_array[: (op.n_mit_mot + op.n_mit_sot)]]))
    st += op.n_sit_sot
    st += op.n_shared_outs
    op_ins, op_outs = scan_utils.reconstruct_graph(op.inputs, op.outputs)

    # Corresponds to the initial states, which should stay untouched.
    # We put those variables aside, and put them back at the end.
    out_stuff_inner = op_ins[op.n_seqs : st]

    non_seqs = op_ins[st:]
    st = op.n_seqs + op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot + op.n_shared_outs + 1
    outer_non_seqs = node.inputs[st:]
    out_stuff_outer = node.inputs[1 + op.n_seqs : st]

    # To replace constants in the outer graph by clones in the inner graph
    givens = {}
    # All the inputs of the inner graph of the new scan
    nw_inner = []
    # Same for the outer graph, initialized w/ number of steps
    nw_outer = [node.inputs[0]]

    all_ins = gof.graph.inputs(op_outs)
    for idx in xrange(op.n_seqs):
        if (
            isinstance(node.inputs[idx + 1], tensor.TensorConstant)
            and node.inputs[idx + 1].tag.unique_value is not None
        ):
            try:
                # This works if input is a constant that has all entries
                # equal
                val = tensor.get_constant_value(node.inputs[idx + 1])
                givens[op_ins[idx]] = node.inputs[idx + 1].clone()[0]
            except TypeError:
                pass
        elif op_ins[idx] in all_ins:
            nw_inner += [op_ins[idx]]
            nw_outer += [node.inputs[idx + 1]]

    nw_n_seqs = len(nw_inner)
    # Add outputs stuff
    nw_inner += out_stuff_inner
    nw_outer += out_stuff_outer
    # Look through non sequences
    for nw_in, nw_out in zip(non_seqs, outer_non_seqs):
        if isinstance(nw_out, tensor.Constant):
            givens[nw_in] = nw_out.clone()
        elif nw_in in all_ins:
            nw_inner += [nw_in]
            nw_outer += [nw_out]

    if len(nw_inner) != len(op_ins):
        op_outs = scan_utils.clone(op_outs, replace=givens)
        nw_info = op.info.copy()
        nw_info["n_seqs"] = nw_n_seqs
        # DEBUG CHECK
        nwScan = scan_op.Scan(nw_inner, op_outs, nw_info)
        nw_outs = nwScan.make_node(*nw_outer).outputs
        return nw_outs
    else:
        return False
Example #3
0
    def process_node(self, env, node):
        # this flag tells if there was any change during the last iterations
        changed = True
        clean_inputs, clean_outputs = scan_utils.reconstruct_graph(node.op.inputs, node.op.outputs)

        local_env = gof.Env(clean_inputs, clean_outputs)
        max_iterations = 2 * len(local_env.toposort()) + 3
        counts = 0
        to_remove = []
        to_replace = []
        replace_with_in = []
        replace_with_out = []
        op = node.op
        # Construct the list of non_sequences to simplify a few things
        st = op.n_seqs
        st += int(numpy.sum([len(x) for x in op.tap_array[: (op.n_mit_mot + op.n_mit_sot)]]))
        st += op.n_sit_sot
        st += op.n_shared_outs
        non_seqs = clean_inputs[st:]
        st = op.n_seqs + op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot + op.n_shared_outs + 1
        outer_non_seqs = node.inputs[st:]
        assert len(non_seqs) == len(outer_non_seqs)
        while changed and counts < max_iterations:
            counts += 1
            changed = False

            for nd in local_env.toposort():
                if (
                    numpy.all(
                        [(x in non_seqs) or (x.owner in to_remove) or isinstance(x, tensor.Constant) for x in nd.inputs]
                    )
                    and
                    # we can do this because the assumption is that a
                    # viewOp or deepCopyOp will be just at the end of the
                    # function and not somewhere in the middle ..
                    not isinstance(nd.op, theano.compile.ViewOp)
                    and not isinstance(nd.op, theano.compile.DeepCopyOp)
                    and
                    # and we didn't already looked at this node
                    not nd in to_remove
                ):

                    # We have a candidate node to removable
                    # Step 1. Reconstruct it on outside
                    to_remove.append(nd)
                    outside_ins = []
                    for x in nd.inputs:
                        if x in non_seqs:
                            outside_ins += [outer_non_seqs[non_seqs.index(x)]]
                        elif x in to_replace:
                            outside_ins += [replace_with_out[to_replace.index(x)]]
                        elif isinstance(x, theano.Constant):
                            outside_ins += [x.clone()]
                        else:
                            raise Exception(
                                (
                                    "Error in the `scan_pushout_non_seq_"
                                    "operations`. The optimization tries "
                                    "to move some computation fron scan "
                                    "which is not allowed to move. Report "
                                    "this on theano-users list"
                                ),
                                x,
                            )
                    nw_outer_node = nd.op.make_node(*outside_ins)
                    # Step 2. Create variables for replacements
                    for idx, y in enumerate(nd.outputs):

                        y_place_holder = scan_utils.safe_new(y, "_replace")
                        to_replace += [y]
                        replace_with_in += [y_place_holder]
                        assert type(y) == type(nw_outer_node.outputs[idx])
                        replace_with_out += [nw_outer_node.outputs[idx]]
                    changed = True
        if counts >= max_iterations:
            raise Exception(
                "Error in the `scan_pushout_non_seq_operations`."
                " The optimization exhausted the maximal number "
                "of iterations allowed!"
            )
        # We need to check all candidate replacements and choose those that
        # make sense for us

        # Step 1. which elements of `to_replace` are used by remaining
        # components of the inner function
        clean_to_replace = []
        clean_replace_with_in = []
        clean_replace_with_out = []
        existent_nodes = [nd for nd in local_env.toposort() if nd not in to_remove]
        to_keep = []
        for nd in existent_nodes:
            to_keep += nd.inputs
        for idx, out in enumerate(to_replace):
            if out in to_keep and out.owner not in existent_nodes:
                clean_to_replace += [out]
                clean_replace_with_in += [replace_with_in[idx]]
                clean_replace_with_out += [replace_with_out[idx]]

        if len(clean_to_replace) > 0:
            # We can finally put an end to all this madness
            givens = {}
            nw_outer = []
            nw_inner = []
            for to_repl, repl_in, repl_out in zip(clean_to_replace, clean_replace_with_in, clean_replace_with_out):
                if isinstance(repl_out, theano.Constant):
                    repl_in = repl_out.clone()
                else:
                    nw_inner += [repl_in]
                    nw_outer += [repl_out]
                givens[to_repl] = repl_in

            _op_outs = scan_utils.clone(clean_outputs, replace=givens)
            _op_ins = clean_inputs + nw_inner
            op_ins, op_outs = scan_utils.reconstruct_graph(_op_ins, _op_outs)
            # Reconstruct node
            nwScan = scan_op.Scan(op_ins, op_outs, op.info)
            nw_node = nwScan.make_node(*(node.inputs + nw_outer))
            env.replace_all_validate(zip(node.outputs, nw_node.outputs), reason="scan_push_computation_out")
            return True
        elif to_keep == []:
            # Nothing in the inner graph should be kept
            replace_with = {}
            for idx, out in enumerate(to_replace):
                if out in local_env.outputs:
                    x = node.outputs[local_env.outputs.index(out)]
                    y = replace_with_out[idx]
                    shape = [y.shape[idx] for idx in xrange(y.ndim)]
                    replace_with[x] = tensor.alloc(y, node.inputs[0], *shape)

            # We need to add one extra dimension to the outputs
            env.replace_all_validate(replace_with.items(), reason="scan_push_computation_out")

        else:
            return False
Example #4
0
    def merge(self, nodes):

        if nodes[0].op.as_while:
            as_while = True
            condition = nodes[0].op.outputs[-1]
        else:
            as_while = False

        info = {}
        info['tap_array'] = []
        info['n_seqs'] = sum([nd.op.n_seqs for nd in nodes])
        info['n_mit_mot'] = sum([nd.op.n_mit_mot for nd in nodes])
        info['n_mit_mot_outs'] = sum([nd.op.n_mit_mot_outs for nd in nodes])
        info['mit_mot_out_slices'] = []
        info['n_mit_sot'] = sum([nd.op.n_mit_sot for nd in nodes])
        info['n_sit_sot'] = sum([nd.op.n_sit_sot for nd in nodes])
        info['n_shared_outs'] = sum([nd.op.n_shared_outs for nd in nodes])
        info['n_nit_sot'] = sum([nd.op.n_nit_sot for nd in nodes])
        info['truncate_gradient'] = nodes[0].op.truncate_gradient
        info['name'] = '&'.join([nd.op.name for nd in nodes])
        info['mode'] = nodes[0].op.mode
        info['gpu'] = False
        info['as_while'] = as_while
        info['profile'] = nodes[0].op.profile

        inner_ins = []
        outer_ins = []
        inner_outs = []
        outer_outs = []

        def rename(ls, suffix):
            for k in ls:
                if k.name:
                    k.name += str(suffix)
            return ls

        for idx, nd in enumerate(nodes):
            # Seq
            inner_ins += rename(nd.op.inner_seqs(nd.op.inputs), idx)
            outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx)

        for idx, nd in enumerate(nodes):
            # MitMot
            inner_ins += rename(nd.op.inner_mitmot(nd.op.inputs), idx)
            inner_outs += nd.op.inner_mitmot_outs(nd.op.outputs)
            info['tap_array'] += nd.op.mitmot_taps()
            info['mit_mot_out_slices'] += nd.op.mitmot_out_taps()
            outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx)
            outer_outs += nd.op.outer_mitmot_outs(nd.outputs)

        for idx, nd in enumerate(nodes):
            # MitSot
            inner_ins += rename(nd.op.inner_mitsot(nd.op.inputs), idx)
            inner_outs += nd.op.inner_mitsot_outs(nd.op.outputs)
            info['tap_array'] += nd.op.mitsot_taps()
            outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx)
            outer_outs += nd.op.outer_mitsot_outs(nd.outputs)

        for idx, nd in enumerate(nodes):
            # SitSot
            inner_ins += rename(nd.op.inner_sitsot(nd.op.inputs), idx)
            info['tap_array'] += [[-1] for x in xrange(nd.op.n_sit_sot)]
            inner_outs += nd.op.inner_sitsot_outs(nd.op.outputs)
            outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx)
            outer_outs += nd.op.outer_sitsot_outs(nd.outputs)

        for idx, nd in enumerate(nodes):
            # Shared
            inner_ins += rename(nd.op.inner_shared(nd.op.inputs), idx)
            outer_ins += rename(nd.op.outer_shared(nd.inputs), idx)

        for idx, nd in enumerate(nodes):
            # NitSot
            inner_outs += nd.op.inner_nitsot_outs(nd.op.outputs)
            outer_ins += rename(nd.op.outer_nitsot(nd.inputs), idx)
            outer_outs += nd.op.outer_nitsot_outs(nd.outputs)

        for idx, nd in enumerate(nodes):
            # Shared
            outer_outs += nd.op.outer_shared_outs(nd.outputs)
            inner_outs += nd.op.inner_shared_outs(nd.op.outputs)

        for idx, nd in enumerate(nodes):
            # Non Seqs
            inner_ins += rename(nd.op.inner_non_seqs(nd.op.inputs), idx)
            outer_ins += rename(nd.op.outer_non_seqs(nd.inputs), idx)

        # Add back the number of steps
        outer_ins = [nodes[0].inputs[0]] + outer_ins

        if as_while:
            # add the condition
            inner_outs.append(condition)
        inner_ins, inner_outs = scan_utils.reconstruct_graph(inner_ins,
                                                             inner_outs)

        new_op = scan_op.Scan(inner_ins, inner_outs, info)
        new_outs = new_op(*outer_ins)

        if not isinstance(new_outs, (list, tuple)):
            new_outs = [new_outs]

        return zip(outer_outs, new_outs)
Example #5
0
def remove_constants_and_unused_inputs_scan(node):
    '''
    Move constants into the inner graph, and remove unused inputs.

    Constants that are in the outer graph are represented by a free symbolic
    variable in the inner graph. If we move them into the inner graph,
    constant-folding can happen in the inner graph.
    This is applied only on sequences and non-sequences,
    not on initial states.
    '''
    if not isinstance(node.op, scan_op.Scan):
        return False
    op = node.op
    # We only need to take care of sequences and other arguments
    st = op.n_seqs
    st += int(numpy.sum([len(x) for x in
                     op.tap_array[:(op.n_mit_mot + op.n_mit_sot)]]))
    st += op.n_sit_sot
    st += op.n_shared_outs
    op_ins, op_outs = scan_utils.reconstruct_graph(op.inputs, op.outputs)

    # Corresponds to the initial states, which should stay untouched.
    # We put those variables aside, and put them back at the end.
    out_stuff_inner = op_ins[op.n_seqs:st]

    non_seqs = op_ins[st:]
    st = (op.n_seqs +
          op.n_mit_mot +
          op.n_mit_sot +
          op.n_sit_sot +
          op.n_nit_sot +
          op.n_shared_outs + 1)
    outer_non_seqs = node.inputs[st:]
    out_stuff_outer = node.inputs[1 + op.n_seqs:st]

    # To replace constants in the outer graph by clones in the inner graph
    givens = {}
    # All the inputs of the inner graph of the new scan
    nw_inner = []
    # Same for the outer graph, initialized w/ number of steps
    nw_outer = [node.inputs[0]]

    all_ins = gof.graph.inputs(op_outs)
    for idx in xrange(op.n_seqs):
        if (isinstance(node.inputs[idx + 1], tensor.TensorConstant) and
            node.inputs[idx + 1].tag.unique_value is not None):
            try:
                # This works if input is a constant that has all entries
                # equal
                givens[op_ins[idx]] = node.inputs[idx + 1].clone()[0]
            except TypeError:
                pass
        elif op_ins[idx] in all_ins:
            nw_inner += [op_ins[idx]]
            nw_outer += [node.inputs[idx + 1]]

    nw_n_seqs = len(nw_inner)
    # Add outputs stuff
    nw_inner += out_stuff_inner
    nw_outer += out_stuff_outer
    # Look through non sequences
    for nw_in, nw_out in zip(non_seqs, outer_non_seqs):
        if isinstance(nw_out, tensor.Constant):
            givens[nw_in] = nw_out.clone()
        elif nw_in in all_ins:
            nw_inner += [nw_in]
            nw_outer += [nw_out]

    if len(nw_inner) != len(op_ins):
        op_outs = scan_utils.clone(op_outs, replace=givens)
        nw_info = copy.deepcopy(op.info)
        nw_info['n_seqs'] = nw_n_seqs
        # DEBUG CHECK
        nwScan = scan_op.Scan(nw_inner, op_outs, nw_info)
        nw_outs = nwScan.make_node(*nw_outer).outputs
        return nw_outs
    else:
        return False
Example #6
0
    def process_node(self, fgraph, node):
        # this flag tells if there was any change during the last iterations
        changed = True
        clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
                        node.op.inputs, node.op.outputs)

        local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs)
        max_iterations = 2 * len(local_fgraph.toposort()) + 3
        counts = 0
        to_remove = []
        to_replace = []
        replace_with_in = []
        replace_with_out = []
        op = node.op
        # Construct the list of non_sequences to simplify a few things
        st = op.n_seqs
        st += int(numpy.sum([len(x) for x in
                             op.tap_array[:(op.n_mit_mot + op.n_mit_sot)]]))
        st += op.n_sit_sot
        st += op.n_shared_outs
        non_seqs = clean_inputs[st:]
        st = (op.n_seqs +
              op.n_mit_mot +
              op.n_mit_sot +
              op.n_sit_sot +
              op.n_nit_sot +
              op.n_shared_outs + 1)
        outer_non_seqs = node.inputs[st:]
        assert len(non_seqs) == len(outer_non_seqs)
        while changed and counts < max_iterations:
            counts += 1
            changed = False

            for nd in local_fgraph.toposort():
                if (numpy.all([(x in non_seqs) or
                               (x.owner in to_remove) or
                               isinstance(x, tensor.Constant)
                                 for x in nd.inputs]) and
                        # we can do this because the assumption is that a
                        # viewOp or deepCopyOp will be just at the end of the
                        # function and not somewhere in the middle ..
                        not isinstance(nd.op, theano.compile.ViewOp) and
                        not isinstance(nd.op, theano.compile.DeepCopyOp) and
                        # and we didn't already looked at this node
                        not nd in to_remove):

                    # We have a candidate node to removable
                    # Step 1. Reconstruct it on outside
                    to_remove.append(nd)
                    outside_ins = []
                    for x in nd.inputs:
                        if x in non_seqs:
                            outside_ins += [outer_non_seqs[non_seqs.index(x)]]
                        elif x in to_replace:
                            outside_ins += [
                                replace_with_out[to_replace.index(x)]]
                        elif isinstance(x, theano.Constant):
                            outside_ins += [x.clone()]
                        else:
                            raise Exception(
                                ('Error in the `scan_pushout_non_seq_'
                                 'operations`. The optimization tries '
                                 'to move some computation fron scan '
                                 'which is not allowed to move. Report '
                                 'this on theano-users list'), x)
                    outside_ins = [x.type.filter_variable(y) for x,y in
                                   zip(nd.inputs, outside_ins)]
                    nw_outer_node = nd.op.make_node(*outside_ins)
                    # Step 2. Create variables for replacements
                    for idx, y in enumerate(nd.outputs):

                        y_place_holder = scan_utils.safe_new(y, '_replace')
                        to_replace += [y]
                        replace_with_in += [y_place_holder]
                        assert type(y) == type(nw_outer_node.outputs[idx])
                        replace_with_out += [nw_outer_node.outputs[idx]]
                    changed = True
        if counts >= max_iterations:
            raise Exception('Error in the `scan_pushout_non_seq_operations`.'
                            ' The optimization exhausted the maximal number '
                            'of iterations allowed!')
        # We need to check all candidate replacements and choose those that
        # make sense for us

        # Step 1. which elements of `to_replace` are used by remaining
        # components of the inner function
        clean_to_replace = []
        clean_replace_with_in = []
        clean_replace_with_out = []
        existent_nodes = [nd for nd in local_fgraph.toposort()
                            if nd not in to_remove]
        to_keep = []
        for nd in existent_nodes:
            to_keep += nd.inputs
        for idx, out in enumerate(to_replace):
            if out in to_keep and out.owner not in existent_nodes:
                clean_to_replace += [out]
                clean_replace_with_in += [replace_with_in[idx]]
                clean_replace_with_out += [replace_with_out[idx]]

        if len(clean_to_replace) > 0:
            # We can finally put an end to all this madness
            givens = {}
            nw_outer = []
            nw_inner = []
            for to_repl, repl_in, repl_out in zip(clean_to_replace,
                                              clean_replace_with_in,
                                              clean_replace_with_out):
                if isinstance(repl_out, theano.Constant):
                    repl_in = repl_out.clone()
                else:
                    nw_inner += [repl_in]
                    nw_outer += [repl_out]
                givens[to_repl] = repl_in

            _op_outs = scan_utils.clone(clean_outputs,
                                        replace=givens)
            _op_ins = clean_inputs + nw_inner
            op_ins, op_outs = scan_utils.reconstruct_graph(_op_ins, _op_outs)
            # Reconstruct node
            nwScan = scan_op.Scan(op_ins, op_outs, op.info)
            nw_node = nwScan.make_node(* (node.inputs + nw_outer))
            fgraph.replace_all_validate_remove(
                zip(node.outputs, nw_node.outputs),
                remove=[node],
                reason='scan_push_computation_out')
            return True
        elif to_keep == []:
            # Nothing in the inner graph should be kept
            replace_with = {}
            for idx, out in enumerate(to_replace):
                if out in local_fgraph.outputs:
                    x = node.outputs[local_fgraph.outputs.index(out)]
                    y = replace_with_out[idx]
                    shape = [y.shape[idx] for idx in xrange(y.ndim)]
                    replace_with[x] = tensor.alloc(y,
                                                   node.inputs[0],
                                                   *shape)

            # We need to add one extra dimension to the outputs
            if replace_with:
                fgraph.replace_all_validate_remove(
                    replace_with.items(),
                    remove=[node],
                    reason='scan_push_computation_out')

        else:
            return False