Esempio n. 1
0
def get_task_model_args(task, model=None):
    sys.argv = [sys.argv[0], "-t", task, "-m"] + ["gcn"] + ["-dt"] + ["cora"]
    parser = get_training_parser()
    TASK_REGISTRY[task].add_args(parser)
    if model is not None:
        if try_import_model(model):
            MODEL_REGISTRY[model].add_args(parser)
    args = parser.parse_args()
    args.task = task
    if model is not None:
        args.model = model
    return args
Esempio n. 2
0
def parse_args_and_arch(parser, args):
    """The parser doesn't know about model-specific args, so we parse twice."""
    # args, _ = parser.parse_known_args()

    # Add *-specific args to parser.
    TASK_REGISTRY[args.task].add_args(parser)
    for model in args.model:
        if try_import_model(model):
            MODEL_REGISTRY[model].add_args(parser)
    for dataset in args.dataset:
        if try_import_dataset(dataset):
            if hasattr(DATASET_REGISTRY[dataset], "add_args"):
                DATASET_REGISTRY[dataset].add_args(parser)
    if "trainer" in args and args.trainer is not None:
        # for trainer in args.trainer:
        if try_import_trainer(args.trainer):
            if hasattr(TRAINER_REGISTRY[args.trainer], "add_args"):
                TRAINER_REGISTRY[args.trainer].add_args(parser)
    # Parse a second time.
    args = parser.parse_args()
    return args
Esempio n. 3
0
def parse_args_and_arch(parser, args):
    # Add *-specific args to parser.
    try_import_task(args.task)
    TASK_REGISTRY[args.task].add_args(parser)
    for model in args.model:
        if try_import_model(model):
            MODEL_REGISTRY[model].add_args(parser)
    for dataset in args.dataset:
        if try_import_dataset(dataset):
            if hasattr(DATASET_REGISTRY[dataset], "add_args"):
                DATASET_REGISTRY[dataset].add_args(parser)

    if "trainer" in args and args.trainer is not None:
        if try_import_trainer(args.trainer):
            if hasattr(TRAINER_REGISTRY[args.trainer], "add_args"):
                TRAINER_REGISTRY[args.trainer].add_args(parser)
    else:
        for model in args.model:
            tr = MODEL_REGISTRY[model].get_trainer(None, None)
            if tr is not None:
                tr.add_args(parser)
    # Parse a second time.
    args = parser.parse_args()
    return args