Exemplo n.º 1
0
def _training_epoch(
    optimizer: optim.Optimizer,
    random_key: np.ndarray,
    training_step: Callable = lambda *x: None,
    training_loader: Sequence = (),
    eval_step: Callable = lambda *x: None,
    eval_loader: Sequence = (),
    eval_frequency: int = 1,
    progress_bar: bool = True,
    desc: str = "",
) -> Tuple[optim.Optimizer, np.ndarray]:
    total = len(training_loader) + eval_frequency * len(eval_loader)
    eval_every = len(training_loader) / eval_frequency
    get_progress_bar = partial(tqdm, disable=not progress_bar, leave=False)

    prog_bar = get_progress_bar(total=total, desc=desc)
    optimizer = optimizer.replicate()
    random_key = random.split(random_key, num=jax.device_count())

    def eval_epoch(eval_loader: Iterable,
                   random_key: np.ndarray) -> np.ndarray:
        loader = get_progress_bar(eval_loader, desc="Valid", position=1)
        for batch in loader:
            random_key, subkey = _parallel_split(random_key)
            metrics = eval_step(optimizer.target, _shard(batch), subkey)

            prog_bar.update()
            if "loss" in metrics:
                loss = np.mean(metrics["loss"]).item()
                loader.set_postfix_str(f"loss={loss:.4f}", refresh=False)

        return random_key

    for i, batch in enumerate(training_loader, 1):
        random_key, subkey = _parallel_split(random_key)
        metrics, optimizer = training_step(optimizer, _shard(batch), subkey)

        prog_bar.update()
        if "loss" in metrics:
            loss = np.mean(metrics["loss"]).item()
            prog_bar.set_postfix_str(f"loss={loss:.4f}", refresh=False)

        if eval_loader and int(i % eval_every) == 0:
            eval_epoch(eval_loader, random_key=random_key)

    prog_bar.close()
    optimizer = optimizer.unreplicate()
    random_key = random_key[0]

    return optimizer, random_key
Exemplo n.º 2
0
def train_step(
    config: Any,
    optimizer: optim.Optimizer,
    model_state: Mapping[str, Any],
    batch: Dict[str, Array],
    rngs: Dict[str, Any],
) -> Tuple[optim.Optimizer, Dict[str, Any], Dict[str, Any]]:
    """Train for a single step."""
    # Make sure to get a new RNG at every step.
    model = model_from_config(config)
    step = optimizer.state.step
    rngs = {name: jax.random.fold_in(rng, step) for name, rng in rngs.items()}

    def loss_fn(params):
        variables = {'params': params, **model_state}
        logits, new_model_state = model.apply(variables,
                                              batch['token_ids'],
                                              batch['length'],
                                              rngs=rngs,
                                              mutable=list(model_state.keys()))

        labels = batch['label']
        if labels.ndim == 1:
            labels = jnp.expand_dims(labels, 1)
        loss = jnp.mean(
            sigmoid_cross_entropy_with_logits(labels=labels, logits=logits))
        return loss, (logits, new_model_state)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    value, grad = grad_fn(optimizer.target)
    (_, (logits, new_model_state)) = value
    optimizer = optimizer.apply_gradient(grad)

    metrics = compute_metrics(labels=batch['label'], logits=logits)
    return optimizer, metrics, new_model_state
Exemplo n.º 3
0
def actor_step(
    rng: PRNGSequence,
    optimizer: optim.Optimizer,
    critic_params: FrozenDict,
    state: jnp.ndarray,
    log_alpha_params: FrozenDict,
    max_action: float,
    action_dim: int,
) -> Tuple[optim.Optimizer, jnp.ndarray]:
    def loss_fn(actor_params):
        actor_action, log_p = apply_gaussian_policy_model(
            actor_params, action_dim, max_action, state, rng, True, False)
        q1, q2 = apply_double_critic_model(critic_params, state, actor_action,
                                           False)
        min_q = jnp.minimum(q1, q2)
        partial_loss_fn = jax.vmap(
            partial(
                actor_loss_fn,
                jax.lax.stop_gradient(
                    apply_constant_model(log_alpha_params, -3.5, False)),
            ), )
        actor_loss = partial_loss_fn(log_p, min_q)
        return jnp.mean(actor_loss), log_p

    grad, log_p = jax.grad(loss_fn, has_aux=True)(optimizer.target)
    return optimizer.apply_gradient(grad), log_p
Exemplo n.º 4
0
def train_op(net: Module, opt: Optimizer, x, y, key,
             hyperparams: ServerHyperParams):
    index = random.randint(key,
                           shape=(hyperparams.oracle_batch_size, ),
                           minval=0,
                           maxval=x.shape[0])
    v, g = vg(net, opt.target, x[index], y[index])
    return v, opt.apply_gradient(g)
Exemplo n.º 5
0
def sig_lagrange_step(optimizer: optim.Optimizer,
                      reg: float) -> optim.Optimizer:
    def loss_fn(sig_lagrange_params):
        return jnp.sum(
            apply_constant_model(sig_lagrange_params, 100.0, True) * reg)

    grad = jax.grad(loss_fn)(optimizer.target)
    return optimizer.apply_gradient(grad)
Exemplo n.º 6
0
 def training_step_(
     optimizer: optim.Optimizer,
     batch: Sequence[np.ndarray],
     random_key: np.ndarray,
 ) -> Tuple[Dict[str, np.ndarray], optim.Optimizer]:
     step_fn = _make_step_fn_differentiable(self.training_step)
     step_fn = partial(step_fn, batch=batch, random_key=random_key)
     (_, metrics), grad = jax.value_and_grad(step_fn, has_aux=True)(
         optimizer.target)
     optimizer = optimizer.apply_gradient(
         lax.pmean(grad, axis_name="batch"))
     return metrics, optimizer
Exemplo n.º 7
0
def critic_step(
    optimizer: optim.Optimizer,
    state: jnp.ndarray,
    action: jnp.ndarray,
    target_Q: jnp.ndarray,
) -> optim.Optimizer:
    def loss_fn(critic_params):
        current_Q1, current_Q2 = apply_double_critic_model(
            critic_params, state, action, False)
        critic_loss = double_mse(current_Q1, current_Q2, target_Q)
        return jnp.mean(critic_loss)

    grad = jax.grad(loss_fn)(optimizer.target)
    return optimizer.apply_gradient(grad)
Exemplo n.º 8
0
def alpha_step(optimizer: optim.Optimizer, log_p: jnp.ndarray,
               target_entropy: float) -> optim.Optimizer:
    log_p = jax.lax.stop_gradient(log_p)

    def loss_fn(log_alpha_params):
        partial_loss_fn = jax.vmap(
            partial(
                alpha_loss_fn,
                apply_constant_model(log_alpha_params, -3.5, False),
                target_entropy,
            ))
        return jnp.mean(partial_loss_fn(log_p))

    grad = jax.grad(loss_fn)(optimizer.target)
    return optimizer.apply_gradient(grad)
Exemplo n.º 9
0
def actor_step(
    optimizer: optim.Optimizer,
    critic_params: FrozenDict,
    max_action: float,
    action_dim: int,
    state: jnp.ndarray,
) -> optim.Optimizer:
    def loss_fn(actor_params):
        actor_loss = -apply_td3_critic_model(
            critic_params,
            state,
            apply_td3_actor_model(actor_params, action_dim, max_action, state),
            True,
        )
        return jnp.mean(actor_loss)

    grad = jax.grad(loss_fn)(optimizer.target)
    return optimizer.apply_gradient(grad)
Exemplo n.º 10
0
def critic_step(
    optimizer: optim.Optimizer,
    state: jnp.ndarray,
    action: jnp.ndarray,
    target_Q: jnp.ndarray,
) -> optim.Optimizer:
    """
    The critic is optimized the same way as typical actor critic methods,
    minimizing the TD error.
    """
    def loss_fn(critic_params):
        current_Q1, current_Q2 = apply_double_critic_model(
            critic_params, state, action, False)
        critic_loss = double_mse(current_Q1, current_Q2, target_Q)
        return critic_loss.mean()

    grad = jax.grad(loss_fn)(optimizer.target)
    grad = clip_grads(grad, 40.0)
    return optimizer.apply_gradient(grad)
Exemplo n.º 11
0
def m_step(
    rngs: PRNGSequence,
    actor_optimizer: optim.Optimizer,
    actor_target_params: FrozenDict,
    eps_mu: float,
    eps_sig: float,
    mu_lagrange_optimizer: optim.Optimizer,
    sig_lagrange_optimizer: optim.Optimizer,
    max_action: float,
    action_dim: int,
    state: jnp.ndarray,
    weights: jnp.ndarray,
    sampled_actions: jnp.ndarray,
) -> Tuple[optim.Optimizer, optim.Optimizer, optim.Optimizer]:
    """
    The 'M-step' from the MPO paper. We optimize our policy network to maximize
    the lower bound on the probablility of obtaining the maximum reward given
    that we act according to our policy (i.e. weighted according to our sampled actions).
    """

    def loss_fn(mlo, slo, actor_params):
        # get the distribution of the actor network (current policy)
        mu, log_sig = apply_gaussian_policy_model(
            actor_params, action_dim, max_action, state, None, False, True
        )
        sig = jnp.exp(log_sig)
        # get the distribution of the target network (old policy)
        target_mu, target_log_sig = apply_gaussian_policy_model(
            actor_target_params, action_dim, max_action, state, None, False, True
        )
        target_mu = jax.lax.stop_gradient(target_mu)
        target_log_sig = jax.lax.stop_gradient(target_log_sig)
        target_sig = jnp.exp(target_log_sig)

        # get the log likelihooods of the sampled actions according to the
        # decoupled distributions. described in section 4.2.1 of
        # Relative Entropy Regularized Policy Iteration
        # this ensures that the nonparametric policy won't collapse to give
        # a probability of 1 to the best action, which is a risk when we use
        # the on-policy distribution to calculate the likelihood.
        actor_log_prob = gaussian_likelihood(sampled_actions, target_mu, log_sig)
        actor_log_prob += gaussian_likelihood(sampled_actions, mu, target_log_sig)
        actor_log_prob = actor_log_prob.transpose((0, 1))

        mu_kl = kl_mvg_diag(target_mu, target_sig, mu, target_sig).mean()
        sig_kl = kl_mvg_diag(target_mu, target_sig, target_mu, sig).mean()

        mlo = mu_lagrange_step(mlo, eps_mu - jax.lax.stop_gradient(mu_kl))
        slo = sig_lagrange_step(slo, eps_sig - jax.lax.stop_gradient(sig_kl))

        # maximize the log likelihood, regularized by the divergence between
        # the target policy and the current policy. the goal here is to fit
        # the parametric policy to have the minimum divergence with the nonparametric
        # distribution based on the sampled actions.
        actor_loss = -(actor_log_prob * weights).sum(axis=1).mean()
        actor_loss -= jax.lax.stop_gradient(
            apply_constant_model(mlo.target, 1.0, True)
        ) * (eps_mu - mu_kl)
        actor_loss -= jax.lax.stop_gradient(
            apply_constant_model(slo.target, 100.0, True)
        ) * (eps_sig - sig_kl)
        return actor_loss.mean(), (mlo, slo)

    grad, (mu_lagrange_optimizer, sig_lagrange_optimizer) = jax.grad(
        partial(loss_fn, mu_lagrange_optimizer, sig_lagrange_optimizer), has_aux=True
    )(actor_optimizer.target)
    grad = clip_grads(grad, 40.0)

    actor_optimizer = actor_optimizer.apply_gradient(grad)

    return mu_lagrange_optimizer, sig_lagrange_optimizer, actor_optimizer