def __init__(self, actor_model, tgt_actor_model, critic_model, tgt_critic_model, action_limits,
                 actor_lr=1e-4, critic_lr=1e-3, critic_decay=1e-2, process=None, rb_size=1e6,
                 minibatch_size=64, tau=1e-3, gamma=0.99, warmup_episodes=None, logging=True):
        super(DDPGAgent, self).__init__(warmup_episodes, logging)

        self.actor = Actor(actor_model, critic_model, lr=actor_lr)
        self.tgt_actor = Actor(tgt_actor_model, tgt_critic_model, lr=actor_lr)
        self.tgt_actor.set_weights(self.actor.get_weights())

        self.critic = Critic(critic_model, lr=critic_lr, decay=critic_decay)
        self.tgt_critic = Critic(tgt_critic_model, lr=critic_lr, decay=critic_decay)
        self.tgt_critic.set_weights(self.critic.get_weights())

        self.action_limits = action_limits
        self.process = process
        self.buffer = ReplayBuffer(rb_size)
        self.minibatch_size = minibatch_size
        self.tau = tau
        self.gamma = gamma

        self.state_space = K.int_shape(critic_model.inputs[0])[1]
        self.action_space = K.int_shape(critic_model.inputs[1])[1]
        if process is None:
            self.process = OrnsteinUhlenbeck(x0=np.zeros(self.action_space), theta=0.15, mu=0,
                                             sigma=0.2)
        else:
            self.process = process
Esempio n. 2
0
    def __init__(self,
                 task,
                 exploration_mu=0,
                 exploration_theta=0.15,
                 exploration_sigma=0.2,
                 tau=0.01):
        self.task = task
        self.state_size = task.state_size
        self.action_size = task.action_size
        self.action_low = task.action_low
        self.action_high = task.action_high

        # Actor (Policy) Model
        self.actor_local = Actor(self.state_size, self.action_size,
                                 self.action_low, self.action_high)
        self.actor_target = Actor(self.state_size, self.action_size,
                                  self.action_low, self.action_high)

        # Critic (Value) Model
        self.critic_local = Critic(self.state_size, self.action_size)
        self.critic_target = Critic(self.state_size, self.action_size)

        # Initialize target model parameters with local model parameters
        self.critic_target.model.set_weights(
            self.critic_local.model.get_weights())
        self.actor_target.model.set_weights(
            self.actor_local.model.get_weights())

        # Noise process
        self.exploration_mu = exploration_mu
        self.exploration_theta = exploration_theta
        self.exploration_sigma = exploration_sigma
        self.noise = OUNoise(self.action_size, self.exploration_mu,
                             self.exploration_theta, self.exploration_sigma)

        # Replay memory
        self.buffer_size = 100000
        self.batch_size = 128  # 64
        self.memory = ReplayBuffer(self.buffer_size, self.batch_size)

        # Algorithm parameters
        self.gamma = 0.85  # 0.99  # discount factor
        self.tau = tau  # for soft update of target parameters

        ##
        self.total_reward = 0
        self.best_score = -np.inf
        self.score = 0
        self.count = 0
Esempio n. 3
0
 def __init__(self, env, hparams):
     n_action = len(env.action_space.high)
     self.actor_main = Actor(n_action, hparams)
     self.actor_target = Actor(n_action, hparams)
     self.critic_main = Critic(hparams)
     self.critic_target = Critic(hparams)
     self.batch_size = 64
     self.n_actions = len(env.action_space.high)
     self.a_opt = tf.keras.optimizers.Adam(hparams['lr'])
     # self.actor_target = tf.keras.optimizers.Adam(.001)
     self.c_opt = tf.keras.optimizers.Adam(hparams['lr'])
     # self.critic_target = tf.keras.optimizers.Adam(.002)
     self.memory = RBuffer(1_00_000, env.observation_space.shape, len(env.action_space.high))
     self.trainstep = 0
     self.replace = 5
     self.gamma = 0.99
     self.min_action = env.action_space.low[0]
     self.max_action = env.action_space.high[0]