def test_featurized(self): e1 = EntitySchema(num_partitions=1, featurized=True) e2 = EntitySchema(num_partitions=1) r1 = RelationSchema(name="r1", lhs="e1", rhs="e2") r2 = RelationSchema(name="r2", lhs="e2", rhs="e1") base_config = ConfigSchema( dimension=10, relations=[r1, r2], entities={ "e1": e1, "e2": e2 }, entity_path=None, # filled in later edge_paths=[], # filled in later checkpoint_path=self.checkpoint_path.name, workers=2, ) dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4, 0.2]) self.addCleanup(dataset.cleanup) train_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[0].name], ) eval_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[1].name], ) # Just make sure no exceptions are raised and nothing crashes. train(train_config, rank=0, subprocess_init=self.subprocess_init) self.assertCheckpointWritten(train_config, version=1) do_eval(eval_config, subprocess_init=self.subprocess_init)
def test_entity_dimensions(self): entity_name = "e" relation_config = RelationSchema(name="r", lhs=entity_name, rhs=entity_name) base_config = ConfigSchema( dimension=10, relations=[relation_config], entities={ entity_name: EntitySchema(num_partitions=1, dimension=8) }, entity_path=None, # filled in later edge_paths=[], # filled in later checkpoint_path=self.checkpoint_path.name, workers=2, ) dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4, 0.2]) self.addCleanup(dataset.cleanup) train_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[0].name], ) eval_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[1].name], relations=[attr.evolve(relation_config, all_negs=True)], ) # Just make sure no exceptions are raised and nothing crashes. train(train_config, rank=0, subprocess_init=self.subprocess_init) self.assertCheckpointWritten(train_config, version=1) do_eval(eval_config, subprocess_init=self.subprocess_init)
def test_resume_from_checkpoint(self): entity_name = "e" relation_config = RelationSchema(name="r", lhs=entity_name, rhs=entity_name) base_config = ConfigSchema( dimension=10, relations=[relation_config], entities={entity_name: EntitySchema(num_partitions=1)}, entity_path=None, # filled in later edge_paths=[], # filled in later checkpoint_path=self.checkpoint_path.name, num_epochs=2, num_edge_chunks=2, workers=2, ) dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4, 0.4]) self.addCleanup(dataset.cleanup) train_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[d.name for d in dataset.relation_paths], ) # Just make sure no exceptions are raised and nothing crashes. init_embeddings(train_config.checkpoint_path, train_config, version=7) train(train_config, rank=0, subprocess_init=self.subprocess_init) self.assertCheckpointWritten(train_config, version=8) # Check we did resume the run, not start the whole thing anew. self.assertFalse( os.path.exists( os.path.join(train_config.checkpoint_path, "model.v6.h5")))
def test_with_initial_value(self): entity_name = "e" relation_config = RelationSchema(name="r", lhs=entity_name, rhs=entity_name) base_config = ConfigSchema( dimension=10, relations=[relation_config], entities={entity_name: EntitySchema(num_partitions=1)}, entity_path=None, # filled in later edge_paths=[], # filled in later checkpoint_path=self.checkpoint_path.name, workers=2, ) dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4]) self.addCleanup(dataset.cleanup) init_dir = TemporaryDirectory() self.addCleanup(init_dir.cleanup) train_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[0].name], init_path=init_dir.name, ) # Just make sure no exceptions are raised and nothing crashes. init_embeddings(train_config.init_path, train_config) train(train_config, rank=0, subprocess_init=self.subprocess_init) self.assertCheckpointWritten(train_config, version=1)
def validate_config( config: str, ) -> Tuple[Dict[str, EntitySchema], List[RelationSchema], str, bool]: user_config = get_config_dict_from_module(config) # validate entites and relations config entities_config = user_config.get("entities") relations_config = user_config.get("relations") entity_path = user_config.get("entity_path") dynamic_relations = user_config.get("dynamic_relations", False) if not isinstance(entities_config, dict): raise TypeError("Config entities is not of type dict") if not isinstance(relations_config, list): raise TypeError("Config relations is not of type list") if not isinstance(entity_path, str): raise TypeError("Config entity_path is not of type str") if not isinstance(dynamic_relations, bool): raise TypeError("Config dynamic_relations is not of type bool") entities = {} relations = [] for entity, entity_config in entities_config.items(): entities[entity] = EntitySchema.from_dict(entity_config) for relation in relations_config: relations.append(RelationSchema.from_dict(relation)) return entities, relations, entity_path, dynamic_relations
def _test_gpu(self, do_half_precision=False, num_partitions=2): entity_name = "e" relation_config = RelationSchema(name="r", lhs=entity_name, rhs=entity_name) base_config = ConfigSchema( dimension=16, batch_size=1024, num_batch_negs=64, num_uniform_negs=64, relations=[relation_config], entities={entity_name: EntitySchema(num_partitions=num_partitions)}, entity_path=None, # filled in later edge_paths=[], # filled in later checkpoint_path=self.checkpoint_path.name, workers=2, num_gpus=2, regularization_coef=1e-4, half_precision=do_half_precision, ) dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4, 0.2]) self.addCleanup(dataset.cleanup) train_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[0].name], ) eval_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[1].name], relations=[attr.evolve(relation_config, all_negs=True)], ) # Just make sure no exceptions are raised and nothing crashes. train(train_config, rank=0, subprocess_init=self.subprocess_init) self.assertCheckpointWritten(train_config, version=1) do_eval(eval_config, subprocess_init=self.subprocess_init)
def parse_config_partial( config_dict: Any, ) -> Tuple[Dict[str, EntitySchema], List[RelationSchema], str, List[str], bool]: entities_config = config_dict.get("entities") relations_config = config_dict.get("relations") entity_path = config_dict.get("entity_path") edge_paths = config_dict.get("edge_paths") dynamic_relations = config_dict.get("dynamic_relations", False) if not isinstance(entities_config, dict): raise TypeError("Config entities is not of type dict") if any(not isinstance(k, str) for k in entities_config.keys()): raise TypeError("Config entities has some keys that are not of type str") if not isinstance(relations_config, list): raise TypeError("Config relations is not of type list") if not isinstance(entity_path, str): raise TypeError("Config entity_path is not of type str") if not isinstance(edge_paths, list): raise TypeError("Config edge_paths is not of type list") if any(not isinstance(p, str) for p in edge_paths): raise TypeError("Config edge_paths has some items that are not of type str") if not isinstance(dynamic_relations, bool): raise TypeError("Config dynamic_relations is not of type bool") entities: Dict[str, EntitySchema] = {} relations: List[RelationSchema] = [] for entity, entity_config in entities_config.items(): entities[entity] = EntitySchema.from_dict(entity_config) for relation in relations_config: relations.append(RelationSchema.from_dict(relation)) return entities, relations, entity_path, edge_paths, dynamic_relations
def test_dynamic_relations(self): relation_config = RelationSchema(name="r", lhs="el", rhs="er") base_config = ConfigSchema( dimension=10, relations=[relation_config], entities={ "el": EntitySchema(num_partitions=1), "er": EntitySchema(num_partitions=1), }, entity_path=None, # filled in later edge_paths=[], # filled in later checkpoint_path=self.checkpoint_path.name, dynamic_relations=True, global_emb=False, # Must be off for dynamic relations. workers=2, ) gen_config = attr.evolve( base_config, relations=[relation_config] * 10, dynamic_relations=False, # Must be off if more than 1 relation. ) dataset = generate_dataset(gen_config, num_entities=100, fractions=[0.04, 0.02]) self.addCleanup(dataset.cleanup) with open( os.path.join(dataset.entity_path.name, "dynamic_rel_count.txt"), "xt") as f: f.write("%d" % len(gen_config.relations)) train_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[0].name], ) eval_config = attr.evolve( base_config, relations=[attr.evolve(relation_config, all_negs=True)], entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[1].name], ) # Just make sure no exceptions are raised and nothing crashes. train(train_config, rank=0, subprocess_init=self.subprocess_init) self.assertCheckpointWritten(train_config, version=1) do_eval(eval_config, subprocess_init=self.subprocess_init)
def test_distributed(self): sync_path = TemporaryDirectory() self.addCleanup(sync_path.cleanup) entity_name = "e" relation_config = RelationSchema( name="r", lhs=entity_name, rhs=entity_name, operator="linear", # To exercise the parameter server. ) base_config = ConfigSchema( dimension=10, relations=[relation_config], entities={entity_name: EntitySchema(num_partitions=4)}, entity_path=None, # filled in later edge_paths=[], # filled in later checkpoint_path=self.checkpoint_path.name, num_machines=2, distributed_init_method="file://%s" % os.path.join(sync_path.name, "sync"), workers=2, ) dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4]) self.addCleanup(dataset.cleanup) train_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[0].name], ) # Just make sure no exceptions are raised and nothing crashes. trainer0 = mp.get_context("spawn").Process(name="trainer#0", target=train, args=(train_config, ), kwargs={"rank": 0}) trainer1 = mp.get_context("spawn").Process(name="trainer#1", target=train, args=(train_config, ), kwargs={"rank": 1}) # FIXME In Python 3.7 use kill here. self.addCleanup(trainer0.terminate) self.addCleanup(trainer1.terminate) trainer0.start() trainer1.start() done = [False, False] while not all(done): time.sleep(1) if not trainer0.is_alive() and not done[0]: self.assertEqual(trainer0.exitcode, 0) done[0] = True if not trainer1.is_alive() and not done[1]: self.assertEqual(trainer1.exitcode, 0) done[1] = True self.assertCheckpointWritten(train_config, version=1)
def test_basic(self): config = ConfigSchema( entities={"e": EntitySchema(num_partitions=1)}, relations=[RelationSchema(name="r", lhs="e", rhs="e")], dimension=1, entity_path="foo", edge_paths=["bar"], checkpoint_path="baz") metadata = ConfigMetadataProvider(config).get_checkpoint_metadata() self.assertIsInstance(metadata, dict) self.assertCountEqual(metadata.keys(), ["config/json"]) self.assertEqual( config, ConfigSchema.from_dict(json.loads(metadata["config/json"])))
def test_distributed_with_partition_servers(self): sync_path = TemporaryDirectory() self.addCleanup(sync_path.cleanup) entity_name = "e" relation_config = RelationSchema(name="r", lhs=entity_name, rhs=entity_name) base_config = ConfigSchema( dimension=10, relations=[relation_config], entities={entity_name: EntitySchema(num_partitions=4)}, entity_path=None, # filled in later edge_paths=[], # filled in later checkpoint_path=self.checkpoint_path.name, num_machines=2, num_partition_servers=2, distributed_init_method="file://%s" % os.path.join(sync_path.name, "sync"), workers=2, ) dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4]) self.addCleanup(dataset.cleanup) train_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[0].name], ) # Just make sure no exceptions are raised and nothing crashes. trainer0 = mp.get_context("spawn").Process( name="Trainer-0", target=partial( call_one_after_the_other, self.subprocess_init, partial(train, train_config, rank=0, subprocess_init=self.subprocess_init), ), ) trainer1 = mp.get_context("spawn").Process( name="Trainer-1", target=partial( call_one_after_the_other, self.subprocess_init, partial(train, train_config, rank=1, subprocess_init=self.subprocess_init), ), ) partition_server0 = mp.get_context("spawn").Process( name="PartitionServer-0", target=partial( call_one_after_the_other, self.subprocess_init, partial( run_partition_server, train_config, rank=0, subprocess_init=self.subprocess_init, ), ), ) partition_server1 = mp.get_context("spawn").Process( name="PartitionServer-1", target=partial( call_one_after_the_other, self.subprocess_init, partial( run_partition_server, train_config, rank=1, subprocess_init=self.subprocess_init, ), ), ) # FIXME In Python 3.7 use kill here. self.addCleanup(trainer0.terminate) self.addCleanup(trainer1.terminate) self.addCleanup(partition_server0.terminate) self.addCleanup(partition_server1.terminate) trainer0.start() trainer1.start() partition_server0.start() partition_server1.start() done = [False, False] while not all(done): time.sleep(1) if not trainer0.is_alive() and not done[0]: self.assertEqual(trainer0.exitcode, 0) done[0] = True if not trainer1.is_alive() and not done[1]: self.assertEqual(trainer1.exitcode, 0) done[1] = True partition_server0.join() partition_server1.join() logger.info( f"Partition server 0 died with exit code {partition_server0.exitcode}" ) logger.info( f"Partition server 0 died with exit code {partition_server1.exitcode}" ) self.assertCheckpointWritten(train_config, version=1)
def test_distributed_with_partition_servers(self): sync_path = TemporaryDirectory() self.addCleanup(sync_path.cleanup) entity_name = "e" relation_config = RelationSchema(name="r", lhs=entity_name, rhs=entity_name) base_config = ConfigSchema( dimension=10, relations=[relation_config], entities={entity_name: EntitySchema(num_partitions=4)}, entity_path=None, # filled in later edge_paths=[], # filled in later checkpoint_path=self.checkpoint_path.name, num_machines=2, num_partition_servers=1, distributed_init_method="file://%s" % os.path.join(sync_path.name, "sync"), ) dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4]) self.addCleanup(dataset.cleanup) train_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[0].name], ) # Just make sure no exceptions are raised and nothing crashes. trainer0 = Process(name="trainer#0", target=train, args=(train_config, ), kwargs={"rank": 0}) trainer1 = Process(name="trainer#1", target=train, args=(train_config, ), kwargs={"rank": 1}) partition_server = Process(name="partition server#0", target=run_partition_server, args=(train_config, ), kwargs={"rank": 0}) # FIXME In Python 3.7 use kill here. self.addCleanup(trainer0.terminate) self.addCleanup(trainer1.terminate) self.addCleanup(partition_server.terminate) trainer0.start() trainer1.start() partition_server.start() done = [False, False] while not all(done): time.sleep(1) if not trainer0.is_alive() and not done[0]: self.assertEqual(trainer0.exitcode, 0) done[0] = True if not trainer1.is_alive() and not done[1]: self.assertEqual(trainer1.exitcode, 0) done[1] = True if not partition_server.is_alive(): self.fail("Partition server died with exit code %d" % partition_server.exitcode) partition_server.terminate() # Cannot be shut down gracefully. partition_server.join() logging.info("Partition server died with exit code %d", partition_server.exitcode) self.assertCheckpointWritten(train_config, version=1)
def _test_distributed_with_partition_servers( self, num_trainers, num_partition_servers, num_groups_per_sharded_partition_server=1, ): sync_path = TemporaryDirectory() self.addCleanup(sync_path.cleanup) entity_name = "e" relation_config = RelationSchema(name="r", lhs=entity_name, rhs=entity_name) base_config = ConfigSchema( dimension=10, relations=[relation_config], entities={ entity_name: EntitySchema(num_partitions=2 * num_trainers) }, entity_path=None, # filled in later edge_paths=[], # filled in later checkpoint_path=self.checkpoint_path.name, num_machines=num_trainers, num_partition_servers=num_partition_servers, distributed_init_method="file://%s" % os.path.join(sync_path.name, "sync"), workers=2, num_groups_per_sharded_partition_server= num_groups_per_sharded_partition_server, ) dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4]) self.addCleanup(dataset.cleanup) train_config = attr.evolve( base_config, entity_path=dataset.entity_path.name, edge_paths=[dataset.relation_paths[0].name], ) # Just make sure no exceptions are raised and nothing crashes. trainers = [] for rank in range(num_trainers): trainers.append( mp.get_context("spawn").Process( name=f"Trainer-{rank}", target=partial( call_one_after_the_other, self.subprocess_init, partial( train, train_config, rank=rank, subprocess_init=self.subprocess_init, ), ), )) partition_servers = [] for rank in range(num_partition_servers): partition_servers.append( mp.get_context("spawn").Process( name=f"PartitionServer-{rank}", target=partial( call_one_after_the_other, self.subprocess_init, partial( run_partition_server, train_config, rank=rank, subprocess_init=self.subprocess_init, ), ), )) # FIXME In Python 3.7 use kill here. for proc in trainers + partition_servers: self.addCleanup(proc.terminate) for proc in trainers + partition_servers: proc.start() done = [False] * num_trainers while not all(done): time.sleep(1) for (rank, trainer) in enumerate(trainers): if not trainer.is_alive() and not done[rank]: self.assertEqual(trainer.exitcode, 0) done[rank] = True for partition_server in partition_servers: partition_server.join() for (rank, partition_server) in enumerate(partition_servers): logger.info( f"Partition server {rank} died with exit code {partition_server.exitcode}" ) self.assertCheckpointWritten(train_config, version=1)