def test_constructable(self, example_config: Dict[str, Any], example_config_injected: Dict[str, Any], example_injection: Dict[str, Any]): injector = Injector(mapping=example_injection) example_config_after_injection = injector.inject_pass(example_config) assert json.dumps(example_config_after_injection) == json.dumps( example_config_injected)
def construct_components( config: Dict, component_names: List[str], device: torch.device, external_injection: Dict[str, Any] = None) -> Dict[str, Any]: if external_injection is not None: injection_mapping = { "id_conv_mnist_standard_collator": MNISTCollator, "id_computation_device": device, **external_injection } injector = Injector(injection_mapping, raise_mapping_not_found=True) else: injection_mapping = { "id_conv_mnist_standard_collator": MNISTCollator } injector = Injector(injection_mapping, raise_mapping_not_found=False) component_factory = ComponentFactory(injector) component_factory.register_component_type( "MODEL_REGISTRY", "DEFAULT", MyModelRegistryConstructable) components = component_factory.build_components_from_config( config, component_names) return components
def construct(self) -> AbstractGymJob: experiment_info = self.get_experiment_info(self.dashify_logging_dir, self.grid_search_id, self.model_name, self.dataset_name, self.run_id) component_names = ["model", "trainer", "optimizer", "evaluator"] injection_mapping = {"id_ata_collator": ATACollator} injector = Injector(injection_mapping) component_factory = ComponentFactory(injector) component_factory.register_component_type( "DATASET_REPOSITORY", "DEFAULT", CustomDatasetRepositoryConstructable) component_factory.register_component_type("DATASET_FACTORY", "ATIS", AtisFactoryConstructable) component_factory.register_component_type("DATASET_FACTORY", "KDD", KddFactoryConstructable) component_factory.register_component_type( "LOSS_FUNCTION_REGISTRY", "DEFAULT", LossFunctionRegistryConstructable), component_factory.register_component_type( "MODEL_REGISTRY", "DEFAULT", CustomModelRegistryConstructable) component_factory.register_component_type( "TRAIN_COMPONENT", "ATA", ATATrainComponentConstructable), component_factory.register_component_type( "EVAL_COMPONENT", "ATA", ATAEvalComponentConstructable) components = component_factory.build_components_from_config( self.config, component_names) gym_job = GymJob(self.run_mode, experiment_info=experiment_info, epochs=self.epochs, **components) return gym_job
def components(self, full_config) -> Dict: component_names = list(full_config.keys()) injection_mapping = {"id_mlp_standard_collator": MLPCollator} injector = Injector(injection_mapping) component_factory = ComponentFactory(injector) component_factory.register_component_type("DATASET_REPOSITORY", "DEFAULT", CustomDatasetRepositoryConstructable) components = component_factory.build_components_from_config(full_config, component_names) return components
def create_blue_prints(self, blue_print_type: Type[BluePrint], job_type: AbstractGymJob.Type, gs_config: Dict[str, Any], num_epochs: int, dashify_logging_path: str) -> List[Type[BluePrint]]: run_id_to_config_dict = { run_id: config for run_id, config in enumerate( GridSearch.create_gs_from_config_dict(gs_config)) } fold_indices = self._get_fold_indices() splits: List[Dict[str, Any]] = CrossValidation._create_folds_splits( fold_indices) blueprints = [] experiment_id = 0 for split in splits: for config_id, experiment_config in run_id_to_config_dict.items(): external_injection = { "id_experiment_id": experiment_id, "id_hyper_paramater_combination_id": config_id, **split } injector = Injector(mapping=external_injection) experiment_config_injected = injector.inject_pass( component_parameters=experiment_config) bp = create_blueprint( blue_print_class=blue_print_type, run_mode=AbstractGymJob.Mode.TRAIN if not self.re_eval else AbstractGymJob.Mode.EVAL, job_type=job_type, experiment_config=experiment_config_injected, dashify_logging_path=dashify_logging_path, num_epochs=num_epochs, grid_search_id=self.grid_search_id, experiment_id=experiment_id) blueprints.append(bp) experiment_id = experiment_id + 1 return blueprints