Пример #1
0
def create_dataloader(triple_source,
                      config,
                      build_label=False,
                      dataset_type=constants.DatasetType.TRAINING):
    """Creates dataloader.
    When training, returns dataloader using perturbation sampler with bernoulli trick for training.
    When validation/testing, returns dataloader with tile support.
    If other configurations are used, such dataloader needs to be created manually.
    See create_dataloader_from_dataset.
    """

    # Use those C++ extension is fast but then we can't use spawn method to start data loader.
    if dataset_type == constants.DatasetType.TRAINING:
        corruptor = kgedata.BernoulliCorruptor(triple_source.train_set,
                                               triple_source.num_relation,
                                               config.negative_entity,
                                               config.base_seed + SEED_OFFSET)
        negative_sampler = kgedata.PerturbationSampler(
            triple_source.train_set, triple_source.num_entity,
            triple_source.num_relation, config.negative_entity,
            config.negative_relation, config.base_seed + 2 * SEED_OFFSET,
            kgedata.PerturbationSamplerStrategy.Hash)

        transforms = [
            transformers.CorruptionFlagGenerator(corruptor),
            transformers.NegativeBatchGenerator(negative_sampler),
        ]
        if build_label:
            negative_label_generator = kgedata.MemoryLabelGenerator(
                triple_source.train_set)
            positive_label_generator = kgedata.StaticLabelGenerator(True)
            transforms.append(
                transformers.LabelBatchGenerator(config,
                                                 negative_label_generator,
                                                 positive_label_generator))
        else:
            transforms.append(transformers.none_label_batch_generator)
        transforms.append(transformers.tensor_transform)

        dataset = TripleDataset(triple_source.train_set,
                                batch_size=config.batch_size,
                                transform=Compose(transforms))
    else:  # Validation and Test
        batch_size = max(
            _SAFE_MINIMAL_BATCH_SIZE,
            int(config.batch_size * config.evaluation_load_factor))
        transforms = [
            transformers.TripleTileGenerator(config, triple_source),
            transformers.test_batch_transform
        ]
        if dataset_type == constants.DatasetType.VALIDATION:
            triple_set = triple_source.valid_set
        else:
            triple_set = triple_source.test_set

        dataset = TripleDataset(triple_set,
                                batch_size=batch_size,
                                transform=Compose(transforms))

    return create_dataloader_from_dataset(dataset, config)
Пример #2
0
def test_bernoulli_corruption_generator(source, num_corrupts,
                                        small_triple_list):
    corruptor = kgedata.BernoulliCorruptor(source.train_set,
                                           source.num_relation, num_corrupts,
                                           2000)
    np.testing.assert_equal(
        transformers.CorruptionFlagGenerator(corruptor)(small_triple_list)[0],
        np.array([False, False], dtype=np.bool).reshape((-1, num_corrupts)))
Пример #3
0
def create_cwa_training_dataloader(triple_source, config):
    """Creates the CWA dataloader for training."""
    corruptor = kgedata.StaticCorruptor(config.negative_entity, False)
    sampler = kgedata.CWASampler(triple_source.num_entity,
                                 triple_source.num_relation, False)

    transforms = [
        transformers.CorruptionFlagGenerator(corruptor),
        transformers.NegativeBatchGenerator(sampler),
    ]

    negative_label_generator = kgedata.MemoryLabelGenerator(
        triple_source.train_set)
    positive_label_generator = lambda x: None
    transforms.append(
        transformers.LabelBatchGenerator(config, negative_label_generator,
                                         positive_label_generator))
    transforms.append(transformers.BatchMasker((True, False, False)))
    transforms.append(transformers.tensor_transform)

    dataset = TripleDataset(triple_source.train_set,
                            batch_size=config.batch_size,
                            transform=Compose(transforms))
    return create_dataloader_from_dataset(dataset, config)
Пример #4
0
def negative_sample_with_negs(small_triple_list, corruptor, negative_sampler):
    sample = transformers.CorruptionFlagGenerator(corruptor)(small_triple_list)
    sample = transformers.NegativeBatchGenerator(negative_sampler)(sample)
    return sample
Пример #5
0
def test_uniform_corruption_flag_generator(corruptor, num_corrupts,
                                           small_triple_list):
    np.testing.assert_equal(
        transformers.CorruptionFlagGenerator(corruptor)(small_triple_list)[0],
        np.array([False, False], dtype=np.bool).reshape((-1, num_corrupts)))