def __init__(self, config: Config, configuration_key: str, dataset: DatasetProcessor): super().__init__(config, configuration_key) # load config self.num_samples = torch.zeros(4, dtype=torch.int) self.filter_positives = torch.zeros(4, dtype=torch.bool) self.vocabulary_size = torch.zeros(4, dtype=torch.int) self.shared = self.get_option("shared") self.shared_type = self.check_option("shared_type", ["naive", "default"]) self.with_replacement = self.get_option("with_replacement") if not self.with_replacement and not self.shared: raise ValueError( "Without replacement sampling is only supported when " "shared negative sampling is enabled.") self.filtering_split = config.get("negative_sampling.filtering.split") if self.filtering_split == "": self.filtering_split = config.get("train.split") for slot in SLOTS: slot_str = SLOT_STR[slot] self.num_samples[slot] = self.get_option(f"num_samples.{slot_str}") self.filter_positives[slot] = self.get_option( f"filtering.{slot_str}") self.vocabulary_size[slot] = ( dataset.num_relations() if slot == P else dataset.num_timestamps() if slot == T else dataset.num_entities() # TODO edit tkge data ) # create indices for filtering here already if needed and not existing # otherwise every worker would create every index again and again if self.filter_positives[slot]: pair = ["po", "so", "sp"][slot] dataset.index(f"{self.filtering_split}_{pair}_to_{slot_str}") if any(self.filter_positives): if self.shared: raise ValueError( "Filtering is not supported when shared negative sampling is enabled." ) self.filter_implementation = self.check_option( "filtering.implementation", ["standard", "fast", "fast_if_available"]) self.filter_implementation = self.get_option( "filtering.implementation") self.dataset = dataset # auto config for slot, copy_from in [(S, O), (P, None), (O, S)]: if self.num_samples[slot] < 0: if copy_from is not None and self.num_samples[copy_from] > 0: self.num_samples[slot] = self.num_samples[copy_from] else: self.num_samples[slot] = 0
def create(config: Config): """Factory method for data creation""" ds_type = config.get("dataset.name") if ds_type in DatasetProcessor.list_available(): kwargs = config.get("dataset.args") # TODO: 需要改成key的格式 return DatasetProcessor.by_name(ds_type)(config) else: raise ConfigurationError( f"{ds_type} specified in configuration file is not supported" f"implement your data class with `DatasetProcessor.register(name)" )
def create(config: Config, dataset: DatasetProcessor): """Factory method for loss creation""" ns_type = config.get("negative_sampling.name") if ns_type in NegativeSampler.list_available(): as_matrix = config.get("negative_sampling.as_matrix") # kwargs = config.get("model.args") # TODO: 需要改成key的格式 return NegativeSampler.by_name(ns_type)(config, dataset, as_matrix) else: raise ConfigurationError( f"{ns_type} specified in configuration file is not supported" f"implement your negative samping class with `NegativeSampler.register(name)" )
def create(config: Config, name: str): reg_type = config.get(f"train.inplace_regularizer.{name}.type") if reg_type in InplaceRegularizer.list_available(): return InplaceRegularizer.by_name(reg_type)(config, name) else: raise ValueError( f"{reg_type} specified in configuration file is not supported" f"implement your inplace-regularizer class with `InplaceRegularizer.register(name)" )
def create(config: Config, dataset: DatasetProcessor): """Factory method for sampler creation""" model_type = config.get("model.name") if model_type in BaseModel.list_available(): # kwargs = config.get("model.args") # TODO: 需要改成key的格式 return BaseModel.by_name(model_type)(config, dataset) else: raise ConfigurationError( f"{model_type} specified in configuration file is not supported" f"implement your model class with `BaseModel.register(name)")
def create(config: Config, configuration_key: str, dataset: DatasetProcessor): """Factory method for sampler creation""" sampling_type = config.get(configuration_key + ".sampling_type") if sampling_type in NegativeSampler.list_available(): return NegativeSampler.by_name(sampling_type)(config, configuration_key, dataset) else: raise ValueError( f"{sampling_type} specified in configuration file is not supported" f"implement your negative sampler class with `NegativeSampler.register(name)" )
import argparse from tkge.task.trainer import TrainTask from tkge.common.config import Config desc = 'Temporal KG Completion methods' parser = argparse.ArgumentParser(description=desc) parser.add_argument('-config', help='configuration file folder', type=str) args = parser.parse_args() config = Config(folder=args.config) trainer = TrainTask(config)
def __init__(self, config: Config): super().__init__(config) self._device = config.get("task.device") self._train_type = config.get("train.type") self._loss = torch.nn.CrossEntropyLoss()
def __init__(self, config: Config): super().__init__(config) self._device = config.get("task.device") self._train_type = config.get("train.type") self._loss = torch.nn.BCEWithLogitsLoss()
version=f"work in progress" ) # parser.add_argument('train', help='task type', type=bool) # parser.add_argument('--config', help='configuration file', type=str) # parser.add_argument('--help', help='help') subparsers = parser.add_subparsers(title="task", description="valid tasks: train, evaluate, predict, search", dest="task") # subparser train parser_train = TrainTask.parse_arguments(subparsers) parser_eval = TestTask.parse_arguments(subparsers) args = parser.parse_args() task_dict = { 'train': TrainTask, 'eval': TestTask } config = Config(folder=args.config, load_default=False) # TODO load_default is false task = task_dict[args.task](config) task.main() # trainer = TrainTask(config) # tester = TestTask(config)