def main(): config_help = "\n\nConfig parameters:\n\n" + "\n".join(ConfigSchema.help()) parser = argparse.ArgumentParser( epilog=config_help, # Needed to preserve line wraps in epilog. formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument("config", help="Path to config file") parser.add_argument("-p", "--param", action="append", nargs="*") parser.add_argument("edge_paths", type=Path, nargs="*", help="Input file paths") parser.add_argument( "-l", "--lhs-col", type=str, required=True, help="Column index for source entity", ) parser.add_argument( "-r", "--rhs-col", type=str, required=True, help="Column index for target entity", ) parser.add_argument("--rel-col", type=str, help="Column index for relation entity") parser.add_argument( "--relation-type-min-count", type=int, default=1, help="Min count for relation types", ) parser.add_argument("--entity-min-count", type=int, default=1, help="Min count for entities") opt = parser.parse_args() loader = ConfigFileLoader() config_dict = loader.load_raw_config(opt.config, opt.param) entity_configs, relation_configs, entity_path, edge_paths, dynamic_relations = parse_config_partial( # noqa config_dict) convert_input_data( entity_configs, relation_configs, entity_path, edge_paths, opt.edge_paths, ParquetEdgelistReader(opt.lhs_col, opt.rhs_col, opt.rel_col), opt.entity_min_count, opt.relation_type_min_count, dynamic_relations, )
def main(): config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help()) parser = argparse.ArgumentParser( epilog=config_help, # Needed to preserve line wraps in epilog. formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument('config', help='Path to config file') parser.add_argument('-p', '--param', action='append', nargs='*') parser.add_argument('edge_paths', type=Path, nargs='*', help='Input file paths') parser.add_argument('-l', '--lhs-col', type=str, required=True, help='Column index for source entity') parser.add_argument('-r', '--rhs-col', type=str, required=True, help='Column index for target entity') parser.add_argument('--rel-col', type=str, help='Column index for relation entity') parser.add_argument('--relation-type-min-count', type=int, default=1, help='Min count for relation types') parser.add_argument('--entity-min-count', type=int, default=1, help='Min count for entities') opt = parser.parse_args() loader = ConfigFileLoader() config_dict = loader.load_raw_config(opt.config, opt.param) entity_configs, relation_configs, entity_path, edge_paths, dynamic_relations = \ parse_config_partial(config_dict) convert_input_data( entity_configs, relation_configs, entity_path, edge_paths, opt.edge_paths, ParquetEdgelistReader(opt.lhs_col, opt.rhs_col, opt.rel_col), opt.entity_min_count, opt.relation_type_min_count, dynamic_relations, )
def main(): setup_logging() parser = argparse.ArgumentParser(description='Example on FB15k') parser.add_argument('--config', default=DEFAULT_CONFIG, help='Path to config file') parser.add_argument('-p', '--param', action='append', nargs='*') parser.add_argument('--data_dir', type=Path, default='data', help='where to save processed data') parser.add_argument('--no-filtered', dest='filtered', action='store_false', help='Run unfiltered eval') args = parser.parse_args() # download data data_dir = args.data_dir fpath = download_url(FB15K_URL, data_dir) extract_tar(fpath) print('Downloaded and extracted file.') loader = ConfigFileLoader() config = loader.load_config(args.config, args.param) set_logging_verbosity(config.verbose) subprocess_init = SubprocessInitializer() subprocess_init.register(setup_logging, config.verbose) subprocess_init.register(add_to_sys_path, loader.config_dir.name) input_edge_paths = [data_dir / name for name in FILENAMES] output_train_path, output_valid_path, output_test_path = config.edge_paths convert_input_data( config.entities, config.relations, config.entity_path, config.edge_paths, input_edge_paths, TSVEdgelistReader(lhs_col=0, rhs_col=2, rel_col=1), dynamic_relations=config.dynamic_relations, ) train_config = attr.evolve(config, edge_paths=[output_train_path]) train(train_config, subprocess_init=subprocess_init) relations = [attr.evolve(r, all_negs=True) for r in config.relations] eval_config = attr.evolve( config, edge_paths=[output_test_path], relations=relations, num_uniform_negs=0) if args.filtered: filter_paths = [output_test_path, output_valid_path, output_train_path] do_eval( eval_config, evaluator=FilteredRankingEvaluator(eval_config, filter_paths), subprocess_init=subprocess_init, ) else: do_eval(eval_config, subprocess_init=subprocess_init)
def main(): setup_logging() parser = argparse.ArgumentParser(description='Example on Livejournal') parser.add_argument('--config', default=DEFAULT_CONFIG, help='Path to config file') parser.add_argument('-p', '--param', action='append', nargs='*') parser.add_argument('--data_dir', type=Path, default='data', help='where to save processed data') args = parser.parse_args() # download data data_dir = args.data_dir data_dir.mkdir(parents=True, exist_ok=True) fpath = download_url(URL, data_dir) fpath = extract_gzip(fpath) print('Downloaded and extracted file.') # random split file for train and test random_split_file(fpath) loader = ConfigFileLoader() config = loader.load_config(args.config, args.param) set_logging_verbosity(config.verbose) subprocess_init = SubprocessInitializer() subprocess_init.register(setup_logging, config.verbose) subprocess_init.register(add_to_sys_path, loader.config_dir.name) input_edge_paths = [data_dir / name for name in FILENAMES] output_train_path, output_test_path = config.edge_paths convert_input_data( config.entities, config.relations, config.entity_path, config.edge_paths, input_edge_paths, TSVEdgelistReader(lhs_col=0, rhs_col=1, rel_col=None), dynamic_relations=config.dynamic_relations, ) train_config = attr.evolve(config, edge_paths=[output_train_path]) train(train_config, subprocess_init=subprocess_init) eval_config = attr.evolve(config, edge_paths=[output_test_path]) do_eval(eval_config, subprocess_init=subprocess_init)
# data/example_2/entity_names_merchant_0.json # data/example_2/entity_names_user_0.json # # and this file with data: # data/example_2/edges_partitioned/edges_0_0.h5 # ================================================= setup_logging() config = parse_config(raw_config) subprocess_init = SubprocessInitializer() input_edge_paths = [Path(GRAPH_PATH)] convert_input_data( config.entities, config.relations, config.entity_path, config.edge_paths, input_edge_paths, TSVEdgelistReader(lhs_col=0, rel_col=1, rhs_col=2), dynamic_relations=config.dynamic_relations, ) # =============================================== # 3. TRAIN THE EMBEDDINGS # files generated in this step: # # checkpoint_version.txt # config.json # embeddings_item_0.v7.h5 # embeddings_merchant_0.v7.h5 # embeddings_user_0.v7.h5 # model.v7.h5