コード例 #1
0
 def start_learning(self, env_manager: EnvManager) -> None:
     self._create_output_path(self.output_path)
     if tf_utils.is_available():
         tf.reset_default_graph()
     try:
         # Initial reset
         self._reset_env(env_manager)
         while self._not_done_training():
             n_steps = self.advance(env_manager)
             for _ in range(n_steps):
                 self.reset_env_if_ready(env_manager)
         # Stop advancing trainers
         self.join_threads()
     except (
             KeyboardInterrupt,
             UnityCommunicationException,
             UnityEnvironmentException,
             UnityCommunicatorStoppedException,
     ) as ex:
         self.join_threads()
         self.logger.info(
             "Learning was interrupted. Please wait while the graph is generated."
         )
         if isinstance(ex, KeyboardInterrupt) or isinstance(
                 ex, UnityCommunicatorStoppedException):
             pass
         else:
             # If the environment failed, we want to make sure to raise
             # the exception so we exit the process with an return code of 1.
             raise ex
     finally:
         if self.train_model:
             self._save_models()
コード例 #2
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # collected_rewards is a dictionary from name of reward signal to a dictionary of agent_id to cumulative reward
        # used for reporting only. We always want to report the environment reward to Tensorboard, regardless
        # of what reward signals are actually present.
        self.cumulative_returns_since_policy_update: List[float] = []
        self.collected_rewards: Dict[str, Dict[str, int]] = {
            "environment": defaultdict(lambda: 0)
        }
        self.update_buffer: AgentBuffer = AgentBuffer()
        self._stats_reporter.add_property(
            StatsPropertyType.HYPERPARAMETERS, self.trainer_settings.as_dict()
        )
        self.framework = self.trainer_settings.framework
        if self.framework == FrameworkType.TENSORFLOW and not tf_utils.is_available():
            raise UnityTrainerException(
                "To use the TensorFlow backend, install the TensorFlow Python package first."
            )

        logger.debug(f"Using framework {self.framework.value}")

        self._next_save_step = 0
        self._next_summary_step = 0
        self.model_saver = self.create_model_saver(
            self.framework, self.trainer_settings, self.artifact_path, self.load
        )
コード例 #3
0
    def __init__(
        self,
        trainer_factory: TrainerFactory,
        output_path: str,
        run_id: str,
        param_manager: EnvironmentParameterManager,
        train: bool,
        training_seed: int,
    ):
        """
        :param output_path: Path to save the model.
        :param summaries_dir: Folder to save training summaries.
        :param run_id: The sub-directory name for model and summary statistics
        :param param_manager: EnvironmentParameterManager object which stores information about all
        environment parameters.
        :param train: Whether to train model, or only run inference.
        :param training_seed: Seed to use for Numpy and Tensorflow random number generation.
        :param threaded: Whether or not to run trainers in a separate thread. Disable for testing/debugging.
        """
        self.trainers: Dict[str, Trainer] = {}
        self.brain_name_to_identifier: Dict[str, Set] = defaultdict(set)
        self.trainer_factory = trainer_factory
        self.output_path = output_path
        self.logger = get_logger(__name__)
        self.run_id = run_id
        self.train_model = train
        self.param_manager = param_manager
        self.ghost_controller = self.trainer_factory.ghost_controller
        self.registered_behavior_ids: Set[str] = set()

        self.trainer_threads: List[threading.Thread] = []
        self.kill_trainers = False
        np.random.seed(training_seed)
        if tf_utils.is_available():
            tf.set_random_seed(training_seed)
        torch_utils.torch.manual_seed(training_seed)
        self.rank = get_rank()
コード例 #4
0
ファイル: trainer.py プロジェクト: ssshammi/ml-agents
from mlagents_envs.logging_util import get_logger
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.policy import Policy
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer
from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.settings import TrainerSettings, PPOSettings, FrameworkType
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import (
    BaseRewardProvider,
)
from mlagents import tf_utils

if tf_utils.is_available():
    from mlagents.trainers.policy.tf_policy import TFPolicy
    from mlagents.trainers.ppo.optimizer_tf import PPOOptimizer
else:
    TFPolicy = None  # type: ignore
    PPOOptimizer = None  # type: ignore


logger = get_logger(__name__)


class PPOTrainer(RLTrainer):
    """The PPOTrainer is an implementation of the PPO algorithm."""

    def __init__(
        self,