Exemplo n.º 1
0
def scan_make_inplace(node):
    op = node.op
    if (isinstance(op, scan_op.Scan) and
        (not op.info['inplace']) and
        (not op.info['gpu'])):
        info = op.info.copy()
        info['inplace'] = True
        # inputs corresponding to sequences and n_steps
        ls_begin = node.inputs[:1 + op.n_seqs]
        ls = op.outer_mitmot(node.inputs)
        ls += op.outer_mitsot(node.inputs)
        ls += op.outer_sitsot(node.inputs)
        ls_end = op.outer_shared(node.inputs)
        ls_end += op.outer_nitsot(node.inputs)
        ls_end += op.outer_non_seqs(node.inputs)
        n_outs = len(ls)
        for idx in xrange(n_outs):
            if ls[idx] in ls[:idx]:
                ls[idx] = deep_copy_op(ls[idx])

        inputs = ls_begin + ls + ls_end
        new_op = scan_op.Scan(op.inputs,
                              op.outputs,
                              info)
        return new_op.make_node(*inputs).outputs
    return False
Exemplo n.º 2
0
    def apply(self, fgraph):

        nodes = fgraph.toposort()
        scan_nodes = [x for x in nodes
                      if (isinstance(x.op, scan_op.Scan) and
                         x.op.info['gpu'] == self.gpu_flag)]
        for scan_idx in xrange(len(scan_nodes)):
            node = scan_nodes[scan_idx]
            op = node.op
            n_outs = (op.info['n_mit_mot'] +
                      op.info['n_mit_sot'] +
                      op.info['n_sit_sot'])
            for pos in xrange(n_outs):
                info = copy.deepcopy(op.info)
                if not 'destroy_map' in info:
                    info['destroy_map'] = {}
                info['destroy_map'][pos] = [pos + 1 + op.info['n_seqs']]
                # inputs corresponding to sequences and n_steps
                ls_begin = node.inputs[:1 + op.n_seqs]
                ls = op.outer_mitmot(node.inputs)
                ls += op.outer_mitsot(node.inputs)
                ls += op.outer_sitsot(node.inputs)
                ls_end = op.outer_shared(node.inputs)
                ls_end += op.outer_nitsot(node.inputs)
                ls_end += op.outer_non_seqs(node.inputs)
                n_outs = len(ls)
                for idx in xrange(n_outs):
                    if ls[idx] in ls[:idx]:
                        ls[idx] = deep_copy_op(ls[idx])

                inputs = ls_begin + ls + ls_end
                new_op = scan_op.Scan(op.inputs,
                                      op.outputs,
                                      info,
                                      typeConstructor=self.typeConstructor)

                new_outs = new_op.make_node(*inputs).outputs
                try:
                    fgraph.replace_all_validate_remove(
                        zip(node.outputs, new_outs),
                        remove=[node],
                        reason=self.__class__.__name__)
                    op = new_op
                    node = new_outs[0].owner
                except InconsistencyError, e:
                    # Failed moving output to be comptued inplace
                    pass
Exemplo n.º 3
0
    def merge(self, nodes):

        if nodes[0].op.as_while:
            as_while = True
            condition = nodes[0].op.outputs[-1]
        else:
            as_while = False

        info = {}
        info['tap_array'] = []
        info['n_seqs'] = sum([nd.op.n_seqs for nd in nodes])
        info['n_mit_mot'] = sum([nd.op.n_mit_mot for nd in nodes])
        info['n_mit_mot_outs'] = sum([nd.op.n_mit_mot_outs for nd in nodes])
        info['mit_mot_out_slices'] = []
        info['n_mit_sot'] = sum([nd.op.n_mit_sot for nd in nodes])
        info['n_sit_sot'] = sum([nd.op.n_sit_sot for nd in nodes])
        info['n_shared_outs'] = sum([nd.op.n_shared_outs for nd in nodes])
        info['n_nit_sot'] = sum([nd.op.n_nit_sot for nd in nodes])
        info['truncate_gradient'] = nodes[0].op.truncate_gradient
        info['name'] = '&'.join([nd.op.name for nd in nodes])
        info['mode'] = nodes[0].op.mode
        info['gpu'] = False
        info['as_while'] = as_while
        info['profile'] = nodes[0].op.profile

        inner_ins = []
        outer_ins = []
        inner_outs = []
        outer_outs = []

        def rename(ls, suffix):
            for k in ls:
                if k.name:
                    k.name += str(suffix)
            return ls

        for idx, nd in enumerate(nodes):
            # Seq
            inner_ins += rename(nd.op.inner_seqs(nd.op.inputs), idx)
            outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx)

        for idx, nd in enumerate(nodes):
            # MitMot
            inner_ins += rename(nd.op.inner_mitmot(nd.op.inputs), idx)
            inner_outs += nd.op.inner_mitmot_outs(nd.op.outputs)
            info['tap_array'] += nd.op.mitmot_taps()
            info['mit_mot_out_slices'] += nd.op.mitmot_out_taps()
            outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx)
            outer_outs += nd.op.outer_mitmot_outs(nd.outputs)

        for idx, nd in enumerate(nodes):
            # MitSot
            inner_ins += rename(nd.op.inner_mitsot(nd.op.inputs), idx)
            inner_outs += nd.op.inner_mitsot_outs(nd.op.outputs)
            info['tap_array'] += nd.op.mitsot_taps()
            outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx)
            outer_outs += nd.op.outer_mitsot_outs(nd.outputs)

        for idx, nd in enumerate(nodes):
            # SitSot
            inner_ins += rename(nd.op.inner_sitsot(nd.op.inputs), idx)
            info['tap_array'] += [[-1] for x in xrange(nd.op.n_sit_sot)]
            inner_outs += nd.op.inner_sitsot_outs(nd.op.outputs)
            outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx)
            outer_outs += nd.op.outer_sitsot_outs(nd.outputs)

        for idx, nd in enumerate(nodes):
            # Shared
            inner_ins += rename(nd.op.inner_shared(nd.op.inputs), idx)
            outer_ins += rename(nd.op.outer_shared(nd.inputs), idx)

        for idx, nd in enumerate(nodes):
            # NitSot
            inner_outs += nd.op.inner_nitsot_outs(nd.op.outputs)
            outer_ins += rename(nd.op.outer_nitsot(nd.inputs), idx)
            outer_outs += nd.op.outer_nitsot_outs(nd.outputs)

        for idx, nd in enumerate(nodes):
            # Shared
            outer_outs += nd.op.outer_shared_outs(nd.outputs)
            inner_outs += nd.op.inner_shared_outs(nd.op.outputs)

        for idx, nd in enumerate(nodes):
            # Non Seqs
            inner_ins += rename(nd.op.inner_non_seqs(nd.op.inputs), idx)
            outer_ins += rename(nd.op.outer_non_seqs(nd.inputs), idx)

        # Add back the number of steps
        outer_ins = [nodes[0].inputs[0]] + outer_ins

        if as_while:
            # add the condition
            inner_outs.append(condition)
        inner_ins, inner_outs = scan_utils.reconstruct_graph(inner_ins,
                                                             inner_outs)

        new_op = scan_op.Scan(inner_ins, inner_outs, info)
        new_outs = new_op(*outer_ins)

        if not isinstance(new_outs, (list, tuple)):
            new_outs = [new_outs]

        return zip(outer_outs, new_outs)
Exemplo n.º 4
0
def remove_constants_and_unused_inputs_scan(node):
    '''
    Move constants into the inner graph, and remove unused inputs.

    Constants that are in the outer graph are represented by a free symbolic
    variable in the inner graph. If we move them into the inner graph,
    constant-folding can happen in the inner graph.
    This is applied only on sequences and non-sequences,
    not on initial states.
    '''
    if not isinstance(node.op, scan_op.Scan):
        return False
    op = node.op
    # We only need to take care of sequences and other arguments
    st = op.n_seqs
    st += int(numpy.sum([len(x) for x in
                     op.tap_array[:(op.n_mit_mot + op.n_mit_sot)]]))
    st += op.n_sit_sot
    st += op.n_shared_outs
    op_ins, op_outs = scan_utils.reconstruct_graph(op.inputs, op.outputs)

    # Corresponds to the initial states, which should stay untouched.
    # We put those variables aside, and put them back at the end.
    out_stuff_inner = op_ins[op.n_seqs:st]

    non_seqs = op_ins[st:]
    st = (op.n_seqs +
          op.n_mit_mot +
          op.n_mit_sot +
          op.n_sit_sot +
          op.n_nit_sot +
          op.n_shared_outs + 1)
    outer_non_seqs = node.inputs[st:]
    out_stuff_outer = node.inputs[1 + op.n_seqs:st]

    # To replace constants in the outer graph by clones in the inner graph
    givens = {}
    # All the inputs of the inner graph of the new scan
    nw_inner = []
    # Same for the outer graph, initialized w/ number of steps
    nw_outer = [node.inputs[0]]

    all_ins = gof.graph.inputs(op_outs)
    for idx in xrange(op.n_seqs):
        if (isinstance(node.inputs[idx + 1], tensor.TensorConstant) and
            node.inputs[idx + 1].tag.unique_value is not None):
            try:
                # This works if input is a constant that has all entries
                # equal
                givens[op_ins[idx]] = node.inputs[idx + 1].clone()[0]
            except TypeError:
                pass
        elif op_ins[idx] in all_ins:
            nw_inner += [op_ins[idx]]
            nw_outer += [node.inputs[idx + 1]]

    nw_n_seqs = len(nw_inner)
    # Add outputs stuff
    nw_inner += out_stuff_inner
    nw_outer += out_stuff_outer
    # Look through non sequences
    for nw_in, nw_out in zip(non_seqs, outer_non_seqs):
        if isinstance(nw_out, tensor.Constant):
            givens[nw_in] = nw_out.clone()
        elif nw_in in all_ins:
            nw_inner += [nw_in]
            nw_outer += [nw_out]

    if len(nw_inner) != len(op_ins):
        op_outs = scan_utils.clone(op_outs, replace=givens)
        nw_info = copy.deepcopy(op.info)
        nw_info['n_seqs'] = nw_n_seqs
        # DEBUG CHECK
        nwScan = scan_op.Scan(nw_inner, op_outs, nw_info)
        nw_outs = nwScan.make_node(*nw_outer).outputs
        return nw_outs
    else:
        return False
Exemplo n.º 5
0
    def process_node(self, fgraph, node):

        # helpful functions
        def select_min(x, y):
            if x is None:
                return y
            if y is None:
                return x
            return tensor.minimum(x, y)

        def select_max(x, y):
            if x is None:
                return y
            if y is None:
                return x
            return tensor.maximum(x, y)

        def sanitize(x):
            if x is None:
                return None
            else:
                return tensor.as_tensor_variable(x)

        if hasattr(fgraph, 'shape_feature'):
            shape_of = node.fgraph.shape_feature.shape_of
        else:
            # Each access to shape_of is in a try..except block in order to
            # use a default version when the variable is not in the shape_of
            # dictionary.
            shape_of = {}
        # 1. Initialization of variables
        # Note 1) We do not actually care about outputs representing shared
        # variables (those have no intermediate values) so it is safer to
        # ignore them and not change them in any way. To simplify the
        # optimizations I construct the variable ``c_outs`` ( that counts
        # outputs up to those we care) and the list ``init_l`` which for any
        # output we care says the length of its initial state. Note that
        # defining ``init_l`` for mit_mot sequences is a bit trickier but
        # it is safe to set it to 0
        op = node.op
        c_outs = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot

        init_l = [0 for x in xrange(op.n_mit_mot)]
        init_l += [abs(numpy.min(v)) for v in op.tap_array[op.n_mit_mot:]]
        init_l += [0 for x in xrange(op.n_nit_sot)]
        # 2. Check the clients of each output and see for how many steps
        # does scan need to run

        # This comparison checks if there is any uncounted output, which
        # can only be an output corresponding to a shared variable

        # 2.1 Initialize
        # global_nsteps is a dictionary having two fields ( 'real' deals
        # with int values, 'sym' with symbolic ones) or None
        # given that a scan op has k outputs o_1, .. o_k and each
        # output has n_j clients c_1^1, c_1^2, .. c_1^{n_1}, c_2^1, ..,
        # global_nsteps is None if any of the clients is different
        # from a subtensor or its real and sym field equal to
        # max(c_i_j.idx_list[0].stop), meaning store up to which maximal
        # index(step) for any output scan actually needs to compute
        # In other words n_steps should be equal to this maximal !
        # Note: if we have a shared variable that gets updated at every step
        # of the loop, reducing the number of steps will affect the the
        # value of the shared variable after the loop so we need not to
        # change the number of steps in that case. To do this we set
        # global_nsteps to None which is seen as a flag that nothing needs
        # to be done
        assert len(node.outputs) >= c_outs
        if len(node.outputs) == c_outs:
            global_nsteps = {'real': -1, 'sym': []}
        else:
            global_nsteps = None

        # Keeps track of the original slices that each client represent
        slices = [None for o in node.outputs]

        # A list for each output indicating how many intermediate values
        # should be stored. If negative it means none of the intermediate
        # values (i.e. the output can be removed since it is not used
        # afterwards in the computations), if 0 it means that all
        # intermediate values are required, otherwise is up to that number
        # of intermediate values
        # Note that for mit_mot outputs and shared outputs we can not change
        # the number of intermediate steps stored without affecting the
        # result of the op
        store_steps = [0 for o in xrange(op.n_mit_mot)]
        store_steps += [-1 for o in node.outputs[op.n_mit_mot:c_outs]]
        # Flag that says if an input has changed and we need to do something
        # or not
        flag_store = False

        # 2.2 Loop over the clients
        for i, out in enumerate(node.outputs[:c_outs]):
            # look at all its clients
            slices[i] = []
            for cl, _ in out.clients:

                # 2.1 outputs of the function
                #=> output needs all its intermediate values
                if type(cl) == str:
                    # if the node is actually an output, then
                    # we need to store the entire thing
                    global_nsteps = None
                    slices[i] = None
                    break
                # 2.2 non-subtensor nodes
                #=> output needs all its intermediate values
                elif not isinstance(cl.op, tensor.basic.Subtensor):
                    global_nsteps = None
                    slices[i] = None
                    break
                # 2.3 subtensor nodes
                #=> output might need to store just a subset of its values
                else:
                    # 2.3.1 extract idx list of subtensor
                    this_slice = tensor.basic.get_idx_list(cl.inputs,
                                                     cl.op.idx_list)
                    if this_slice is None:
                        # if unable to extract idx_list
                        #=> outputs needs all its intermediate values
                        global_nsteps = None
                        slices[i] = None
                        break

                    # 2.3.2 extract the begin/end of the first dimension
                    if i >= op.n_mit_mot:
                        try:
                            length = shape_of[out][0]
                        except KeyError:
                            length = node.inputs[0] + init_l[i]
                    else:
                        try:
                            length = shape_of[out][0]
                        except KeyError:
                            length = out.shape[0]
                    cf_slice = tensor.basic.get_canonical_form_slice(
                                                    this_slice[0], length)
                    slices[i] += [(cf_slice, this_slice)]

                    if (isinstance(this_slice[0], slice) and
                        this_slice[0].stop is None):
                        global_nsteps = None
                    if isinstance(cf_slice[0], slice):
                        stop = tensor.basic.extract_constant(cf_slice[0].stop)
                    else:
                        stop = tensor.basic.extract_constant(cf_slice[0]) + 1
                    if stop == maxsize or stop == length:
                        stop = None
                    else:
                        # there is a **gotcha** here ! Namely, scan returns an
                        # array that contains the initial state of the output
                        # as well. Which means that if have a initial state of
                        # length 3, and you look for 5 steps you get an output
                        # y of length 8. If you only use y[:5], this does not
                        # mean that you only need to loop for 5 steps but
                        # actually only for 2 steps ( the first 3 are the
                        # initial state)
                        stop = stop - init_l[i]

                    # 2.3.3 we might get away with less number of steps
                    if stop is not None and global_nsteps is not None:
                        # yes if it is a tensor
                        if isinstance(stop, tensor.Variable):
                            global_nsteps['sym'] += [stop]
                        # not if it is maxsize
                        elif (type(stop) in (int, long) and
                              stop == maxsize):
                            global_nsteps = None
                        # yes if it is a int k, 0 < k < maxsize
                        elif (type(stop) in (int, long) and
                              global_nsteps['real'] < stop):
                            global_nsteps['real'] = stop
                        # yes if it is a int k, 0 < k < maxsize
                        elif (type(stop) in (int, long) and stop > 0):
                            pass
                        # not otherwise
                        else:
                            global_nsteps = None

        # 2.3. Analyze global_nsteps to figure out for how many steps scan
        # needs to iterate
        if global_nsteps is not None:
            nw_steps = node.inputs[0]

            # there are some symbolic tensors that limit the number of
            # steps
            if len(global_nsteps['sym']) == 0:
                sym_steps = None
            else:
                sym_steps = global_nsteps['sym'][0]
                for c in global_nsteps['sym'][1:]:
                    sym_steps = tensor.maximum(sym_steps, c)

            if global_nsteps['real'] >= 0:
                real_steps = global_nsteps['real']
            else:
                real_steps = None
            nw_steps = select_min(select_max(sym_steps, real_steps),
                                  node.inputs[0])
        else:
            nw_steps = node.inputs[0]
            global_nsteps = None

        # 2.4 Loop over the clients again now looking just to see how many
        # intermediate steps to store
        for i, out in enumerate(node.outputs[:c_outs]):
            # look at all its clients
            for cl, _ in out.clients:
                if type(cl) == str:
                    store_steps[i] = 0
                    break
                elif not isinstance(cl.op, tensor.basic.Subtensor):
                    store_steps[i] = 0
                    break
                else:
                    this_slice = tensor.basic.get_idx_list(cl.inputs,
                                                         cl.op.idx_list)
                    if this_slice is None:
                        store_steps[i] = 0
                        break

                    if (isinstance(this_slice[0], slice) and
                        this_slice[0].start is None):
                        store_steps[i] = 0
                        break

                    if i > op.n_mit_mot:
                        length = node.inputs[0] + init_l[i]
                    else:
                        try:
                            length = shape_of[out][0]
                        except KeyError:
                            length = out.shape[0]
                    cf_slice = tensor.basic.get_canonical_form_slice(
                                                    this_slice[0], length)

                    if isinstance(cf_slice[0], slice):
                        start = tensor.basic.extract_constant(
                            cf_slice[0].start)
                    else:
                        start = tensor.basic.extract_constant(cf_slice[0])
                    if start == 0 or store_steps[i] == 0:
                        store_steps[i] = 0
                    else:
                        pval = select_max(nw_steps - start + init_l[i],
                                          init_l[i])
                        if store_steps[i] != -1:
                            pval = select_max(pval, store_steps[i])

                        store_steps[i] = pval
                        flag_store = True

        orphane_outs = [i for i, x in enumerate(store_steps)
                        if (type(x) is int) and (x < 0)]
        flag_store = flag_store or (len(orphane_outs) > 0)
        # 3. is there anything to change ?
        if (flag_store or global_nsteps is not None):
            # 3.1 initialize inputs for the new scan
            old_outputs = []
            nw_inputs = list(node.inputs)
            nw_inputs[0] = nw_steps

            # 3.2 check orphane outputs to see if we can eliminate any
            required, not_required = \
                    scan_utils.scan_can_remove_outs(node.op,
                                                    orphane_outs)
            # 3.3. compose replace pairs for those nodes that need not
            # to store everything in memory ( or ar orphane and required
            # by the inner function .. )
            replaced_outs = []
            offset = 1 + op.n_seqs + op.n_mit_mot
            for idx, _val in enumerate(store_steps[op.n_mit_mot:]):
                i = idx + op.n_mit_mot
                if not(type(_val) is int and _val <= 0 and i not in required):

                    if idx + op.n_mit_mot in required:
                        val = 1
                    else:
                        val = _val
                    # If the memory for this output has been pre-allocated
                    # before going into the scan op (by an alloc node)
                    if idx < op.n_mit_sot + op.n_sit_sot:
                        # In case the input is still an alloc node, we
                        # actually have two options:
                        #   a) the input is a set_subtensor, in that case we
                        #      can replace the initial tensor by a slice,
                        #   b) it is not, and we simply take a slice of it.

                        #TODO: commit change below with Razvan
                        if (nw_inputs[offset + idx].owner and
                            isinstance(nw_inputs[offset + idx].owner.op,
                                       tensor.IncSubtensor) and
                            isinstance(nw_inputs[offset+idx].owner.op.idx_list[0], slice)):

                            _nw_input = nw_inputs[offset + idx].owner.inputs[1]
                            val = tensor.as_tensor_variable(val)
                            initl = tensor.as_tensor_variable(init_l[i])
                            tmp = pre_greedy_local_optimizer(list_opt_slice,
                                    tensor.maximum(val - initl, 0))
                            tmp = pre_constant_merge([tmp])[0]
                            nw_input = scan_utils.expand(_nw_input, tmp)
                        else:
                            tmp = tensor.as_tensor_variable(val)
                            initl = tensor.as_tensor_variable(init_l[i])
                            tmp = tensor.maximum(tmp, initl)
                            tmp = pre_greedy_local_optimizer(list_opt_slice,
                                                             tmp)
                            tmp = pre_constant_merge([tmp])[0]
                            nw_input = nw_inputs[offset + idx][:tmp]

                        nw_inputs[offset + idx] = nw_input
                        replaced_outs.append(op.n_mit_mot + idx)
                        odx = op.n_mit_mot + idx
                        old_outputs += [(odx, [x[0].outputs[0] for x in
                                        node.outputs[odx].clients])]
                    # If there is no memory pre-allocated for this output
                    elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot:

                        pos = (op.n_mit_mot + idx + op.n_seqs +
                               1 + op.n_shared_outs)
                        if nw_inputs[pos] == node.inputs[0]:
                            nw_inputs[pos] = val
                        odx = op.n_mit_mot + idx
                        replaced_outs.append(odx)
                        old_outputs += [(odx, [x[0].outputs[0] for x in
                                        node.outputs[odx].clients])]
            # 3.4. Recompute inputs for everything else based on the new
            # number of steps
            if global_nsteps is not None:
                for idx, val in enumerate(store_steps[op.n_mit_mot:]):
                    if val == 0:
                        if idx < op.n_mit_sot + op.n_sit_sot:
                            _nw_input = nw_inputs[offset + idx].owner.inputs[1]
                            odx = op.n_mit_mot + idx
                            nw_input = scan_utils.expand(_nw_input, nw_steps)
                            nw_inputs[offset + idx] = nw_input
                        elif idx < (op.n_mit_sot + op.n_sit_sot +
                                    op.n_nit_sot):
                            in_idx = offset + idx + op.n_shared_outs
                            if nw_inputs[in_idx] == node.inputs[0]:
                                nw_inputs[in_idx] = nw_steps
                            odx = op.n_mit_mot + idx

            # 3.5 Remove unwanted orphane outputs
            (inps, outs, info, node_ins, compress_map) = \
                    scan_utils.compress_outs(op, not_required, nw_inputs)
            inv_compress_map = {}
            for k, v in compress_map.items():
                inv_compress_map[v] = k

            node_ins = [pre_greedy_local_optimizer(list_opt_slice, x) for x in
                        node_ins]
            node_ins = pre_constant_merge(node_ins)
            # 3.6 Compose the new scan
            # I need to make sure I'm not reapplying the same optimization
            # twice since bad things usually happen if I do that
            info['_scan_savemem_visited'] = True
            new_outs = scan_op.Scan(inps,
                                    outs,
                                    info).make_node(*node_ins).outputs

            old_new = []
            # 3.7 Get replace pairs for those outputs that do not change
            # the number of intermediate steps stored
            for idx, sl in enumerate(slices):
                if global_nsteps and sl is not None and store_steps[idx] == 0:
                    for hdx, cl in enumerate(node.outputs[idx].clients):
                        cnf_slice, old_slices = sl[hdx]
                        # Sanitize the nw_slice by converting ints back into
                        # constants :) I only need to do this for the first
                        # slice since that is the only slice

                        if isinstance(cnf_slice[0], slice):
                            fslice = slice(
                                sanitize(cnf_slice[0].start),
                                sanitize(cnf_slice[0].stop),
                                sanitize(cnf_slice[0].step))
                        else:
                            fslice = sanitize(cnf_slice[0])

                        nw_slice = (fslice,) + tuple(old_slices[1:])
                        nw_pos = inv_compress_map[idx]

                        subtens = tensor.basic.Subtensor(nw_slice)
                        # slice inputs
                        sl_ins = tensor.basic.Subtensor.collapse(
                            nw_slice,
                            lambda entry: isinstance(entry,
                                                    tensor.Variable))
                        new_o = subtens.make_node(new_outs[nw_pos],
                                                  *sl_ins).outputs[0]
                        if new_o.ndim > 0:
                            new_o = new_o[::cnf_slice[1]]
                        replaced_outs.append(idx)
                        old_new += [(cl[0].outputs[0], new_o)]
            # 3.8. Get replace pairs for those outputs that change
            # the number of stored intermediate steps
            for pos, old_outs in old_outputs:
                if len(old_outs) > 0:
                    nw_pos = compress_map[pos]
                    for k, old in enumerate(old_outs):
                        # Get the correct slice
                        cnf_slice, old_slices = slices[pos][k]
                        if type(cnf_slice[0]) is slice:
                            start = (cnf_slice[0].start - nw_steps -
                                     init_l[pos] + store_steps[pos])
                            if (cnf_slice[0].stop is not None and
                                cnf_slice[0].stop != maxsize):
                                stop = (cnf_slice[0].stop - nw_steps -
                                        init_l[pos] + store_steps[pos])
                            else:
                                stop = None
                            nw_slice = ((slice(sanitize(start),
                                               sanitize(stop),
                                               sanitize(cnf_slice[0].step)),)
                                        + tuple(old_slices[1:]))

                        else:
                            position = (cnf_slice[0] - nw_steps -
                                         init_l[pos] + store_steps[pos])

                            nw_slice = (sanitize(position),) + \
                                    tuple(old_slices[1:])

                        subtens = tensor.basic.Subtensor(nw_slice)
                        sl_ins = tensor.basic.Subtensor.collapse(
                            nw_slice,
                            lambda entry: isinstance(entry,
                                                     tensor.Variable))
                        new_o = subtens.make_node(new_outs[nw_pos],
                                                  *sl_ins).outputs[0]
                        if new_o.ndim > 0:
                            new_o = new_o[::cnf_slice[1]]
                        old_new += [(old, new_o)]

            # 3.9. Get replace pairs for all other nodes
            if flag_store or global_nsteps is not None:
                for idx, o in enumerate(node.outputs):
                    if not (idx in replaced_outs) and not idx in not_required:
                        nw_pos = compress_map[idx]
                        old_new += [(o, new_outs[nw_pos])]

                fgraph.replace_all_validate_remove(old_new,
                                                   remove=[node],
                                                   reason='scan_save_mem')
Exemplo n.º 6
0
    def process_node(self, fgraph, node):
        # this flag tells if there was any change during the last iterations
        changed = True
        clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
                        node.op.inputs, node.op.outputs)

        local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs)
        max_iterations = 2 * len(local_fgraph.toposort()) + 3
        counts = 0
        to_remove = []
        to_replace = []
        replace_with_in = []
        replace_with_out = []
        op = node.op
        # Construct the list of non_sequences to simplify a few things
        st = op.n_seqs
        st += int(numpy.sum([len(x) for x in
                             op.tap_array[:(op.n_mit_mot + op.n_mit_sot)]]))
        st += op.n_sit_sot
        st += op.n_shared_outs
        non_seqs = clean_inputs[st:]
        st = (op.n_seqs +
              op.n_mit_mot +
              op.n_mit_sot +
              op.n_sit_sot +
              op.n_nit_sot +
              op.n_shared_outs + 1)
        outer_non_seqs = node.inputs[st:]
        assert len(non_seqs) == len(outer_non_seqs)
        while changed and counts < max_iterations:
            counts += 1
            changed = False

            for nd in local_fgraph.toposort():
                if (numpy.all([(x in non_seqs) or
                               (x.owner in to_remove) or
                               isinstance(x, tensor.Constant)
                                 for x in nd.inputs]) and
                        # we can do this because the assumption is that a
                        # viewOp or deepCopyOp will be just at the end of the
                        # function and not somewhere in the middle ..
                        not isinstance(nd.op, theano.compile.ViewOp) and
                        not isinstance(nd.op, theano.compile.DeepCopyOp) and
                        # and we didn't already looked at this node
                        not nd in to_remove):

                    # We have a candidate node to removable
                    # Step 1. Reconstruct it on outside
                    to_remove.append(nd)
                    outside_ins = []
                    for x in nd.inputs:
                        if x in non_seqs:
                            outside_ins += [outer_non_seqs[non_seqs.index(x)]]
                        elif x in to_replace:
                            outside_ins += [
                                replace_with_out[to_replace.index(x)]]
                        elif isinstance(x, theano.Constant):
                            outside_ins += [x.clone()]
                        else:
                            raise Exception(
                                ('Error in the `scan_pushout_non_seq_'
                                 'operations`. The optimization tries '
                                 'to move some computation fron scan '
                                 'which is not allowed to move. Report '
                                 'this on theano-users list'), x)
                    outside_ins = [x.type.filter_variable(y) for x,y in
                                   zip(nd.inputs, outside_ins)]
                    nw_outer_node = nd.op.make_node(*outside_ins)
                    # Step 2. Create variables for replacements
                    for idx, y in enumerate(nd.outputs):

                        y_place_holder = scan_utils.safe_new(y, '_replace')
                        to_replace += [y]
                        replace_with_in += [y_place_holder]
                        assert type(y) == type(nw_outer_node.outputs[idx])
                        replace_with_out += [nw_outer_node.outputs[idx]]
                    changed = True
        if counts >= max_iterations:
            raise Exception('Error in the `scan_pushout_non_seq_operations`.'
                            ' The optimization exhausted the maximal number '
                            'of iterations allowed!')
        # We need to check all candidate replacements and choose those that
        # make sense for us

        # Step 1. which elements of `to_replace` are used by remaining
        # components of the inner function
        clean_to_replace = []
        clean_replace_with_in = []
        clean_replace_with_out = []
        existent_nodes = [nd for nd in local_fgraph.toposort()
                            if nd not in to_remove]
        to_keep = []
        for nd in existent_nodes:
            to_keep += nd.inputs
        for idx, out in enumerate(to_replace):
            if out in to_keep and out.owner not in existent_nodes:
                clean_to_replace += [out]
                clean_replace_with_in += [replace_with_in[idx]]
                clean_replace_with_out += [replace_with_out[idx]]

        if len(clean_to_replace) > 0:
            # We can finally put an end to all this madness
            givens = {}
            nw_outer = []
            nw_inner = []
            for to_repl, repl_in, repl_out in zip(clean_to_replace,
                                              clean_replace_with_in,
                                              clean_replace_with_out):
                if isinstance(repl_out, theano.Constant):
                    repl_in = repl_out.clone()
                else:
                    nw_inner += [repl_in]
                    nw_outer += [repl_out]
                givens[to_repl] = repl_in

            _op_outs = scan_utils.clone(clean_outputs,
                                        replace=givens)
            _op_ins = clean_inputs + nw_inner
            op_ins, op_outs = scan_utils.reconstruct_graph(_op_ins, _op_outs)
            # Reconstruct node
            nwScan = scan_op.Scan(op_ins, op_outs, op.info)
            nw_node = nwScan.make_node(* (node.inputs + nw_outer))
            fgraph.replace_all_validate_remove(
                zip(node.outputs, nw_node.outputs),
                remove=[node],
                reason='scan_push_computation_out')
            return True
        elif to_keep == []:
            # Nothing in the inner graph should be kept
            replace_with = {}
            for idx, out in enumerate(to_replace):
                if out in local_fgraph.outputs:
                    x = node.outputs[local_fgraph.outputs.index(out)]
                    y = replace_with_out[idx]
                    shape = [y.shape[idx] for idx in xrange(y.ndim)]
                    replace_with[x] = tensor.alloc(y,
                                                   node.inputs[0],
                                                   *shape)

            # We need to add one extra dimension to the outputs
            if replace_with:
                fgraph.replace_all_validate_remove(
                    replace_with.items(),
                    remove=[node],
                    reason='scan_push_computation_out')

        else:
            return False
Exemplo n.º 7
0
def scan_merge_inouts(node):
    if not isinstance(node.op, scan_op.Scan):
        return False

    a = scan_args(node.inputs, node.outputs,
                  node.op.inputs, node.op.outputs, node.op.info)

    inp_equiv = {}

    if has_duplicates(a.outer_in_seqs):
        new_outer_seqs = []
        new_inner_seqs = []
        for out_seq, in_seq in zip(a.outer_in_seqs, a.inner_in_seqs):
            if out_seq in new_outer_seqs:
                i = new_outer_seqs.index(out_seq)
                inp_equiv[in_seq] = new_inner_seqs[i]
            else:
                new_outer_seqs.append(out_seq)
                new_inner_seqs.append(in_seq)
        a.outer_in_seqs = new_outer_seqs
        a.inner_in_seqs = new_inner_seqs

    if has_duplicates(a.outer_in_non_seqs):
        new_outer_nseqs = []
        new_inner_nseqs = []
        for out_nseq, in_nseq in zip(a.outer_in_non_seqs, a.inner_in_non_seqs):
            if out_nseq in new_outer_nseqs:
                i = new_outer_nseqs.index(out_nseq)
                inp_equiv[in_nseq] = new_inner_nseqs[i]
            else:
                new_outer_nseqs.append(out_nseq)
                new_inner_nseqs.append(in_nseq)
        a.outer_in_non_seqs = new_outer_nseqs
        a.inner_in_non_seqs = new_inner_nseqs

    if len(inp_equiv) > 0:
        # do the replacement now. The rest will be left to ScanSaveMem
        inner_inputs = a.inner_inputs
        outer_inputs = a.outer_inputs
        info = a.info
        if info['as_while']:
            a_inner_outs = a.inner_outputs + a.cond
        else:
            a_inner_outs = a.inner_outputs
        inner_outputs = scan_utils.clone(a_inner_outs, replace=inp_equiv)

        op = scan_op.Scan(inner_inputs, inner_outputs, info)
        outputs = op(*outer_inputs)

        if not isinstance(outputs, (list, tuple)):
            outputs = [outputs]

        na = scan_args(outer_inputs, outputs, op.inputs, op.outputs, op.info)
    else:
        na = a

    # start again
    left = []
    right = []

    if has_duplicates(na.outer_in_shared):
        _left, _right = make_equiv(na.outer_in_shared, na.inner_in_shared)
        left += _left
        right += _right
    if has_duplicates(na.outer_in_sit_sot):
        _left, _right = make_equiv(na.outer_in_sit_sot, na.inner_in_sit_sot)
        left += _left
        right += _right
    if has_duplicates(na.outer_in_mit_mot):
        seen = {}
        for omm, imm, _sl in zip(na.outer_in_mit_mot,
                                 na.inner_in_mit_mot, na.mit_mot_in_slices):
            sl = tuple(_sl)
            if (omm, sl) in seen:
                simm = seen[(omm, sl)]
                left += imm
                right += simm
            else:
                seen[(omm, sl)] = imm

    if has_duplicates(na.outer_in_mit_sot):
        seen = {}
        for oms, ims, _sl in zip(na.outer_in_mit_sot,
                                 na.inner_in_mit_sot,
                                 na.mit_sot_in_slices):
            sl = tuple(_sl)
            if (oms, sl) in seen:
                sims = seen[(oms, sl)]
                left += ims
                right += sims
            else:
                seen[(oms, sl)] = ims

    def map_out(i, o, seen):
        for si, so in seen:
            if equal_computations([i], [si], left, right):
                return so
        seen.append((i, o))
        return o

    seen = []
    na.outer_out_nit_sot = [map_out(i, o, seen)
                            for i, o in zip(na.inner_out_nit_sot,
                                            na.outer_out_nit_sot)]

    seen = []
    na.outer_out_sit_sot = [map_out(i, o, seen)
                            for i, o in zip(na.inner_out_sit_sot,
                                            na.outer_out_sit_sot)]

    seen = []
    na.outer_out_mit_sot = [map_out(i, o, seen)
                            for i, o in zip(na.inner_out_mit_sot,
                                            na.outer_out_mit_sot)]

    seen = []
    new_outer_out_mit_mot = []
    for imm, omm, osl in zip(na.inner_out_mit_mot,
                             na.outer_out_mit_mot, na.mit_mot_out_slices):
        for simm, somm, sosl in seen:
            if osl == sosl and equal_computations(imm, simm, left, right):
                new_outer_out_mit_mot.append(somm)
                break
        else:
            seen.append((imm, omm, osl))
            new_outer_out_mit_mot.append(omm)
    na.outer_out_mit_mot = new_outer_out_mit_mot

    return na.outer_outputs
Exemplo n.º 8
0
    info['n_mit_mot'] = n_mit_mot
    info['n_mit_mot_outs'] = n_mit_mot_outs
    info['mit_mot_out_slices'] = mit_mot_out_slices
    info['n_mit_sot'] = n_mit_sot
    info['n_sit_sot'] = n_sit_sot
    info['n_shared_outs'] = n_shared_outs
    info['n_nit_sot'] = n_nit_sot
    info['truncate_gradient'] = truncate_gradient
    info['name'] = name
    info['mode'] = mode
    info['destroy_map'] = {}
    info['gpu'] = False
    info['as_while'] = as_while
    info['profile'] = profile

    local_op = scan_op.Scan(inner_inputs, new_outs, info)

    ##
    ### Step 8. Compute the outputs using the scan op
    ##
    _scan_inputs = (scan_seqs + mit_mot_scan_inputs + mit_sot_scan_inputs +
                    sit_sot_scan_inputs + shared_scan_inputs +
                    [actual_n_steps for x in xrange(n_nit_sot)] +
                    other_shared_scan_args + other_scan_args)

    scan_inputs = []
    for arg in [actual_n_steps] + _scan_inputs:
        try:
            arg = tensor.as_tensor_variable(arg)
        except TypeError:
            # This happens for Random States for e.g. but it is a good way