def create_dataloader(triple_source, config, build_label=False, dataset_type=constants.DatasetType.TRAINING): """Creates dataloader. When training, returns dataloader using perturbation sampler with bernoulli trick for training. When validation/testing, returns dataloader with tile support. If other configurations are used, such dataloader needs to be created manually. See create_dataloader_from_dataset. """ # Use those C++ extension is fast but then we can't use spawn method to start data loader. if dataset_type == constants.DatasetType.TRAINING: corruptor = kgedata.BernoulliCorruptor(triple_source.train_set, triple_source.num_relation, config.negative_entity, config.base_seed + SEED_OFFSET) negative_sampler = kgedata.PerturbationSampler( triple_source.train_set, triple_source.num_entity, triple_source.num_relation, config.negative_entity, config.negative_relation, config.base_seed + 2 * SEED_OFFSET, kgedata.PerturbationSamplerStrategy.Hash) transforms = [ transformers.CorruptionFlagGenerator(corruptor), transformers.NegativeBatchGenerator(negative_sampler), ] if build_label: negative_label_generator = kgedata.MemoryLabelGenerator( triple_source.train_set) positive_label_generator = kgedata.StaticLabelGenerator(True) transforms.append( transformers.LabelBatchGenerator(config, negative_label_generator, positive_label_generator)) else: transforms.append(transformers.none_label_batch_generator) transforms.append(transformers.tensor_transform) dataset = TripleDataset(triple_source.train_set, batch_size=config.batch_size, transform=Compose(transforms)) else: # Validation and Test batch_size = max( _SAFE_MINIMAL_BATCH_SIZE, int(config.batch_size * config.evaluation_load_factor)) transforms = [ transformers.TripleTileGenerator(config, triple_source), transformers.test_batch_transform ] if dataset_type == constants.DatasetType.VALIDATION: triple_set = triple_source.valid_set else: triple_set = triple_source.test_set dataset = TripleDataset(triple_set, batch_size=batch_size, transform=Compose(transforms)) return create_dataloader_from_dataset(dataset, config)
def test_bernoulli_corruption_generator(source, num_corrupts, small_triple_list): corruptor = kgedata.BernoulliCorruptor(source.train_set, source.num_relation, num_corrupts, 2000) np.testing.assert_equal( transformers.CorruptionFlagGenerator(corruptor)(small_triple_list)[0], np.array([False, False], dtype=np.bool).reshape((-1, num_corrupts)))
def create_cwa_training_dataloader(triple_source, config): """Creates the CWA dataloader for training.""" corruptor = kgedata.StaticCorruptor(config.negative_entity, False) sampler = kgedata.CWASampler(triple_source.num_entity, triple_source.num_relation, False) transforms = [ transformers.CorruptionFlagGenerator(corruptor), transformers.NegativeBatchGenerator(sampler), ] negative_label_generator = kgedata.MemoryLabelGenerator( triple_source.train_set) positive_label_generator = lambda x: None transforms.append( transformers.LabelBatchGenerator(config, negative_label_generator, positive_label_generator)) transforms.append(transformers.BatchMasker((True, False, False))) transforms.append(transformers.tensor_transform) dataset = TripleDataset(triple_source.train_set, batch_size=config.batch_size, transform=Compose(transforms)) return create_dataloader_from_dataset(dataset, config)
def negative_sample_with_negs(small_triple_list, corruptor, negative_sampler): sample = transformers.CorruptionFlagGenerator(corruptor)(small_triple_list) sample = transformers.NegativeBatchGenerator(negative_sampler)(sample) return sample
def test_uniform_corruption_flag_generator(corruptor, num_corrupts, small_triple_list): np.testing.assert_equal( transformers.CorruptionFlagGenerator(corruptor)(small_triple_list)[0], np.array([False, False], dtype=np.bool).reshape((-1, num_corrupts)))