Example #1
0
    def create(
        config: Config,
        dataset: Dataset,
        configuration_key: str,
        vocab_size: int,
        init_for_load_only=False,
        parameter_client=None,
        lapse_offset=0,
        complete_vocab_size=None,
    ) -> "KgeEmbedder":
        """Factory method for embedder creation."""
        if complete_vocab_size is None:
            complete_vocab_size = vocab_size

        try:
            embedder_type = config.get_default(configuration_key + ".type")
            class_name = config.get(embedder_type + ".class_name")
        except:
            raise Exception("Can't find {}.type in config".format(configuration_key))

        try:
            if "distributed" in config.get("train.type"):
                class_name = "Distributed" + class_name
                embedder = init_from(
                    class_name,
                    config.get("modules"),
                    config=config,
                    dataset=dataset,
                    configuration_key=configuration_key,
                    vocab_size=vocab_size,
                    init_for_load_only=init_for_load_only,
                    parameter_client=parameter_client,
                    lapse_offset=lapse_offset,
                    complete_vocab_size=complete_vocab_size
                )
            else:
                embedder = init_from(
                    class_name,
                    config.get("modules"),
                    config,
                    dataset,
                    configuration_key,
                    vocab_size,
                    init_for_load_only=init_for_load_only,
                )
            return embedder
        except:
            config.log(f"Failed to create embedder {embedder_type} (class {class_name}).")
            raise
Example #2
0
    def create(
        config: Config,
        dataset: Dataset,
        configuration_key: Optional[str] = None,
        init_for_load_only=False,
    ) -> "KgeModel":
        """Factory method for model creation."""
        try:
            if configuration_key is not None:
                model_name = config.get(configuration_key + ".type")
            else:
                model_name = config.get("model")
            config._import(model_name)
            class_name = config.get(model_name + ".class_name")
        except:
            raise Exception("Can't find {}.type in config".format(configuration_key))

        try:
            model = init_from(
                class_name,
                config.get("modules"),
                config=config,
                dataset=dataset,
                configuration_key=configuration_key,
                init_for_load_only=init_for_load_only,
            )
            model.to(config.get("job.device"))
            return model
        except:
            config.log(f"Failed to create model {model_name} (class {class_name}).")
            raise
Example #3
0
    def create(
        config: Config,
        dataset: Dataset,
        configuration_key: str,
        vocab_size: int,
        init_for_load_only=False,
    ) -> "KgeEmbedder":
        """Factory method for embedder creation."""

        try:
            embedder_type = config.get_default(configuration_key + ".type")
            class_name = config.get(embedder_type + ".class_name")
        except:
            raise Exception("Can't find {}.type in config".format(configuration_key))

        try:
            embedder = init_from(
                class_name,
                config.get("modules"),
                config,
                dataset,
                configuration_key,
                vocab_size,
                init_for_load_only=init_for_load_only,
            )
            return embedder
        except:
            config.log(
                f"Failed to create embedder {embedder_type} (class {class_name})."
            )
            raise
Example #4
0
    def create(config, dataset, parent_job=None):
        """Factory method to create a search job."""

        search_type = config.get("search.type")
        class_name = config.get_default(f"{search_type}.class_name")
        return init_from(class_name, config.modules(), config, dataset,
                         parent_job)
Example #5
0
 def create(
     config: Config,
     dataset: Dataset,
     parent_job: Job = None,
     model=None,
     forward_only=False,
     parameter_client=None,
     init_for_load_only=False,
 ) -> "TrainingJob":
     """Factory method to create a training job."""
     train_type = config.get("train.type")
     class_name = config.get_default(f"{train_type}.class_name")
     job_config_object = {
         "class_name": class_name,
         "modules": config.modules(),
         "config": config,
         "dataset": dataset,
         "parent_job": parent_job,
         "model": model,
         "forward_only": forward_only,
     }
     if "distributed" in train_type:
         job_config_object.update({
             "parameter_client": parameter_client,
             "init_for_load_only": init_for_load_only
         })
     return init_from(**job_config_object)
Example #6
0
File: eval.py Project: uma-pi1/kge
    def create(config, dataset, parent_job=None, model=None):
        """Factory method to create an evaluation job """

        eval_type = config.get("eval.type")
        class_name = config.get_default(f"{eval_type}.class_name")
        return init_from(
            class_name,
            config.modules(),
            config,
            dataset,
            parent_job=parent_job,
            model=model,
        )
Example #7
0
File: train.py Project: uma-pi1/kge
 def create(
     config: Config,
     dataset: Dataset,
     parent_job: Job = None,
     model=None,
     forward_only=False,
 ) -> "TrainingJob":
     """Factory method to create a training job."""
     train_type = config.get("train.type")
     class_name = config.get_default(f"{train_type}.class_name")
     return init_from(
         class_name,
         config.modules(),
         config,
         dataset,
         parent_job,
         model=model,
         forward_only=forward_only,
     )