Пример #1
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # collected_rewards is a dictionary from name of reward signal to a dictionary of agent_id to cumulative reward
        # used for reporting only. We always want to report the environment reward to Tensorboard, regardless
        # of what reward signals are actually present.
        self.cumulative_returns_since_policy_update: List[float] = []
        self.collected_rewards: Dict[str, Dict[str, int]] = {
            "environment": defaultdict(lambda: 0)
        }
        self.update_buffer: AgentBuffer = AgentBuffer()
        self._stats_reporter.add_property(StatsPropertyType.HYPERPARAMETERS,
                                          self.trainer_settings.as_dict())
        self.framework = self.trainer_settings.framework
        if self.framework == FrameworkType.PYTORCH and not torch_utils.is_available(
        ):
            raise UnityTrainerException(
                "To use the experimental PyTorch backend, install the PyTorch Python package first."
            )

        logger.debug(f"Using framework {self.framework.value}")

        self._next_save_step = 0
        self._next_summary_step = 0
        self.model_saver = self.create_model_saver(self.framework,
                                                   self.trainer_settings,
                                                   self.artifact_path,
                                                   self.load)
Пример #2
0
    def __init__(
        self,
        trainer_factory: TrainerFactory,
        output_path: str,
        run_id: str,
        param_manager: EnvironmentParameterManager,
        train: bool,
        training_seed: int,
    ):
        """
        :param output_path: Path to save the model.
        :param summaries_dir: Folder to save training summaries.
        :param run_id: The sub-directory name for model and summary statistics
        :param param_manager: EnvironmentParameterManager object which stores information about all
        environment parameters.
        :param train: Whether to train model, or only run inference.
        :param training_seed: Seed to use for Numpy and Tensorflow random number generation.
        :param threaded: Whether or not to run trainers in a separate thread. Disable for testing/debugging.
        """
        self.trainers: Dict[str, Trainer] = {}
        self.brain_name_to_identifier: Dict[str, Set] = defaultdict(set)
        self.trainer_factory = trainer_factory
        self.output_path = output_path
        self.logger = get_logger(__name__)
        self.run_id = run_id
        self.train_model = train
        self.param_manager = param_manager
        self.ghost_controller = self.trainer_factory.ghost_controller
        self.registered_behavior_ids: Set[str] = set()

        self.trainer_threads: List[threading.Thread] = []
        self.kill_trainers = False
        np.random.seed(training_seed)
        tf.set_random_seed(training_seed)
        if torch_utils.is_available():
            torch_utils.torch.manual_seed(training_seed)
        self.rank = get_rank()
Пример #3
0
from mlagents.trainers.policy.checkpoint_manager import NNCheckpoint

from mlagents_envs.logging_util import get_logger
from mlagents_envs.timers import timed
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.policy import Policy
from mlagents.trainers.sac.optimizer_tf import SACOptimizer
from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.trajectory import Trajectory, SplitObservations
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.settings import TrainerSettings, SACSettings, FrameworkType
from mlagents.trainers.components.reward_signals import RewardSignal
from mlagents import torch_utils

if torch_utils.is_available():
    from mlagents.trainers.policy.torch_policy import TorchPolicy
    from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer
else:
    TorchPolicy = None  # type: ignore
    TorchSACOptimizer = None  # type: ignore

logger = get_logger(__name__)

BUFFER_TRUNCATE_PERCENT = 0.8


class SACTrainer(RLTrainer):
    """
    The SACTrainer is an implementation of the SAC algorithm, with support
    for discrete actions and recurrent networks.