コード例 #1
0
ファイル: joint.py プロジェクト: tessythomas123/funsor
def eager_reduce_exp(op, arg, reduced_vars):
    # x.exp().reduce(ops.add) == x.reduce(ops.logaddexp).exp()
    log_result = arg.arg.reduce(ops.logaddexp, reduced_vars)
    if log_result is not normalize(Reduce, ops.logaddexp, arg.arg,
                                   reduced_vars):
        return log_result.exp()
    return None
コード例 #2
0
def normalize_contraction_commutative_canonical_order(red_op, bin_op,
                                                      reduced_vars, *terms):
    # when bin_op is commutative, put terms into a canonical order for pattern matching
    new_terms = tuple(v for i, v in sorted(
        enumerate(terms),
        key=lambda t: (ORDERING.get(type(t[1]).__origin__, -1), t[0])))
    if any(v is not vv for v, vv in zip(terms, new_terms)):
        return Contraction(red_op, bin_op, reduced_vars, *new_terms)
    return normalize(Contraction, red_op, bin_op, reduced_vars, new_terms)
コード例 #3
0
ファイル: cnf.py プロジェクト: lawrencechen0921/funsor
def eager_contraction_generic_recursive(red_op, bin_op, reduced_vars, terms):

    # push down leaf reductions
    terms, reduced_vars, leaf_reduced = list(terms), frozenset(
        reduced_vars), False
    for i, v in enumerate(terms):
        unique_vars = reduced_vars.intersection(v.inputs) - \
            frozenset().union(*(reduced_vars.intersection(vv.inputs) for vv in terms if vv is not v))
        if unique_vars:
            result = v.reduce(red_op, unique_vars)
            if result is not normalize(Contraction, red_op, nullop,
                                       unique_vars, (v, )):
                terms[i] = result
                reduced_vars -= unique_vars
                leaf_reduced = True

    if leaf_reduced:
        return Contraction(red_op, bin_op, reduced_vars, *terms)

    # exploit associativity to recursively evaluate this contraction
    # a bit expensive, but handles interpreter-imposed directionality constraints
    terms = tuple(terms)
    # return reduce(bin_op, terms).reduce(red_op, reduced_vars)
    # for i, (lhs, rhs) in enumerate(zip(terms[0:-1], terms[1:])):
    for i, lhs in enumerate(terms[0:-1]):
        for j_, rhs in enumerate(terms[i + 1:]):
            j = i + j_ + 1
            unique_vars = reduced_vars.intersection(lhs.inputs, rhs.inputs) - \
                frozenset().union(*(reduced_vars.intersection(vv.inputs)
                                    for vv in terms[:i] + terms[i+1:j] + terms[j+1:]))
            result = Contraction(red_op, bin_op, unique_vars, lhs, rhs)
            if result is not normalize(Contraction, red_op, bin_op,
                                       unique_vars,
                                       (lhs, rhs)):  # did we make progress?
                # pick the first evaluable pair
                reduced_vars -= unique_vars
                new_terms = terms[:i] + (result, ) + terms[i + 1:j] + terms[j +
                                                                            1:]
                return Contraction(red_op, bin_op, reduced_vars, *new_terms)

    return None
コード例 #4
0
def normalize_contraction_generic_args(red_op, bin_op, reduced_vars, *terms):
    return normalize(Contraction, red_op, bin_op, reduced_vars, tuple(terms))