Пример #1
0
    def fit_value(self, states):
        """Updates critic parameters.

    Args:
      states: Batch of states.

    Returns:
      Dictionary with information to track.
    """

        actions, log_probs = self.actor(states,
                                        sample=True,
                                        with_log_probs=True)

        q1, q2 = self.critic(states, actions)
        q = tf.minimum(q1, q2) - self.alpha * log_probs

        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(self.value.trainable_variables)

            v = self.value(states)

            value_loss = tf.losses.mean_squared_error(q, v)

        grads = tape.gradient(value_loss, self.value.trainable_variables)

        self.value_optimizer.apply_gradients(
            zip(grads, self.value.trainable_variables))

        if self.value_optimizer.iterations % self.target_update_period == 0:
            critic.soft_update(self.value, self.value_target, tau=self.tau)

        return {'v': tf.reduce_mean(v), 'value_loss': value_loss}
Пример #2
0
    def fit_critic(self, states, actions, next_states, rewards, discounts):
        """Updates critic parameters.

    Args:
      states: Batch of states.
      actions: Batch of actions.
      next_states: Batch of next states.
      rewards: Batch of rewards.
      discounts: Batch of masks indicating the end of the episodes.

    Returns:
      Dictionary with information to track.
    """
        next_actions = self.actor(next_states, sample=True)

        next_q1, next_q2 = self.critic_target(next_states, next_actions)
        target_q = rewards + self.discount * discounts * tf.minimum(
            next_q1, next_q2)

        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(self.critic.trainable_variables)

            q1, q2 = self.critic(states, actions)
            policy_actions = self.actor(states, sample=True)

            q_pi1, q_pi2 = self.critic(states, policy_actions)

            def discriminator_loss(real_output, fake_output):
                diff = target_q - real_output

                def my_exp(x):
                    return 1.0 + x + 0.5 * tf.square(x)

                alpha = 10.0
                real_loss = tf.reduce_mean(my_exp(diff / alpha) * alpha)
                total_loss = real_loss + tf.reduce_mean(fake_output)
                return total_loss

            critic_loss1 = discriminator_loss(q1, q_pi1)
            critic_loss2 = discriminator_loss(q2, q_pi2)

            critic_loss = (critic_loss1 + critic_loss2)

        critic_grads = tape.gradient(critic_loss,
                                     self.critic.trainable_variables)

        self.critic_optimizer.apply_gradients(
            zip(critic_grads, self.critic.trainable_variables))

        critic.soft_update(self.critic, self.critic_target, tau=self.tau)

        return {
            'q1': tf.reduce_mean(q1),
            'q2': tf.reduce_mean(q2),
            'critic_loss': critic_loss,
            'q_pi1': tf.reduce_mean(q_pi1),
            'q_pi2': tf.reduce_mean(q_pi2)
        }
Пример #3
0
    def __init__(self,
                 observation_spec,
                 action_spec,
                 actor_lr=3e-4,
                 critic_lr=3e-4,
                 alpha_lr=3e-4,
                 discount=0.99,
                 tau=0.005,
                 target_update_period=1,
                 target_entropy=0.0,
                 use_soft_critic=False):
        """Creates networks.

    Args:
      observation_spec: environment observation spec.
      action_spec: Action spec.
      actor_lr: Actor learning rate.
      critic_lr: Critic learning rate.
      alpha_lr: Temperature learning rate.
      discount: MDP discount.
      tau: Soft target update parameter.
      target_update_period: Target network update period.
      target_entropy: Target entropy.
      use_soft_critic: Whether to use soft critic representation.
    """
        assert len(observation_spec.shape) == 1
        state_dim = observation_spec.shape[0]

        self.actor = policies.DiagGuassianPolicy(state_dim, action_spec)
        self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr)

        self.log_alpha = tf.Variable(tf.math.log(0.1), trainable=True)
        self.alpha_optimizer = tf.keras.optimizers.Adam(learning_rate=alpha_lr)

        self.target_entropy = target_entropy
        self.discount = discount

        self.tau = tau
        self.target_update_period = target_update_period

        self.value = critic.CriticNet(state_dim)
        self.value_target = critic.CriticNet(state_dim)
        critic.soft_update(self.value, self.value_target, tau=1.0)
        self.value_optimizer = tf.keras.optimizers.Adam(
            learning_rate=critic_lr)

        if use_soft_critic:
            self.critic = critic.SoftCritic(state_dim, action_spec)
        else:
            action_dim = action_spec.shape[0]
            self.critic = critic.Critic(state_dim, action_dim)
        self.critic_optimizer = tf.keras.optimizers.Adam(
            learning_rate=critic_lr)
Пример #4
0
    def fit(self, states, actions, next_states, rewards, discounts):
        """Updates critic parameters.

    Args:
      states: Batch of states.
      actions: Batch of actions.
      next_states: Batch of next states.
      rewards: Batch of rewards.
      discounts: Batch of masks indicating the end of the episodes.

    Returns:
      Dictionary with information to track.
    """
        next_v = self.value_target(next_states)
        target_q = rewards + self.discount * discounts * next_v

        all_vars = (list(self.actor.trainable_variables) +
                    list(self.value.trainable_variables))
        with tf.GradientTape(watch_accessed_variables=False,
                             persistent=True) as tape:
            tape.watch(all_vars)

            actor_log_probs = self.actor.log_probs(states, actions)
            q = self.value(states) + self.alpha * actor_log_probs

            adv = tf.stop_gradient(target_q - q)
            actor_loss = -tf.reduce_mean(actor_log_probs * adv)
            critic_loss = tf.losses.mean_squared_error(target_q, q)

        actor_grads = tape.gradient(actor_loss, self.actor.trainable_variables)
        critic_grads = tape.gradient(critic_loss,
                                     self.value.trainable_variables)

        self.actor_optimizer.apply_gradients(
            zip(actor_grads, self.actor.trainable_variables))
        self.critic_optimizer.apply_gradients(
            zip(critic_grads, self.value.trainable_variables))

        del tape
        if self.critic_optimizer.iterations % self.target_update_period == 0:
            critic.soft_update(self.value, self.value_target, tau=self.tau)

        return {
            'q': tf.reduce_mean(q),
            'critic_loss': critic_loss,
            'actor_log_probs': tf.reduce_mean(actor_log_probs),
            'adv': tf.reduce_mean(adv)
        }
Пример #5
0
    def __init__(self,
                 observation_spec,
                 action_spec,
                 actor_lr=1e-4,
                 critic_lr=3e-4,
                 alpha_lr=1e-4,
                 discount=0.99,
                 tau=0.005,
                 target_entropy=0.0):
        """Creates networks.

    Args:
      observation_spec: environment observation spec.
      action_spec: Action spec.
      actor_lr: Actor learning rate.
      critic_lr: Critic learning rate.
      alpha_lr: Temperature learning rate.
      discount: MDP discount.
      tau: Soft target update parameter.
      target_entropy: Target entropy.
    """
        assert len(observation_spec.shape) == 1
        state_dim = observation_spec.shape[0]

        beta_1 = 0.0
        self.actor = policies.DiagGuassianPolicy(state_dim, action_spec)
        self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr,
                                                        beta_1=beta_1)

        self.log_alpha = tf.Variable(tf.math.log(0.1), trainable=True)
        self.alpha_optimizer = tf.keras.optimizers.Adam(learning_rate=alpha_lr,
                                                        beta_1=beta_1)

        self.target_entropy = target_entropy
        self.discount = discount
        self.tau = tau

        action_dim = action_spec.shape[0]
        self.critic = critic.Critic(state_dim, action_dim)
        self.critic_target = critic.Critic(state_dim, action_dim)
        critic.soft_update(self.critic, self.critic_target, tau=1.0)
        self.critic_optimizer = tf.keras.optimizers.Adam(
            learning_rate=critic_lr, beta_1=beta_1)
Пример #6
0
    def fit_critic(self, states, actions, next_states, rewards, discounts):
        """Updates critic parameters.

    Args:
      states: Batch of states.
      actions: Batch of actions.
      next_states: Batch of next states.
      rewards: Batch of rewards.
      discounts: Batch of masks indicating the end of the episodes.

    Returns:
      Dictionary with information to track.
    """
        next_actions = self.actor(next_states, sample=True)
        bc_log_probs = self.bc.policy.log_probs(next_states, next_actions)

        next_target_q1, next_target_q2 = self.critic_target(
            next_states, next_actions)
        target_q = rewards + self.discount * discounts * (tf.minimum(
            next_target_q1, next_target_q2) + self.alpha * bc_log_probs)

        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(self.critic.trainable_variables)

            q1, q2 = self.critic(states, actions)

            critic_loss = (tf.losses.mean_squared_error(target_q, q1) +
                           tf.losses.mean_squared_error(target_q, q2))
        critic_grads = tape.gradient(critic_loss,
                                     self.critic.trainable_variables)

        self.critic_optimizer.apply_gradients(
            zip(critic_grads, self.critic.trainable_variables))

        critic.soft_update(self.critic, self.critic_target, tau=self.tau)

        return {
            'q1': tf.reduce_mean(q1),
            'q2': tf.reduce_mean(q2),
            'critic_loss': critic_loss
        }
Пример #7
0
    def __init__(self,
                 observation_spec,
                 action_spec,
                 actor_lr=3e-4,
                 critic_lr=3e-4,
                 alpha_lr=3e-4,
                 discount=0.99,
                 tau=0.005,
                 target_entropy=0.0,
                 f_reg=1.0,
                 reward_bonus=5.0,
                 num_augmentations=1,
                 env_name='',
                 batch_size=256):
        """Creates networks.

    Args:
      observation_spec: environment observation spec.
      action_spec: Action spec.
      actor_lr: Actor learning rate.
      critic_lr: Critic learning rate.
      alpha_lr: Temperature learning rate.
      discount: MDP discount.
      tau: Soft target update parameter.
      target_entropy: Target entropy.
      f_reg: Critic regularization weight.
      reward_bonus: Bonus added to the rewards.
      num_augmentations: Number of random crops
      env_name: Env name
      batch_size: batch size
    """
        del num_augmentations, env_name
        assert len(observation_spec.shape) == 1
        state_dim = observation_spec.shape[0]
        self.batch_size = batch_size

        hidden_dims = (256, 256, 256)
        self.actor = policies.DiagGuassianPolicy(state_dim,
                                                 action_spec,
                                                 hidden_dims=hidden_dims)
        self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr)

        self.log_alpha = tf.Variable(tf.math.log(1.0), trainable=True)
        self.alpha_optimizer = tf.keras.optimizers.Adam(learning_rate=alpha_lr)

        self.target_entropy = target_entropy
        self.discount = discount
        self.tau = tau

        self.bc = behavioral_cloning.BehavioralCloning(observation_spec,
                                                       action_spec,
                                                       mixture=True)

        action_dim = action_spec.shape[0]
        self.critic = critic.Critic(state_dim,
                                    action_dim,
                                    hidden_dims=hidden_dims)
        self.critic_target = critic.Critic(state_dim,
                                           action_dim,
                                           hidden_dims=hidden_dims)
        critic.soft_update(self.critic, self.critic_target, tau=1.0)
        self.critic_optimizer = tf.keras.optimizers.Adam(
            learning_rate=critic_lr)

        self.f_reg = f_reg
        self.reward_bonus = reward_bonus

        self.model_dict = {
            'critic': self.critic,
            'actor': self.actor,
            'critic_target': self.critic_target,
            'actor_optimizer': self.actor_optimizer,
            'critic_optimizer': self.critic_optimizer,
            'alpha_optimizer': self.alpha_optimizer
        }
Пример #8
0
    def fit_critic(self, states, actions, next_states, rewards, discounts):
        """Updates critic parameters.

    Args:
      states: Batch of states.
      actions: Batch of actions.
      next_states: Batch of next states.
      rewards: Batch of rewards.
      discounts: Batch of masks indicating the end of the episodes.

    Returns:
      Dictionary with information to track.
    """
        next_actions = self.actor(next_states, sample=True)
        policy_actions = self.actor(states, sample=True)

        next_target_q1, next_target_q2 = self.dist_critic(next_states,
                                                          next_actions,
                                                          target=True)
        target_q = rewards + self.discount * discounts * tf.minimum(
            next_target_q1, next_target_q2)

        critic_variables = self.critic.trainable_variables

        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(critic_variables)
            q1, q2 = self.dist_critic(states, actions, stop_gradient=True)
            with tf.GradientTape(watch_accessed_variables=False,
                                 persistent=True) as tape2:
                tape2.watch([policy_actions])

                q1_reg, q2_reg = self.critic(states, policy_actions)

            q1_grads = tape2.gradient(q1_reg, policy_actions)
            q2_grads = tape2.gradient(q2_reg, policy_actions)

            q1_grad_norm = tf.reduce_sum(tf.square(q1_grads), axis=-1)
            q2_grad_norm = tf.reduce_sum(tf.square(q2_grads), axis=-1)

            del tape2

            q_reg = tf.reduce_mean(q1_grad_norm + q2_grad_norm)

            critic_loss = (tf.losses.mean_squared_error(target_q, q1) +
                           tf.losses.mean_squared_error(target_q, q2) +
                           self.f_reg * q_reg)

        critic_grads = tape.gradient(critic_loss, critic_variables)

        self.critic_optimizer.apply_gradients(
            zip(critic_grads, critic_variables))

        critic.soft_update(self.critic, self.critic_target, tau=self.tau)

        return {
            'q1': tf.reduce_mean(q1),
            'q2': tf.reduce_mean(q2),
            'critic_loss': critic_loss,
            'q1_grad': tf.reduce_mean(q1_grad_norm),
            'q2_grad': tf.reduce_mean(q2_grad_norm)
        }
Пример #9
0
    def __init__(self,
                 observation_spec,
                 action_spec,
                 actor_lr=3e-4,
                 critic_lr=3e-4,
                 alpha_lr=3e-4,
                 discount=0.99,
                 tau=0.005,
                 target_entropy=0.0,
                 f_reg=1.0,
                 reward_bonus=5.0,
                 num_augmentations=1,
                 rep_learn_keywords='outer',
                 env_name='',
                 batch_size=256,
                 n_quantiles=5,
                 temp=0.1,
                 num_training_levels=200,
                 latent_dim=256,
                 n_levels_nce=5,
                 popart_norm_beta=0.1):
        """Creates networks.

    Args:
      observation_spec: environment observation spec.
      action_spec: Action spec.
      actor_lr: Actor learning rate.
      critic_lr: Critic learning rate.
      alpha_lr: Temperature learning rate.
      discount: MDP discount.
      tau: Soft target update parameter.
      target_entropy: Target entropy.
      f_reg: Critic regularization weight.
      reward_bonus: Bonus added to the rewards.
      num_augmentations: Number of DrQ augmentations (crops)
      rep_learn_keywords: Representation learning loss to add (see below)
      env_name: Env name
      batch_size: Batch size
      n_quantiles: Number of GVF quantiles
      temp: Temperature of NCE softmax
      num_training_levels: Number of training MDPs (Procgen=200)
      latent_dim: Latent dimensions of auxiliary MLPs
      n_levels_nce: Number of MDPs to use contrastive loss on
      popart_norm_beta: PopArt normalization constant

    For `rep_learn_keywords`, pick from:
      stop_grad_FQI: whether to stop_grad TD/FQI critic updates?
      linear_Q: use a linear critic?

      successor_features: uses ||SF|| as cumulant
      gvf_termination: uses +1 if done else 0 as cumulant
      gvf_action_count: uses state-cond. action counts as cumulant

      nce: uses the multi-class dot-product InfoNCE objective
      cce: uses MoCo Categorical CrossEntropy objective
      energy: uses SimCLR + pairwise GVF distance (not fully tested)

    If no cumulant is specified, the reward will be taken as default one.
    """
        del actor_lr, critic_lr, alpha_lr, target_entropy
        self.action_spec = action_spec
        self.num_augmentations = num_augmentations
        self.rep_learn_keywords = rep_learn_keywords.split('__')
        self.batch_size = batch_size
        self.env_name = env_name
        self.stop_grad_fqi = 'stop_grad_FQI' in self.rep_learn_keywords
        critic_kwargs = {'hidden_dims': (1024, 1024)}
        self.latent_dim = latent_dim
        self.n_levels_nce = n_levels_nce
        hidden_dims = hidden_dims_per_level = (self.latent_dim,
                                               self.latent_dim)
        self.num_training_levels = int(num_training_levels)
        self.n_quantiles = n_quantiles
        self.temp = temp

        # Make 2 sets of weights:
        # - Critic
        # - Critic (target)
        # Optionally, make a 3rd set for per-level critics

        if observation_spec.shape == (64, 64, 3):
            # IMPALA for Procgen
            def conv_stack():
                return make_impala_cnn_network(depths=[16, 32, 32],
                                               use_batch_norm=False,
                                               dropout_rate=0.)

            state_dim = 256
        else:
            # Reduced architecture for DMC
            def conv_stack():
                return ConvStack(observation_spec.shape)

            state_dim = 50

        conv_stack_critic = conv_stack()
        conv_target_stack_critic = conv_stack()

        if observation_spec.shape == (64, 64, 3):
            conv_stack_critic.output_size = state_dim
            conv_target_stack_critic.output_size = state_dim
        critic_kwargs['encoder'] = ImageEncoder(conv_stack_critic,
                                                feature_dim=state_dim,
                                                bprop_conv_stack=True)
        # Note: the target critic does not share any weights.
        critic_kwargs['encoder_target'] = ImageEncoder(
            conv_target_stack_critic,
            feature_dim=state_dim,
            bprop_conv_stack=True)

        conv_stack_critic_per_level = conv_stack()
        conv_target_stack_critic_per_level = conv_stack()
        if observation_spec.shape == (64, 64, 3):
            conv_stack_critic_per_level.output_size = state_dim
            conv_target_stack_critic_per_level.output_size = state_dim

        self.encoder_per_level = ImageEncoder(conv_stack_critic_per_level,
                                              feature_dim=state_dim,
                                              bprop_conv_stack=True)
        self.encoder_per_level_target = ImageEncoder(
            conv_target_stack_critic_per_level,
            feature_dim=state_dim,
            bprop_conv_stack=True)

        criticCL.soft_update(self.encoder_per_level,
                             self.encoder_per_level_target,
                             tau=1.0)

        if self.num_augmentations == 0:
            dummy_state = tf.constant(
                np.zeros([1] + list(observation_spec.shape)))
        else:  # account for padding of +4 everywhere and then cropping out 68
            dummy_state = tf.constant(np.zeros(shape=[1, 68, 68, 3]))
        dummy_enc = critic_kwargs['encoder'](dummy_state)

        @tf.function
        def init_models():
            """This function initializes all auxiliary networks (state and action encoders) with dummy input (Procgen-specific, 68x68x3, 15 actions).
      """
            critic_kwargs['encoder'](dummy_state)
            critic_kwargs['encoder_target'](dummy_state)
            self.encoder_per_level(dummy_state)
            self.encoder_per_level_target(dummy_state)

        init_models()

        action_dim = action_spec.maximum.item() + 1

        self.action_dim = action_dim
        self.discount = discount
        self.tau = tau
        self.reg = f_reg
        self.reward_bonus = reward_bonus

        self.critic = criticCL.Critic(state_dim,
                                      action_dim,
                                      hidden_dims=hidden_dims,
                                      encoder=critic_kwargs['encoder'],
                                      discrete_actions=True,
                                      linear='linear_Q'
                                      in self.rep_learn_keywords)
        self.critic_target = criticCL.Critic(
            state_dim,
            action_dim,
            hidden_dims=hidden_dims,
            encoder=critic_kwargs['encoder_target'],
            discrete_actions=True,
            linear='linear_Q' in self.rep_learn_keywords)

        self.critic_optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4)
        self.task_critic_optimizer = tf.keras.optimizers.Adam(
            learning_rate=3e-4)
        self.br_optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4)

        if 'cce' in self.rep_learn_keywords:
            self.classifier = tf.keras.Sequential([
                tf.keras.layers.Dense(self.latent_dim, use_bias=True),
                tf.keras.layers.ReLU(),
                tf.keras.layers.Dense(self.n_quantiles, use_bias=True)
            ],
                                                  name='classifier')
        elif 'nce' in self.rep_learn_keywords:
            self.embedding = tf.keras.Sequential([
                tf.keras.layers.Dense(self.latent_dim, use_bias=True),
                tf.keras.layers.ReLU(),
                tf.keras.layers.Dense(self.latent_dim, use_bias=True)
            ],
                                                 name='embedding')

        # This snipet initializes all auxiliary networks (state and action encoders)
        # with dummy input (Procgen-specific, 68x68x3, 15 actions).
        dummy_state = tf.zeros((1, 68, 68, 3), dtype=tf.float32)
        phi_s = self.critic.encoder(dummy_state)
        phi_a = tf.eye(action_dim, dtype=tf.float32)
        if 'linear_Q' in self.rep_learn_keywords:
            _ = self.critic.critic1.state_encoder(phi_s)
            _ = self.critic.critic2.state_encoder(phi_s)
            _ = self.critic.critic1.action_encoder(phi_a)
            _ = self.critic.critic2.action_encoder(phi_a)
            _ = self.critic_target.critic1.state_encoder(phi_s)
            _ = self.critic_target.critic2.state_encoder(phi_s)
            _ = self.critic_target.critic1.action_encoder(phi_a)
            _ = self.critic_target.critic2.action_encoder(phi_a)
        if 'cce' in self.rep_learn_keywords:
            self.classifier(phi_s)
        elif 'nce' in self.rep_learn_keywords:
            self.embedding(phi_s)

        self.target_critic_to_use = self.critic_target
        self.critic_to_use = self.critic

        criticCL.soft_update(self.critic, self.critic_target, tau=1.0)

        self.cce = tf.keras.losses.SparseCategoricalCrossentropy(
            reduction=tf.keras.losses.Reduction.NONE, from_logits=True)

        self.bc = None

        if 'successor_features' in self.rep_learn_keywords:
            self.output_dim_level = self.latent_dim
        elif 'gvf_termination' in self.rep_learn_keywords:
            self.output_dim_level = 1
        elif 'gvf_action_count' in self.rep_learn_keywords:
            self.output_dim_level = action_dim
        else:
            self.output_dim_level = action_dim

        self.task_critic_one = criticCL.Critic(
            state_dim,
            self.output_dim_level * self.num_training_levels,
            hidden_dims=hidden_dims_per_level,
            encoder=None,  # critic_kwargs['encoder'],
            discrete_actions=True,
            cross_norm=False)
        self.task_critic_target_one = criticCL.Critic(
            state_dim,
            self.output_dim_level * 200,
            hidden_dims=hidden_dims_per_level,
            encoder=None,  # critic_kwargs['encoder'],
            discrete_actions=True,
            cross_norm=False)
        self.task_critic_one(dummy_enc,
                             actions=None,
                             training=False,
                             return_features=False,
                             stop_grad_features=False)
        self.task_critic_target_one(dummy_enc,
                                    actions=None,
                                    training=False,
                                    return_features=False,
                                    stop_grad_features=False)
        criticCL.soft_update(self.task_critic_one,
                             self.task_critic_target_one,
                             tau=1.0)

        # Normalization constant beta, set to best default value as per PopArt paper
        self.reward_normalizer = popart.PopArt(
            running_statistics.EMAMeanStd(popart_norm_beta))
        self.reward_normalizer.init()

        if 'CLIP' in self.rep_learn_keywords or 'clip' in self.rep_learn_keywords:
            self.loss_temp = tf.Variable(tf.constant(0.0, dtype=tf.float32),
                                         name='loss_temp',
                                         trainable=True)

        self.model_dict = {
            'critic': self.critic,
            'critic_target': self.critic_target,
            'critic_optimizer': self.critic_optimizer,
            'br_optimizer': self.br_optimizer
        }

        self.model_dict['encoder_perLevel'] = self.encoder_per_level
        self.model_dict[
            'encoder_perLevel_target'] = self.encoder_per_level_target
        self.model_dict['task_critic'] = self.task_critic_one
        self.model_dict['task_critic_target'] = self.task_critic_target_one
Пример #10
0
    def fit_task_critics(self, mb_states, mb_actions, mb_next_states,
                         mb_next_actions, mb_rewards, mb_discounts, level_ids):
        """Updates per-level critic parameters.

    Args:
      mb_states: Batch of states.
      mb_actions: Batch of actions.
      mb_next_states: Batch of next states.
      mb_next_actions: Batch of next actions from training policy.
      mb_rewards: Batch of rewards.
      mb_discounts: Batch of masks indicating the end of the episodes.
      level_ids: Batch of level ids

    Returns:
      Dictionary with information to track.
    """
        if 'popart' in self.rep_learn_keywords:
            # The PopArt normalization normalizes the GVF's cumulant signal so that
            # it's not affected by the difference in scales across MDPs.
            mb_rewards = self.reward_normalizer.normalize_target(mb_rewards)

        trainable_variables = self.encoder_per_level.trainable_variables + self.task_critic_one.trainable_variables

        next_action_indices = tf.stack([
            tf.range(tf.shape(mb_next_actions)[0],
                     dtype=tf.int32), level_ids * self.output_dim_level +
            tf.cast(mb_next_actions, dtype=tf.int32)
        ],
                                       axis=-1)

        action_indices = tf.stack([
            tf.range(tf.shape(mb_actions)[0],
                     dtype=tf.int32), level_ids * self.output_dim_level +
            tf.cast(mb_actions, dtype=tf.int32)
        ],
                                  axis=-1)
        level_ids = tf.stack([
            tf.range(tf.shape(mb_next_actions)[0], dtype=tf.int32),
            tf.cast(level_ids, dtype=tf.int32)
        ],
                             axis=-1)

        next_states = [self.encoder_per_level_target(mb_next_states[0])]
        next_q1, next_q2 = self.task_critic_target_one(next_states[0],
                                                       actions=None)
        # Learn d-dimensional successor features
        if 'successor_features' in self.rep_learn_keywords:
            target_q = tf.concat(
                [next_states[0]] * 200, 1) + self.discount * tf.expand_dims(
                    mb_discounts, 1) * tf.minimum(next_q1, next_q2)
        # Learn discounted episode termination
        elif 'gvf_termination' in self.rep_learn_keywords:
            target_q = tf.expand_dims(
                mb_discounts, 1) + self.discount * tf.expand_dims(
                    mb_discounts, 1) * tf.minimum(next_q1, next_q2)
        # Learn discounted future action counts
        elif 'gvf_action_count' in self.rep_learn_keywords:
            target_q = tf.concat(
                [tf.one_hot(mb_actions, depth=self.action_dim)] * 200,
                1) + self.discount * tf.expand_dims(
                    mb_discounts, 1) * tf.minimum(next_q1, next_q2)
        else:
            target_q = tf.expand_dims(
                mb_rewards, 1) + self.discount * tf.expand_dims(
                    mb_discounts, 1) * tf.minimum(next_q1, next_q2)

        if ('successor_features' in self.rep_learn_keywords
                or 'gvf_termination' in self.rep_learn_keywords
                or 'gvf_action_count' in self.rep_learn_keywords):
            target_q = tf.reshape(target_q, (-1, 200, self.output_dim_level))
            target_q = tf.gather_nd(target_q, indices=level_ids)
        else:
            target_q = tf.gather_nd(target_q, indices=next_action_indices)

        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(trainable_variables)

            states = [self.encoder_per_level(mb_states[0])]
            q1_all, q2_all = self.task_critic_one(states[0], actions=None)

            q = tf.minimum(q1_all, q2_all)
            if ('successor_features' in self.rep_learn_keywords
                    or 'gvf_termination' in self.rep_learn_keywords
                    or 'gvf_action_count' in self.rep_learn_keywords):
                q1_all = tf.reshape(q1_all, (-1, 200, self.output_dim_level))
                q2_all = tf.reshape(q2_all, (-1, 200, self.output_dim_level))
                critic_loss = (
                    tf.losses.mean_squared_error(
                        target_q, tf.gather_nd(q1_all, indices=level_ids)) +
                    tf.losses.mean_squared_error(
                        target_q, tf.gather_nd(q2_all, indices=level_ids)))
            else:
                critic_loss = (tf.losses.mean_squared_error(
                    target_q, tf.gather_nd(q1_all, indices=action_indices)) +
                               tf.losses.mean_squared_error(
                                   target_q,
                                   tf.gather_nd(q2_all,
                                                indices=action_indices)))

        critic_grads = tape.gradient(critic_loss, trainable_variables)

        self.task_critic_optimizer.apply_gradients(
            zip(critic_grads, trainable_variables))

        criticCL.soft_update(self.encoder_per_level,
                             self.encoder_per_level_target,
                             tau=self.tau)
        criticCL.soft_update(self.task_critic_one,
                             self.task_critic_target_one,
                             tau=self.tau)

        gn = tf.reduce_mean(
            [tf.linalg.norm(v) for v in critic_grads if v is not None])

        return {
            'avg_level_critic_loss': tf.reduce_mean(critic_loss),
            'avg_q': tf.reduce_mean(q),
            'level_critic_grad_norm': gn
        }
Пример #11
0
    def fit_critic(self, states, actions, next_states, next_actions, rewards,
                   discounts):
        """Updates critic parameters.

    Args:
      states: Batch of states.
      actions: Batch of actions.
      next_states: Batch of next states.
      next_actions: Batch of next actions from training policy.
      rewards: Batch of rewards.
      discounts: Batch of masks indicating the end of the episodes.
    Returns:
      Dictionary with information to track.
    """
        action_indices = tf.stack(
            [tf.range(tf.shape(actions)[0], dtype=tf.int64), actions], axis=-1)
        next_action_indices = tf.stack([
            tf.range(tf.shape(next_actions)[0], dtype=tf.int64), next_actions
        ],
                                       axis=-1)

        if self.num_augmentations > 1:
            target_q = 0.
            for i in range(self.num_augmentations):
                next_q1_i, next_q2_i = self.critic_target(
                    next_states[i],
                    actions=None,
                    stop_grad_features=self.stop_grad_fqi)
                target_q_i = tf.expand_dims(
                    rewards, 1) + self.discount * tf.expand_dims(
                        discounts, 1) * tf.minimum(next_q1_i, next_q2_i)
                target_q += target_q_i
            target_q /= self.num_augmentations
        elif self.num_augmentations == 1:
            next_q1, next_q2 = self.critic_target(
                next_states[0],
                actions=None,
                stop_grad_features=self.stop_grad_fqi)
            target_q = tf.expand_dims(
                rewards, 1) + self.discount * tf.expand_dims(
                    discounts, 1) * tf.minimum(next_q1, next_q2)
        else:
            next_q1, next_q2 = self.target_critic_to_use(
                next_states,
                actions=None,
                stop_grad_features=self.stop_grad_fqi)
            target_q = tf.expand_dims(
                rewards, 1) + self.discount * tf.expand_dims(
                    discounts, 1) * tf.minimum(next_q1, next_q2)

        target_q = tf.gather_nd(target_q, indices=next_action_indices)
        trainable_variables = self.critic.trainable_variables

        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(trainable_variables)

            if self.num_augmentations > 1:
                critic_loss = 0.
                q1 = 0.
                q2 = 0.
                for i in range(self.num_augmentations):
                    q1_i, q2_i = self.critic_to_use(
                        states[i],
                        actions=None,
                        stop_grad_features=self.stop_grad_fqi)
                    critic_loss_i = (tf.losses.mean_squared_error(
                        target_q, tf.gather_nd(q1_i, indices=action_indices)) +
                                     tf.losses.mean_squared_error(
                                         target_q,
                                         tf.gather_nd(q2_i,
                                                      indices=action_indices)))
                    q1 += q1_i
                    q2 += q2_i
                    critic_loss += critic_loss_i
                q1 /= self.num_augmentations
                q2 /= self.num_augmentations
                critic_loss /= self.num_augmentations
            elif self.num_augmentations == 1:
                q1, q2 = self.critic_to_use(
                    states[0],
                    actions=None,
                    stop_grad_features=self.stop_grad_fqi)
                q = tf.minimum(q1, q2)
                critic_loss = (
                    tf.losses.mean_squared_error(
                        target_q, tf.gather_nd(q1, indices=action_indices)) +
                    tf.losses.mean_squared_error(
                        target_q, tf.gather_nd(q2, indices=action_indices)))
            else:
                q1, q2 = self.critic_to_use(
                    states,
                    actions=None,
                    stop_grad_features=self.stop_grad_fqi)
                q = tf.minimum(q1, q2)
                critic_loss = (
                    tf.losses.mean_squared_error(
                        target_q, tf.gather_nd(q1, indices=action_indices)) +
                    tf.losses.mean_squared_error(
                        target_q, tf.gather_nd(q2, indices=action_indices)))

            # LSE from CQL
            cql_logsumexp = tf.reduce_logsumexp(q, 1)
            cql_loss = tf.reduce_mean(cql_logsumexp -
                                      tf.gather_nd(q, indices=action_indices))
            # Jointly optimize both losses
            critic_loss = critic_loss + cql_loss

        critic_grads = tape.gradient(critic_loss, trainable_variables)

        self.critic_optimizer.apply_gradients(
            zip(critic_grads, trainable_variables))

        criticCL.soft_update(self.critic, self.critic_target, tau=self.tau)

        gn = tf.reduce_mean(
            [tf.linalg.norm(v) for v in critic_grads if v is not None])

        return {
            'q1': tf.reduce_mean(q1),
            'q2': tf.reduce_mean(q2),
            'critic_loss': critic_loss,
            'cql_loss': cql_loss,
            'critic_grad_norm': gn
        }
        self.log_alpha = tf.Variable(tf.math.log(1.0), trainable=True)
        self.log_cql_alpha = self.log_alpha
        self.alpha_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr)

        self.critic = critic.Critic(state_dim,
                                    action_dim,
                                    hidden_dims=hidden_dims,
                                    encoder=critic_kwargs['encoder'],
                                    discrete_actions=True)
        self.critic_target = critic.Critic(
            state_dim,
            action_dim,
            hidden_dims=hidden_dims,
            encoder=critic_kwargs['encoder_target'],
            discrete_actions=True)
        critic.soft_update(self.critic, self.critic_target, tau=1.0)
        self.critic_optimizer = tf.keras.optimizers.Adam(
            learning_rate=critic_lr)
        self.tau = tau

        self.reg = reg
        self.target_entropy = target_entropy
        self.discount = discount

        self.num_cql_actions = num_cql_actions
        self.bc_pretraining_steps = bc_pretraining_steps
        self.min_q_weight = min_q_weight

        self.bc = None

        self.model_dict = {
Пример #13
0
    def __init__(self,
                 observation_spec,
                 action_spec,
                 embedding_dim=256,
                 hidden_dims=(256, 256),
                 sequence_length=2,
                 learning_rate=None,
                 discount=0.95,
                 target_update_period=1000,
                 num_augmentations=0,
                 rep_learn_keywords='outer',
                 batch_size=256):
        """Creates networks.

    Args:
      observation_spec: State spec.
      action_spec: Action spec.
      embedding_dim: Embedding size.
      hidden_dims: List of hidden dimensions.
      sequence_length: Expected length of sequences provided as input.
      learning_rate: Learning rate.
      discount: discount factor.
      target_update_period: How frequently update target?
      num_augmentations: Number of DrQ random crops.
      rep_learn_keywords: Representation learning loss to add.
      batch_size: batch size.
    """
        super().__init__()
        action_dim = action_spec.maximum.item() + 1
        self.observation_spec = observation_spec
        self.action_dim = action_dim
        self.action_spec = action_spec
        self.embedding_dim = embedding_dim
        self.sequence_length = sequence_length
        self.discount = discount
        self.tau = 0.005
        self.discount = 0.99
        self.target_update_period = target_update_period
        self.num_augmentations = num_augmentations
        self.rep_learn_keywords = rep_learn_keywords.split('__')
        self.batch_size = batch_size

        critic_kwargs = {}

        if observation_spec.shape == (64, 64, 3):
            # IMPALA for Procgen
            def conv_stack():
                return make_impala_cnn_network(depths=[16, 32, 32],
                                               use_batch_norm=False,
                                               dropout_rate=0.)

            state_dim = 256
        else:
            # Reduced architecture for DMC
            def conv_stack():
                return ConvStack(observation_spec.shape)

            state_dim = 50

        conv_stack_critic = conv_stack()
        conv_target_stack_critic = conv_stack()

        if observation_spec.shape == (64, 64, 3):
            conv_stack_critic.output_size = state_dim
            conv_target_stack_critic.output_size = state_dim

        critic_kwargs['encoder'] = ImageEncoder(conv_stack_critic,
                                                feature_dim=state_dim,
                                                bprop_conv_stack=True)
        critic_kwargs['encoder_target'] = ImageEncoder(
            conv_target_stack_critic,
            feature_dim=state_dim,
            bprop_conv_stack=True)

        self.embedder = tf_utils.EmbedNet(state_dim,
                                          embedding_dim=self.embedding_dim,
                                          hidden_dims=hidden_dims)
        self.f_value = tf_utils.create_mlp(self.embedding_dim,
                                           1,
                                           hidden_dims=hidden_dims,
                                           activation=tf.nn.swish)
        self.f_value_target = tf_utils.create_mlp(self.embedding_dim,
                                                  1,
                                                  hidden_dims=hidden_dims,
                                                  activation=tf.nn.swish)
        self.f_trans = tf_utils.create_mlp(self.embedding_dim +
                                           self.embedding_dim,
                                           self.embedding_dim,
                                           hidden_dims=hidden_dims,
                                           activation=tf.nn.swish)
        self.f_out = tf_utils.create_mlp(self.embedding_dim +
                                         self.embedding_dim,
                                         2,
                                         hidden_dims=hidden_dims,
                                         activation=tf.nn.swish)

        self.action_encoder = tf.keras.Sequential(
            [
                tf.keras.layers.Dense(
                    self.embedding_dim, use_bias=True
                ),  # , kernel_regularizer=tf.keras.regularizers.l2(WEIGHT_DECAY)
                tf.keras.layers.ReLU(),
                tf.keras.layers.Dense(self.embedding_dim)
            ],
            name='action_encoder')

        if self.num_augmentations == 0:
            dummy_state = tf.constant(
                np.zeros(shape=[1] + list(observation_spec.shape)))
            self.obs_spec = list(observation_spec.shape)
        else:  # account for padding of +4 everywhere and then cropping out 68
            dummy_state = tf.constant(np.zeros(shape=[1, 68, 68, 3]))
            self.obs_spec = [68, 68, 3]

        @tf.function
        def init_models():
            critic_kwargs['encoder'](dummy_state)
            critic_kwargs['encoder_target'](dummy_state)
            self.action_encoder(
                tf.cast(tf.one_hot([1], depth=action_dim), tf.float32))

        init_models()

        self.critic = critic.Critic(state_dim,
                                    action_dim,
                                    hidden_dims=hidden_dims,
                                    encoder=critic_kwargs['encoder'],
                                    discrete_actions=True,
                                    linear='linear_Q'
                                    in self.rep_learn_keywords)
        self.critic_target = critic.Critic(
            state_dim,
            action_dim,
            hidden_dims=hidden_dims,
            encoder=critic_kwargs['encoder_target'],
            discrete_actions=True,
            linear='linear_Q' in self.rep_learn_keywords)

        @tf.function
        def init_models2():
            dummy_state = tf.zeros((1, 68, 68, 3), dtype=tf.float32)
            phi_s = self.critic.encoder(dummy_state)
            phi_a = tf.eye(15, dtype=tf.float32)
            if 'linear_Q' in self.rep_learn_keywords:
                _ = self.critic.critic1.state_encoder(phi_s)
                _ = self.critic.critic2.state_encoder(phi_s)
                _ = self.critic.critic1.action_encoder(phi_a)
                _ = self.critic.critic2.action_encoder(phi_a)
                _ = self.critic_target.critic1.state_encoder(phi_s)
                _ = self.critic_target.critic2.state_encoder(phi_s)
                _ = self.critic_target.critic1.action_encoder(phi_a)
                _ = self.critic_target.critic2.action_encoder(phi_a)

        init_models2()

        critic.soft_update(self.critic, self.critic_target, tau=1.0)
        critic.soft_update(self.f_value, self.f_value_target, tau=1.0)

        learning_rate = learning_rate or 1e-4
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4)
        self.critic_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

        self.all_variables = (self.embedder.trainable_variables +
                              self.f_value.trainable_variables +
                              self.f_value_target.trainable_variables +
                              self.f_trans.trainable_variables +
                              self.f_out.trainable_variables +
                              self.critic.trainable_variables +
                              self.critic_target.trainable_variables)

        self.model_dict = {
            'action_encoder': self.action_encoder,
            'f_out': self.f_out,
            'f_trans': self.f_trans,
            'f_value_target': self.f_value_target,
            'f_value': self.f_value,
            'embedder': self.embedder,
            'critic': self.critic,
            'critic_target': self.critic_target,
            'critic_optimizer': self.critic_optimizer,
            'optimizer': self.optimizer
        }
Пример #14
0
    def fit_embedding(self, states, actions, next_states, next_actions,
                      rewards, discounts):
        """Updates critic parameters.

    Args:
      states: Batch of states.
      actions: Batch of actions.
      next_states: Batch of next states.
      next_actions: batch of next actions
      rewards: Batch of rewards.
      discounts: Batch of masks indicating the end of the episodes.

    Returns:
      Dictionary with information to track.
    """

        states = tf.transpose(
            tf.stack([states, next_states])[:, 0], (1, 0, 2, 3, 4))
        batch_size = tf.shape(states)[0]
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(self.all_variables)

            actions = tf.transpose(
                tf.one_hot(tf.stack([actions, next_actions]),
                           depth=self.action_dim), (1, 0, 2))
            actions = tf.reshape(
                actions, [batch_size * self.sequence_length, self.action_dim])
            actions = self.action_encoder(actions)
            actions = tf.reshape(
                actions,
                [batch_size, self.sequence_length, self.embedding_dim])

            all_states = tf.reshape(
                states, [batch_size * self.sequence_length] + self.obs_spec)
            all_features = self.critic.encoder(all_states)
            all_embeddings = self.embedder(all_features, stop_gradient=False)
            embeddings = tf.reshape(
                all_embeddings,
                [batch_size, self.sequence_length, self.embedding_dim])[:,
                                                                        0, :]

            all_pred_values = []
            all_pred_rewards = []
            all_pred_discounts = []
            for idx in range(self.sequence_length):
                pred_value = self.f_value(embeddings)[Ellipsis, 0]
                pred_reward, pred_discount = tf.unstack(self.f_out(
                    tf.concat([embeddings, actions[:, idx, :]], -1)),
                                                        axis=-1)
                pred_embeddings = embeddings + self.f_trans(
                    tf.concat([embeddings, actions[:, idx, :]], -1))

                all_pred_values.append(pred_value)
                all_pred_rewards.append(pred_reward)
                all_pred_discounts.append(pred_discount)

                embeddings = pred_embeddings

            last_value = tf.stop_gradient(
                self.f_value_target(embeddings)[Ellipsis,
                                                0]) / (1 - self.discount)
            all_true_values = []
            # for idx in range(self.sequence_length - 1, -1, -1):
            value = self.discount * discounts * last_value + rewards  #[:, idx]
            all_true_values.append(value)
            last_value = value
            all_true_values = all_true_values[::-1]

            reward_error = tf.stack(all_pred_rewards, -1)[:, 0] - rewards
            value_error = tf.stack(
                all_pred_values,
                -1) - (1 - self.discount) * tf.stack(all_true_values, -1)
            reward_loss = tf.reduce_sum(tf.math.square(reward_error), -1)
            value_loss = tf.reduce_sum(tf.math.square(value_error), -1)

            loss = tf.reduce_mean(reward_loss + value_loss)

        grads = tape.gradient(loss, self.all_variables)

        self.optimizer.apply_gradients(zip(grads, self.all_variables))
        if self.optimizer.iterations % self.target_update_period == 0:
            critic.soft_update(self.f_value, self.f_value_target, tau=self.tau)

        return {
            'embed_loss': loss,
            'reward_loss': tf.reduce_mean(reward_loss),
            'value_loss': tf.reduce_mean(value_loss),
        }
Пример #15
0
    def fit_critic(self, states, actions, next_states, next_actions, rewards,
                   discounts):
        """Updates critic parameters.

    Args:
      states: Batch of states.
      actions: Batch of actions.
      next_states: Batch of next states.
      next_actions: Batch of next actions from training policy.
      rewards: Batch of rewards.
      discounts: Batch of masks indicating the end of the episodes.

    Returns:
      Dictionary with information to track.
    """

        next_q1, next_q2 = self.critic_target(next_states, next_actions)
        target_q = rewards + self.discount * discounts * tf.minimum(
            next_q1, next_q2)

        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(self.critic.trainable_variables)

            q1, q2 = self.critic(states, actions)
            critic_loss = (tf.losses.mean_squared_error(target_q, q1) +
                           tf.losses.mean_squared_error(target_q, q2))

            n_states = tf.repeat(states[tf.newaxis, :, :],
                                 self.num_cql_actions, 0)
            n_states = tf.reshape(n_states, [-1, n_states.get_shape()[-1]])

            n_rand_actions = tf.random.uniform(
                [tf.shape(n_states)[0],
                 actions.get_shape()[-1]], self.actor.action_spec.minimum,
                self.actor.action_spec.maximum)

            n_actions, n_log_probs = self.actor(n_states,
                                                sample=True,
                                                with_log_probs=True)

            q1_rand, q2_rand = self.critic(n_states, n_rand_actions)
            q1_curr_actions, q2_curr_actions = self.critic(n_states, n_actions)

            log_u = -tf.reduce_mean(
                tf.repeat((tf.math.log(2.0 * self.actor.action_scale) *
                           n_rand_actions.shape[-1])[tf.newaxis, :],
                          tf.shape(n_states)[0], 0), 1)

            log_probs_all = tf.concat([n_log_probs, log_u], 0)
            q1_all = tf.concat([q1_curr_actions, q1_rand], 0)
            q2_all = tf.concat([q2_curr_actions, q2_rand], 0)

            def get_qf_loss(q, log_probs):
                q -= log_probs
                q = tf.reshape(q, [-1, tf.shape(states)[0]])
                return tf.math.reduce_logsumexp(q, axis=0)

            min_qf1_loss = get_qf_loss(q1_all, log_probs_all)
            min_qf2_loss = get_qf_loss(q2_all, log_probs_all)

            cql_loss = tf.reduce_mean((min_qf1_loss - q1) +
                                      (min_qf2_loss - q2))
            critic_loss += self.min_q_weight * cql_loss

        critic_grads = tape.gradient(critic_loss,
                                     self.critic.trainable_variables)

        self.critic_optimizer.apply_gradients(
            zip(critic_grads, self.critic.trainable_variables))

        critic.soft_update(self.critic, self.critic_target, tau=self.tau)

        return {
            'q1': tf.reduce_mean(q1),
            'q2': tf.reduce_mean(q2),
            'critic_loss': critic_loss,
            'cql_loss': cql_loss
        }
Пример #16
0
    def __init__(self,
                 observation_spec,
                 action_spec,
                 actor_lr=3e-4,
                 critic_lr=3e-4,
                 alpha_lr=3e-4,
                 discount=0.99,
                 tau=0.005,
                 target_entropy=0.0,
                 f_reg=1.0,
                 reward_bonus=5.0,
                 num_augmentations=1,
                 env_name='',
                 batch_size=256):
        """Creates networks.

    Args:
      observation_spec: environment observation spec.
      action_spec: Action spec.
      actor_lr: Actor learning rate.
      critic_lr: Critic learning rate.
      alpha_lr: Temperature learning rate.
      discount: MDP discount.
      tau: Soft target update parameter.
      target_entropy: Target entropy.
      f_reg: Critic regularization weight.
      reward_bonus: Bonus added to the rewards.
      num_augmentations: Number of DrQ augmentations (crops)
      env_name: Env name
      batch_size: Batch size
    """
        self.num_augmentations = num_augmentations
        self.discrete_actions = False if len(action_spec.shape) else True
        self.batch_size = batch_size

        actor_kwargs = {'hidden_dims': (1024, 1024)}
        critic_kwargs = {'hidden_dims': (1024, 1024)}

        # DRQ encoder params.
        # https://github.com/denisyarats/drq/blob/master/config.yaml#L73

        # Make 4 sets of weights:
        # - BC
        # - Actor
        # - Critic
        # - Critic (target)

        if observation_spec.shape == (64, 64, 3):
            # IMPALA for Procgen
            def conv_stack():
                return make_impala_cnn_network(depths=[16, 32, 32],
                                               use_batch_norm=False,
                                               dropout_rate=0.)

            state_dim = 256
        else:
            # Reduced architecture for DMC
            def conv_stack():
                return ConvStack(observation_spec.shape)

            state_dim = 50

        conv_stack_bc = conv_stack()
        conv_stack_actor = conv_stack()
        conv_stack_critic = conv_stack()
        conv_target_stack_critic = conv_stack()

        if observation_spec.shape == (64, 64, 3):
            conv_stack_bc.output_size = state_dim
            conv_stack_actor.output_size = state_dim
            conv_stack_critic.output_size = state_dim
            conv_target_stack_critic.output_size = state_dim
        # Combine and stop_grad some of the above conv stacks
        actor_kwargs['encoder_bc'] = ImageEncoder(conv_stack_bc,
                                                  feature_dim=state_dim,
                                                  bprop_conv_stack=True)
        actor_kwargs['encoder'] = ImageEncoder(conv_stack_critic,
                                               feature_dim=state_dim,
                                               bprop_conv_stack=False)
        critic_kwargs['encoder'] = ImageEncoder(conv_stack_critic,
                                                feature_dim=state_dim,
                                                bprop_conv_stack=True)
        # Note: the target critic does not share any weights.
        critic_kwargs['encoder_target'] = ImageEncoder(
            conv_target_stack_critic,
            feature_dim=state_dim,
            bprop_conv_stack=True)

        if self.num_augmentations == 0:
            dummy_state = tf.constant(
                np.zeros(shape=[1] + list(observation_spec.shape)))
        else:  # account for padding of +4 everywhere and then cropping out 68
            dummy_state = tf.constant(np.zeros(shape=[1, 68, 68, 3]))

        @tf.function
        def init_models():
            actor_kwargs['encoder_bc'](dummy_state)
            actor_kwargs['encoder'](dummy_state)
            critic_kwargs['encoder'](dummy_state)
            critic_kwargs['encoder_target'](dummy_state)

        init_models()

        if self.discrete_actions:
            hidden_dims = ()
            self.actor = policies.CategoricalPolicy(
                state_dim,
                action_spec,
                hidden_dims=hidden_dims,
                encoder=actor_kwargs['encoder'])
            action_dim = action_spec.maximum.item() + 1
        else:
            hidden_dims = (256, 256, 256)
            self.actor = policies.DiagGuassianPolicy(
                state_dim,
                action_spec,
                hidden_dims=hidden_dims,
                encoder=actor_kwargs['encoder'])
            action_dim = action_spec.shape[0]

        self.action_dim = action_dim

        self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr)

        self.log_alpha = tf.Variable(tf.math.log(1.0), trainable=True)
        self.alpha_optimizer = tf.keras.optimizers.Adam(learning_rate=alpha_lr)

        self.target_entropy = target_entropy
        self.discount = discount
        self.tau = tau

        self.bc = behavioral_cloning.BehavioralCloning(
            observation_spec,
            action_spec,
            mixture=True,
            encoder=actor_kwargs['encoder_bc'],
            num_augmentations=self.num_augmentations,
            env_name=env_name,
            batch_size=batch_size)

        self.critic = critic.Critic(state_dim,
                                    action_dim,
                                    hidden_dims=hidden_dims,
                                    encoder=critic_kwargs['encoder'])
        self.critic_target = critic.Critic(
            state_dim,
            action_dim,
            hidden_dims=hidden_dims,
            encoder=critic_kwargs['encoder_target'])

        critic.soft_update(self.critic, self.critic_target, tau=1.0)
        self.critic_optimizer = tf.keras.optimizers.Adam(
            learning_rate=critic_lr)

        self.f_reg = f_reg
        self.reward_bonus = reward_bonus

        self.model_dict = {
            'critic': self.critic,
            'critic_target': self.critic_target,
            'actor': self.actor,
            'bc': self.bc,
            'critic_optimizer': self.critic_optimizer,
            'alpha_optimizer': self.alpha_optimizer,
            'actor_optimizer': self.actor_optimizer
        }
  def __init__(
      self,
      observation_spec,
      action_spec,
      embedding_dim = 256,
      num_distributions=None,
      hidden_dims = (256, 256),
      sequence_length = 2,
      learning_rate=None,
      latent_dim = 256,
      reward_weight = 1.0,
      forward_weight
       = 1.0,  # Predict last state given prev actions/states.
      inverse_weight = 1.0,  # Predict last action given states.
      state_prediction_mode = 'energy',
      num_augmentations = 0,
      rep_learn_keywords = 'outer',
      batch_size = 256):
    """Creates networks.

    Args:
      observation_spec: State spec.
      action_spec: Action spec.
      embedding_dim: Embedding size.
      num_distributions: Number of categorical distributions for discrete
        embedding.
      hidden_dims: List of hidden dimensions.
      sequence_length: Expected length of sequences provided as input
      learning_rate: Learning rate.
      latent_dim: Dimension of the latent variable.
      reward_weight: Weight on the reward loss.
      forward_weight: Weight on the forward loss.
      inverse_weight: Weight on the inverse loss.
      state_prediction_mode: One of ['latent', 'energy'].
      num_augmentations: Num of random crops
      rep_learn_keywords: Representation learning loss to add.
      batch_size: Batch size
    """
    super().__init__()
    action_dim = action_spec.maximum.item() + 1
    self.observation_spec = observation_spec
    self.action_dim = action_dim
    self.action_spec = action_spec
    self.embedding_dim = embedding_dim
    self.num_distributions = num_distributions
    self.sequence_length = sequence_length
    self.latent_dim = latent_dim
    self.reward_weight = reward_weight
    self.forward_weight = forward_weight
    self.inverse_weight = inverse_weight
    self.state_prediction_mode = state_prediction_mode
    self.num_augmentations = num_augmentations
    self.rep_learn_keywords = rep_learn_keywords.split('__')
    self.batch_size = batch_size
    self.tau = 0.005
    self.discount = 0.99

    critic_kwargs = {}

    if observation_spec.shape == (64, 64, 3):
      # IMPALA for Procgen
      def conv_stack():
        return make_impala_cnn_network(
            depths=[16, 32, 32], use_batch_norm=False, dropout_rate=0.)

      state_dim = 256
    else:
      # Reduced architecture for DMC
      def conv_stack():
        return ConvStack(observation_spec.shape)
      state_dim = 50

    conv_stack_critic = conv_stack()
    conv_target_stack_critic = conv_stack()

    if observation_spec.shape == (64, 64, 3):
      conv_stack_critic.output_size = state_dim
      conv_target_stack_critic.output_size = state_dim

    critic_kwargs['encoder'] = ImageEncoder(
        conv_stack_critic, feature_dim=state_dim, bprop_conv_stack=True)
    critic_kwargs['encoder_target'] = ImageEncoder(
        conv_target_stack_critic, feature_dim=state_dim, bprop_conv_stack=True)

    self.embedder = tf_utils.EmbedNet(
        state_dim,
        embedding_dim=self.embedding_dim,
        num_distributions=self.num_distributions,
        hidden_dims=hidden_dims)

    if self.sequence_length > 2:
      self.latent_embedder = tf_utils.RNNEmbedNet(
          [self.sequence_length - 2, self.embedding_dim + self.embedding_dim],
          embedding_dim=self.latent_dim)

    self.reward_decoder = tf_utils.EmbedNet(
        self.latent_dim + self.embedding_dim + self.embedding_dim,
        embedding_dim=1,
        hidden_dims=hidden_dims)

    forward_decoder_out = (
        self.embedding_dim if (self.state_prediction_mode
                               in ['latent', 'energy']) else self.input_dim)
    forward_decoder_dists = (
        self.num_distributions if
        (self.state_prediction_mode in ['latent', 'energy']) else None)
    self.forward_decoder = tf_utils.StochasticEmbedNet(
        self.latent_dim + self.embedding_dim + self.embedding_dim,
        embedding_dim=forward_decoder_out,
        num_distributions=forward_decoder_dists,
        hidden_dims=hidden_dims)

    self.weight = tf.Variable(tf.eye(self.embedding_dim))

    self.action_encoder = tf.keras.Sequential(
        [
            tf.keras.layers.Dense(
                self.embedding_dim, use_bias=True
            ),  # , kernel_regularizer=tf.keras.regularizers.l2(WEIGHT_DECAY)
            tf.keras.layers.ReLU(),
            tf.keras.layers.Dense(self.embedding_dim)
        ],
        name='action_encoder')

    if self.num_augmentations == 0:
      dummy_state = tf.constant(
          np.zeros(shape=[1] + list(observation_spec.shape)))
      self.obs_spec = list(observation_spec.shape)
    else:  # account for padding of +4 everywhere and then cropping out 68
      dummy_state = tf.constant(np.zeros(shape=[1, 68, 68, 3]))
      self.obs_spec = [68, 68, 3]

    @tf.function
    def init_models():
      critic_kwargs['encoder'](dummy_state)
      critic_kwargs['encoder_target'](dummy_state)
      self.action_encoder(
          tf.cast(tf.one_hot([1], depth=action_dim), tf.float32))

    init_models()

    hidden_dims = (256, 256)
    # self.actor = policies.CategoricalPolicy(state_dim, action_spec,
    #  hidden_dims=hidden_dims, encoder=actor_kwargs['encoder'])

    self.critic = critic.Critic(
        state_dim,
        action_dim,
        hidden_dims=hidden_dims,
        encoder=critic_kwargs['encoder'],
        discrete_actions=True,
        linear='linear_Q' in self.rep_learn_keywords)
    self.critic_target = critic.Critic(
        state_dim,
        action_dim,
        hidden_dims=hidden_dims,
        encoder=critic_kwargs['encoder_target'],
        discrete_actions=True,
        linear='linear_Q' in self.rep_learn_keywords)

    @tf.function
    def init_models2():
      dummy_state = tf.zeros((1, 68, 68, 3), dtype=tf.float32)
      phi_s = self.critic.encoder(dummy_state)
      phi_a = tf.eye(15, dtype=tf.float32)
      if 'linear_Q' in self.rep_learn_keywords:
        _ = self.critic.critic1.state_encoder(phi_s)
        _ = self.critic.critic2.state_encoder(phi_s)
        _ = self.critic.critic1.action_encoder(phi_a)
        _ = self.critic.critic2.action_encoder(phi_a)
        _ = self.critic_target.critic1.state_encoder(phi_s)
        _ = self.critic_target.critic2.state_encoder(phi_s)
        _ = self.critic_target.critic1.action_encoder(phi_a)
        _ = self.critic_target.critic2.action_encoder(phi_a)

    init_models2()

    critic.soft_update(self.critic, self.critic_target, tau=1.0)

    learning_rate = learning_rate or 1e-4
    self.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    self.critic_optimizer = tf.keras.optimizers.Adam(
        learning_rate=learning_rate)

    self.all_variables = (
        self.embedder.trainable_variables +
        self.reward_decoder.trainable_variables +
        self.forward_decoder.trainable_variables +
        self.action_encoder.trainable_variables +
        self.critic.trainable_variables +
        self.critic_target.trainable_variables)

    self.model_dict = {
        'action_encoder': self.action_encoder,
        'weight': self.weight,
        'forward_decoder': self.forward_decoder,
        'reward_decoder': self.reward_decoder,
        'embedder': self.embedder,
        'critic': self.critic,
        'critic_target': self.critic_target,
        'critic_optimizer': self.critic_optimizer,
        'optimizer': self.optimizer
    }