Exemple #1
0
 def test_constructable(self, example_config: Dict[str, Any],
                        example_config_injected: Dict[str, Any],
                        example_injection: Dict[str, Any]):
     injector = Injector(mapping=example_injection)
     example_config_after_injection = injector.inject_pass(example_config)
     assert json.dumps(example_config_after_injection) == json.dumps(
         example_config_injected)
Exemple #2
0
    def construct_components(
            config: Dict,
            component_names: List[str],
            device: torch.device,
            external_injection: Dict[str, Any] = None) -> Dict[str, Any]:
        if external_injection is not None:
            injection_mapping = {
                "id_conv_mnist_standard_collator": MNISTCollator,
                "id_computation_device": device,
                **external_injection
            }
            injector = Injector(injection_mapping,
                                raise_mapping_not_found=True)

        else:
            injection_mapping = {
                "id_conv_mnist_standard_collator": MNISTCollator
            }
            injector = Injector(injection_mapping,
                                raise_mapping_not_found=False)

        component_factory = ComponentFactory(injector)
        component_factory.register_component_type(
            "MODEL_REGISTRY", "DEFAULT", MyModelRegistryConstructable)
        components = component_factory.build_components_from_config(
            config, component_names)
        return components
Exemple #3
0
 def construct(self) -> AbstractGymJob:
     experiment_info = self.get_experiment_info(self.dashify_logging_dir,
                                                self.grid_search_id,
                                                self.model_name,
                                                self.dataset_name,
                                                self.run_id)
     component_names = ["model", "trainer", "optimizer", "evaluator"]
     injection_mapping = {"id_ata_collator": ATACollator}
     injector = Injector(injection_mapping)
     component_factory = ComponentFactory(injector)
     component_factory.register_component_type(
         "DATASET_REPOSITORY", "DEFAULT",
         CustomDatasetRepositoryConstructable)
     component_factory.register_component_type("DATASET_FACTORY", "ATIS",
                                               AtisFactoryConstructable)
     component_factory.register_component_type("DATASET_FACTORY", "KDD",
                                               KddFactoryConstructable)
     component_factory.register_component_type(
         "LOSS_FUNCTION_REGISTRY", "DEFAULT",
         LossFunctionRegistryConstructable),
     component_factory.register_component_type(
         "MODEL_REGISTRY", "DEFAULT", CustomModelRegistryConstructable)
     component_factory.register_component_type(
         "TRAIN_COMPONENT", "ATA", ATATrainComponentConstructable),
     component_factory.register_component_type(
         "EVAL_COMPONENT", "ATA", ATAEvalComponentConstructable)
     components = component_factory.build_components_from_config(
         self.config, component_names)
     gym_job = GymJob(self.run_mode,
                      experiment_info=experiment_info,
                      epochs=self.epochs,
                      **components)
     return gym_job
 def components(self, full_config) -> Dict:
     component_names = list(full_config.keys())
     injection_mapping = {"id_mlp_standard_collator": MLPCollator}
     injector = Injector(injection_mapping)
     component_factory = ComponentFactory(injector)
     component_factory.register_component_type("DATASET_REPOSITORY", "DEFAULT", CustomDatasetRepositoryConstructable)
     components = component_factory.build_components_from_config(full_config, component_names)
     return components
Exemple #5
0
    def create_blue_prints(self, blue_print_type: Type[BluePrint],
                           job_type: AbstractGymJob.Type,
                           gs_config: Dict[str, Any], num_epochs: int,
                           dashify_logging_path: str) -> List[Type[BluePrint]]:

        run_id_to_config_dict = {
            run_id: config
            for run_id, config in enumerate(
                GridSearch.create_gs_from_config_dict(gs_config))
        }

        fold_indices = self._get_fold_indices()

        splits: List[Dict[str, Any]] = CrossValidation._create_folds_splits(
            fold_indices)

        blueprints = []
        experiment_id = 0
        for split in splits:
            for config_id, experiment_config in run_id_to_config_dict.items():
                external_injection = {
                    "id_experiment_id": experiment_id,
                    "id_hyper_paramater_combination_id": config_id,
                    **split
                }
                injector = Injector(mapping=external_injection)
                experiment_config_injected = injector.inject_pass(
                    component_parameters=experiment_config)
                bp = create_blueprint(
                    blue_print_class=blue_print_type,
                    run_mode=AbstractGymJob.Mode.TRAIN
                    if not self.re_eval else AbstractGymJob.Mode.EVAL,
                    job_type=job_type,
                    experiment_config=experiment_config_injected,
                    dashify_logging_path=dashify_logging_path,
                    num_epochs=num_epochs,
                    grid_search_id=self.grid_search_id,
                    experiment_id=experiment_id)
                blueprints.append(bp)
                experiment_id = experiment_id + 1
        return blueprints