コード例 #1
0
def main():
    setup_logging()
    config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('config', help="Path to config file")
    parser.add_argument('-p', '--param', action='append', nargs='*')
    parser.add_argument('--rank',
                        type=int,
                        default=0,
                        help="For multi-machine, this machine's rank")
    opt = parser.parse_args()

    if opt.param is not None:
        overrides = chain.from_iterable(opt.param)  # flatten
    else:
        overrides = None
    loader = ConfigFileLoader()
    config = loader.load_config(opt.config, overrides)
    set_logging_verbosity(config.verbose)
    subprocess_init = SubprocessInitializer()
    subprocess_init.register(setup_logging, config.verbose)
    subprocess_init.register(add_to_sys_path, loader.config_dir.name)

    train(config, rank=Rank(opt.rank), subprocess_init=subprocess_init)
コード例 #2
0
ファイル: train.py プロジェクト: zxin1023/PyTorch-BigGraph
def main():
    setup_logging()
    config_help = "\n\nConfig parameters:\n\n" + "\n".join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument("config", help="Path to config file")
    parser.add_argument("-p", "--param", action="append", nargs="*")
    parser.add_argument(
        "--rank",
        type=int,
        default=SINGLE_TRAINER,
        help="For multi-machine, this machine's rank",
    )
    opt = parser.parse_args()

    loader = ConfigFileLoader()
    config = loader.load_config(opt.config, opt.param)
    set_logging_verbosity(config.verbose)
    subprocess_init = SubprocessInitializer()
    subprocess_init.register(setup_logging, config.verbose)
    subprocess_init.register(add_to_sys_path, loader.config_dir.name)

    train(config, rank=opt.rank, subprocess_init=subprocess_init)
コード例 #3
0
def main():
    config_help = "\n\nConfig parameters:\n\n" + "\n".join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument("config", help="Path to config file")
    parser.add_argument("-p", "--param", action="append", nargs="*")
    parser.add_argument("edge_paths",
                        type=Path,
                        nargs="*",
                        help="Input file paths")
    parser.add_argument(
        "-l",
        "--lhs-col",
        type=str,
        required=True,
        help="Column index for source entity",
    )
    parser.add_argument(
        "-r",
        "--rhs-col",
        type=str,
        required=True,
        help="Column index for target entity",
    )
    parser.add_argument("--rel-col",
                        type=str,
                        help="Column index for relation entity")
    parser.add_argument(
        "--relation-type-min-count",
        type=int,
        default=1,
        help="Min count for relation types",
    )
    parser.add_argument("--entity-min-count",
                        type=int,
                        default=1,
                        help="Min count for entities")
    opt = parser.parse_args()

    loader = ConfigFileLoader()
    config_dict = loader.load_raw_config(opt.config, opt.param)

    entity_configs, relation_configs, entity_path, edge_paths, dynamic_relations = parse_config_partial(  # noqa
        config_dict)

    convert_input_data(
        entity_configs,
        relation_configs,
        entity_path,
        edge_paths,
        opt.edge_paths,
        ParquetEdgelistReader(opt.lhs_col, opt.rhs_col, opt.rel_col),
        opt.entity_min_count,
        opt.relation_type_min_count,
        dynamic_relations,
    )
コード例 #4
0
ファイル: import_from_tsv.py プロジェクト: delldu/BigGraph
def main():
    config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('config', help='Path to config file')
    parser.add_argument('-p', '--param', action='append', nargs='*')
    parser.add_argument('edge_paths',
                        type=Path,
                        nargs='*',
                        help='Input file paths')
    parser.add_argument('-l',
                        '--lhs-col',
                        type=int,
                        required=True,
                        help='Column index for source entity')
    parser.add_argument('-r',
                        '--rhs-col',
                        type=int,
                        required=True,
                        help='Column index for target entity')
    parser.add_argument('--rel-col',
                        type=int,
                        help='Column index for relation entity')
    parser.add_argument('--relation-type-min-count',
                        type=int,
                        default=1,
                        help='Min count for relation types')
    parser.add_argument('--entity-min-count',
                        type=int,
                        default=1,
                        help='Min count for entities')
    opt = parser.parse_args()

    loader = ConfigFileLoader()
    config_dict = loader.load_raw_config(opt.config)

    if opt.param is not None:
        overrides = chain.from_iterable(opt.param)  # flatten
        config_dict = override_config_dict(config_dict, overrides)

    entity_configs, relation_configs, entity_path, dynamic_relations = \
        parse_config_partial(config_dict)

    convert_input_data(
        entity_configs,
        relation_configs,
        entity_path,
        opt.edge_paths,
        opt.lhs_col,
        opt.rhs_col,
        opt.rel_col,
        opt.entity_min_count,
        opt.relation_type_min_count,
        dynamic_relations,
    )
コード例 #5
0
def main():
    config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('config', help='Path to config file')
    parser.add_argument('edge_paths', nargs='*', help='Input file paths')
    parser.add_argument('-l',
                        '--lhs-col',
                        type=int,
                        required=True,
                        help='Column index for source entity')
    parser.add_argument('-r',
                        '--rhs-col',
                        type=int,
                        required=True,
                        help='Column index for target entity')
    parser.add_argument('--rel-col',
                        type=int,
                        help='Column index for relation entity')
    parser.add_argument('--relation-type-min-count',
                        type=int,
                        default=1,
                        help='Min count for relation types')
    parser.add_argument('--entity-min-count',
                        type=int,
                        default=1,
                        help='Min count for entities')

    opt = parser.parse_args()

    entity_configs, relation_configs, entity_path, dynamic_relations = \
        validate_config(opt.config)

    convert_input_data(
        entity_configs,
        relation_configs,
        entity_path,
        opt.edge_paths,
        opt.lhs_col,
        opt.rhs_col,
        opt.rel_col,
        opt.entity_min_count,
        opt.relation_type_min_count,
        dynamic_relations,
    )
コード例 #6
0
def main():
    config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('config', help="Path to config file")
    parser.add_argument('-p', '--param', action='append', nargs='*')
    opt = parser.parse_args()

    if opt.param is not None:
        overrides = chain.from_iterable(opt.param)  # flatten
    else:
        overrides = None
    config = parse_config(opt.config, overrides)

    do_eval(config)
コード例 #7
0
def main():
    config_help = "\n\nConfig parameters:\n\n" + "\n".join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument("config", help="Path to config file")
    parser.add_argument("-p", "--param", action="append", nargs="*")
    parser.add_argument("--entities-output", required=True)
    parser.add_argument("--relation-types-output", required=True)
    opt = parser.parse_args()

    loader = ConfigFileLoader()
    config = loader.load_config(opt.config, opt.param)

    with open(opt.entities_output,
              "xt") as entities_tf, open(opt.relation_types_output,
                                         "xt") as relation_types_tf:
        make_tsv(config, entities_tf, relation_types_tf)
コード例 #8
0
ファイル: train.py プロジェクト: yueyedeai/PyTorch-BigGraph
def main():
    config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('config', help="Path to config file")
    parser.add_argument('-p', '--param', action='append', nargs='*')
    parser.add_argument('--rank', type=int, default=0,
                        help="For multi-machine, this machine's rank")
    opt = parser.parse_args()

    if opt.param is not None:
        overrides = chain.from_iterable(opt.param)  # flatten
    else:
        overrides = None
    config = parse_config(opt.config, overrides)

    train(config, rank=Rank(opt.rank))
コード例 #9
0
ファイル: eval.py プロジェクト: zgsxwsdxg/PyTorch-BigGraph
def main():
    setup_logging()
    config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('config', help="Path to config file")
    parser.add_argument('-p', '--param', action='append', nargs='*')
    opt = parser.parse_args()

    loader = ConfigFileLoader()
    config = loader.load_config(opt.config, opt.param)
    set_logging_verbosity(config.verbose)
    subprocess_init = SubprocessInitializer()
    subprocess_init.register(setup_logging, config.verbose)
    subprocess_init.register(add_to_sys_path, loader.config_dir.name)

    do_eval(config, subprocess_init=subprocess_init)
コード例 #10
0
def main():
    config_help = '\n\nConfig parameters:\n\n' + '\n'.join(ConfigSchema.help())
    parser = argparse.ArgumentParser(
        epilog=config_help,
        # Needed to preserve line wraps in epilog.
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('config', help="Path to config file")
    parser.add_argument('-p', '--param', action='append', nargs='*')
    parser.add_argument('--checkpoint')
    parser.add_argument('--dict', required=True)
    parser.add_argument('--entities-output', required=True)
    parser.add_argument('--relation-types-output', required=True)
    opt = parser.parse_args()

    if opt.param is not None:
        overrides = chain.from_iterable(opt.param)  # flatten
    else:
        overrides = None
    loader = ConfigFileLoader()
    config = loader.load_config(opt.config, overrides)

    print("Loading relation types and entities...")
    with open(opt.dict, "rt") as tf:
        dump = json.load(tf)

    with open(opt.entities_output, "xt") as entities_tf, \
            open(opt.relation_types_output, "xt") as relation_types_tf:
        make_tsv(
            config,
            opt.checkpoint,
            dump["entities"],
            dump["relations"],
            entities_tf,
            relation_types_tf,
        )