コード例 #1
0
ファイル: test_gcn.py プロジェクト: xjg23/mindspore
def test_gcn():
    print("test_gcn begin")
    np.random.seed(SEED)
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        save_graphs=True)
    config = ConfigGCN()
    adj, feature, label = get_adj_features_labels(DATA_DIR)

    nodes_num = label.shape[0]
    train_mask = get_mask(nodes_num, 0, TRAIN_NODE_NUM)
    eval_mask = get_mask(nodes_num, TRAIN_NODE_NUM,
                         TRAIN_NODE_NUM + EVAL_NODE_NUM)
    test_mask = get_mask(nodes_num, nodes_num - TEST_NODE_NUM, nodes_num)

    class_num = label.shape[1]
    gcn_net = GCN(config, adj, feature, class_num)
    gcn_net.add_flags_recursive(fp16=True)

    eval_net = LossAccuracyWrapper(gcn_net, label, eval_mask,
                                   config.weight_decay)
    test_net = LossAccuracyWrapper(gcn_net, label, test_mask,
                                   config.weight_decay)
    train_net = TrainNetWrapper(gcn_net, label, train_mask, config)

    loss_list = []
    for epoch in range(config.epochs):
        t = time.time()

        train_net.set_train()
        train_result = train_net()
        train_loss = train_result[0].asnumpy()
        train_accuracy = train_result[1].asnumpy()

        eval_net.set_train(False)
        eval_result = eval_net()
        eval_loss = eval_result[0].asnumpy()
        eval_accuracy = eval_result[1].asnumpy()

        loss_list.append(eval_loss)
        print("Epoch:", '%04d' % (epoch + 1), "train_loss=",
              "{:.5f}".format(train_loss), "train_acc=",
              "{:.5f}".format(train_accuracy), "val_loss=",
              "{:.5f}".format(eval_loss), "val_acc=",
              "{:.5f}".format(eval_accuracy), "time=",
              "{:.5f}".format(time.time() - t))

        if epoch > config.early_stopping and loss_list[-1] > np.mean(
                loss_list[-(config.early_stopping + 1):-1]):
            print("Early stopping...")
            break

    test_net.set_train(False)
    test_result = test_net()
    test_loss = test_result[0].asnumpy()
    test_accuracy = test_result[1].asnumpy()
    print("Test set results:", "loss=", "{:.5f}".format(test_loss),
          "accuracy=", "{:.5f}".format(test_accuracy))
    assert test_accuracy > 0.812
コード例 #2
0
ファイル: train.py プロジェクト: xjg23/mindspore
def train():
    """Train model."""
    parser = argparse.ArgumentParser(description='GCN')
    parser.add_argument('--data_dir', type=str, default='./data/cora/cora_mr', help='Dataset directory')
    parser.add_argument('--seed', type=int, default=123, help='Random seed')
    parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training')
    parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation')
    parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test')
    args_opt = parser.parse_args()

    np.random.seed(args_opt.seed)
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend", save_graphs=False)
    config = ConfigGCN()
    adj, feature, label = get_adj_features_labels(args_opt.data_dir)

    nodes_num = label.shape[0]
    train_mask = get_mask(nodes_num, 0, args_opt.train_nodes_num)
    eval_mask = get_mask(nodes_num, args_opt.train_nodes_num, args_opt.train_nodes_num + args_opt.eval_nodes_num)
    test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num, nodes_num)

    class_num = label.shape[1]
    gcn_net = GCN(config, adj, feature, class_num)
    gcn_net.add_flags_recursive(fp16=True)

    eval_net = LossAccuracyWrapper(gcn_net, label, eval_mask, config.weight_decay)
    test_net = LossAccuracyWrapper(gcn_net, label, test_mask, config.weight_decay)
    train_net = TrainNetWrapper(gcn_net, label, train_mask, config)

    loss_list = []
    for epoch in range(config.epochs):
        t = time.time()

        train_net.set_train()
        train_result = train_net()
        train_loss = train_result[0].asnumpy()
        train_accuracy = train_result[1].asnumpy()

        eval_net.set_train(False)
        eval_result = eval_net()
        eval_loss = eval_result[0].asnumpy()
        eval_accuracy = eval_result[1].asnumpy()

        loss_list.append(eval_loss)
        print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(train_loss),
              "train_acc=", "{:.5f}".format(train_accuracy), "val_loss=", "{:.5f}".format(eval_loss),
              "val_acc=", "{:.5f}".format(eval_accuracy), "time=", "{:.5f}".format(time.time() - t))

        if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]):
            print("Early stopping...")
            break

    t_test = time.time()
    test_net.set_train(False)
    test_result = test_net()
    test_loss = test_result[0].asnumpy()
    test_accuracy = test_result[1].asnumpy()
    print("Test set results:", "loss=", "{:.5f}".format(test_loss),
          "accuracy=", "{:.5f}".format(test_accuracy), "time=", "{:.5f}".format(time.time() - t_test))
コード例 #3
0
ファイル: main.py プロジェクト: shirley18411/course
def train(args_opt):
    """Train model."""
    np.random.seed(args_opt.seed)
    config = ConfigGCN()
    adj, feature, label = get_adj_features_labels(args_opt.data_dir)

    nodes_num = label.shape[0]
    train_mask = get_mask(nodes_num, 0, args_opt.train_nodes_num)
    eval_mask = get_mask(nodes_num, args_opt.train_nodes_num,
                         args_opt.train_nodes_num + args_opt.eval_nodes_num)
    test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num,
                         nodes_num)

    class_num = label.shape[1]
    gcn_net = GCN(config, adj, feature, class_num)
    gcn_net.add_flags_recursive(fp16=True)

    eval_net = LossAccuracyWrapper(gcn_net, label, eval_mask,
                                   config.weight_decay)
    test_net = LossAccuracyWrapper(gcn_net, label, test_mask,
                                   config.weight_decay)
    train_net = TrainNetWrapper(gcn_net, label, train_mask, config)

    loss_list = []
    for epoch in range(config.epochs):
        t = time.time()

        train_net.set_train()
        train_result = train_net()
        train_loss = train_result[0].asnumpy()
        train_accuracy = train_result[1].asnumpy()

        eval_net.set_train(False)
        eval_result = eval_net()
        eval_loss = eval_result[0].asnumpy()
        eval_accuracy = eval_result[1].asnumpy()

        loss_list.append(eval_loss)
        if epoch % 10 == 0:
            print("Epoch:", '%04d' % (epoch), "train_loss=",
                  "{:.5f}".format(train_loss), "train_acc=",
                  "{:.5f}".format(train_accuracy), "val_loss=",
                  "{:.5f}".format(eval_loss), "val_acc=",
                  "{:.5f}".format(eval_accuracy), "time=",
                  "{:.5f}".format(time.time() - t))

        if epoch > config.early_stopping and loss_list[-1] > np.mean(
                loss_list[-(config.early_stopping + 1):-1]):
            print("Early stopping...")
            break

    t_test = time.time()
    test_net.set_train(False)
    test_result = test_net()
    test_loss = test_result[0].asnumpy()
    test_accuracy = test_result[1].asnumpy()
    print("Test set results:", "loss=", "{:.5f}".format(test_loss),
          "accuracy=", "{:.5f}".format(test_accuracy), "time=",
          "{:.5f}".format(time.time() - t_test))
コード例 #4
0
ファイル: train.py プロジェクト: zhangjinrong/mindspore
def train():
    """Train model."""
    parser = argparse.ArgumentParser(description='GCN')
    parser.add_argument('--data_dir', type=str, default='./data/cora/cora_mr', help='Dataset directory')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')
    parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training')
    parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation')
    parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test')
    parser.add_argument('--save_TSNE', type=ast.literal_eval, default=False, help='Whether to save t-SNE graph')
    args_opt = parser.parse_args()
    if not os.path.exists("ckpts"):
        os.mkdir("ckpts")

    set_seed(args_opt.seed)
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend", save_graphs=False)
    config = ConfigGCN()
    adj, feature, label_onehot, label = get_adj_features_labels(args_opt.data_dir)

    nodes_num = label_onehot.shape[0]
    train_mask = get_mask(nodes_num, 0, args_opt.train_nodes_num)
    eval_mask = get_mask(nodes_num, args_opt.train_nodes_num, args_opt.train_nodes_num + args_opt.eval_nodes_num)
    test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num, nodes_num)

    class_num = label_onehot.shape[1]
    gcn_net = GCN(config, adj, feature, class_num)
    gcn_net.add_flags_recursive(fp16=True)

    eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
    train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)

    loss_list = []

    if args_opt.save_TSNE:
        out_feature = gcn_net()
        tsne_result = t_SNE(out_feature.asnumpy(), 2)
        graph_data = []
        graph_data.append(tsne_result)
        fig = plt.figure()
        scat = plt.scatter(tsne_result[:, 0], tsne_result[:, 1], s=2, c=label, cmap='rainbow')
        plt.title('t-SNE visualization of Epoch:0', fontsize='large', fontweight='bold', verticalalignment='center')

    for epoch in range(config.epochs):
        t = time.time()

        train_net.set_train()
        train_result = train_net()
        train_loss = train_result[0].asnumpy()
        train_accuracy = train_result[1].asnumpy()

        eval_net.set_train(False)
        eval_result = eval_net()
        eval_loss = eval_result[0].asnumpy()
        eval_accuracy = eval_result[1].asnumpy()

        loss_list.append(eval_loss)
        print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(train_loss),
              "train_acc=", "{:.5f}".format(train_accuracy), "val_loss=", "{:.5f}".format(eval_loss),
              "val_acc=", "{:.5f}".format(eval_accuracy), "time=", "{:.5f}".format(time.time() - t))

        if args_opt.save_TSNE:
            out_feature = gcn_net()
            tsne_result = t_SNE(out_feature.asnumpy(), 2)
            graph_data.append(tsne_result)

        if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]):
            print("Early stopping...")
            break
    save_checkpoint(gcn_net, "ckpts/gcn.ckpt")
    gcn_net_test = GCN(config, adj, feature, class_num)
    load_checkpoint("ckpts/gcn.ckpt", net=gcn_net_test)
    gcn_net_test.add_flags_recursive(fp16=True)

    test_net = LossAccuracyWrapper(gcn_net_test, label_onehot, test_mask, config.weight_decay)
    t_test = time.time()
    test_net.set_train(False)
    test_result = test_net()
    test_loss = test_result[0].asnumpy()
    test_accuracy = test_result[1].asnumpy()
    print("Test set results:", "loss=", "{:.5f}".format(test_loss),
          "accuracy=", "{:.5f}".format(test_accuracy), "time=", "{:.5f}".format(time.time() - t_test))

    if args_opt.save_TSNE:
        ani = animation.FuncAnimation(fig, update_graph, frames=range(config.epochs + 1), fargs=(graph_data, scat, plt))
        ani.save('t-SNE_visualization.gif', writer='imagemagick')