예제 #1
0
def scan_make_inplace(node):
    op = node.op
    if (isinstance(op, scan_op.Scan) and
        (not op.info['inplace']) and
        (not op.info['gpu'])):
        info = op.info.copy()
        info['inplace'] = True
        # inputs corresponding to sequences and n_steps
        ls_begin = node.inputs[:1 + op.n_seqs]
        ls = op.outer_mitmot(node.inputs)
        ls += op.outer_mitsot(node.inputs)
        ls += op.outer_sitsot(node.inputs)
        ls_end = op.outer_shared(node.inputs)
        ls_end += op.outer_nitsot(node.inputs)
        ls_end += op.outer_non_seqs(node.inputs)
        n_outs = len(ls)
        for idx in xrange(n_outs):
            if ls[idx] in ls[:idx]:
                ls[idx] = deep_copy_op(ls[idx])

        inputs = ls_begin + ls + ls_end
        new_op = scan_op.Scan(op.inputs,
                              op.outputs,
                              info)
        return new_op.make_node(*inputs).outputs
    return False
예제 #2
0
def scan_make_inplace(node):
    op = node.op
    if (isinstance(op, scan_op.Scan) and
        (not op.info['inplace']) and
        (not op.info['gpu'])):
        info = op.info.copy()
        info['inplace'] = True
        # inputs corresponding to sequences and n_steps
        ls_begin = node.inputs[:1 + op.n_seqs]
        ls = op.outer_mitmot(node.inputs)
        ls += op.outer_mitsot(node.inputs)
        ls += op.outer_sitsot(node.inputs)
        ls_end = op.outer_shared(node.inputs)
        ls_end += op.outer_nitsot(node.inputs)
        ls_end += op.outer_non_seqs(node.inputs)
        n_outs = len(ls)
        for idx in xrange(n_outs):
            if ls[idx] in ls[:idx]:
                ls[idx] = deep_copy_op(ls[idx])

        inputs = ls_begin + ls + ls_end
        new_op = scan_op.Scan(op.inputs,
                              op.outputs,
                              info)
        return new_op.make_node(*inputs).outputs
    return False
예제 #3
0
    def apply(self, fgraph):

        nodes = fgraph.toposort()
        scan_nodes = [x for x in nodes
                      if (isinstance(x.op, scan_op.Scan) and
                         x.op.info['gpu'] == self.gpu_flag)]
        for scan_idx in xrange(len(scan_nodes)):
            node = scan_nodes[scan_idx]
            op = node.op
            n_outs = (op.info['n_mit_mot'] +
                      op.info['n_mit_sot'] +
                      op.info['n_sit_sot'])
            for pos in xrange(n_outs):
                info = copy.deepcopy(op.info)
                if not 'destroy_map' in info:
                    info['destroy_map'] = {}
                info['destroy_map'][pos] = [pos + 1 + op.info['n_seqs']]
                # inputs corresponding to sequences and n_steps
                ls_begin = node.inputs[:1 + op.n_seqs]
                ls = op.outer_mitmot(node.inputs)
                ls += op.outer_mitsot(node.inputs)
                ls += op.outer_sitsot(node.inputs)
                ls_end = op.outer_shared(node.inputs)
                ls_end += op.outer_nitsot(node.inputs)
                ls_end += op.outer_non_seqs(node.inputs)
                n_outs = len(ls)
                for idx in xrange(n_outs):
                    if ls[idx] in ls[:idx]:
                        ls[idx] = deep_copy_op(ls[idx])

                inputs = ls_begin + ls + ls_end
                new_op = scan_op.Scan(op.inputs,
                                      op.outputs,
                                      info,
                                      typeConstructor=self.typeConstructor)

                new_outs = new_op.make_node(*inputs).outputs
                try:
                    fgraph.replace_all_validate_remove(
                        zip(node.outputs, new_outs),
                        remove=[node],
                        reason=self.__class__.__name__)
                    op = new_op
                    node = new_outs[0].owner
                except InconsistencyError, e:
                    # Failed moving output to be comptued inplace
                    pass
예제 #4
0
파일: scan_opt.py 프로젝트: srifai/Theano
    def apply(self, fgraph):

        nodes = fgraph.toposort()
        scan_nodes = [x for x in nodes
                      if (isinstance(x.op, scan_op.Scan) and
                         x.op.info['gpu'] == self.gpu_flag)]
        for scan_idx in xrange(len(scan_nodes)):
            node = scan_nodes[scan_idx]
            op = node.op
            n_outs = (op.info['n_mit_mot'] +
                      op.info['n_mit_sot'] +
                      op.info['n_sit_sot'])
            for pos in xrange(n_outs):
                info = copy.deepcopy(op.info)
                if not 'destroy_map' in info:
                    info['destroy_map'] = {}
                info['destroy_map'][pos] = [pos + 1 + op.info['n_seqs']]
                # inputs corresponding to sequences and n_steps
                ls_begin = node.inputs[:1 + op.n_seqs]
                ls = op.outer_mitmot(node.inputs)
                ls += op.outer_mitsot(node.inputs)
                ls += op.outer_sitsot(node.inputs)
                ls_end = op.outer_shared(node.inputs)
                ls_end += op.outer_nitsot(node.inputs)
                ls_end += op.outer_non_seqs(node.inputs)
                n_outs = len(ls)
                for idx in xrange(n_outs):
                    if ls[idx] in ls[:idx]:
                        ls[idx] = deep_copy_op(ls[idx])

                inputs = ls_begin + ls + ls_end
                new_op = scan_op.Scan(op.inputs,
                                      op.outputs,
                                      info,
                                      typeConstructor=self.typeConstructor)

                new_outs = new_op.make_node(*inputs).outputs
                try:
                    fgraph.replace_all_validate_remove(
                        zip(node.outputs, new_outs),
                        remove=[node],
                        reason=self.__class__.__name__)
                    op = new_op
                    node = new_outs[0].owner
                except InconsistencyError, e:
                    # Failed moving output to be comptued inplace
                    pass