Esempio n. 1
0
def make_model(config: ConfigSchema) -> MultiRelationEmbedder:
    if config.dynamic_relations:
        if len(config.relations) != 1:
            raise RuntimeError(
                "Dynamic relations are enabled, so there should only be one "
                "entry in config.relations with config for all relations.")
        try:
            relation_type_storage = RELATION_TYPE_STORAGES.make_instance(
                config.entity_path)
            num_dynamic_rels = relation_type_storage.load_count()
        except CouldNotLoadData:
            raise RuntimeError(
                "Dynamic relations are enabled, so there should be a file called "
                "dynamic_rel_count.txt in the entity path with their count.")
    else:
        num_dynamic_rels = 0

    if config.num_batch_negs > 0 and config.batch_size % config.num_batch_negs != 0:
        raise RuntimeError(
            "Batch size (%d) must be a multiple of num_batch_negs (%d)" %
            (config.batch_size, config.num_batch_negs))

    lhs_operators: List[Optional[Union[AbstractOperator,
                                       AbstractDynamicOperator]]] = []
    rhs_operators: List[Optional[Union[AbstractOperator,
                                       AbstractDynamicOperator]]] = []
    for r in config.relations:
        lhs_operators.append(
            instantiate_operator(r.operator, Side.LHS, num_dynamic_rels,
                                 config.entity_dimension(r.lhs)))
        rhs_operators.append(
            instantiate_operator(r.operator, Side.RHS, num_dynamic_rels,
                                 config.entity_dimension(r.rhs)))

    comparator_class = COMPARATORS.get_class(config.comparator)
    comparator = comparator_class()

    if config.bias:
        comparator = BiasedComparator(comparator)

    return MultiRelationEmbedder(
        config.dimension,
        config.relations,
        config.entities,
        num_uniform_negs=config.num_uniform_negs,
        num_batch_negs=config.num_batch_negs,
        disable_lhs_negs=config.disable_lhs_negs,
        disable_rhs_negs=config.disable_rhs_negs,
        lhs_operators=lhs_operators,
        rhs_operators=rhs_operators,
        comparator=comparator,
        global_emb=config.global_emb,
        max_norm=config.max_norm,
        num_dynamic_rels=num_dynamic_rels,
        half_precision=config.half_precision,
    )
Esempio n. 2
0
 def test_basic(self):
     config = ConfigSchema(
         entities={"e": EntitySchema(num_partitions=1)},
         relations=[RelationSchema(name="r", lhs="e", rhs="e")],
         dimension=1,
         entity_path="foo", edge_paths=["bar"], checkpoint_path="baz")
     metadata = ConfigMetadataProvider(config).get_checkpoint_metadata()
     self.assertIsInstance(metadata, dict)
     self.assertCountEqual(metadata.keys(), ["config/json"])
     self.assertEqual(
         config, ConfigSchema.from_dict(json.loads(metadata["config/json"])))
Esempio n. 3
0
    def assertCheckpointWritten(self, config: ConfigSchema, *,
                                version: int) -> None:
        with open(
                os.path.join(config.checkpoint_path, "checkpoint_version.txt"),
                "rt") as tf:
            self.assertEqual(version, int(tf.read().strip()))

        with open(os.path.join(config.checkpoint_path, "config.json"),
                  "rt") as tf:
            self.assertEqual(json.load(tf), config.to_dict())

        with h5py.File(
                os.path.join(config.checkpoint_path, "model.v%d.h5" % version),
                "r") as hf:
            self.assertHasMetadata(hf, config)
            self.assertIsModelParameters(hf["model"])
            self.assertIsOptimStateDict(hf["optimizer/state_dict"])

        with open(os.path.join(config.checkpoint_path, "training_stats.json"),
                  "rt") as tf:
            for line in tf:
                self.assertIsStatsDict(json.loads(line))

        for entity_name, entity in config.entities.items():
            for partition in range(entity.num_partitions):
                with open(
                        os.path.join(
                            config.entity_path,
                            "entity_count_%s_%d.txt" %
                            (entity_name, partition),
                        ),
                        "rt",
                ) as tf:
                    entity_count = int(tf.read().strip())
                with h5py.File(
                        os.path.join(
                            config.checkpoint_path,
                            "embeddings_%s_%d.v%d.h5" %
                            (entity_name, partition, version),
                        ),
                        "r",
                ) as hf:
                    self.assertHasMetadata(hf, config)
                    self.assertIsEmbeddings(
                        hf["embeddings"],
                        entity_count,
                        config.entity_dimension(entity_name),
                    )
                    self.assertIsOptimStateDict(hf["optimizer/state_dict"])
 def _test_gpu(self, do_half_precision=False, num_partitions=2):
     entity_name = "e"
     relation_config = RelationSchema(name="r", lhs=entity_name, rhs=entity_name)
     base_config = ConfigSchema(
         dimension=16,
         batch_size=1024,
         num_batch_negs=64,
         num_uniform_negs=64,
         relations=[relation_config],
         entities={entity_name: EntitySchema(num_partitions=num_partitions)},
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         workers=2,
         num_gpus=2,
         regularization_coef=1e-4,
         half_precision=do_half_precision,
     )
     dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4, 0.2])
     self.addCleanup(dataset.cleanup)
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[0].name],
     )
     eval_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[1].name],
         relations=[attr.evolve(relation_config, all_negs=True)],
     )
     # Just make sure no exceptions are raised and nothing crashes.
     train(train_config, rank=0, subprocess_init=self.subprocess_init)
     self.assertCheckpointWritten(train_config, version=1)
     do_eval(eval_config, subprocess_init=self.subprocess_init)
Esempio n. 5
0
def init_embeddings(target: str, config: ConfigSchema, *, version: int = 0):
    with open(os.path.join(target, "checkpoint_version.txt"), "xt") as tf:
        tf.write("%d" % version)
    for entity_name, entity in config.entities.items():
        for partition in range(entity.num_partitions):
            with open(
                    os.path.join(
                        config.entity_path,
                        "entity_count_%s_%d.txt" % (entity_name, partition),
                    ),
                    "rt",
            ) as tf:
                entity_count = int(tf.read().strip())
            with h5py.File(
                    os.path.join(
                        target,
                        "embeddings_%s_%d.v%d.h5" %
                        (entity_name, partition, version),
                    ),
                    "x",
            ) as hf:
                hf.attrs["format_version"] = 1
                hf.create_dataset(
                    "embeddings",
                    data=np.random.randn(entity_count,
                                         config.entity_dimension(entity_name)),
                )
    with h5py.File(os.path.join(target, "model.v%d.h5" % version), "x") as hf:
        hf.attrs["format_version"] = 1
Esempio n. 6
0
def main():
    setup_logging()
    config_help = "\n\nConfig parameters:\n\n" + "\n".join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument("config", help="Path to config file")
    parser.add_argument("-p", "--param", action="append", nargs="*")
    parser.add_argument(
        "--rank",
        type=int,
        default=SINGLE_TRAINER,
        help="For multi-machine, this machine's rank",
    )
    opt = parser.parse_args()

    loader = ConfigFileLoader()
    config = loader.load_config(opt.config, opt.param)
    set_logging_verbosity(config.verbose)
    subprocess_init = SubprocessInitializer()
    subprocess_init.register(setup_logging, config.verbose)
    subprocess_init.register(add_to_sys_path, loader.config_dir.name)

    train(config, rank=opt.rank, subprocess_init=subprocess_init)
Esempio n. 7
0
def main():
    config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('config', help="Path to config file")
    parser.add_argument('-p', '--param', action='append', nargs='*')
    parser.add_argument('--entities-output', required=True)
    parser.add_argument('--relation-types-output', required=True)
    opt = parser.parse_args()

    if opt.param is not None:
        overrides = chain.from_iterable(opt.param)  # flatten
    else:
        overrides = None
    loader = ConfigFileLoader()
    config = loader.load_config(opt.config, overrides)

    with open(opt.entities_output, "xt") as entities_tf, \
            open(opt.relation_types_output, "xt") as relation_types_tf:
        make_tsv(
            config,
            entities_tf,
            relation_types_tf,
        )
Esempio n. 8
0
 def test_resume_from_checkpoint(self):
     entity_name = "e"
     relation_config = RelationSchema(name="r",
                                      lhs=entity_name,
                                      rhs=entity_name)
     base_config = ConfigSchema(
         dimension=10,
         relations=[relation_config],
         entities={entity_name: EntitySchema(num_partitions=1)},
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         num_epochs=2,
         num_edge_chunks=2,
         workers=2,
     )
     dataset = generate_dataset(base_config,
                                num_entities=100,
                                fractions=[0.4, 0.4])
     self.addCleanup(dataset.cleanup)
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[d.name for d in dataset.relation_paths],
     )
     # Just make sure no exceptions are raised and nothing crashes.
     init_embeddings(train_config.checkpoint_path, train_config, version=7)
     train(train_config, rank=0, subprocess_init=self.subprocess_init)
     self.assertCheckpointWritten(train_config, version=8)
     # Check we did resume the run, not start the whole thing anew.
     self.assertFalse(
         os.path.exists(
             os.path.join(train_config.checkpoint_path, "model.v6.h5")))
Esempio n. 9
0
 def test_with_initial_value(self):
     entity_name = "e"
     relation_config = RelationSchema(name="r",
                                      lhs=entity_name,
                                      rhs=entity_name)
     base_config = ConfigSchema(
         dimension=10,
         relations=[relation_config],
         entities={entity_name: EntitySchema(num_partitions=1)},
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         workers=2,
     )
     dataset = generate_dataset(base_config,
                                num_entities=100,
                                fractions=[0.4])
     self.addCleanup(dataset.cleanup)
     init_dir = TemporaryDirectory()
     self.addCleanup(init_dir.cleanup)
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[0].name],
         init_path=init_dir.name,
     )
     # Just make sure no exceptions are raised and nothing crashes.
     init_embeddings(train_config.init_path, train_config)
     train(train_config, rank=0, subprocess_init=self.subprocess_init)
     self.assertCheckpointWritten(train_config, version=1)
Esempio n. 10
0
 def test_featurized(self):
     e1 = EntitySchema(num_partitions=1, featurized=True)
     e2 = EntitySchema(num_partitions=1)
     r1 = RelationSchema(name="r1", lhs="e1", rhs="e2")
     r2 = RelationSchema(name="r2", lhs="e2", rhs="e1")
     base_config = ConfigSchema(
         dimension=10,
         relations=[r1, r2],
         entities={
             "e1": e1,
             "e2": e2
         },
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         workers=2,
     )
     dataset = generate_dataset(base_config,
                                num_entities=100,
                                fractions=[0.4, 0.2])
     self.addCleanup(dataset.cleanup)
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[0].name],
     )
     eval_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[1].name],
     )
     # Just make sure no exceptions are raised and nothing crashes.
     train(train_config, rank=0, subprocess_init=self.subprocess_init)
     self.assertCheckpointWritten(train_config, version=1)
     do_eval(eval_config, subprocess_init=self.subprocess_init)
Esempio n. 11
0
 def test_entity_dimensions(self):
     entity_name = "e"
     relation_config = RelationSchema(name="r",
                                      lhs=entity_name,
                                      rhs=entity_name)
     base_config = ConfigSchema(
         dimension=10,
         relations=[relation_config],
         entities={
             entity_name: EntitySchema(num_partitions=1, dimension=8)
         },
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         workers=2,
     )
     dataset = generate_dataset(base_config,
                                num_entities=100,
                                fractions=[0.4, 0.2])
     self.addCleanup(dataset.cleanup)
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[0].name],
     )
     eval_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[1].name],
         relations=[attr.evolve(relation_config, all_negs=True)],
     )
     # Just make sure no exceptions are raised and nothing crashes.
     train(train_config, rank=0, subprocess_init=self.subprocess_init)
     self.assertCheckpointWritten(train_config, version=1)
     do_eval(eval_config, subprocess_init=self.subprocess_init)
 def write_new_version(
     self,
     config: ConfigSchema,
     entity_counts: Dict[EntityName, List[int]],
     embedding_storage_freelist: Dict[EntityName, Set[torch.FloatStorage]],
 ) -> None:
     metadata = self.collect_metadata()
     new_version = self._version(True)
     if self.partition_client is not None:
         for entity, econf in config.entities.items():
             dimension = config.entity_dimension(entity)
             for part in range(self.rank, econf.num_partitions,
                               self.num_machines):
                 logger.debug(f"Getting {entity} {part}")
                 count = entity_counts[entity][part]
                 s = next(iter(embedding_storage_freelist[entity]))
                 out = torch.FloatTensor(s).view(-1, dimension)[:count]
                 embs, serialized_optim_state = self.partition_client.get(
                     entity, part, out=out)
                 logger.debug(f"Done getting {entity} {part}")
                 logger.debug(f"Saving {entity} {part} v{new_version}")
                 self.storage.save_entity_partition(
                     new_version,
                     entity,
                     part,
                     embs,
                     serialized_optim_state,
                     metadata,
                 )
                 logger.debug(f"Done saving {entity} {part} v{new_version}")
Esempio n. 13
0
def main():
    setup_logging()
    config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('config', help="Path to config file")
    parser.add_argument('-p', '--param', action='append', nargs='*')
    parser.add_argument('--rank',
                        type=int,
                        default=0,
                        help="For multi-machine, this machine's rank")
    opt = parser.parse_args()

    if opt.param is not None:
        overrides = chain.from_iterable(opt.param)  # flatten
    else:
        overrides = None
    loader = ConfigFileLoader()
    config = loader.load_config(opt.config, overrides)
    set_logging_verbosity(config.verbose)
    subprocess_init = SubprocessInitializer()
    subprocess_init.register(setup_logging, config.verbose)
    subprocess_init.register(add_to_sys_path, loader.config_dir.name)

    train(config, rank=Rank(opt.rank), subprocess_init=subprocess_init)
Esempio n. 14
0
    def __init__(
        self,
        config: ConfigSchema,
        model: Optional[MultiRelationEmbedder] = None,
        trainer: Optional[AbstractBatchProcessor] = None,
        evaluator: Optional[AbstractBatchProcessor] = None,
        rank: Rank = SINGLE_TRAINER,
        subprocess_init: Optional[Callable[[], None]] = None,
        stats_handler: StatsHandler = NOOP_STATS_HANDLER,
    ):

        super().__init__(
            config, model, trainer, evaluator, rank, subprocess_init, stats_handler
        )

        assert config.num_gpus > 0
        if not CPP_INSTALLED:
            raise RuntimeError(
                "GPU support requires C++ installation: "
                "install with C++ support by running "
                "`PBG_INSTALL_CPP=1 pip install .`"
            )

        if config.half_precision:
            for entity in config.entities:
                # need this for tensor cores to work
                assert config.entity_dimension(entity) % 8 == 0
            assert config.batch_size % 8 == 0
            assert config.num_batch_negs % 8 == 0
            assert config.num_uniform_negs % 8 == 0

        assert len(self.holder.lhs_unpartitioned_types) == 0
        assert len(self.holder.rhs_unpartitioned_types) == 0

        num_edge_chunks = self.iteration_manager.num_edge_chunks
        max_edges = 0
        for edge_path in config.edge_paths:
            edge_storage = EDGE_STORAGES.make_instance(edge_path)
            for lhs_part in range(self.holder.nparts_lhs):
                for rhs_part in range(self.holder.nparts_rhs):
                    num_edges = edge_storage.get_number_of_edges(lhs_part, rhs_part)
                    num_edges_per_chunk = div_roundup(num_edges, num_edge_chunks)
                    max_edges = max(max_edges, num_edges_per_chunk)
        self.shared_lhs = allocate_shared_tensor((max_edges,), dtype=torch.long)
        self.shared_rhs = allocate_shared_tensor((max_edges,), dtype=torch.long)
        self.shared_rel = allocate_shared_tensor((max_edges,), dtype=torch.long)

        # fork early for HOGWILD threads
        logger.info("Creating GPU workers...")
        torch.set_num_threads(1)
        self.gpu_pool = GPUProcessPool(
            config.num_gpus,
            subprocess_init,
            {s for ss in self.embedding_storage_freelist.values() for s in ss}
            | {
                self.shared_lhs.storage(),
                self.shared_rhs.storage(),
                self.shared_rel.storage(),
            },
        )
Esempio n. 15
0
def main():
    config_help = "\n\nConfig parameters:\n\n" + "\n".join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument("config", help="Path to config file")
    parser.add_argument("-p", "--param", action="append", nargs="*")
    parser.add_argument("edge_paths",
                        type=Path,
                        nargs="*",
                        help="Input file paths")
    parser.add_argument(
        "-l",
        "--lhs-col",
        type=str,
        required=True,
        help="Column index for source entity",
    )
    parser.add_argument(
        "-r",
        "--rhs-col",
        type=str,
        required=True,
        help="Column index for target entity",
    )
    parser.add_argument("--rel-col",
                        type=str,
                        help="Column index for relation entity")
    parser.add_argument(
        "--relation-type-min-count",
        type=int,
        default=1,
        help="Min count for relation types",
    )
    parser.add_argument("--entity-min-count",
                        type=int,
                        default=1,
                        help="Min count for entities")
    opt = parser.parse_args()

    loader = ConfigFileLoader()
    config_dict = loader.load_raw_config(opt.config, opt.param)

    entity_configs, relation_configs, entity_path, edge_paths, dynamic_relations = parse_config_partial(  # noqa
        config_dict)

    convert_input_data(
        entity_configs,
        relation_configs,
        entity_path,
        edge_paths,
        opt.edge_paths,
        ParquetEdgelistReader(opt.lhs_col, opt.rhs_col, opt.rel_col),
        opt.entity_min_count,
        opt.relation_type_min_count,
        dynamic_relations,
    )
Esempio n. 16
0
def main():
    config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('config', help='Path to config file')
    parser.add_argument('-p', '--param', action='append', nargs='*')
    parser.add_argument('edge_paths',
                        type=Path,
                        nargs='*',
                        help='Input file paths')
    parser.add_argument('-l',
                        '--lhs-col',
                        type=int,
                        required=True,
                        help='Column index for source entity')
    parser.add_argument('-r',
                        '--rhs-col',
                        type=int,
                        required=True,
                        help='Column index for target entity')
    parser.add_argument('--rel-col',
                        type=int,
                        help='Column index for relation entity')
    parser.add_argument('--relation-type-min-count',
                        type=int,
                        default=1,
                        help='Min count for relation types')
    parser.add_argument('--entity-min-count',
                        type=int,
                        default=1,
                        help='Min count for entities')
    opt = parser.parse_args()

    loader = ConfigFileLoader()
    config_dict = loader.load_raw_config(opt.config)

    if opt.param is not None:
        overrides = chain.from_iterable(opt.param)  # flatten
        config_dict = override_config_dict(config_dict, overrides)

    entity_configs, relation_configs, entity_path, dynamic_relations = \
        parse_config_partial(config_dict)

    convert_input_data(
        entity_configs,
        relation_configs,
        entity_path,
        opt.edge_paths,
        opt.lhs_col,
        opt.rhs_col,
        opt.rel_col,
        opt.entity_min_count,
        opt.relation_type_min_count,
        dynamic_relations,
    )
Esempio n. 17
0
 def assertHasMetadata(self, hf: h5py.File, config: ConfigSchema) -> None:
     self.assertEqual(hf.attrs["format_version"], 1)
     self.assertEqual(json.loads(hf.attrs["config/json"]), config.to_dict())
     self.assertCountEqual([
         key.partition("/")[-1]
         for key in hf.attrs.keys() if key.startswith("iteration/")
     ], [
         "num_epochs", "epoch_idx", "num_edge_paths", "edge_path_idx",
         "edge_path", "num_edge_chunks", "edge_chunk_idx"
     ])
Esempio n. 18
0
 def test_distributed(self):
     sync_path = TemporaryDirectory()
     self.addCleanup(sync_path.cleanup)
     entity_name = "e"
     relation_config = RelationSchema(
         name="r",
         lhs=entity_name,
         rhs=entity_name,
         operator="linear",  # To exercise the parameter server.
     )
     base_config = ConfigSchema(
         dimension=10,
         relations=[relation_config],
         entities={entity_name: EntitySchema(num_partitions=4)},
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         num_machines=2,
         distributed_init_method="file://%s" %
         os.path.join(sync_path.name, "sync"),
         workers=2,
     )
     dataset = generate_dataset(base_config,
                                num_entities=100,
                                fractions=[0.4])
     self.addCleanup(dataset.cleanup)
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[0].name],
     )
     # Just make sure no exceptions are raised and nothing crashes.
     trainer0 = mp.get_context("spawn").Process(name="trainer#0",
                                                target=train,
                                                args=(train_config, ),
                                                kwargs={"rank": 0})
     trainer1 = mp.get_context("spawn").Process(name="trainer#1",
                                                target=train,
                                                args=(train_config, ),
                                                kwargs={"rank": 1})
     # FIXME In Python 3.7 use kill here.
     self.addCleanup(trainer0.terminate)
     self.addCleanup(trainer1.terminate)
     trainer0.start()
     trainer1.start()
     done = [False, False]
     while not all(done):
         time.sleep(1)
         if not trainer0.is_alive() and not done[0]:
             self.assertEqual(trainer0.exitcode, 0)
             done[0] = True
         if not trainer1.is_alive() and not done[1]:
             self.assertEqual(trainer1.exitcode, 0)
             done[1] = True
     self.assertCheckpointWritten(train_config, version=1)
Esempio n. 19
0
def main():
    config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('config', help='Path to config file')
    parser.add_argument('edge_paths', nargs='*', help='Input file paths')
    parser.add_argument('-l',
                        '--lhs-col',
                        type=int,
                        required=True,
                        help='Column index for source entity')
    parser.add_argument('-r',
                        '--rhs-col',
                        type=int,
                        required=True,
                        help='Column index for target entity')
    parser.add_argument('--rel-col',
                        type=int,
                        help='Column index for relation entity')
    parser.add_argument('--relation-type-min-count',
                        type=int,
                        default=1,
                        help='Min count for relation types')
    parser.add_argument('--entity-min-count',
                        type=int,
                        default=1,
                        help='Min count for entities')

    opt = parser.parse_args()

    entity_configs, relation_configs, entity_path, dynamic_relations = \
        validate_config(opt.config)

    convert_input_data(
        entity_configs,
        relation_configs,
        entity_path,
        opt.edge_paths,
        opt.lhs_col,
        opt.rhs_col,
        opt.rel_col,
        opt.entity_min_count,
        opt.relation_type_min_count,
        dynamic_relations,
    )
Esempio n. 20
0
def main():
    config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('config', help="Path to config file")
    parser.add_argument('-p', '--param', action='append', nargs='*')
    opt = parser.parse_args()

    if opt.param is not None:
        overrides = chain.from_iterable(opt.param)  # flatten
    else:
        overrides = None
    config = parse_config(opt.config, overrides)

    do_eval(config)
Esempio n. 21
0
 def test_dynamic_relations(self):
     relation_config = RelationSchema(name="r", lhs="el", rhs="er")
     base_config = ConfigSchema(
         dimension=10,
         relations=[relation_config],
         entities={
             "el": EntitySchema(num_partitions=1),
             "er": EntitySchema(num_partitions=1),
         },
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         dynamic_relations=True,
         global_emb=False,  # Must be off for dynamic relations.
         workers=2,
     )
     gen_config = attr.evolve(
         base_config,
         relations=[relation_config] * 10,
         dynamic_relations=False,  # Must be off if more than 1 relation.
     )
     dataset = generate_dataset(gen_config,
                                num_entities=100,
                                fractions=[0.04, 0.02])
     self.addCleanup(dataset.cleanup)
     with open(
             os.path.join(dataset.entity_path.name,
                          "dynamic_rel_count.txt"), "xt") as f:
         f.write("%d" % len(gen_config.relations))
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[0].name],
     )
     eval_config = attr.evolve(
         base_config,
         relations=[attr.evolve(relation_config, all_negs=True)],
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[1].name],
     )
     # Just make sure no exceptions are raised and nothing crashes.
     train(train_config, rank=0, subprocess_init=self.subprocess_init)
     self.assertCheckpointWritten(train_config, version=1)
     do_eval(eval_config, subprocess_init=self.subprocess_init)
Esempio n. 22
0
def main():
    config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('config', help="Path to config file")
    parser.add_argument('-p', '--param', action='append', nargs='*')
    parser.add_argument('--rank', type=int, default=0,
                        help="For multi-machine, this machine's rank")
    opt = parser.parse_args()

    if opt.param is not None:
        overrides = chain.from_iterable(opt.param)  # flatten
    else:
        overrides = None
    config = parse_config(opt.config, overrides)

    train(config, rank=Rank(opt.rank))
Esempio n. 23
0
def main():
    config_help = "\n\nConfig parameters:\n\n" + "\n".join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument("config", help="Path to config file")
    parser.add_argument("-p", "--param", action="append", nargs="*")
    parser.add_argument("--entities-output", required=True)
    parser.add_argument("--relation-types-output", required=True)
    opt = parser.parse_args()

    loader = ConfigFileLoader()
    config = loader.load_config(opt.config, opt.param)

    with open(opt.entities_output,
              "xt") as entities_tf, open(opt.relation_types_output,
                                         "xt") as relation_types_tf:
        make_tsv(config, entities_tf, relation_types_tf)
Esempio n. 24
0
def main():
    setup_logging()
    config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('config', help="Path to config file")
    parser.add_argument('-p', '--param', action='append', nargs='*')
    opt = parser.parse_args()

    loader = ConfigFileLoader()
    config = loader.load_config(opt.config, opt.param)
    set_logging_verbosity(config.verbose)
    subprocess_init = SubprocessInitializer()
    subprocess_init.register(setup_logging, config.verbose)
    subprocess_init.register(add_to_sys_path, loader.config_dir.name)

    do_eval(config, subprocess_init=subprocess_init)
Esempio n. 25
0
def train_and_report_stats(
    config: ConfigSchema,
    model: Optional[MultiRelationEmbedder] = None,
    trainer: Optional[AbstractBatchProcessor] = None,
    evaluator: Optional[AbstractBatchProcessor] = None,
    rank: Rank = RANK_ZERO,
    subprocess_init: Optional[Callable[[], None]] = None,
) -> Generator[Tuple[int, Optional[Stats], Stats, Optional[Stats]], None,
               None]:
    """Each epoch/pass, for each partition pair, loads in embeddings and edgelist
    from disk, runs HOGWILD training on them, and writes partitions back to disk.
    """
    tag_logs_with_process_name(f"Trainer-{rank}")

    if config.verbose > 0:
        import pprint
        pprint.PrettyPrinter().pprint(config.to_dict())

    logger.info("Loading entity counts...")
    entity_storage = ENTITY_STORAGES.make_instance(config.entity_path)
    entity_counts: Dict[str, List[int]] = {}
    for entity, econf in config.entities.items():
        entity_counts[entity] = []
        for part in range(econf.num_partitions):
            entity_counts[entity].append(
                entity_storage.load_count(entity, part))

    # Figure out how many lhs and rhs partitions we need
    nparts_lhs, lhs_partitioned_types = get_partitioned_types(config, Side.LHS)
    nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS)
    logger.debug(f"nparts {nparts_lhs} {nparts_rhs} "
                 f"types {lhs_partitioned_types} {rhs_partitioned_types}")
    total_buckets = nparts_lhs * nparts_rhs

    sync: AbstractSynchronizer
    bucket_scheduler: AbstractBucketScheduler
    parameter_sharer: Optional[ParameterSharer]
    partition_client: Optional[PartitionClient]
    if config.num_machines > 1:
        if not 0 <= rank < config.num_machines:
            raise RuntimeError("Invalid rank for trainer")
        if not td.is_available():
            raise RuntimeError("The installed PyTorch version doesn't provide "
                               "distributed training capabilities.")
        ranks = ProcessRanks.from_num_invocations(config.num_machines,
                                                  config.num_partition_servers)

        if rank == RANK_ZERO:
            logger.info("Setup lock server...")
            start_server(
                LockServer(
                    num_clients=len(ranks.trainers),
                    nparts_lhs=nparts_lhs,
                    nparts_rhs=nparts_rhs,
                    lock_lhs=len(lhs_partitioned_types) > 0,
                    lock_rhs=len(rhs_partitioned_types) > 0,
                    init_tree=config.distributed_tree_init_order,
                ),
                process_name="LockServer",
                init_method=config.distributed_init_method,
                world_size=ranks.world_size,
                server_rank=ranks.lock_server,
                groups=[ranks.trainers],
                subprocess_init=subprocess_init,
            )

        bucket_scheduler = DistributedBucketScheduler(
            server_rank=ranks.lock_server,
            client_rank=ranks.trainers[rank],
        )

        logger.info("Setup param server...")
        start_server(
            ParameterServer(num_clients=len(ranks.trainers)),
            process_name=f"ParamS-{rank}",
            init_method=config.distributed_init_method,
            world_size=ranks.world_size,
            server_rank=ranks.parameter_servers[rank],
            groups=[ranks.trainers],
            subprocess_init=subprocess_init,
        )

        parameter_sharer = ParameterSharer(
            process_name=f"ParamC-{rank}",
            client_rank=ranks.parameter_clients[rank],
            all_server_ranks=ranks.parameter_servers,
            init_method=config.distributed_init_method,
            world_size=ranks.world_size,
            groups=[ranks.trainers],
            subprocess_init=subprocess_init,
        )

        if config.num_partition_servers == -1:
            start_server(
                ParameterServer(num_clients=len(ranks.trainers),
                                log_stats=True),
                process_name=f"PartS-{rank}",
                init_method=config.distributed_init_method,
                world_size=ranks.world_size,
                server_rank=ranks.partition_servers[rank],
                groups=[ranks.trainers],
                subprocess_init=subprocess_init,
            )

        if len(ranks.partition_servers) > 0:
            partition_client = PartitionClient(ranks.partition_servers,
                                               log_stats=True)
        else:
            partition_client = None

        groups = init_process_group(
            rank=ranks.trainers[rank],
            world_size=ranks.world_size,
            init_method=config.distributed_init_method,
            groups=[ranks.trainers],
        )
        trainer_group, = groups
        sync = DistributedSynchronizer(trainer_group)

    else:
        sync = DummySynchronizer()
        bucket_scheduler = SingleMachineBucketScheduler(
            nparts_lhs, nparts_rhs, config.bucket_order)
        parameter_sharer = None
        partition_client = None
        hide_distributed_logging()

    # fork early for HOGWILD threads
    logger.info("Creating workers...")
    num_workers = get_num_workers(config.workers)
    pool = create_pool(
        num_workers,
        subprocess_name=f"TWorker-{rank}",
        subprocess_init=subprocess_init,
    )

    def make_optimizer(params: Iterable[torch.nn.Parameter],
                       is_emb: bool) -> Optimizer:
        params = list(params)
        if len(params) == 0:
            optimizer = DummyOptimizer()
        elif is_emb:
            optimizer = RowAdagrad(params, lr=config.lr)
        else:
            if config.relation_lr is not None:
                lr = config.relation_lr
            else:
                lr = config.lr
            optimizer = Adagrad(params, lr=lr)
        optimizer.share_memory()
        return optimizer

    # background_io is only supported in single-machine mode
    background_io = config.background_io and config.num_machines == 1

    checkpoint_manager = CheckpointManager(
        config.checkpoint_path,
        background=background_io,
        rank=rank,
        num_machines=config.num_machines,
        partition_client=partition_client,
        subprocess_name=f"BackgRW-{rank}",
        subprocess_init=subprocess_init,
    )
    checkpoint_manager.register_metadata_provider(
        ConfigMetadataProvider(config))
    checkpoint_manager.write_config(config)

    if config.num_edge_chunks is not None:
        num_edge_chunks = config.num_edge_chunks
    else:
        num_edge_chunks = get_num_edge_chunks(config.edge_paths, nparts_lhs,
                                              nparts_rhs,
                                              config.max_edges_per_chunk)
    iteration_manager = IterationManager(
        config.num_epochs,
        config.edge_paths,
        num_edge_chunks,
        iteration_idx=checkpoint_manager.checkpoint_version)
    checkpoint_manager.register_metadata_provider(iteration_manager)

    if config.init_path is not None:
        loadpath_manager = CheckpointManager(config.init_path)
    else:
        loadpath_manager = None

    def load_embeddings(
        entity: EntityName,
        part: Partition,
        strict: bool = False,
        force_dirty: bool = False,
    ) -> Tuple[torch.nn.Parameter, Optional[OptimizerStateDict]]:
        if strict:
            embs, optim_state = checkpoint_manager.read(
                entity, part, force_dirty=force_dirty)
        else:
            # Strict is only false during the first iteration, because in that
            # case the checkpoint may not contain any data (unless a previous
            # run was resumed) so we fall back on initial values.
            embs, optim_state = checkpoint_manager.maybe_read(
                entity, part, force_dirty=force_dirty)
            if embs is None and loadpath_manager is not None:
                embs, optim_state = loadpath_manager.maybe_read(entity, part)
            if embs is None:
                embs, optim_state = init_embs(entity,
                                              entity_counts[entity][part],
                                              config.dimension,
                                              config.init_scale)
        assert embs.is_shared()
        return torch.nn.Parameter(embs), optim_state

    logger.info("Initializing global model...")

    if model is None:
        model = make_model(config)
    model.share_memory()
    if trainer is None:
        trainer = Trainer(
            global_optimizer=make_optimizer(model.parameters(), False),
            loss_fn=config.loss_fn,
            margin=config.margin,
            relations=config.relations,
        )
    if evaluator is None:
        evaluator = TrainingRankingEvaluator(
            override_num_batch_negs=config.eval_num_batch_negs,
            override_num_uniform_negs=config.eval_num_uniform_negs,
        )
    eval_batch_size = round_up_to_nearest_multiple(config.batch_size,
                                                   config.eval_num_batch_negs)

    state_dict, optim_state = checkpoint_manager.maybe_read_model()

    if state_dict is None and loadpath_manager is not None:
        state_dict, optim_state = loadpath_manager.maybe_read_model()
    if state_dict is not None:
        model.load_state_dict(state_dict, strict=False)
    if optim_state is not None:
        trainer.global_optimizer.load_state_dict(optim_state)

    logger.debug("Loading unpartitioned entities...")
    for entity, econfig in config.entities.items():
        if econfig.num_partitions == 1:
            embs, optim_state = load_embeddings(entity, Partition(0))
            model.set_embeddings(entity, embs, Side.LHS)
            model.set_embeddings(entity, embs, Side.RHS)
            optimizer = make_optimizer([embs], True)
            if optim_state is not None:
                optimizer.load_state_dict(optim_state)
            trainer.entity_optimizers[(entity, Partition(0))] = optimizer

    # start communicating shared parameters with the parameter server
    if parameter_sharer is not None:
        parameter_sharer.share_model_params(model)

    strict = False

    def swap_partitioned_embeddings(
        old_b: Optional[Bucket],
        new_b: Optional[Bucket],
    ):
        # 0. given the old and new buckets, construct data structures to keep
        #    track of old and new embedding (entity, part) tuples

        io_bytes = 0
        logger.info(f"Swapping partitioned embeddings {old_b} {new_b}")

        types = ([(e, Side.LHS) for e in lhs_partitioned_types] +
                 [(e, Side.RHS) for e in rhs_partitioned_types])
        old_parts = {(e, old_b.get_partition(side)): side
                     for e, side in types if old_b is not None}
        new_parts = {(e, new_b.get_partition(side)): side
                     for e, side in types if new_b is not None}

        to_checkpoint = set(old_parts) - set(new_parts)
        preserved = set(old_parts) & set(new_parts)

        # 1. checkpoint embeddings that will not be used in the next pair
        #
        if old_b is not None:  # there are previous embeddings to checkpoint
            logger.info("Writing partitioned embeddings")
            for entity, part in to_checkpoint:
                side = old_parts[(entity, part)]
                side_name = side.pick("lhs", "rhs")
                logger.debug(f"Checkpointing ({entity} {part} {side_name})")
                embs = model.get_embeddings(entity, side)
                optim_key = (entity, part)
                optim_state = OptimizerStateDict(
                    trainer.entity_optimizers[optim_key].state_dict())
                io_bytes += embs.numel() * embs.element_size(
                )  # ignore optim state
                checkpoint_manager.write(entity, part, embs.detach(),
                                         optim_state)
                if optim_key in trainer.entity_optimizers:
                    del trainer.entity_optimizers[optim_key]
                # these variables are holding large objects; let them be freed
                del embs
                del optim_state

            bucket_scheduler.release_bucket(old_b)

        # 2. copy old embeddings that will be used in the next pair
        #    into a temporary dictionary
        #
        tmp_emb = {
            x: model.get_embeddings(x[0], old_parts[x])
            for x in preserved
        }

        for entity, _ in types:
            model.clear_embeddings(entity, Side.LHS)
            model.clear_embeddings(entity, Side.RHS)

        if new_b is None:  # there are no new embeddings to load
            return io_bytes

        bucket_logger = BucketLogger(logger, bucket=new_b)

        # 3. load new embeddings into the model/optimizer, either from disk
        #    or the temporary dictionary
        #
        bucket_logger.info("Loading entities")
        for entity, side in types:
            part = new_b.get_partition(side)
            part_key = (entity, part)
            if part_key in tmp_emb:
                bucket_logger.debug(
                    f"Loading ({entity}, {part}) from preserved")
                embs, optim_state = tmp_emb[part_key], None
            else:
                bucket_logger.debug(f"Loading ({entity}, {part})")

                force_dirty = bucket_scheduler.check_and_set_dirty(
                    entity, part)
                embs, optim_state = load_embeddings(entity,
                                                    part,
                                                    strict=strict,
                                                    force_dirty=force_dirty)
                io_bytes += embs.numel() * embs.element_size(
                )  # ignore optim state

            model.set_embeddings(entity, embs, side)
            tmp_emb[part_key] = embs

            optim_key = (entity, part)
            if optim_key not in trainer.entity_optimizers:
                bucket_logger.debug(f"Resetting optimizer {optim_key}")
                optimizer = make_optimizer([embs], True)
                if optim_state is not None:
                    bucket_logger.debug("Setting optim state")
                    optimizer.load_state_dict(optim_state)

                trainer.entity_optimizers[optim_key] = optimizer

        return io_bytes

    if rank == RANK_ZERO:
        for stats_dict in checkpoint_manager.maybe_read_stats():
            index: int = stats_dict["index"]
            stats: Stats = Stats.from_dict(stats_dict["stats"])
            eval_stats_before: Optional[Stats] = None
            if "eval_stats_before" in stats_dict:
                eval_stats_before = Stats.from_dict(
                    stats_dict["eval_stats_before"])
            eval_stats_after: Optional[Stats] = None
            if "eval_stats_after" in stats_dict:
                eval_stats_after = Stats.from_dict(
                    stats_dict["eval_stats_after"])
            yield (index, eval_stats_before, stats, eval_stats_after)

    # Start of the main training loop.
    for epoch_idx, edge_path_idx, edge_chunk_idx in iteration_manager:
        logger.info(
            f"Starting epoch {epoch_idx + 1} / {iteration_manager.num_epochs}, "
            f"edge path {edge_path_idx + 1} / {iteration_manager.num_edge_paths}, "
            f"edge chunk {edge_chunk_idx + 1} / {iteration_manager.num_edge_chunks}"
        )
        edge_storage = EDGE_STORAGES.make_instance(iteration_manager.edge_path)
        logger.info(f"Edge path: {iteration_manager.edge_path}")

        sync.barrier()
        dist_logger.info("Lock client new epoch...")
        bucket_scheduler.new_pass(
            is_first=iteration_manager.iteration_idx == 0)
        sync.barrier()

        remaining = total_buckets
        cur_b = None
        while remaining > 0:
            old_b = cur_b
            io_time = 0.
            io_bytes = 0
            cur_b, remaining = bucket_scheduler.acquire_bucket()
            logger.info(f"still in queue: {remaining}")
            if cur_b is None:
                if old_b is not None:
                    # if you couldn't get a new pair, release the lock
                    # to prevent a deadlock!
                    tic = time.time()
                    io_bytes += swap_partitioned_embeddings(old_b, None)
                    io_time += time.time() - tic
                time.sleep(1)  # don't hammer td
                continue

            bucket_logger = BucketLogger(logger, bucket=cur_b)

            tic = time.time()

            io_bytes += swap_partitioned_embeddings(old_b, cur_b)

            current_index = \
                (iteration_manager.iteration_idx + 1) * total_buckets - remaining

            next_b = bucket_scheduler.peek()
            if next_b is not None and background_io:
                # Ensure the previous bucket finished writing to disk.
                checkpoint_manager.wait_for_marker(current_index - 1)

                bucket_logger.debug("Prefetching")
                for entity in lhs_partitioned_types:
                    checkpoint_manager.prefetch(entity, next_b.lhs)
                for entity in rhs_partitioned_types:
                    checkpoint_manager.prefetch(entity, next_b.rhs)

                checkpoint_manager.record_marker(current_index)

            bucket_logger.debug("Loading edges")
            edges = edge_storage.load_chunk_of_edges(
                cur_b.lhs, cur_b.rhs, edge_chunk_idx,
                iteration_manager.num_edge_chunks)
            num_edges = len(edges)
            # this might be off in the case of tensorlist or extra edge fields
            io_bytes += edges.lhs.tensor.numel(
            ) * edges.lhs.tensor.element_size()
            io_bytes += edges.rhs.tensor.numel(
            ) * edges.rhs.tensor.element_size()
            io_bytes += edges.rel.numel() * edges.rel.element_size()

            bucket_logger.debug("Shuffling edges")
            # Fix a seed to get the same permutation every time; have it
            # depend on all and only what affects the set of edges.
            g = torch.Generator()
            g.manual_seed(
                hash((edge_path_idx, edge_chunk_idx, cur_b.lhs, cur_b.rhs)))

            num_eval_edges = int(num_edges * config.eval_fraction)
            if num_eval_edges > 0:
                edge_perm = torch.randperm(num_edges, generator=g)
                eval_edge_perm = edge_perm[-num_eval_edges:]
                num_edges -= num_eval_edges
                edge_perm = edge_perm[torch.randperm(num_edges)]
            else:
                edge_perm = torch.randperm(num_edges)

            # HOGWILD evaluation before training
            eval_stats_before: Optional[Stats] = None
            if num_eval_edges > 0:
                bucket_logger.debug(
                    "Waiting for workers to perform evaluation")
                future_all_eval_stats_before = pool.map_async(
                    call, [
                        partial(
                            process_in_batches,
                            batch_size=eval_batch_size,
                            model=model,
                            batch_processor=evaluator,
                            edges=edges,
                            indices=eval_edge_perm[s],
                        ) for s in split_almost_equally(eval_edge_perm.size(0),
                                                        num_parts=num_workers)
                    ])
                all_eval_stats_before = \
                    get_async_result(future_all_eval_stats_before, pool)
                eval_stats_before = Stats.sum(all_eval_stats_before).average()
                bucket_logger.info(
                    f"Stats before training: {eval_stats_before}")

            io_time += time.time() - tic
            tic = time.time()
            # HOGWILD training
            bucket_logger.debug("Waiting for workers to perform training")
            # FIXME should we only delay if iteration_idx == 0?
            future_all_stats = pool.map_async(call, [
                partial(
                    process_in_batches,
                    batch_size=config.batch_size,
                    model=model,
                    batch_processor=trainer,
                    edges=edges,
                    indices=edge_perm[s],
                    delay=config.hogwild_delay
                    if epoch_idx == 0 and rank > 0 else 0,
                ) for rank, s in enumerate(
                    split_almost_equally(edge_perm.size(0),
                                         num_parts=num_workers))
            ])
            all_stats = get_async_result(future_all_stats, pool)
            stats = Stats.sum(all_stats).average()
            compute_time = time.time() - tic

            bucket_logger.info(
                f"bucket {total_buckets - remaining} / {total_buckets} : "
                f"Processed {num_edges} edges in {compute_time:.2f} s "
                f"( {num_edges / compute_time / 1e6:.2g} M/sec ); "
                f"io: {io_time:.2f} s ( {io_bytes / io_time / 1e6:.2f} MB/sec )"
            )
            bucket_logger.info(f"{stats}")

            # HOGWILD eval after training
            eval_stats_after: Optional[Stats] = None
            if num_eval_edges > 0:
                bucket_logger.debug(
                    "Waiting for workers to perform evaluation")
                future_all_eval_stats_after = pool.map_async(
                    call, [
                        partial(
                            process_in_batches,
                            batch_size=eval_batch_size,
                            model=model,
                            batch_processor=evaluator,
                            edges=edges,
                            indices=eval_edge_perm[s],
                        ) for s in split_almost_equally(eval_edge_perm.size(0),
                                                        num_parts=num_workers)
                    ])
                all_eval_stats_after = \
                    get_async_result(future_all_eval_stats_after, pool)
                eval_stats_after = Stats.sum(all_eval_stats_after).average()
                bucket_logger.info(f"Stats after training: {eval_stats_after}")

            # Add train/eval metrics to queue
            stats_dict = {
                "index": current_index,
                "stats": stats.to_dict(),
            }
            if eval_stats_before is not None:
                stats_dict["eval_stats_before"] = eval_stats_before.to_dict()
            if eval_stats_after is not None:
                stats_dict["eval_stats_after"] = eval_stats_after.to_dict()
            checkpoint_manager.append_stats(stats_dict)
            yield current_index, eval_stats_before, stats, eval_stats_after

        swap_partitioned_embeddings(cur_b, None)

        # Distributed Processing: all machines can leave the barrier now.
        sync.barrier()

        # Preserving a checkpoint requires two steps:
        # - create a snapshot (w/ symlinks) after it's first written;
        # - don't delete it once the following one is written.
        # These two happen in two successive iterations of the main loop: the
        # one just before and the one just after the epoch boundary.
        preserve_old_checkpoint = should_preserve_old_checkpoint(
            iteration_manager, config.checkpoint_preservation_interval)
        preserve_new_checkpoint = should_preserve_old_checkpoint(
            iteration_manager + 1, config.checkpoint_preservation_interval)

        # Write metadata: for multiple machines, write from rank-0
        logger.info(
            f"Finished epoch {epoch_idx + 1} / {iteration_manager.num_epochs}, "
            f"edge path {edge_path_idx + 1} / {iteration_manager.num_edge_paths}, "
            f"edge chunk {edge_chunk_idx + 1} / {iteration_manager.num_edge_chunks}"
        )
        if rank == 0:
            for entity, econfig in config.entities.items():
                if econfig.num_partitions == 1:
                    embs = model.get_embeddings(entity, Side.LHS)
                    optimizer = trainer.entity_optimizers[(entity,
                                                           Partition(0))]

                    checkpoint_manager.write(
                        entity, Partition(0), embs.detach(),
                        OptimizerStateDict(optimizer.state_dict()))

            sanitized_state_dict: ModuleStateDict = {}
            for k, v in ModuleStateDict(model.state_dict()).items():
                if k.startswith('lhs_embs') or k.startswith('rhs_embs'):
                    # skipping state that's an entity embedding
                    continue
                sanitized_state_dict[k] = v

            logger.info("Writing the metadata")
            checkpoint_manager.write_model(
                sanitized_state_dict,
                OptimizerStateDict(trainer.global_optimizer.state_dict()),
            )

        logger.info("Writing the checkpoint")
        checkpoint_manager.write_new_version(config)

        dist_logger.info(
            "Waiting for other workers to write their parts of the checkpoint")
        sync.barrier()
        dist_logger.info("All parts of the checkpoint have been written")

        logger.info("Switching to the new checkpoint version")
        checkpoint_manager.switch_to_new_version()

        dist_logger.info(
            "Waiting for other workers to switch to the new checkpoint version"
        )
        sync.barrier()
        dist_logger.info(
            "All workers have switched to the new checkpoint version")

        # After all the machines have finished committing
        # checkpoints, we either remove the old checkpoints
        # or we preserve it
        if preserve_new_checkpoint:
            # Add 1 so the index is a multiple of the interval, it looks nicer.
            checkpoint_manager.preserve_current_version(config, epoch_idx + 1)
        if not preserve_old_checkpoint:
            checkpoint_manager.remove_old_version(config)

        # now we're sure that all partition files exist,
        # so be strict about loading them
        strict = True

    # quiescence
    pool.close()
    pool.join()

    sync.barrier()

    checkpoint_manager.close()
    if loadpath_manager is not None:
        loadpath_manager.close()

    # FIXME join distributed workers (not really necessary)

    logger.info("Exiting")
Esempio n. 26
0
 def read_config(self) -> ConfigSchema:
     config_json = self.storage.load_config()
     return ConfigSchema.from_dict(json.loads(config_json))
Esempio n. 27
0
 def write_config(self, config: ConfigSchema) -> None:
     config_json = json.dumps(config.to_dict(), indent=4)
     self.storage.save_config(config_json)
Esempio n. 28
0
 def __init__(self, config: ConfigSchema) -> None:
     self.json_config_dict = json.dumps(config.to_dict(), indent=4)
Esempio n. 29
0
 def test_distributed_with_partition_servers(self):
     sync_path = TemporaryDirectory()
     self.addCleanup(sync_path.cleanup)
     entity_name = "e"
     relation_config = RelationSchema(name="r",
                                      lhs=entity_name,
                                      rhs=entity_name)
     base_config = ConfigSchema(
         dimension=10,
         relations=[relation_config],
         entities={entity_name: EntitySchema(num_partitions=4)},
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         num_machines=2,
         num_partition_servers=2,
         distributed_init_method="file://%s" %
         os.path.join(sync_path.name, "sync"),
         workers=2,
     )
     dataset = generate_dataset(base_config,
                                num_entities=100,
                                fractions=[0.4])
     self.addCleanup(dataset.cleanup)
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[0].name],
     )
     # Just make sure no exceptions are raised and nothing crashes.
     trainer0 = mp.get_context("spawn").Process(
         name="Trainer-0",
         target=partial(
             call_one_after_the_other,
             self.subprocess_init,
             partial(train,
                     train_config,
                     rank=0,
                     subprocess_init=self.subprocess_init),
         ),
     )
     trainer1 = mp.get_context("spawn").Process(
         name="Trainer-1",
         target=partial(
             call_one_after_the_other,
             self.subprocess_init,
             partial(train,
                     train_config,
                     rank=1,
                     subprocess_init=self.subprocess_init),
         ),
     )
     partition_server0 = mp.get_context("spawn").Process(
         name="PartitionServer-0",
         target=partial(
             call_one_after_the_other,
             self.subprocess_init,
             partial(
                 run_partition_server,
                 train_config,
                 rank=0,
                 subprocess_init=self.subprocess_init,
             ),
         ),
     )
     partition_server1 = mp.get_context("spawn").Process(
         name="PartitionServer-1",
         target=partial(
             call_one_after_the_other,
             self.subprocess_init,
             partial(
                 run_partition_server,
                 train_config,
                 rank=1,
                 subprocess_init=self.subprocess_init,
             ),
         ),
     )
     # FIXME In Python 3.7 use kill here.
     self.addCleanup(trainer0.terminate)
     self.addCleanup(trainer1.terminate)
     self.addCleanup(partition_server0.terminate)
     self.addCleanup(partition_server1.terminate)
     trainer0.start()
     trainer1.start()
     partition_server0.start()
     partition_server1.start()
     done = [False, False]
     while not all(done):
         time.sleep(1)
         if not trainer0.is_alive() and not done[0]:
             self.assertEqual(trainer0.exitcode, 0)
             done[0] = True
         if not trainer1.is_alive() and not done[1]:
             self.assertEqual(trainer1.exitcode, 0)
             done[1] = True
     partition_server0.join()
     partition_server1.join()
     logger.info(
         f"Partition server 0 died with exit code {partition_server0.exitcode}"
     )
     logger.info(
         f"Partition server 0 died with exit code {partition_server1.exitcode}"
     )
     self.assertCheckpointWritten(train_config, version=1)
Esempio n. 30
0
    def __init__(  # noqa
        self,
        config: ConfigSchema,
        model: Optional[MultiRelationEmbedder] = None,
        trainer: Optional[AbstractBatchProcessor] = None,
        evaluator: Optional[AbstractBatchProcessor] = None,
        rank: Rank = SINGLE_TRAINER,
        subprocess_init: Optional[Callable[[], None]] = None,
        stats_handler: StatsHandler = NOOP_STATS_HANDLER,
    ):
        """Each epoch/pass, for each partition pair, loads in embeddings and edgelist
        from disk, runs HOGWILD training on them, and writes partitions back to disk.
        """
        tag_logs_with_process_name(f"Trainer-{rank}")
        self.config = config
        if config.verbose > 0:
            import pprint

            pprint.PrettyPrinter().pprint(config.to_dict())

        logger.info("Loading entity counts...")
        entity_storage = ENTITY_STORAGES.make_instance(config.entity_path)
        entity_counts: Dict[str, List[int]] = {}
        for entity, econf in config.entities.items():
            entity_counts[entity] = []
            for part in range(econf.num_partitions):
                entity_counts[entity].append(entity_storage.load_count(entity, part))

        # Figure out how many lhs and rhs partitions we need
        holder = self.holder = EmbeddingHolder(config)

        logger.debug(
            f"nparts {holder.nparts_lhs} {holder.nparts_rhs} "
            f"types {holder.lhs_partitioned_types} {holder.rhs_partitioned_types}"
        )

        # We know ahead of time that we wil need 1-2 storages for each embedding type,
        # as well as the max size of this storage (num_entities x D).
        # We allocate these storages n advance in `embedding_storage_freelist`.
        # When we need storage for an entity type, we pop it from this free list,
        # and then add it back when we 'delete' the embedding table.
        embedding_storage_freelist: Dict[
            EntityName, Set[torch.FloatStorage]
        ] = defaultdict(set)
        for entity_type, counts in entity_counts.items():
            max_count = max(counts)
            num_sides = (
                (1 if entity_type in holder.lhs_partitioned_types else 0)
                + (1 if entity_type in holder.rhs_partitioned_types else 0)
                + (
                    1
                    if entity_type
                    in (holder.lhs_unpartitioned_types | holder.rhs_unpartitioned_types)
                    else 0
                )
            )
            for _ in range(num_sides):
                embedding_storage_freelist[entity_type].add(
                    allocate_shared_tensor(
                        (max_count, config.entity_dimension(entity_type)),
                        dtype=torch.float,
                    ).storage()
                )

        # create the handlers, threads, etc. for distributed training
        if config.num_machines > 1 or config.num_partition_servers > 0:
            if not 0 <= rank < config.num_machines:
                raise RuntimeError("Invalid rank for trainer")
            if not td.is_available():
                raise RuntimeError(
                    "The installed PyTorch version doesn't provide "
                    "distributed training capabilities."
                )
            ranks = ProcessRanks.from_num_invocations(
                config.num_machines, config.num_partition_servers
            )

            num_ps_groups = config.num_groups_for_partition_server
            groups: List[List[int]] = [ranks.trainers]  # barrier group
            groups += [
                ranks.trainers + ranks.partition_servers
            ] * num_ps_groups  # ps groups
            group_idxs_for_partition_servers = range(1, len(groups))

            if rank == SINGLE_TRAINER:
                logger.info("Setup lock server...")
                start_server(
                    LockServer(
                        num_clients=len(ranks.trainers),
                        nparts_lhs=holder.nparts_lhs,
                        nparts_rhs=holder.nparts_rhs,
                        entities_lhs=holder.lhs_partitioned_types,
                        entities_rhs=holder.rhs_partitioned_types,
                        entity_counts=entity_counts,
                        init_tree=config.distributed_tree_init_order,
                        stats_handler=stats_handler,
                    ),
                    process_name="LockServer",
                    init_method=config.distributed_init_method,
                    world_size=ranks.world_size,
                    server_rank=ranks.lock_server,
                    groups=groups,
                    subprocess_init=subprocess_init,
                )

            self.bucket_scheduler = DistributedBucketScheduler(
                server_rank=ranks.lock_server, client_rank=ranks.trainers[rank]
            )

            logger.info("Setup param server...")
            start_server(
                ParameterServer(num_clients=len(ranks.trainers)),
                process_name=f"ParamS-{rank}",
                init_method=config.distributed_init_method,
                world_size=ranks.world_size,
                server_rank=ranks.parameter_servers[rank],
                groups=groups,
                subprocess_init=subprocess_init,
            )

            parameter_sharer = ParameterSharer(
                process_name=f"ParamC-{rank}",
                client_rank=ranks.parameter_clients[rank],
                all_server_ranks=ranks.parameter_servers,
                init_method=config.distributed_init_method,
                world_size=ranks.world_size,
                groups=groups,
                subprocess_init=subprocess_init,
            )

            if config.num_partition_servers == -1:
                start_server(
                    ParameterServer(
                        num_clients=len(ranks.trainers),
                        group_idxs=group_idxs_for_partition_servers,
                        log_stats=True,
                    ),
                    process_name=f"PartS-{rank}",
                    init_method=config.distributed_init_method,
                    world_size=ranks.world_size,
                    server_rank=ranks.partition_servers[rank],
                    groups=groups,
                    subprocess_init=subprocess_init,
                )

            groups = init_process_group(
                rank=ranks.trainers[rank],
                world_size=ranks.world_size,
                init_method=config.distributed_init_method,
                groups=groups,
            )
            trainer_group, *groups_for_partition_servers = groups
            self.barrier_group = trainer_group

            if len(ranks.partition_servers) > 0:
                partition_client = PartitionClient(
                    ranks.partition_servers,
                    groups=groups_for_partition_servers,
                    log_stats=True,
                )
            else:
                partition_client = None
        else:
            self.barrier_group = None
            self.bucket_scheduler = SingleMachineBucketScheduler(
                holder.nparts_lhs, holder.nparts_rhs, config.bucket_order, stats_handler
            )
            parameter_sharer = None
            partition_client = None
            hide_distributed_logging()

        # fork early for HOGWILD threads
        logger.info("Creating workers...")
        self.num_workers = get_num_workers(config.workers)
        self.pool = create_pool(
            self.num_workers,
            subprocess_name=f"TWorker-{rank}",
            subprocess_init=subprocess_init,
        )

        checkpoint_manager = CheckpointManager(
            config.checkpoint_path,
            rank=rank,
            num_machines=config.num_machines,
            partition_client=partition_client,
            subprocess_name=f"BackgRW-{rank}",
            subprocess_init=subprocess_init,
        )
        self.checkpoint_manager = checkpoint_manager
        checkpoint_manager.register_metadata_provider(ConfigMetadataProvider(config))
        if rank == 0:
            checkpoint_manager.write_config(config)

        num_edge_chunks = get_num_edge_chunks(config)

        self.iteration_manager = IterationManager(
            config.num_epochs,
            config.edge_paths,
            num_edge_chunks,
            iteration_idx=checkpoint_manager.checkpoint_version,
        )
        checkpoint_manager.register_metadata_provider(self.iteration_manager)

        logger.info("Initializing global model...")
        if model is None:
            model = make_model(config)
        model.share_memory()
        loss_fn = LOSS_FUNCTIONS.get_class(config.loss_fn)(margin=config.margin)
        relation_weights = [relation.weight for relation in config.relations]
        if trainer is None:
            trainer = Trainer(
                model_optimizer=make_optimizer(config, model.parameters(), False),
                loss_fn=loss_fn,
                relation_weights=relation_weights,
            )
        if evaluator is None:
            eval_overrides = {}
            if config.eval_num_batch_negs is not None:
                eval_overrides["num_batch_negs"] = config.eval_num_batch_negs
            if config.eval_num_uniform_negs is not None:
                eval_overrides["num_uniform_negs"] = config.eval_num_uniform_negs

            evaluator = RankingEvaluator(
                loss_fn=loss_fn,
                relation_weights=relation_weights,
                overrides=eval_overrides,
            )

        if config.init_path is not None:
            self.loadpath_manager = CheckpointManager(config.init_path)
        else:
            self.loadpath_manager = None

        # load model from checkpoint or loadpath, if available
        state_dict, optim_state = checkpoint_manager.maybe_read_model()
        if state_dict is None and self.loadpath_manager is not None:
            state_dict, optim_state = self.loadpath_manager.maybe_read_model()
        if state_dict is not None:
            model.load_state_dict(state_dict, strict=False)
        if optim_state is not None:
            trainer.model_optimizer.load_state_dict(optim_state)

        logger.debug("Loading unpartitioned entities...")
        for entity in holder.lhs_unpartitioned_types | holder.rhs_unpartitioned_types:
            count = entity_counts[entity][0]
            s = embedding_storage_freelist[entity].pop()
            dimension = config.entity_dimension(entity)
            embs = torch.FloatTensor(s).view(-1, dimension)[:count]
            embs, optimizer = self._load_embeddings(entity, UNPARTITIONED, out=embs)
            holder.unpartitioned_embeddings[entity] = embs
            trainer.unpartitioned_optimizers[entity] = optimizer

        # start communicating shared parameters with the parameter server
        if parameter_sharer is not None:
            shared_parameters: Set[int] = set()
            for name, param in model.named_parameters():
                if id(param) in shared_parameters:
                    continue
                shared_parameters.add(id(param))
                key = f"model.{name}"
                logger.info(
                    f"Adding {key} ({param.numel()} params) to parameter server"
                )
                parameter_sharer.set_param(key, param.data)
            for entity, embs in holder.unpartitioned_embeddings.items():
                key = f"entity.{entity}"
                logger.info(f"Adding {key} ({embs.numel()} params) to parameter server")
                parameter_sharer.set_param(key, embs.data)

        # store everything in self
        self.model = model
        self.trainer = trainer
        self.evaluator = evaluator
        self.rank = rank
        self.entity_counts = entity_counts
        self.embedding_storage_freelist = embedding_storage_freelist
        self.stats_handler = stats_handler

        self.strict = False