def train_epoch(self): """Trains RL for one epoch.""" n_value_evals = rl_training.remaining_evals( self._value_trainer.step, self._epoch, self._value_train_steps_per_epoch, self._value_evals_per_epoch) for _ in range(n_value_evals): self._value_trainer.train_epoch( self._value_train_steps_per_epoch // self._value_evals_per_epoch, self._value_eval_steps, ) if self._n_shared_layers > 0: # Copy value weights to policy trainer. _copy_model_weights(0, self._n_shared_layers, self._value_trainer, self._policy_trainer) n_policy_evals = rl_training.remaining_evals( self._policy_trainer.step, self._epoch, self._policy_train_steps_per_epoch, self._policy_evals_per_epoch) # Check if there was a restart after value training finishes and policy not. stopped_after_value = (n_value_evals == 0 and n_policy_evals < self._policy_evals_per_epoch) should_copy_weights = self._n_shared_layers > 0 and not stopped_after_value if should_copy_weights: _copy_model_weights(0, self._n_shared_layers, self._value_trainer, self._policy_trainer) for _ in range(n_policy_evals): self._policy_trainer.train_epoch( self._policy_train_steps_per_epoch // self._policy_evals_per_epoch, self._policy_eval_steps, ) if self._n_shared_layers > 0: # Copy policy weights to value trainer. _copy_model_weights(0, self._n_shared_layers, self._policy_trainer, self._value_trainer)
def train_epoch(self): """Trains RL for one epoch.""" # Copy policy state accumulated during data collection to the trainer. self._policy_trainer.model_state = self._policy_collect_model.state # Copy policy weights and state to value trainer. if self._n_shared_layers > 0: _copy_model_weights_and_state(0, self._n_shared_layers, self._policy_trainer, self._value_trainer) # Update the target value network. self._value_eval_model.weights = self._value_trainer.model_weights self._value_eval_model.state = self._value_trainer.model_state n_value_evals = rl_training.remaining_evals( self._value_trainer.step, self._epoch, self._value_train_steps_per_epoch, self._value_evals_per_epoch) for _ in range(n_value_evals): self._value_trainer.train_epoch( self._value_train_steps_per_epoch // self._value_evals_per_epoch, self._value_eval_steps, ) # Update the target value network. self._value_eval_model.weights = self._value_trainer.model_weights self._value_eval_model.state = self._value_trainer.model_state # Copy value weights and state to policy trainer. if self._n_shared_layers > 0: _copy_model_weights_and_state(0, self._n_shared_layers, self._value_trainer, self._policy_trainer) n_policy_evals = rl_training.remaining_evals( self._policy_trainer.step, self._epoch, self._policy_train_steps_per_epoch, self._policy_evals_per_epoch) # Check if there was a restart after value training finishes and policy not. stopped_after_value = (n_value_evals == 0 and n_policy_evals < self._policy_evals_per_epoch) should_copy_weights = self._n_shared_layers > 0 and not stopped_after_value if should_copy_weights: _copy_model_weights_and_state(0, self._n_shared_layers, self._value_trainer, self._policy_trainer) # Update the target value network. self._value_eval_model.weights = self._value_trainer.model_weights self._value_eval_model.state = self._value_trainer.model_state for _ in range(n_policy_evals): self._policy_trainer.train_epoch( self._policy_train_steps_per_epoch // self._policy_evals_per_epoch, self._policy_eval_steps, )
def train_epoch(self): """Trains RL for one epoch.""" n_evals = rl_training.remaining_evals(self._trainer.step, self._epoch, self._train_steps_per_epoch, self._supervised_evals_per_epoch) for _ in range(n_evals): self._trainer.train_epoch( self._train_steps_per_epoch // self._supervised_evals_per_epoch, self._supervised_eval_steps)