Ejemplo n.º 1
0
def compute_loss(predicted_positions,
                 predicted_momentums,
                 target_positions,
                 target_momentums,
                 auxiliary_predictions=None,
                 regularizations=None):
    """Computes the loss for the given predictions."""
    assert predicted_positions.shape == target_positions.shape, f'Got predicted_positions: {predicted_positions.shape}, target_positions: {target_positions.shape}'
    assert predicted_momentums.shape == target_momentums.shape, f'Got predicted_momentums: {predicted_momentums.shape}, target_momentums: {target_momentums.shape}'

    loss = optax.l2_loss(predictions=predicted_positions,
                         targets=target_positions)
    loss += optax.l2_loss(predictions=predicted_momentums,
                          targets=target_momentums)
    loss = jnp.mean(loss)

    if auxiliary_predictions is not None:
        angular_velocities = auxiliary_predictions['angular_velocities']
        angular_velocities_variances = jnp.var(angular_velocities,
                                               axis=0).sum()
        loss += regularizations[
            'angular_velocities'] * angular_velocities_variances

        actions = auxiliary_predictions['actions']
        actions_variances = jnp.var(actions, axis=0).sum()
        loss += regularizations['actions'] * actions_variances
    return loss
Ejemplo n.º 2
0
    def loss_fn(params):
        outputs = state.apply_fn(params, inputs)
        predictions = outputs.predictions
        phis = outputs.phi
        # Split out phis for implicit least squares grad computation
        loss = jnp.mean(optax.l2_loss(predictions, targets))
        rank = jnp.linalg.matrix_rank(phis.T @ phis)
        metrics_update = metrics.single_from_model_output(loss=loss, rank=rank)

        return loss, metrics_update
Ejemplo n.º 3
0
def _optimizer_loop(optimizer, iterations=5):
    """Helper function for running optimizer loops."""
    params = {'w': jnp.ones((2, ))}
    opt_state = optimizer.init(params)
    results = []
    for _ in range(iterations):
        compute_loss = lambda params, x, y: optax.l2_loss(
            params['w'].dot(x), y)
        grads = jax.grad(compute_loss)(params, jnp.array([5.0, 6.0]), 4.0)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        results.append(params)
    return results
Ejemplo n.º 4
0
    def loss_fn(params):
        outputs = state.apply_fn(params, inputs)
        phis = outputs.phi

        # ws = jax.scipy.sparse.linalg.cg(
        #     phis.T @ phis, phis.T @ targets, tol=1e-12)[0]
        ws, _, _, _ = jnp.linalg.lstsq(phis, targets, rcond=rcond)
        if stop_grad:
            ws = jax.lax.stop_gradient(ws)

        task_outputs = phis @ ws
        loss = jnp.mean(optax.l2_loss(task_outputs, targets))

        rank = jnp.linalg.matrix_rank(phis.T @ phis)
        metrics_update = metrics.single_from_model_output(loss=loss, rank=rank)

        return loss, metrics_update
Ejemplo n.º 5
0
    def test_dummy_step(self):
        """Test dummy step."""
        num_weights = 100
        xs = jnp.ones((num_weights, ))
        ys = 1

        optimizer = transform_chain(['nesterov', 'polyak_hb'], [{}, {}])
        params = {'w': jnp.ones((num_weights, ))}
        opt_state = optimizer.init(flax.core.FrozenDict(params))

        compute_loss = lambda params, x, y: optax.l2_loss(
            params['w'].dot(x), y)
        grads = jax.grad(compute_loss)(params, xs, ys)

        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)

        self.assertTrue(params)
Ejemplo n.º 6
0
 def loss_fn():
     outputs = train_state.apply_fn(train_state.params, states)
     phis = outputs.phi
     predictions = jax.vmap(state.apply_fn,
                            in_axes=(None, 0))(state.params, phis)
     return jnp.mean(optax.l2_loss(predictions, targets))
Ejemplo n.º 7
0
 def loss_fn(params):
     predictions = jax.vmap(eval_state.apply_fn, in_axes=(None, 0))(params,
                                                                    phis)
     loss = jnp.mean(optax.l2_loss(predictions, targets))
     metrics_update = EvalMetrics.single_from_model_output(loss=loss)
     return loss, metrics_update
Ejemplo n.º 8
0
def forward(variables, batch, rngs=None):
    del rngs
    out = model.apply(variables, batch['input'])
    loss = optax.l2_loss(out, batch['target']).mean()
    return loss, (out, {})
Ejemplo n.º 9
0
 def rms_loss(pred, tar, w):
     return jnp.mean(optax.l2_loss(pred, tar))