Пример #1
0
    def _check_args_and_kwargs(self, *args, **kwargs):
        for k in kwargs.keys():
            if k not in GraphV2.allowed_kwargs:
                raise ValueError(f"unknown argument {k}")

        consumed_args = 0
        if kwargs.get("inputs", False):
            inputs = jax.tree_flatten(kwargs["inputs"])[0]
        else:
            if len(args) == 0:
                raise ValueError("Input is not passed correctly")
            else:
                inputs = jax.tree_flatten(args[consumed_args])[0]
                consumed_args += 1

        if kwargs.get("outputs", False):
            outputs = jax.tree_flatten(kwargs["outputs"])[0]
        else:
            if len(args) == 0:
                raise ValueError("Output is not passed correctly")
            else:
                outputs = jax.tree_flatten(args[consumed_args])[0]

        self._output_names = [o.name for o in outputs]
        return inputs, outputs
Пример #2
0
 def vectors_to_blocks(
     self,
     parameter_structured_vector: Any,
 ) -> Sequence[BlockVector]:
     """Splits the parameters to values for the corresponding blocks."""
     in_vars = jax.tree_unflatten(self._in_tree, self._jaxpr.invars)
     params_vars = in_vars[self.params_index]
     params_vars_flat = jax.tree_flatten(params_vars)[0]
     params_values_flat = jax.tree_flatten(parameter_structured_vector)[0]
     assert len(params_vars_flat) == len(params_values_flat)
     params_dict = dict(zip(params_vars_flat, params_values_flat))
     per_block_vectors = []
     for eqn in self._layer_tags:
         if eqn.primitive.name == "generic_tag":
             block_vars = eqn.invars
         else:
             block_vars = eqn.primitive.split_all_inputs(eqn.invars)[2]
         per_block_vectors.append(
             tuple(params_dict.pop(v) for v in block_vars))
     if params_dict:
         raise ValueError(
             f"From the parameters the following structure is not "
             f"assigned to any block: {params_dict}. Most likely "
             f"this part of the parameters is not part of the graph "
             f"reaching the losses.")
     return tuple(per_block_vectors)
Пример #3
0
 def assertStructureAllClose(self, x, y, **kwargs):
     x_v, x_tree = jax.tree_flatten(x)
     y_v, y_tree = jax.tree_flatten(y)
     self.assertEqual(x_tree, y_tree)
     for xi, yi in zip(x_v, y_v):
         self.assertEqual(xi.shape, yi.shape)
         self.assertAllClose(xi, yi, check_dtypes=True, **kwargs)
Пример #4
0
def tree_allclose(*trees):
    """
    Determines if all elements of `trees`

    a) have the same tree structure
    and
    b) the corresponding leaves all fulfill ``np.allclose(leaf1, leaf2)``

    such that the trees are, up to numerical tolerances of `np.allclose`, equal
    """
    if len(trees) > 2:
        return tree_allclose(trees[0], trees[1]) and tree_allclose(*trees[1:])

    if len(trees) < 2:
        return True

    tree1, tree2 = trees
    _, tree_def1 = tree_flatten(tree1)
    _, tree_def2 = tree_flatten(tree2)

    if tree_def1 != tree_def2:
        return False

    return np.all(
        tree_flatten(
            tree_multimap(lambda arr1, arr2: np.allclose(arr1, arr2), tree1,
                          tree2))[0])
Пример #5
0
    def update(self, model_gradient, old_optimizer, new_optimizer):
        """Computes a number of statistics from the model params and update.

    Statistics computed:
      Per layer update variances and norms.
      Per layer gradient variance and norms.
      Per layer param norms.
      Ratio of parameter norm to update and update variance.

    Args:
      model_gradient: A pytree of the same shape as the model_params pytree that
        was used when the metrics_grabber object was created.
      old_optimizer: The optimizer before the param update.
      new_optimizer: The optimizer after the param update.

    Returns:
      An updated class object.
    """
        grads_flat, treedef = jax.tree_flatten(model_gradient)
        new_params_flat, _ = jax.tree_flatten(new_optimizer.target.params)
        old_params_flat, _ = jax.tree_flatten(old_optimizer.target.params)

        # flatten_up_to here is needed to avoid flattening the _MetricsLeafState
        # nodes.
        state_flat = treedef.flatten_up_to(self.state)
        new_states_flat = [
            _update_param_stats(state, grad, new_param - old_param, new_param,
                                self.config)
            for state, grad, old_param, new_param in zip(
                state_flat, grads_flat, old_params_flat, new_params_flat)
        ]

        return self.replace(state=jax.tree_unflatten(treedef, new_states_flat))
    def update_fn(opt, lr, images, labels, rng):
        """Update step."""

        measurements = {}

        # Get device-specific loss rng.
        rng, rng_model = jax.random.split(rng, 2)
        rng_model_local = jax.random.fold_in(rng_model,
                                             jax.lax.axis_index('batch'))

        def loss_fn(params, images, labels):
            logits, _ = model.apply({'params': flax.core.freeze(params)},
                                    images,
                                    train=True,
                                    rngs={'dropout': rng_model_local})
            accuracy = jnp.mean(
                jnp.equal(jnp.argmax(logits, axis=-1),
                          jnp.argmax(labels, axis=-1)))
            return getattr(train_utils,
                           config.get('loss',
                                      'sigmoid_xent'))(logits=logits,
                                                       labels=labels), accuracy

        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (l, train_accuracy), g = grad_fn(opt.target, images, labels)
        l, g = jax.lax.pmean((l, g), axis_name='batch')
        measurements['accuracy'] = train_accuracy

        # Log the gradient norm only if we need to compute it anyways (clipping)
        # or if we don't use grad_accum_steps, as they interact badly.
        if config.get('grad_accum_steps',
                      1) == 1 or config.get('grad_clip_norm'):
            grads, _ = jax.tree_flatten(g)
            l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
            measurements['l2_grads'] = l2_g

        # Optionally resize the global gradient to a maximum norm. We found this
        # useful in some cases across optimizers, hence it's in the main loop.
        if config.get('grad_clip_norm'):
            g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
            g = jax.tree_util.tree_map(lambda p: g_factor * p, g)
        opt = opt.apply_gradient(g, learning_rate=lr)

        decay_rules = config.get('weight_decay', []) or []
        if isinstance(decay_rules, numbers.Number):
            decay_rules = [('.*kernel.*', decay_rules)]
        sched_m = lr / config.lr.base if config.get(
            'weight_decay_decouple') else lr

        def decay_fn(v, wd):
            return (1.0 - sched_m * wd) * v

        opt = opt.replace(target=train_utils.tree_map_with_regex(
            decay_fn, opt.target, decay_rules))

        params, _ = jax.tree_flatten(opt.target)
        measurements['l2_params'] = jnp.sqrt(
            sum([jnp.vdot(p, p) for p in params]))

        return opt, l, rng, measurements
Пример #7
0
    def _check_output(self, actual, expected):
        expected_leaves, expected_structure = jax.tree_flatten(expected)
        actual_leaves, actual_structure = jax.tree_flatten(actual)
        assert all(isinstance(x, jnp.ndarray)
                   for x in expected_leaves), "bad example_data"

        if actual_structure != expected_structure:
            raise TypeError(
                f"func has bad return tree_structure, expected: {expected_structure}, "
                f"got: {actual_structure}")

        if not all(isinstance(x, jnp.ndarray) for x in actual_leaves):
            bad_types = tuple(
                type(x) for x in actual_leaves
                if not isinstance(x, jnp.ndarray))
            raise TypeError(
                "all leaves of dist_params must be of type: jax.numpy.ndarray, "
                f"found leaves of type: {bad_types}")

        if not all(a.shape == b.shape
                   for a, b in zip(actual_leaves, expected_leaves)):
            shapes_tree = jax.tree_multimap(
                lambda a, b:
                f"{a.shape} {'!=' if a.shape != b.shape else '=='} {b.shape}",
                actual, expected)
            raise TypeError(
                f"found leaves with unexpected shapes: {shapes_tree}")
Пример #8
0
    def test_jit_pytree_return(self):
        @iree.jax.jit
        def apply_sqrt(pytree):
            return jax.tree_map(jnp.sqrt, pytree)

        np.random.seed(0)
        input_tree = {
            "a": [
                normal((2, 3)),
                {
                    "b": normal(3)
                },
            ],
            "c": (
                {
                    "d": [normal(2), normal(3)]
                },
                (normal(1), normal(4)),
            )
        }

        expected = jax.tree_map(jnp.sqrt, input_tree)
        expected_arrays, expected_tree = jax.tree_flatten(expected)
        result = apply_sqrt(input_tree)
        result_arrays, result_tree = jax.tree_flatten(result)

        self.assertEqual(expected_tree, result_tree)
        for expected_array, result_array in zip(expected_arrays,
                                                result_arrays):
            np.testing.assert_allclose(expected_array, result_array,
                                       **TOLERANCE)
Пример #9
0
def __local_cost_and_grad_function(local_cost_fun, dtype, logpsi, pars, *args):
    costfun_outdtype = _outdtype[local_cost_fun]
    lcfun_u = _unjitted_fun[local_cost_fun]

    if dtype is complex:
        der_local_cost_fun = jax.value_and_grad(lcfun_u, argnums=1, holomorphic=True)

        return der_local_cost_fun(logpsi, pars, *args)
    else:
        if costfun_outdtype is complex:
            _costfun_re = lambda w: lcfun_u(logpsi, w, *args).real
            _costfun_im = lambda w: lcfun_u(logpsi, w, *args).imag

            # Pullbacks
            cost_val_re, cost_vjp_re = jax.vjp(_costfun_re, pars)
            cost_val_im, cost_vjp_im = jax.vjp(_costfun_im, pars)

            cost_val = cost_val_re + 1.0j * cost_val_im

            primal = jax.numpy.ones(cost_val.shape)

            # Apply pullbacks to primals
            cost_grad_re, tree_fun = jax.tree_flatten(cost_vjp_re(primal)[0])
            cost_grad_im, _ = jax.tree_flatten(cost_vjp_im(primal)[0])

            out_flat = [re + 1.0j * im for re, im in zip(cost_grad_re, cost_grad_im)]

            grad_c = jax.tree_unflatten(tree_fun, out_flat)
            return (cost_val, grad_c)
        else:
            der_local_cost_fun = jax.value_and_grad(lcfun_u, argnums=1)
            return der_local_cost_fun(logpsi, pars, *args)
def _local_value_and_grad_notcentered_kernel(
    logpsi, pars, vp, mel, v, real_to_complex=False
):
    # can use if with jit because that argument is exposed statically to the jit!
    if real_to_complex:
        logpsi_vp_r, f_vjp_r = jax.vjp(lambda w: (logpsi(w, vp).real), pars)
        logpsi_vp_j, f_vjp_j = jax.vjp(lambda w: (logpsi(w, vp).imag), pars)

        logpsi_vp = logpsi_vp_r + 1.0j * logpsi_vp_j

        vec = mel * jax.numpy.exp(logpsi_vp - logpsi(pars, v))
        vec_r = vec.real
        vec_j = vec.imag

        loc_val = vec.sum()

        vr_grad_r, tree_fun = jax.tree_flatten(f_vjp_r(vec_r)[0])
        vj_grad_r, _ = jax.tree_flatten(f_vjp_r(vec_j)[0])
        vr_grad_j, _ = jax.tree_flatten(f_vjp_j(vec_r)[0])
        vj_grad_j, _ = jax.tree_flatten(f_vjp_j(vec_j)[0])

        r_flat = [rr + 1j * jr for rr, jr in zip(vr_grad_r, vj_grad_r)]
        j_flat = [rr + 1j * jr for rr, jr in zip(vr_grad_j, vj_grad_j)]
        out_flat = [re + 1.0j * im for re, im in zip(r_flat, j_flat)]

        grad_c = jax.tree_unflatten(tree_fun, out_flat)
    else:
        logpsi_vp, f_vjp = jax.vjp(lambda w: logpsi(w, vp), pars)

        vec = mel * jax.numpy.exp(logpsi_vp - logpsi(pars, v))

        loc_val = vec.sum()
        grad_c = f_vjp(vec)[0]

    return loc_val, grad_c
    def update_fn(opt, lr, images, labels, rng):
        """Update step."""

        measurements = {}

        # Get device-specific loss rng.
        rng, rng_model = jax.random.split(rng, 2)
        rng_model_local = jax.random.fold_in(rng_model,
                                             jax.lax.axis_index('batch'))
        rng_model_local, diag_noise_rng, standard_noise_rng = jax.random.split(
            rng_model_local, num=3)

        def loss_fn(params, images, labels):
            logits, _ = model.apply({'params': flax.core.freeze(params)},
                                    images,
                                    train=True,
                                    rngs={
                                        'dropout':
                                        rng_model_local,
                                        'diag_noise_samples':
                                        diag_noise_rng,
                                        'standard_norm_noise_samples':
                                        standard_noise_rng
                                    })
            label_indices = config.get('label_indices')
            if label_indices:
                logits = logits[:, label_indices]
            return getattr(train_utils,
                           config.get('loss', 'sigmoid_xent'))(logits=logits,
                                                               labels=labels)

        # Implementation considerations compared and summarized at
        # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en#
        l, g = train_utils.accumulate_gradient(jax.value_and_grad(loss_fn),
                                               opt.target, images, labels,
                                               config.get('grad_accum_steps'))
        l, g = jax.lax.pmean((l, g), axis_name='batch')

        # Log the gradient norm only if we need to compute it anyways (clipping)
        # or if we don't use grad_accum_steps, as they interact badly.
        if config.get('grad_accum_steps',
                      1) == 1 or config.get('grad_clip_norm'):
            grads, _ = jax.tree_flatten(g)
            l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
            measurements['l2_grads'] = l2_g

        # Optionally resize the global gradient to a maximum norm. We found this
        # useful in some cases across optimizers, hence it's in the main loop.
        if config.get('grad_clip_norm'):
            g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
            g = jax.tree_map(lambda p: g_factor * p, g)
        opt = opt.apply_gradient(g, learning_rate=lr)
        opt = opt.replace(target=weight_decay_fn(opt.target, lr))

        params, _ = jax.tree_flatten(opt.target)
        measurements['l2_params'] = jnp.sqrt(
            sum([jnp.vdot(p, p) for p in params]))

        return opt, l, rng, measurements
Пример #12
0
    def update_fn(opt, lr, images, labels, rng):
        """Update step. Copy to deterministic_utils.py whenever changes are made!"""
        measurements = {}

        # Split rng and return next_rng for the following step.
        rng, next_rng = jax.random.split(rng, 2)
        rng_local = jax.random.fold_in(rng, jax.lax.axis_index('batch'))

        def loss_fn(params, images, labels):
            logits, _ = model.apply({'params': flax.core.freeze(params)},
                                    images,
                                    train=True,
                                    rngs={'dropout': rng_local})
            label_indices = config.get('label_indices')
            if label_indices:
                logits = logits[:, label_indices]
            loss = getattr(train_utils,
                           config.get('loss', 'sigmoid_xent'))(logits=logits,
                                                               labels=labels)
            return loss, logits

        # Implementation considerations compared and summarized at
        # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en#
        (l, logits), g = train_utils.accumulate_gradient(
            jax.value_and_grad(loss_fn, has_aux=True), opt.target, images,
            labels, config.get('grad_accum_steps'))
        l, g = jax.lax.pmean((l, g), axis_name='batch')
        measurements['training_loss'] = l

        # Log the gradient norm only if we need to compute it anyways (clipping)
        # or if we don't use grad_accum_steps, as they interact badly.
        if config.get('grad_accum_steps',
                      1) == 1 or config.get('grad_clip_norm'):
            grads, _ = jax.tree_flatten(g)
            l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
            measurements['l2_grads'] = l2_g

        # Optionally resize the global gradient to a maximum norm. We found this
        # useful in some cases across optimizers, hence it's in the main loop.
        if config.get('grad_clip_norm'):
            g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
            g = jax.tree_util.tree_map(lambda p: g_factor * p, g)
        opt = opt.apply_gradient(g, learning_rate=lr)

        opt = opt.replace(target=weight_decay_fn(opt.target, lr))

        params, _ = jax.tree_flatten(opt.target)
        measurements['l2_params'] = jnp.sqrt(
            sum([jnp.vdot(p, p) for p in params]))

        top1_idx = jnp.argmax(logits, axis=1)
        top1_correct = jnp.take_along_axis(labels, top1_idx[:, None],
                                           axis=1)[:, 0]
        prec1 = jax.lax.psum(jnp.sum(top1_correct),
                             axis_name='batch') / batch_size
        measurements['training_prec@1'] = prec1
        measurements['learning_rate'] = lr
        return opt, next_rng, measurements
Пример #13
0
    def update_fn(opt, lr, images, labels, rng):
        """Update step."""
        measurements = {}

        # Get device-specific loss rng.
        rng, rng_model = jax.random.split(rng, 2)
        rng_model_local = jax.random.fold_in(rng_model,
                                             jax.lax.axis_index('batch'))

        def loss_fn(params, images, labels):
            logits, _ = model.apply({'params': flax.core.freeze(params)},
                                    images,
                                    train=True,
                                    rngs={'dropout': rng_model_local})
            return getattr(train_utils,
                           config.get('loss', 'sigmoid_xent'))(logits=logits,
                                                               labels=labels)

        # Implementation considerations compared and summarized at
        # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en#
        l, g = train_utils.accumulate_gradient(jax.value_and_grad(loss_fn),
                                               opt.target, images, labels,
                                               config.get('grad_accum_steps'))
        l, g = jax.lax.pmean((l, g), axis_name='batch')

        # Log the gradient norm only if we need to compute it anyways (clipping)
        # or if we don't use grad_accum_steps, as they interact badly.
        if config.get('grad_accum_steps',
                      1) == 1 or grad_clip_norm is not None:
            grads, _ = jax.tree_flatten(g)
            l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
            measurements['l2_grads'] = l2_g

        # Optionally resize the global gradient to a maximum norm. We found this
        # useful in some cases across optimizers, hence it's in the main loop.
        if grad_clip_norm is not None:
            g_factor = jnp.minimum(1.0, grad_clip_norm / l2_g)
            g = jax.tree_util.tree_map(lambda p: g_factor * p, g)
        opt = opt.apply_gradient(g, learning_rate=lr)

        decay_rules = weight_decay or []
        if isinstance(decay_rules, numbers.Number):
            decay_rules = [('.*kernel.*', decay_rules)]
        sched_m = lr / config.lr.base if config.get(
            'weight_decay_decouple') else lr

        def decay_fn(v, wd):
            return (1.0 - sched_m * wd) * v

        opt = opt.replace(target=train_utils.tree_map_with_regex(
            decay_fn, opt.target, decay_rules))

        params, _ = jax.tree_flatten(opt.target)
        measurements['l2_params'] = jnp.sqrt(
            sum([jnp.vdot(p, p) for p in params]))

        return opt, l, rng, measurements
Пример #14
0
def tree_vmap(f, lst):
    stacked = jax.tree_map(lambda args: jnp.stack(args), lst)
    out_stacked = jax.vmap(f)(stacked)
    _, outer_treedef = jax.tree_flatten([None] * len(lst))
    _, inner_treedef = jax.tree_flatten(out_stacked)
    out_unstacked_transposed = jax.tree_map(list, out_stacked)
    out_unstacked = jax.tree_transpose(
        outer_treedef, inner_treedef, out_unstacked_transposed
    )
    return out_unstacked
Пример #15
0
def check_eq(xs, ys, atol=None, rtol=None):
    xs_leaves, xs_tree = jax.tree_flatten(xs)
    ys_leaves, ys_tree = jax.tree_flatten(ys)
    assert xs_tree == ys_tree, "Tree shapes don't match."
    assert jax.tree_util.tree_all(
        jax.tree_multimap(lambda x, y: np.array(x).shape == np.array(y).shape,
                          xs_leaves, ys_leaves)), "Leaves' shapes don't match."
    assert jax.tree_multimap(
        partial(_assert_numpy_allclose, atol=atol, rtol=rtol), xs_leaves,
        ys_leaves)
Пример #16
0
    def __call__(self, params: Params) -> float:
        """Evaluates the regularizer."""
        params = self._preprocess_fn(params)
        leaves, _ = jax.tree_flatten(params)

        if self._param_weights:
            param_weight_leaves, _ = jax.tree_flatten(self._param_weights)
            return sum(
                jnp.vdot(pw, jnp.square(x))
                for pw, x in zip(param_weight_leaves, leaves)) * self._weight

        return sum(jnp.vdot(x, x) for x in leaves) * self._weight
Пример #17
0
  def test_load_state(self):
    temp_dir = self.get_temp_dir()
    path = os.path.join(temp_dir, 'state')
    init_state = test_util.create_mock_state()
    serialization.save_state(init_state, path)

    state = serialization.load_state(path)

    expected_flat, expected_tree_def = jax.tree_flatten(init_state)
    actual_flat, actual_tree_def = jax.tree_flatten(state)
    for expected_array, actual_array in zip(expected_flat, actual_flat):
      self.assertAllEqual(expected_array, actual_array)
    self.assertEqual(expected_tree_def, actual_tree_def)
Пример #18
0
    def assertAgentParametersEqual(self, agent1: sac_agent.SACAgent,
                                   agent2: sac_agent.SACAgent):
        agent1_params = get_agent_params(agent1)
        agent2_params = get_agent_params(agent2)

        agent1_params, agent1_structure = jax.tree_flatten(agent1_params)
        agent2_params, agent2_structure = jax.tree_flatten(agent2_params)

        self.assertEqual(agent1_structure, agent2_structure,
                         'Parameter structures do not match.')

        for param1, param2 in zip(agent1_params, agent2_params):
            if (param1 != param2).any():
                self.fail(f'Parameters are not equal: {param1}, {param2}')
Пример #19
0
    def test_apply_if_finite(self, opt_builder):
        one = jnp.ones([])
        nan = jnp.array(jnp.nan)

        def fn(x):
            return x * hk.get_parameter('p', [],
                                        init=hk.initializers.Constant(0.))

        fn = hk.without_apply_rng(hk.transform(fn))
        params = fn.init(jax.random.PRNGKey(1905), one)
        opt = wrappers.apply_if_finite(opt_builder(), 2)
        state = opt.init(params)
        grads_fn = jax.grad(self.variant(fn.apply))
        # Do one successful param update
        grads = grads_fn(params, one)
        updates, state = opt.update(grads, state, params)
        params = update.apply_updates(params, updates)
        # We know exactly what should be the value of params since we are
        # effectively using sgd in all cases.
        self.assertEqual(-1., float(jax.tree_flatten(params)[0][0]))
        self.assertTrue(bool(state.last_finite))
        # Check 2 rejected param updates
        for step in range(2):
            grads = grads_fn(params, nan)
            updates, state = opt.update(grads, state, params)
            params = update.apply_updates(params, updates)
            self.assertEqual(-1., float(jax.tree_flatten(params)[0][0]))
            self.assertFalse(bool(state.last_finite))
            self.assertEqual(step + 1, int(state.notfinite_count))
        # Next successful param update
        grads = grads_fn(params, one)
        updates, state = opt.update(grads, state, params)
        params = update.apply_updates(params, updates)
        self.assertEqual(-2., float(jax.tree_flatten(params)[0][0]))
        self.assertTrue(bool(state.last_finite))
        # Again 2 rejected param updates
        for step in range(2):
            grads = grads_fn(params, nan)
            updates, state = opt.update(grads, state, params)
            params = update.apply_updates(params, updates)
            self.assertEqual(-2., float(jax.tree_flatten(params)[0][0]))
            self.assertFalse(bool(state.last_finite))
            self.assertEqual(step + 1, int(state.notfinite_count))
        # Next param update with NaN is accepted since we reached maximum
        grads = grads_fn(params, nan)
        updates, state = opt.update(grads, state, params)
        params = update.apply_updates(params, updates)
        self.assertTrue(bool(jnp.isnan(jax.tree_flatten(params)[0][0])))
        self.assertEqual(5, int(state.total_notfinite))
Пример #20
0
    def train_benchmark(state):

        train_iter = iter(train_data)
        example = next(train_iter)

        params_, net_state_, opt_state_, metrics_state_, *_ = train_step(
            params, net_state, next(rng), opt_state, metrics_state, *example)

        [x.block_until_ready() for x in jax.tree_flatten(params_)[0]]
        while state:
            params_, net_state_, opt_state_, metrics_state_, *_ = train_step(
                params_, net_state_, next(rng), opt_state_, metrics_state_,
                *example)
            example = next(train_iter)
            [x.block_until_ready() for x in jax.tree_flatten(params_)[0]]
Пример #21
0
 def backward(ctx, *tangents):
     tangents = jax.tree_map(
         map_if(is_torch_tensor)(to_jax_ndarray), tangents)
     tangents = jax.tree_unflatten(ctx.result_tree, tangents[:-1])
     grads = ctx.fun_vjp(tangents)
     return (None, *jax.tree_flatten(
         jax.tree_map(map_if(is_jax_ndarray)(to_torch_tensor), grads))[0])
Пример #22
0
 def forward(ctx, fn, *args):
     args = jax.tree_map(map_if(is_torch_tensor)(to_jax_ndarray), args)
     result, ctx.fun_vjp = jax.vjp(fn, *args)
     result_flat, result_tree = jax.tree_flatten(result)
     ctx.result_tree = result_tree
     return (*jax.tree_map(
         map_if(is_jax_ndarray)(to_torch_tensor), result_flat), result_tree)
  def update_fn(grads, state, params):
    """Transform the input gradient and update all statistics.

    Args:
      grads: the gradient tensors for the parameters.
      state: a named tuple containing the state of the optimizer
      params: the parameters that should be updated.

    Returns:
      A tuple containing the new parameters and the new optimizer state.
    """
    params_flat, treedef = jax.tree_flatten(params)
    stats_flat = treedef.flatten_up_to(state.stats)
    grads_flat = treedef.flatten_up_to(grads)

    new_stats_flat = jax.tree_multimap(
        lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat,
        stats_flat, params_flat)
    new_stats_flat = _compute_preconditioners(new_stats_flat, params_flat,
                                              state.count)

    outputs = jax.tree_multimap(
        lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat,
        new_stats_flat, params_flat)
    updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())

    updates = jax.tree_unflatten(treedef, updates_flat)
    new_stats = jax.tree_unflatten(treedef, new_stats_flat)

    new_state = ShampooState(
        count=state.count+1, stats=new_stats)
    return updates, new_state
Пример #24
0
    def test_simple_mask_two_layer(self):
        """Tests generation of a simple mask."""
        mask = {
            'MaskedModule_0': {
                'kernel':
                jnp.zeros(self._masked_model_twolayer.params['MaskedModule_0']
                          ['unmasked']['kernel'].shape),
                'bias':
                None,
            },
            'MaskedModule_1': {
                'kernel':
                jnp.zeros(self._masked_model_twolayer.params['MaskedModule_1']
                          ['unmasked']['kernel'].shape),
                'bias':
                None,
            },
        }

        gen_mask = masked.simple_mask(self._masked_model_twolayer, jnp.zeros,
                                      ['kernel'])

        result, _ = jax.tree_flatten(
            jax.tree_util.tree_multimap(lambda x, *xs: (x == xs[0]).all(),
                                        mask, gen_mask))

        self.assertTrue(all(result))
Пример #25
0
    def apply_gradient(self, hyper_params, params, state, grads):
        p_leaves, treedef = jax.tree_flatten(params)
        s_leaves = treedef.flatten_up_to(state.param_states)
        g_leaves = treedef.flatten_up_to(grads)
        split_grads = zip(*(self._split_grad(p, s, g, hyper_params.wn_decay)
                            for p, s, g in zip(p_leaves, s_leaves, g_leaves)))
        d_p, d_s, d_g, s_p, s_s, s_g = [
            jax.tree_unflatten(treedef, x) for x in split_grads
        ]
        wn_params = {'direction': d_p, 'scale': s_p}
        wn_state = {'direction': d_s, 'scale': s_s}
        wn_grads = {'direction': d_g, 'scale': s_g}
        new_wn_params, new_state = self.wrapped_optimizer.apply_gradient(
            hyper_params.inner, wn_params,
            state.replace(param_states=wn_state), wn_grads)

        directions = treedef.flatten_up_to(new_wn_params['direction'])
        scales = treedef.flatten_up_to(new_wn_params['scale'])
        new_params, mults = zip(*(self._merge_param(d, s, hyper_params.wn_eps)
                                  for d, s in zip(directions, scales)))
        new_params = jax.tree_unflatten(treedef, new_params)
        mults = jax.tree_unflatten(treedef, mults)

        direction_state = new_state.param_states['direction']
        scale_state = new_state.param_states['scale']
        param_states = jax.tree_multimap(
            lambda _, *args: _WeightNormParamState(*args), params,
            direction_state, scale_state, mults)
        return new_params, new_state.replace(param_states=param_states)
Пример #26
0
def scan_wrapper(f, init, xs, length, reverse, rng_key=None, substitute_stack=[], enum=False):
    if length is None:
        length = tree_flatten(xs)[0][0].shape[0]

    if enum:
        return scan_enum(f, init, xs, length, reverse, rng_key, substitute_stack)

    def body_fn(wrapped_carry, x):
        i, rng_key, carry = wrapped_carry
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None)

        with handlers.block():
            seeded_fn = handlers.seed(f, subkey) if subkey is not None else f
            for subs_type, subs_map in substitute_stack:
                subs_fn = partial(_subs_wrapper, subs_map, i, length)
                if subs_type == 'condition':
                    seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
                elif subs_type == 'substitute':
                    seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)

            with handlers.trace() as trace:
                carry, y = seeded_fn(carry, x)

        return (i + 1, rng_key, carry), (PytreeTrace(trace), y)

    return lax.scan(body_fn, (jnp.array(0), rng_key, init), xs, length=length, reverse=reverse)
Пример #27
0
    def body_fn(wrapped_carry, x, prefix=None):
        i, rng_key, carry = wrapped_carry
        init = True if (not_jax_tracer(i) and i == 0) else False
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None)

        seeded_fn = handlers.seed(f, subkey) if subkey is not None else f
        for subs_type, subs_map in substitute_stack:
            subs_fn = partial(_subs_wrapper, subs_map, i, length)
            if subs_type == 'condition':
                seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
            elif subs_type == 'substitute':
                seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)

        if init:
            with handlers.scope(prefix="_init"):
                new_carry, y = seeded_fn(carry, x)
                trace = {}
        else:
            with handlers.block(), packed_trace() as trace, promote_shapes(), enum(), markov():
                # Like scan_wrapper, we collect the trace of scan's transition function
                # `seeded_fn` here. To put time dimension to the correct position, we need to
                # promote shapes to make `fn` and `value`
                # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`,
                # and value's batch_shape is (3,), then we promote shape of
                # value so that its batch shape is (1, 3)).
                new_carry, y = config_enumerate(seeded_fn)(carry, x)

            # store shape of new_carry at a global variable
            nonlocal carry_shape_at_t1
            carry_shape_at_t1 = [jnp.shape(x) for x in tree_flatten(new_carry)[0]]
            # make new_carry have the same shape as carry
            # FIXME: is this rigorous?
            new_carry = tree_multimap(lambda a, b: jnp.reshape(a, jnp.shape(b)),
                                      new_carry, carry)
        return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y)
    def apply_gradient(self, hyper_params, params, state, grads, hessians):
        """Applies a gradient for a set of parameters.
        Args:
            hyper_params: a named tuple of hyper parameters.
            params: the parameters that should be updated.
            state: a named tuple containing the state of the optimizer
            grads: the gradient tensors for the parameters.
            hessians: the hessian tensors for the parameters.
        Returns:
            A tuple containing the new parameters and the new optimizer state.
        """
        step = state.step
        params_flat, treedef = jax.tree_flatten(params)
        states_flat = treedef.flatten_up_to(state.param_states)
        grads_flat = treedef.flatten_up_to(grads)
        hessians_flat = treedef.flatten_up_to(hessians)
        out = [
            self.apply_param_gradient(step, hyper_params, param, state, grad,
                                      hessian) for param, state, grad, hessian
            in zip(params_flat, states_flat, grads_flat, hessians_flat)
        ]

        new_params_flat, new_states_flat = list(zip(*out)) if out else ((), ())
        new_params = jax.tree_unflatten(treedef, new_params_flat)
        new_param_states = jax.tree_unflatten(treedef, new_states_flat)
        new_state = OptimizerState(step + 1, new_param_states)
        return new_params, new_state
Пример #29
0
    def forward_compute_losses(
        params_primals: Any, ) -> Sequence[Sequence[jnp.ndarray]]:
        primals_[params_index] = params_primals
        flat_args = jax.tree_flatten(primals_)[0]
        # Mapping from variable -> value
        env = dict()
        read = functools.partial(tgm.read_env, env)
        write = functools.partial(tgm.write_env, env)

        # Bind args and consts to environment
        write(jax.core.unitvar, jax.core.unit)
        jax_util.safe_map(write, jaxpr.invars, flat_args)
        jax_util.safe_map(write, jaxpr.constvars, consts)

        # Loop through equations and evaluate primitives using `bind`
        losses_so_far = 0
        loss_tags = []
        for eqn in jaxpr.eqns:
            tgm.evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write)
            if isinstance(eqn.primitive, tags.LossTag):
                loss_tags.append(eqn)
                losses_so_far += 1
            if num_losses is not None and losses_so_far == num_losses:
                break
        return tuple(tuple(read(v) for v in tag.invars) for tag in loss_tags)
Пример #30
0
        def vjp_func(tangents):
            all_tangents = aux_vjp(tangents)
            tangents_dict, inputs_tangents = all_tangents[0], all_tangents[1:]
            inputs_tangents = jax.tree_flatten(inputs_tangents)[0]
            tangents_dict.update(zip(jaxpr.invars, inputs_tangents))

            read_primals = functools.partial(tgm.read_env, primals_dict)
            read_tangents = functools.partial(tgm.read_env, tangents_dict)
            layers_info = []
            for jaxpr_eqn in layer_tags:
                layer_tag = _unbox_layer_tag(jaxpr_eqn)
                info = dict()
                primals = jax_util.safe_map(read_primals,
                                            tuple(jaxpr_eqn.invars))
                (
                    info["outputs"],
                    info["inputs"],
                    info["params"],
                ) = layer_tag.split_all_inputs(primals)
                tangents = jax_util.safe_map(read_tangents,
                                             tuple(jaxpr_eqn.invars))
                (
                    info["outputs_tangent"],
                    info["inputs_tangent"],
                    info["params_tangent"],
                ) = layer_tag.split_all_inputs(tangents)
                layers_info.append(info)
            return tuple(layers_info)