Ejemplo n.º 1
0
    def belongs_to_set(self, node, set_nodes):
        """
        This function checks if node `node` belongs to `set_nodes`, in the
        sense that it can be merged together with every other node in
        `set_nodes`. In order for two nodes to be mergeable, they have to go
        over the same number of steps, have the same condition (if any),
        have the same value for truncate_gradient, and have the same mode.
        Questionable, we should also consider profile ?
        """
        rep = set_nodes[0]
        if not rep.op.as_while and node.op.as_while:
            return False

        nsteps = node.inputs[0]
        try:
            nsteps = int(get_constant_value(nsteps))
        except TypeError:
            pass


        rep_nsteps = rep.inputs[0]
        try:
            rep_nsteps = int(get_constant_value(rep_nsteps))
        except TypeError:
            pass

        # Check to see if it is an input of a different node
        can_add = True
        for nd in set_nodes:
            if find_up(node, nd) or find_up(nd, node):
                can_add = False

        can_add = can_add and (node.op.truncate_gradient ==
                               rep.op.truncate_gradient)
        can_add = can_add and (node.op.mode == rep.op.mode)
        if not node.op.as_while:
            return nsteps == rep_nsteps and can_add
        cond = node.op.outputs[-1]
        rep_cond = rep.op.outputs[-1]
        same_cond = scan_utils.equal_computations([cond], [rep_cond],
                                                  node.op.inputs,
                                                  rep.op.inputs)
        return same_cond and (nsteps == rep_nsteps) and can_add
Ejemplo n.º 2
0
    def belongs_to_set(self, node, set_nodes):
        """
        This function checks if node `node` belongs to `set_nodes`, in the
        sense that it can be merged together with every other node in
        `set_nodes`. In order for two nodes to be mergeable, they have to go
        over the same number of steps, have the same condition (if any),
        have the same value for truncate_gradient, and have the same mode.
        Questionable, we should also consider profile ?
        """
        rep = set_nodes[0]
        if not rep.op.as_while and node.op.as_while:
            return False

        nsteps = node.inputs[0]
        try:
            nsteps = int(get_constant_value(nsteps))
        except TypeError:
            pass

        rep_nsteps = rep.inputs[0]
        try:
            rep_nsteps = int(get_constant_value(rep_nsteps))
        except TypeError:
            pass

        # Check to see if it is an input of a different node
        can_add = True
        for nd in set_nodes:
            if find_up(node, nd) or find_up(nd, node):
                can_add = False

        can_add = can_add and (node.op.truncate_gradient ==
                               rep.op.truncate_gradient)
        can_add = can_add and (node.op.mode == rep.op.mode)
        if not node.op.as_while:
            return nsteps == rep_nsteps and can_add
        cond = node.op.outputs[-1]
        rep_cond = rep.op.outputs[-1]
        same_cond = scan_utils.equal_computations([cond], [rep_cond],
                                                  node.op.inputs,
                                                  rep.op.inputs)
        return same_cond and (nsteps == rep_nsteps) and can_add
Ejemplo n.º 3
0
def scan_merge_inouts(node):
    if not isinstance(node.op, scan_op.Scan):
        return False

    a = scan_args(node.inputs, node.outputs, node.op.inputs, node.op.outputs, node.op.info)

    inp_equiv = {}

    if has_duplicates(a.outer_in_seqs):
        new_outer_seqs = []
        new_inner_seqs = []
        for out_seq, in_seq in zip(a.outer_in_seqs, a.inner_in_seqs):
            if out_seq in new_outer_seqs:
                i = new_outer_seqs.index(out_seq)
                inp_equiv[in_seq] = new_inner_seqs[i]
            else:
                new_outer_seqs.append(out_seq)
                new_inner_seqs.append(in_seq)
        a.outer_in_seqs = new_outer_seqs
        a.inner_in_seqs = new_inner_seqs

    if has_duplicates(a.outer_in_non_seqs):
        new_outer_nseqs = []
        new_inner_nseqs = []
        for out_nseq, in_nseq in zip(a.outer_in_non_seqs, a.inner_in_non_seqs):
            if out_nseq in new_outer_nseqs:
                i = new_outer_nseqs.index(out_nseq)
                inp_equiv[in_nseq] = new_inner_nseqs[i]
            else:
                new_outer_nseqs.append(out_nseq)
                new_inner_nseqs.append(in_nseq)
        a.outer_in_non_seqs = new_outer_nseqs
        a.inner_in_non_seqs = new_inner_nseqs

    if len(inp_equiv) > 0:
        # do the replacement now. The rest will be left to ScanSaveMem
        inner_inputs = a.inner_inputs
        outer_inputs = a.outer_inputs
        info = a.info
        if info["as_while"]:
            a_inner_outs = a.inner_outputs + a.cond
        else:
            a_inner_outs = a.inner_outputs
        inner_outputs = scan_utils.clone(a_inner_outs, replace=inp_equiv)
        orig_outputs = a.outer_outputs

        op = scan_op.Scan(inner_inputs, inner_outputs, info)
        outputs = op(*outer_inputs)

        if not isinstance(outputs, (list, tuple)):
            outputs = [outputs]

        na = scan_args(outer_inputs, outputs, op.inputs, op.outputs, op.info)
    else:
        na = a

    # start again
    left = []
    right = []

    if has_duplicates(na.outer_in_shared):
        _left, _right = make_equiv(na.outer_in_shared, na.inner_in_shared)
        left += _left
        right += _right
    if has_duplicates(na.outer_in_sit_sot):
        _left, _right = make_equiv(na.outer_in_sit_sot, na.inner_in_sit_sot)
        left += _left
        right += _right
    if has_duplicates(na.outer_in_mit_mot):
        seen = {}
        for omm, imm, _sl in zip(na.outer_in_mit_mot, na.inner_in_mit_mot, na.mit_mot_in_slices):
            sl = tuple(_sl)
            if (omm, sl) in seen:
                simm = seen[(omm, sl)]
                left += imm
                right += simm
            else:
                seen[(omm, sl)] = imm

    if has_duplicates(na.outer_in_mit_sot):
        seen = {}
        for oms, ims, _sl in zip(na.outer_in_mit_sot, na.inner_in_mit_sot, na.mit_sot_in_slices):
            sl = tuple(_sl)
            if (oms, sl) in seen:
                sims = seen[(oms, sl)]
                left += ims
                right += sims
            else:
                seen[(oms, sl)] = ims

    def map_out(i, o, seen):
        for si, so in seen:
            if equal_computations([i], [si], left, right):
                return so
        seen.append((i, o))
        return o

    seen = []
    na.outer_out_nit_sot = [map_out(i, o, seen) for i, o in zip(na.inner_out_nit_sot, na.outer_out_nit_sot)]

    seen = []
    na.outer_out_sit_sot = [map_out(i, o, seen) for i, o in zip(na.inner_out_sit_sot, na.outer_out_sit_sot)]

    seen = []
    na.outer_out_mit_sot = [map_out(i, o, seen) for i, o in zip(na.inner_out_mit_sot, na.outer_out_mit_sot)]

    seen = []
    new_outer_out_mit_mot = []
    for imm, omm, osl in zip(na.inner_out_mit_mot, na.outer_out_mit_mot, na.mit_mot_out_slices):
        for simm, somm, sosl in seen:
            if osl == sosl and equal_computations(imm, simm, left, right):
                new_outer_out_mit_mot.append(somm)
                break
        else:
            seen.append((imm, omm, osl))
            new_outer_out_mit_mot.append(omm)
    na.outer_out_mit_mot = new_outer_out_mit_mot

    return na.outer_outputs
Ejemplo n.º 4
0
 def map_out(i, o, seen):
     for si, so in seen:
         if equal_computations([i], [si], left, right):
             return so
     seen.append((i, o))
     return o
Ejemplo n.º 5
0
 def map_out(i, o, seen):
     for si, so in seen:
         if equal_computations([i], [si], left, right):
             return so
     seen.append((i, o))
     return o
Ejemplo n.º 6
0
def scan_merge_inouts(node):
    if not isinstance(node.op, scan_op.Scan):
        return False

    a = scan_args(node.inputs, node.outputs,
                  node.op.inputs, node.op.outputs, node.op.info)

    inp_equiv = {}

    if has_duplicates(a.outer_in_seqs):
        new_outer_seqs = []
        new_inner_seqs = []
        for out_seq, in_seq in zip(a.outer_in_seqs, a.inner_in_seqs):
            if out_seq in new_outer_seqs:
                i = new_outer_seqs.index(out_seq)
                inp_equiv[in_seq] = new_inner_seqs[i]
            else:
                new_outer_seqs.append(out_seq)
                new_inner_seqs.append(in_seq)
        a.outer_in_seqs = new_outer_seqs
        a.inner_in_seqs = new_inner_seqs

    if has_duplicates(a.outer_in_non_seqs):
        new_outer_nseqs = []
        new_inner_nseqs = []
        for out_nseq, in_nseq in zip(a.outer_in_non_seqs, a.inner_in_non_seqs):
            if out_nseq in new_outer_nseqs:
                i = new_outer_nseqs.index(out_nseq)
                inp_equiv[in_nseq] = new_inner_nseqs[i]
            else:
                new_outer_nseqs.append(out_nseq)
                new_inner_nseqs.append(in_nseq)
        a.outer_in_non_seqs = new_outer_nseqs
        a.inner_in_non_seqs = new_inner_nseqs

    if len(inp_equiv) > 0:
        # do the replacement now. The rest will be left to ScanSaveMem
        inner_inputs = a.inner_inputs
        outer_inputs = a.outer_inputs
        info = a.info
        if info['as_while']:
            a_inner_outs = a.inner_outputs + a.cond
        else:
            a_inner_outs = a.inner_outputs
        inner_outputs = scan_utils.clone(a_inner_outs, replace=inp_equiv)

        op = scan_op.Scan(inner_inputs, inner_outputs, info)
        outputs = op(*outer_inputs)

        if not isinstance(outputs, (list, tuple)):
            outputs = [outputs]

        na = scan_args(outer_inputs, outputs, op.inputs, op.outputs, op.info)
    else:
        na = a

    # start again
    left = []
    right = []

    if has_duplicates(na.outer_in_shared):
        _left, _right = make_equiv(na.outer_in_shared, na.inner_in_shared)
        left += _left
        right += _right
    if has_duplicates(na.outer_in_sit_sot):
        _left, _right = make_equiv(na.outer_in_sit_sot, na.inner_in_sit_sot)
        left += _left
        right += _right
    if has_duplicates(na.outer_in_mit_mot):
        seen = {}
        for omm, imm, _sl in zip(na.outer_in_mit_mot,
                                 na.inner_in_mit_mot, na.mit_mot_in_slices):
            sl = tuple(_sl)
            if (omm, sl) in seen:
                simm = seen[(omm, sl)]
                left += imm
                right += simm
            else:
                seen[(omm, sl)] = imm

    if has_duplicates(na.outer_in_mit_sot):
        seen = {}
        for oms, ims, _sl in zip(na.outer_in_mit_sot,
                                 na.inner_in_mit_sot,
                                 na.mit_sot_in_slices):
            sl = tuple(_sl)
            if (oms, sl) in seen:
                sims = seen[(oms, sl)]
                left += ims
                right += sims
            else:
                seen[(oms, sl)] = ims

    def map_out(i, o, seen):
        for si, so in seen:
            if equal_computations([i], [si], left, right):
                return so
        seen.append((i, o))
        return o

    seen = []
    na.outer_out_nit_sot = [map_out(i, o, seen)
                            for i, o in zip(na.inner_out_nit_sot,
                                            na.outer_out_nit_sot)]

    seen = []
    na.outer_out_sit_sot = [map_out(i, o, seen)
                            for i, o in zip(na.inner_out_sit_sot,
                                            na.outer_out_sit_sot)]

    seen = []
    na.outer_out_mit_sot = [map_out(i, o, seen)
                            for i, o in zip(na.inner_out_mit_sot,
                                            na.outer_out_mit_sot)]

    seen = []
    new_outer_out_mit_mot = []
    for imm, omm, osl in zip(na.inner_out_mit_mot,
                             na.outer_out_mit_mot, na.mit_mot_out_slices):
        for simm, somm, sosl in seen:
            if osl == sosl and equal_computations(imm, simm, left, right):
                new_outer_out_mit_mot.append(somm)
                break
        else:
            seen.append((imm, omm, osl))
            new_outer_out_mit_mot.append(omm)
    na.outer_out_mit_mot = new_outer_out_mit_mot

    return na.outer_outputs