def belongs_to_set(self, node, set_nodes): """ This function checks if node `node` belongs to `set_nodes`, in the sense that it can be merged together with every other node in `set_nodes`. In order for two nodes to be mergeable, they have to go over the same number of steps, have the same condition (if any), have the same value for truncate_gradient, and have the same mode. Questionable, we should also consider profile ? """ rep = set_nodes[0] if not rep.op.as_while and node.op.as_while: return False nsteps = node.inputs[0] try: nsteps = int(get_constant_value(nsteps)) except TypeError: pass rep_nsteps = rep.inputs[0] try: rep_nsteps = int(get_constant_value(rep_nsteps)) except TypeError: pass # Check to see if it is an input of a different node can_add = True for nd in set_nodes: if find_up(node, nd) or find_up(nd, node): can_add = False can_add = can_add and (node.op.truncate_gradient == rep.op.truncate_gradient) can_add = can_add and (node.op.mode == rep.op.mode) if not node.op.as_while: return nsteps == rep_nsteps and can_add cond = node.op.outputs[-1] rep_cond = rep.op.outputs[-1] same_cond = scan_utils.equal_computations([cond], [rep_cond], node.op.inputs, rep.op.inputs) return same_cond and (nsteps == rep_nsteps) and can_add
def scan_merge_inouts(node): if not isinstance(node.op, scan_op.Scan): return False a = scan_args(node.inputs, node.outputs, node.op.inputs, node.op.outputs, node.op.info) inp_equiv = {} if has_duplicates(a.outer_in_seqs): new_outer_seqs = [] new_inner_seqs = [] for out_seq, in_seq in zip(a.outer_in_seqs, a.inner_in_seqs): if out_seq in new_outer_seqs: i = new_outer_seqs.index(out_seq) inp_equiv[in_seq] = new_inner_seqs[i] else: new_outer_seqs.append(out_seq) new_inner_seqs.append(in_seq) a.outer_in_seqs = new_outer_seqs a.inner_in_seqs = new_inner_seqs if has_duplicates(a.outer_in_non_seqs): new_outer_nseqs = [] new_inner_nseqs = [] for out_nseq, in_nseq in zip(a.outer_in_non_seqs, a.inner_in_non_seqs): if out_nseq in new_outer_nseqs: i = new_outer_nseqs.index(out_nseq) inp_equiv[in_nseq] = new_inner_nseqs[i] else: new_outer_nseqs.append(out_nseq) new_inner_nseqs.append(in_nseq) a.outer_in_non_seqs = new_outer_nseqs a.inner_in_non_seqs = new_inner_nseqs if len(inp_equiv) > 0: # do the replacement now. The rest will be left to ScanSaveMem inner_inputs = a.inner_inputs outer_inputs = a.outer_inputs info = a.info if info["as_while"]: a_inner_outs = a.inner_outputs + a.cond else: a_inner_outs = a.inner_outputs inner_outputs = scan_utils.clone(a_inner_outs, replace=inp_equiv) orig_outputs = a.outer_outputs op = scan_op.Scan(inner_inputs, inner_outputs, info) outputs = op(*outer_inputs) if not isinstance(outputs, (list, tuple)): outputs = [outputs] na = scan_args(outer_inputs, outputs, op.inputs, op.outputs, op.info) else: na = a # start again left = [] right = [] if has_duplicates(na.outer_in_shared): _left, _right = make_equiv(na.outer_in_shared, na.inner_in_shared) left += _left right += _right if has_duplicates(na.outer_in_sit_sot): _left, _right = make_equiv(na.outer_in_sit_sot, na.inner_in_sit_sot) left += _left right += _right if has_duplicates(na.outer_in_mit_mot): seen = {} for omm, imm, _sl in zip(na.outer_in_mit_mot, na.inner_in_mit_mot, na.mit_mot_in_slices): sl = tuple(_sl) if (omm, sl) in seen: simm = seen[(omm, sl)] left += imm right += simm else: seen[(omm, sl)] = imm if has_duplicates(na.outer_in_mit_sot): seen = {} for oms, ims, _sl in zip(na.outer_in_mit_sot, na.inner_in_mit_sot, na.mit_sot_in_slices): sl = tuple(_sl) if (oms, sl) in seen: sims = seen[(oms, sl)] left += ims right += sims else: seen[(oms, sl)] = ims def map_out(i, o, seen): for si, so in seen: if equal_computations([i], [si], left, right): return so seen.append((i, o)) return o seen = [] na.outer_out_nit_sot = [map_out(i, o, seen) for i, o in zip(na.inner_out_nit_sot, na.outer_out_nit_sot)] seen = [] na.outer_out_sit_sot = [map_out(i, o, seen) for i, o in zip(na.inner_out_sit_sot, na.outer_out_sit_sot)] seen = [] na.outer_out_mit_sot = [map_out(i, o, seen) for i, o in zip(na.inner_out_mit_sot, na.outer_out_mit_sot)] seen = [] new_outer_out_mit_mot = [] for imm, omm, osl in zip(na.inner_out_mit_mot, na.outer_out_mit_mot, na.mit_mot_out_slices): for simm, somm, sosl in seen: if osl == sosl and equal_computations(imm, simm, left, right): new_outer_out_mit_mot.append(somm) break else: seen.append((imm, omm, osl)) new_outer_out_mit_mot.append(omm) na.outer_out_mit_mot = new_outer_out_mit_mot return na.outer_outputs
def map_out(i, o, seen): for si, so in seen: if equal_computations([i], [si], left, right): return so seen.append((i, o)) return o
def scan_merge_inouts(node): if not isinstance(node.op, scan_op.Scan): return False a = scan_args(node.inputs, node.outputs, node.op.inputs, node.op.outputs, node.op.info) inp_equiv = {} if has_duplicates(a.outer_in_seqs): new_outer_seqs = [] new_inner_seqs = [] for out_seq, in_seq in zip(a.outer_in_seqs, a.inner_in_seqs): if out_seq in new_outer_seqs: i = new_outer_seqs.index(out_seq) inp_equiv[in_seq] = new_inner_seqs[i] else: new_outer_seqs.append(out_seq) new_inner_seqs.append(in_seq) a.outer_in_seqs = new_outer_seqs a.inner_in_seqs = new_inner_seqs if has_duplicates(a.outer_in_non_seqs): new_outer_nseqs = [] new_inner_nseqs = [] for out_nseq, in_nseq in zip(a.outer_in_non_seqs, a.inner_in_non_seqs): if out_nseq in new_outer_nseqs: i = new_outer_nseqs.index(out_nseq) inp_equiv[in_nseq] = new_inner_nseqs[i] else: new_outer_nseqs.append(out_nseq) new_inner_nseqs.append(in_nseq) a.outer_in_non_seqs = new_outer_nseqs a.inner_in_non_seqs = new_inner_nseqs if len(inp_equiv) > 0: # do the replacement now. The rest will be left to ScanSaveMem inner_inputs = a.inner_inputs outer_inputs = a.outer_inputs info = a.info if info['as_while']: a_inner_outs = a.inner_outputs + a.cond else: a_inner_outs = a.inner_outputs inner_outputs = scan_utils.clone(a_inner_outs, replace=inp_equiv) op = scan_op.Scan(inner_inputs, inner_outputs, info) outputs = op(*outer_inputs) if not isinstance(outputs, (list, tuple)): outputs = [outputs] na = scan_args(outer_inputs, outputs, op.inputs, op.outputs, op.info) else: na = a # start again left = [] right = [] if has_duplicates(na.outer_in_shared): _left, _right = make_equiv(na.outer_in_shared, na.inner_in_shared) left += _left right += _right if has_duplicates(na.outer_in_sit_sot): _left, _right = make_equiv(na.outer_in_sit_sot, na.inner_in_sit_sot) left += _left right += _right if has_duplicates(na.outer_in_mit_mot): seen = {} for omm, imm, _sl in zip(na.outer_in_mit_mot, na.inner_in_mit_mot, na.mit_mot_in_slices): sl = tuple(_sl) if (omm, sl) in seen: simm = seen[(omm, sl)] left += imm right += simm else: seen[(omm, sl)] = imm if has_duplicates(na.outer_in_mit_sot): seen = {} for oms, ims, _sl in zip(na.outer_in_mit_sot, na.inner_in_mit_sot, na.mit_sot_in_slices): sl = tuple(_sl) if (oms, sl) in seen: sims = seen[(oms, sl)] left += ims right += sims else: seen[(oms, sl)] = ims def map_out(i, o, seen): for si, so in seen: if equal_computations([i], [si], left, right): return so seen.append((i, o)) return o seen = [] na.outer_out_nit_sot = [map_out(i, o, seen) for i, o in zip(na.inner_out_nit_sot, na.outer_out_nit_sot)] seen = [] na.outer_out_sit_sot = [map_out(i, o, seen) for i, o in zip(na.inner_out_sit_sot, na.outer_out_sit_sot)] seen = [] na.outer_out_mit_sot = [map_out(i, o, seen) for i, o in zip(na.inner_out_mit_sot, na.outer_out_mit_sot)] seen = [] new_outer_out_mit_mot = [] for imm, omm, osl in zip(na.inner_out_mit_mot, na.outer_out_mit_mot, na.mit_mot_out_slices): for simm, somm, sosl in seen: if osl == sosl and equal_computations(imm, simm, left, right): new_outer_out_mit_mot.append(somm) break else: seen.append((imm, omm, osl)) new_outer_out_mit_mot.append(omm) na.outer_out_mit_mot = new_outer_out_mit_mot return na.outer_outputs