Пример #1
0
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
# =============================================================================
"""Top-level recidiviz package."""
import datetime

import cattr

from recidiviz.ingest.models.ingest_info import IngestInfo
from recidiviz.ingest.scrape import ingest_utils
from recidiviz.utils import environment

# TODO(#3820): Move these hooks out of this global file
# We want to add these globally because the serialization hooks are used in
# ingest and persistence.

cattr.register_unstructure_hook(datetime.datetime, datetime.datetime.isoformat)
cattr.register_structure_hook(
    datetime.datetime,
    lambda serialized, desired_type: datetime.datetime.fromisoformat(serialized
                                                                     ),
)

cattr.register_unstructure_hook(IngestInfo,
                                ingest_utils.ingest_info_to_serializable)
cattr.register_structure_hook(
    IngestInfo,
    lambda serializable, desired_type: ingest_utils.
    ingest_info_from_serializable(serializable),
)
Пример #2
0
    in the json for fields the caller set EXPLICIT_NULL on.
    """
    data = cattr.global_converter.unstructure_attrs_asdict(api_model)
    for key, value in data.copy().items():
        if value is None:
            del data[key]
        elif value == model.EXPLICIT_NULL:
            data[key] = None
    for reserved in keyword.kwlist:
        if f"{reserved}_" in data:
            data[reserved] = data.pop(f"{reserved}_")
    return data


if sys.version_info < (3, 7):
    from dateutil import parser

    def datetime_structure_hook(d: str, t: datetime.datetime) -> datetime.datetime:
        return parser.isoparse(d)


else:

    def datetime_structure_hook(d: str, t: datetime.datetime) -> datetime.datetime:
        return datetime.datetime.strptime(d, "%Y-%m-%dT%H:%M:%S.%f%z")


converter31.register_structure_hook(datetime.datetime, datetime_structure_hook)
converter40.register_structure_hook(datetime.datetime, datetime_structure_hook)
cattr.register_unstructure_hook(model.Model, unstructure_hook)  # type: ignore
Пример #3
0
class RunOptions(ExportableSettings):
    default_settings: Optional[TrainerSettings] = None
    behaviors: TrainerSettings.DefaultTrainerDict = attr.ib(
        factory=TrainerSettings.DefaultTrainerDict
    )
    env_settings: EnvironmentSettings = attr.ib(factory=EnvironmentSettings)
    engine_settings: EngineSettings = attr.ib(factory=EngineSettings)
    environment_parameters: Optional[Dict[str, EnvironmentParameterSettings]] = None
    checkpoint_settings: CheckpointSettings = attr.ib(factory=CheckpointSettings)
    torch_settings: TorchSettings = attr.ib(factory=TorchSettings)

    # These are options that are relevant to the run itself, and not the engine or environment.
    # They will be left here.
    debug: bool = parser.get_default("debug")

    # Convert to settings while making sure all fields are valid
    cattr.register_structure_hook(EnvironmentSettings, strict_to_cls)
    cattr.register_structure_hook(EngineSettings, strict_to_cls)
    cattr.register_structure_hook(CheckpointSettings, strict_to_cls)
    cattr.register_structure_hook(
        Dict[str, EnvironmentParameterSettings], EnvironmentParameterSettings.structure
    )
    cattr.register_structure_hook(Lesson, strict_to_cls)
    cattr.register_structure_hook(
        ParameterRandomizationSettings, ParameterRandomizationSettings.structure
    )
    cattr.register_unstructure_hook(
        ParameterRandomizationSettings, ParameterRandomizationSettings.unstructure
    )
    cattr.register_structure_hook(TrainerSettings, TrainerSettings.structure)
    cattr.register_structure_hook(
        TrainerSettings.DefaultTrainerDict, TrainerSettings.dict_to_trainerdict
    )
    cattr.register_unstructure_hook(collections.defaultdict, defaultdict_to_dict)

    @staticmethod
    def from_argparse(args: argparse.Namespace) -> "RunOptions":
        """
        Takes an argparse.Namespace as specified in `parse_command_line`, loads input configuration files
        from file paths, and converts to a RunOptions instance.
        :param args: collection of command-line parameters passed to mlagents-learn
        :return: RunOptions representing the passed in arguments, with trainer config, curriculum and sampler
          configs loaded from files.
        """
        argparse_args = vars(args)
        config_path = StoreConfigFile.trainer_config_path

        # Load YAML
        configured_dict: Dict[str, Any] = {
            "checkpoint_settings": {},
            "env_settings": {},
            "engine_settings": {},
            "torch_settings": {},
        }
        _require_all_behaviors = True
        if config_path is not None:
            configured_dict.update(load_config(config_path))
        else:
            # If we're not loading from a file, we don't require all behavior names to be specified.
            _require_all_behaviors = False

        # Use the YAML file values for all values not specified in the CLI.
        for key in configured_dict.keys():
            # Detect bad config options
            if key not in attr.fields_dict(RunOptions):
                raise TrainerConfigError(
                    "The option {} was specified in your YAML file, but is invalid.".format(
                        key
                    )
                )

        # Override with CLI args
        # Keep deprecated --load working, TODO: remove
        argparse_args["resume"] = argparse_args["resume"] or argparse_args["load_model"]

        for key, val in argparse_args.items():
            if key in DetectDefault.non_default_args:
                if key in attr.fields_dict(CheckpointSettings):
                    configured_dict["checkpoint_settings"][key] = val
                elif key in attr.fields_dict(EnvironmentSettings):
                    configured_dict["env_settings"][key] = val
                elif key in attr.fields_dict(EngineSettings):
                    configured_dict["engine_settings"][key] = val
                elif key in attr.fields_dict(TorchSettings):
                    configured_dict["torch_settings"][key] = val
                else:  # Base options
                    configured_dict[key] = val

        final_runoptions = RunOptions.from_dict(configured_dict)
        final_runoptions.checkpoint_settings.prioritize_resume_init()
        # Need check to bypass type checking but keep structure on dict working
        if isinstance(final_runoptions.behaviors, TrainerSettings.DefaultTrainerDict):
            # configure whether or not we should require all behavior names to be found in the config YAML
            final_runoptions.behaviors.set_config_specified(_require_all_behaviors)

        _non_default_args = DetectDefault.non_default_args

        # Prioritize the deterministic mode from the cli for deterministic actions.
        if "deterministic" in _non_default_args:
            for behaviour in final_runoptions.behaviors.keys():
                final_runoptions.behaviors[
                    behaviour
                ].network_settings.deterministic = argparse_args["deterministic"]

        return final_runoptions

    @staticmethod
    def from_dict(options_dict: Dict[str, Any]) -> "RunOptions":
        # If a default settings was specified, set the TrainerSettings class override
        if (
            "default_settings" in options_dict.keys()
            and options_dict["default_settings"] is not None
        ):
            TrainerSettings.default_override = cattr.structure(
                options_dict["default_settings"], TrainerSettings
            )
        return cattr.structure(options_dict, RunOptions)
Пример #4
0
"""
dataclasses for the configuration system

Copyright (c) 2020 The Fuel Rat Mischief,
All rights reserved.

Licensed under the BSD 3-Clause License.

See LICENSE.md
"""
from ipaddress import ip_address, IPv4Address, IPv6Address

from .prometheus import IPAddress
from .root import ConfigRoot
import cattr

__all__ = ["ConfigRoot"]


def structure_ip_address(raw: str, *args) -> IPAddress:
    return ip_address(raw)


cattr.register_structure_hook(IPAddress, structure_ip_address)
cattr.register_structure_hook(IPv4Address, structure_ip_address)
cattr.register_structure_hook(IPv6Address, structure_ip_address)

cattr.register_unstructure_hook(IPv6Address, lambda data: f"{data}")
cattr.register_unstructure_hook(IPv4Address, lambda data: f"{data}")
Пример #5
0
from dataclasses import dataclass
from typing import NewType, Optional

import cattr

JobId = NewType('JobId', str)


@dataclass
class Job:
    id: Optional[JobId]
    name: str
    description: str


# Serialization. cattr is basically Jackson for python.
# Instead of using the global cattr we'd be better off compiling
# an optimized converter that we attach to pyramid
# but I'm using the global here to save setup time.
cattr.register_structure_hook(Job, lambda d, typ: Job(**d))
cattr.register_unstructure_hook(Job, lambda e: asdict(e))
Пример #6
0
        default=False,
        converter=to_bool,
        desc="If True, certain methods will print a progress bar to the screen",
    )
    cache_dir: Path = attrib(
        default=".quantized/cache",
        converter=Path,
        desc="The directory where cached objects are stored",
    )
    joblib_verbosity: int = attrib(default=0,
                                   converter=int,
                                   desc="Verbosity level for joblib's cache")


cattr.register_structure_hook(Path, lambda s, t: Path(s))
cattr.register_unstructure_hook(Path, lambda p: str(p))
default_conf = Config()


def load(p: Path = default_conf_path) -> Config:
    """Load a configuration from a json file"""

    conf_file_d = json.loads(p.read_text())
    return attr.evolve(default_conf, **conf_file_d)


try:
    conf: Config = load()
except FileNotFoundError:
    conf: Config = Config()
Пример #7
0
from dataclasses import dataclass, asdict
from typing import NewType, Optional

import cattr

from .job import JobId
from .department import DepartmentId

EmployeeId = NewType('EmployeeId', str)


@dataclass
class Employee:
    id: Optional[EmployeeId]
    name: str
    photo_id: str
    job: JobId
    department: DepartmentId
    location: str


# Serialization. cattr is basically Jackson for python.
# Instead of using the global cattr we'd be better off compiling
# an optimized converter that we attach to pyramid
# but I'm using the global here to save setup time.
cattr.register_structure_hook(Employee, lambda d, typ: Employee(**d))
cattr.register_unstructure_hook(Employee, lambda e: asdict(e))
Пример #8
0
from dataclasses import dataclass
from typing import NewType, Optional

import cattr

DepartmentId = NewType('DepartmentId', str)


@dataclass
class Department:
    id: Optional[DepartmentId]
    name: str


# Serialization. cattr is basically Jackson for python.
# Instead of using the global cattr we'd be better off compiling
# an optimized converter that we attach to pyramid
# but I'm using the global here to save setup time.
cattr.register_structure_hook(Department, lambda d, typ: Department(**d))
cattr.register_unstructure_hook(Department, lambda e: asdict(e))
Пример #9
0
    return obj.isoformat()


def _structure_datetime(obj: str, cls: Type[datetime]) -> datetime:
    return cls.fromisoformat(obj.replace("Z", "+00:00"))


def _unstructure_datetime(obj: datetime) -> str:
    # If no timezone is present, assume UTC
    if not obj.tzinfo:
        obj = obj.replace(tzinfo=timezone.utc)
    return obj.isoformat(timespec="milliseconds")


cattr.register_structure_hook(date, _structure_date)
cattr.register_unstructure_hook(date, _unstructure_date)
cattr.register_structure_hook(datetime, _structure_datetime)
cattr.register_unstructure_hook(datetime, _unstructure_datetime)


@attr.s
class _HasFields:
    """A class that has fields in the API."""
    @classmethod
    def fields(cls) -> List[str]:
        """Build a list of field names needed to create the Python model.

        :return: A list of field names for the ``opt_fields`` input to the Asana API.
        """
        fields = attr.fields(cls)
        field_types = [(f.name, innermost_type(f.type)) for f in fields
Пример #10
0
import cattr

from .cursor import Cursor


@dataclass
class Pagination:
    first: Optional[Cursor] = None
    last: Optional[Cursor] = None
    prev: Optional[Cursor] = None
    next: Optional[Cursor] = None


cattr.register_structure_hook(Pagination, lambda d, typ: Pagination(**d))
cattr.register_unstructure_hook(Pagination, lambda e: asdict(e))

# TODO handle undefined vs null in serialization here
T = TypeVar('T')


@dataclass
class Envelope(Generic[T]):
    data: T
    links: Optional[Pagination] = None


# Serialization. cattr is basically Jackson for python.
# Instead of using the global cattr we'd be better off compiling
# an optimized converter that we attach to pyramid
# but I'm using the global here to save setup time.
Пример #11
0
def parse_version(version: str) -> Version:
    version = version.strip()
    if version == "*":
        return AnyVersion()
    if "," in version or version[0] in ["<", ">", "=", "~", "!"]:
        spec_set = SpecifierSet(version)
    else:
        spec_set = SpecifierSet("==" + version)
    if len(spec_set) == 1:
        spec = next(iter(spec_set))
        if spec.operator == "==" and not spec.version.endswith(".*"):
            # Equal, and not a '*' thing. Must be exact.
            return ExactVersion(spec.version)
    return DynamicVersion(spec_set)


def cattr_parse_version(value, _typ) -> Version:
    return parse_version(str(value))


cattr.register_structure_hook(Version, cattr_parse_version)
cattr.register_structure_hook(AnyVersion, cattr_parse_version)
cattr.register_structure_hook(DynamicVersion, cattr_parse_version)
cattr.register_structure_hook(ExactVersion, cattr_parse_version)

cattr.register_unstructure_hook(Version, str)
cattr.register_unstructure_hook(AnyVersion, str)
cattr.register_unstructure_hook(DynamicVersion, str)
cattr.register_unstructure_hook(ExactVersion, str)
Пример #12
0
 def unstructure(self):
     cattr.register_unstructure_hook(Decimal, lambda d: str(d))
     return cattr.unstructure(self)
Пример #13
0
      binfiles_file_group=FileGroup.binfile(base_dir),
    )

  @property
  def vpaths(self) -> List[Path]:
    return list(chain(self.binfiles_file_group.vpaths, self.dotfiles_file_group.vpaths))

  @property
  def link_data(self) -> List[LinkData]:
    return list(collapse((fg.link_data for fg in self.file_groups), base_type=LinkData))

  @property
  def file_groups(self) -> List[FileGroup]:
    return [self.binfiles_file_group, self.dotfiles_file_group]


## register necessary serde with cattr


def _unstructure_path(posx: Path) -> str:
  return str(posx)


def _structure_path(pstr: str, typ: Type[Path]) -> Path:
  return Path(pstr)


# mypy: no-disallow-untyped-calls
cattr.register_structure_hook(Path, _structure_path)
cattr.register_unstructure_hook(Path, _unstructure_path)
Пример #14
0
            "code_red": "codeRed",
            "platform": "platform",
        }
        # translate internal datamodel names to the APIs datamodel names
        keep = {
            field_map[field]
            for field in changes if field != "mark_for_deletion"
        }

        if "mark_for_deletion" in changes:
            keep |= {"outcome"}
            self.attributes.outcome = "purge"
        # serialize API rescue object
        data = attr.asdict(self, recurse=True)
        # figure out which keys we need to keep (only send the ones modified internally)
        kept_attribs = {
            key: value
            for key, value in data["attributes"].items() if key in keep
        }
        # and patch the object.
        data["attributes"] = kept_attribs
        return data


@attr.dataclass
class RescueDocument(Document):
    data: Rescue


cattr.register_unstructure_hook(Rescue, lambda rescue: rescue.to_delta())
Пример #15
0
        # Lets create a skeleton object, use the filename for the name since old LEAP
        # skeletons did not have names.
        skeleton = cls(name=filename)

        skel_mat = loadmat(filename)
        skel_mat["nodes"] = skel_mat["nodes"][0][0]  # convert to scalar
        skel_mat[
            "edges"] = skel_mat["edges"] - 1  # convert to 0-based indexing

        node_names = skel_mat["nodeNames"]
        node_names = [str(n[0][0]) for n in node_names]
        skeleton.add_nodes(node_names)
        for k in range(len(skel_mat["edges"])):
            edge = skel_mat["edges"][k]
            skeleton.add_edge(source=node_names[edge[0]],
                              destination=node_names[edge[1]])

        return skeleton

    def __hash__(self):
        """
        Construct a hash from skeleton id.
        """
        return id(self)


cattr.register_unstructure_hook(Skeleton,
                                lambda skeleton: Skeleton.to_dict(skeleton))
cattr.register_structure_hook(Skeleton,
                              lambda dicts, cls: Skeleton.from_dict(dicts))
Пример #16
0
from contextlib import contextmanager
from pathlib import Path
from datetime import datetime
from time import time

import attr
import cattr
import json
import typing

TIME_FORMAT = "%Y-%m-%d %H:%M:%S"

# Used to avoid duplicates in the log
from pybloom_live import BloomFilter

cattr.register_unstructure_hook(
    datetime, lambda dt: datetime.strftime(dt, format=TIME_FORMAT))


def make_filter():
    return BloomFilter(capacity=settings["MAX_POSTS"], error_rate=0.001)


@attr.s
class Archive:
    archive_type = attr.ib()

    # We give the Archive class a file handle
    archive_file = attr.ib()

    _bloom_filter = attr.ib(factory=make_filter)
Пример #17
0
class RunOptions(ExportableSettings):
    behaviors: DefaultDict[str, TrainerSettings] = attr.ib(
        factory=lambda: collections.defaultdict(TrainerSettings)
    )
    env_settings: EnvironmentSettings = attr.ib(factory=EnvironmentSettings)
    engine_settings: EngineSettings = attr.ib(factory=EngineSettings)
    parameter_randomization: Optional[Dict] = None
    curriculum: Optional[Dict[str, CurriculumSettings]] = None
    checkpoint_settings: CheckpointSettings = attr.ib(factory=CheckpointSettings)

    # These are options that are relevant to the run itself, and not the engine or environment.
    # They will be left here.
    debug: bool = parser.get_default("debug")
    # Strict conversion
    cattr.register_structure_hook(EnvironmentSettings, strict_to_cls)
    cattr.register_structure_hook(EngineSettings, strict_to_cls)
    cattr.register_structure_hook(CheckpointSettings, strict_to_cls)
    cattr.register_structure_hook(CurriculumSettings, strict_to_cls)
    cattr.register_structure_hook(TrainerSettings, TrainerSettings.structure)
    cattr.register_structure_hook(
        DefaultDict[str, TrainerSettings], TrainerSettings.dict_to_defaultdict
    )
    cattr.register_unstructure_hook(collections.defaultdict, defaultdict_to_dict)

    @staticmethod
    def from_argparse(args: argparse.Namespace) -> "RunOptions":
        """
        Takes an argparse.Namespace as specified in `parse_command_line`, loads input configuration files
        from file paths, and converts to a RunOptions instance.
        :param args: collection of command-line parameters passed to mlagents-learn
        :return: RunOptions representing the passed in arguments, with trainer config, curriculum and sampler
          configs loaded from files.
        """
        argparse_args = vars(args)
        config_path = StoreConfigFile.trainer_config_path

        # Load YAML
        configured_dict: Dict[str, Any] = {
            "checkpoint_settings": {},
            "env_settings": {},
            "engine_settings": {},
        }
        if config_path is not None:
            configured_dict.update(load_config(config_path))

        # Use the YAML file values for all values not specified in the CLI.
        for key in configured_dict.keys():
            # Detect bad config options
            if key not in attr.fields_dict(RunOptions):
                raise TrainerConfigError(
                    "The option {} was specified in your YAML file, but is invalid.".format(
                        key
                    )
                )
        # Override with CLI args
        # Keep deprecated --load working, TODO: remove
        argparse_args["resume"] = argparse_args["resume"] or argparse_args["load_model"]
        for key, val in argparse_args.items():
            if key in DetectDefault.non_default_args:
                if key in attr.fields_dict(CheckpointSettings):
                    configured_dict["checkpoint_settings"][key] = val
                elif key in attr.fields_dict(EnvironmentSettings):
                    configured_dict["env_settings"][key] = val
                elif key in attr.fields_dict(EngineSettings):
                    configured_dict["engine_settings"][key] = val
                else:  # Base options
                    configured_dict[key] = val
        return RunOptions.from_dict(configured_dict)

    @staticmethod
    def from_dict(options_dict: Dict[str, Any]) -> "RunOptions":
        return cattr.structure(options_dict, RunOptions)
Пример #18
0
import datetime
from attr import attrs, attrib
import cattr

TIME_FORMAT = '%Y-%m-%dT%H:%M:%S.%fZ'


@attrs
class Event(object):
    happened_at = attrib(type=datetime.datetime)


cattr.register_unstructure_hook(datetime.datetime,
                                lambda dt: dt.strftime(TIME_FORMAT))
cattr.register_structure_hook(
    datetime.datetime,
    lambda string, _: datetime.datetime.strptime(string, TIME_FORMAT))

event = Event(happened_at=datetime.datetime(2019, 6, 1))
print('event:', event)
json = cattr.unstructure(event)
print('json:', json)
event = cattr.structure(json, Event)
print('Event:', event)