def actor_loss(policy_params, q_params, target_q_params, alpha,
                       transitions, snr_state, key, in_initial_bc_iters):
            dist_params = networks.policy_network.apply(
                policy_params, transitions.observation)
            if in_initial_bc_iters:
                log_prob = networks.log_prob(dist_params, transitions.action)
                min_q = 0.
                actor_loss = -log_prob

                # No SNR in bc iters
                sn = 0.
                new_snr_state = snr_state
            else:
                key, sub_key = jax.random.split(key)
                action = networks.sample(dist_params, sub_key)
                log_prob = networks.log_prob(dist_params, action)
                q_action = networks.q_network.apply(q_params,
                                                    transitions.observation,
                                                    action)
                min_q = jnp.min(q_action, axis=-1)
                actor_loss = alpha * log_prob - min_q

                # SNR only applied after initial BC iters
                if self._use_snr:
                    next_dist_params = networks.policy_network.apply(
                        policy_params, transitions.next_observation)
                    next_dist_params = [
                        next_dist_params._distribution._distribution.loc,
                        next_dist_params._distribution._distribution.scale,
                    ]
                    key, sub_key = jax.random.split(key)
                    sn, (masked_s, C, new_snr_state) = snr_loss_fn(
                        next_dist_params, transitions.observation,
                        transitions.action, transitions.next_observation,
                        transitions.discount, sub_key, snr_state, q_params,
                        target_q_params)
                    actor_loss = actor_loss + snr_alpha * sn
                else:
                    sn = 0.
                    new_snr_state = snr_state

            return jnp.mean(actor_loss), (min_q, jnp.mean(log_prob), sn,
                                          new_snr_state)
 def critic_loss(q_params, policy_params, target_q_params, alpha,
                 transitions, key):
     q_old_action = networks.q_network.apply(q_params,
                                             transitions.observation,
                                             transitions.action)
     next_dist_params = networks.policy_network.apply(
         policy_params, transitions.next_observation)
     next_action = networks.sample(next_dist_params, key)
     next_log_prob = networks.log_prob(next_dist_params, next_action)
     next_q = networks.q_network.apply(target_q_params,
                                       transitions.next_observation,
                                       next_action)
     # next_v = jnp.min(next_q, axis=-1) - alpha * next_log_prob
     next_v = jnp.min(next_q, axis=-1)
     target_q = jax.lax.stop_gradient(
         transitions.reward * reward_scale +
         transitions.discount * discount * next_v)
     q_error = q_old_action - jnp.expand_dims(target_q, -1)
     # q_loss = 0.5 * jnp.mean(jnp.square(q_error))
     q_loss = jnp.mean(jnp.square(q_error))
     q_loss = q_loss * q_error.shape[-1]
     return q_loss