Ejemplo n.º 1
0
def main():
    setup_logging()
    config_help = "\n\nConfig parameters:\n\n" + "\n".join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument("config", help="Path to config file")
    parser.add_argument("-p", "--param", action="append", nargs="*")
    parser.add_argument(
        "--rank",
        type=int,
        default=SINGLE_TRAINER,
        help="For multi-machine, this machine's rank",
    )
    opt = parser.parse_args()

    loader = ConfigFileLoader()
    config = loader.load_config(opt.config, opt.param)
    set_logging_verbosity(config.verbose)
    subprocess_init = SubprocessInitializer()
    subprocess_init.register(setup_logging, config.verbose)
    subprocess_init.register(add_to_sys_path, loader.config_dir.name)

    train(config, rank=opt.rank, subprocess_init=subprocess_init)
Ejemplo n.º 2
0
def main():
    setup_logging()
    config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('config', help="Path to config file")
    parser.add_argument('-p', '--param', action='append', nargs='*')
    parser.add_argument('--rank',
                        type=int,
                        default=0,
                        help="For multi-machine, this machine's rank")
    opt = parser.parse_args()

    if opt.param is not None:
        overrides = chain.from_iterable(opt.param)  # flatten
    else:
        overrides = None
    loader = ConfigFileLoader()
    config = loader.load_config(opt.config, overrides)
    set_logging_verbosity(config.verbose)
    subprocess_init = SubprocessInitializer()
    subprocess_init.register(setup_logging, config.verbose)
    subprocess_init.register(add_to_sys_path, loader.config_dir.name)

    train(config, rank=Rank(opt.rank), subprocess_init=subprocess_init)
Ejemplo n.º 3
0
def main():
    setup_logging()
    parser = argparse.ArgumentParser(description='Example on FB15k')
    parser.add_argument('--config', default=DEFAULT_CONFIG,
                        help='Path to config file')
    parser.add_argument('-p', '--param', action='append', nargs='*')
    parser.add_argument('--data_dir', type=Path, default='data',
                        help='where to save processed data')
    parser.add_argument('--no-filtered', dest='filtered', action='store_false',
                        help='Run unfiltered eval')
    args = parser.parse_args()

    if args.param is not None:
        overrides = chain.from_iterable(args.param)  # flatten
    else:
        overrides = None

    # download data
    data_dir = args.data_dir
    fpath = download_url(FB15K_URL, data_dir)
    extract_tar(fpath)
    print('Downloaded and extracted file.')

    loader = ConfigFileLoader()
    config = loader.load_config(args.config, overrides)
    set_logging_verbosity(config.verbose)
    subprocess_init = SubprocessInitializer()
    subprocess_init.register(setup_logging, config.verbose)
    subprocess_init.register(add_to_sys_path, loader.config_dir.name)
    input_edge_paths = [data_dir / name for name in FILENAMES]
    output_train_path, output_valid_path, output_test_path = config.edge_paths

    convert_input_data(
        config.entities,
        config.relations,
        config.entity_path,
        config.edge_paths,
        input_edge_paths,
        lhs_col=0,
        rhs_col=2,
        rel_col=1,
        dynamic_relations=config.dynamic_relations,
    )

    train_config = attr.evolve(config, edge_paths=[output_train_path])
    train(train_config, subprocess_init=subprocess_init)

    relations = [attr.evolve(r, all_negs=True) for r in config.relations]
    eval_config = attr.evolve(
        config, edge_paths=[output_test_path], relations=relations, num_uniform_negs=0)
    if args.filtered:
        filter_paths = [output_test_path, output_valid_path, output_train_path]
        do_eval(
            eval_config,
            evaluator=FilteredRankingEvaluator(eval_config, filter_paths),
            subprocess_init=subprocess_init,
        )
    else:
        do_eval(eval_config, subprocess_init=subprocess_init)
Ejemplo n.º 4
0
def main():
    setup_logging()
    parser = argparse.ArgumentParser(description='Example on Livejournal')
    parser.add_argument('--config',
                        default=DEFAULT_CONFIG,
                        help='Path to config file')
    parser.add_argument('-p', '--param', action='append', nargs='*')
    parser.add_argument('--data_dir',
                        type=Path,
                        default='data',
                        help='where to save processed data')

    args = parser.parse_args()

    if args.param is not None:
        overrides = chain.from_iterable(args.param)  # flatten
    else:
        overrides = None

    # download data
    data_dir = args.data_dir
    data_dir.mkdir(parents=True, exist_ok=True)
    fpath = download_url(URL, data_dir)
    fpath = extract_gzip(fpath)
    print('Downloaded and extracted file.')

    # random split file for train and test
    random_split_file(fpath)

    loader = ConfigFileLoader()
    config = loader.load_config(args.config, overrides)
    set_logging_verbosity(config.verbose)
    subprocess_init = SubprocessInitializer()
    subprocess_init.register(setup_logging, config.verbose)
    subprocess_init.register(add_to_sys_path, loader.config_dir.name)
    edge_paths = [data_dir / name for name in FILENAMES.values()]

    convert_input_data(
        config.entities,
        config.relations,
        config.entity_path,
        edge_paths,
        lhs_col=0,
        rhs_col=1,
        rel_col=None,
        dynamic_relations=config.dynamic_relations,
    )

    train_path = [str(convert_path(data_dir / FILENAMES['train']))]
    train_config = attr.evolve(config, edge_paths=train_path)

    train(train_config, subprocess_init=subprocess_init)

    eval_path = [str(convert_path(data_dir / FILENAMES['test']))]
    eval_config = attr.evolve(config, edge_paths=eval_path)

    do_eval(eval_config, subprocess_init=subprocess_init)
Ejemplo n.º 5
0
def main():
    setup_logging()
    parser = argparse.ArgumentParser(description='Example on Livejournal')
    parser.add_argument('--config',
                        default=DEFAULT_CONFIG,
                        help='Path to config file')
    parser.add_argument('-p', '--param', action='append', nargs='*')
    parser.add_argument('--data_dir',
                        type=Path,
                        default='data',
                        help='where to save processed data')

    args = parser.parse_args()

    # download data
    data_dir = args.data_dir
    data_dir.mkdir(parents=True, exist_ok=True)
    fpath = download_url(URL, data_dir)
    fpath = extract_gzip(fpath)
    print('Downloaded and extracted file.')

    # random split file for train and test
    random_split_file(fpath)

    loader = ConfigFileLoader()
    config = loader.load_config(args.config, args.param)
    set_logging_verbosity(config.verbose)
    subprocess_init = SubprocessInitializer()
    subprocess_init.register(setup_logging, config.verbose)
    subprocess_init.register(add_to_sys_path, loader.config_dir.name)
    input_edge_paths = [data_dir / name for name in FILENAMES]
    output_train_path, output_test_path = config.edge_paths

    convert_input_data(
        config.entities,
        config.relations,
        config.entity_path,
        config.edge_paths,
        input_edge_paths,
        TSVEdgelistReader(lhs_col=0, rhs_col=1, rel_col=None),
        dynamic_relations=config.dynamic_relations,
    )

    train_config = attr.evolve(config, edge_paths=[output_train_path])
    train(train_config, subprocess_init=subprocess_init)

    eval_config = attr.evolve(config, edge_paths=[output_test_path])
    do_eval(eval_config, subprocess_init=subprocess_init)
Ejemplo n.º 6
0
def main():
    setup_logging()
    config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('config', help="Path to config file")
    parser.add_argument('-p', '--param', action='append', nargs='*')
    opt = parser.parse_args()

    loader = ConfigFileLoader()
    config = loader.load_config(opt.config, opt.param)
    set_logging_verbosity(config.verbose)
    subprocess_init = SubprocessInitializer()
    subprocess_init.register(setup_logging, config.verbose)
    subprocess_init.register(add_to_sys_path, loader.config_dir.name)

    do_eval(config, subprocess_init=subprocess_init)
Ejemplo n.º 7
0
class TestFunctional(TestCase):
    def setUp(self) -> None:
        self.subprocess_init = SubprocessInitializer()
        self.subprocess_init.register(setup_logging, 1)
        self.subprocess_init()

        self.checkpoint_path = TemporaryDirectory()
        self.addCleanup(self.checkpoint_path.cleanup)

        seed = random.getrandbits(32)
        np.random.seed(seed)
        logger.info(f"Random seed: {seed}")

    def assertHasMetadata(self, hf: h5py.File, config: ConfigSchema) -> None:
        self.assertEqual(hf.attrs["format_version"], 1)
        self.assertEqual(json.loads(hf.attrs["config/json"]), config.to_dict())
        self.assertCountEqual(
            [
                key.partition("/")[-1]
                for key in hf.attrs.keys() if key.startswith("iteration/")
            ],
            [
                "num_epochs",
                "epoch_idx",
                "num_edge_paths",
                "edge_path_idx",
                "edge_path",
                "num_edge_chunks",
                "edge_chunk_idx",
            ],
        )

    def assertIsModelParameter(self, dataset: h5py.Dataset) -> None:
        # In fact it could also be a group...
        if not isinstance(dataset, h5py.Dataset):
            return
        self.assertIn("state_dict_key", dataset.attrs)
        self.assertTrue(np.isfinite(dataset[...]).all())

    def assertIsModelParameters(self, group: h5py.Group) -> None:
        self.assertIsInstance(group, h5py.Group)
        group.visititems(lambda _, d: self.assertIsModelParameter(d))

    def assertIsOptimStateDict(self, dataset: h5py.Dataset) -> None:
        self.assertIsInstance(dataset, h5py.Dataset)
        self.assertEqual(dataset.dtype, np.dtype("V1"))
        self.assertEqual(len(dataset.shape), 1)

    def assertIsEmbeddings(self, dataset: h5py.Dataset, entity_count: int,
                           dimension: int) -> None:
        self.assertIsInstance(dataset, h5py.Dataset)
        self.assertEqual(dataset.dtype, np.float32)
        self.assertEqual(dataset.shape, (entity_count, dimension))
        self.assertTrue(np.all(np.isfinite(dataset[...])))
        self.assertTrue(np.all(np.linalg.norm(dataset[...], axis=-1) != 0))

    def assertIsStatsDict(
            self, stats: Mapping[str, Union[int, SerializedStats]]) -> None:
        self.assertIsInstance(stats, dict)
        self.assertIn("index", stats)
        for k, v in stats.items():
            if k in (
                    "epoch_idx",
                    "edge_path_idx",
                    "edge_chunk_idx",
                    "lhs_partition",
                    "rhs_partition",
                    "index",
            ):
                self.assertIsInstance(v, int)
            elif k in (
                    "stats",
                    "eval_stats_before",
                    "eval_stats_after",
                    "eval_stats_chunk_avg",
            ):
                self.assertIsInstance(v, dict)
                assert isinstance(v, dict)
                self.assertCountEqual(v.keys(), ["count", "metrics"])
                self.assertIsInstance(v["count"], int)
                metrics = v["metrics"]
                self.assertIsInstance(metrics, dict)
                assert isinstance(metrics, dict)
                for m in metrics.values():
                    self.assertIsInstance(m, float)
            else:
                self.fail(f"Unknown stats key: {k}")

    def assertCheckpointWritten(self, config: ConfigSchema, *,
                                version: int) -> None:
        with open(
                os.path.join(config.checkpoint_path, "checkpoint_version.txt"),
                "rt") as tf:
            self.assertEqual(version, int(tf.read().strip()))

        with open(os.path.join(config.checkpoint_path, "config.json"),
                  "rt") as tf:
            self.assertEqual(json.load(tf), config.to_dict())

        with h5py.File(
                os.path.join(config.checkpoint_path, "model.v%d.h5" % version),
                "r") as hf:
            self.assertHasMetadata(hf, config)
            self.assertIsModelParameters(hf["model"])
            self.assertIsOptimStateDict(hf["optimizer/state_dict"])

        with open(os.path.join(config.checkpoint_path, "training_stats.json"),
                  "rt") as tf:
            for line in tf:
                self.assertIsStatsDict(json.loads(line))

        for entity_name, entity in config.entities.items():
            for partition in range(entity.num_partitions):
                with open(
                        os.path.join(
                            config.entity_path,
                            "entity_count_%s_%d.txt" %
                            (entity_name, partition),
                        ),
                        "rt",
                ) as tf:
                    entity_count = int(tf.read().strip())
                with h5py.File(
                        os.path.join(
                            config.checkpoint_path,
                            "embeddings_%s_%d.v%d.h5" %
                            (entity_name, partition, version),
                        ),
                        "r",
                ) as hf:
                    self.assertHasMetadata(hf, config)
                    self.assertIsEmbeddings(
                        hf["embeddings"],
                        entity_count,
                        config.entity_dimension(entity_name),
                    )
                    self.assertIsOptimStateDict(hf["optimizer/state_dict"])

    def test_default(self):
        entity_name = "e"
        relation_config = RelationSchema(name="r",
                                         lhs=entity_name,
                                         rhs=entity_name)
        base_config = ConfigSchema(
            dimension=10,
            relations=[relation_config],
            entities={entity_name: EntitySchema(num_partitions=1)},
            entity_path=None,  # filled in later
            edge_paths=[],  # filled in later
            checkpoint_path=self.checkpoint_path.name,
            workers=2,
        )
        dataset = generate_dataset(base_config,
                                   num_entities=100,
                                   fractions=[0.4, 0.2])
        self.addCleanup(dataset.cleanup)
        train_config = attr.evolve(
            base_config,
            entity_path=dataset.entity_path.name,
            edge_paths=[dataset.relation_paths[0].name],
        )
        eval_config = attr.evolve(
            base_config,
            entity_path=dataset.entity_path.name,
            edge_paths=[dataset.relation_paths[1].name],
            relations=[attr.evolve(relation_config, all_negs=True)],
        )
        # Just make sure no exceptions are raised and nothing crashes.
        train(train_config, rank=0, subprocess_init=self.subprocess_init)
        self.assertCheckpointWritten(train_config, version=1)
        do_eval(eval_config, subprocess_init=self.subprocess_init)

    def test_resume_from_checkpoint(self):
        entity_name = "e"
        relation_config = RelationSchema(name="r",
                                         lhs=entity_name,
                                         rhs=entity_name)
        base_config = ConfigSchema(
            dimension=10,
            relations=[relation_config],
            entities={entity_name: EntitySchema(num_partitions=1)},
            entity_path=None,  # filled in later
            edge_paths=[],  # filled in later
            checkpoint_path=self.checkpoint_path.name,
            num_epochs=2,
            num_edge_chunks=2,
            workers=2,
        )
        dataset = generate_dataset(base_config,
                                   num_entities=100,
                                   fractions=[0.4, 0.4])
        self.addCleanup(dataset.cleanup)
        train_config = attr.evolve(
            base_config,
            entity_path=dataset.entity_path.name,
            edge_paths=[d.name for d in dataset.relation_paths],
        )
        # Just make sure no exceptions are raised and nothing crashes.
        init_embeddings(train_config.checkpoint_path, train_config, version=7)
        train(train_config, rank=0, subprocess_init=self.subprocess_init)
        self.assertCheckpointWritten(train_config, version=8)
        # Check we did resume the run, not start the whole thing anew.
        self.assertFalse(
            os.path.exists(
                os.path.join(train_config.checkpoint_path, "model.v6.h5")))

    def test_with_initial_value(self):
        entity_name = "e"
        relation_config = RelationSchema(name="r",
                                         lhs=entity_name,
                                         rhs=entity_name)
        base_config = ConfigSchema(
            dimension=10,
            relations=[relation_config],
            entities={entity_name: EntitySchema(num_partitions=1)},
            entity_path=None,  # filled in later
            edge_paths=[],  # filled in later
            checkpoint_path=self.checkpoint_path.name,
            workers=2,
        )
        dataset = generate_dataset(base_config,
                                   num_entities=100,
                                   fractions=[0.4])
        self.addCleanup(dataset.cleanup)
        init_dir = TemporaryDirectory()
        self.addCleanup(init_dir.cleanup)
        train_config = attr.evolve(
            base_config,
            entity_path=dataset.entity_path.name,
            edge_paths=[dataset.relation_paths[0].name],
            init_path=init_dir.name,
        )
        # Just make sure no exceptions are raised and nothing crashes.
        init_embeddings(train_config.init_path, train_config)
        train(train_config, rank=0, subprocess_init=self.subprocess_init)
        self.assertCheckpointWritten(train_config, version=1)

    def test_featurized(self):
        e1 = EntitySchema(num_partitions=1, featurized=True)
        e2 = EntitySchema(num_partitions=1)
        r1 = RelationSchema(name="r1", lhs="e1", rhs="e2")
        r2 = RelationSchema(name="r2", lhs="e2", rhs="e1")
        base_config = ConfigSchema(
            dimension=10,
            relations=[r1, r2],
            entities={
                "e1": e1,
                "e2": e2
            },
            entity_path=None,  # filled in later
            edge_paths=[],  # filled in later
            checkpoint_path=self.checkpoint_path.name,
            workers=2,
        )
        dataset = generate_dataset(base_config,
                                   num_entities=100,
                                   fractions=[0.4, 0.2])
        self.addCleanup(dataset.cleanup)
        train_config = attr.evolve(
            base_config,
            entity_path=dataset.entity_path.name,
            edge_paths=[dataset.relation_paths[0].name],
        )
        eval_config = attr.evolve(
            base_config,
            entity_path=dataset.entity_path.name,
            edge_paths=[dataset.relation_paths[1].name],
        )
        # Just make sure no exceptions are raised and nothing crashes.
        train(train_config, rank=0, subprocess_init=self.subprocess_init)
        self.assertCheckpointWritten(train_config, version=1)
        do_eval(eval_config, subprocess_init=self.subprocess_init)

    def test_partitioned(self):
        e1 = EntitySchema(num_partitions=1)
        e2 = EntitySchema(num_partitions=2)
        e3 = EntitySchema(num_partitions=3)
        r1 = RelationSchema(name="r1", lhs="e1", rhs="e3")
        r2 = RelationSchema(name="r2", lhs="e2", rhs="e3")
        r3 = RelationSchema(name="r3", lhs="e2", rhs="e1")
        base_config = ConfigSchema(
            dimension=10,
            relations=[r1, r2, r3],
            entities={
                "e1": e1,
                "e2": e2,
                "e3": e3
            },
            entity_path=None,  # filled in later
            edge_paths=[],  # filled in later
            checkpoint_path=self.checkpoint_path.name,
            workers=2,
        )
        dataset = generate_dataset(base_config,
                                   num_entities=100,
                                   fractions=[0.4, 0.2])
        self.addCleanup(dataset.cleanup)
        train_config = attr.evolve(
            base_config,
            entity_path=dataset.entity_path.name,
            edge_paths=[dataset.relation_paths[0].name],
        )
        eval_config = attr.evolve(
            base_config,
            entity_path=dataset.entity_path.name,
            edge_paths=[dataset.relation_paths[1].name],
        )
        # Just make sure no exceptions are raised and nothing crashes.
        train(train_config, rank=0, subprocess_init=self.subprocess_init)
        self.assertCheckpointWritten(train_config, version=1)
        do_eval(eval_config, subprocess_init=self.subprocess_init)

    @unittest.skipIf(not torch.cuda.is_available(), "No GPU")
    def test_gpu(self):
        self._test_gpu()

    @unittest.skipIf(not torch.cuda.is_available(), "No GPU")
    def test_gpu_half(self):
        self._test_gpu(do_half_precision=True)

    @unittest.skipIf(not torch.cuda.is_available(), "No GPU")
    def test_gpu_1partition(self):
        self._test_gpu(num_partitions=1)

    def _test_gpu(self, do_half_precision=False, num_partitions=2):
        entity_name = "e"
        relation_config = RelationSchema(name="r",
                                         lhs=entity_name,
                                         rhs=entity_name)
        base_config = ConfigSchema(
            dimension=16,
            batch_size=1024,
            num_batch_negs=64,
            num_uniform_negs=64,
            relations=[relation_config],
            entities={
                entity_name: EntitySchema(num_partitions=num_partitions)
            },
            entity_path=None,  # filled in later
            edge_paths=[],  # filled in later
            checkpoint_path=self.checkpoint_path.name,
            workers=2,
            num_gpus=2,
            half_precision=do_half_precision,
        )
        dataset = generate_dataset(base_config,
                                   num_entities=100,
                                   fractions=[0.4, 0.2])
        self.addCleanup(dataset.cleanup)
        train_config = attr.evolve(
            base_config,
            entity_path=dataset.entity_path.name,
            edge_paths=[dataset.relation_paths[0].name],
        )
        eval_config = attr.evolve(
            base_config,
            entity_path=dataset.entity_path.name,
            edge_paths=[dataset.relation_paths[1].name],
            relations=[attr.evolve(relation_config, all_negs=True)],
        )
        # Just make sure no exceptions are raised and nothing crashes.
        train_gpu(train_config, rank=0, subprocess_init=self.subprocess_init)
        self.assertCheckpointWritten(train_config, version=1)
        do_eval(eval_config, subprocess_init=self.subprocess_init)

    def _test_distributed(self, num_partitions):
        sync_path = TemporaryDirectory()
        self.addCleanup(sync_path.cleanup)
        e1 = "e1"
        e2 = "e2"
        relation_config = RelationSchema(
            name="r",
            lhs=e1,
            rhs=e2,
            operator="linear",  # To exercise the parameter server.
        )
        base_config = ConfigSchema(
            dimension=10,
            relations=[relation_config],
            entities={
                e1: EntitySchema(num_partitions=num_partitions),
                e2: EntitySchema(num_partitions=4),
            },
            entity_path=None,  # filled in later
            edge_paths=[],  # filled in later
            checkpoint_path=self.checkpoint_path.name,
            num_machines=2,
            distributed_init_method="file://%s" %
            os.path.join(sync_path.name, "sync"),
            workers=2,
        )
        dataset = generate_dataset(base_config,
                                   num_entities=100,
                                   fractions=[0.4])
        self.addCleanup(dataset.cleanup)
        train_config = attr.evolve(
            base_config,
            entity_path=dataset.entity_path.name,
            edge_paths=[dataset.relation_paths[0].name],
        )
        # Just make sure no exceptions are raised and nothing crashes.
        trainer0 = mp.get_context("spawn").Process(
            name="Trainer-0",
            target=partial(
                call_one_after_the_other,
                self.subprocess_init,
                partial(train,
                        train_config,
                        rank=0,
                        subprocess_init=self.subprocess_init),
            ),
        )
        trainer1 = mp.get_context("spawn").Process(
            name="Trainer-1",
            target=partial(
                call_one_after_the_other,
                self.subprocess_init,
                partial(train,
                        train_config,
                        rank=1,
                        subprocess_init=self.subprocess_init),
            ),
        )
        # FIXME In Python 3.7 use kill here.
        self.addCleanup(trainer0.terminate)
        self.addCleanup(trainer1.terminate)
        trainer0.start()
        trainer1.start()
        done = [False, False]
        while not all(done):
            time.sleep(1)
            if not trainer0.is_alive() and not done[0]:
                self.assertEqual(trainer0.exitcode, 0)
                done[0] = True
            if not trainer1.is_alive() and not done[1]:
                self.assertEqual(trainer1.exitcode, 0)
                done[1] = True
        self.assertCheckpointWritten(train_config, version=1)

    def test_distributed(self):
        self._test_distributed(num_partitions=4)

    def test_distributed_unpartitioned(self):
        self._test_distributed(num_partitions=1)

    def test_distributed_with_partition_servers(self):
        sync_path = TemporaryDirectory()
        self.addCleanup(sync_path.cleanup)
        entity_name = "e"
        relation_config = RelationSchema(name="r",
                                         lhs=entity_name,
                                         rhs=entity_name)
        base_config = ConfigSchema(
            dimension=10,
            relations=[relation_config],
            entities={entity_name: EntitySchema(num_partitions=4)},
            entity_path=None,  # filled in later
            edge_paths=[],  # filled in later
            checkpoint_path=self.checkpoint_path.name,
            num_machines=2,
            num_partition_servers=2,
            distributed_init_method="file://%s" %
            os.path.join(sync_path.name, "sync"),
            workers=2,
        )
        dataset = generate_dataset(base_config,
                                   num_entities=100,
                                   fractions=[0.4])
        self.addCleanup(dataset.cleanup)
        train_config = attr.evolve(
            base_config,
            entity_path=dataset.entity_path.name,
            edge_paths=[dataset.relation_paths[0].name],
        )
        # Just make sure no exceptions are raised and nothing crashes.
        trainer0 = mp.get_context("spawn").Process(
            name="Trainer-0",
            target=partial(
                call_one_after_the_other,
                self.subprocess_init,
                partial(train,
                        train_config,
                        rank=0,
                        subprocess_init=self.subprocess_init),
            ),
        )
        trainer1 = mp.get_context("spawn").Process(
            name="Trainer-1",
            target=partial(
                call_one_after_the_other,
                self.subprocess_init,
                partial(train,
                        train_config,
                        rank=1,
                        subprocess_init=self.subprocess_init),
            ),
        )
        partition_server0 = mp.get_context("spawn").Process(
            name="PartitionServer-0",
            target=partial(
                call_one_after_the_other,
                self.subprocess_init,
                partial(
                    run_partition_server,
                    train_config,
                    rank=0,
                    subprocess_init=self.subprocess_init,
                ),
            ),
        )
        partition_server1 = mp.get_context("spawn").Process(
            name="PartitionServer-1",
            target=partial(
                call_one_after_the_other,
                self.subprocess_init,
                partial(
                    run_partition_server,
                    train_config,
                    rank=1,
                    subprocess_init=self.subprocess_init,
                ),
            ),
        )
        # FIXME In Python 3.7 use kill here.
        self.addCleanup(trainer0.terminate)
        self.addCleanup(trainer1.terminate)
        self.addCleanup(partition_server0.terminate)
        self.addCleanup(partition_server1.terminate)
        trainer0.start()
        trainer1.start()
        partition_server0.start()
        partition_server1.start()
        done = [False, False]
        while not all(done):
            time.sleep(1)
            if not trainer0.is_alive() and not done[0]:
                self.assertEqual(trainer0.exitcode, 0)
                done[0] = True
            if not trainer1.is_alive() and not done[1]:
                self.assertEqual(trainer1.exitcode, 0)
                done[1] = True
        partition_server0.join()
        partition_server1.join()
        logger.info(
            f"Partition server 0 died with exit code {partition_server0.exitcode}"
        )
        logger.info(
            f"Partition server 0 died with exit code {partition_server1.exitcode}"
        )
        self.assertCheckpointWritten(train_config, version=1)

    def test_dynamic_relations(self):
        relation_config = RelationSchema(name="r", lhs="el", rhs="er")
        base_config = ConfigSchema(
            dimension=10,
            relations=[relation_config],
            entities={
                "el": EntitySchema(num_partitions=1),
                "er": EntitySchema(num_partitions=1),
            },
            entity_path=None,  # filled in later
            edge_paths=[],  # filled in later
            checkpoint_path=self.checkpoint_path.name,
            dynamic_relations=True,
            global_emb=False,  # Must be off for dynamic relations.
            workers=2,
        )
        gen_config = attr.evolve(
            base_config,
            relations=[relation_config] * 10,
            dynamic_relations=False,  # Must be off if more than 1 relation.
        )
        dataset = generate_dataset(gen_config,
                                   num_entities=100,
                                   fractions=[0.04, 0.02])
        self.addCleanup(dataset.cleanup)
        with open(
                os.path.join(dataset.entity_path.name,
                             "dynamic_rel_count.txt"), "xt") as f:
            f.write("%d" % len(gen_config.relations))
        train_config = attr.evolve(
            base_config,
            entity_path=dataset.entity_path.name,
            edge_paths=[dataset.relation_paths[0].name],
        )
        eval_config = attr.evolve(
            base_config,
            relations=[attr.evolve(relation_config, all_negs=True)],
            entity_path=dataset.entity_path.name,
            edge_paths=[dataset.relation_paths[1].name],
        )
        # Just make sure no exceptions are raised and nothing crashes.
        train(train_config, rank=0, subprocess_init=self.subprocess_init)
        self.assertCheckpointWritten(train_config, version=1)
        do_eval(eval_config, subprocess_init=self.subprocess_init)

    def test_entity_dimensions(self):
        entity_name = "e"
        relation_config = RelationSchema(name="r",
                                         lhs=entity_name,
                                         rhs=entity_name)
        base_config = ConfigSchema(
            dimension=10,
            relations=[relation_config],
            entities={
                entity_name: EntitySchema(num_partitions=1, dimension=8)
            },
            entity_path=None,  # filled in later
            edge_paths=[],  # filled in later
            checkpoint_path=self.checkpoint_path.name,
            workers=2,
        )
        dataset = generate_dataset(base_config,
                                   num_entities=100,
                                   fractions=[0.4, 0.2])
        self.addCleanup(dataset.cleanup)
        train_config = attr.evolve(
            base_config,
            entity_path=dataset.entity_path.name,
            edge_paths=[dataset.relation_paths[0].name],
        )
        eval_config = attr.evolve(
            base_config,
            entity_path=dataset.entity_path.name,
            edge_paths=[dataset.relation_paths[1].name],
            relations=[attr.evolve(relation_config, all_negs=True)],
        )
        # Just make sure no exceptions are raised and nothing crashes.
        train(train_config, rank=0, subprocess_init=self.subprocess_init)
        self.assertCheckpointWritten(train_config, version=1)
        do_eval(eval_config, subprocess_init=self.subprocess_init)
Ejemplo n.º 8
0
def main():
    setup_logging()
    parser = argparse.ArgumentParser(description="Example on FB15k")
    parser.add_argument("--config",
                        default=DEFAULT_CONFIG,
                        help="Path to config file")
    parser.add_argument("-p", "--param", action="append", nargs="*")
    parser.add_argument("--data_dir",
                        type=Path,
                        default="data",
                        help="where to save processed data")
    parser.add_argument(
        "--no-filtered",
        dest="filtered",
        action="store_false",
        help="Run unfiltered eval",
    )
    args = parser.parse_args()

    # download data
    data_dir = args.data_dir
    fpath = download_url(FB15K_URL, data_dir)
    extract_tar(fpath)
    print("Downloaded and extracted file.")

    loader = ConfigFileLoader()
    config = loader.load_config(args.config, args.param)
    set_logging_verbosity(config.verbose)
    subprocess_init = SubprocessInitializer()
    subprocess_init.register(setup_logging, config.verbose)
    subprocess_init.register(add_to_sys_path, loader.config_dir.name)
    input_edge_paths = [data_dir / name for name in FILENAMES]
    output_train_path, output_valid_path, output_test_path = config.edge_paths

    convert_input_data(
        config.entities,
        config.relations,
        config.entity_path,
        config.edge_paths,
        input_edge_paths,
        TSVEdgelistReader(lhs_col=0, rhs_col=2, rel_col=1),
        dynamic_relations=config.dynamic_relations,
    )

    train_config = attr.evolve(config, edge_paths=[output_train_path])
    train(train_config, subprocess_init=subprocess_init)

    relations = [attr.evolve(r, all_negs=True) for r in config.relations]
    eval_config = attr.evolve(config,
                              edge_paths=[output_test_path],
                              relations=relations,
                              num_uniform_negs=0)
    if args.filtered:
        filter_paths = [output_test_path, output_valid_path, output_train_path]
        do_eval(
            eval_config,
            evaluator=FilteredRankingEvaluator(eval_config, filter_paths),
            subprocess_init=subprocess_init,
        )
    else:
        do_eval(eval_config, subprocess_init=subprocess_init)