Example #1
0
def _cond_transpose(reduce_axes, cts, *args, branches, linear):
    index, *ops = args
    in_avals = _map(raise_to_shaped, branches[0].in_avals)
    num_res = len(ops) - sum(linear)

    branches_trans = tuple(
        _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes)
        for jaxpr in branches)
    lin_in_avals = [
        raise_to_shaped(a, weak_type=False) for a, l in zip(in_avals, linear)
        if l
    ]
    assert all(
        core.typematch(out_aval, lin_in_aval) for jaxpr in branches_trans
        for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals))

    res = ops[:num_res]
    cts = _map(ad.instantiate_zeros_aval, branches[0].out_avals, cts)
    linear_trans = (False, ) * num_res + (True, ) * len(cts)

    out = cond_p.bind(index,
                      *res,
                      *cts,
                      branches=branches_trans,
                      linear=linear_trans)
    assert all(_map(core.typecheck, lin_in_avals, out))

    out_iter = iter(out)
    out = [next(out_iter) if l else None for l in linear]
    assert next(out_iter, None) is None
    return [None] + out
Example #2
0
def check_arg_avals_for_call(ref_avals, arg_avals):
    if len(ref_avals) != len(arg_avals):
        raise TypeError(f"Computation compiled for {len(ref_avals)} inputs "
                        f"but called with {len(arg_avals)}")
    for ref_aval, arg_aval in zip(ref_avals, arg_avals):
        if not core.typematch(ref_aval, arg_aval):
            ref_avals_fmt = ', '.join(str(a) for a in ref_avals)
            arg_avals_fmt = ', '.join(str(a) for a in arg_avals)
            raise TypeError(
                f"Computation compiled for input types:\n  {ref_avals_fmt}\n"
                f"called with:\n  {arg_avals_fmt}")
Example #3
0
def _show_diff(array1, array2):
    if core.typematch(array1, array2):
        return f"{array1}"
    return f"DIFFERENT {array1} vs. {array2}"