Пример #1
0
 def test_featurized(self):
     e1 = EntitySchema(num_partitions=1, featurized=True)
     e2 = EntitySchema(num_partitions=1)
     r1 = RelationSchema(name="r1", lhs="e1", rhs="e2")
     r2 = RelationSchema(name="r2", lhs="e2", rhs="e1")
     base_config = ConfigSchema(
         dimension=10,
         relations=[r1, r2],
         entities={
             "e1": e1,
             "e2": e2
         },
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         workers=2,
     )
     dataset = generate_dataset(base_config,
                                num_entities=100,
                                fractions=[0.4, 0.2])
     self.addCleanup(dataset.cleanup)
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[0].name],
     )
     eval_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[1].name],
     )
     # Just make sure no exceptions are raised and nothing crashes.
     train(train_config, rank=0, subprocess_init=self.subprocess_init)
     self.assertCheckpointWritten(train_config, version=1)
     do_eval(eval_config, subprocess_init=self.subprocess_init)
Пример #2
0
 def test_entity_dimensions(self):
     entity_name = "e"
     relation_config = RelationSchema(name="r",
                                      lhs=entity_name,
                                      rhs=entity_name)
     base_config = ConfigSchema(
         dimension=10,
         relations=[relation_config],
         entities={
             entity_name: EntitySchema(num_partitions=1, dimension=8)
         },
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         workers=2,
     )
     dataset = generate_dataset(base_config,
                                num_entities=100,
                                fractions=[0.4, 0.2])
     self.addCleanup(dataset.cleanup)
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[0].name],
     )
     eval_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[1].name],
         relations=[attr.evolve(relation_config, all_negs=True)],
     )
     # Just make sure no exceptions are raised and nothing crashes.
     train(train_config, rank=0, subprocess_init=self.subprocess_init)
     self.assertCheckpointWritten(train_config, version=1)
     do_eval(eval_config, subprocess_init=self.subprocess_init)
Пример #3
0
 def test_resume_from_checkpoint(self):
     entity_name = "e"
     relation_config = RelationSchema(name="r",
                                      lhs=entity_name,
                                      rhs=entity_name)
     base_config = ConfigSchema(
         dimension=10,
         relations=[relation_config],
         entities={entity_name: EntitySchema(num_partitions=1)},
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         num_epochs=2,
         num_edge_chunks=2,
         workers=2,
     )
     dataset = generate_dataset(base_config,
                                num_entities=100,
                                fractions=[0.4, 0.4])
     self.addCleanup(dataset.cleanup)
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[d.name for d in dataset.relation_paths],
     )
     # Just make sure no exceptions are raised and nothing crashes.
     init_embeddings(train_config.checkpoint_path, train_config, version=7)
     train(train_config, rank=0, subprocess_init=self.subprocess_init)
     self.assertCheckpointWritten(train_config, version=8)
     # Check we did resume the run, not start the whole thing anew.
     self.assertFalse(
         os.path.exists(
             os.path.join(train_config.checkpoint_path, "model.v6.h5")))
Пример #4
0
 def test_with_initial_value(self):
     entity_name = "e"
     relation_config = RelationSchema(name="r",
                                      lhs=entity_name,
                                      rhs=entity_name)
     base_config = ConfigSchema(
         dimension=10,
         relations=[relation_config],
         entities={entity_name: EntitySchema(num_partitions=1)},
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         workers=2,
     )
     dataset = generate_dataset(base_config,
                                num_entities=100,
                                fractions=[0.4])
     self.addCleanup(dataset.cleanup)
     init_dir = TemporaryDirectory()
     self.addCleanup(init_dir.cleanup)
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[0].name],
         init_path=init_dir.name,
     )
     # Just make sure no exceptions are raised and nothing crashes.
     init_embeddings(train_config.init_path, train_config)
     train(train_config, rank=0, subprocess_init=self.subprocess_init)
     self.assertCheckpointWritten(train_config, version=1)
Пример #5
0
def validate_config(
    config: str,
) -> Tuple[Dict[str, EntitySchema], List[RelationSchema], str, bool]:
    user_config = get_config_dict_from_module(config)

    # validate entites and relations config
    entities_config = user_config.get("entities")
    relations_config = user_config.get("relations")
    entity_path = user_config.get("entity_path")
    dynamic_relations = user_config.get("dynamic_relations", False)
    if not isinstance(entities_config, dict):
        raise TypeError("Config entities is not of type dict")
    if not isinstance(relations_config, list):
        raise TypeError("Config relations is not of type list")
    if not isinstance(entity_path, str):
        raise TypeError("Config entity_path is not of type str")
    if not isinstance(dynamic_relations, bool):
        raise TypeError("Config dynamic_relations is not of type bool")

    entities = {}
    relations = []
    for entity, entity_config in entities_config.items():
        entities[entity] = EntitySchema.from_dict(entity_config)
    for relation in relations_config:
        relations.append(RelationSchema.from_dict(relation))

    return entities, relations, entity_path, dynamic_relations
 def _test_gpu(self, do_half_precision=False, num_partitions=2):
     entity_name = "e"
     relation_config = RelationSchema(name="r", lhs=entity_name, rhs=entity_name)
     base_config = ConfigSchema(
         dimension=16,
         batch_size=1024,
         num_batch_negs=64,
         num_uniform_negs=64,
         relations=[relation_config],
         entities={entity_name: EntitySchema(num_partitions=num_partitions)},
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         workers=2,
         num_gpus=2,
         regularization_coef=1e-4,
         half_precision=do_half_precision,
     )
     dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4, 0.2])
     self.addCleanup(dataset.cleanup)
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[0].name],
     )
     eval_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[1].name],
         relations=[attr.evolve(relation_config, all_negs=True)],
     )
     # Just make sure no exceptions are raised and nothing crashes.
     train(train_config, rank=0, subprocess_init=self.subprocess_init)
     self.assertCheckpointWritten(train_config, version=1)
     do_eval(eval_config, subprocess_init=self.subprocess_init)
Пример #7
0
def parse_config_partial(
    config_dict: Any,
) -> Tuple[Dict[str, EntitySchema], List[RelationSchema], str, List[str], bool]:
    entities_config = config_dict.get("entities")
    relations_config = config_dict.get("relations")
    entity_path = config_dict.get("entity_path")
    edge_paths = config_dict.get("edge_paths")
    dynamic_relations = config_dict.get("dynamic_relations", False)
    if not isinstance(entities_config, dict):
        raise TypeError("Config entities is not of type dict")
    if any(not isinstance(k, str) for k in entities_config.keys()):
        raise TypeError("Config entities has some keys that are not of type str")
    if not isinstance(relations_config, list):
        raise TypeError("Config relations is not of type list")
    if not isinstance(entity_path, str):
        raise TypeError("Config entity_path is not of type str")
    if not isinstance(edge_paths, list):
        raise TypeError("Config edge_paths is not of type list")
    if any(not isinstance(p, str) for p in edge_paths):
        raise TypeError("Config edge_paths has some items that are not of type str")
    if not isinstance(dynamic_relations, bool):
        raise TypeError("Config dynamic_relations is not of type bool")

    entities: Dict[str, EntitySchema] = {}
    relations: List[RelationSchema] = []
    for entity, entity_config in entities_config.items():
        entities[entity] = EntitySchema.from_dict(entity_config)
    for relation in relations_config:
        relations.append(RelationSchema.from_dict(relation))

    return entities, relations, entity_path, edge_paths, dynamic_relations
Пример #8
0
 def test_dynamic_relations(self):
     relation_config = RelationSchema(name="r", lhs="el", rhs="er")
     base_config = ConfigSchema(
         dimension=10,
         relations=[relation_config],
         entities={
             "el": EntitySchema(num_partitions=1),
             "er": EntitySchema(num_partitions=1),
         },
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         dynamic_relations=True,
         global_emb=False,  # Must be off for dynamic relations.
         workers=2,
     )
     gen_config = attr.evolve(
         base_config,
         relations=[relation_config] * 10,
         dynamic_relations=False,  # Must be off if more than 1 relation.
     )
     dataset = generate_dataset(gen_config,
                                num_entities=100,
                                fractions=[0.04, 0.02])
     self.addCleanup(dataset.cleanup)
     with open(
             os.path.join(dataset.entity_path.name,
                          "dynamic_rel_count.txt"), "xt") as f:
         f.write("%d" % len(gen_config.relations))
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[0].name],
     )
     eval_config = attr.evolve(
         base_config,
         relations=[attr.evolve(relation_config, all_negs=True)],
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[1].name],
     )
     # Just make sure no exceptions are raised and nothing crashes.
     train(train_config, rank=0, subprocess_init=self.subprocess_init)
     self.assertCheckpointWritten(train_config, version=1)
     do_eval(eval_config, subprocess_init=self.subprocess_init)
Пример #9
0
 def test_distributed(self):
     sync_path = TemporaryDirectory()
     self.addCleanup(sync_path.cleanup)
     entity_name = "e"
     relation_config = RelationSchema(
         name="r",
         lhs=entity_name,
         rhs=entity_name,
         operator="linear",  # To exercise the parameter server.
     )
     base_config = ConfigSchema(
         dimension=10,
         relations=[relation_config],
         entities={entity_name: EntitySchema(num_partitions=4)},
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         num_machines=2,
         distributed_init_method="file://%s" %
         os.path.join(sync_path.name, "sync"),
         workers=2,
     )
     dataset = generate_dataset(base_config,
                                num_entities=100,
                                fractions=[0.4])
     self.addCleanup(dataset.cleanup)
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[0].name],
     )
     # Just make sure no exceptions are raised and nothing crashes.
     trainer0 = mp.get_context("spawn").Process(name="trainer#0",
                                                target=train,
                                                args=(train_config, ),
                                                kwargs={"rank": 0})
     trainer1 = mp.get_context("spawn").Process(name="trainer#1",
                                                target=train,
                                                args=(train_config, ),
                                                kwargs={"rank": 1})
     # FIXME In Python 3.7 use kill here.
     self.addCleanup(trainer0.terminate)
     self.addCleanup(trainer1.terminate)
     trainer0.start()
     trainer1.start()
     done = [False, False]
     while not all(done):
         time.sleep(1)
         if not trainer0.is_alive() and not done[0]:
             self.assertEqual(trainer0.exitcode, 0)
             done[0] = True
         if not trainer1.is_alive() and not done[1]:
             self.assertEqual(trainer1.exitcode, 0)
             done[1] = True
     self.assertCheckpointWritten(train_config, version=1)
Пример #10
0
 def test_basic(self):
     config = ConfigSchema(
         entities={"e": EntitySchema(num_partitions=1)},
         relations=[RelationSchema(name="r", lhs="e", rhs="e")],
         dimension=1,
         entity_path="foo", edge_paths=["bar"], checkpoint_path="baz")
     metadata = ConfigMetadataProvider(config).get_checkpoint_metadata()
     self.assertIsInstance(metadata, dict)
     self.assertCountEqual(metadata.keys(), ["config/json"])
     self.assertEqual(
         config, ConfigSchema.from_dict(json.loads(metadata["config/json"])))
Пример #11
0
 def test_distributed_with_partition_servers(self):
     sync_path = TemporaryDirectory()
     self.addCleanup(sync_path.cleanup)
     entity_name = "e"
     relation_config = RelationSchema(name="r",
                                      lhs=entity_name,
                                      rhs=entity_name)
     base_config = ConfigSchema(
         dimension=10,
         relations=[relation_config],
         entities={entity_name: EntitySchema(num_partitions=4)},
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         num_machines=2,
         num_partition_servers=2,
         distributed_init_method="file://%s" %
         os.path.join(sync_path.name, "sync"),
         workers=2,
     )
     dataset = generate_dataset(base_config,
                                num_entities=100,
                                fractions=[0.4])
     self.addCleanup(dataset.cleanup)
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[0].name],
     )
     # Just make sure no exceptions are raised and nothing crashes.
     trainer0 = mp.get_context("spawn").Process(
         name="Trainer-0",
         target=partial(
             call_one_after_the_other,
             self.subprocess_init,
             partial(train,
                     train_config,
                     rank=0,
                     subprocess_init=self.subprocess_init),
         ),
     )
     trainer1 = mp.get_context("spawn").Process(
         name="Trainer-1",
         target=partial(
             call_one_after_the_other,
             self.subprocess_init,
             partial(train,
                     train_config,
                     rank=1,
                     subprocess_init=self.subprocess_init),
         ),
     )
     partition_server0 = mp.get_context("spawn").Process(
         name="PartitionServer-0",
         target=partial(
             call_one_after_the_other,
             self.subprocess_init,
             partial(
                 run_partition_server,
                 train_config,
                 rank=0,
                 subprocess_init=self.subprocess_init,
             ),
         ),
     )
     partition_server1 = mp.get_context("spawn").Process(
         name="PartitionServer-1",
         target=partial(
             call_one_after_the_other,
             self.subprocess_init,
             partial(
                 run_partition_server,
                 train_config,
                 rank=1,
                 subprocess_init=self.subprocess_init,
             ),
         ),
     )
     # FIXME In Python 3.7 use kill here.
     self.addCleanup(trainer0.terminate)
     self.addCleanup(trainer1.terminate)
     self.addCleanup(partition_server0.terminate)
     self.addCleanup(partition_server1.terminate)
     trainer0.start()
     trainer1.start()
     partition_server0.start()
     partition_server1.start()
     done = [False, False]
     while not all(done):
         time.sleep(1)
         if not trainer0.is_alive() and not done[0]:
             self.assertEqual(trainer0.exitcode, 0)
             done[0] = True
         if not trainer1.is_alive() and not done[1]:
             self.assertEqual(trainer1.exitcode, 0)
             done[1] = True
     partition_server0.join()
     partition_server1.join()
     logger.info(
         f"Partition server 0 died with exit code {partition_server0.exitcode}"
     )
     logger.info(
         f"Partition server 0 died with exit code {partition_server1.exitcode}"
     )
     self.assertCheckpointWritten(train_config, version=1)
Пример #12
0
 def test_distributed_with_partition_servers(self):
     sync_path = TemporaryDirectory()
     self.addCleanup(sync_path.cleanup)
     entity_name = "e"
     relation_config = RelationSchema(name="r",
                                      lhs=entity_name,
                                      rhs=entity_name)
     base_config = ConfigSchema(
         dimension=10,
         relations=[relation_config],
         entities={entity_name: EntitySchema(num_partitions=4)},
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         num_machines=2,
         num_partition_servers=1,
         distributed_init_method="file://%s" %
         os.path.join(sync_path.name, "sync"),
     )
     dataset = generate_dataset(base_config,
                                num_entities=100,
                                fractions=[0.4])
     self.addCleanup(dataset.cleanup)
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[0].name],
     )
     # Just make sure no exceptions are raised and nothing crashes.
     trainer0 = Process(name="trainer#0",
                        target=train,
                        args=(train_config, ),
                        kwargs={"rank": 0})
     trainer1 = Process(name="trainer#1",
                        target=train,
                        args=(train_config, ),
                        kwargs={"rank": 1})
     partition_server = Process(name="partition server#0",
                                target=run_partition_server,
                                args=(train_config, ),
                                kwargs={"rank": 0})
     # FIXME In Python 3.7 use kill here.
     self.addCleanup(trainer0.terminate)
     self.addCleanup(trainer1.terminate)
     self.addCleanup(partition_server.terminate)
     trainer0.start()
     trainer1.start()
     partition_server.start()
     done = [False, False]
     while not all(done):
         time.sleep(1)
         if not trainer0.is_alive() and not done[0]:
             self.assertEqual(trainer0.exitcode, 0)
             done[0] = True
         if not trainer1.is_alive() and not done[1]:
             self.assertEqual(trainer1.exitcode, 0)
             done[1] = True
         if not partition_server.is_alive():
             self.fail("Partition server died with exit code %d" %
                       partition_server.exitcode)
     partition_server.terminate()  # Cannot be shut down gracefully.
     partition_server.join()
     logging.info("Partition server died with exit code %d",
                  partition_server.exitcode)
     self.assertCheckpointWritten(train_config, version=1)
    def _test_distributed_with_partition_servers(
        self,
        num_trainers,
        num_partition_servers,
        num_groups_per_sharded_partition_server=1,
    ):
        sync_path = TemporaryDirectory()
        self.addCleanup(sync_path.cleanup)
        entity_name = "e"
        relation_config = RelationSchema(name="r",
                                         lhs=entity_name,
                                         rhs=entity_name)
        base_config = ConfigSchema(
            dimension=10,
            relations=[relation_config],
            entities={
                entity_name: EntitySchema(num_partitions=2 * num_trainers)
            },
            entity_path=None,  # filled in later
            edge_paths=[],  # filled in later
            checkpoint_path=self.checkpoint_path.name,
            num_machines=num_trainers,
            num_partition_servers=num_partition_servers,
            distributed_init_method="file://%s" %
            os.path.join(sync_path.name, "sync"),
            workers=2,
            num_groups_per_sharded_partition_server=
            num_groups_per_sharded_partition_server,
        )
        dataset = generate_dataset(base_config,
                                   num_entities=100,
                                   fractions=[0.4])
        self.addCleanup(dataset.cleanup)
        train_config = attr.evolve(
            base_config,
            entity_path=dataset.entity_path.name,
            edge_paths=[dataset.relation_paths[0].name],
        )

        # Just make sure no exceptions are raised and nothing crashes.
        trainers = []
        for rank in range(num_trainers):
            trainers.append(
                mp.get_context("spawn").Process(
                    name=f"Trainer-{rank}",
                    target=partial(
                        call_one_after_the_other,
                        self.subprocess_init,
                        partial(
                            train,
                            train_config,
                            rank=rank,
                            subprocess_init=self.subprocess_init,
                        ),
                    ),
                ))

        partition_servers = []
        for rank in range(num_partition_servers):
            partition_servers.append(
                mp.get_context("spawn").Process(
                    name=f"PartitionServer-{rank}",
                    target=partial(
                        call_one_after_the_other,
                        self.subprocess_init,
                        partial(
                            run_partition_server,
                            train_config,
                            rank=rank,
                            subprocess_init=self.subprocess_init,
                        ),
                    ),
                ))

        # FIXME In Python 3.7 use kill here.
        for proc in trainers + partition_servers:
            self.addCleanup(proc.terminate)

        for proc in trainers + partition_servers:
            proc.start()

        done = [False] * num_trainers
        while not all(done):
            time.sleep(1)
            for (rank, trainer) in enumerate(trainers):
                if not trainer.is_alive() and not done[rank]:
                    self.assertEqual(trainer.exitcode, 0)
                    done[rank] = True

        for partition_server in partition_servers:
            partition_server.join()

        for (rank, partition_server) in enumerate(partition_servers):
            logger.info(
                f"Partition server {rank} died with exit code {partition_server.exitcode}"
            )

        self.assertCheckpointWritten(train_config, version=1)