コード例 #1
0
    def __init__(self, config: Config, configuration_key: str,
                 dataset: DatasetProcessor):
        super().__init__(config, configuration_key)

        # load config
        self.num_samples = torch.zeros(4, dtype=torch.int)
        self.filter_positives = torch.zeros(4, dtype=torch.bool)
        self.vocabulary_size = torch.zeros(4, dtype=torch.int)
        self.shared = self.get_option("shared")
        self.shared_type = self.check_option("shared_type",
                                             ["naive", "default"])
        self.with_replacement = self.get_option("with_replacement")
        if not self.with_replacement and not self.shared:
            raise ValueError(
                "Without replacement sampling is only supported when "
                "shared negative sampling is enabled.")
        self.filtering_split = config.get("negative_sampling.filtering.split")
        if self.filtering_split == "":
            self.filtering_split = config.get("train.split")
        for slot in SLOTS:
            slot_str = SLOT_STR[slot]
            self.num_samples[slot] = self.get_option(f"num_samples.{slot_str}")
            self.filter_positives[slot] = self.get_option(
                f"filtering.{slot_str}")
            self.vocabulary_size[slot] = (
                dataset.num_relations()
                if slot == P else dataset.num_timestamps()
                if slot == T else dataset.num_entities()
                # TODO edit tkge data
            )
            # create indices for filtering here already if needed and not existing
            # otherwise every worker would create every index again and again
            if self.filter_positives[slot]:
                pair = ["po", "so", "sp"][slot]
                dataset.index(f"{self.filtering_split}_{pair}_to_{slot_str}")

        if any(self.filter_positives):
            if self.shared:
                raise ValueError(
                    "Filtering is not supported when shared negative sampling is enabled."
                )

            self.filter_implementation = self.check_option(
                "filtering.implementation",
                ["standard", "fast", "fast_if_available"])

            self.filter_implementation = self.get_option(
                "filtering.implementation")

        self.dataset = dataset

        # auto config
        for slot, copy_from in [(S, O), (P, None), (O, S)]:
            if self.num_samples[slot] < 0:
                if copy_from is not None and self.num_samples[copy_from] > 0:
                    self.num_samples[slot] = self.num_samples[copy_from]
                else:
                    self.num_samples[slot] = 0
コード例 #2
0
    def create(config: Config):
        """Factory method for data creation"""

        ds_type = config.get("dataset.name")

        if ds_type in DatasetProcessor.list_available():
            kwargs = config.get("dataset.args")  # TODO: 需要改成key的格式
            return DatasetProcessor.by_name(ds_type)(config)
        else:
            raise ConfigurationError(
                f"{ds_type} specified in configuration file is not supported"
                f"implement your data class with `DatasetProcessor.register(name)"
            )
コード例 #3
0
    def create(config: Config, dataset: DatasetProcessor):
        """Factory method for loss creation"""

        ns_type = config.get("negative_sampling.name")

        if ns_type in NegativeSampler.list_available():
            as_matrix = config.get("negative_sampling.as_matrix")
            # kwargs = config.get("model.args")  # TODO: 需要改成key的格式
            return NegativeSampler.by_name(ns_type)(config, dataset, as_matrix)
        else:
            raise ConfigurationError(
                f"{ns_type} specified in configuration file is not supported"
                f"implement your negative samping class with `NegativeSampler.register(name)"
            )
コード例 #4
0
    def create(config: Config, name: str):
        reg_type = config.get(f"train.inplace_regularizer.{name}.type")

        if reg_type in InplaceRegularizer.list_available():
            return InplaceRegularizer.by_name(reg_type)(config, name)

        else:
            raise ValueError(
                f"{reg_type} specified in configuration file is not supported"
                f"implement your inplace-regularizer class with `InplaceRegularizer.register(name)"
            )
コード例 #5
0
    def create(config: Config, dataset: DatasetProcessor):
        """Factory method for sampler creation"""

        model_type = config.get("model.name")

        if model_type in BaseModel.list_available():
            # kwargs = config.get("model.args")  # TODO: 需要改成key的格式
            return BaseModel.by_name(model_type)(config, dataset)
        else:
            raise ConfigurationError(
                f"{model_type} specified in configuration file is not supported"
                f"implement your model class with `BaseModel.register(name)")
コード例 #6
0
    def create(config: Config, configuration_key: str,
               dataset: DatasetProcessor):
        """Factory method for sampler creation"""

        sampling_type = config.get(configuration_key + ".sampling_type")

        if sampling_type in NegativeSampler.list_available():
            return NegativeSampler.by_name(sampling_type)(config,
                                                          configuration_key,
                                                          dataset)

        else:
            raise ValueError(
                f"{sampling_type} specified in configuration file is not supported"
                f"implement your negative sampler class with `NegativeSampler.register(name)"
            )
コード例 #7
0
ファイル: main.py プロジェクト: tkg-framework/TKG-framework
import argparse

from tkge.task.trainer import TrainTask
from tkge.common.config import Config

desc = 'Temporal KG Completion methods'
parser = argparse.ArgumentParser(description=desc)

parser.add_argument('-config', help='configuration file folder', type=str)
args = parser.parse_args()

config = Config(folder=args.config)

trainer = TrainTask(config)
コード例 #8
0
    def __init__(self, config: Config):
        super().__init__(config)

        self._device = config.get("task.device")
        self._train_type = config.get("train.type")
        self._loss = torch.nn.CrossEntropyLoss()
コード例 #9
0
    def __init__(self, config: Config):
        super().__init__(config)

        self._device = config.get("task.device")
        self._train_type = config.get("train.type")
        self._loss = torch.nn.BCEWithLogitsLoss()
コード例 #10
0
ファイル: tkge.py プロジェクト: tkg-framework/TKG-framework
    version=f"work in progress"
)


# parser.add_argument('train', help='task type', type=bool)
# parser.add_argument('--config', help='configuration file', type=str)
# parser.add_argument('--help', help='help')

subparsers = parser.add_subparsers(title="task",
                                   description="valid tasks: train, evaluate, predict, search",
                                   dest="task")

# subparser train
parser_train = TrainTask.parse_arguments(subparsers)
parser_eval = TestTask.parse_arguments(subparsers)

args = parser.parse_args()

task_dict = {
    'train': TrainTask,
    'eval': TestTask
}

config = Config(folder=args.config, load_default=False)  # TODO load_default is false
task = task_dict[args.task](config)

task.main()

# trainer = TrainTask(config)
# tester = TestTask(config)