Exemplo n.º 1
0
 def test_featurized(self):
     e1 = EntitySchema(num_partitions=1, featurized=True)
     e2 = EntitySchema(num_partitions=1)
     r1 = RelationSchema(name="r1", lhs="e1", rhs="e2")
     r2 = RelationSchema(name="r2", lhs="e2", rhs="e1")
     base_config = ConfigSchema(
         dimension=10,
         relations=[r1, r2],
         entities={
             "e1": e1,
             "e2": e2
         },
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         workers=2,
     )
     dataset = generate_dataset(base_config,
                                num_entities=100,
                                fractions=[0.4, 0.2])
     self.addCleanup(dataset.cleanup)
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[0].name],
     )
     eval_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[1].name],
     )
     # Just make sure no exceptions are raised and nothing crashes.
     train(train_config, rank=0, subprocess_init=self.subprocess_init)
     self.assertCheckpointWritten(train_config, version=1)
     do_eval(eval_config, subprocess_init=self.subprocess_init)
Exemplo n.º 2
0
 def test_resume_from_checkpoint(self):
     entity_name = "e"
     relation_config = RelationSchema(name="r",
                                      lhs=entity_name,
                                      rhs=entity_name)
     base_config = ConfigSchema(
         dimension=10,
         relations=[relation_config],
         entities={entity_name: EntitySchema(num_partitions=1)},
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         num_epochs=2,
         num_edge_chunks=2,
         workers=2,
     )
     dataset = generate_dataset(base_config,
                                num_entities=100,
                                fractions=[0.4, 0.4])
     self.addCleanup(dataset.cleanup)
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[d.name for d in dataset.relation_paths],
     )
     # Just make sure no exceptions are raised and nothing crashes.
     init_embeddings(train_config.checkpoint_path, train_config, version=7)
     train(train_config, rank=0, subprocess_init=self.subprocess_init)
     self.assertCheckpointWritten(train_config, version=8)
     # Check we did resume the run, not start the whole thing anew.
     self.assertFalse(
         os.path.exists(
             os.path.join(train_config.checkpoint_path, "model.v6.h5")))
Exemplo n.º 3
0
 def test_with_initial_value(self):
     entity_name = "e"
     relation_config = RelationSchema(name="r",
                                      lhs=entity_name,
                                      rhs=entity_name)
     base_config = ConfigSchema(
         dimension=10,
         relations=[relation_config],
         entities={entity_name: EntitySchema(num_partitions=1)},
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         workers=2,
     )
     dataset = generate_dataset(base_config,
                                num_entities=100,
                                fractions=[0.4])
     self.addCleanup(dataset.cleanup)
     init_dir = TemporaryDirectory()
     self.addCleanup(init_dir.cleanup)
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[0].name],
         init_path=init_dir.name,
     )
     # Just make sure no exceptions are raised and nothing crashes.
     init_embeddings(train_config.init_path, train_config)
     train(train_config, rank=0, subprocess_init=self.subprocess_init)
     self.assertCheckpointWritten(train_config, version=1)
Exemplo n.º 4
0
 def test_entity_dimensions(self):
     entity_name = "e"
     relation_config = RelationSchema(name="r",
                                      lhs=entity_name,
                                      rhs=entity_name)
     base_config = ConfigSchema(
         dimension=10,
         relations=[relation_config],
         entities={
             entity_name: EntitySchema(num_partitions=1, dimension=8)
         },
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         workers=2,
     )
     dataset = generate_dataset(base_config,
                                num_entities=100,
                                fractions=[0.4, 0.2])
     self.addCleanup(dataset.cleanup)
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[0].name],
     )
     eval_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[1].name],
         relations=[attr.evolve(relation_config, all_negs=True)],
     )
     # Just make sure no exceptions are raised and nothing crashes.
     train(train_config, rank=0, subprocess_init=self.subprocess_init)
     self.assertCheckpointWritten(train_config, version=1)
     do_eval(eval_config, subprocess_init=self.subprocess_init)
 def _test_gpu(self, do_half_precision=False, num_partitions=2):
     entity_name = "e"
     relation_config = RelationSchema(name="r", lhs=entity_name, rhs=entity_name)
     base_config = ConfigSchema(
         dimension=16,
         batch_size=1024,
         num_batch_negs=64,
         num_uniform_negs=64,
         relations=[relation_config],
         entities={entity_name: EntitySchema(num_partitions=num_partitions)},
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         workers=2,
         num_gpus=2,
         regularization_coef=1e-4,
         half_precision=do_half_precision,
     )
     dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4, 0.2])
     self.addCleanup(dataset.cleanup)
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[0].name],
     )
     eval_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[1].name],
         relations=[attr.evolve(relation_config, all_negs=True)],
     )
     # Just make sure no exceptions are raised and nothing crashes.
     train(train_config, rank=0, subprocess_init=self.subprocess_init)
     self.assertCheckpointWritten(train_config, version=1)
     do_eval(eval_config, subprocess_init=self.subprocess_init)
Exemplo n.º 6
0
def main():
    setup_logging()
    parser = argparse.ArgumentParser(description='Example on FB15k')
    parser.add_argument('--config', default=DEFAULT_CONFIG,
                        help='Path to config file')
    parser.add_argument('-p', '--param', action='append', nargs='*')
    parser.add_argument('--data_dir', type=Path, default='data',
                        help='where to save processed data')
    parser.add_argument('--no-filtered', dest='filtered', action='store_false',
                        help='Run unfiltered eval')
    args = parser.parse_args()

    if args.param is not None:
        overrides = chain.from_iterable(args.param)  # flatten
    else:
        overrides = None

    # download data
    data_dir = args.data_dir
    fpath = download_url(FB15K_URL, data_dir)
    extract_tar(fpath)
    print('Downloaded and extracted file.')

    loader = ConfigFileLoader()
    config = loader.load_config(args.config, overrides)
    set_logging_verbosity(config.verbose)
    subprocess_init = SubprocessInitializer()
    subprocess_init.register(setup_logging, config.verbose)
    subprocess_init.register(add_to_sys_path, loader.config_dir.name)
    input_edge_paths = [data_dir / name for name in FILENAMES]
    output_train_path, output_valid_path, output_test_path = config.edge_paths

    convert_input_data(
        config.entities,
        config.relations,
        config.entity_path,
        config.edge_paths,
        input_edge_paths,
        lhs_col=0,
        rhs_col=2,
        rel_col=1,
        dynamic_relations=config.dynamic_relations,
    )

    train_config = attr.evolve(config, edge_paths=[output_train_path])
    train(train_config, subprocess_init=subprocess_init)

    relations = [attr.evolve(r, all_negs=True) for r in config.relations]
    eval_config = attr.evolve(
        config, edge_paths=[output_test_path], relations=relations, num_uniform_negs=0)
    if args.filtered:
        filter_paths = [output_test_path, output_valid_path, output_train_path]
        do_eval(
            eval_config,
            evaluator=FilteredRankingEvaluator(eval_config, filter_paths),
            subprocess_init=subprocess_init,
        )
    else:
        do_eval(eval_config, subprocess_init=subprocess_init)
Exemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser(description='Example on FB15k')
    parser.add_argument('--config',
                        default='./fb15k_config.py',
                        help='Path to config file')
    parser.add_argument('-p', '--param', action='append', nargs='*')
    parser.add_argument('--data_dir',
                        default='../../../data',
                        help='where to save processed data')
    parser.add_argument('--no-filtered',
                        dest='filtered',
                        action='store_false',
                        help='Run unfiltered eval')
    args = parser.parse_args()

    if args.param is not None:
        overrides = chain.from_iterable(args.param)  # flatten
    else:
        overrides = None

    # download data
    data_dir = args.data_dir
    #fpath = utils.download_url(FB15K_URL, data_dir)
    #utils.extract_tar(fpath)
    #print('Downloaded and extracted file.')

    edge_paths = [os.path.join(data_dir, name) for name in FILENAMES.values()]
    print('edge_paths', edge_paths)
    convert_input_data(
        args.config,
        edge_paths,
        lhs_col=0,
        rhs_col=2,
        rel_col=1,
    )

    config = parse_config(args.config, overrides)

    train_path = [convert_path(os.path.join(data_dir, FILENAMES['train']))]
    train_config = attr.evolve(config, edge_paths=train_path)

    train(train_config)

    eval_path = [convert_path(os.path.join(data_dir, FILENAMES['test']))]
    relations = [attr.evolve(r, all_negs=True) for r in config.relations]
    eval_config = attr.evolve(config,
                              edge_paths=eval_path,
                              relations=relations)
    if args.filtered:
        filter_paths = [
            convert_path(os.path.join(data_dir, FILENAMES['test'])),
            convert_path(os.path.join(data_dir, FILENAMES['valid'])),
            convert_path(os.path.join(data_dir, FILENAMES['train'])),
        ]
        do_eval(eval_config, FilteredRankingEvaluator(eval_config,
                                                      filter_paths))
    else:
        do_eval(eval_config)
Exemplo n.º 8
0
def main():
    parser = argparse.ArgumentParser(description='Example on Livejournal')
    parser.add_argument('--config',
                        default=DEFAULT_CONFIG,
                        help='Path to config file')
    parser.add_argument('-p', '--param', action='append', nargs='*')
    parser.add_argument('--data_dir',
                        default='data',
                        help='where to save processed data')

    args = parser.parse_args()

    if args.param is not None:
        overrides = chain.from_iterable(args.param)  # flatten
    else:
        overrides = None

    # download data
    data_dir = args.data_dir
    os.makedirs(data_dir, exist_ok=True)
    fpath = utils.download_url(URL, data_dir)
    fpath = utils.extract_gzip(fpath)
    print('Downloaded and extracted file.')

    # random split file for train and test
    random_split_file(fpath)

    loader = ConfigFileLoader()
    config = loader.load_config(args.config, overrides)
    edge_paths = [os.path.join(data_dir, name) for name in FILENAMES.values()]

    convert_input_data(
        config.entities,
        config.relations,
        config.entity_path,
        edge_paths,
        lhs_col=0,
        rhs_col=1,
        rel_col=None,
        dynamic_relations=config.dynamic_relations,
    )

    train_path = [convert_path(os.path.join(data_dir, FILENAMES['train']))]
    train_config = attr.evolve(config, edge_paths=train_path)

    train(
        train_config,
        subprocess_init=partial(add_to_sys_path, loader.config_dir.name),
    )

    eval_path = [convert_path(os.path.join(data_dir, FILENAMES['test']))]
    eval_config = attr.evolve(config, edge_paths=eval_path)

    do_eval(
        eval_config,
        subprocess_init=partial(add_to_sys_path, loader.config_dir.name),
    )
Exemplo n.º 9
0
def main():
    setup_logging()
    parser = argparse.ArgumentParser(description='Example on Livejournal')
    parser.add_argument('--config',
                        default=DEFAULT_CONFIG,
                        help='Path to config file')
    parser.add_argument('-p', '--param', action='append', nargs='*')
    parser.add_argument('--data_dir',
                        type=Path,
                        default='data',
                        help='where to save processed data')

    args = parser.parse_args()

    if args.param is not None:
        overrides = chain.from_iterable(args.param)  # flatten
    else:
        overrides = None

    # download data
    data_dir = args.data_dir
    data_dir.mkdir(parents=True, exist_ok=True)
    fpath = download_url(URL, data_dir)
    fpath = extract_gzip(fpath)
    print('Downloaded and extracted file.')

    # random split file for train and test
    random_split_file(fpath)

    loader = ConfigFileLoader()
    config = loader.load_config(args.config, overrides)
    set_logging_verbosity(config.verbose)
    subprocess_init = SubprocessInitializer()
    subprocess_init.register(setup_logging, config.verbose)
    subprocess_init.register(add_to_sys_path, loader.config_dir.name)
    edge_paths = [data_dir / name for name in FILENAMES.values()]

    convert_input_data(
        config.entities,
        config.relations,
        config.entity_path,
        edge_paths,
        lhs_col=0,
        rhs_col=1,
        rel_col=None,
        dynamic_relations=config.dynamic_relations,
    )

    train_path = [str(convert_path(data_dir / FILENAMES['train']))]
    train_config = attr.evolve(config, edge_paths=train_path)

    train(train_config, subprocess_init=subprocess_init)

    eval_path = [str(convert_path(data_dir / FILENAMES['test']))]
    eval_config = attr.evolve(config, edge_paths=eval_path)

    do_eval(eval_config, subprocess_init=subprocess_init)
Exemplo n.º 10
0
def main():
    setup_logging()
    parser = argparse.ArgumentParser(description='Example on Livejournal')
    parser.add_argument('--config',
                        default=DEFAULT_CONFIG,
                        help='Path to config file')
    parser.add_argument('-p', '--param', action='append', nargs='*')
    parser.add_argument('--data_dir',
                        type=Path,
                        default='data',
                        help='where to save processed data')

    args = parser.parse_args()

    # download data
    data_dir = args.data_dir
    data_dir.mkdir(parents=True, exist_ok=True)
    fpath = download_url(URL, data_dir)
    fpath = extract_gzip(fpath)
    print('Downloaded and extracted file.')

    # random split file for train and test
    random_split_file(fpath)

    loader = ConfigFileLoader()
    config = loader.load_config(args.config, args.param)
    set_logging_verbosity(config.verbose)
    subprocess_init = SubprocessInitializer()
    subprocess_init.register(setup_logging, config.verbose)
    subprocess_init.register(add_to_sys_path, loader.config_dir.name)
    input_edge_paths = [data_dir / name for name in FILENAMES]
    output_train_path, output_test_path = config.edge_paths

    convert_input_data(
        config.entities,
        config.relations,
        config.entity_path,
        config.edge_paths,
        input_edge_paths,
        TSVEdgelistReader(lhs_col=0, rhs_col=1, rel_col=None),
        dynamic_relations=config.dynamic_relations,
    )

    train_config = attr.evolve(config, edge_paths=[output_train_path])
    train(train_config, subprocess_init=subprocess_init)

    eval_config = attr.evolve(config, edge_paths=[output_test_path])
    do_eval(eval_config, subprocess_init=subprocess_init)
Exemplo n.º 11
0
def run_train_eval():
    #将数据转为PBG可读的分区文件
    convert_input_data(CONFIG_PATH,
                       edge_paths,
                       lhs_col=0,
                       rhs_col=1,
                       rel_col=None)
    #解析配置
    config = parse_config(CONFIG_PATH)
    #训练配置,已分区的train_paths路径替换配置文件中的edge_paths
    train_config = attr.evolve(config, edge_paths=train_paths)
    #传入训练配置文件开始训练
    train(train_config)
    #测试配置,已分区的eval_paths路径替换配置文件中的edge_paths
    eval_config = attr.evolve(config, edge_paths=eval_paths)
    #开始验证
    do_eval(eval_config)
Exemplo n.º 12
0
 def test_dynamic_relations(self):
     relation_config = RelationSchema(name="r", lhs="el", rhs="er")
     base_config = ConfigSchema(
         dimension=10,
         relations=[relation_config],
         entities={
             "el": EntitySchema(num_partitions=1),
             "er": EntitySchema(num_partitions=1),
         },
         entity_path=None,  # filled in later
         edge_paths=[],  # filled in later
         checkpoint_path=self.checkpoint_path.name,
         dynamic_relations=True,
         global_emb=False,  # Must be off for dynamic relations.
         workers=2,
     )
     gen_config = attr.evolve(
         base_config,
         relations=[relation_config] * 10,
         dynamic_relations=False,  # Must be off if more than 1 relation.
     )
     dataset = generate_dataset(gen_config,
                                num_entities=100,
                                fractions=[0.04, 0.02])
     self.addCleanup(dataset.cleanup)
     with open(
             os.path.join(dataset.entity_path.name,
                          "dynamic_rel_count.txt"), "xt") as f:
         f.write("%d" % len(gen_config.relations))
     train_config = attr.evolve(
         base_config,
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[0].name],
     )
     eval_config = attr.evolve(
         base_config,
         relations=[attr.evolve(relation_config, all_negs=True)],
         entity_path=dataset.entity_path.name,
         edge_paths=[dataset.relation_paths[1].name],
     )
     # Just make sure no exceptions are raised and nothing crashes.
     train(train_config, rank=0, subprocess_init=self.subprocess_init)
     self.assertCheckpointWritten(train_config, version=1)
     do_eval(eval_config, subprocess_init=self.subprocess_init)
Exemplo n.º 13
0
def run_train_eval():
    random_split_file(DATA_PATH)

    convert_input_data(
        CONFIG_PATH,
        edge_paths,
        lhs_col=0,
        rhs_col=1,
        rel_col=None,
    )

    train_config = parse_config(CONFIG_PATH)

    train_config = attr.evolve(train_config, edge_paths=train_path)

    train(train_config)

    eval_config = attr.evolve(train_config, edge_paths=eval_path)

    do_eval(eval_config)
Exemplo n.º 14
0
from torchbiggraph.config import parse_config
import attr
train_config = parse_config(CONFIG_PATH)

train_path = [convert_path(os.path.join(DATA_DIR, FILENAMES['train']))]
train_config = attr.evolve(train_config, edge_paths=train_path)

from torchbiggraph.train import train
train(train_config)

# Time to run on liveJournal data: 17:43 - ???
### SNIPPET 3 ###
Exemplo n.º 15
0
    )

    # ===============================================
    # 3. TRAIN THE EMBEDDINGS
    # files generated in this step:
    #
    # checkpoint_version.txt
    # config.json
    # embeddings_item_0.v7.h5
    # embeddings_merchant_0.v7.h5
    # embeddings_user_0.v7.h5
    # model.v7.h5
    # training_stats.json
    # ===============================================

    train(config, subprocess_init=subprocess_init)

    # =======================================================================
    # 4. LOAD THE EMBEDDINGS
    # The final output of the process consists of 3 dictionaries -
    # - one for users, items, merchants - mapping entity to its embedding
    # =======================================================================
    users_path = DATA_DIR + '/entity_names_user_0.json'
    items_path = DATA_DIR + '/entity_names_item_0.json'
    merchants_path = DATA_DIR + '/entity_names_merchant_0.json'

    user_emb_path = MODEL_DIR + "/embeddings_user_0.v{NUMBER_OF_EPOCHS}.h5" \
        .format(NUMBER_OF_EPOCHS=raw_config['num_epochs'])
    item_emb_path = MODEL_DIR + "/embeddings_item_0.v{NUMBER_OF_EPOCHS}.h5" \
        .format(NUMBER_OF_EPOCHS=raw_config['num_epochs'])
    merchant_emb_path = MODEL_DIR + "/embeddings_merchant_0.v{NUMBER_OF_EPOCHS}.h5" \
Exemplo n.º 16
0
def main():
    parser = argparse.ArgumentParser(description='Example on FB15k')
    parser.add_argument('--config',
                        default=DEFAULT_CONFIG,
                        help='Path to config file')
    parser.add_argument('-p', '--param', action='append', nargs='*')
    parser.add_argument('--data_dir',
                        default='data',
                        help='where to save processed data')
    parser.add_argument('--no-filtered',
                        dest='filtered',
                        action='store_false',
                        help='Run unfiltered eval')
    args = parser.parse_args()

    if args.param is not None:
        overrides = chain.from_iterable(args.param)  # flatten
    else:
        overrides = None

    # download data
    data_dir = args.data_dir
    fpath = utils.download_url(FB15K_URL, data_dir)
    utils.extract_tar(fpath)
    print('Downloaded and extracted file.')

    loader = ConfigFileLoader()
    config = loader.load_config(args.config, overrides)
    edge_paths = [os.path.join(data_dir, name) for name in FILENAMES.values()]

    convert_input_data(
        config.entities,
        config.relations,
        config.entity_path,
        edge_paths,
        lhs_col=0,
        rhs_col=2,
        rel_col=1,
        dynamic_relations=config.dynamic_relations,
    )

    train_path = [convert_path(os.path.join(data_dir, FILENAMES['train']))]
    train_config = attr.evolve(config, edge_paths=train_path)

    train(
        train_config,
        subprocess_init=partial(add_to_sys_path, loader.config_dir.name),
    )

    eval_path = [convert_path(os.path.join(data_dir, FILENAMES['test']))]
    relations = [attr.evolve(r, all_negs=True) for r in config.relations]
    eval_config = attr.evolve(config,
                              edge_paths=eval_path,
                              relations=relations,
                              num_uniform_negs=0)
    if args.filtered:
        filter_paths = [
            convert_path(os.path.join(data_dir, FILENAMES['test'])),
            convert_path(os.path.join(data_dir, FILENAMES['valid'])),
            convert_path(os.path.join(data_dir, FILENAMES['train'])),
        ]
        do_eval(
            eval_config,
            evaluator=FilteredRankingEvaluator(eval_config, filter_paths),
            subprocess_init=partial(add_to_sys_path, loader.config_dir.name),
        )
    else:
        do_eval(
            eval_config,
            subprocess_init=partial(add_to_sys_path, loader.config_dir.name),
        )
Exemplo n.º 17
0
def run(input_file: KGTKFiles,
        output_file: KGTKFiles,
        verbose: bool = False,
        very_verbose: bool = False,
        **kwargs):
    """
    **kwargs stores all parameters providing by user
    """
    # print(kwargs)

    # import modules locally
    import sys
    import typing
    import os
    import logging
    from pathlib import Path
    import json, os, h5py, gzip, torch, shutil
    from torchbiggraph.config import parse_config
    from kgtk.exceptions import KGTKException
    # copy  missing file under kgtk/graph_embeddings
    from kgtk.templates.kgtkcopytemplate import KgtkCopyTemplate
    from kgtk.graph_embeddings.importers import TSVEdgelistReader, convert_input_data
    from torchbiggraph.train import train
    from torchbiggraph.util import SubprocessInitializer, setup_logging
    from kgtk.graph_embeddings.export_to_tsv import make_tsv
    # from torchbiggraph.converters.export_to_tsv import make_tsv

    try:
        input_kgtk_file: Path = KGTKArgumentParser.get_input_file(input_file)
        output_kgtk_file: Path = KGTKArgumentParser.get_output_file(
            output_file)

        # store the data into log file, then the console will not output anything
        if kwargs['log_file_path'] != None:
            log_file_path = kwargs['log_file_path']
            logging.basicConfig(
                format='%(asctime)s - %(filename)s[line:%(lineno)d] \
            - %(levelname)s: %(message)s',
                level=logging.DEBUG,
                filename=str(log_file_path),
                filemode='w')
            print(
                f'In Processing, Please go to {kwargs["log_file_path"]} to check details',
                file=sys.stderr,
                flush=True)

        tmp_folder = kwargs['temporary_directory']
        tmp_tsv_path: Path = tmp_folder / f'tmp_{input_kgtk_file.name}'
        # tmp_tsv_path:Path = input_kgtk_file.parent/f'tmp_{input_kgtk_file.name}'

        #  make sure the tmp folder exists, otherwise it will raise an exception
        if not os.path.exists(tmp_folder):
            os.makedirs(tmp_folder)

        try:  #if output_kgtk_file is not empty, delete it
            output_kgtk_file.unlink()
        except:
            pass  # didn't find, then let it go

        # *********************************************
        # 0. PREPARE PBG TSV FILE
        # *********************************************
        reader_options: KgtkReaderOptions = KgtkReaderOptions.from_dict(kwargs)
        value_options: KgtkValueOptions = KgtkValueOptions.from_dict(kwargs)
        error_file: typing.TextIO = sys.stdout if kwargs.get(
            "errors_to_stdout") else sys.stderr
        kct: KgtkCopyTemplate = KgtkCreateTmpTsv(
            input_file_path=input_kgtk_file,
            output_file_path=tmp_tsv_path,
            reader_options=reader_options,
            value_options=value_options,
            error_file=error_file,
            verbose=verbose,
            very_verbose=very_verbose,
        )
        # prepare the graph file
        # create a tmp tsv file for PBG embedding

        logging.info('Generate the valid tsv format for embedding ...')
        kct.process()
        logging.info('Embedding file is ready...')

        # *********************************************
        # 1. DEFINE CONFIG
        # *********************************************
        raw_config = get_config(**kwargs)

        ## setting corresponding learning rate and loss function for different algorthim
        processed_config = config_preprocess(raw_config)

        # temporry output folder
        tmp_output_folder = Path(processed_config['entity_path'])

        # before moving, need to check whether the tmp folder is not empty in case of bug
        try:  #if temporry output folder is alrady existing then delete it
            shutil.rmtree(tmp_output_folder)
        except:
            pass  # didn't find, then let it go

        # **************************************************
        # 2. TRANSFORM GRAPH TO A BIGGRAPH-FRIENDLY FORMAT
        # **************************************************
        setup_logging()
        config = parse_config(processed_config)
        subprocess_init = SubprocessInitializer()
        input_edge_paths = [tmp_tsv_path]

        convert_input_data(
            config.entities,
            config.relations,
            config.entity_path,
            config.edge_paths,
            input_edge_paths,
            TSVEdgelistReader(lhs_col=0, rel_col=1, rhs_col=2),
            dynamic_relations=config.dynamic_relations,
        )

        # ************************************************
        # 3. TRAIN THE EMBEDDINGS
        #*************************************************
        train(config, subprocess_init=subprocess_init)

        # ************************************************
        # 4. GENERATE THE OUTPUT
        # ************************************************
        # entities_output = output_kgtk_file
        entities_output = tmp_output_folder / 'entities_output.tsv'
        relation_types_output = tmp_output_folder / 'relation_types_tf.tsv'

        with open(entities_output,
                  "xt") as entities_tf, open(relation_types_output,
                                             "xt") as relation_types_tf:
            make_tsv(config, entities_tf, relation_types_tf)

        # output  correct format for embeddings
        if kwargs['output_format'] == 'glove':  # glove format output
            shutil.copyfile(entities_output, output_kgtk_file)
        elif kwargs['output_format'] == 'w2v':  # w2v format output
            generate_w2v_output(entities_output, output_kgtk_file, kwargs)

        else:  # write to the kgtk output format tsv
            generate_kgtk_output(entities_output, output_kgtk_file,
                                 kwargs.get('output_no_header', False),
                                 verbose, very_verbose)

        logging.info(f'Embeddings has been generated in {output_kgtk_file}.')

        # ************************************************
        # 5. Garbage collection
        # ************************************************
        if kwargs['retain_temporary_data'] == False:
            shutil.rmtree(kwargs['temporary_directory'])
            # tmp_tsv_path.unlink() # delete temporay tsv file
            # shutil.rmtree(tmp_output_folder) # deleter temporay output folder

        if kwargs["log_file_path"] != None:
            print('Processed Finished.', file=sys.stderr, flush=True)
            logging.info(
                f"Process Finished.\nOutput has been saved in {repr(str(output_kgtk_file))}"
            )
        else:
            print(
                f"Process Finished.\nOutput has been saved in {repr(str(output_kgtk_file))}",
                file=sys.stderr,
                flush=True)

    except Exception as e:
        raise KGTKException(str(e))
Exemplo n.º 18
0
def main():
    setup_logging()
    parser = argparse.ArgumentParser(description="Example on FB15k")
    parser.add_argument("--config",
                        default=DEFAULT_CONFIG,
                        help="Path to config file")
    parser.add_argument("-p", "--param", action="append", nargs="*")
    parser.add_argument("--data_dir",
                        type=Path,
                        default="data",
                        help="where to save processed data")
    parser.add_argument(
        "--no-filtered",
        dest="filtered",
        action="store_false",
        help="Run unfiltered eval",
    )
    args = parser.parse_args()

    # download data
    data_dir = args.data_dir
    fpath = download_url(FB15K_URL, data_dir)
    extract_tar(fpath)
    print("Downloaded and extracted file.")

    loader = ConfigFileLoader()
    config = loader.load_config(args.config, args.param)
    set_logging_verbosity(config.verbose)
    subprocess_init = SubprocessInitializer()
    subprocess_init.register(setup_logging, config.verbose)
    subprocess_init.register(add_to_sys_path, loader.config_dir.name)
    input_edge_paths = [data_dir / name for name in FILENAMES]
    output_train_path, output_valid_path, output_test_path = config.edge_paths

    convert_input_data(
        config.entities,
        config.relations,
        config.entity_path,
        config.edge_paths,
        input_edge_paths,
        TSVEdgelistReader(lhs_col=0, rhs_col=2, rel_col=1),
        dynamic_relations=config.dynamic_relations,
    )

    train_config = attr.evolve(config, edge_paths=[output_train_path])
    train(train_config, subprocess_init=subprocess_init)

    relations = [attr.evolve(r, all_negs=True) for r in config.relations]
    eval_config = attr.evolve(config,
                              edge_paths=[output_test_path],
                              relations=relations,
                              num_uniform_negs=0)
    if args.filtered:
        filter_paths = [output_test_path, output_valid_path, output_train_path]
        do_eval(
            eval_config,
            evaluator=FilteredRankingEvaluator(eval_config, filter_paths),
            subprocess_init=subprocess_init,
        )
    else:
        do_eval(eval_config, subprocess_init=subprocess_init)