def update_nets(self): '''Update target networks''' if util.frame_mod(self.body.env.clock.frame, self.q1_net.update_frequency, self.body.env.num_envs): if self.q1_net.update_type == 'replace': net_util.copy(self.q1_net, self.target_q1_net) net_util.copy(self.q2_net, self.target_q2_net) elif self.q1_net.update_type == 'polyak': net_util.polyak_update(self.q1_net, self.target_q1_net, self.q1_net.polyak_coef) net_util.polyak_update(self.q2_net, self.target_q2_net, self.q2_net.polyak_coef) else: raise ValueError('Unknown q1_net.update_type. Should be "replace" or "polyak". Exiting.')
def update_nets(self): total_t = self.body.env.clock.total_t if total_t % self.net.update_frequency == 0: if self.net.update_type == 'replace': logger.debug('Updating target_net by replacing') net_util.copy(self.net, self.target_net) elif self.net.update_type == 'polyak': logger.debug('Updating net by averaging') net_util.polyak_update(self.net, self.target_net, self.net.polyak_coef) else: raise ValueError( 'Unknown net.update_type. Should be "replace" or "polyak". Exiting.' )
def update_nets(self): total_t = util.s_get(self, 'aeb_space.clock').get('total_t') if self.net.update_type == 'replace': if total_t % self.net.update_frequency == 0: logger.debug('Updating target_net by replacing') self.target_net.load_state_dict(self.net.state_dict()) self.online_net = self.target_net self.eval_net = self.target_net elif self.net.update_type == 'polyak': logger.debug('Updating net by averaging') net_util.polyak_update(self.net, self.target_net, self.net.polyak_coef) self.online_net = self.target_net self.eval_net = self.target_net else: raise ValueError('Unknown net.update_type. Should be "replace" or "polyak". Exiting.')