コード例 #1
0
    def train_epoch(self):
        """Trains RL for one epoch."""
        n_value_evals = rl_training.remaining_evals(
            self._value_trainer.step, self._epoch,
            self._value_train_steps_per_epoch, self._value_evals_per_epoch)
        for _ in range(n_value_evals):
            self._value_trainer.train_epoch(
                self._value_train_steps_per_epoch //
                self._value_evals_per_epoch,
                self._value_eval_steps,
            )
        if self._n_shared_layers > 0:  # Copy value weights to policy trainer.
            _copy_model_weights(0, self._n_shared_layers, self._value_trainer,
                                self._policy_trainer)
        n_policy_evals = rl_training.remaining_evals(
            self._policy_trainer.step, self._epoch,
            self._policy_train_steps_per_epoch, self._policy_evals_per_epoch)
        # Check if there was a restart after value training finishes and policy not.
        stopped_after_value = (n_value_evals == 0 and
                               n_policy_evals < self._policy_evals_per_epoch)
        should_copy_weights = self._n_shared_layers > 0 and not stopped_after_value
        if should_copy_weights:
            _copy_model_weights(0, self._n_shared_layers, self._value_trainer,
                                self._policy_trainer)

        for _ in range(n_policy_evals):
            self._policy_trainer.train_epoch(
                self._policy_train_steps_per_epoch //
                self._policy_evals_per_epoch,
                self._policy_eval_steps,
            )
        if self._n_shared_layers > 0:  # Copy policy weights to value trainer.
            _copy_model_weights(0, self._n_shared_layers, self._policy_trainer,
                                self._value_trainer)
コード例 #2
0
ファイル: actor_critic.py プロジェクト: tvjoseph/trax
    def train_epoch(self):
        """Trains RL for one epoch."""
        # Copy policy state accumulated during data collection to the trainer.
        self._policy_trainer.model_state = self._policy_collect_model.state

        # Copy policy weights and state to value trainer.
        if self._n_shared_layers > 0:
            _copy_model_weights_and_state(0, self._n_shared_layers,
                                          self._policy_trainer,
                                          self._value_trainer)

        # Update the target value network.
        self._value_eval_model.weights = self._value_trainer.model_weights
        self._value_eval_model.state = self._value_trainer.model_state

        n_value_evals = rl_training.remaining_evals(
            self._value_trainer.step, self._epoch,
            self._value_train_steps_per_epoch, self._value_evals_per_epoch)
        for _ in range(n_value_evals):
            self._value_trainer.train_epoch(
                self._value_train_steps_per_epoch //
                self._value_evals_per_epoch,
                self._value_eval_steps,
            )
            # Update the target value network.
            self._value_eval_model.weights = self._value_trainer.model_weights
            self._value_eval_model.state = self._value_trainer.model_state

        # Copy value weights and state to policy trainer.
        if self._n_shared_layers > 0:
            _copy_model_weights_and_state(0, self._n_shared_layers,
                                          self._value_trainer,
                                          self._policy_trainer)
        n_policy_evals = rl_training.remaining_evals(
            self._policy_trainer.step, self._epoch,
            self._policy_train_steps_per_epoch, self._policy_evals_per_epoch)
        # Check if there was a restart after value training finishes and policy not.
        stopped_after_value = (n_value_evals == 0 and
                               n_policy_evals < self._policy_evals_per_epoch)
        should_copy_weights = self._n_shared_layers > 0 and not stopped_after_value
        if should_copy_weights:
            _copy_model_weights_and_state(0, self._n_shared_layers,
                                          self._value_trainer,
                                          self._policy_trainer)

        # Update the target value network.
        self._value_eval_model.weights = self._value_trainer.model_weights
        self._value_eval_model.state = self._value_trainer.model_state

        for _ in range(n_policy_evals):
            self._policy_trainer.train_epoch(
                self._policy_train_steps_per_epoch //
                self._policy_evals_per_epoch,
                self._policy_eval_steps,
            )
コード例 #3
0
 def train_epoch(self):
     """Trains RL for one epoch."""
     n_evals = rl_training.remaining_evals(self._trainer.step, self._epoch,
                                           self._train_steps_per_epoch,
                                           self._supervised_evals_per_epoch)
     for _ in range(n_evals):
         self._trainer.train_epoch(
             self._train_steps_per_epoch //
             self._supervised_evals_per_epoch, self._supervised_eval_steps)