Exemple #1
0
                      init_val), init_val_bd


def _jaxtupletree_select(pred, on_true, on_false):
    aval = core.get_aval(on_true)
    if type(aval) is core.AbstractTuple:
        return core.pack(
            _map(partial(_jaxtupletree_select, pred), on_true, on_false))
    elif isinstance(aval, UnshapedArray):
        return lax.select(pred, on_true, on_false)
    else:
        raise TypeError(aval)


while_p = lax.Primitive('while')
while_p.def_impl(partial(xla.apply_primitive, while_p))
while_p.def_abstract_eval(_while_loop_abstract_eval)
xla.initial_style_translations[while_p] = _while_loop_translation_rule
batching.primitive_batchers[while_p] = _while_loop_batching_rule

### cond


def cond(pred, true_operand, true_fun, false_operand, false_fun):
    def trace_jaxpr(fun, operand):
        op_flat, in_tree = pytree_to_flatjaxtuple(operand)
        fun_flat, out_tree = pytree_fun_to_flatjaxtuple_fun(
            lu.wrap_init(fun), (in_tree, ))
        jaxpr, pvout, consts = pe.trace_to_jaxpr(fun_flat,
                                                 (_abstractify(op_flat), ))
        return op_flat, jaxpr, consts, pvout, out_tree
Exemple #2
0
# psum translation rule has special handling for complex dtypes
def _psum_translation_rule(c, val, replica_groups):
    psum = partial(_allreduce_translation_rule,
                   lax.add_p,
                   c,
                   replica_groups=replica_groups)
    dtype = c.GetShape(val).numpy_dtype()
    if dtypes.issubdtype(dtype, onp.complexfloating):
        return c.Complex(psum(c.Real(val)), psum(c.Imag(val)))
    else:
        return psum(val)


psum_p = standard_pmap_primitive('psum')
pxla.split_axis_rules[psum_p] = \
    partial(_allreduce_split_axis_rule, psum_p, lax._reduce_sum)
xla.parallel_translations[psum_p] = _psum_translation_rule
pxla.parallel_pure_rules[psum_p] = lambda x, shape: x * prod(shape)
ad.deflinear(psum_p, lambda t, axis_name: [psum(t, axis_name)])
pxla.multi_host_supported_collectives.add(psum_p)

pmax_p = standard_pmap_primitive('pmax')
xla.parallel_translations[pmax_p] = \
    partial(_allreduce_translation_rule, lax.max_p)
pxla.split_axis_rules[pmax_p] = \
    partial(_allreduce_split_axis_rule, pmax_p, lax._reduce_max)

pmin_p = standard_pmap_primitive('pmin')
xla.parallel_translations[pmin_p] = \
    partial(_allreduce_translation_rule, lax.min_p)
pxla.split_axis_rules[pmin_p] = \
Exemple #3
0
def _defbroadcasting(prim):
    parallel.papply_primitive_rules[prim] = partial(_broadcasting_papply, prim)
Exemple #4
0
    CallSpec(fun_call_jitted, (R(1, ), )),
    CallSpec(fun_with_nested_calls, (R(), )),
    CallSpec(fun_with_nested_calls, (R(3, 2), )),
    CallSpec(fun_with_nested_calls_2, (R(1, 2), )),
]


def jvp_unlinearized(f, primals, tangents):
    out, jvp = linearize(f, *primals)
    return out, jvp(*tangents)


test_specs = []
for ts in test_specs_base:
    test_specs.append(ts)
    test_specs.append(CallSpec(partial(jvp, ts.fun), (ts.args, ts.args)))
    test_specs.append(CallSpec(jit(ts.fun), ts.args))
    test_specs.append(CallSpec(jit(jit(ts.fun)), ts.args))
    test_specs.append(
        CallSpec(partial(jvp_unlinearized, ts.fun), (ts.args, ts.args)))


def fwd_deriv(f):
    def df(x):
        return jvp(f, (x, ), (1.0, ))[1]

    return df


class CoreTest(jtu.JaxTestCase):
    def test_tree_multimap(self):
Exemple #5
0
 def test_jvp_linearized(self, f, args):
     jtu.check_jvp(f,
                   partial(jvp_unlinearized, f),
                   args,
                   rtol={np.float32: 3e-2})
Exemple #6
0
 def test_jvp_linearized(self, f, args):
     jtu.check_jvp(f, partial(jvp_unlinearized, f), args)
Exemple #7
0
def _defreducer(prim, collective_prim):
    parallel.papply_primitive_rules[prim] = partial(_reducer_papply, prim,
                                                    collective_prim)
Exemple #8
0
 def testDot5(self):
     f = vmap(partial(np.einsum, 'ij,j->i'), (None, 0))
     jaxpr = make_jaxpr(f)(np.zeros((1000, 1000)), np.zeros((1000, 1000)))
     assert "broadcast" not in str(jaxpr)
Exemple #9
0
import pytest
import scipy.special as osp_special
import scipy.stats as osp_stats
from numpy.testing import assert_allclose

import jax.numpy as np
from jax import grad, jit, lax, random
from jax.scipy.special import expit
from jax.util import partial

from numpyro.distributions.util import binary_cross_entropy_with_logits, standard_gamma, xlog1py, xlogy

_zeros = partial(lax.full_like, fill_value=0)


@pytest.mark.parametrize('x, y', [
    (np.array([1]), np.array([1, 2, 3])),
    (np.array([0]), np.array([0, 0])),
    (np.array([[0.], [0.]]), np.array([1., 2.])),
])
@pytest.mark.parametrize('jit_fn', [False, True])
def test_xlogy(x, y, jit_fn):
    fn = xlogy if not jit_fn else jit(xlogy)
    assert_allclose(fn(x, y), osp_special.xlogy(x, y))


@pytest.mark.parametrize('x, y, grad1, grad2', [
    (np.array([1., 1., 1.]), np.array([1., 2., 3.]), np.log(np.array(
        [1, 2, 3])), np.array([1., 0.5, 1. / 3])),
    (np.array([1.]), np.array([1., 2., 3.]), np.sum(np.log(np.array(
        [1, 2, 3]))), np.array([1., 0.5, 1. / 3])),
Exemple #10
0
                      init_val), init_val_bd


def _jaxtupletree_select(pred, on_true, on_false):
    aval = core.get_aval(on_true)
    if type(aval) is core.AbstractTuple:
        return core.pack(
            map(partial(_jaxtupletree_select, pred), on_true, on_false))
    elif isinstance(aval, UnshapedArray):
        return lax.select(pred, on_true, on_false)
    else:
        raise TypeError(aval)


while_p = lax.Primitive('while')
while_p.def_impl(partial(xla.apply_primitive, while_p))
while_p.def_abstract_eval(_while_loop_abstract_eval)
xla.translations[while_p] = _while_loop_translation_rule
batching.primitive_batchers[while_p] = _while_loop_batching_rule

### cond


def cond(pred, true_operand, true_fun, false_operand, false_fun):
    def trace_jaxpr(fun, operand):
        op_flat, in_tree = pytree_to_flatjaxtuple(operand)
        fun_flat, out_tree = pytree_fun_to_flatjaxtuple_fun(
            lu.wrap_init(fun), (in_tree, ))
        jaxpr, pvout, consts = pe.trace_to_jaxpr(fun_flat,
                                                 (_abstractify(op_flat), ))
        return op_flat, jaxpr, consts, pvout, out_tree
Exemple #11
0
 def __init__(self, func, prim):
     self.func = func
     self.prim = prim
     # Register a default inverse that inverts the wrapped function
     ildj_registry[self.prim] = jax_util.partial(core.call_ildj, self.prim)
Exemple #12
0
def _scan_partial_eval(trace, *tracers, **kwargs):
    forward, length, num_consts, num_carry, jaxpr, linear = split_dict(
        kwargs,
        ["forward", "length", "num_consts", "num_carry", "jaxpr", "linear"])
    num_xs = len(jaxpr.in_avals) - num_carry - num_consts
    num_ys = len(jaxpr.out_avals) - num_carry

    unknowns = original_unknowns = [t.pval[0] is not None for t in tracers]
    const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry])

    carry_uk = init_uk
    for _ in range(1000):
        unknowns = const_uk + carry_uk + xs_uk
        jaxpr_1, jaxpr_2, out_uk = pe.partial_eval_jaxpr(jaxpr,
                                                         unknowns,
                                                         instantiate=carry_uk +
                                                         [False] * num_ys)
        carry_uk_out, ys_uk = out_uk[:num_carry], out_uk[num_carry:]
        if carry_uk_out == carry_uk:
            break
        else:
            carry_uk = carry_uk_out
    else:
        raise FixedPointError

    in_consts = [
        core.unit if uk else t.pval[1] for uk, t in zip(unknowns, tracers)
    ]
    new_tracers = [
        trace.instantiate_const(t)
        if uk else trace.new_instantiated_literal(core.unit)
        for uk, t in zip(unknowns, tracers)
    ]

    carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
    ys_avals = _map(partial(_promote_aval_rank, length), y_avals)
    out_avals = carry_avals + ys_avals
    out_pvs = [aval if uk else None for aval, uk in zip(out_avals, out_uk)]

    linear_1 = [lin or uk for uk, lin in zip(unknowns, linear)]
    out_flat = scan_p.bind(*in_consts,
                           forward=forward,
                           length=length,
                           jaxpr=jaxpr_1,
                           num_consts=num_consts,
                           num_carry=num_carry,
                           linear=linear_1)
    out_carry, ys, residuals = split_list(out_flat, [num_carry, num_ys])
    out_consts = out_carry + ys
    residual_tracers = _map(trace.new_instantiated_const, residuals)
    out_tracers = [
        pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None)
        for pv, const in zip(out_pvs, out_consts)
    ]
    linear_2 = ([lin or not uk for uk, lin in zip(unknowns, linear)] +
                [False] * len(residual_tracers))
    eqn = pe.new_jaxpr_eqn(
        new_tracers + residual_tracers, out_tracers, scan_p, (),
        dict(forward=forward,
             length=length,
             jaxpr=jaxpr_2,
             num_consts=num_consts,
             num_carry=num_carry,
             linear=linear_2))
    for t in out_tracers:
        t.recipe = eqn
    return out_tracers
Exemple #13
0
def _while_loop_translation_rule(c, axis_env, *args, **kwargs):
    backend = kwargs.pop('backend', None)
    cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts = split_dict(
        kwargs, ["cond_jaxpr", "body_jaxpr", "cond_nconsts", "body_nconsts"])
    cond_consts, body_consts, init_vals = split_list(
        args, [cond_nconsts, body_nconsts])
    batched = bool(cond_jaxpr.out_avals[0].shape)

    # Since jaxprs don't have tuples and have multiple return values, but we need
    # the HLO While loop to take a single tuple input and output a single boolean
    # (for the cond computation) or a single tuple output (for the body
    # computation), we build XLA computations that handle the tuple munging before
    # generating a Call into the computations formed from the jaxprs.

    init_carry = c.Tuple(*(cond_consts + body_consts + init_vals))

    cond_c = xb.make_computation_builder("cond_computation")
    cond_carry = cond_c.ParameterWithShape(c.GetShape(init_carry))
    cond_carry_elts = [
        cond_c.GetTupleElement(cond_carry, i) for i in range(len(args))
    ]
    x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts])
    cond_outs = cond_c.Call(
        xla.jaxpr_computation(cond_jaxpr.jaxpr, backend, axis_env,
                              cond_jaxpr.literals, (),
                              *_map(cond_c.GetShape, x + z)), x + z)
    pred = cond_c.GetTupleElement(cond_outs, 0)
    if batched:
        scalar = xla_client.Shape.array_shape(onp.dtype(onp.bool_), ())
        or_ = xla.primitive_computation(lax.or_p, scalar, scalar)
        pred = cond_c.Reduce(pred, cond_c.Constant(onp.array(False)), or_,
                             list(range(cond_jaxpr.out_avals[0].ndim)))

    body_c = xb.make_computation_builder("body_computation")
    body_carry = body_c.ParameterWithShape(c.GetShape(init_carry))
    body_carry_elts = [
        body_c.GetTupleElement(body_carry, i) for i in range(len(args))
    ]
    x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
    body_out = body_c.Call(
        xla.jaxpr_computation(body_jaxpr.jaxpr, backend, axis_env,
                              body_jaxpr.literals, (),
                              *_map(body_c.GetShape, y + z)), y + z)
    new_z = [
        body_c.GetTupleElement(body_out, i) for i in range(len(init_vals))
    ]
    if batched:
        body_cond_outs = body_c.Call(
            xla.jaxpr_computation(cond_jaxpr.jaxpr, backend, axis_env,
                                  cond_jaxpr.literals, (),
                                  *_map(body_c.GetShape, x + z)), x + z)
        body_pred = body_c.GetTupleElement(body_cond_outs, 0)
        new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z)
        assert _map(body_c.GetShape, new_z) == _map(body_c.GetShape,
                                                    z)  # no broadcast
    new_carry = body_c.Tuple(*(x + y + new_z))

    ans = c.While(cond_c.Build(pred), body_c.Build(new_carry), init_carry)
    ans_elts = [c.GetTupleElement(ans, i) for i in range(len(args))]
    _, _, z = split_list(ans_elts, [cond_nconsts, body_nconsts])
    return c.Tuple(*z)
Exemple #14
0
        batching.moveaxis(x, d, 0) if now_bat else x
        for x, d, was_bat, now_bat in zip(init, init_dims, init_bat, carry_bat)
    ]

    outs = while_p.bind(*(new_consts + new_init),
                        cond_nconsts=cond_nconsts,
                        cond_jaxpr=cond_jaxpr_batched,
                        body_nconsts=body_nconsts,
                        body_jaxpr=body_jaxpr_batched)
    out_bdims = [0 if b else batching.not_mapped for b in carry_bat]
    return outs, out_bdims


while_p = lax.Primitive('while')
while_p.multiple_results = True
while_p.def_impl(partial(xla.apply_primitive, while_p))
while_p.def_abstract_eval(_while_loop_abstract_eval)
xla.initial_style_translations[while_p] = _while_loop_translation_rule
batching.primitive_batchers[while_p] = _while_loop_batching_rule

### cond


def cond(pred, true_operand, true_fun, false_operand, false_fun):
    true_ops, true_tree = tree_flatten((true_operand, ))
    true_avals = tuple(_map(_abstractify, true_ops))
    true_jaxpr, true_consts, out_tree = _initial_style_jaxpr(
        true_fun, true_tree, true_avals)
    false_ops, false_tree = tree_flatten((false_operand, ))
    false_avals = tuple(_map(_abstractify, false_ops))
    false_jaxpr, false_consts, out_tree2 = _initial_style_jaxpr(
def main(argv):
    del argv
    # BEGIN GOOGLE-INTERNAL
    xm.setup_work_unit()
    # END GOOGLE-INTERNAL

    tf.enable_v2_behavior()
    init_mllogger()

    mllogger.event('cache_clear')
    mllogger.start('init_start')
    mllogger.event('submission_org', 'Google')
    mllogger.event('submission_platform',
                   'TPUv3-{}'.format(jax.device_count()))
    mllogger.event('submission_division', 'closed')
    mllogger.event('submission_status', 'research')
    mllogger.event('submission_benchmark', 'resnet')
    mllogger.event('train_samples', input_pipeline.TRAIN_IMAGES)
    mllogger.event('eval_samples', input_pipeline.EVAL_IMAGES)

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.output_dir)
        # Write summaries in background thread to avoid blocking on device sync
        summary_thread = thread.ThreadPoolExecutor(1, 'summary')
    # Infeed is currently synchronous, so do it in a background thread too
    infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(), 'infeed')

    if FLAGS.seed is not None:
        seed = FLAGS.seed
    else:
        seed = np.uint32(time.time() if jax.host_id() == 0 else 0)
        seed = per_host_sum_pmap(seed)

    mllogger.event('seed', int(seed))
    key = random.PRNGKey(seed)

    batch_size = FLAGS.batch_size
    if batch_size == -1:
        if jax.device_count() > 4096:
            batch_size = 65536
        else:
            batch_size = min(128 * jax.device_count(), 32768)
    mllogger.event('global_batch_size', batch_size)
    eval_batch_size = min(input_pipeline.EVAL_IMAGES, 256 * jax.device_count())
    device_batch_size = batch_size // jax.device_count()
    device_eval_batch_size = int(
        math.ceil(eval_batch_size / jax.device_count()))

    model_dtype = jnp.bfloat16 if FLAGS.bfloat16 else jnp.float32
    input_dtype = tf.bfloat16 if FLAGS.bfloat16 else tf.float32

    num_epochs = FLAGS.num_epochs
    if num_epochs is None:
        if batch_size < 32768:
            num_epochs = 56
        elif batch_size < 65536:
            num_epochs = 64
        else:
            num_epochs = 92

    steps_per_epoch = input_pipeline.TRAIN_IMAGES / batch_size
    # match TF submission behavior (round steps per loop up)
    steps_per_loop = int(math.ceil(steps_per_epoch * FLAGS.epochs_per_loop))
    # also apply rounding loop up to next step to "epochs" in LR schedule
    steps_per_epoch *= steps_per_loop / (steps_per_epoch *
                                         FLAGS.epochs_per_loop)

    steps_per_eval = int(
        math.ceil(input_pipeline.EVAL_IMAGES / eval_batch_size))

    base_learning_rate = FLAGS.learning_rate * batch_size / 256.
    beta = FLAGS.momentum
    if beta is None:
        if batch_size < 32768:
            beta = 0.9
        elif batch_size < 65536:
            beta = 0.929
        else:
            beta = 0.9537213777059405
    weight_decay = FLAGS.weight_decay
    if weight_decay is None:
        weight_decay = 2e-4 if batch_size < 32768 else 1e-4

    space_to_depth = FLAGS.space_to_depth
    if space_to_depth is None:
        space_to_depth = device_batch_size <= 8

    image_format = FLAGS.image_format
    if image_format is None:
        if space_to_depth and device_batch_size <= 8:
            image_format = 'HWNC'
        else:
            image_format = 'HWCN'

    image_size = input_pipeline.IMAGE_SIZE
    if space_to_depth:
        train_input_shape = (device_batch_size, image_size // 2,
                             image_size // 2, 12)
        eval_input_shape = (device_eval_batch_size, image_size // 2,
                            image_size // 2, 12)
    else:
        train_input_shape = (device_batch_size, image_size, image_size, 3)
        eval_input_shape = (device_eval_batch_size, image_size, image_size, 3)
    if image_format == 'HWCN':
        train_input_shape = tuple(train_input_shape[i] for i in [1, 2, 3, 0])
        eval_input_shape = tuple(eval_input_shape[i] for i in [1, 2, 3, 0])
    elif image_format == 'HWNC':
        train_input_shape = tuple(train_input_shape[i] for i in [1, 2, 0, 3])
        eval_input_shape = tuple(eval_input_shape[i] for i in [1, 2, 0, 3])

    model, state = create_model(key, device_batch_size, image_size,
                                model_dtype, space_to_depth)

    if FLAGS.lars:
        mllogger.event('opt_name', 'lars')
        mllogger.event('lars_opt_weight_decay', weight_decay)
        mllogger.event('lars_opt_momentum', beta)
        mllogger.event('lars_epsilon', 0)
        weight_opt_def = optim.LARS(base_learning_rate,
                                    beta,
                                    weight_decay=weight_decay)
        other_opt_def = optim.Momentum(base_learning_rate,
                                       beta,
                                       weight_decay=0,
                                       nesterov=False)
        learning_rate_fn = polynomial_learning_rate_fn(batch_size,
                                                       steps_per_epoch,
                                                       num_epochs)
    else:
        mllogger.event('opt_name', 'sgd')
        mllogger.event('sgd_opt_momentum', beta)
        weight_opt_def = optim.Momentum(base_learning_rate,
                                        beta,
                                        weight_decay=weight_decay,
                                        nesterov=True)
        other_opt_def = optim.Momentum(base_learning_rate,
                                       beta,
                                       weight_decay=0,
                                       nesterov=True)
        learning_rate_fn = piecewise_learning_rate_fn(base_learning_rate,
                                                      steps_per_epoch,
                                                      num_epochs)

    def filter_weights(key, _):
        return 'bias' not in key and 'scale' not in key

    def filter_other(key, _):
        return 'bias' in key or 'scale' in key

    weight_traversal = optim.ModelParamTraversal(filter_weights)
    other_traversal = optim.ModelParamTraversal(filter_other)
    optimizer_def = optim.MultiOptimizer((weight_traversal, weight_opt_def),
                                         (other_traversal, other_opt_def))
    optimizer = optimizer_def.create(model)
    del model  # do not keep a copy of the initial model

    optimizer = broadcast(optimizer)
    state = broadcast(state)
    empty_metrics = broadcast({'samples': 0, 'loss': 0., 'accuracy': 0})

    p_allreduce_metrics = jax.pmap(allreduce_metrics, axis_name='batch')

    p_sync_batchnorm_stats = jax.pmap(sync_batchnorm_stats, axis_name='batch')

    def host_loop_train_step(optimizer, state, metrics):
        token = lax.create_token(optimizer.state[0].step)
        batch, token = lax.infeed(token,
                                  shape=(jax.ShapedArray(
                                      train_input_shape, model_dtype),
                                         jax.ShapedArray((device_batch_size, ),
                                                         jnp.int32)))
        optimizer, state, metrics = train_step(optimizer, state, batch,
                                               metrics, learning_rate_fn,
                                               image_format, space_to_depth)
        return optimizer, state, metrics

    p_host_loop_train_step = jax.pmap(host_loop_train_step,
                                      axis_name='batch',
                                      in_axes=(None, 0, 0))

    def host_loop_eval_step(model, state, metrics):
        token = lax.create_token(metrics['samples'])
        batch, token = lax.infeed(
            token,
            shape=(jax.ShapedArray(eval_input_shape, model_dtype),
                   jax.ShapedArray((device_eval_batch_size, ), jnp.int32)))
        metrics = eval_step(model, state, batch, metrics, image_format,
                            space_to_depth)
        return metrics

    p_host_loop_eval_step = jax.pmap(host_loop_eval_step,
                                     axis_name='batch',
                                     in_axes=(None, None, 0))

    def device_train_loop_cond(args):
        _, _, _, _, step, loop = args
        return step // steps_per_loop == loop

    def device_train_loop_body(args):
        optimizer, state, metrics, token, step, loop = args
        batch, token = lax.infeed(token,
                                  shape=(jax.ShapedArray(
                                      train_input_shape, model_dtype),
                                         jax.ShapedArray((device_batch_size, ),
                                                         jnp.int32)))
        optimizer, state, metrics = train_step(optimizer, state, batch,
                                               metrics, learning_rate_fn,
                                               image_format, space_to_depth)
        step += 1
        return optimizer, state, metrics, token, step, loop

    def device_train_loop(optimizer, state, metrics, step, loop):
        token = lax.create_token(step)
        optimizer, state, metrics, _, step, _ = lax.while_loop(
            device_train_loop_cond, device_train_loop_body,
            (optimizer, state, metrics, token, step, loop))
        state = sync_batchnorm_stats(state)
        metrics = allreduce_metrics(metrics)
        return optimizer, state, metrics, step

    p_train_loop = jax.pmap(device_train_loop,
                            axis_name='batch',
                            in_axes=(None, None, 0, None, None))

    # BEGIN GOOGLE-INTERNAL
    def maybe_start_xprof(seconds):
        if jax.host_id() == 0 and FLAGS.xprof:
            xprof = xprof_session.XprofSession()
            xprof.start_session('REDACTED', True, 2)

            def sleep_and_end_xprof():
                time.sleep(seconds)
                logging.info(
                    'Xprof URL: %s',
                    xprof.end_session_and_get_url(
                        tag='flax resnet, {} devices, batch {} per device'.
                        format(jax.device_count(), device_batch_size)))

            thread.ThreadPoolExecutor(1, 'xprof').submit(sleep_and_end_xprof)

    # END GOOGLE-INTERNAL

    if FLAGS.precompile:
        logging.info('precompiling step/loop functions')
        if FLAGS.device_loop:
            # the device training loop condition will immediately be false
            p_train_loop(unbroadcast(optimizer), unbroadcast(state),
                         empty_metrics, jnp.array(0, dtype=jnp.int32), 1)
        else:
            for device in jax.local_devices():
                images = np.zeros(train_input_shape, model_dtype)
                labels = np.zeros((device_batch_size, ), np.int32)
                infeed_pool.submit(
                    partial(device.transfer_to_infeed, (images, labels)))
            p_host_loop_train_step(unbroadcast(optimizer), state,
                                   empty_metrics)
            p_sync_batchnorm_stats(state)
        for device in jax.local_devices():
            images = np.zeros(eval_input_shape, model_dtype)
            labels = np.zeros((device_eval_batch_size, ), np.int32)
            infeed_pool.submit(
                partial(device.transfer_to_infeed, (images, labels)))
        p_host_loop_eval_step(unbroadcast(optimizer.target),
                              unbroadcast(state), empty_metrics)
        p_allreduce_metrics(empty_metrics)['accuracy'].block_until_ready()
        logging.info('finished precompiling')

    # BEGIN GOOGLE-INTERNAL
    maybe_start_xprof(20)
    # END GOOGLE-INTERNAL
    if not FLAGS.fake_data:
        logging.info('constructing datasets')
        # pylint: disable=g-complex-comprehension
        train_ds, eval_ds = [
            input_pipeline.load_split(
                device_batch_size if train else device_eval_batch_size,
                dtype=input_dtype,
                train=train,
                image_format=image_format,
                space_to_depth=space_to_depth,
                cache_uncompressed=jax.device_count() > 64)
            for train in (True, False)
        ]
        logging.info('constructing dataset iterators')
        train_iter = iter(train_ds)
        eval_iter = iter(eval_ds)

    local_devices = jax.local_devices()
    host_step, device_step = 0, broadcast(0)
    mllogger.end('init_stop')
    mllogger.start('run_start')
    mllogger.start('block_start',
                   metadata={
                       'first_epoch_num': 1,
                       'epoch_count': FLAGS.epochs_per_loop
                   })
    for loop in range(int(math.ceil(num_epochs / FLAGS.epochs_per_loop)) + 2):
        # BEGIN GOOGLE-INTERNAL
        if loop == 10: maybe_start_xprof(1)
        # END GOOGLE-INTERNAL
        metrics = empty_metrics
        if FLAGS.device_loop:
            optimizer, state, metrics, device_step = p_train_loop(
                unbroadcast(optimizer), unbroadcast(state), metrics,
                unbroadcast(device_step), loop)
        while int(host_step // steps_per_loop) == loop:
            if not FLAGS.device_loop:
                optimizer, state, metrics = p_host_loop_train_step(
                    unbroadcast(optimizer), state, metrics)
            # pylint: disable=protected-access
            while infeed_pool._work_queue.qsize() > 100:
                time.sleep(0.01)
            for device in local_devices:
                if FLAGS.fake_data:
                    images = np.zeros(train_input_shape, model_dtype)
                    labels = np.zeros((device_batch_size, ), np.int32)
                else:
                    # pylint: disable=protected-access
                    images, labels = jax.tree_map(lambda x: x._numpy(),
                                                  next(train_iter))
                assert images.shape == train_input_shape and labels.dtype == jnp.int32
                infeed_pool.submit(
                    partial(device.transfer_to_infeed, (images, labels)))
            host_step += 1
        epoch = (loop + 1) * FLAGS.epochs_per_loop
        if FLAGS.train_metrics:
            if not FLAGS.device_loop:
                metrics = p_allreduce_metrics(metrics)
            if jax.host_id() == 0:
                summary_thread.submit(
                    partial(write_summary, summary_writer, metrics, 'train',
                            epoch))
        if not FLAGS.device_loop:
            state = p_sync_batchnorm_stats(state)
        metrics = empty_metrics
        for _ in range(steps_per_eval):
            metrics = p_host_loop_eval_step(unbroadcast(optimizer.target),
                                            unbroadcast(state), metrics)
            for device in local_devices:
                if FLAGS.fake_data:
                    images = np.zeros(eval_input_shape, model_dtype)
                    labels = np.zeros((device_eval_batch_size, ), np.int32)
                else:
                    # pylint: disable=protected-access
                    images, labels = jax.tree_map(lambda x: x._numpy(),
                                                  next(eval_iter))
                assert images.shape == eval_input_shape and labels.dtype == jnp.int32, \
                    'images.shape={}'.format(images.shape)
                infeed_pool.submit(
                    partial(device.transfer_to_infeed, (images, labels)))
        metrics = p_allreduce_metrics(metrics)
        if jax.host_id() == 0:
            summary_thread.submit(
                partial(write_summary, summary_writer, metrics, 'eval', epoch))
    # Wait until computations are done before exiting
    p_allreduce_metrics(metrics)['accuracy'].block_until_ready()
    if jax.host_id() == 0:
        summary_thread.shutdown()
        if not DONE:
            mllogger.end('run_stop', metadata={'status': 'aborted'})
Exemple #16
0
def create_token(x):
    """Creates an XLA token value with no preconditions for sequencing effects.
    This is a mpi4jax customized version, which behaves as the jax one but it
    is also possible to compute the gradient of it.

    Experimental.

    Args:
      x: a dummy argument used to tie the CreateToken operator into a trace. The
         value of `x` is ignored.
    """
    # x is a dummy argument used to tie the operator into a trace.
    return create_token_p.bind(x)


create_token_p = Primitive("create_token_mpi4jax")
create_token_p.def_impl(partial(xla.apply_primitive, create_token_p))
create_token_p.def_abstract_eval(lambda _: abstract_token)
xla.translations[create_token_p] = lambda c, _: xla_client.ops.CreateToken(c)


def create_token_value_and_jvp(in_args, tan_args):
    (x, ) = in_args
    res = create_token(x)
    jvp = zeros_like_array(x)
    return (res, jvp)


ad.primitive_jvps[create_token_p] = create_token_value_and_jvp
Exemple #17
0
 def test_jvp(self, f, args):
     jtu.check_jvp(f, partial(jvp, f), args)
Exemple #18
0
 def test_vjp(self, f, args):
   jtu.check_vjp(f, partial(vjp, f), args,
                 rtol={np.float32: 3e-1, np.float64: 1e-5},
                 atol={np.float32: 1e-2, np.float64: 1e-5})
Exemple #19
0
 def test_vjp(self, f, args):
     jtu.check_vjp(f, partial(vjp, f), args)
Exemple #20
0
                      init_val), init_val_bd


def _jaxtupletree_select(pred, on_true, on_false):
    aval = core.get_aval(on_true)
    if type(aval) is core.AbstractTuple:
        return core.pack(
            map(partial(_jaxtupletree_select, pred), on_true, on_false))
    elif isinstance(aval, UnshapedArray):
        return lax.select(pred, on_true, on_false)
    else:
        raise TypeError(aval)


while_p = lax.Primitive('while')
while_p.def_impl(partial(xla.apply_primitive, while_p))
while_p.def_abstract_eval(_while_loop_abstract_eval)
xla.translations[while_p] = _while_loop_translation_rule
batching.primitive_batchers[while_p] = _while_loop_batching_rule

### cond


def cond(pred, true_operand, true_fun, false_operand, false_fun):
    def trace_jaxpr(fun, operand):
        op_flat, in_tree = pytree_to_flatjaxtuple(operand)
        fun_flat, out_tree = pytree_fun_to_flatjaxtuple_fun(
            lu.wrap_init(fun), (in_tree, ))
        jaxpr, pvout, consts = pe.trace_to_jaxpr(fun_flat,
                                                 (_abstractify(op_flat), ))
        return op_flat, jaxpr, consts, pvout, out_tree
Exemple #21
0
def _defidentity(prim, argnum=0):
    parallel.papply_primitive_rules[prim] = partial(_identity_papply, prim,
                                                    argnum)
Exemple #22
0
def _promote_aval_rank(n, xs):
    assert isinstance(xs, core.AbstractValue)
    if isinstance(xs, core.AbstractTuple):
        return core.AbstractTuple(map(partial(_promote_aval_rank, n), xs))
    else:
        return ShapedArray((n, ) + xs.shape, xs.dtype)
Exemple #23
0
 def test_jvp(self, f, args):
     jtu.check_jvp(f, partial(jvp, f), args, rtol={np.float32: 3e-2})
Exemple #24
0
def _index_arrays(i, aval, xs):
    assert isinstance(aval, core.AbstractValue)
    if isinstance(aval, core.AbstractTuple):
        return core.pack(map(partial(_index_arrays, i), aval, xs))
    else:
        return lax.dynamic_index_in_dim(xs, i, keepdims=False)
Exemple #25
0
    out = lapack.jax_syevd(c, operand, lower=lower)
    return c.Tuple(c.GetTupleElement(out, 0), c.GetTupleElement(out, 1))
  else:
    raise NotImplementedError(
        "Only unbatched eigendecomposition is implemented on CPU")

eigh_p = Primitive('eigh')
eigh_p.def_impl(eigh_impl)
eigh_p.def_abstract_eval(eigh_abstract_eval)
xla.translations[eigh_p] = eigh_translation_rule
xla.backend_specific_translations['Host'][eigh_p] = eigh_cpu_translation_rule



triangular_solve_dtype_rule = partial(
    binop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex),
    'triangular_solve')

def triangular_solve_shape_rule(a, b, left_side=False, **unused_kwargs):
  if a.ndim < 2:
    msg = "triangular_solve requires a.ndim to be at least 2, got {}."
    raise TypeError(msg.format(a.ndim))
  if a.shape[-1] != a.shape[-2]:
    msg = ("triangular_solve requires the last two dimensions of a to be equal "
           "in size, got a.shape of {}.")
    raise TypeError(msg.format(a.shape))
  if a.shape[:-2] != b.shape[:-2]:
    msg = ("triangular_solve requires both arguments to have the same number "
           "of dimensions and equal batch dimensions, got {} and {}.")
    raise TypeError(msg.format(a.shape, b.shape))
  common_dim = -2 if left_side else -1
Exemple #26
0
def _update_arrays(i, aval, xs, x):
    assert isinstance(aval, core.AbstractValue)
    if isinstance(aval, core.AbstractTuple):
        return core.pack(map(partial(_update_arrays, i), aval, xs, x))
    else:
        return lax.dynamic_update_index_in_dim(xs, x[None, ...], i, axis=0)
Exemple #27
0
def standard_pmap_primitive(name):
    prim = core.Primitive(name)
    prim.def_impl(partial(pxla.apply_parallel_primitive, prim))
    prim.def_abstract_eval(lambda x, *args, **params: x)
    return prim
Exemple #28
0
        nan = xb.constant(c, np.array(np.nan * (1. + 1j), dtype=dtype))
    else:
        nan = xb.constant(c, np.array(np.nan, dtype=dtype))
    return xops.Broadcast(nan, shape.dimensions())


def _cholesky_cpu_gpu_translation_rule(potrf_impl, c, operand):
    shape = c.get_shape(operand)
    batch_dims = shape.dimensions()[:-2]
    result, info = potrf_impl(c, operand, lower=True)
    ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
    return _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)),
                                result, _nan_like(c, result))


xla.backend_specific_translations['cpu'][cholesky_p] = partial(
    _cholesky_cpu_gpu_translation_rule, lapack.potrf)

xla.backend_specific_translations['gpu'][cholesky_p] = partial(
    _cholesky_cpu_gpu_translation_rule, cusolver.potrf)

# Asymmetric eigendecomposition


def eig_impl(operand, *, compute_left_eigenvectors,
             compute_right_eigenvectors):
    return (xla.apply_primitive(
        eig_p,
        operand,
        compute_left_eigenvectors=compute_left_eigenvectors,
        compute_right_eigenvectors=compute_right_eigenvectors))
Exemple #29
0
def _defvectorized(prim):
    parallel.papply_primitive_rules[prim] = partial(_vectorized_papply, prim)
Exemple #30
0
def flat_propagate(tree, *flat_invals):
  invals, outvals = tree_util.tree_unflatten(tree, flat_invals)
  env, state = yield ((invals, outvals), {})
  new_incells = [env.read(var) for var in env.jaxpr.invars]
  new_outcells = [env.read(var) for var in env.jaxpr.outvars]
  flat_out, out_tree = tree_util.tree_flatten(
      (new_incells, new_outcells, state))
  yield flat_out, out_tree


def call_rule(prim, incells, outcells, **params):
  """Propagate rule for call primitives."""
  f, incells = incells[0], incells[1:]
  flat_vals, in_tree = tree_util.tree_flatten((incells, outcells))
  new_params = dict(params)
  if 'donated_invars' in params:
    new_params['donated_invars'] = (False,) * len(flat_vals)
  f, aux = flat_propagate(f, in_tree)
  flat_out = prim.bind(f, *flat_vals, **new_params)
  out_tree = aux()
  return tree_util.tree_unflatten(out_tree, flat_out)


default_call_rules = {}
default_call_rules[xla.xla_call_p] = jax_util.partial(call_rule, xla.xla_call_p)
default_call_rules[jax_core.call_p] = jax_util.partial(call_rule,
                                                       jax_core.call_p)
default_call_rules[pe.remat_call_p] = jax_util.partial(call_rule,
                                                       pe.remat_call_p)
default_call_rules[harvest.nest_p] = jax_util.partial(call_rule, harvest.nest_p)