コード例 #1
0
 def conv_stack():
   return ConvStack(observation_spec.shape)
コード例 #2
0
ファイル: sac.py プロジェクト: tallamjr/google-research
    def __init__(self,
                 observation_spec,
                 action_spec,
                 actor_lr=3e-4,
                 critic_lr=3e-4,
                 alpha_lr=3e-4,
                 discount=0.99,
                 tau=0.005,
                 target_update_period=1,
                 target_entropy=0.0,
                 cross_norm=False,
                 pcl_actor_update=False):
        """Creates networks.

    Args:
      observation_spec: environment observation spec.
      action_spec: Action spec.
      actor_lr: Actor learning rate.
      critic_lr: Critic learning rate.
      alpha_lr: Temperature learning rate.
      discount: MDP discount.
      tau: Soft target update parameter.
      target_update_period: Target network update period.
      target_entropy: Target entropy.
      cross_norm: Whether to fit cross norm critic.
      pcl_actor_update: Whether to use PCL actor update.
    """
        actor_kwargs = {}
        critic_kwargs = {}

        if len(observation_spec.shape) == 3:  # Image observations.
            # DRQ encoder params.
            # https://github.com/denisyarats/drq/blob/master/config.yaml#L73
            state_dim = 50

            # Actor and critic encoders share conv weights only.
            conv_stack = ConvStack(observation_spec.shape)

            actor_kwargs['encoder'] = ImageEncoder(conv_stack,
                                                   state_dim,
                                                   bprop_conv_stack=False)
            actor_kwargs['hidden_dims'] = (1024, 1024)

            critic_kwargs['encoder'] = ImageEncoder(conv_stack,
                                                    state_dim,
                                                    bprop_conv_stack=True)
            critic_kwargs['hidden_dims'] = (1024, 1024)

            if not cross_norm:
                # Note: the target critic does not share any weights.
                critic_kwargs['encoder_target'] = ImageEncoder(
                    ConvStack(observation_spec.shape),
                    state_dim,
                    bprop_conv_stack=True)

        else:  # 1D state observations.
            assert len(observation_spec.shape) == 1
            state_dim = observation_spec.shape[0]

        if cross_norm:
            beta_1 = 0.0
        else:
            beta_1 = 0.9

        self.actor = policies.DiagGuassianPolicy(state_dim, action_spec,
                                                 **actor_kwargs)
        self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr,
                                                        beta_1=beta_1)

        self.log_alpha = tf.Variable(tf.math.log(0.1), trainable=True)
        self.alpha_optimizer = tf.keras.optimizers.Adam(learning_rate=alpha_lr,
                                                        beta_1=beta_1)

        if cross_norm:
            assert 'encoder_target' not in critic_kwargs
            self.critic_learner = critic.CrossNormCriticLearner(
                state_dim, action_spec.shape[0], critic_lr, discount, tau,
                **critic_kwargs)
        else:
            self.critic_learner = critic.CriticLearner(
                state_dim, action_spec.shape[0], critic_lr, discount, tau,
                target_update_period, **critic_kwargs)

        self.target_entropy = target_entropy
        self.discount = discount

        self.pcl_actor_update = pcl_actor_update