コード例 #1
0
ファイル: sac.py プロジェクト: jimfleming/SLM-Lab
 def update_nets(self):
     '''Update target networks'''
     if util.frame_mod(self.body.env.clock.frame, self.q1_net.update_frequency, self.body.env.num_envs):
         if self.q1_net.update_type == 'replace':
             net_util.copy(self.q1_net, self.target_q1_net)
             net_util.copy(self.q2_net, self.target_q2_net)
         elif self.q1_net.update_type == 'polyak':
             net_util.polyak_update(self.q1_net, self.target_q1_net, self.q1_net.polyak_coef)
             net_util.polyak_update(self.q2_net, self.target_q2_net, self.q2_net.polyak_coef)
         else:
             raise ValueError('Unknown q1_net.update_type. Should be "replace" or "polyak". Exiting.')
コード例 #2
0
ファイル: dqn.py プロジェクト: vmuthuk2/SLM-Lab
 def update_nets(self):
     total_t = self.body.env.clock.total_t
     if total_t % self.net.update_frequency == 0:
         if self.net.update_type == 'replace':
             logger.debug('Updating target_net by replacing')
             net_util.copy(self.net, self.target_net)
         elif self.net.update_type == 'polyak':
             logger.debug('Updating net by averaging')
             net_util.polyak_update(self.net, self.target_net,
                                    self.net.polyak_coef)
         else:
             raise ValueError(
                 'Unknown net.update_type. Should be "replace" or "polyak". Exiting.'
             )
コード例 #3
0
 def update_nets(self):
     total_t = util.s_get(self, 'aeb_space.clock').get('total_t')
     if self.net.update_type == 'replace':
         if total_t % self.net.update_frequency == 0:
             logger.debug('Updating target_net by replacing')
             self.target_net.load_state_dict(self.net.state_dict())
             self.online_net = self.target_net
             self.eval_net = self.target_net
     elif self.net.update_type == 'polyak':
         logger.debug('Updating net by averaging')
         net_util.polyak_update(self.net, self.target_net, self.net.polyak_coef)
         self.online_net = self.target_net
         self.eval_net = self.target_net
     else:
         raise ValueError('Unknown net.update_type. Should be "replace" or "polyak". Exiting.')