示例#1
0
def main(args):
    if torch.cuda.is_available() and not args.cpu:
        pid = mp.current_process().pid
        torch.cuda.set_device(args.pid_to_cuda[pid])

    set_random_seed(args.seed)

    task = build_task(args)
    result = task.train()
    return result
示例#2
0
def train(args):
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id[0])

    set_random_seed(args.seed)

    task = build_task(args)
    result = task.train()

    return result
示例#3
0
def run(dataset_name):
    args = build_default_args_for_unsupervised_node_classification(dataset_name)
    args = DATASET_REGISTRY[dataset_name](args)
    dataset = build_dataset(args)
    results = []
    for seed in args.seed:
        set_random_seed(seed)
        task = build_task(args, dataset=dataset)
        result = task.train()
        results.append(result)
    return results
示例#4
0
def run(dataset_name):
    args = build_default_args_for_multiplex_link_prediction(dataset_name)
    args = DATASET_REGISTRY[dataset_name](args)
    dataset = build_dataset(args)
    results = []
    for seed in args.seed:
        set_random_seed(seed)
        task = build_task(args, dataset=dataset)
        result = task.train()
        results.append(result)
    return results
示例#5
0
    def run_n_seed(self, args):
        result_list = []
        for seed in range(N_SEED):
            set_random_seed(seed)

            model = build_model(args)
            task = build_task(args, model=model, dataset=self.dataset)

            result = task.train()
            result_list.append(result)
        return result_list
示例#6
0
def train(args):
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id[0])

    set_random_seed(args.seed)

    if getattr(args, "use_best_config", False):
        args = set_best_config(args)

    print(args)
    task = build_task(args)
    result = task.train()

    return result
示例#7
0
def run(dataset_name, missing_rate=0, num_layers=40):
    args = build_default_args_for_node_classification(
        dataset_name, missing_rate=missing_rate, num_layers=num_layers)
    args = DATASET_REGISTRY[dataset_name](args)
    dataset, args = build_dataset(args)
    args.num_features = dataset.num_features
    args.num_classes = dataset.num_classes
    results = []
    for seed in args.seed:
        set_random_seed(seed)
        task = build_task(args, dataset=dataset)
        result = task.train()
        results.append(result)
    return results
示例#8
0
def run(dataset_name):
    args = build_default_args_for_heterogeneous_node_classification(dataset_name)
    args = DATASET_REGISTRY[dataset_name](args)
    dataset = build_dataset(args)
    args.num_features = dataset.num_features
    args.num_classes = dataset.num_classes
    args.num_edge = dataset.num_edge
    args.num_nodes = dataset.num_nodes
    results = []
    for seed in args.seed:
        set_random_seed(seed)
        task = build_task(args, dataset=dataset)
        result = task.train()
        results.append(result)
    return results
示例#9
0
def train():
    args = build_default_args()
    dataset, args = get_dataset(args)

    combinations = get_parameters()
    best_parameters = None
    best_result = None
    best_val_acc = 0

    print(f"===== Start At: {get_time()} ===========")
    start = time.time()

    random_seeds = list(range(5))
    for item in combinations:
        for key, val in item.items():
            setattr(args, key, val)

        print(f"### -- Parameters: {args.__dict__}")
        result_list = []
        for seed in random_seeds:
            set_random_seed(seed)

            task = build_task(args, dataset=dataset)
            res = task.train()
            result_list.append(res)

        val_acc = [x["ValAcc"] for x in result_list]
        test_acc = [x["Acc"] for x in result_list]
        val_acc = sum(val_acc) / len(val_acc)
        print(f"###    Result: {val_acc}")
        if val_acc > best_val_acc:
            best_parameters = copy.deepcopy(args)
            best_result = dict(Acc=sum(test_acc) / len(test_acc),
                               ValAcc=val_acc)
    print(f"Best Parameters: {best_parameters}")
    print(f"Best result: {best_result}")

    end = time.time()
    print(f"===== End At: {get_time()} ===========")
    print("Time cost:", end - start)
示例#10
0
def train(args):  # noqa: C901
    if isinstance(args.dataset, list):
        args.dataset = args.dataset[0]
    if isinstance(args.model, list):
        args.model = args.model[0]
    if isinstance(args.seed, list):
        args.seed = args.seed[0]
    if isinstance(args.split, list):
        args.split = args.split[0]
    set_random_seed(args.seed)

    model_name = args.model if isinstance(args.model,
                                          str) else args.model.model_name
    dw_name = args.dw if isinstance(args.dw, str) else args.dw.__name__
    mw_name = args.mw if isinstance(args.mw, str) else args.mw.__name__

    print(f""" 
|-------------------------------------{'-' * (len(str(args.dataset)) + len(model_name) + len(dw_name) + len(mw_name))}|
    *** Running (`{args.dataset}`, `{model_name}`, `{dw_name}`, `{mw_name}`)
|-------------------------------------{'-' * (len(str(args.dataset)) + len(model_name) + len(dw_name) + len(mw_name))}|"""
          )

    if getattr(args, "use_best_config", False):
        args = set_best_config(args)

    # setup dataset and specify `num_features` and `num_classes` for model
    if isinstance(args.dataset, Dataset):
        dataset = args.dataset
    else:
        dataset = build_dataset(args)

    mw_class = fetch_model_wrapper(args.mw)
    dw_class = fetch_data_wrapper(args.dw)

    if mw_class is None:
        raise NotImplementedError("`model wrapper(--mw)` must be specified.")

    if dw_class is None:
        raise NotImplementedError("`data wrapper(--dw)` must be specified.")

    data_wrapper_args = dict()
    model_wrapper_args = dict()
    # unworthy code: share `args` between model and dataset_wrapper
    for key in inspect.signature(dw_class).parameters.keys():
        if hasattr(args, key) and key != "dataset":
            data_wrapper_args[key] = getattr(args, key)
    # unworthy code: share `args` between model and model_wrapper
    for key in inspect.signature(mw_class).parameters.keys():
        if hasattr(args, key) and key != "model":
            model_wrapper_args[key] = getattr(args, key)

    # setup data_wrapper
    dataset_wrapper = dw_class(dataset, **data_wrapper_args)

    args.num_features = dataset.num_features
    if hasattr(dataset, "num_nodes"):
        args.num_nodes = dataset.num_nodes
    if hasattr(dataset, "num_edges"):
        args.num_edges = dataset.num_edges
    if hasattr(dataset, "num_edge"):
        args.num_edge = dataset.num_edge
    if hasattr(dataset, "max_graph_size"):
        args.max_graph_size = dataset.max_graph_size
    if hasattr(dataset, "edge_attr_size"):
        args.edge_attr_size = dataset.edge_attr_size
    else:
        args.edge_attr_size = [0]
    if hasattr(args, "unsup") and args.unsup:
        args.num_classes = args.hidden_size
    else:
        args.num_classes = dataset.num_classes
    if hasattr(dataset.data,
               "edge_attr") and dataset.data.edge_attr is not None:
        args.num_entities = len(
            torch.unique(torch.stack(dataset.data.edge_index)))
        args.num_rels = len(torch.unique(dataset.data.edge_attr))

    # setup model
    if isinstance(args.model, nn.Module):
        model = args.model
    else:
        model = build_model(args)
    # specify configs for optimizer
    optimizer_cfg = dict(
        lr=args.lr,
        weight_decay=args.weight_decay,
        n_warmup_steps=args.n_warmup_steps,
        epochs=args.epochs,
        batch_size=args.batch_size if hasattr(args, "batch_size") else 0,
    )

    if hasattr(args, "hidden_size"):
        optimizer_cfg["hidden_size"] = args.hidden_size

    # setup model_wrapper
    if isinstance(args.mw, str) and "embedding" in args.mw:
        model_wrapper = mw_class(model, **model_wrapper_args)
    else:
        model_wrapper = mw_class(model, optimizer_cfg, **model_wrapper_args)

    os.makedirs("./checkpoints", exist_ok=True)

    # setup controller
    trainer = Trainer(
        epochs=args.epochs,
        device_ids=args.devices,
        cpu=args.cpu,
        save_emb_path=args.save_emb_path,
        load_emb_path=args.load_emb_path,
        cpu_inference=args.cpu_inference,
        progress_bar=args.progress_bar,
        distributed_training=args.distributed,
        checkpoint_path=args.checkpoint_path,
        resume_training=args.resume_training,
        patience=args.patience,
        eval_step=args.eval_step,
        logger=args.logger,
        log_path=args.log_path,
        project=args.project,
        no_test=args.no_test,
        nstage=args.nstage,
        actnn=args.actnn,
        fp16=args.fp16,
    )

    # Go!!!
    result = trainer.run(model_wrapper, dataset_wrapper)

    return result