Exemplo n.º 1
0
 def __init__(
     self,
     config: Config,
     dataset: Dataset,
     configuration_key=None,
     init_for_load_only=False,
     create_embedders=True,
     parameter_client=None,
     max_partition_entities=0,
 ):
     self._init_configuration(config, configuration_key)
     self.base_model_config_key = self.configuration_key + ".base_model"
     self.parameter_client = parameter_client
     self.max_partition_entities = max_partition_entities
     base_config = deepcopy(config)
     base_config.set("model", config.get("distributed_model.base_model.type"))
     base_model = KgeModel.create(
         config=base_config,
         dataset=dataset,
         configuration_key=self.base_model_config_key,
         init_for_load_only=init_for_load_only,
         create_embedders=False,
     )
     # Initialize this model
     super().__init__(
         config=config,
         dataset=dataset,
         create_embedders=False,
         scorer=base_model.get_scorer(),
         init_for_load_only=init_for_load_only,
     )
     self.base_model = base_model
     if create_embedders:
         self._create_embedders(init_for_load_only, parameter_client, max_partition_entities)
    def __init__(
        self,
        config: Config,
        dataset: Dataset,
        configuration_key=None,
        init_for_load_only=False,
    ):
        self._init_configuration(config, configuration_key)

        # Initialize base model
        # Using a dataset with twice the number of relations to initialize base model
        alt_dataset = dataset.shallow_copy()
        alt_dataset._num_relations = dataset.num_relations() * 2
        base_model = KgeModel.create(
            config=config,
            dataset=alt_dataset,
            configuration_key=self.configuration_key + ".base_model",
            init_for_load_only=init_for_load_only,
        )

        # Initialize this model
        super().__init__(
            config=config,
            dataset=dataset,
            scorer=base_model.get_scorer(),
            create_embedders=False,
            init_for_load_only=init_for_load_only,
        )
        self._base_model = base_model
        # TODO change entity_embedder assignment to sub and obj embedders when support
        # for that is added
        self._entity_embedder = self._base_model.get_s_embedder()
        self._relation_embedder = self._base_model.get_p_embedder()
Exemplo n.º 3
0
 def load_pretrained_model(
         pretrained_filename: str,
 ) -> Optional[KgeModel]:
     if pretrained_filename != "":
         self.config.log(
             f"Initializing with embeddings stored in "
             f"{pretrained_filename}"
         )
         checkpoint = load_checkpoint(pretrained_filename)
         return KgeModel.create_from(checkpoint,
                                     parameter_client=parameter_client)
     return None