Example #1
0
 def sync_target_network(self):
     """Synchronize target network with current network."""
     synchronize_parameters(
         src=self.model,
         dst=self.target_model,
         method=self.target_update_method,
         tau=self.soft_update_tau)
Example #2
0
 def sync_target_network(self):
     """Synchronize target network with current network."""
     if self.target_model is None:
         self.target_model = copy.deepcopy(self.model)
     else:
         synchronize_parameters(src=self.model,
                                dst=self.target_model,
                                method=self.target_update_method,
                                tau=self.soft_update_tau)
 def sync_target_network(self):
     """Synchronize target network with current network."""
     synchronize_parameters(
         src=self.q_func1,
         dst=self.target_q_func1,
         method='soft',
         tau=self.soft_update_tau,
     )
     synchronize_parameters(
         src=self.q_func2,
         dst=self.target_q_func2,
         method='soft',
         tau=self.soft_update_tau,
     )
Example #4
0
    def sync_target_network(self):
        """Synchronize target network with current network."""
        if self.target_model is None:
            self.target_model = copy.deepcopy(self.model)
            call_orig = self.target_model.__call__

            def call_test(self_, x):
                with chainer.using_config('train', False):
                    return call_orig(self_, x)

            self.target_model.__call__ = call_test
        else:
            synchronize_parameters(src=self.model,
                                   dst=self.target_model,
                                   method=self.target_update_method,
                                   tau=self.soft_update_tau)