def load_config(config_path: str) -> Dict[str, Any]: try: with open(config_path) as data_file: return _load_config(data_file) except OSError: abs_path = os.path.abspath(config_path) raise TrainerConfigError( f"Config file could not be found at {abs_path}.") except UnicodeDecodeError: raise TrainerConfigError( f"There was an error decoding Config file from {config_path}. " f"Make sure your file is save using UTF-8")
def _check_threshold_value(self, attribute, value): """ Verify that the threshold has a value between 0 and 1 when the measure is PROGRESS """ if self.measure == self.MeasureType.PROGRESS: if self.threshold > 1.0: raise TrainerConfigError( "Threshold for next lesson cannot be greater than 1 when the measure is progress." ) if self.threshold < 0.0: raise TrainerConfigError( "Threshold for next lesson cannot be negative when the measure is progress." )
def convert_behaviors(old_trainer_config: Dict[str, Any]) -> Dict[str, Any]: all_behavior_config_dict = {} default_config = old_trainer_config.get("default", {}) for behavior_name, config in old_trainer_config.items(): if behavior_name != "default": config = default_config.copy() config.update(old_trainer_config[behavior_name]) # Convert to split TrainerSettings, Hyperparameters, NetworkSettings # Set trainer_type and get appropriate hyperparameter settings try: trainer_type = config["trainer"] except KeyError: raise TrainerConfigError( "Config doesn't specify a trainer type. " "Please specify trainer: in your config." ) new_config = {} new_config["trainer_type"] = trainer_type hyperparam_cls = TrainerType(trainer_type).to_settings() # Try to absorb as much as possible into the hyperparam_cls new_config["hyperparameters"] = cattr.structure(config, hyperparam_cls) # Try to absorb as much as possible into the network settings new_config["network_settings"] = cattr.structure(config, NetworkSettings) # Deal with recurrent try: if config["use_recurrent"]: new_config[ "network_settings" ].memory = NetworkSettings.MemorySettings( sequence_length=config["sequence_length"], memory_size=config["memory_size"], ) except KeyError: raise TrainerConfigError( "Config doesn't specify use_recurrent. " "Please specify true or false for use_recurrent in your config." ) # Absorb the rest into the base TrainerSettings for key, val in config.items(): if key in attr.fields_dict(TrainerSettings): new_config[key] = val # Structure the whole thing all_behavior_config_dict[behavior_name] = cattr.structure( new_config, TrainerSettings ) return all_behavior_config_dict
def structure(d: Mapping, t: type) -> Any: """ Helper method to structure a Dict of RewardSignalSettings class. Meant to be registered with cattr.register_structure_hook() and called with cattr.structure(). This is needed to handle the special Enum selection of RewardSignalSettings classes. """ if not isinstance(d, Mapping): raise TrainerConfigError(f"Unsupported reward signal configuration {d}.") d_final: Dict[RewardSignalType, RewardSignalSettings] = {} for key, val in d.items(): enum_key = RewardSignalType(key) t = enum_key.to_settings() d_final[enum_key] = strict_to_cls(val, t) # Checks to see if user specifying deprecated encoding_size for RewardSignals. # If network_settings is not specified, this updates the default hidden_units # to the value of encoding size. If specified, this ignores encoding size and # uses network_settings values. if "encoding_size" in val: logger.warning( "'encoding_size' was deprecated for RewardSignals. Please use network_settings." ) # If network settings was not specified, use the encoding size. Otherwise, use hidden_units if "network_settings" not in val: d_final[enum_key].network_settings.hidden_units = val[ "encoding_size" ] return d_final
def _check_batch_size_seq_length(self, attribute, value): if self.network_settings.memory is not None: if (self.network_settings.memory.sequence_length > self.hyperparameters.batch_size): raise TrainerConfigError( "When using memory, sequence length must be less than or equal to batch size. " )
def structure(d: Mapping, t: type) -> Dict[str, "EnvironmentParameterSettings"]: """ Helper method to structure a Dict of EnvironmentParameterSettings class. Meant to be registered with cattr.register_structure_hook() and called with cattr.structure(). """ if not isinstance(d, Mapping): raise TrainerConfigError( f"Unsupported parameter environment parameter settings {d}.") d_final: Dict[str, EnvironmentParameterSettings] = {} for environment_parameter, environment_parameter_config in d.items(): if (isinstance(environment_parameter_config, Mapping) and "curriculum" in environment_parameter_config): d_final[environment_parameter] = strict_to_cls( environment_parameter_config, EnvironmentParameterSettings) EnvironmentParameterSettings._check_lesson_chain( d_final[environment_parameter].curriculum, environment_parameter) else: sampler = ParameterRandomizationSettings.structure( environment_parameter_config, ParameterRandomizationSettings) d_final[environment_parameter] = EnvironmentParameterSettings( curriculum=[ Lesson( completion_criteria=None, value=sampler, name=environment_parameter, ) ]) return d_final
def check_and_structure(key: str, value: Any, class_type: type) -> Any: attr_fields_dict = attr.fields_dict(class_type) if key not in attr_fields_dict: raise TrainerConfigError( f"The option {key} was specified in your YAML file for {class_type.__name__}, but is invalid." ) # Apply cattr structure to the values return cattr.structure(value, attr_fields_dict[key].type)
def strict_to_cls(d: Mapping, t: type) -> Any: if not isinstance(d, Mapping): raise TrainerConfigError(f"Unsupported config {d} for {t.__name__}.") d_copy: Dict[str, Any] = {} d_copy.update(d) for key, val in d_copy.items(): d_copy[key] = check_and_structure(key, val, t) return t(**d_copy)
def from_argparse(args: argparse.Namespace) -> "RunOptions": """ Takes an argparse.Namespace as specified in `parse_command_line`, loads input configuration files from file paths, and converts to a RunOptions instance. :param args: collection of command-line parameters passed to mlagents-learn :return: RunOptions representing the passed in arguments, with trainer config, curriculum and sampler configs loaded from files. """ argparse_args = vars(args) config_path = StoreConfigFile.trainer_config_path # Load YAML configured_dict: Dict[str, Any] = { "checkpoint_settings": {}, "env_settings": {}, "engine_settings": {}, "torch_settings": {}, } _require_all_behaviors = True if config_path is not None: configured_dict.update(load_config(config_path)) else: # If we're not loading from a file, we don't require all behavior names to be specified. _require_all_behaviors = False # Use the YAML file values for all values not specified in the CLI. for key in configured_dict.keys(): # Detect bad config options if key not in attr.fields_dict(RunOptions): raise TrainerConfigError( "The option {} was specified in your YAML file, but is invalid.".format( key ) ) # Override with CLI args # Keep deprecated --load working, TODO: remove argparse_args["resume"] = argparse_args["resume"] or argparse_args["load_model"] for key, val in argparse_args.items(): if key in DetectDefault.non_default_args: if key in attr.fields_dict(CheckpointSettings): configured_dict["checkpoint_settings"][key] = val elif key in attr.fields_dict(EnvironmentSettings): configured_dict["env_settings"][key] = val elif key in attr.fields_dict(EngineSettings): configured_dict["engine_settings"][key] = val elif key in attr.fields_dict(TorchSettings): configured_dict["torch_settings"][key] = val else: # Base options configured_dict[key] = val final_runoptions = RunOptions.from_dict(configured_dict) final_runoptions.checkpoint_settings.prioritize_resume_init() # Need check to bypass type checking but keep structure on dict working if isinstance(final_runoptions.behaviors, TrainerSettings.DefaultTrainerDict): # configure whether or not we should require all behavior names to be found in the config YAML final_runoptions.behaviors.set_config_specified(_require_all_behaviors) return final_runoptions
def structure(d: Mapping, t: type) -> Any: """ Helper method to structure a TrainerSettings class. Meant to be registered with cattr.register_structure_hook() and called with cattr.structure(). """ if not isinstance(d, Mapping): raise TrainerConfigError(f"Unsupported config {d} for {t.__name__}.") d_copy: Dict[str, Any] = {} # Check if a default_settings was specified. If so, used those as the default # rather than an empty dict. if TrainerSettings.default_override is not None: d_copy.update(cattr.unstructure(TrainerSettings.default_override)) deep_update_dict(d_copy, d) if "framework" in d_copy: logger.warning("Framework option was deprecated but was specified") d_copy.pop("framework", None) for key, val in d_copy.items(): if attr.has(type(val)): # Don't convert already-converted attrs classes. continue if key == "hyperparameters": if "trainer_type" not in d_copy: raise TrainerConfigError( "Hyperparameters were specified but no trainer_type was given." ) else: d_copy[key] = check_hyperparam_schedules( val, d_copy["trainer_type"] ) d_copy[key] = strict_to_cls( d_copy[key], TrainerType(d_copy["trainer_type"]).to_settings() ) elif key == "max_steps": d_copy[key] = int(float(val)) # In some legacy configs, max steps was specified as a float else: d_copy[key] = check_and_structure(key, val, t) return t(**d_copy)
def _check_lesson_chain(lessons, parameter_name): """ Ensures that when using curriculum, all non-terminal lessons have a valid CompletionCriteria """ num_lessons = len(lessons) for index, lesson in enumerate(lessons): if index < num_lessons - 1 and lesson.completion_criteria is None: raise TrainerConfigError( f"A non-terminal lesson does not have a completion_criteria for {parameter_name}." )
def _load_config(fp: TextIO) -> Dict[str, Any]: """ Load the yaml config from the file-like object. """ try: return yaml.safe_load(fp) except yaml.parser.ParserError as e: raise TrainerConfigError( "Error parsing yaml file. Please check for formatting errors. " "A tool such as http://www.yamllint.com/ can be helpful with this." ) from e
def structure(d: Union[Mapping, float], t: type) -> "ParameterRandomizationSettings": """ Helper method to a ParameterRandomizationSettings class. Meant to be registered with cattr.register_structure_hook() and called with cattr.structure(). This is needed to handle the special Enum selection of ParameterRandomizationSettings classes. """ if isinstance(d, (float, int)): return ConstantSettings(value=d) if not isinstance(d, Mapping): raise TrainerConfigError( f"Unsupported parameter randomization configuration {d}.") if "sampler_type" not in d: raise TrainerConfigError( f"Sampler configuration does not contain sampler_type : {d}.") if "sampler_parameters" not in d: raise TrainerConfigError( f"Sampler configuration does not contain sampler_parameters : {d}." ) enum_key = ParameterRandomizationType(d["sampler_type"]) t = enum_key.to_settings() return strict_to_cls(d["sampler_parameters"], t)
def from_argparse(args: argparse.Namespace) -> "RunOptions": """ Takes an argparse.Namespace as specified in `parse_command_line`, loads input configuration files from file paths, and converts to a CommandLineOptions instance. :param args: collection of command-line parameters passed to mlagents-learn :return: CommandLineOptions representing the passed in arguments, with trainer config, curriculum and sampler configs loaded from files. """ argparse_args = vars(args) run_options_dict = {} run_options_dict.update(argparse_args) config_path = StoreConfigFile.trainer_config_path # Load YAML yaml_config = load_config(config_path) # This is the only option that is not optional and has no defaults. if "behaviors" not in yaml_config: raise TrainerConfigError( "Trainer configurations not found. Make sure your YAML file has a section for behaviors." ) # Use the YAML file values for all values not specified in the CLI. for key, val in yaml_config.items(): # Detect bad config options if not hasattr(RunOptions, key): raise TrainerConfigError( "The option {} was specified in your YAML file, but is invalid.".format( key ) ) if key not in DetectDefault.non_default_args: run_options_dict[key] = val # Keep deprecated --load working, TODO: remove run_options_dict["resume"] = ( run_options_dict["resume"] or run_options_dict["load_model"] ) return RunOptions(**run_options_dict)
def structure(d: Mapping, t: type) -> Any: """ Helper method to structure a Dict of RewardSignalSettings class. Meant to be registered with cattr.register_structure_hook() and called with cattr.structure(). This is needed to handle the special Enum selection of RewardSignalSettings classes. """ if not isinstance(d, Mapping): raise TrainerConfigError(f"Unsupported reward signal configuration {d}.") d_final: Dict[RewardSignalType, RewardSignalSettings] = {} for key, val in d.items(): enum_key = RewardSignalType(key) t = enum_key.to_settings() d_final[enum_key] = strict_to_cls(val, t) return d_final
def __missing__(self, key: Any) -> "TrainerSettings": if TrainerSettings.default_override is not None: self[key] = copy.deepcopy(TrainerSettings.default_override) elif self._config_specified: raise TrainerConfigError( f"The behavior name {key} has not been specified in the trainer configuration. " f"Please add an entry in the configuration file for {key}, or set default_settings." ) else: logger.warning( f"Behavior name {key} does not match any behaviors specified " f"in the trainer configuration file. A default configuration will be used." ) self[key] = TrainerSettings() return self[key]
def from_argparse(args: argparse.Namespace) -> "RunOptions": """ Takes an argparse.Namespace as specified in `parse_command_line`, loads input configuration files from file paths, and converts to a RunOptions instance. :param args: collection of command-line parameters passed to mlagents-learn :return: RunOptions representing the passed in arguments, with trainer config, curriculum and sampler configs loaded from files. """ argparse_args = vars(args) config_path = StoreConfigFile.trainer_config_path # Load YAML configured_dict: Dict[str, Any] = { "checkpoint_settings": {}, "env_settings": {}, "engine_settings": {}, "torch_settings": {}, } if config_path is not None: configured_dict.update(load_config(config_path)) # Use the YAML file values for all values not specified in the CLI. for key in configured_dict.keys(): # Detect bad config options if key not in attr.fields_dict(RunOptions): raise TrainerConfigError( "The option {} was specified in your YAML file, but is invalid.".format( key ) ) # Override with CLI args # Keep deprecated --load working, TODO: remove argparse_args["resume"] = argparse_args["resume"] or argparse_args["load_model"] for key, val in argparse_args.items(): if key in DetectDefault.non_default_args: if key in attr.fields_dict(CheckpointSettings): configured_dict["checkpoint_settings"][key] = val elif key in attr.fields_dict(EnvironmentSettings): configured_dict["env_settings"][key] = val elif key in attr.fields_dict(EngineSettings): configured_dict["engine_settings"][key] = val elif key in attr.fields_dict(TorchSettings): configured_dict["torch_settings"][key] = val else: # Base options configured_dict[key] = val final_runoptions = RunOptions.from_dict(configured_dict) return final_runoptions
def _check_lesson_chain(lessons, parameter_name): """ Ensures that when using curriculum, all non-terminal lessons have a valid CompletionCriteria, and that the terminal lesson does not contain a CompletionCriteria. """ num_lessons = len(lessons) for index, lesson in enumerate(lessons): if index < num_lessons - 1 and lesson.completion_criteria is None: raise TrainerConfigError( f"A non-terminal lesson does not have a completion_criteria for {parameter_name}." ) if index == num_lessons - 1 and lesson.completion_criteria is not None: warnings.warn( f"Your final lesson definition contains completion_criteria for {parameter_name}." f"It will be ignored.", TrainerConfigWarning, )
def get_run_options(config_path: str, run_id: str) -> RunOptions: configured_dict: Dict[str, Any] = { "checkpoint_settings": {}, "env_settings": {}, "engine_settings": {}, } if config_path is not None: config = mlagents.trainers.cli_utils.load_config(config_path) configured_dict.update(config) # Use the YAML file values for all values not specified in the CLI. for key in configured_dict.keys(): # Detect bad config options if key not in attr.fields_dict(RunOptions): raise TrainerConfigError( "The option {} was specified in your YAML file, but is invalid.".format( key ) ) configured_dict["checkpoint_settings"]["run_id"] = run_id return RunOptions.from_dict(configured_dict)
def initialize_trainer( trainer_settings: TrainerSettings, brain_name: str, run_id: str, output_path: str, train_model: bool, load_model: bool, ghost_controller: GhostController, seed: int, init_path: str = None, meta_curriculum: MetaCurriculum = None, multi_gpu: bool = False, ) -> Trainer: """ Initializes a trainer given a provided trainer configuration and brain parameters, as well as some general training session options. :param trainer_settings: Original trainer configuration loaded from YAML :param brain_name: Name of the brain to be associated with trainer :param run_id: Run ID to associate with this training run :param output_path: Path to save the model and summary statistics :param keep_checkpoints: How many model checkpoints to keep :param train_model: Whether to train the model (vs. run inference) :param load_model: Whether to load the model or randomly initialize :param ghost_controller: The object that coordinates ghost trainers :param seed: The random seed to use :param init_path: Path from which to load model, if different from model_path. :param meta_curriculum: Optional meta_curriculum, used to determine a reward buffer length for PPOTrainer :return: """ trainer_settings.output_path = os.path.join(output_path, brain_name) if init_path is not None: trainer_settings.init_path = os.path.join(init_path, brain_name) min_lesson_length = 1 if meta_curriculum: if brain_name in meta_curriculum.brains_to_curricula: min_lesson_length = meta_curriculum.brains_to_curricula[ brain_name].min_lesson_length else: logger.warning( f"Metacurriculum enabled, but no curriculum for brain {brain_name}. " f"Brains with curricula: {meta_curriculum.brains_to_curricula.keys()}. " ) trainer: Trainer = None # type: ignore # will be set to one of these, or raise trainer_type = trainer_settings.trainer_type if trainer_type == TrainerType.PPO: trainer = PPOTrainer( brain_name, min_lesson_length, trainer_settings, train_model, load_model, seed, run_id, ) elif trainer_type == TrainerType.SAC: trainer = SACTrainer( brain_name, min_lesson_length, trainer_settings, train_model, load_model, seed, run_id, ) else: raise TrainerConfigError( f'The trainer config contains an unknown trainer type "{trainer_type}" for brain {brain_name}' ) if trainer_settings.self_play is not None: trainer = GhostTrainer( trainer, brain_name, ghost_controller, min_lesson_length, trainer_settings, train_model, run_id, ) return trainer
def _initialize_trainer( trainer_settings: TrainerSettings, brain_name: str, output_path: str, train_model: bool, load_model: bool, ghost_controller: GhostController, seed: int, param_manager: EnvironmentParameterManager, init_path: str = None, multi_gpu: bool = False, ) -> Trainer: """ Initializes a trainer given a provided trainer configuration and brain parameters, as well as some general training session options. :param trainer_settings: Original trainer configuration loaded from YAML :param brain_name: Name of the brain to be associated with trainer :param output_path: Path to save the model and summary statistics :param keep_checkpoints: How many model checkpoints to keep :param train_model: Whether to train the model (vs. run inference) :param load_model: Whether to load the model or randomly initialize :param ghost_controller: The object that coordinates ghost trainers :param seed: The random seed to use :param param_manager: EnvironmentParameterManager, used to determine a reward buffer length for PPOTrainer :param init_path: Path from which to load model, if different from model_path. :return: """ trainer_artifact_path = os.path.join(output_path, brain_name) if init_path is not None: trainer_settings.init_path = os.path.join(init_path, brain_name) min_lesson_length = param_manager.get_minimum_reward_buffer_size( brain_name) trainer: Trainer = None # type: ignore # will be set to one of these, or raise trainer_type = trainer_settings.trainer_type if trainer_type == TrainerType.PPO: trainer = PPOTrainer( brain_name, min_lesson_length, trainer_settings, train_model, load_model, seed, trainer_artifact_path, ) elif trainer_type == TrainerType.SAC: trainer = SACTrainer( brain_name, min_lesson_length, trainer_settings, train_model, load_model, seed, trainer_artifact_path, ) else: raise TrainerConfigError( f'The trainer config contains an unknown trainer type "{trainer_type}" for brain {brain_name}' ) if trainer_settings.self_play is not None: trainer = GhostTrainer( trainer, brain_name, ghost_controller, min_lesson_length, trainer_settings, train_model, trainer_artifact_path, ) return trainer
def initialize_trainer( trainer_config: Any, brain_name: str, run_id: str, output_path: str, keep_checkpoints: int, train_model: bool, load_model: bool, ghost_controller: GhostController, seed: int, init_path: str = None, meta_curriculum: MetaCurriculum = None, multi_gpu: bool = False, ) -> Trainer: """ Initializes a trainer given a provided trainer configuration and brain parameters, as well as some general training session options. :param trainer_config: Original trainer configuration loaded from YAML :param brain_name: Name of the brain to be associated with trainer :param run_id: Run ID to associate with this training run :param output_path: Path to save the model and summary statistics :param keep_checkpoints: How many model checkpoints to keep :param train_model: Whether to train the model (vs. run inference) :param load_model: Whether to load the model or randomly initialize :param ghost_controller: The object that coordinates ghost trainers :param seed: The random seed to use :param init_path: Path from which to load model, if different from model_path. :param meta_curriculum: Optional meta_curriculum, used to determine a reward buffer length for PPOTrainer :return: """ if "default" not in trainer_config and brain_name not in trainer_config: raise TrainerConfigError( f'Trainer config must have either a "default" section, or a section for the brain name {brain_name}. ' "See the config/ directory for examples.") trainer_parameters = trainer_config.get("default", {}).copy() trainer_parameters["output_path"] = os.path.join(output_path, brain_name) if init_path is not None: trainer_parameters["init_path"] = os.path.join(init_path, brain_name) trainer_parameters["keep_checkpoints"] = keep_checkpoints if brain_name in trainer_config: _brain_key: Any = brain_name while not isinstance(trainer_config[_brain_key], dict): _brain_key = trainer_config[_brain_key] trainer_parameters.update(trainer_config[_brain_key]) if init_path is not None: trainer_parameters["init_path"] = "{basedir}/{name}".format( basedir=init_path, name=brain_name) min_lesson_length = 1 if meta_curriculum: if brain_name in meta_curriculum.brains_to_curricula: min_lesson_length = meta_curriculum.brains_to_curricula[ brain_name].min_lesson_length else: logger.warning( f"Metacurriculum enabled, but no curriculum for brain {brain_name}. " f"Brains with curricula: {meta_curriculum.brains_to_curricula.keys()}. " ) trainer: Trainer = None # type: ignore # will be set to one of these, or raise if "trainer" not in trainer_parameters: raise TrainerConfigError( f'The "trainer" key must be set in your trainer config for brain {brain_name} (or the default brain).' ) trainer_type = trainer_parameters["trainer"] if trainer_type == "offline_bc": raise UnityTrainerException( "The offline_bc trainer has been removed. To train with demonstrations, " "please use a PPO or SAC trainer with the GAIL Reward Signal and/or the " "Behavioral Cloning feature enabled.") elif trainer_type == "ppo": trainer = PPOTrainer( brain_name, min_lesson_length, trainer_parameters, train_model, load_model, seed, run_id, ) elif trainer_type == "sac": trainer = SACTrainer( brain_name, min_lesson_length, trainer_parameters, train_model, load_model, seed, run_id, ) else: raise TrainerConfigError( f'The trainer config contains an unknown trainer type "{trainer_type}" for brain {brain_name}' ) if "self_play" in trainer_parameters: trainer = GhostTrainer( trainer, brain_name, ghost_controller, min_lesson_length, trainer_parameters, train_model, run_id, ) return trainer
def initialize_trainer( trainer_config: Any, brain_parameters: BrainParameters, summaries_dir: str, run_id: str, model_path: str, keep_checkpoints: int, train_model: bool, load_model: bool, seed: int, meta_curriculum: MetaCurriculum = None, multi_gpu: bool = False, ) -> Trainer: """ Initializes a trainer given a provided trainer configuration and brain parameters, as well as some general training session options. :param trainer_config: Original trainer configuration loaded from YAML :param brain_parameters: BrainParameters provided by the Unity environment :param summaries_dir: Directory to store trainer summary statistics :param run_id: Run ID to associate with this training run :param model_path: Path to save the model :param keep_checkpoints: How many model checkpoints to keep :param train_model: Whether to train the model (vs. run inference) :param load_model: Whether to load the model or randomly initialize :param seed: The random seed to use :param meta_curriculum: Optional meta_curriculum, used to determine a reward buffer length for PPOTrainer :param multi_gpu: Whether to use multi-GPU training :return: """ brain_name = brain_parameters.brain_name if "default" not in trainer_config and brain_name not in trainer_config: raise TrainerConfigError( f'Trainer config must have either a "default" section, or a section for the brain name ({brain_name}). ' "See config/trainer_config.yaml for an example.") trainer_parameters = trainer_config.get("default", {}).copy() trainer_parameters["summary_path"] = str(run_id) + "_" + brain_name trainer_parameters["model_path"] = "{basedir}/{name}".format( basedir=model_path, name=brain_name) trainer_parameters["keep_checkpoints"] = keep_checkpoints if brain_name in trainer_config: _brain_key: Any = brain_name while not isinstance(trainer_config[_brain_key], dict): _brain_key = trainer_config[_brain_key] trainer_parameters.update(trainer_config[_brain_key]) min_lesson_length = 1 if meta_curriculum: if brain_name in meta_curriculum.brains_to_curriculums: min_lesson_length = meta_curriculum.brains_to_curriculums[ brain_name].min_lesson_length else: logger.warning( f"Metacurriculum enabled, but no curriculum for brain {brain_name}. " f"Brains with curricula: {meta_curriculum.brains_to_curriculums.keys()}. " ) trainer: Trainer = None # type: ignore # will be set to one of these, or raise if "trainer" not in trainer_parameters: raise TrainerConfigError( f'The "trainer" key must be set in your trainer config for brain {brain_name} (or the default brain).' ) trainer_type = trainer_parameters["trainer"] if trainer_type == "offline_bc": raise UnityTrainerException( "The offline_bc trainer has been removed. To train with demonstrations, " "please use a PPO or SAC trainer with the GAIL Reward Signal and/or the " "Behavioral Cloning feature enabled.") elif trainer_type == "ppo": trainer = PPOTrainer( brain_parameters, min_lesson_length, trainer_parameters, train_model, load_model, seed, run_id, multi_gpu, ) elif trainer_type == "sac": trainer = SACTrainer( brain_parameters, min_lesson_length, trainer_parameters, train_model, load_model, seed, run_id, ) else: raise TrainerConfigError( f'The trainer config contains an unknown trainer type "{trainer_type}" for brain {brain_name}' ) return trainer
def __str__(self) -> str: """ Helper method to output sampler stats to console. """ raise TrainerConfigError(f"__str__ not implemented for type {self.__class__}.")
def _check_min_value(self, attribute, value): if self.min_value > self.max_value: raise TrainerConfigError( "Minimum value is greater than maximum value in uniform sampler." )