コード例 #1
0
    def update_fn(opt, lr, images, labels, rng):
        """Update step."""

        measurements = {}

        # Get device-specific loss rng.
        rng, rng_model = jax.random.split(rng, 2)
        rng_model_local = jax.random.fold_in(rng_model,
                                             jax.lax.axis_index('batch'))
        rng_model_local, diag_noise_rng, standard_noise_rng = jax.random.split(
            rng_model_local, num=3)

        def loss_fn(params, images, labels):
            logits, _ = model.apply({'params': flax.core.freeze(params)},
                                    images,
                                    train=True,
                                    rngs={
                                        'dropout':
                                        rng_model_local,
                                        'diag_noise_samples':
                                        diag_noise_rng,
                                        'standard_norm_noise_samples':
                                        standard_noise_rng
                                    })
            label_indices = config.get('label_indices')
            if label_indices:
                logits = logits[:, label_indices]
            return getattr(train_utils,
                           config.get('loss', 'sigmoid_xent'))(logits=logits,
                                                               labels=labels)

        # Implementation considerations compared and summarized at
        # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en#
        l, g = train_utils.accumulate_gradient(jax.value_and_grad(loss_fn),
                                               opt.target, images, labels,
                                               config.get('grad_accum_steps'))
        l, g = jax.lax.pmean((l, g), axis_name='batch')

        # Log the gradient norm only if we need to compute it anyways (clipping)
        # or if we don't use grad_accum_steps, as they interact badly.
        if config.get('grad_accum_steps',
                      1) == 1 or config.get('grad_clip_norm'):
            grads, _ = jax.tree_flatten(g)
            l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
            measurements['l2_grads'] = l2_g

        # Optionally resize the global gradient to a maximum norm. We found this
        # useful in some cases across optimizers, hence it's in the main loop.
        if config.get('grad_clip_norm'):
            g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
            g = jax.tree_map(lambda p: g_factor * p, g)
        opt = opt.apply_gradient(g, learning_rate=lr)
        opt = opt.replace(target=weight_decay_fn(opt.target, lr))

        params, _ = jax.tree_flatten(opt.target)
        measurements['l2_params'] = jnp.sqrt(
            sum([jnp.vdot(p, p) for p in params]))

        return opt, l, rng, measurements
コード例 #2
0
    def update_fn(opt, lr, images, labels, rng):
        """Update step. Copy to deterministic_utils.py whenever changes are made!"""
        measurements = {}

        # Split rng and return next_rng for the following step.
        rng, next_rng = jax.random.split(rng, 2)
        rng_local = jax.random.fold_in(rng, jax.lax.axis_index('batch'))

        def loss_fn(params, images, labels):
            logits, _ = model.apply({'params': flax.core.freeze(params)},
                                    images,
                                    train=True,
                                    rngs={'dropout': rng_local})
            label_indices = config.get('label_indices')
            if label_indices:
                logits = logits[:, label_indices]
            loss = getattr(train_utils,
                           config.get('loss', 'sigmoid_xent'))(logits=logits,
                                                               labels=labels)
            return loss, logits

        # Implementation considerations compared and summarized at
        # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en#
        (l, logits), g = train_utils.accumulate_gradient(
            jax.value_and_grad(loss_fn, has_aux=True), opt.target, images,
            labels, config.get('grad_accum_steps'))
        l, g = jax.lax.pmean((l, g), axis_name='batch')
        measurements['training_loss'] = l

        # Log the gradient norm only if we need to compute it anyways (clipping)
        # or if we don't use grad_accum_steps, as they interact badly.
        if config.get('grad_accum_steps',
                      1) == 1 or config.get('grad_clip_norm'):
            grads, _ = jax.tree_flatten(g)
            l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
            measurements['l2_grads'] = l2_g

        # Optionally resize the global gradient to a maximum norm. We found this
        # useful in some cases across optimizers, hence it's in the main loop.
        if config.get('grad_clip_norm'):
            g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
            g = jax.tree_util.tree_map(lambda p: g_factor * p, g)
        opt = opt.apply_gradient(g, learning_rate=lr)

        opt = opt.replace(target=weight_decay_fn(opt.target, lr))

        params, _ = jax.tree_flatten(opt.target)
        measurements['l2_params'] = jnp.sqrt(
            sum([jnp.vdot(p, p) for p in params]))

        top1_idx = jnp.argmax(logits, axis=1)
        top1_correct = jnp.take_along_axis(labels, top1_idx[:, None],
                                           axis=1)[:, 0]
        prec1 = jax.lax.psum(jnp.sum(top1_correct),
                             axis_name='batch') / batch_size
        measurements['training_prec@1'] = prec1
        measurements['learning_rate'] = lr
        return opt, next_rng, measurements
コード例 #3
0
    def update_fn(opt, lr, images, labels, rng):
        """Update step."""
        measurements = {}

        # Get device-specific loss rng.
        rng, rng_model = jax.random.split(rng, 2)
        rng_model_local = jax.random.fold_in(rng_model,
                                             jax.lax.axis_index('batch'))

        def loss_fn(params, images, labels):
            logits, _ = model.apply({'params': flax.core.freeze(params)},
                                    images,
                                    train=True,
                                    rngs={'dropout': rng_model_local})
            return getattr(train_utils,
                           config.get('loss', 'sigmoid_xent'))(logits=logits,
                                                               labels=labels)

        # Implementation considerations compared and summarized at
        # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en#
        l, g = train_utils.accumulate_gradient(jax.value_and_grad(loss_fn),
                                               opt.target, images, labels,
                                               config.get('grad_accum_steps'))
        l, g = jax.lax.pmean((l, g), axis_name='batch')

        # Log the gradient norm only if we need to compute it anyways (clipping)
        # or if we don't use grad_accum_steps, as they interact badly.
        if config.get('grad_accum_steps',
                      1) == 1 or grad_clip_norm is not None:
            grads, _ = jax.tree_flatten(g)
            l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
            measurements['l2_grads'] = l2_g

        # Optionally resize the global gradient to a maximum norm. We found this
        # useful in some cases across optimizers, hence it's in the main loop.
        if grad_clip_norm is not None:
            g_factor = jnp.minimum(1.0, grad_clip_norm / l2_g)
            g = jax.tree_util.tree_map(lambda p: g_factor * p, g)
        opt = opt.apply_gradient(g, learning_rate=lr)

        decay_rules = weight_decay or []
        if isinstance(decay_rules, numbers.Number):
            decay_rules = [('.*kernel.*', decay_rules)]
        sched_m = lr / config.lr.base if config.get(
            'weight_decay_decouple') else lr

        def decay_fn(v, wd):
            return (1.0 - sched_m * wd) * v

        opt = opt.replace(target=train_utils.tree_map_with_regex(
            decay_fn, opt.target, decay_rules))

        params, _ = jax.tree_flatten(opt.target)
        measurements['l2_params'] = jnp.sqrt(
            sum([jnp.vdot(p, p) for p in params]))

        return opt, l, rng, measurements