def scan_make_inplace(node): op = node.op if (isinstance(op, scan_op.Scan) and (not op.info['inplace']) and (not op.info['gpu'])): info = op.info.copy() info['inplace'] = True # inputs corresponding to sequences and n_steps ls_begin = node.inputs[:1 + op.n_seqs] ls = op.outer_mitmot(node.inputs) ls += op.outer_mitsot(node.inputs) ls += op.outer_sitsot(node.inputs) ls_end = op.outer_shared(node.inputs) ls_end += op.outer_nitsot(node.inputs) ls_end += op.outer_non_seqs(node.inputs) n_outs = len(ls) for idx in xrange(n_outs): if ls[idx] in ls[:idx]: ls[idx] = deep_copy_op(ls[idx]) inputs = ls_begin + ls + ls_end new_op = scan_op.Scan(op.inputs, op.outputs, info) return new_op.make_node(*inputs).outputs return False
def apply(self, fgraph): nodes = fgraph.toposort() scan_nodes = [x for x in nodes if (isinstance(x.op, scan_op.Scan) and x.op.info['gpu'] == self.gpu_flag)] for scan_idx in xrange(len(scan_nodes)): node = scan_nodes[scan_idx] op = node.op n_outs = (op.info['n_mit_mot'] + op.info['n_mit_sot'] + op.info['n_sit_sot']) for pos in xrange(n_outs): info = copy.deepcopy(op.info) if not 'destroy_map' in info: info['destroy_map'] = {} info['destroy_map'][pos] = [pos + 1 + op.info['n_seqs']] # inputs corresponding to sequences and n_steps ls_begin = node.inputs[:1 + op.n_seqs] ls = op.outer_mitmot(node.inputs) ls += op.outer_mitsot(node.inputs) ls += op.outer_sitsot(node.inputs) ls_end = op.outer_shared(node.inputs) ls_end += op.outer_nitsot(node.inputs) ls_end += op.outer_non_seqs(node.inputs) n_outs = len(ls) for idx in xrange(n_outs): if ls[idx] in ls[:idx]: ls[idx] = deep_copy_op(ls[idx]) inputs = ls_begin + ls + ls_end new_op = scan_op.Scan(op.inputs, op.outputs, info, typeConstructor=self.typeConstructor) new_outs = new_op.make_node(*inputs).outputs try: fgraph.replace_all_validate_remove( zip(node.outputs, new_outs), remove=[node], reason=self.__class__.__name__) op = new_op node = new_outs[0].owner except InconsistencyError, e: # Failed moving output to be comptued inplace pass