def fit_critic(self, states, actions, next_states, rewards, discounts):
        """Updates critic parameters.

    Args:
      states: Batch of states.
      actions: Batch of actions.
      next_states: Batch of next states.
      rewards: Batch of rewards.
      discounts: Batch of masks indicating the end of the episodes.

    Returns:
      Dictionary with information to track.
    """
        next_actions = self.actor(next_states, sample=True)
        policy_actions = self.actor(states, sample=True)

        next_target_q1, next_target_q2 = self.dist_critic(next_states,
                                                          next_actions,
                                                          target=True)
        target_q = rewards + self.discount * discounts * tf.minimum(
            next_target_q1, next_target_q2)

        critic_variables = self.critic.trainable_variables

        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(critic_variables)
            q1, q2 = self.dist_critic(states, actions, stop_gradient=True)
            with tf.GradientTape(watch_accessed_variables=False,
                                 persistent=True) as tape2:
                tape2.watch([policy_actions])

                q1_reg, q2_reg = self.critic(states, policy_actions)

            q1_grads = tape2.gradient(q1_reg, policy_actions)
            q2_grads = tape2.gradient(q2_reg, policy_actions)

            q1_grad_norm = tf.reduce_sum(tf.square(q1_grads), axis=-1)
            q2_grad_norm = tf.reduce_sum(tf.square(q2_grads), axis=-1)

            del tape2

            q_reg = tf.reduce_mean(q1_grad_norm + q2_grad_norm)

            critic_loss = tf.losses.mean_squared_error(target_q, q1) + \
                tf.losses.mean_squared_error(target_q, q2) + self.f_reg * q_reg

        critic_grads = tape.gradient(critic_loss, critic_variables)

        self.critic_optimizer.apply_gradients(
            zip(critic_grads, critic_variables))

        critic.soft_update(self.critic, self.critic_target, tau=self.tau)

        return {
            'q1': tf.reduce_mean(q1),
            'q2': tf.reduce_mean(q2),
            'critic_loss': critic_loss,
            'q1_grad': tf.reduce_mean(q1_grad_norm),
            'q2_grad': tf.reduce_mean(q2_grad_norm)
        }
    def __init__(self,
                 observation_spec,
                 action_spec,
                 actor_lr=3e-4,
                 critic_lr=3e-4,
                 alpha_lr=3e-4,
                 discount=0.99,
                 tau=0.005,
                 target_entropy=0.0,
                 f_reg=1.0,
                 reward_bonus=5.0):
        """Creates networks.

    Args:
      observation_spec: environment observation spec.
      action_spec: Action spec.
      actor_lr: Actor learning rate.
      critic_lr: Critic learning rate.
      alpha_lr: Temperature learning rate.
      discount: MDP discount.
      tau: Soft target update parameter.
      target_entropy: Target entropy.
      f_reg: Critic regularization weight.
      reward_bonus: Bonus added to the rewards.
    """
        assert len(observation_spec.shape) == 1
        state_dim = observation_spec.shape[0]

        hidden_dims = (256, 256, 256)
        self.actor = policies.DiagGuassianPolicy(state_dim,
                                                 action_spec,
                                                 hidden_dims=hidden_dims)
        self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr)

        self.log_alpha = tf.Variable(tf.math.log(1.0), trainable=True)
        self.alpha_optimizer = tf.keras.optimizers.Adam(learning_rate=alpha_lr)

        self.target_entropy = target_entropy
        self.discount = discount
        self.tau = tau

        self.bc = behavioral_cloning.BehavioralCloning(observation_spec,
                                                       action_spec,
                                                       mixture=True)

        action_dim = action_spec.shape[0]
        self.critic = critic.Critic(state_dim,
                                    action_dim,
                                    hidden_dims=hidden_dims)
        self.critic_target = critic.Critic(state_dim,
                                           action_dim,
                                           hidden_dims=hidden_dims)
        critic.soft_update(self.critic, self.critic_target, tau=1.0)
        self.critic_optimizer = tf.keras.optimizers.Adam(
            learning_rate=critic_lr)

        self.f_reg = f_reg
        self.reward_bonus = reward_bonus