def main(): parser = argparse.ArgumentParser(description='Example on FB15k') parser.add_argument('--config', default='./fb15k_config.py', help='Path to config file') parser.add_argument('-p', '--param', action='append', nargs='*') parser.add_argument('--data_dir', default='../../../data', help='where to save processed data') parser.add_argument('--no-filtered', dest='filtered', action='store_false', help='Run unfiltered eval') args = parser.parse_args() if args.param is not None: overrides = chain.from_iterable(args.param) # flatten else: overrides = None # download data data_dir = args.data_dir #fpath = utils.download_url(FB15K_URL, data_dir) #utils.extract_tar(fpath) #print('Downloaded and extracted file.') edge_paths = [os.path.join(data_dir, name) for name in FILENAMES.values()] print('edge_paths', edge_paths) convert_input_data( args.config, edge_paths, lhs_col=0, rhs_col=2, rel_col=1, ) config = parse_config(args.config, overrides) train_path = [convert_path(os.path.join(data_dir, FILENAMES['train']))] train_config = attr.evolve(config, edge_paths=train_path) train(train_config) eval_path = [convert_path(os.path.join(data_dir, FILENAMES['test']))] relations = [attr.evolve(r, all_negs=True) for r in config.relations] eval_config = attr.evolve(config, edge_paths=eval_path, relations=relations) if args.filtered: filter_paths = [ convert_path(os.path.join(data_dir, FILENAMES['test'])), convert_path(os.path.join(data_dir, FILENAMES['valid'])), convert_path(os.path.join(data_dir, FILENAMES['train'])), ] do_eval(eval_config, FilteredRankingEvaluator(eval_config, filter_paths)) else: do_eval(eval_config)
def main(): parser = argparse.ArgumentParser(description='Example on Livejournal') parser.add_argument('--config', default=DEFAULT_CONFIG, help='Path to config file') parser.add_argument('-p', '--param', action='append', nargs='*') parser.add_argument('--data_dir', default='data', help='where to save processed data') args = parser.parse_args() if args.param is not None: overrides = chain.from_iterable(args.param) # flatten else: overrides = None # download data data_dir = args.data_dir os.makedirs(data_dir, exist_ok=True) fpath = utils.download_url(URL, data_dir) fpath = utils.extract_gzip(fpath) print('Downloaded and extracted file.') # random split file for train and test random_split_file(fpath) entity_configs, relation_configs, entity_path, dynamic_relations = \ validate_config(args.config) edge_paths = [os.path.join(data_dir, name) for name in FILENAMES.values()] convert_input_data( entity_configs, relation_configs, entity_path, edge_paths, lhs_col=0, rhs_col=1, rel_col=None, dynamic_relations=dynamic_relations, ) config = parse_config(args.config, overrides) train_path = [convert_path(os.path.join(data_dir, FILENAMES['train']))] train_config = attr.evolve(config, edge_paths=train_path) train(train_config) eval_path = [convert_path(os.path.join(data_dir, FILENAMES['test']))] eval_config = attr.evolve(config, edge_paths=eval_path) do_eval(eval_config)
def run_train_eval(): #将数据转为PBG可读的分区文件 convert_input_data(CONFIG_PATH, edge_paths, lhs_col=0, rhs_col=1, rel_col=None) #解析配置 config = parse_config(CONFIG_PATH) #训练配置,已分区的train_paths路径替换配置文件中的edge_paths train_config = attr.evolve(config, edge_paths=train_paths) #传入训练配置文件开始训练 train(train_config) #测试配置,已分区的eval_paths路径替换配置文件中的edge_paths eval_config = attr.evolve(config, edge_paths=eval_paths) #开始验证 do_eval(eval_config)
def main(): config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help()) parser = argparse.ArgumentParser( epilog=config_help, # Needed to preserve line wraps in epilog. formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument('config', help="Path to config file") parser.add_argument('-p', '--param', action='append', nargs='*') opt = parser.parse_args() if opt.param is not None: overrides = chain.from_iterable(opt.param) # flatten else: overrides = None config = parse_config(opt.config, overrides) do_eval(config)
def main(): config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help()) parser = argparse.ArgumentParser( epilog=config_help, # Needed to preserve line wraps in epilog. formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument('config', help="Path to config file") parser.add_argument('-p', '--param', action='append', nargs='*') parser.add_argument('--rank', type=int, default=0, help="For multi-machine, this machine's rank") opt = parser.parse_args() if opt.param is not None: overrides = chain.from_iterable(opt.param) # flatten else: overrides = None config = parse_config(opt.config, overrides) train(config, rank=Rank(opt.rank))
def run_train_eval(): random_split_file(DATA_PATH) convert_input_data( CONFIG_PATH, edge_paths, lhs_col=0, rhs_col=1, rel_col=None, ) train_config = parse_config(CONFIG_PATH) train_config = attr.evolve(train_config, edge_paths=train_path) train(train_config) eval_config = attr.evolve(train_config, edge_paths=eval_path) do_eval(eval_config)
from torchbiggraph.config import parse_config import attr train_config = parse_config(CONFIG_PATH) train_path = [convert_path(os.path.join(DATA_DIR, FILENAMES['train']))] train_config = attr.evolve(train_config, edge_paths=train_path) from torchbiggraph.train import train train(train_config) # Time to run on liveJournal data: 17:43 - ??? ### SNIPPET 3 ###
# ================================================= # 2. TRANSFORM GRAPH TO A BIGGRAPH-FRIENDLY FORMAT # This step generates the following metadata files: # # data/example_2/entity_count_item_0.txt # data/example_2/entity_count_merchant_0.txt # data/example_2/entity_count_user_0.txt # data/example_2/entity_names_item_0.json # data/example_2/entity_names_merchant_0.json # data/example_2/entity_names_user_0.json # # and this file with data: # data/example_2/edges_partitioned/edges_0_0.h5 # ================================================= setup_logging() config = parse_config(raw_config) subprocess_init = SubprocessInitializer() input_edge_paths = [Path(GRAPH_PATH)] convert_input_data( config.entities, config.relations, config.entity_path, config.edge_paths, input_edge_paths, TSVEdgelistReader(lhs_col=0, rel_col=1, rhs_col=2), dynamic_relations=config.dynamic_relations, ) # =============================================== # 3. TRAIN THE EMBEDDINGS
def run(input_file: KGTKFiles, output_file: KGTKFiles, verbose: bool = False, very_verbose: bool = False, **kwargs): """ **kwargs stores all parameters providing by user """ # print(kwargs) # import modules locally import sys import typing import os import logging from pathlib import Path import json, os, h5py, gzip, torch, shutil from torchbiggraph.config import parse_config from kgtk.exceptions import KGTKException # copy missing file under kgtk/graph_embeddings from kgtk.templates.kgtkcopytemplate import KgtkCopyTemplate from kgtk.graph_embeddings.importers import TSVEdgelistReader, convert_input_data from torchbiggraph.train import train from torchbiggraph.util import SubprocessInitializer, setup_logging from kgtk.graph_embeddings.export_to_tsv import make_tsv # from torchbiggraph.converters.export_to_tsv import make_tsv try: input_kgtk_file: Path = KGTKArgumentParser.get_input_file(input_file) output_kgtk_file: Path = KGTKArgumentParser.get_output_file( output_file) # store the data into log file, then the console will not output anything if kwargs['log_file_path'] != None: log_file_path = kwargs['log_file_path'] logging.basicConfig( format='%(asctime)s - %(filename)s[line:%(lineno)d] \ - %(levelname)s: %(message)s', level=logging.DEBUG, filename=str(log_file_path), filemode='w') print( f'In Processing, Please go to {kwargs["log_file_path"]} to check details', file=sys.stderr, flush=True) tmp_folder = kwargs['temporary_directory'] tmp_tsv_path: Path = tmp_folder / f'tmp_{input_kgtk_file.name}' # tmp_tsv_path:Path = input_kgtk_file.parent/f'tmp_{input_kgtk_file.name}' # make sure the tmp folder exists, otherwise it will raise an exception if not os.path.exists(tmp_folder): os.makedirs(tmp_folder) try: #if output_kgtk_file is not empty, delete it output_kgtk_file.unlink() except: pass # didn't find, then let it go # ********************************************* # 0. PREPARE PBG TSV FILE # ********************************************* reader_options: KgtkReaderOptions = KgtkReaderOptions.from_dict(kwargs) value_options: KgtkValueOptions = KgtkValueOptions.from_dict(kwargs) error_file: typing.TextIO = sys.stdout if kwargs.get( "errors_to_stdout") else sys.stderr kct: KgtkCopyTemplate = KgtkCreateTmpTsv( input_file_path=input_kgtk_file, output_file_path=tmp_tsv_path, reader_options=reader_options, value_options=value_options, error_file=error_file, verbose=verbose, very_verbose=very_verbose, ) # prepare the graph file # create a tmp tsv file for PBG embedding logging.info('Generate the valid tsv format for embedding ...') kct.process() logging.info('Embedding file is ready...') # ********************************************* # 1. DEFINE CONFIG # ********************************************* raw_config = get_config(**kwargs) ## setting corresponding learning rate and loss function for different algorthim processed_config = config_preprocess(raw_config) # temporry output folder tmp_output_folder = Path(processed_config['entity_path']) # before moving, need to check whether the tmp folder is not empty in case of bug try: #if temporry output folder is alrady existing then delete it shutil.rmtree(tmp_output_folder) except: pass # didn't find, then let it go # ************************************************** # 2. TRANSFORM GRAPH TO A BIGGRAPH-FRIENDLY FORMAT # ************************************************** setup_logging() config = parse_config(processed_config) subprocess_init = SubprocessInitializer() input_edge_paths = [tmp_tsv_path] convert_input_data( config.entities, config.relations, config.entity_path, config.edge_paths, input_edge_paths, TSVEdgelistReader(lhs_col=0, rel_col=1, rhs_col=2), dynamic_relations=config.dynamic_relations, ) # ************************************************ # 3. TRAIN THE EMBEDDINGS #************************************************* train(config, subprocess_init=subprocess_init) # ************************************************ # 4. GENERATE THE OUTPUT # ************************************************ # entities_output = output_kgtk_file entities_output = tmp_output_folder / 'entities_output.tsv' relation_types_output = tmp_output_folder / 'relation_types_tf.tsv' with open(entities_output, "xt") as entities_tf, open(relation_types_output, "xt") as relation_types_tf: make_tsv(config, entities_tf, relation_types_tf) # output correct format for embeddings if kwargs['output_format'] == 'glove': # glove format output shutil.copyfile(entities_output, output_kgtk_file) elif kwargs['output_format'] == 'w2v': # w2v format output generate_w2v_output(entities_output, output_kgtk_file, kwargs) else: # write to the kgtk output format tsv generate_kgtk_output(entities_output, output_kgtk_file, kwargs.get('output_no_header', False), verbose, very_verbose) logging.info(f'Embeddings has been generated in {output_kgtk_file}.') # ************************************************ # 5. Garbage collection # ************************************************ if kwargs['retain_temporary_data'] == False: shutil.rmtree(kwargs['temporary_directory']) # tmp_tsv_path.unlink() # delete temporay tsv file # shutil.rmtree(tmp_output_folder) # deleter temporay output folder if kwargs["log_file_path"] != None: print('Processed Finished.', file=sys.stderr, flush=True) logging.info( f"Process Finished.\nOutput has been saved in {repr(str(output_kgtk_file))}" ) else: print( f"Process Finished.\nOutput has been saved in {repr(str(output_kgtk_file))}", file=sys.stderr, flush=True) except Exception as e: raise KGTKException(str(e))