예제 #1
0
 def from_dict(cls, kwargs) -> "TrainingConfig":
     kwargs = {**kwargs}  # make a copy to avoid mutating the input
     if "input_variables" in kwargs:
         warnings.warn(
             "input_variables is no longer a top-level TrainingConfig "
             "parameter, pass it under hyperparameters instead",
             DeprecationWarning,
         )
         kwargs["hyperparameters"]["input_variables"] = kwargs.pop(
             "input_variables")
     if "output_variables" in kwargs:
         warnings.warn(
             "output_variables is no longer a top-level TrainingConfig "
             "parameter, pass it under hyperparameters instead",
             DeprecationWarning,
         )
         kwargs["hyperparameters"]["output_variables"] = kwargs.pop(
             "output_variables")
     hyperparameter_class = get_hyperparameter_class(kwargs["model_type"])
     kwargs["hyperparameters"] = dacite.from_dict(
         data_class=hyperparameter_class,
         data=kwargs.get("hyperparameters", {}),
         config=dacite.Config(strict=True),
     )
     return dacite.from_dict(data_class=cls,
                             data=kwargs,
                             config=dacite.Config(strict=True))
예제 #2
0
 def from_json(cls, serialized, many=False):
     if many:
         items = cls.schema().loads(serialized, many=True)
         return [
             dacite.from_dict(data_class=cls,
                              data=a,
                              config=dacite.Config(check_types=False))
             for a in items
         ]
     else:
         return dacite.from_dict(data_class=cls,
                                 data=cls.schema().loads(serialized),
                                 config=dacite.Config(check_types=False))
예제 #3
0
def test_config():

    conv1 = conv_layer("relu", 32, 4, 2, 0)
    conv2 = conv_layer("relu", 64, 4, 2, 0)
    conv3 = conv_layer("relu", 128, 4, 2, 0)
    conv4 = conv_layer("relu", 256, 4, 2, 0)

    deconv1 = conv_layer("relu", 128, 5, 2, 0)
    deconv2 = conv_layer("relu", 64, 5, 2, 0)
    deconv3 = conv_layer("relu", 32, 6, 2, 0)
    deconv4 = conv_layer("sigmoid", 3, 6, 2, 0)

    raw_data = {
        "image_channels": 3,
        "nb_layers": 4,
        "conv_layers": (conv1, conv2, conv3, conv4),
        "fc1_in": 256,
        "fc1_out": 128,
        "latent_dim": 32,
        "fc2_out": 1024,
        "deconv_layers": (deconv1, deconv2, deconv3, deconv4),
        "tr_epochs": 50000,
        "batch_size": 12,
        "resize": (64, 64),
        "ckp_folder": "C:/Users/vince/Documents/AI/world_models/ckp",
        "ckp_path": "last_ckp.pth"
    }
    converters = {}
    config = dacite.from_dict(data_class=conv_params,
                              data=raw_data,
                              config=dacite.Config(type_hooks=converters))

    return config
예제 #4
0
    def from_dict(
        component_descriptor_dict: dict,
        validation_mode: ValidationMode = ValidationMode.NONE,
    ):
        component_descriptor = dacite.from_dict(
            data_class=ComponentDescriptor,
            data=component_descriptor_dict,
            config=dacite.Config(
                cast=[
                    OciComponentNameMapping,
                    AccessType,
                    Provider,
                    ResourceType,
                    SchemaVersion,
                    SourceType,
                    ResourceRelation,
                ],
                type_hooks={
                    typing.Union[AccessType, str]:
                    functools.partial(enum_or_string, enum_type=AccessType),
                    typing.Union[ResourceType, str]:
                    functools.partial(enum_or_string, enum_type=ResourceType),
                    typing.Union[OciComponentNameMapping, str]:
                    functools.partial(enum_or_string,
                                      enum_type=OciComponentNameMapping),
                },
            ))
        if not validation_mode is ValidationMode.NONE:
            ComponentDescriptor.validate(
                component_descriptor_dict=component_descriptor_dict,
                validation_mode=validation_mode,
            )

        return component_descriptor
예제 #5
0
    def load(
        self, file_path: pathlib.Path, config_data: PluginConfig, root: Config
    ) -> Sequence[ComponentBase]:
        config = FakePluginConfig()
        if config_data.config is not None:
            config = dacite.from_dict(
                FakePluginConfig, config_data.config, dacite.Config(strict=True)
            )

        r = [0.0]
        components: List[ComponentBase] = []

        if config.enable_c1:
            components.append(
                FakeComponent(
                    f"{self._name}_component1",
                    {"op1": (2, Operation.MUL), "op2": (10, Operation.ADD)},
                    None,
                    None,
                    r,
                )
            )

        if config.enable_c2:
            components.append(
                FakeComponent(
                    f"{self._name}_component2",
                    {"op1": (3, Operation.MUL), "op3": (-1, Operation.MUL)},
                    None,
                    None,
                    r,
                )
            )

        return components
    def from_dict(
        component_descriptor_dict: dict,
        validation_mode: ValidationMode = ValidationMode.NONE,
    ):
        component_descriptor = dacite.from_dict(
            data_class=ComponentDescriptor,
            data=component_descriptor_dict,
            config=dacite.Config(
                cast=[
                    AccessType,
                    Provider,
                    ResourceType,
                    SchemaVersion,
                    SourceType,
                    ResourceRelation,
                ]
            )
        )
        if not validation_mode is ValidationMode.NONE:
            ComponentDescriptor.validate(
                component_descriptor_dict=component_descriptor_dict,
                validation_mode=validation_mode,
            )

        return component_descriptor
예제 #7
0
def insights_iter(file_path: pathlib.Path) -> Iterable[Prediction]:
    for prediction in jsonl_iter(file_path):
        yield dacite.from_dict(
            data_class=Prediction,
            data=prediction,
            config=dacite.Config(cast=[PredictionType]),
        )
예제 #8
0
 def from_dict(cls, value: dict):
     """ Creates a gateway config from a dictionary """
     return dacite.from_dict(cls, value, config=dacite.Config(
         type_hooks={
             ChaincodeLanguage: ChaincodeLanguage
         }
     ))
예제 #9
0
    def fetch_hashes(
        self,
        *,
        page_size: int = 800,
        start_timestamp: int = DEFAULT_START_TIME,
        next_page: str = "",
    ) -> FetchHashesResponse:
        """
        Fetch a series of update records from the hash API.

        Records represent the current snapshot of all data, so if you see
        the same SignalType+Hash in a later iteration, it should completely
        replace the previously observed record.
        """
        params: t.Dict[str, t.Any] = {
            "startTimestamp": start_timestamp,
            "pageSize": page_size,
        }
        if next_page:
            params["nextPageToken"] = next_page
        logging.debug("StopNCII FetchHashes called: %s", params)
        json_val = self._get(StopNCIIEndpoint.FetchHashes, **params)
        logging.debug("StopNCII FetchHashes returns: %s", json_val)
        return dacite.from_dict(
            data_class=FetchHashesResponse,
            data=json_val,
            config=dacite.Config(cast=[enum.Enum, set]),
        )
예제 #10
0
def get_booking(response) -> Booking:
    '''Getting and testing response from the /booking endpoint'''
    if type(response) is not dict:
        raise FailedTest('The response should be a JSON Object')
    try:
        booking = dacite.from_dict(data_class=Booking,
                                   data=response,
                                   config=dacite.Config(strict=True))
    except (
            dacite.exceptions.WrongTypeError,
            dacite.exceptions.MissingValueError,
            dacite.exceptions.UnexpectedDataError,
    ):
        raise FailedTest(
            'Incorrect JSON format in response from the /booking endpoint')

    if booking.barcode_format not in ('QRCODE', 'CODE128', 'CODE39', 'ITF',
                                      'DATAMATRIX', 'EAN13'):
        raise FailedTest(
            f'Incorrect barcode format ({booking.barcode_format})')
    if booking.barcode_position not in ('order', 'ticket'):
        raise FailedTest(
            f'Incorrect value in the barcode_position field ({booking.barcode_position})'
        )
    if booking.barcode_position == 'order' and not booking.barcode:
        raise FailedTest('Barcode for the whole order is empty')
    if booking.barcode_position == 'ticket' and not booking.tickets:
        raise FailedTest('Tickets Array is empty')
    return booking
예제 #11
0
def parse_availability_timeslots(raw_response: Request, response) -> List[Timeslot]:
    '''Getting and testing response from the /timeslots endpoint'''
    if type(response) is not list:
        raise FailedTest(
            message='The response should be a JSON Array',
            response=raw_response,
        )
    try:
        timeslots = [
            dacite.from_dict(
                data_class=Timeslot,
                data=timeslot,
                config=dacite.Config(
                    type_hooks={date: date.fromisoformat},
                    strict=True,
                )
            )
            for timeslot in response
        ]
    except (
        dacite.exceptions.WrongTypeError,
        dacite.exceptions.MissingValueError,
        dacite.exceptions.UnexpectedDataError,
    ) as e:
        raise FailedTest(
            message=f'Incorrect JSON format in response from the /timeslots endpoint: {format_error_message(e)}',
            response=raw_response,
        )
    return timeslots
예제 #12
0
def release_manifest_set(
    s3_client: 'botocore.client.S3',
    bucket_name: str,
    manifest_key: str,
    absent_ok: bool = False,
) -> glci.model.OnlineReleaseManifest:
    buf = io.BytesIO()
    try:
        s3_client.download_fileobj(
            Bucket=bucket_name,
            Key=manifest_key,
            Fileobj=buf,
        )
    except botocore.exceptions.ClientError as e:
        if absent_ok and str(e.response['Error']['Code']) == '404':
            return None
        raise e

    buf.seek(0)
    parsed = yaml.safe_load(buf)

    parsed['s3_bucket'] = bucket_name
    parsed['s3_key'] = manifest_key

    print(manifest_key)
    manifest = dacite.from_dict(
        data_class=glci.model.OnlineReleaseManifestSet,
        data=parsed,
        config=dacite.Config(cast=[
            glci.model.Architecture,
            typing.Tuple,
            glci.model.TestResultCode,
        ], ),
    )
    return manifest
예제 #13
0
def release_manifest(
    s3_client: 'botocore.client.S3',
    bucket_name: str,
    key: str,
) -> glci.model.OnlineReleaseManifest:
    '''
    retrieves and deserialises a gardenlinux release manifest from the specified s3 object
    (expects a YAML or JSON document)
    '''
    buf = io.BytesIO()
    s3_client.download_fileobj(
        Bucket=bucket_name,
        Key=key,
        Fileobj=buf,
    )
    buf.seek(0)
    parsed = yaml.safe_load(buf)

    # patch-in transient attrs
    parsed['s3_key'] = key
    parsed['s3_bucket'] = bucket_name

    manifest = dacite.from_dict(
        data_class=glci.model.OnlineReleaseManifest,
        data=parsed,
        config=dacite.Config(cast=[glci.model.Architecture, typing.Tuple], ),
    )

    return manifest
예제 #14
0
def enumerate_build_flavours(build_yaml: str='../build.yaml'):
    with open(build_yaml) as f:
        parsed = yaml.safe_load(f)

    GardenlinuxFlavour = glci.model.GardenlinuxFlavour
    GardenlinuxFlavourCombination = glci.model.GardenlinuxFlavourCombination
    Architecture = glci.model.Architecture
    Platform = glci.model.Platform
    Extension = glci.model.Extension
    Modifier = glci.model.Modifier

    flavour_combinations = [
        dacite.from_dict(
            data_class=GardenlinuxFlavourCombination,
            data=flavour_def,
            config=dacite.Config(
                cast=[Architecture, Platform, Extension, Modifier, typing.Tuple],
            )
        ) for flavour_def in parsed['flavours']
    ]
    for comb in flavour_combinations:
        for arch, platf, exts, mods in itertools.product(
            comb.architectures,
            comb.platforms,
            comb.extensions,
            comb.modifiers,
        ):
            yield GardenlinuxFlavour(
                architecture=arch,
                platform=platf,
                extensions=exts,
                modifiers=mods,
                # fails=comb.fails, # not part of variant
            )
예제 #15
0
def get_reservation(raw_response: Request, response) -> Reservation:
    '''Getting and testing response from the /reservation endpoint'''
    if type(response) is not dict:
        raise FailedTest(
            message='The response should be a JSON Object',
            response=raw_response,
        )
    try:
        return dacite.from_dict(
            data_class=Reservation,
            data=response,
            config=dacite.Config(
                type_hooks={datetime: datetime.fromisoformat},
                strict=True,
            )
        )
    except (
        dacite.exceptions.WrongTypeError,
        dacite.exceptions.MissingValueError,
        dacite.exceptions.UnexpectedDataError,
    ) as e:
        raise FailedTest(
            message=f'Incorrect JSON format in response from the /reservation endpoint: {format_error_message(e)}',
            response=raw_response,
        )
예제 #16
0
    def from_dict(cls, descriptor_dict: Dict[str, Any], config: DescriptorConfig = None):
        try:
            descriptor = dacite.from_dict(
                data_class=BenchmarkDescriptor,
                data=descriptor_dict,
                config=dacite.Config(
                    type_hooks={
                        DistributedStrategy: lambda dist_strat: DistributedStrategy(dist_strat),
                        ExecutionEngine: lambda exec_eng: ExecutionEngine(exec_eng),
                        MLFramework: lambda framework: MLFramework(framework),
                        HttpProbeScheme: lambda scheme_str: HttpProbeScheme(scheme_str.lower()),
                    },
                    strict=True,
                ),
            )
            if config:
                if descriptor.hardware.strategy.value not in config.valid_strategies:
                    raise DescriptorError(
                        f"Invalid strategy: {descriptor.hardware.strategy} (must be one of {config.valid_strategies})"
                    )

                if descriptor.ml.framework.value not in config.valid_frameworks:
                    raise DescriptorError(
                        f"Invalid framework  {descriptor.ml.framework} (must be one of {config.valid_frameworks}"
                    )
            return descriptor
        except (dacite.MissingValueError, dacite.WrongTypeError, ValueError) as err:
            raise DescriptorError(f"Error parsing descriptor: {err}")
예제 #17
0
def _run_status(dict_with_status: dict):
    '''
    determines the current status of a given tekton entity bearing a `status`.
    Examples of such entities are:
    - pipelineruns
    - taskruns

    the passed `dict` is expected to bear an attribute `status`, with a sub-attr `conditions`, which
    in turn is parsable into a list of `TknCondition`
    '''
    if not 'status' in dict_with_status:
        # XXX if we are too early, there is no status, yet
        return None

    status = dict_with_status['status']
    conditions = [
        dacite.from_dict(
            data=condition,
            data_class=TknCondition,
            config=dacite.Config(
                cast=[
                    StatusReason,
                ],
            ),
        ) for condition in status['conditions']
    ]

    latest_condition = sorted(
        conditions,
        key=lambda c: dateutil.parser.isoparse(c.lastTransitionTime)
    )[-1]

    return latest_condition
 def _resource(
     image_reference,
     name='abc',
     version='1.2.3',
 ):
     return dacite.from_dict(
         data_class=gci.componentmodel.Resource,
         data={
             'name': name,
             'version': version,
             'type': 'ociImage',
             'access': {
                 'type': 'ociRegistry',
                 'imageReference': image_reference,
             },
         },
         config=dacite.Config(
             cast=[
                 gci.componentmodel.ResourceType,
                 gci.componentmodel.AccessType,
             ],
             type_hooks={
                 typing.Union[gci.componentmodel.AccessType, str]:
                 functools.partial(
                     gci.componentmodel.enum_or_string,
                     enum_type=gci.componentmodel.AccessType,
                 ),
                 typing.Union[gci.componentmodel.ResourceType, str]:
                 functools.partial(
                     gci.componentmodel.enum_or_string,
                     enum_type=gci.componentmodel.ResourceType,
                 ),
             },
         ),
     )
예제 #19
0
 def from_dict(cls, diag_table: dict):
     file_configs = [
         dacite.from_dict(DiagFileConfig,
                          f,
                          config=dacite.Config(cast=[Enum]))
         for f in diag_table["file_configs"]
     ]
     return cls(diag_table["name"], diag_table["base_time"], file_configs)
예제 #20
0
def get_source_scan_label_from_labels(labels: typing.Sequence[cm.Label]):
    for label in labels:
        if sdo.labels.ScanLabelName(
                label.name) is sdo.labels.ScanLabelName.SOURCE_SCAN:
            return dacite.from_dict(
                sdo.labels.SourceScanHint,
                data=label.value,
                config=dacite.Config(cast=[sdo.labels.ScanPolicy]))
예제 #21
0
 def load(self, file_path: pathlib.Path, config_data: PluginConfig,
          root: Config) -> Sequence[ComponentBase]:
     assert config_data.config is not None, f"{config_data.location}.config must be not None"
     config = dacite.from_dict(AutoflakeSetting, config_data.config,
                               dacite.Config(strict=True))
     source = Source(config.includes, config.excludes, config.include_globs,
                     config.exclude_globs)
     return [Autoflake(setting=config, source=source)]
예제 #22
0
def _get_ws_label_from_artifact(
        source: cm.ComponentSource) -> dso.labels.SourceIdHint:
    if label := source.find_label(dso.labels.ScanLabelName.SOURCE_ID.value):
        return dacite.from_dict(
            data_class=dso.labels.SourceIdHint,
            data=label.value,
            config=dacite.Config(cast=[dso.labels.ScanPolicy]),
        )
예제 #23
0
 def get_ml_model(self, ml_model_id: ObjectId) -> MlModel:
     ml_model = self.collection.find_one({"_id": ml_model_id})
     if ml_model is None:
         raise NonExistentError(
             "Workspace with the given id does not exist")
     return dacite.from_dict(
         data_class=MlModel,
         data=ml_model,
         config=dacite.Config(cast=[SensorComponent, Enum]))
def main():
    dacite_conf = dacite.Config(cast=[Enum])
    with open("base_config.yaml") as f:
        conf: Dict = yaml.load(f, Loader=yaml.FullLoader)

    conf: Config = dacite.from_dict(data_class=Config, data=conf,
                                    config=dacite_conf)

    print(train_and_test(conf, dirpath='run', progress=True))
예제 #25
0
    def from_dict(cls, data):
        def _from_int_str(x):
            try:
                return int(x)
            except ValueError:
                return int(f"0x{x}", 16)

        return dacite.from_dict(cls, data,
                                dacite.Config(type_hooks={int: _from_int_str}))
예제 #26
0
 async def get_workspace(self, workspace_id: ObjectId) -> Workspace:
     workspace = await self.collection.find_one({"_id": workspace_id})
     if workspace is None:
         raise NonExistentError(
             "Workspace with the given id does not exist")
     return dacite.from_dict(
         data_class=Workspace,
         data=workspace,
         config=dacite.Config(cast=[SensorComponent, Enum]))
예제 #27
0
def parse_datasources(*args) -> List[SourceConfig]:
    config = Config(*args)

    data_sources = dacite.from_dict(
        data_class=__DataSourceConfig,
        data=dict(datasources=config["datasources"]),
        config=dacite.Config(strict=True))

    return data_sources.datasources
예제 #28
0
    def from_dict(image_config_dict: dict) -> 'ImageBuildConfig':
        image_config = dacite.from_dict(data_class=ImageBuildConfig,
                                        data=image_config_dict,
                                        config=dacite.Config(cast=[
                                            SourceType,
                                            TagType,
                                        ]))

        return image_config
예제 #29
0
    def get_config():
        with open(Config._config_full_path, "r") as file:
            raw_config = json.load(file)

        config = dacite.from_dict(
            data_class=Configuration,
            data=raw_config,
            config=dacite.Config(type_hooks=Config._converters))
        return config
예제 #30
0
 def _cfg_types(self):
     return {
         cfg.cfg_type_name(): cfg
         for cfg in (dacite.from_dict(
             data_class=ConfigType,
             data=cfg_dict,
             config=dacite.Config(cast=[tuple], ),
         ) for cfg_dict in self.raw[self.CFG_TYPES].values())
     }