예제 #1
0
def gat(*args, **kwargs):
    return GAT(*args, **kwargs)
예제 #2
0
        feature_size = [1, 2708, 1433]
        biases_size = [1, 2708, 2708]
        num_classes = 7

    hid_units = GatConfig.hid_units
    n_heads = GatConfig.n_heads

    feature = np.random.uniform(0.0, 1.0, size=feature_size).astype(np.float32)
    biases = np.random.uniform(0.0, 1.0, size=biases_size).astype(np.float64)

    feature_size = feature.shape[2]
    num_nodes = feature.shape[1]

    gat_net = GAT(feature_size,
                  num_classes,
                  num_nodes,
                  hid_units,
                  n_heads,
                  attn_drop=0.0,
                  ftr_drop=0.0)

    gat_net.set_train(False)
    load_checkpoint(args.ckpt_file, net=gat_net)
    gat_net.add_flags_recursive(fp16=True)

    export(gat_net,
           Tensor(feature),
           Tensor(biases),
           file_name=args.file_name,
           file_format=args.file_format)
예제 #3
0
파일: train.py 프로젝트: xyg320/mindspore
def train():
    """Train GAT model."""
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir',
                        type=str,
                        default='./data/cora/cora_mr',
                        help='Data dir')
    parser.add_argument('--train_nodes_num',
                        type=int,
                        default=140,
                        help='Nodes numbers for training')
    parser.add_argument('--eval_nodes_num',
                        type=int,
                        default=500,
                        help='Nodes numbers for evaluation')
    parser.add_argument('--test_nodes_num',
                        type=int,
                        default=1000,
                        help='Nodes numbers for test')
    args = parser.parse_args()
    if not os.path.exists("ckpts"):
        os.mkdir("ckpts")
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        save_graphs=False)
    # train parameters
    hid_units = GatConfig.hid_units
    n_heads = GatConfig.n_heads
    early_stopping = GatConfig.early_stopping
    lr = GatConfig.lr
    l2_coeff = GatConfig.l2_coeff
    num_epochs = GatConfig.num_epochs
    feature, biases, y_train, train_mask, y_val, eval_mask, y_test, test_mask = load_and_process(
        args.data_dir, args.train_nodes_num, args.eval_nodes_num,
        args.test_nodes_num)
    feature_size = feature.shape[2]
    num_nodes = feature.shape[1]
    num_class = y_train.shape[2]

    gat_net = GAT(feature,
                  biases,
                  feature_size,
                  num_class,
                  num_nodes,
                  hid_units,
                  n_heads,
                  attn_drop=GatConfig.attn_dropout,
                  ftr_drop=GatConfig.feature_dropout)
    gat_net.add_flags_recursive(fp16=True)

    eval_net = LossAccuracyWrapper(gat_net, num_class, y_val, eval_mask,
                                   l2_coeff)

    train_net = TrainGAT(gat_net, num_class, y_train, train_mask, lr, l2_coeff)

    train_net.set_train(True)
    val_acc_max = 0.0
    val_loss_min = np.inf
    for _epoch in range(num_epochs):
        train_result = train_net()
        train_loss = train_result[0].asnumpy()
        train_acc = train_result[1].asnumpy()

        eval_result = eval_net()
        eval_loss = eval_result[0].asnumpy()
        eval_acc = eval_result[1].asnumpy()

        print(
            "Epoch:{}, train loss={:.5f}, train acc={:.5f} | val loss={:.5f}, val acc={:.5f}"
            .format(_epoch, train_loss, train_acc, eval_loss, eval_acc))
        if eval_acc >= val_acc_max or eval_loss < val_loss_min:
            if eval_acc >= val_acc_max and eval_loss < val_loss_min:
                val_acc_model = eval_acc
                val_loss_model = eval_loss
                _exec_save_checkpoint(train_net.network, "ckpts/gat.ckpt")
            val_acc_max = np.max((val_acc_max, eval_acc))
            val_loss_min = np.min((val_loss_min, eval_loss))
            curr_step = 0
        else:
            curr_step += 1
            if curr_step == early_stopping:
                print("Early Stop Triggered!, Min loss: {}, Max accuracy: {}".
                      format(val_loss_min, val_acc_max))
                print(
                    "Early stop model validation loss: {}, accuracy{}".format(
                        val_loss_model, val_acc_model))
                break
    gat_net_test = GAT(feature,
                       biases,
                       feature_size,
                       num_class,
                       num_nodes,
                       hid_units,
                       n_heads,
                       attn_drop=0.0,
                       ftr_drop=0.0)
    load_checkpoint("ckpts/gat.ckpt", net=gat_net_test)
    gat_net_test.add_flags_recursive(fp16=True)

    test_net = LossAccuracyWrapper(gat_net_test, num_class, y_test, test_mask,
                                   l2_coeff)
    test_result = test_net()
    print("Test loss={}, test acc={}".format(test_result[0], test_result[1]))
예제 #4
0
파일: main.py 프로젝트: mindspore-ai/course
def train(args_opt):
    """Train GAT model."""

    if not os.path.exists("ckpts"):
        os.mkdir("ckpts")

    # train parameters
    hid_units = GatConfig.hid_units
    n_heads = GatConfig.n_heads
    early_stopping = GatConfig.early_stopping
    lr = GatConfig.lr
    l2_coeff = GatConfig.l2_coeff
    num_epochs = GatConfig.num_epochs
    feature, biases, y_train, train_mask, y_val, eval_mask, y_test, test_mask = load_and_process(
        args_opt.data_dir, args_opt.train_nodes_num, args_opt.eval_nodes_num,
        args_opt.test_nodes_num)
    feature_size = feature.shape[2]
    num_nodes = feature.shape[1]
    num_class = y_train.shape[2]

    gat_net = GAT(feature,
                  biases,
                  feature_size,
                  num_class,
                  num_nodes,
                  hid_units,
                  n_heads,
                  attn_drop=GatConfig.attn_dropout,
                  ftr_drop=GatConfig.feature_dropout)
    gat_net.add_flags_recursive(fp16=True)

    eval_net = LossAccuracyWrapper(gat_net, num_class, y_val, eval_mask,
                                   l2_coeff)

    train_net = TrainGAT(gat_net, num_class, y_train, train_mask, lr, l2_coeff)

    train_net.set_train(True)
    val_acc_max = 0.0
    val_loss_min = np.inf
    for _epoch in range(num_epochs):
        train_result = train_net()
        train_loss = train_result[0].asnumpy()
        train_acc = train_result[1].asnumpy()

        eval_result = eval_net()
        eval_loss = eval_result[0].asnumpy()
        eval_acc = eval_result[1].asnumpy()

        print(
            "Epoch:{}, train loss={:.5f}, train acc={:.5f} | val loss={:.5f}, val acc={:.5f}"
            .format(_epoch, train_loss, train_acc, eval_loss, eval_acc))
        if eval_acc >= val_acc_max or eval_loss < val_loss_min:
            if eval_acc >= val_acc_max and eval_loss < val_loss_min:
                val_acc_model = eval_acc
                val_loss_model = eval_loss
                save_checkpoint(train_net.network, "ckpts/gat.ckpt")
            val_acc_max = np.max((val_acc_max, eval_acc))
            val_loss_min = np.min((val_loss_min, eval_loss))
            curr_step = 0
        else:
            curr_step += 1
            if curr_step == early_stopping:
                print("Early Stop Triggered!, Min loss: {}, Max accuracy: {}".
                      format(val_loss_min, val_acc_max))
                print(
                    "Early stop model validation loss: {}, accuracy{}".format(
                        val_loss_model, val_acc_model))
                break
    gat_net_test = GAT(feature,
                       biases,
                       feature_size,
                       num_class,
                       num_nodes,
                       hid_units,
                       n_heads,
                       attn_drop=0.0,
                       ftr_drop=0.0)
    load_checkpoint("ckpts/gat.ckpt", net=gat_net_test)
    gat_net_test.add_flags_recursive(fp16=True)

    test_net = LossAccuracyWrapper(gat_net_test, num_class, y_test, test_mask,
                                   l2_coeff)
    test_result = test_net()
    print("Test loss={}, test acc={}".format(test_result[0], test_result[1]))