예제 #1
0
 def update(self, algorithm, clock):
     '''Get an updated value for var'''
     if (util.in_eval_lab_mode()) or self._updater_name == 'no_decay':
         return self.end_val
     step = clock.get()
     val = self._updater(self.start_val, self.end_val, self.start_step,
                         self.end_step, step)
     return val
예제 #2
0
def try_scale_reward(cls, reward):
    '''Env class to scale reward'''
    if util.in_eval_lab_mode():  # only trigger on training
        return reward
    if cls.reward_scale is not None:
        if cls.sign_reward:
            reward = np.sign(reward)
        else:
            reward *= cls.reward_scale
    return reward
예제 #3
0
 def update(self, state, action, reward, next_state, done):
     '''Update per timestep after env transitions, e.g. memory, algorithm, update agent params, train net'''
     self.body.update(state, action, reward, next_state, done)
     if util.in_eval_lab_mode():  # eval does not update agent for training
         return
     self.body.memory.update(state, action, reward, next_state, done)
     loss = self.algorithm.train()
     if not np.isnan(loss):  # set for log_summary()
         self.body.loss = loss
     explore_var = self.algorithm.update()
     return loss, explore_var
예제 #4
0
파일: control.py 프로젝트: c-w-m/slm-lab
 def to_ckpt(self, env, mode='eval'):
     '''Check with clock whether to run log/eval ckpt: at the start, save_freq, and the end'''
     if mode == 'eval' and util.in_eval_lab_mode(
     ):  # avoid double-eval: eval-ckpt in eval mode
         return False
     clock = env.clock
     frame = clock.get()
     frequency = env.eval_frequency if mode == 'eval' else env.log_frequency
     to_ckpt = util.frame_mod(frame, frequency,
                              env.num_envs) or frame == clock.max_frame
     return to_ckpt
예제 #5
0
 def save(self, ckpt=None):
     '''Save agent'''
     if util.in_eval_lab_mode():  # eval does not save new models
         return
     self.algorithm.save(ckpt=ckpt)