Esempio n. 1
0
    elif isinstance(value, str):
        return VersionedPlugin(cattr.structure(value, Version))
    elif not isinstance(value, dict):
        raise ValueError("Unsupported plugin type {value!r}")
    elif "version" in value:
        value = value.pop("version")
        if value is None:
            return VersionedPlugin(AnyVersion())
        elif isinstance(value, str):
            return VersionedPlugin(cattr.structure(value, Version))
    elif "git" in value:
        return GitPlugin(**value)
    elif "link" in value:
        return RawPlugin(**value)
    raise ValueError("Unsupported plugin type {value!r}")


cattr.register_structure_hook(Plugin, parse_plugin)


def unstringify_distribution_file(
    base_dir: pathlib.Path, file_content: str
) -> DistributionFile:
    v: DistributionFile = cattr.structure(toml.loads(file_content), DistributionFile)
    v.normalise_license_key(base_dir)
    return v


def stringify_distribution_file(dist_file: DistributionFile) -> str:
    return toml.dumps(cattr.unstructure(dist_file))
Esempio n. 2
0
class TrainerSettings(ExportableSettings):
    trainer_type: TrainerType = TrainerType.PPO
    hyperparameters: HyperparamSettings = attr.ib()

    @hyperparameters.default
    def _set_default_hyperparameters(self):
        return self.trainer_type.to_settings()()

    network_settings: NetworkSettings = attr.ib(factory=NetworkSettings)
    reward_signals: Dict[RewardSignalType, RewardSignalSettings] = attr.ib(
        factory=lambda: {RewardSignalType.EXTRINSIC: RewardSignalSettings()})
    init_path: Optional[str] = None
    keep_checkpoints: int = 5
    checkpoint_interval: int = 500000
    max_steps: int = 500000
    time_horizon: int = 64
    summary_freq: int = 50000
    threaded: bool = True
    self_play: Optional[SelfPlaySettings] = None
    behavioral_cloning: Optional[BehavioralCloningSettings] = None

    cattr.register_structure_hook(Dict[RewardSignalType, RewardSignalSettings],
                                  RewardSignalSettings.structure)

    @network_settings.validator
    def _check_batch_size_seq_length(self, attribute, value):
        if self.network_settings.memory is not None:
            if (self.network_settings.memory.sequence_length >
                    self.hyperparameters.batch_size):
                raise TrainerConfigError(
                    "When using memory, sequence length must be less than or equal to batch size. "
                )

    @staticmethod
    def dict_to_defaultdict(d: Dict, t: type) -> DefaultDict:
        return collections.defaultdict(
            TrainerSettings, cattr.structure(d, Dict[str, TrainerSettings]))

    @staticmethod
    def structure(d: Mapping, t: type) -> Any:
        """
        Helper method to structure a TrainerSettings class. Meant to be registered with
        cattr.register_structure_hook() and called with cattr.structure().
        """
        if not isinstance(d, Mapping):
            raise TrainerConfigError(
                f"Unsupported config {d} for {t.__name__}.")
        d_copy: Dict[str, Any] = {}
        d_copy.update(d)

        for key, val in d_copy.items():
            if attr.has(type(val)):
                # Don't convert already-converted attrs classes.
                continue
            if key == "hyperparameters":
                if "trainer_type" not in d_copy:
                    raise TrainerConfigError(
                        "Hyperparameters were specified but no trainer_type was given."
                    )
                else:
                    d_copy[key] = strict_to_cls(
                        d_copy[key],
                        TrainerType(d_copy["trainer_type"]).to_settings())
            elif key == "max_steps":
                d_copy[key] = int(float(val))
                # In some legacy configs, max steps was specified as a float
            else:
                d_copy[key] = check_and_structure(key, val, t)
        return t(**d_copy)
Esempio n. 3
0
            tags=[item.name],
            title=sub_item.name,
            description=sub_item.name,
            http_url=self.__internal_variables(
                sub_item.request.url.get("raw")),
            http_method=sub_item.request.method,
            http_request_body=self.__internal_variables(
                sub_item.request.body.get("raw")),
            http_headers={
                k: DynamicStringData(value=self.__internal_variables(v))
                for k, v in kv_list_to_dict(sub_item.request.header).items()
            },
            sequence_number=self.app_state_interactor.update_sequence_number(),
        )
        return api_call

    def __internal_variables(self, input_str):
        return (re.sub(self.var_selector, r"${\2}", input_str, count=0)
                if input_str else "")

    def __validate_file(self, file_path):
        if not (Path(file_path).exists() and Path(file_path).is_file()):
            raise FileNotFoundError(
                f"Path {file_path} should be a valid file path")


for cls in [PostmanDataModel, PostmanItem, PostmanSubItem, PostmanRequest]:
    cattr.register_structure_hook(cls, structure_attrs_from_dict)

importer = PostmanCollectionImporter()
Esempio n. 4
0
def unstructure_hook(api_model):
    """cattr unstructure hook

    Map reserved_ words in models to correct json field names.
    Also handle stripping None fields from dict while setting
    EXPLICIT_NULL fields to None so that we only send null
    in the json for fields the caller set EXPLICIT_NULL on.
    """
    data = cattr.global_converter.unstructure_attrs_asdict(api_model)
    for key, value in data.copy().items():
        if value is None:
            del data[key]
        elif value == model.EXPLICIT_NULL:
            data[key] = None
    for reserved in keyword.kwlist:
        if f"{reserved}_" in data:
            data[reserved] = data.pop(f"{reserved}_")
    return data


structure_hook_func = functools.partial(structure_hook,
                                        globals())  # type: ignore
cattr.register_structure_hook(model.Model, structure_hook_func)  # type: ignore
cattr.register_structure_hook(
    datetime.datetime,
    lambda d, _: datetime.datetime.strptime(  # type: ignore
        d, "%Y-%m-%dT%H:%M:%S.%f%z"),
)
cattr.register_unstructure_hook(model.Model, unstructure_hook)  # type: ignore
Esempio n. 5
0
class RunOptions(ExportableSettings):
    default_settings: Optional[TrainerSettings] = None
    behaviors: DefaultDict[str, TrainerSettings] = attr.ib(
        factory=TrainerSettings.DefaultTrainerDict)
    env_settings: EnvironmentSettings = attr.ib(factory=EnvironmentSettings)
    engine_settings: EngineSettings = attr.ib(factory=EngineSettings)
    environment_parameters: Optional[Dict[str,
                                          EnvironmentParameterSettings]] = None
    checkpoint_settings: CheckpointSettings = attr.ib(
        factory=CheckpointSettings)

    # These are options that are relevant to the run itself, and not the engine or environment.
    # They will be left here.
    debug: bool = parser.get_default("debug")
    # Strict conversion
    cattr.register_structure_hook(EnvironmentSettings, strict_to_cls)
    cattr.register_structure_hook(EngineSettings, strict_to_cls)
    cattr.register_structure_hook(CheckpointSettings, strict_to_cls)
    cattr.register_structure_hook(Dict[str, EnvironmentParameterSettings],
                                  EnvironmentParameterSettings.structure)
    cattr.register_structure_hook(Lesson, strict_to_cls)
    cattr.register_structure_hook(ParameterRandomizationSettings,
                                  ParameterRandomizationSettings.structure)
    cattr.register_unstructure_hook(ParameterRandomizationSettings,
                                    ParameterRandomizationSettings.unstructure)
    cattr.register_structure_hook(TrainerSettings, TrainerSettings.structure)
    cattr.register_structure_hook(DefaultDict[str, TrainerSettings],
                                  TrainerSettings.dict_to_defaultdict)
    cattr.register_unstructure_hook(collections.defaultdict,
                                    defaultdict_to_dict)

    @staticmethod
    def from_argparse(args: argparse.Namespace) -> "RunOptions":
        """
        Takes an argparse.Namespace as specified in `parse_command_line`, loads input configuration files
        from file paths, and converts to a RunOptions instance.
        :param args: collection of command-line parameters passed to mlagents-learn
        :return: RunOptions representing the passed in arguments, with trainer config, curriculum and sampler
          configs loaded from files.
        """
        argparse_args = vars(args)
        config_path = StoreConfigFile.trainer_config_path

        # Load YAML
        configured_dict: Dict[str, Any] = {
            "checkpoint_settings": {},
            "env_settings": {},
            "engine_settings": {},
        }
        if config_path is not None:
            configured_dict.update(load_config(config_path))

        # Use the YAML file values for all values not specified in the CLI.
        for key in configured_dict.keys():
            # Detect bad config options
            if key not in attr.fields_dict(RunOptions):
                raise TrainerConfigError(
                    "The option {} was specified in your YAML file, but is invalid."
                    .format(key))
        # Override with CLI args
        # Keep deprecated --load working, TODO: remove
        argparse_args[
            "resume"] = argparse_args["resume"] or argparse_args["load_model"]
        for key, val in argparse_args.items():
            if key in DetectDefault.non_default_args:
                if key in attr.fields_dict(CheckpointSettings):
                    configured_dict["checkpoint_settings"][key] = val
                elif key in attr.fields_dict(EnvironmentSettings):
                    configured_dict["env_settings"][key] = val
                elif key in attr.fields_dict(EngineSettings):
                    configured_dict["engine_settings"][key] = val
                else:  # Base options
                    configured_dict[key] = val

        final_runoptions = RunOptions.from_dict(configured_dict)
        return final_runoptions

    @staticmethod
    def from_dict(options_dict: Dict[str, Any]) -> "RunOptions":
        # If a default settings was specified, set the TrainerSettings class override
        if ("default_settings" in options_dict.keys()
                and options_dict["default_settings"] is not None):
            TrainerSettings.default_override = cattr.structure(
                options_dict["default_settings"], TrainerSettings)
        return cattr.structure(options_dict, RunOptions)
Esempio n. 6
0
            dotfiles_file_group=FileGroup.dotfile(base_dir),
            binfiles_file_group=FileGroup.binfile(base_dir),
        )

    @property
    def vpaths(self) -> List[Path]:
        return list(
            chain(self.binfiles_file_group.vpaths,
                  self.dotfiles_file_group.vpaths))

    @property
    def link_data(self) -> List[LinkData]:
        return list(
            chain(self.binfiles_file_group.link_data,
                  self.dotfiles_file_group.link_data))


## register necessary serde with cattr


def _unstructure_path(posx: Path) -> str:
    return str(posx)


def _structure_path(pstr: str, typ: Type[Path]) -> Path:
    return Path(pstr)


cattr.register_structure_hook(Path, _structure_path)
cattr.register_unstructure_hook(Path, _unstructure_path)
Esempio n. 7
0
class TrainerSettings(ExportableSettings):
    default_override: ClassVar[Optional["TrainerSettings"]] = None
    trainer_type: TrainerType = TrainerType.PPO
    hyperparameters: HyperparamSettings = attr.ib()

    @hyperparameters.default
    def _set_default_hyperparameters(self):
        return self.trainer_type.to_settings()()

    network_settings: NetworkSettings = attr.ib(factory=NetworkSettings)
    reward_signals: Dict[RewardSignalType, RewardSignalSettings] = attr.ib(
        factory=lambda: {RewardSignalType.EXTRINSIC: RewardSignalSettings()})
    init_path: Optional[str] = None
    keep_checkpoints: int = 5
    checkpoint_interval: int = 500000
    max_steps: int = 500000
    time_horizon: int = 64
    summary_freq: int = 50000
    threaded: bool = True
    self_play: Optional[SelfPlaySettings] = None
    behavioral_cloning: Optional[BehavioralCloningSettings] = None

    cattr.register_structure_hook(Dict[RewardSignalType, RewardSignalSettings],
                                  RewardSignalSettings.structure)

    @network_settings.validator
    def _check_batch_size_seq_length(self, attribute, value):
        if self.network_settings.memory is not None:
            if (self.network_settings.memory.sequence_length >
                    self.hyperparameters.batch_size):
                raise TrainerConfigError(
                    "When using memory, sequence length must be less than or equal to batch size. "
                )

    @staticmethod
    def dict_to_defaultdict(d: Dict, t: type) -> DefaultDict:
        return TrainerSettings.DefaultTrainerDict(
            cattr.structure(d, Dict[str, TrainerSettings]))

    @staticmethod
    def structure(d: Mapping, t: type) -> Any:
        """
        Helper method to structure a TrainerSettings class. Meant to be registered with
        cattr.register_structure_hook() and called with cattr.structure().
        """

        if not isinstance(d, Mapping):
            raise TrainerConfigError(
                f"Unsupported config {d} for {t.__name__}.")

        d_copy: Dict[str, Any] = {}

        # Check if a default_settings was specified. If so, used those as the default
        # rather than an empty dict.
        if TrainerSettings.default_override is not None:
            d_copy.update(cattr.unstructure(TrainerSettings.default_override))

        deep_update_dict(d_copy, d)

        if "framework" in d_copy:
            logger.warning("Framework option was deprecated but was specified")
            d_copy.pop("framework", None)

        for key, val in d_copy.items():
            if attr.has(type(val)):
                # Don't convert already-converted attrs classes.
                continue
            if key == "hyperparameters":
                if "trainer_type" not in d_copy:
                    raise TrainerConfigError(
                        "Hyperparameters were specified but no trainer_type was given."
                    )
                else:
                    d_copy[key] = strict_to_cls(
                        d_copy[key],
                        TrainerType(d_copy["trainer_type"]).to_settings())
            elif key == "max_steps":
                d_copy[key] = int(float(val))
                # In some legacy configs, max steps was specified as a float
            else:
                d_copy[key] = check_and_structure(key, val, t)
        return t(**d_copy)

    class DefaultTrainerDict(collections.defaultdict):
        def __init__(self, *args):
            # Depending on how this is called, args may have the defaultdict
            # callable at the start of the list or not. In particular, unpickling
            # will pass [TrainerSettings].
            if args and args[0] == TrainerSettings:
                super().__init__(*args)
            else:
                super().__init__(TrainerSettings, *args)

        def __missing__(self, key: Any) -> "TrainerSettings":
            if TrainerSettings.default_override is not None:
                return copy.deepcopy(TrainerSettings.default_override)
            else:
                return TrainerSettings()
Esempio n. 8
0
    enable_progressbar: bool = attrib(
        default=False,
        converter=to_bool,
        desc="If True, certain methods will print a progress bar to the screen",
    )
    cache_dir: Path = attrib(
        default=".quantized/cache",
        converter=Path,
        desc="The directory where cached objects are stored",
    )
    joblib_verbosity: int = attrib(default=0,
                                   converter=int,
                                   desc="Verbosity level for joblib's cache")


cattr.register_structure_hook(Path, lambda s, t: Path(s))
cattr.register_unstructure_hook(Path, lambda p: str(p))
default_conf = Config()


def load(p: Path = default_conf_path) -> Config:
    """Load a configuration from a json file"""

    conf_file_d = json.loads(p.read_text())
    return attr.evolve(default_conf, **conf_file_d)


try:
    conf: Config = load()
except FileNotFoundError:
    conf: Config = Config()
Esempio n. 9
0
import datetime
from attr import attrs, attrib
import cattr

TIME_FORMAT = '%Y-%m-%dT%H:%M:%S.%fZ'


@attrs
class Event(object):
    happened_at = attrib(type=datetime.datetime)


cattr.register_unstructure_hook(datetime.datetime,
                                lambda dt: dt.strftime(TIME_FORMAT))
cattr.register_structure_hook(
    datetime.datetime,
    lambda string, _: datetime.datetime.strptime(string, TIME_FORMAT))

event = Event(happened_at=datetime.datetime(2019, 6, 1))
print('event:', event)
json = cattr.unstructure(event)
print('json:', json)
event = cattr.structure(json, Event)
print('Event:', event)