def setUp(self) -> None: self.subprocess_init = SubprocessInitializer() self.subprocess_init.register(setup_logging, 1) self.subprocess_init() self.checkpoint_path = TemporaryDirectory() self.addCleanup(self.checkpoint_path.cleanup) seed = random.getrandbits(32) np.random.seed(seed) logger.info(f"Random seed: {seed}")
def main(): setup_logging() config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help()) parser = argparse.ArgumentParser( epilog=config_help, # Needed to preserve line wraps in epilog. formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument('config', help="Path to config file") parser.add_argument('-p', '--param', action='append', nargs='*') parser.add_argument('--rank', type=int, default=0, help="For multi-machine, this machine's rank") opt = parser.parse_args() if opt.param is not None: overrides = chain.from_iterable(opt.param) # flatten else: overrides = None loader = ConfigFileLoader() config = loader.load_config(opt.config, overrides) set_logging_verbosity(config.verbose) subprocess_init = SubprocessInitializer() subprocess_init.register(setup_logging, config.verbose) subprocess_init.register(add_to_sys_path, loader.config_dir.name) train(config, rank=Rank(opt.rank), subprocess_init=subprocess_init)
def main(): setup_logging() config_help = "\n\nConfig parameters:\n\n" + "\n".join(ConfigSchema.help()) parser = argparse.ArgumentParser( epilog=config_help, # Needed to preserve line wraps in epilog. formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument("config", help="Path to config file") parser.add_argument("-p", "--param", action="append", nargs="*") parser.add_argument( "--rank", type=int, default=SINGLE_TRAINER, help="For multi-machine, this machine's rank", ) opt = parser.parse_args() loader = ConfigFileLoader() config = loader.load_config(opt.config, opt.param) set_logging_verbosity(config.verbose) subprocess_init = SubprocessInitializer() subprocess_init.register(setup_logging, config.verbose) subprocess_init.register(add_to_sys_path, loader.config_dir.name) train(config, rank=opt.rank, subprocess_init=subprocess_init)
def main(): setup_logging() parser = argparse.ArgumentParser(description='Example on FB15k') parser.add_argument('--config', default=DEFAULT_CONFIG, help='Path to config file') parser.add_argument('-p', '--param', action='append', nargs='*') parser.add_argument('--data_dir', type=Path, default='data', help='where to save processed data') parser.add_argument('--no-filtered', dest='filtered', action='store_false', help='Run unfiltered eval') args = parser.parse_args() if args.param is not None: overrides = chain.from_iterable(args.param) # flatten else: overrides = None # download data data_dir = args.data_dir fpath = download_url(FB15K_URL, data_dir) extract_tar(fpath) print('Downloaded and extracted file.') loader = ConfigFileLoader() config = loader.load_config(args.config, overrides) set_logging_verbosity(config.verbose) subprocess_init = SubprocessInitializer() subprocess_init.register(setup_logging, config.verbose) subprocess_init.register(add_to_sys_path, loader.config_dir.name) input_edge_paths = [data_dir / name for name in FILENAMES] output_train_path, output_valid_path, output_test_path = config.edge_paths convert_input_data( config.entities, config.relations, config.entity_path, config.edge_paths, input_edge_paths, lhs_col=0, rhs_col=2, rel_col=1, dynamic_relations=config.dynamic_relations, ) train_config = attr.evolve(config, edge_paths=[output_train_path]) train(train_config, subprocess_init=subprocess_init) relations = [attr.evolve(r, all_negs=True) for r in config.relations] eval_config = attr.evolve( config, edge_paths=[output_test_path], relations=relations, num_uniform_negs=0) if args.filtered: filter_paths = [output_test_path, output_valid_path, output_train_path] do_eval( eval_config, evaluator=FilteredRankingEvaluator(eval_config, filter_paths), subprocess_init=subprocess_init, ) else: do_eval(eval_config, subprocess_init=subprocess_init)
def main(): setup_logging() parser = argparse.ArgumentParser(description='Example on Livejournal') parser.add_argument('--config', default=DEFAULT_CONFIG, help='Path to config file') parser.add_argument('-p', '--param', action='append', nargs='*') parser.add_argument('--data_dir', type=Path, default='data', help='where to save processed data') args = parser.parse_args() if args.param is not None: overrides = chain.from_iterable(args.param) # flatten else: overrides = None # download data data_dir = args.data_dir data_dir.mkdir(parents=True, exist_ok=True) fpath = download_url(URL, data_dir) fpath = extract_gzip(fpath) print('Downloaded and extracted file.') # random split file for train and test random_split_file(fpath) loader = ConfigFileLoader() config = loader.load_config(args.config, overrides) set_logging_verbosity(config.verbose) subprocess_init = SubprocessInitializer() subprocess_init.register(setup_logging, config.verbose) subprocess_init.register(add_to_sys_path, loader.config_dir.name) edge_paths = [data_dir / name for name in FILENAMES.values()] convert_input_data( config.entities, config.relations, config.entity_path, edge_paths, lhs_col=0, rhs_col=1, rel_col=None, dynamic_relations=config.dynamic_relations, ) train_path = [str(convert_path(data_dir / FILENAMES['train']))] train_config = attr.evolve(config, edge_paths=train_path) train(train_config, subprocess_init=subprocess_init) eval_path = [str(convert_path(data_dir / FILENAMES['test']))] eval_config = attr.evolve(config, edge_paths=eval_path) do_eval(eval_config, subprocess_init=subprocess_init)
def main(): setup_logging() parser = argparse.ArgumentParser(description='Example on Livejournal') parser.add_argument('--config', default=DEFAULT_CONFIG, help='Path to config file') parser.add_argument('-p', '--param', action='append', nargs='*') parser.add_argument('--data_dir', type=Path, default='data', help='where to save processed data') args = parser.parse_args() # download data data_dir = args.data_dir data_dir.mkdir(parents=True, exist_ok=True) fpath = download_url(URL, data_dir) fpath = extract_gzip(fpath) print('Downloaded and extracted file.') # random split file for train and test random_split_file(fpath) loader = ConfigFileLoader() config = loader.load_config(args.config, args.param) set_logging_verbosity(config.verbose) subprocess_init = SubprocessInitializer() subprocess_init.register(setup_logging, config.verbose) subprocess_init.register(add_to_sys_path, loader.config_dir.name) input_edge_paths = [data_dir / name for name in FILENAMES] output_train_path, output_test_path = config.edge_paths convert_input_data( config.entities, config.relations, config.entity_path, config.edge_paths, input_edge_paths, TSVEdgelistReader(lhs_col=0, rhs_col=1, rel_col=None), dynamic_relations=config.dynamic_relations, ) train_config = attr.evolve(config, edge_paths=[output_train_path]) train(train_config, subprocess_init=subprocess_init) eval_config = attr.evolve(config, edge_paths=[output_test_path]) do_eval(eval_config, subprocess_init=subprocess_init)
def main(): setup_logging() config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help()) parser = argparse.ArgumentParser( epilog=config_help, # Needed to preserve line wraps in epilog. formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument('config', help="Path to config file") parser.add_argument('-p', '--param', action='append', nargs='*') opt = parser.parse_args() loader = ConfigFileLoader() config = loader.load_config(opt.config, opt.param) set_logging_verbosity(config.verbose) subprocess_init = SubprocessInitializer() subprocess_init.register(setup_logging, config.verbose) subprocess_init.register(add_to_sys_path, loader.config_dir.name) do_eval(config, subprocess_init=subprocess_init)
# 2. TRANSFORM GRAPH TO A BIGGRAPH-FRIENDLY FORMAT # This step generates the following metadata files: # # data/example_2/entity_count_item_0.txt # data/example_2/entity_count_merchant_0.txt # data/example_2/entity_count_user_0.txt # data/example_2/entity_names_item_0.json # data/example_2/entity_names_merchant_0.json # data/example_2/entity_names_user_0.json # # and this file with data: # data/example_2/edges_partitioned/edges_0_0.h5 # ================================================= setup_logging() config = parse_config(raw_config) subprocess_init = SubprocessInitializer() input_edge_paths = [Path(GRAPH_PATH)] convert_input_data( config.entities, config.relations, config.entity_path, config.edge_paths, input_edge_paths, TSVEdgelistReader(lhs_col=0, rel_col=1, rhs_col=2), dynamic_relations=config.dynamic_relations, ) # =============================================== # 3. TRAIN THE EMBEDDINGS # files generated in this step:
class TestFunctional(TestCase): def setUp(self) -> None: self.subprocess_init = SubprocessInitializer() self.subprocess_init.register(setup_logging, 1) self.subprocess_init() self.checkpoint_path = TemporaryDirectory() self.addCleanup(self.checkpoint_path.cleanup) seed = random.getrandbits(32) np.random.seed(seed) logger.info(f"Random seed: {seed}") def assertHasMetadata(self, hf: h5py.File, config: ConfigSchema) -> None: self.assertEqual(hf.attrs["format_version"], 1) self.assertEqual(json.loads(hf.attrs["config/json"]), config.to_dict()) self.assertCountEqual( [ key.partition("/")[-1] for key in hf.attrs.keys() if key.startswith("iteration/") ], [ "num_epochs", "epoch_idx", "num_edge_paths", "edge_path_idx", "edge_path", "num_edge_chunks", "edge_chunk_idx", ], ) def assertIsModelParameter(self, dataset: h5py.Dataset) -> None: # In fact it could also be a group... if not isinstance(dataset, h5py.Dataset): return self.assertIn("state_dict_key", dataset.attrs) self.assertTrue(np.isfinite(dataset[...]).all()) def assertIsModelParameters(self, group: h5py.Group) -> None: self.assertIsInstance(group, h5py.Group) group.visititems(lambda _, d: self.assertIsModelParameter(d)) def assertIsOptimStateDict(self, dataset: h5py.Dataset) -> None: self.assertIsInstance(dataset, h5py.Dataset) self.assertEqual(dataset.dtype, np.dtype("V1")) self.assertEqual(len(dataset.shape), 1) def assertIsEmbeddings(self, dataset: h5py.Dataset, entity_count: int, dimension: int) -> None: self.assertIsInstance(dataset, h5py.Dataset) self.assertEqual(dataset.dtype, np.float32) self.assertEqual(dataset.shape, (entity_count, dimension)) self.assertTrue(np.all(np.isfinite(dataset[...]))) self.assertTrue(np.all(np.linalg.norm(dataset[...], axis=-1) != 0)) def assertIsStatsDict( self, stats: Mapping[str, Union[int, SerializedStats]]) -> None: self.assertIsInstance(stats, dict) self.assertIn("index", stats) for k, v in stats.items(): if k in ( "epoch_idx", "edge_path_idx", "edge_chunk_idx", "lhs_partition", "rhs_partition", "index", ): self.assertIsInstance(v, int) elif k in ( "stats", "eval_stats_before", "eval_stats_after", "eval_stats_chunk_avg", ): self.assertIsInstance(v, dict) assert isinstance(v, dict) self.assertCountEqual(v.keys(), ["count", "metrics"]) self.assertIsInstance(v["count"], int) metrics = v["metrics"] self.assertIsInstance(metrics, dict) assert isinstance(metrics, dict) for m in metrics.values(): self.assertIsInstance(m, float) else: self.fail(f"Unknown stats key: {k}") def assertCheckpointWritten(self, config: ConfigSchema, *, version: int) -> None: with open( os.path.join(config.checkpoint_path, "checkpoint_version.txt"), "rt") as tf: self.assertEqual(version, int(tf.read().strip())) with open(os.path.join(config.checkpoint_path, "config.json"), "rt") as tf: self.assertEqual(json.load(tf), config.to_dict()) with h5py.File( os.path.join(config.checkpoint_path, "model.v%d.h5" % version), "r") as hf: self.assertHasMetadata(hf, config) self.assertIsModelParameters(hf["model"]) self.assertIsOptimStateDict(hf["optimizer/state_dict"]) with open(os.path.join(config.checkpoint_path, "training_stats.json"), "rt") as tf: for line in tf: self.assertIsStatsDict(json.loads(line)) for entity_name, entity in config.entities.items(): for partition in range(entity.num_partitions): with open( os.path.join( config.entity_path, "entity_count_%s_%d.txt" % (entity_name, partition), ), "rt", ) as tf: entity_count = int(tf.read().strip()) with h5py.File( os.path.join( config.checkpoint_path, "embeddings_%s_%d.v%d.h5" % (entity_name, partition, version), ), "r", ) as hf: self.assertHasMetadata(hf, config) self.assertIsEmbeddings( hf["embeddings"], entity_count, config.entity_dimension(entity_name), ) self.assertIsOptimStateDict(hf["optimizer/state_dict"]) def test_default(self): entity_name = "e" relation_config = RelationSchema(name="r", lhs=entity_name, rhs=entity_name) base_config = ConfigSchema( dimension=10, relations=[relation_config], entities={entity_name: EntitySchema(num_partitions=1)}, entity_path=None, # filled in later edge_paths=[], # filled in later checkpoint_path=self.checkpoint_path.name, workers=2, ) dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4, 0.2]) self.addCleanup(dataset.cleanup) train_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[0].name], ) eval_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[1].name], relations=[attr.evolve(relation_config, all_negs=True)], ) # Just make sure no exceptions are raised and nothing crashes. train(train_config, rank=0, subprocess_init=self.subprocess_init) self.assertCheckpointWritten(train_config, version=1) do_eval(eval_config, subprocess_init=self.subprocess_init) def test_resume_from_checkpoint(self): entity_name = "e" relation_config = RelationSchema(name="r", lhs=entity_name, rhs=entity_name) base_config = ConfigSchema( dimension=10, relations=[relation_config], entities={entity_name: EntitySchema(num_partitions=1)}, entity_path=None, # filled in later edge_paths=[], # filled in later checkpoint_path=self.checkpoint_path.name, num_epochs=2, num_edge_chunks=2, workers=2, ) dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4, 0.4]) self.addCleanup(dataset.cleanup) train_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[d.name for d in dataset.relation_paths], ) # Just make sure no exceptions are raised and nothing crashes. init_embeddings(train_config.checkpoint_path, train_config, version=7) train(train_config, rank=0, subprocess_init=self.subprocess_init) self.assertCheckpointWritten(train_config, version=8) # Check we did resume the run, not start the whole thing anew. self.assertFalse( os.path.exists( os.path.join(train_config.checkpoint_path, "model.v6.h5"))) def test_with_initial_value(self): entity_name = "e" relation_config = RelationSchema(name="r", lhs=entity_name, rhs=entity_name) base_config = ConfigSchema( dimension=10, relations=[relation_config], entities={entity_name: EntitySchema(num_partitions=1)}, entity_path=None, # filled in later edge_paths=[], # filled in later checkpoint_path=self.checkpoint_path.name, workers=2, ) dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4]) self.addCleanup(dataset.cleanup) init_dir = TemporaryDirectory() self.addCleanup(init_dir.cleanup) train_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[0].name], init_path=init_dir.name, ) # Just make sure no exceptions are raised and nothing crashes. init_embeddings(train_config.init_path, train_config) train(train_config, rank=0, subprocess_init=self.subprocess_init) self.assertCheckpointWritten(train_config, version=1) def test_featurized(self): e1 = EntitySchema(num_partitions=1, featurized=True) e2 = EntitySchema(num_partitions=1) r1 = RelationSchema(name="r1", lhs="e1", rhs="e2") r2 = RelationSchema(name="r2", lhs="e2", rhs="e1") base_config = ConfigSchema( dimension=10, relations=[r1, r2], entities={ "e1": e1, "e2": e2 }, entity_path=None, # filled in later edge_paths=[], # filled in later checkpoint_path=self.checkpoint_path.name, workers=2, ) dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4, 0.2]) self.addCleanup(dataset.cleanup) train_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[0].name], ) eval_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[1].name], ) # Just make sure no exceptions are raised and nothing crashes. train(train_config, rank=0, subprocess_init=self.subprocess_init) self.assertCheckpointWritten(train_config, version=1) do_eval(eval_config, subprocess_init=self.subprocess_init) def test_partitioned(self): e1 = EntitySchema(num_partitions=1) e2 = EntitySchema(num_partitions=2) e3 = EntitySchema(num_partitions=3) r1 = RelationSchema(name="r1", lhs="e1", rhs="e3") r2 = RelationSchema(name="r2", lhs="e2", rhs="e3") r3 = RelationSchema(name="r3", lhs="e2", rhs="e1") base_config = ConfigSchema( dimension=10, relations=[r1, r2, r3], entities={ "e1": e1, "e2": e2, "e3": e3 }, entity_path=None, # filled in later edge_paths=[], # filled in later checkpoint_path=self.checkpoint_path.name, workers=2, ) dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4, 0.2]) self.addCleanup(dataset.cleanup) train_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[0].name], ) eval_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[1].name], ) # Just make sure no exceptions are raised and nothing crashes. train(train_config, rank=0, subprocess_init=self.subprocess_init) self.assertCheckpointWritten(train_config, version=1) do_eval(eval_config, subprocess_init=self.subprocess_init) @unittest.skipIf(not torch.cuda.is_available(), "No GPU") def test_gpu(self): self._test_gpu() @unittest.skipIf(not torch.cuda.is_available(), "No GPU") def test_gpu_half(self): self._test_gpu(do_half_precision=True) @unittest.skipIf(not torch.cuda.is_available(), "No GPU") def test_gpu_1partition(self): self._test_gpu(num_partitions=1) def _test_gpu(self, do_half_precision=False, num_partitions=2): entity_name = "e" relation_config = RelationSchema(name="r", lhs=entity_name, rhs=entity_name) base_config = ConfigSchema( dimension=16, batch_size=1024, num_batch_negs=64, num_uniform_negs=64, relations=[relation_config], entities={ entity_name: EntitySchema(num_partitions=num_partitions) }, entity_path=None, # filled in later edge_paths=[], # filled in later checkpoint_path=self.checkpoint_path.name, workers=2, num_gpus=2, half_precision=do_half_precision, ) dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4, 0.2]) self.addCleanup(dataset.cleanup) train_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[0].name], ) eval_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[1].name], relations=[attr.evolve(relation_config, all_negs=True)], ) # Just make sure no exceptions are raised and nothing crashes. train_gpu(train_config, rank=0, subprocess_init=self.subprocess_init) self.assertCheckpointWritten(train_config, version=1) do_eval(eval_config, subprocess_init=self.subprocess_init) def _test_distributed(self, num_partitions): sync_path = TemporaryDirectory() self.addCleanup(sync_path.cleanup) e1 = "e1" e2 = "e2" relation_config = RelationSchema( name="r", lhs=e1, rhs=e2, operator="linear", # To exercise the parameter server. ) base_config = ConfigSchema( dimension=10, relations=[relation_config], entities={ e1: EntitySchema(num_partitions=num_partitions), e2: EntitySchema(num_partitions=4), }, entity_path=None, # filled in later edge_paths=[], # filled in later checkpoint_path=self.checkpoint_path.name, num_machines=2, distributed_init_method="file://%s" % os.path.join(sync_path.name, "sync"), workers=2, ) dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4]) self.addCleanup(dataset.cleanup) train_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[0].name], ) # Just make sure no exceptions are raised and nothing crashes. trainer0 = mp.get_context("spawn").Process( name="Trainer-0", target=partial( call_one_after_the_other, self.subprocess_init, partial(train, train_config, rank=0, subprocess_init=self.subprocess_init), ), ) trainer1 = mp.get_context("spawn").Process( name="Trainer-1", target=partial( call_one_after_the_other, self.subprocess_init, partial(train, train_config, rank=1, subprocess_init=self.subprocess_init), ), ) # FIXME In Python 3.7 use kill here. self.addCleanup(trainer0.terminate) self.addCleanup(trainer1.terminate) trainer0.start() trainer1.start() done = [False, False] while not all(done): time.sleep(1) if not trainer0.is_alive() and not done[0]: self.assertEqual(trainer0.exitcode, 0) done[0] = True if not trainer1.is_alive() and not done[1]: self.assertEqual(trainer1.exitcode, 0) done[1] = True self.assertCheckpointWritten(train_config, version=1) def test_distributed(self): self._test_distributed(num_partitions=4) def test_distributed_unpartitioned(self): self._test_distributed(num_partitions=1) def test_distributed_with_partition_servers(self): sync_path = TemporaryDirectory() self.addCleanup(sync_path.cleanup) entity_name = "e" relation_config = RelationSchema(name="r", lhs=entity_name, rhs=entity_name) base_config = ConfigSchema( dimension=10, relations=[relation_config], entities={entity_name: EntitySchema(num_partitions=4)}, entity_path=None, # filled in later edge_paths=[], # filled in later checkpoint_path=self.checkpoint_path.name, num_machines=2, num_partition_servers=2, distributed_init_method="file://%s" % os.path.join(sync_path.name, "sync"), workers=2, ) dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4]) self.addCleanup(dataset.cleanup) train_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[0].name], ) # Just make sure no exceptions are raised and nothing crashes. trainer0 = mp.get_context("spawn").Process( name="Trainer-0", target=partial( call_one_after_the_other, self.subprocess_init, partial(train, train_config, rank=0, subprocess_init=self.subprocess_init), ), ) trainer1 = mp.get_context("spawn").Process( name="Trainer-1", target=partial( call_one_after_the_other, self.subprocess_init, partial(train, train_config, rank=1, subprocess_init=self.subprocess_init), ), ) partition_server0 = mp.get_context("spawn").Process( name="PartitionServer-0", target=partial( call_one_after_the_other, self.subprocess_init, partial( run_partition_server, train_config, rank=0, subprocess_init=self.subprocess_init, ), ), ) partition_server1 = mp.get_context("spawn").Process( name="PartitionServer-1", target=partial( call_one_after_the_other, self.subprocess_init, partial( run_partition_server, train_config, rank=1, subprocess_init=self.subprocess_init, ), ), ) # FIXME In Python 3.7 use kill here. self.addCleanup(trainer0.terminate) self.addCleanup(trainer1.terminate) self.addCleanup(partition_server0.terminate) self.addCleanup(partition_server1.terminate) trainer0.start() trainer1.start() partition_server0.start() partition_server1.start() done = [False, False] while not all(done): time.sleep(1) if not trainer0.is_alive() and not done[0]: self.assertEqual(trainer0.exitcode, 0) done[0] = True if not trainer1.is_alive() and not done[1]: self.assertEqual(trainer1.exitcode, 0) done[1] = True partition_server0.join() partition_server1.join() logger.info( f"Partition server 0 died with exit code {partition_server0.exitcode}" ) logger.info( f"Partition server 0 died with exit code {partition_server1.exitcode}" ) self.assertCheckpointWritten(train_config, version=1) def test_dynamic_relations(self): relation_config = RelationSchema(name="r", lhs="el", rhs="er") base_config = ConfigSchema( dimension=10, relations=[relation_config], entities={ "el": EntitySchema(num_partitions=1), "er": EntitySchema(num_partitions=1), }, entity_path=None, # filled in later edge_paths=[], # filled in later checkpoint_path=self.checkpoint_path.name, dynamic_relations=True, global_emb=False, # Must be off for dynamic relations. workers=2, ) gen_config = attr.evolve( base_config, relations=[relation_config] * 10, dynamic_relations=False, # Must be off if more than 1 relation. ) dataset = generate_dataset(gen_config, num_entities=100, fractions=[0.04, 0.02]) self.addCleanup(dataset.cleanup) with open( os.path.join(dataset.entity_path.name, "dynamic_rel_count.txt"), "xt") as f: f.write("%d" % len(gen_config.relations)) train_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[0].name], ) eval_config = attr.evolve( base_config, relations=[attr.evolve(relation_config, all_negs=True)], entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[1].name], ) # Just make sure no exceptions are raised and nothing crashes. train(train_config, rank=0, subprocess_init=self.subprocess_init) self.assertCheckpointWritten(train_config, version=1) do_eval(eval_config, subprocess_init=self.subprocess_init) def test_entity_dimensions(self): entity_name = "e" relation_config = RelationSchema(name="r", lhs=entity_name, rhs=entity_name) base_config = ConfigSchema( dimension=10, relations=[relation_config], entities={ entity_name: EntitySchema(num_partitions=1, dimension=8) }, entity_path=None, # filled in later edge_paths=[], # filled in later checkpoint_path=self.checkpoint_path.name, workers=2, ) dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4, 0.2]) self.addCleanup(dataset.cleanup) train_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[0].name], ) eval_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[1].name], relations=[attr.evolve(relation_config, all_negs=True)], ) # Just make sure no exceptions are raised and nothing crashes. train(train_config, rank=0, subprocess_init=self.subprocess_init) self.assertCheckpointWritten(train_config, version=1) do_eval(eval_config, subprocess_init=self.subprocess_init)
def main(): setup_logging() parser = argparse.ArgumentParser(description="Example on FB15k") parser.add_argument("--config", default=DEFAULT_CONFIG, help="Path to config file") parser.add_argument("-p", "--param", action="append", nargs="*") parser.add_argument("--data_dir", type=Path, default="data", help="where to save processed data") parser.add_argument( "--no-filtered", dest="filtered", action="store_false", help="Run unfiltered eval", ) args = parser.parse_args() # download data data_dir = args.data_dir fpath = download_url(FB15K_URL, data_dir) extract_tar(fpath) print("Downloaded and extracted file.") loader = ConfigFileLoader() config = loader.load_config(args.config, args.param) set_logging_verbosity(config.verbose) subprocess_init = SubprocessInitializer() subprocess_init.register(setup_logging, config.verbose) subprocess_init.register(add_to_sys_path, loader.config_dir.name) input_edge_paths = [data_dir / name for name in FILENAMES] output_train_path, output_valid_path, output_test_path = config.edge_paths convert_input_data( config.entities, config.relations, config.entity_path, config.edge_paths, input_edge_paths, TSVEdgelistReader(lhs_col=0, rhs_col=2, rel_col=1), dynamic_relations=config.dynamic_relations, ) train_config = attr.evolve(config, edge_paths=[output_train_path]) train(train_config, subprocess_init=subprocess_init) relations = [attr.evolve(r, all_negs=True) for r in config.relations] eval_config = attr.evolve(config, edge_paths=[output_test_path], relations=relations, num_uniform_negs=0) if args.filtered: filter_paths = [output_test_path, output_valid_path, output_train_path] do_eval( eval_config, evaluator=FilteredRankingEvaluator(eval_config, filter_paths), subprocess_init=subprocess_init, ) else: do_eval(eval_config, subprocess_init=subprocess_init)
def run(input_file: KGTKFiles, output_file: KGTKFiles, verbose: bool = False, very_verbose: bool = False, **kwargs): """ **kwargs stores all parameters providing by user """ # print(kwargs) # import modules locally import sys import typing import os import logging from pathlib import Path import json, os, h5py, gzip, torch, shutil from torchbiggraph.config import parse_config from kgtk.exceptions import KGTKException # copy missing file under kgtk/graph_embeddings from kgtk.templates.kgtkcopytemplate import KgtkCopyTemplate from kgtk.graph_embeddings.importers import TSVEdgelistReader, convert_input_data from torchbiggraph.train import train from torchbiggraph.util import SubprocessInitializer, setup_logging from kgtk.graph_embeddings.export_to_tsv import make_tsv # from torchbiggraph.converters.export_to_tsv import make_tsv try: input_kgtk_file: Path = KGTKArgumentParser.get_input_file(input_file) output_kgtk_file: Path = KGTKArgumentParser.get_output_file( output_file) # store the data into log file, then the console will not output anything if kwargs['log_file_path'] != None: log_file_path = kwargs['log_file_path'] logging.basicConfig( format='%(asctime)s - %(filename)s[line:%(lineno)d] \ - %(levelname)s: %(message)s', level=logging.DEBUG, filename=str(log_file_path), filemode='w') print( f'In Processing, Please go to {kwargs["log_file_path"]} to check details', file=sys.stderr, flush=True) tmp_folder = kwargs['temporary_directory'] tmp_tsv_path: Path = tmp_folder / f'tmp_{input_kgtk_file.name}' # tmp_tsv_path:Path = input_kgtk_file.parent/f'tmp_{input_kgtk_file.name}' # make sure the tmp folder exists, otherwise it will raise an exception if not os.path.exists(tmp_folder): os.makedirs(tmp_folder) try: #if output_kgtk_file is not empty, delete it output_kgtk_file.unlink() except: pass # didn't find, then let it go # ********************************************* # 0. PREPARE PBG TSV FILE # ********************************************* reader_options: KgtkReaderOptions = KgtkReaderOptions.from_dict(kwargs) value_options: KgtkValueOptions = KgtkValueOptions.from_dict(kwargs) error_file: typing.TextIO = sys.stdout if kwargs.get( "errors_to_stdout") else sys.stderr kct: KgtkCopyTemplate = KgtkCreateTmpTsv( input_file_path=input_kgtk_file, output_file_path=tmp_tsv_path, reader_options=reader_options, value_options=value_options, error_file=error_file, verbose=verbose, very_verbose=very_verbose, ) # prepare the graph file # create a tmp tsv file for PBG embedding logging.info('Generate the valid tsv format for embedding ...') kct.process() logging.info('Embedding file is ready...') # ********************************************* # 1. DEFINE CONFIG # ********************************************* raw_config = get_config(**kwargs) ## setting corresponding learning rate and loss function for different algorthim processed_config = config_preprocess(raw_config) # temporry output folder tmp_output_folder = Path(processed_config['entity_path']) # before moving, need to check whether the tmp folder is not empty in case of bug try: #if temporry output folder is alrady existing then delete it shutil.rmtree(tmp_output_folder) except: pass # didn't find, then let it go # ************************************************** # 2. TRANSFORM GRAPH TO A BIGGRAPH-FRIENDLY FORMAT # ************************************************** setup_logging() config = parse_config(processed_config) subprocess_init = SubprocessInitializer() input_edge_paths = [tmp_tsv_path] convert_input_data( config.entities, config.relations, config.entity_path, config.edge_paths, input_edge_paths, TSVEdgelistReader(lhs_col=0, rel_col=1, rhs_col=2), dynamic_relations=config.dynamic_relations, ) # ************************************************ # 3. TRAIN THE EMBEDDINGS #************************************************* train(config, subprocess_init=subprocess_init) # ************************************************ # 4. GENERATE THE OUTPUT # ************************************************ # entities_output = output_kgtk_file entities_output = tmp_output_folder / 'entities_output.tsv' relation_types_output = tmp_output_folder / 'relation_types_tf.tsv' with open(entities_output, "xt") as entities_tf, open(relation_types_output, "xt") as relation_types_tf: make_tsv(config, entities_tf, relation_types_tf) # output correct format for embeddings if kwargs['output_format'] == 'glove': # glove format output shutil.copyfile(entities_output, output_kgtk_file) elif kwargs['output_format'] == 'w2v': # w2v format output generate_w2v_output(entities_output, output_kgtk_file, kwargs) else: # write to the kgtk output format tsv generate_kgtk_output(entities_output, output_kgtk_file, kwargs.get('output_no_header', False), verbose, very_verbose) logging.info(f'Embeddings has been generated in {output_kgtk_file}.') # ************************************************ # 5. Garbage collection # ************************************************ if kwargs['retain_temporary_data'] == False: shutil.rmtree(kwargs['temporary_directory']) # tmp_tsv_path.unlink() # delete temporay tsv file # shutil.rmtree(tmp_output_folder) # deleter temporay output folder if kwargs["log_file_path"] != None: print('Processed Finished.', file=sys.stderr, flush=True) logging.info( f"Process Finished.\nOutput has been saved in {repr(str(output_kgtk_file))}" ) else: print( f"Process Finished.\nOutput has been saved in {repr(str(output_kgtk_file))}", file=sys.stderr, flush=True) except Exception as e: raise KGTKException(str(e))