def compute_loss(predicted_positions, predicted_momentums, target_positions, target_momentums, auxiliary_predictions=None, regularizations=None): """Computes the loss for the given predictions.""" assert predicted_positions.shape == target_positions.shape, f'Got predicted_positions: {predicted_positions.shape}, target_positions: {target_positions.shape}' assert predicted_momentums.shape == target_momentums.shape, f'Got predicted_momentums: {predicted_momentums.shape}, target_momentums: {target_momentums.shape}' loss = optax.l2_loss(predictions=predicted_positions, targets=target_positions) loss += optax.l2_loss(predictions=predicted_momentums, targets=target_momentums) loss = jnp.mean(loss) if auxiliary_predictions is not None: angular_velocities = auxiliary_predictions['angular_velocities'] angular_velocities_variances = jnp.var(angular_velocities, axis=0).sum() loss += regularizations[ 'angular_velocities'] * angular_velocities_variances actions = auxiliary_predictions['actions'] actions_variances = jnp.var(actions, axis=0).sum() loss += regularizations['actions'] * actions_variances return loss
def loss_fn(params): outputs = state.apply_fn(params, inputs) predictions = outputs.predictions phis = outputs.phi # Split out phis for implicit least squares grad computation loss = jnp.mean(optax.l2_loss(predictions, targets)) rank = jnp.linalg.matrix_rank(phis.T @ phis) metrics_update = metrics.single_from_model_output(loss=loss, rank=rank) return loss, metrics_update
def _optimizer_loop(optimizer, iterations=5): """Helper function for running optimizer loops.""" params = {'w': jnp.ones((2, ))} opt_state = optimizer.init(params) results = [] for _ in range(iterations): compute_loss = lambda params, x, y: optax.l2_loss( params['w'].dot(x), y) grads = jax.grad(compute_loss)(params, jnp.array([5.0, 6.0]), 4.0) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) results.append(params) return results
def loss_fn(params): outputs = state.apply_fn(params, inputs) phis = outputs.phi # ws = jax.scipy.sparse.linalg.cg( # phis.T @ phis, phis.T @ targets, tol=1e-12)[0] ws, _, _, _ = jnp.linalg.lstsq(phis, targets, rcond=rcond) if stop_grad: ws = jax.lax.stop_gradient(ws) task_outputs = phis @ ws loss = jnp.mean(optax.l2_loss(task_outputs, targets)) rank = jnp.linalg.matrix_rank(phis.T @ phis) metrics_update = metrics.single_from_model_output(loss=loss, rank=rank) return loss, metrics_update
def test_dummy_step(self): """Test dummy step.""" num_weights = 100 xs = jnp.ones((num_weights, )) ys = 1 optimizer = transform_chain(['nesterov', 'polyak_hb'], [{}, {}]) params = {'w': jnp.ones((num_weights, ))} opt_state = optimizer.init(flax.core.FrozenDict(params)) compute_loss = lambda params, x, y: optax.l2_loss( params['w'].dot(x), y) grads = jax.grad(compute_loss)(params, xs, ys) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) self.assertTrue(params)
def loss_fn(): outputs = train_state.apply_fn(train_state.params, states) phis = outputs.phi predictions = jax.vmap(state.apply_fn, in_axes=(None, 0))(state.params, phis) return jnp.mean(optax.l2_loss(predictions, targets))
def loss_fn(params): predictions = jax.vmap(eval_state.apply_fn, in_axes=(None, 0))(params, phis) loss = jnp.mean(optax.l2_loss(predictions, targets)) metrics_update = EvalMetrics.single_from_model_output(loss=loss) return loss, metrics_update
def forward(variables, batch, rngs=None): del rngs out = model.apply(variables, batch['input']) loss = optax.l2_loss(out, batch['target']).mean() return loss, (out, {})
def rms_loss(pred, tar, w): return jnp.mean(optax.l2_loss(pred, tar))