Ejemplo n.º 1
0
  def __init__(self,
               observation_spec,
               action_spec,
               actor_lr = 3e-4,
               critic_lr = 3e-4,
               discount = 0.99,
               tau = 0.005,
               num_augmentations = 1):
    """Creates networks.

    Args:
      observation_spec: environment observation spec.
      action_spec: Action spec.
      actor_lr: Actor learning rate.
      critic_lr: Critic learning rate.
      discount: MDP discount.
      tau: Soft target update parameter.
      num_augmentations: Number of DrQ-style augmentations to perform on pixels
    """

    self.num_augmentations = num_augmentations
    self.discrete_actions = False if len(action_spec.shape) else True

    actor_kwargs = {}
    critic_kwargs = {}

    if observation_spec.shape == (64, 64, 3):
      # IMPALA for Procgen
      def conv_stack():
        return make_impala_cnn_network(
            depths=[16, 32, 32], use_batch_norm=False, dropout_rate=0.)

      state_dim = 256
    else:
      # Reduced architecture for DMC
      def conv_stack():
        return ConvStack(observation_spec.shape)
      state_dim = 50

    conv_stack_actor = conv_stack()
    conv_stack_critic = conv_stack()
    conv_target_stack_critic = conv_stack()

    if observation_spec.shape == (64, 64, 3):
      conv_stack_actor.output_size = state_dim
      conv_stack_critic.output_size = state_dim
      conv_target_stack_critic.output_size = state_dim
    # Combine and stop_grad some of the above conv stacks
    actor_kwargs['encoder'] = ImageEncoder(
        conv_stack_actor, feature_dim=state_dim, bprop_conv_stack=True)
    critic_kwargs['encoder'] = ImageEncoder(
        conv_stack_critic, feature_dim=state_dim, bprop_conv_stack=True)
    # Note: the target critic does not share any weights.
    critic_kwargs['encoder_target'] = ImageEncoder(
        conv_target_stack_critic, feature_dim=state_dim, bprop_conv_stack=True)

    if self.num_augmentations == 0:
      dummy_state = tf.constant(
          np.zeros(shape=[1] + list(observation_spec.shape)))
    else:  # account for padding of +4 everywhere and then cropping out 68
      dummy_state = tf.constant(np.zeros(shape=[1, 68, 68, 3]))

    @tf.function
    def init_models():
      actor_kwargs['encoder'](dummy_state)
      critic_kwargs['encoder'](dummy_state)
      critic_kwargs['encoder_target'](dummy_state)

    init_models()

    if self.discrete_actions:
      action_dim = action_spec.maximum.item() + 1
      self.actor = policies.CVAEPolicyPixelsDiscrete(
          state_dim,
          action_spec,
          action_dim * 2,
          encoder=actor_kwargs['encoder'])

    else:
      action_dim = action_spec.shape[0]
      self.actor = policies.CVAEPolicyPixels(
          state_dim,
          action_spec,
          action_dim * 2,
          encoder=actor_kwargs['encoder'])

    self.action_dim = action_dim
    self.state_dim = state_dim

    if self.discrete_actions:
      self.action_encoder = tf.keras.Sequential(
          [
              tf.keras.layers.Dense(
                  state_dim, use_bias=True
              ),  # , kernel_regularizer=tf.keras.regularizers.l2(WEIGHT_DECAY)
              tf.keras.layers.ReLU(),
              # tf.keras.layers.BatchNormalization(),
              tf.keras.layers.Dense(action_dim)
          ],
          name='action_encoder')
      dummy_psi_act = tf.constant(np.zeros(shape=[1, state_dim]))
      self.action_encoder(dummy_psi_act)
    else:
      self.action_encoder = None

    self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr)

    self.critic_learner = critic.CriticLearner(
        state_dim,
        action_dim,
        critic_lr,
        discount,
        tau,
        encoder=critic_kwargs['encoder'],
        encoder_target=critic_kwargs['encoder_target'])

    self.bc = None
    self.threshold = 0.3

    self.model_dict = {
        'critic_learner': self.critic_learner,
        'action_encoder': self.action_encoder,
        'actor': self.actor,
        'actor_optimizer': self.actor_optimizer
    }
        else:
            # Reduced architecture for DMC
            def conv_stack():
                return ConvStack(observation_spec.shape)

            state_dim = 50

        conv_stack_critic = conv_stack()
        conv_target_stack_critic = conv_stack()

        if observation_spec.shape == (64, 64, 3):
            conv_stack_critic.output_size = state_dim
            conv_target_stack_critic.output_size = state_dim
        # Combine and stop_grad some of the above conv stacks
        critic_kwargs['encoder'] = ImageEncoder(conv_stack_critic,
                                                feature_dim=state_dim,
                                                bprop_conv_stack=True)
        # Note: the target critic does not share any weights.
        critic_kwargs['encoder_target'] = ImageEncoder(
            conv_target_stack_critic,
            feature_dim=state_dim,
            bprop_conv_stack=True)

        if self.num_augmentations == 0:
            dummy_state = tf.constant(
                np.zeros(shape=[1] + list(observation_spec.shape)))
        else:  # account for padding of +4 everywhere and then cropping out 68
            dummy_state = tf.constant(np.zeros(shape=[1, 68, 68, 3]))

        @tf.function
        def init_models():
Ejemplo n.º 3
0
    def __init__(self,
                 observation_spec,
                 action_spec,
                 actor_lr=3e-4,
                 critic_lr=3e-4,
                 alpha_lr=3e-4,
                 discount=0.99,
                 tau=0.005,
                 target_entropy=0.0,
                 f_reg=1.0,
                 reward_bonus=5.0,
                 num_augmentations=1,
                 rep_learn_keywords='outer',
                 env_name='',
                 batch_size=256,
                 n_quantiles=5,
                 temp=0.1,
                 num_training_levels=200,
                 latent_dim=256,
                 n_levels_nce=5,
                 popart_norm_beta=0.1):
        """Creates networks.

    Args:
      observation_spec: environment observation spec.
      action_spec: Action spec.
      actor_lr: Actor learning rate.
      critic_lr: Critic learning rate.
      alpha_lr: Temperature learning rate.
      discount: MDP discount.
      tau: Soft target update parameter.
      target_entropy: Target entropy.
      f_reg: Critic regularization weight.
      reward_bonus: Bonus added to the rewards.
      num_augmentations: Number of DrQ augmentations (crops)
      rep_learn_keywords: Representation learning loss to add (see below)
      env_name: Env name
      batch_size: Batch size
      n_quantiles: Number of GVF quantiles
      temp: Temperature of NCE softmax
      num_training_levels: Number of training MDPs (Procgen=200)
      latent_dim: Latent dimensions of auxiliary MLPs
      n_levels_nce: Number of MDPs to use contrastive loss on
      popart_norm_beta: PopArt normalization constant

    For `rep_learn_keywords`, pick from:
      stop_grad_FQI: whether to stop_grad TD/FQI critic updates?
      linear_Q: use a linear critic?

      successor_features: uses ||SF|| as cumulant
      gvf_termination: uses +1 if done else 0 as cumulant
      gvf_action_count: uses state-cond. action counts as cumulant

      nce: uses the multi-class dot-product InfoNCE objective
      cce: uses MoCo Categorical CrossEntropy objective
      energy: uses SimCLR + pairwise GVF distance (not fully tested)

    If no cumulant is specified, the reward will be taken as default one.
    """
        del actor_lr, critic_lr, alpha_lr, target_entropy
        self.action_spec = action_spec
        self.num_augmentations = num_augmentations
        self.rep_learn_keywords = rep_learn_keywords.split('__')
        self.batch_size = batch_size
        self.env_name = env_name
        self.stop_grad_fqi = 'stop_grad_FQI' in self.rep_learn_keywords
        critic_kwargs = {'hidden_dims': (1024, 1024)}
        self.latent_dim = latent_dim
        self.n_levels_nce = n_levels_nce
        hidden_dims = hidden_dims_per_level = (self.latent_dim,
                                               self.latent_dim)
        self.num_training_levels = int(num_training_levels)
        self.n_quantiles = n_quantiles
        self.temp = temp

        # Make 2 sets of weights:
        # - Critic
        # - Critic (target)
        # Optionally, make a 3rd set for per-level critics

        if observation_spec.shape == (64, 64, 3):
            # IMPALA for Procgen
            def conv_stack():
                return make_impala_cnn_network(depths=[16, 32, 32],
                                               use_batch_norm=False,
                                               dropout_rate=0.)

            state_dim = 256
        else:
            # Reduced architecture for DMC
            def conv_stack():
                return ConvStack(observation_spec.shape)

            state_dim = 50

        conv_stack_critic = conv_stack()
        conv_target_stack_critic = conv_stack()

        if observation_spec.shape == (64, 64, 3):
            conv_stack_critic.output_size = state_dim
            conv_target_stack_critic.output_size = state_dim
        critic_kwargs['encoder'] = ImageEncoder(conv_stack_critic,
                                                feature_dim=state_dim,
                                                bprop_conv_stack=True)
        # Note: the target critic does not share any weights.
        critic_kwargs['encoder_target'] = ImageEncoder(
            conv_target_stack_critic,
            feature_dim=state_dim,
            bprop_conv_stack=True)

        conv_stack_critic_per_level = conv_stack()
        conv_target_stack_critic_per_level = conv_stack()
        if observation_spec.shape == (64, 64, 3):
            conv_stack_critic_per_level.output_size = state_dim
            conv_target_stack_critic_per_level.output_size = state_dim

        self.encoder_per_level = ImageEncoder(conv_stack_critic_per_level,
                                              feature_dim=state_dim,
                                              bprop_conv_stack=True)
        self.encoder_per_level_target = ImageEncoder(
            conv_target_stack_critic_per_level,
            feature_dim=state_dim,
            bprop_conv_stack=True)

        criticCL.soft_update(self.encoder_per_level,
                             self.encoder_per_level_target,
                             tau=1.0)

        if self.num_augmentations == 0:
            dummy_state = tf.constant(
                np.zeros([1] + list(observation_spec.shape)))
        else:  # account for padding of +4 everywhere and then cropping out 68
            dummy_state = tf.constant(np.zeros(shape=[1, 68, 68, 3]))
        dummy_enc = critic_kwargs['encoder'](dummy_state)

        @tf.function
        def init_models():
            """This function initializes all auxiliary networks (state and action encoders) with dummy input (Procgen-specific, 68x68x3, 15 actions).
      """
            critic_kwargs['encoder'](dummy_state)
            critic_kwargs['encoder_target'](dummy_state)
            self.encoder_per_level(dummy_state)
            self.encoder_per_level_target(dummy_state)

        init_models()

        action_dim = action_spec.maximum.item() + 1

        self.action_dim = action_dim
        self.discount = discount
        self.tau = tau
        self.reg = f_reg
        self.reward_bonus = reward_bonus

        self.critic = criticCL.Critic(state_dim,
                                      action_dim,
                                      hidden_dims=hidden_dims,
                                      encoder=critic_kwargs['encoder'],
                                      discrete_actions=True,
                                      linear='linear_Q'
                                      in self.rep_learn_keywords)
        self.critic_target = criticCL.Critic(
            state_dim,
            action_dim,
            hidden_dims=hidden_dims,
            encoder=critic_kwargs['encoder_target'],
            discrete_actions=True,
            linear='linear_Q' in self.rep_learn_keywords)

        self.critic_optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4)
        self.task_critic_optimizer = tf.keras.optimizers.Adam(
            learning_rate=3e-4)
        self.br_optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4)

        if 'cce' in self.rep_learn_keywords:
            self.classifier = tf.keras.Sequential([
                tf.keras.layers.Dense(self.latent_dim, use_bias=True),
                tf.keras.layers.ReLU(),
                tf.keras.layers.Dense(self.n_quantiles, use_bias=True)
            ],
                                                  name='classifier')
        elif 'nce' in self.rep_learn_keywords:
            self.embedding = tf.keras.Sequential([
                tf.keras.layers.Dense(self.latent_dim, use_bias=True),
                tf.keras.layers.ReLU(),
                tf.keras.layers.Dense(self.latent_dim, use_bias=True)
            ],
                                                 name='embedding')

        # This snipet initializes all auxiliary networks (state and action encoders)
        # with dummy input (Procgen-specific, 68x68x3, 15 actions).
        dummy_state = tf.zeros((1, 68, 68, 3), dtype=tf.float32)
        phi_s = self.critic.encoder(dummy_state)
        phi_a = tf.eye(action_dim, dtype=tf.float32)
        if 'linear_Q' in self.rep_learn_keywords:
            _ = self.critic.critic1.state_encoder(phi_s)
            _ = self.critic.critic2.state_encoder(phi_s)
            _ = self.critic.critic1.action_encoder(phi_a)
            _ = self.critic.critic2.action_encoder(phi_a)
            _ = self.critic_target.critic1.state_encoder(phi_s)
            _ = self.critic_target.critic2.state_encoder(phi_s)
            _ = self.critic_target.critic1.action_encoder(phi_a)
            _ = self.critic_target.critic2.action_encoder(phi_a)
        if 'cce' in self.rep_learn_keywords:
            self.classifier(phi_s)
        elif 'nce' in self.rep_learn_keywords:
            self.embedding(phi_s)

        self.target_critic_to_use = self.critic_target
        self.critic_to_use = self.critic

        criticCL.soft_update(self.critic, self.critic_target, tau=1.0)

        self.cce = tf.keras.losses.SparseCategoricalCrossentropy(
            reduction=tf.keras.losses.Reduction.NONE, from_logits=True)

        self.bc = None

        if 'successor_features' in self.rep_learn_keywords:
            self.output_dim_level = self.latent_dim
        elif 'gvf_termination' in self.rep_learn_keywords:
            self.output_dim_level = 1
        elif 'gvf_action_count' in self.rep_learn_keywords:
            self.output_dim_level = action_dim
        else:
            self.output_dim_level = action_dim

        self.task_critic_one = criticCL.Critic(
            state_dim,
            self.output_dim_level * self.num_training_levels,
            hidden_dims=hidden_dims_per_level,
            encoder=None,  # critic_kwargs['encoder'],
            discrete_actions=True,
            cross_norm=False)
        self.task_critic_target_one = criticCL.Critic(
            state_dim,
            self.output_dim_level * 200,
            hidden_dims=hidden_dims_per_level,
            encoder=None,  # critic_kwargs['encoder'],
            discrete_actions=True,
            cross_norm=False)
        self.task_critic_one(dummy_enc,
                             actions=None,
                             training=False,
                             return_features=False,
                             stop_grad_features=False)
        self.task_critic_target_one(dummy_enc,
                                    actions=None,
                                    training=False,
                                    return_features=False,
                                    stop_grad_features=False)
        criticCL.soft_update(self.task_critic_one,
                             self.task_critic_target_one,
                             tau=1.0)

        # Normalization constant beta, set to best default value as per PopArt paper
        self.reward_normalizer = popart.PopArt(
            running_statistics.EMAMeanStd(popart_norm_beta))
        self.reward_normalizer.init()

        if 'CLIP' in self.rep_learn_keywords or 'clip' in self.rep_learn_keywords:
            self.loss_temp = tf.Variable(tf.constant(0.0, dtype=tf.float32),
                                         name='loss_temp',
                                         trainable=True)

        self.model_dict = {
            'critic': self.critic,
            'critic_target': self.critic_target,
            'critic_optimizer': self.critic_optimizer,
            'br_optimizer': self.br_optimizer
        }

        self.model_dict['encoder_perLevel'] = self.encoder_per_level
        self.model_dict[
            'encoder_perLevel_target'] = self.encoder_per_level_target
        self.model_dict['task_critic'] = self.task_critic_one
        self.model_dict['task_critic_target'] = self.task_critic_target_one
Ejemplo n.º 4
0
    def __init__(self,
                 observation_spec,
                 action_spec,
                 embedding_dim=256,
                 hidden_dims=(256, 256),
                 sequence_length=2,
                 learning_rate=None,
                 discount=0.95,
                 target_update_period=1000,
                 num_augmentations=0,
                 rep_learn_keywords='outer',
                 batch_size=256):
        """Creates networks.

    Args:
      observation_spec: State spec.
      action_spec: Action spec.
      embedding_dim: Embedding size.
      hidden_dims: List of hidden dimensions.
      sequence_length: Expected length of sequences provided as input.
      learning_rate: Learning rate.
      discount: discount factor.
      target_update_period: How frequently update target?
      num_augmentations: Number of DrQ random crops.
      rep_learn_keywords: Representation learning loss to add.
      batch_size: batch size.
    """
        super().__init__()
        action_dim = action_spec.maximum.item() + 1
        self.observation_spec = observation_spec
        self.action_dim = action_dim
        self.action_spec = action_spec
        self.embedding_dim = embedding_dim
        self.sequence_length = sequence_length
        self.discount = discount
        self.tau = 0.005
        self.discount = 0.99
        self.target_update_period = target_update_period
        self.num_augmentations = num_augmentations
        self.rep_learn_keywords = rep_learn_keywords.split('__')
        self.batch_size = batch_size

        critic_kwargs = {}

        if observation_spec.shape == (64, 64, 3):
            # IMPALA for Procgen
            def conv_stack():
                return make_impala_cnn_network(depths=[16, 32, 32],
                                               use_batch_norm=False,
                                               dropout_rate=0.)

            state_dim = 256
        else:
            # Reduced architecture for DMC
            def conv_stack():
                return ConvStack(observation_spec.shape)

            state_dim = 50

        conv_stack_critic = conv_stack()
        conv_target_stack_critic = conv_stack()

        if observation_spec.shape == (64, 64, 3):
            conv_stack_critic.output_size = state_dim
            conv_target_stack_critic.output_size = state_dim

        critic_kwargs['encoder'] = ImageEncoder(conv_stack_critic,
                                                feature_dim=state_dim,
                                                bprop_conv_stack=True)
        critic_kwargs['encoder_target'] = ImageEncoder(
            conv_target_stack_critic,
            feature_dim=state_dim,
            bprop_conv_stack=True)

        self.embedder = tf_utils.EmbedNet(state_dim,
                                          embedding_dim=self.embedding_dim,
                                          hidden_dims=hidden_dims)
        self.f_value = tf_utils.create_mlp(self.embedding_dim,
                                           1,
                                           hidden_dims=hidden_dims,
                                           activation=tf.nn.swish)
        self.f_value_target = tf_utils.create_mlp(self.embedding_dim,
                                                  1,
                                                  hidden_dims=hidden_dims,
                                                  activation=tf.nn.swish)
        self.f_trans = tf_utils.create_mlp(self.embedding_dim +
                                           self.embedding_dim,
                                           self.embedding_dim,
                                           hidden_dims=hidden_dims,
                                           activation=tf.nn.swish)
        self.f_out = tf_utils.create_mlp(self.embedding_dim +
                                         self.embedding_dim,
                                         2,
                                         hidden_dims=hidden_dims,
                                         activation=tf.nn.swish)

        self.action_encoder = tf.keras.Sequential(
            [
                tf.keras.layers.Dense(
                    self.embedding_dim, use_bias=True
                ),  # , kernel_regularizer=tf.keras.regularizers.l2(WEIGHT_DECAY)
                tf.keras.layers.ReLU(),
                tf.keras.layers.Dense(self.embedding_dim)
            ],
            name='action_encoder')

        if self.num_augmentations == 0:
            dummy_state = tf.constant(
                np.zeros(shape=[1] + list(observation_spec.shape)))
            self.obs_spec = list(observation_spec.shape)
        else:  # account for padding of +4 everywhere and then cropping out 68
            dummy_state = tf.constant(np.zeros(shape=[1, 68, 68, 3]))
            self.obs_spec = [68, 68, 3]

        @tf.function
        def init_models():
            critic_kwargs['encoder'](dummy_state)
            critic_kwargs['encoder_target'](dummy_state)
            self.action_encoder(
                tf.cast(tf.one_hot([1], depth=action_dim), tf.float32))

        init_models()

        self.critic = critic.Critic(state_dim,
                                    action_dim,
                                    hidden_dims=hidden_dims,
                                    encoder=critic_kwargs['encoder'],
                                    discrete_actions=True,
                                    linear='linear_Q'
                                    in self.rep_learn_keywords)
        self.critic_target = critic.Critic(
            state_dim,
            action_dim,
            hidden_dims=hidden_dims,
            encoder=critic_kwargs['encoder_target'],
            discrete_actions=True,
            linear='linear_Q' in self.rep_learn_keywords)

        @tf.function
        def init_models2():
            dummy_state = tf.zeros((1, 68, 68, 3), dtype=tf.float32)
            phi_s = self.critic.encoder(dummy_state)
            phi_a = tf.eye(15, dtype=tf.float32)
            if 'linear_Q' in self.rep_learn_keywords:
                _ = self.critic.critic1.state_encoder(phi_s)
                _ = self.critic.critic2.state_encoder(phi_s)
                _ = self.critic.critic1.action_encoder(phi_a)
                _ = self.critic.critic2.action_encoder(phi_a)
                _ = self.critic_target.critic1.state_encoder(phi_s)
                _ = self.critic_target.critic2.state_encoder(phi_s)
                _ = self.critic_target.critic1.action_encoder(phi_a)
                _ = self.critic_target.critic2.action_encoder(phi_a)

        init_models2()

        critic.soft_update(self.critic, self.critic_target, tau=1.0)
        critic.soft_update(self.f_value, self.f_value_target, tau=1.0)

        learning_rate = learning_rate or 1e-4
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4)
        self.critic_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

        self.all_variables = (self.embedder.trainable_variables +
                              self.f_value.trainable_variables +
                              self.f_value_target.trainable_variables +
                              self.f_trans.trainable_variables +
                              self.f_out.trainable_variables +
                              self.critic.trainable_variables +
                              self.critic_target.trainable_variables)

        self.model_dict = {
            'action_encoder': self.action_encoder,
            'f_out': self.f_out,
            'f_trans': self.f_trans,
            'f_value_target': self.f_value_target,
            'f_value': self.f_value,
            'embedder': self.embedder,
            'critic': self.critic,
            'critic_target': self.critic_target,
            'critic_optimizer': self.critic_optimizer,
            'optimizer': self.optimizer
        }
Ejemplo n.º 5
0
    def __init__(self,
                 observation_spec,
                 action_spec,
                 actor_lr=3e-4,
                 critic_lr=3e-4,
                 alpha_lr=3e-4,
                 discount=0.99,
                 tau=0.005,
                 target_update_period=1,
                 target_entropy=0.0,
                 cross_norm=False,
                 pcl_actor_update=False):
        """Creates networks.

    Args:
      observation_spec: environment observation spec.
      action_spec: Action spec.
      actor_lr: Actor learning rate.
      critic_lr: Critic learning rate.
      alpha_lr: Temperature learning rate.
      discount: MDP discount.
      tau: Soft target update parameter.
      target_update_period: Target network update period.
      target_entropy: Target entropy.
      cross_norm: Whether to fit cross norm critic.
      pcl_actor_update: Whether to use PCL actor update.
    """
        actor_kwargs = {}
        critic_kwargs = {}

        if len(observation_spec.shape) == 3:  # Image observations.
            # DRQ encoder params.
            # https://github.com/denisyarats/drq/blob/master/config.yaml#L73
            state_dim = 50

            # Actor and critic encoders share conv weights only.
            conv_stack = ConvStack(observation_spec.shape)

            actor_kwargs['encoder'] = ImageEncoder(conv_stack,
                                                   state_dim,
                                                   bprop_conv_stack=False)
            actor_kwargs['hidden_dims'] = (1024, 1024)

            critic_kwargs['encoder'] = ImageEncoder(conv_stack,
                                                    state_dim,
                                                    bprop_conv_stack=True)
            critic_kwargs['hidden_dims'] = (1024, 1024)

            if not cross_norm:
                # Note: the target critic does not share any weights.
                critic_kwargs['encoder_target'] = ImageEncoder(
                    ConvStack(observation_spec.shape),
                    state_dim,
                    bprop_conv_stack=True)

        else:  # 1D state observations.
            assert len(observation_spec.shape) == 1
            state_dim = observation_spec.shape[0]

        if cross_norm:
            beta_1 = 0.0
        else:
            beta_1 = 0.9

        self.actor = policies.DiagGuassianPolicy(state_dim, action_spec,
                                                 **actor_kwargs)
        self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr,
                                                        beta_1=beta_1)

        self.log_alpha = tf.Variable(tf.math.log(0.1), trainable=True)
        self.alpha_optimizer = tf.keras.optimizers.Adam(learning_rate=alpha_lr,
                                                        beta_1=beta_1)

        if cross_norm:
            assert 'encoder_target' not in critic_kwargs
            self.critic_learner = critic.CrossNormCriticLearner(
                state_dim, action_spec.shape[0], critic_lr, discount, tau,
                **critic_kwargs)
        else:
            self.critic_learner = critic.CriticLearner(
                state_dim, action_spec.shape[0], critic_lr, discount, tau,
                target_update_period, **critic_kwargs)

        self.target_entropy = target_entropy
        self.discount = discount

        self.pcl_actor_update = pcl_actor_update
Ejemplo n.º 6
0
    def __init__(self,
                 observation_spec,
                 action_spec,
                 actor_lr=3e-4,
                 critic_lr=3e-4,
                 alpha_lr=3e-4,
                 discount=0.99,
                 tau=0.005,
                 target_entropy=0.0,
                 f_reg=1.0,
                 reward_bonus=5.0,
                 num_augmentations=1,
                 env_name='',
                 batch_size=256):
        """Creates networks.

    Args:
      observation_spec: environment observation spec.
      action_spec: Action spec.
      actor_lr: Actor learning rate.
      critic_lr: Critic learning rate.
      alpha_lr: Temperature learning rate.
      discount: MDP discount.
      tau: Soft target update parameter.
      target_entropy: Target entropy.
      f_reg: Critic regularization weight.
      reward_bonus: Bonus added to the rewards.
      num_augmentations: Number of DrQ augmentations (crops)
      env_name: Env name
      batch_size: Batch size
    """
        self.num_augmentations = num_augmentations
        self.discrete_actions = False if len(action_spec.shape) else True
        self.batch_size = batch_size

        actor_kwargs = {'hidden_dims': (1024, 1024)}
        critic_kwargs = {'hidden_dims': (1024, 1024)}

        # DRQ encoder params.
        # https://github.com/denisyarats/drq/blob/master/config.yaml#L73

        # Make 4 sets of weights:
        # - BC
        # - Actor
        # - Critic
        # - Critic (target)

        if observation_spec.shape == (64, 64, 3):
            # IMPALA for Procgen
            def conv_stack():
                return make_impala_cnn_network(depths=[16, 32, 32],
                                               use_batch_norm=False,
                                               dropout_rate=0.)

            state_dim = 256
        else:
            # Reduced architecture for DMC
            def conv_stack():
                return ConvStack(observation_spec.shape)

            state_dim = 50

        conv_stack_bc = conv_stack()
        conv_stack_actor = conv_stack()
        conv_stack_critic = conv_stack()
        conv_target_stack_critic = conv_stack()

        if observation_spec.shape == (64, 64, 3):
            conv_stack_bc.output_size = state_dim
            conv_stack_actor.output_size = state_dim
            conv_stack_critic.output_size = state_dim
            conv_target_stack_critic.output_size = state_dim
        # Combine and stop_grad some of the above conv stacks
        actor_kwargs['encoder_bc'] = ImageEncoder(conv_stack_bc,
                                                  feature_dim=state_dim,
                                                  bprop_conv_stack=True)
        actor_kwargs['encoder'] = ImageEncoder(conv_stack_critic,
                                               feature_dim=state_dim,
                                               bprop_conv_stack=False)
        critic_kwargs['encoder'] = ImageEncoder(conv_stack_critic,
                                                feature_dim=state_dim,
                                                bprop_conv_stack=True)
        # Note: the target critic does not share any weights.
        critic_kwargs['encoder_target'] = ImageEncoder(
            conv_target_stack_critic,
            feature_dim=state_dim,
            bprop_conv_stack=True)

        if self.num_augmentations == 0:
            dummy_state = tf.constant(
                np.zeros(shape=[1] + list(observation_spec.shape)))
        else:  # account for padding of +4 everywhere and then cropping out 68
            dummy_state = tf.constant(np.zeros(shape=[1, 68, 68, 3]))

        @tf.function
        def init_models():
            actor_kwargs['encoder_bc'](dummy_state)
            actor_kwargs['encoder'](dummy_state)
            critic_kwargs['encoder'](dummy_state)
            critic_kwargs['encoder_target'](dummy_state)

        init_models()

        if self.discrete_actions:
            hidden_dims = ()
            self.actor = policies.CategoricalPolicy(
                state_dim,
                action_spec,
                hidden_dims=hidden_dims,
                encoder=actor_kwargs['encoder'])
            action_dim = action_spec.maximum.item() + 1
        else:
            hidden_dims = (256, 256, 256)
            self.actor = policies.DiagGuassianPolicy(
                state_dim,
                action_spec,
                hidden_dims=hidden_dims,
                encoder=actor_kwargs['encoder'])
            action_dim = action_spec.shape[0]

        self.action_dim = action_dim

        self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr)

        self.log_alpha = tf.Variable(tf.math.log(1.0), trainable=True)
        self.alpha_optimizer = tf.keras.optimizers.Adam(learning_rate=alpha_lr)

        self.target_entropy = target_entropy
        self.discount = discount
        self.tau = tau

        self.bc = behavioral_cloning.BehavioralCloning(
            observation_spec,
            action_spec,
            mixture=True,
            encoder=actor_kwargs['encoder_bc'],
            num_augmentations=self.num_augmentations,
            env_name=env_name,
            batch_size=batch_size)

        self.critic = critic.Critic(state_dim,
                                    action_dim,
                                    hidden_dims=hidden_dims,
                                    encoder=critic_kwargs['encoder'])
        self.critic_target = critic.Critic(
            state_dim,
            action_dim,
            hidden_dims=hidden_dims,
            encoder=critic_kwargs['encoder_target'])

        critic.soft_update(self.critic, self.critic_target, tau=1.0)
        self.critic_optimizer = tf.keras.optimizers.Adam(
            learning_rate=critic_lr)

        self.f_reg = f_reg
        self.reward_bonus = reward_bonus

        self.model_dict = {
            'critic': self.critic,
            'critic_target': self.critic_target,
            'actor': self.actor,
            'bc': self.bc,
            'critic_optimizer': self.critic_optimizer,
            'alpha_optimizer': self.alpha_optimizer,
            'actor_optimizer': self.actor_optimizer
        }
  def __init__(self,
               observation_spec,
               action_spec,
               mixture = False,
               encoder=None,
               num_augmentations = 1,
               env_name = '',
               rep_learn_keywords = 'outer',
               batch_size = 256):
    if observation_spec.shape == (64, 64, 3):
      state_dim = 256
    else:
      state_dim = 50

    self.batch_size = batch_size
    self.num_augmentations = num_augmentations
    self.rep_learn_keywords = rep_learn_keywords.split('__')

    self.discrete_actions = False if len(action_spec.shape) else True

    self.action_spec = action_spec

    if encoder is None:
      if observation_spec.shape == (64, 64, 3):
        # IMPALA for Procgen
        def conv_stack():
          return make_impala_cnn_network(
              depths=[16, 32, 32], use_batch_norm=False, dropout_rate=0.)

        state_dim = 256
      else:
        # Reduced architecture for DMC
        def conv_stack():
          return ConvStack(observation_spec.shape)
        state_dim = 50

      conv_stack_bc = conv_stack()

      if observation_spec.shape == (64, 64, 3):
        conv_stack_bc.output_size = state_dim
      # Combine and stop_grad some of the above conv stacks
      encoder = ImageEncoder(
          conv_stack_bc, feature_dim=state_dim, bprop_conv_stack=True)

      if self.num_augmentations == 0:
        dummy_state = tf.constant(
            np.zeros(shape=[1] + list(observation_spec.shape)))
      else:  # account for padding of +4 everywhere and then cropping out 68
        dummy_state = tf.constant(np.zeros(shape=[1, 68, 68, 3]))

      @tf.function
      def init_models():
        encoder(dummy_state)

      init_models()

    if self.discrete_actions:
      if 'linear_Q' in self.rep_learn_keywords:
        hidden_dims = ()
      else:
        hidden_dims = (256, 256)
      self.policy = policies.CategoricalPolicy(
          state_dim, action_spec, hidden_dims=hidden_dims, encoder=encoder)
      action_dim = action_spec.maximum.item() + 1
    else:
      action_dim = action_spec.shape[0]
      if mixture:
        self.policy = policies.MixtureGuassianPolicy(
            state_dim, action_spec, encoder=encoder)
      else:
        self.policy = policies.DiagGuassianPolicy(
            state_dim, action_spec, encoder=encoder)

    self.optimizer = tf.keras.optimizers.Adam(
        learning_rate=5e-4)

    self.log_alpha = tf.Variable(tf.math.log(1.0), trainable=True)
    self.alpha_optimizer = tf.keras.optimizers.Adam(
        learning_rate=5e-4)

    self.target_entropy = -action_dim

    if env_name and env_name.startswith('procgen'):
      self.procgen_action_mat = PROCGEN_ACTION_MAT[env_name.split('-')[1]]

    self.bc = None

    self.model_dict = {
        'policy': self.policy,
        'optimizer': self.optimizer
    }
  def __init__(
      self,
      observation_spec,
      action_spec,
      embedding_dim = 256,
      num_distributions=None,
      hidden_dims = (256, 256),
      sequence_length = 2,
      learning_rate=None,
      latent_dim = 256,
      reward_weight = 1.0,
      forward_weight
       = 1.0,  # Predict last state given prev actions/states.
      inverse_weight = 1.0,  # Predict last action given states.
      state_prediction_mode = 'energy',
      num_augmentations = 0,
      rep_learn_keywords = 'outer',
      batch_size = 256):
    """Creates networks.

    Args:
      observation_spec: State spec.
      action_spec: Action spec.
      embedding_dim: Embedding size.
      num_distributions: Number of categorical distributions for discrete
        embedding.
      hidden_dims: List of hidden dimensions.
      sequence_length: Expected length of sequences provided as input
      learning_rate: Learning rate.
      latent_dim: Dimension of the latent variable.
      reward_weight: Weight on the reward loss.
      forward_weight: Weight on the forward loss.
      inverse_weight: Weight on the inverse loss.
      state_prediction_mode: One of ['latent', 'energy'].
      num_augmentations: Num of random crops
      rep_learn_keywords: Representation learning loss to add.
      batch_size: Batch size
    """
    super().__init__()
    action_dim = action_spec.maximum.item() + 1
    self.observation_spec = observation_spec
    self.action_dim = action_dim
    self.action_spec = action_spec
    self.embedding_dim = embedding_dim
    self.num_distributions = num_distributions
    self.sequence_length = sequence_length
    self.latent_dim = latent_dim
    self.reward_weight = reward_weight
    self.forward_weight = forward_weight
    self.inverse_weight = inverse_weight
    self.state_prediction_mode = state_prediction_mode
    self.num_augmentations = num_augmentations
    self.rep_learn_keywords = rep_learn_keywords.split('__')
    self.batch_size = batch_size
    self.tau = 0.005
    self.discount = 0.99

    critic_kwargs = {}

    if observation_spec.shape == (64, 64, 3):
      # IMPALA for Procgen
      def conv_stack():
        return make_impala_cnn_network(
            depths=[16, 32, 32], use_batch_norm=False, dropout_rate=0.)

      state_dim = 256
    else:
      # Reduced architecture for DMC
      def conv_stack():
        return ConvStack(observation_spec.shape)
      state_dim = 50

    conv_stack_critic = conv_stack()
    conv_target_stack_critic = conv_stack()

    if observation_spec.shape == (64, 64, 3):
      conv_stack_critic.output_size = state_dim
      conv_target_stack_critic.output_size = state_dim

    critic_kwargs['encoder'] = ImageEncoder(
        conv_stack_critic, feature_dim=state_dim, bprop_conv_stack=True)
    critic_kwargs['encoder_target'] = ImageEncoder(
        conv_target_stack_critic, feature_dim=state_dim, bprop_conv_stack=True)

    self.embedder = tf_utils.EmbedNet(
        state_dim,
        embedding_dim=self.embedding_dim,
        num_distributions=self.num_distributions,
        hidden_dims=hidden_dims)

    if self.sequence_length > 2:
      self.latent_embedder = tf_utils.RNNEmbedNet(
          [self.sequence_length - 2, self.embedding_dim + self.embedding_dim],
          embedding_dim=self.latent_dim)

    self.reward_decoder = tf_utils.EmbedNet(
        self.latent_dim + self.embedding_dim + self.embedding_dim,
        embedding_dim=1,
        hidden_dims=hidden_dims)

    forward_decoder_out = (
        self.embedding_dim if (self.state_prediction_mode
                               in ['latent', 'energy']) else self.input_dim)
    forward_decoder_dists = (
        self.num_distributions if
        (self.state_prediction_mode in ['latent', 'energy']) else None)
    self.forward_decoder = tf_utils.StochasticEmbedNet(
        self.latent_dim + self.embedding_dim + self.embedding_dim,
        embedding_dim=forward_decoder_out,
        num_distributions=forward_decoder_dists,
        hidden_dims=hidden_dims)

    self.weight = tf.Variable(tf.eye(self.embedding_dim))

    self.action_encoder = tf.keras.Sequential(
        [
            tf.keras.layers.Dense(
                self.embedding_dim, use_bias=True
            ),  # , kernel_regularizer=tf.keras.regularizers.l2(WEIGHT_DECAY)
            tf.keras.layers.ReLU(),
            tf.keras.layers.Dense(self.embedding_dim)
        ],
        name='action_encoder')

    if self.num_augmentations == 0:
      dummy_state = tf.constant(
          np.zeros(shape=[1] + list(observation_spec.shape)))
      self.obs_spec = list(observation_spec.shape)
    else:  # account for padding of +4 everywhere and then cropping out 68
      dummy_state = tf.constant(np.zeros(shape=[1, 68, 68, 3]))
      self.obs_spec = [68, 68, 3]

    @tf.function
    def init_models():
      critic_kwargs['encoder'](dummy_state)
      critic_kwargs['encoder_target'](dummy_state)
      self.action_encoder(
          tf.cast(tf.one_hot([1], depth=action_dim), tf.float32))

    init_models()

    hidden_dims = (256, 256)
    # self.actor = policies.CategoricalPolicy(state_dim, action_spec,
    #  hidden_dims=hidden_dims, encoder=actor_kwargs['encoder'])

    self.critic = critic.Critic(
        state_dim,
        action_dim,
        hidden_dims=hidden_dims,
        encoder=critic_kwargs['encoder'],
        discrete_actions=True,
        linear='linear_Q' in self.rep_learn_keywords)
    self.critic_target = critic.Critic(
        state_dim,
        action_dim,
        hidden_dims=hidden_dims,
        encoder=critic_kwargs['encoder_target'],
        discrete_actions=True,
        linear='linear_Q' in self.rep_learn_keywords)

    @tf.function
    def init_models2():
      dummy_state = tf.zeros((1, 68, 68, 3), dtype=tf.float32)
      phi_s = self.critic.encoder(dummy_state)
      phi_a = tf.eye(15, dtype=tf.float32)
      if 'linear_Q' in self.rep_learn_keywords:
        _ = self.critic.critic1.state_encoder(phi_s)
        _ = self.critic.critic2.state_encoder(phi_s)
        _ = self.critic.critic1.action_encoder(phi_a)
        _ = self.critic.critic2.action_encoder(phi_a)
        _ = self.critic_target.critic1.state_encoder(phi_s)
        _ = self.critic_target.critic2.state_encoder(phi_s)
        _ = self.critic_target.critic1.action_encoder(phi_a)
        _ = self.critic_target.critic2.action_encoder(phi_a)

    init_models2()

    critic.soft_update(self.critic, self.critic_target, tau=1.0)

    learning_rate = learning_rate or 1e-4
    self.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    self.critic_optimizer = tf.keras.optimizers.Adam(
        learning_rate=learning_rate)

    self.all_variables = (
        self.embedder.trainable_variables +
        self.reward_decoder.trainable_variables +
        self.forward_decoder.trainable_variables +
        self.action_encoder.trainable_variables +
        self.critic.trainable_variables +
        self.critic_target.trainable_variables)

    self.model_dict = {
        'action_encoder': self.action_encoder,
        'weight': self.weight,
        'forward_decoder': self.forward_decoder,
        'reward_decoder': self.reward_decoder,
        'embedder': self.embedder,
        'critic': self.critic,
        'critic_target': self.critic_target,
        'critic_optimizer': self.critic_optimizer,
        'optimizer': self.optimizer
    }