예제 #1
0
def test(args):
    comm = graphmix.Client()
    rank = comm.rank()
    nrank = comm.num_worker()
    if rank != 0:
        return
    t = ThreadPoolExecutor(max_workers=max_thread)
    item_count = 0

    def pull_data():
        while True:
            indices = np.random.randint(0, comm.meta["node"], 1000)
            query = comm.pull_node(indices)
            comm.wait(query)
            nonlocal item_count
            item_count += len(indices)

    def watch():
        nonlocal item_count
        start = time.time()
        while True:
            time.sleep(1)
            speed = item_count / (time.time() - start)
            print("speed : {} item/s".format(speed))

    task_list = [None for i in range(max_thread)]
    threading.Thread(target=watch).start()
    for i in range(max_thread):
        task_list[i] = t.submit(pull_data)
    time.sleep(1000)
예제 #2
0
def worker_main(args):
    from graphmix.utils import powerset
    driver = PytorchTrain(args)
    comm = graphmix.Client()
    mapping = {
        "G": graphmix.sampler.GraphSage,
        "R": graphmix.sampler.RandomWalk,
        "L": graphmix.sampler.LocalNode,
    }
    tests = powerset("GRL")
    train_dict = {}
    if comm.rank() == 0:
        log_file = open("log.txt", "w")
    for test in tests:
        test = "".join(test)
        samplers = list(test)
        eval_accs, test_accs, epochs = [], [], []
        for i in range(args.rerun):
            eval_acc, test_acc, epoch = driver.train_once(samplers)
            eval_accs.append(eval_acc)
            test_accs.append(test_acc)
            epochs.append(epoch)
        if comm.rank() == 0:
            printstr = "\t{:.3f}+-{:.4f}\t{:.3f}+-{:.4f}\t{:.1f}+-{:.3f}".format(
                np.mean(eval_accs), np.std(eval_accs), np.mean(test_accs),
                np.std(test_accs), np.mean(epochs), np.std(epochs))
            print(test, printstr)
            print(test, printstr, file=log_file, flush=True)
            train_dict["{}{}".format(test, i)] = driver.train_info
    if comm.rank() == 0:
        with open("train_data.yml", "w") as f:
            f.write(yaml.dump(train_dict))
예제 #3
0
def test(args):
    cora_dataset = graphmix.dataset.load_dataset("Cora")
    comm = graphmix.Client()
    rank = comm.rank()
    nrank = comm.num_worker()

    def check(graph):
        if graph.type == graphmix.sampler.GraphSage:
            assert np.all(graph.extra[:, 0] <= graph.i_feat[:, -1])
        for f, i in zip(graph.f_feat, graph.i_feat):
            idx = i[-2]
            assert np.all(f == cora_dataset.x[idx])
            assert i[0] == cora_dataset.y[idx]
        all_edge = np.array(cora_dataset.graph.edge_index).T
        for u, v in zip(graph.edge_index[0], graph.edge_index[1]):
            assert (index[u], index[v]) in all_edge

    samplers = [0, 1, 2, 3]
    for i in range(20):
        random.shuffle(samplers)
        query = comm.pull_graph(*samplers)
        graph = comm.wait(query)
        graph.convert2coo()
        index = graph.i_feat[:, -1]
        check(graph)
    print("CHECK OK")
예제 #4
0
def test(args):
    cora_dataset = graphmix.dataset.load_dataset("Cora")
    comm = graphmix.Client()
    repeats = 1000
    querys = []
    for i in range(repeats):
        querys.append(comm.pull_graph())
    for i in range(repeats):
        graph = comm.wait(querys[i])
    print("CHECK OK")
예제 #5
0
 def __init__(self, args):
     comm = graphmix.Client()
     self.meta = comm.meta
     dist.init_process_group(backend='nccl',
                             init_method='env://',
                             world_size=comm.num_worker(),
                             rank=comm.rank())
     self.device = args.local_rank
     self.eval_dataset = []
     torch.cuda.set_device(self.device)
     if dist.get_rank() == 0:
         self.dataset = load_dataset(self.meta["name"])
예제 #6
0
    def train_once(self, samplers):
        self.train_info = []
        meta = self.meta
        device = self.device
        model = Net(meta["float_feature"], meta["class"],
                    args.hidden).cuda(device)
        DDPmodel = nn.parallel.DistributedDataParallel(
            model, device_ids=[device], find_unused_parameters=True)
        optimizer = torch.optim.Adam(DDPmodel.parameters(), args.lr)
        num_nodes, num_epoch = 0, 0
        comm = graphmix.Client()
        query = comm.pull_graph()
        start = time.time()
        best_result = 0
        test_acc_result = 0
        converge_epoch = 0

        while True:
            sampler = samplers[random.randrange(len(samplers))]
            graph = comm.wait(query)
            graph.add_self_loop()
            query = comm.pull_graph(sampler)
            x = torch.Tensor(graph.f_feat).to(device)
            y = torch.Tensor(graph.i_feat).to(device, torch.long)
            if graph.type == graphmix.sampler.GraphSage:
                train_mask = torch.Tensor(graph.extra[:, 0]).to(
                    device, torch.long)
            else:
                train_mask = y[:, -1] == 1
            out = DDPmodel(x, graph)
            label = y[:, :-1]
            loss = model.loss(out, label, train_mask)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            total = train_mask.sum()
            total = torch_sync_data(total)
            num_nodes += total
            if num_nodes >= meta["train_node"]:
                num_epoch += 1
                num_nodes = 0
                if num_epoch == args.num_epoch:
                    break
                if dist.get_rank() == 0:
                    eval_acc, test_acc = self.eval_data(model)
                    self.train_info.append((eval_acc, test_acc))
                    if eval_acc > best_result:
                        best_result = eval_acc
                        test_acc_result = test_acc
                        converge_epoch = num_epoch
        return best_result, test_acc_result, converge_epoch
예제 #7
0
def worker_main(args):
    comm = graphmix.Client()
    meta = comm.meta
    dist.init_process_group(
        backend='nccl',
        init_method='env://',
        world_size=comm.num_worker(),
        rank=comm.rank()
    )
    device = args.local_rank
    torch.cuda.set_device(device)
    model = Net(meta["float_feature"], meta["class"], args.hidden).cuda(device)
    DDPmodel = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], find_unused_parameters=True)
    optimizer = torch.optim.Adam(DDPmodel.parameters(), 1e-3)

    query = comm.pull_graph()
    batch_num = meta["node"] // (args.batch_size * dist.get_world_size())
    times = []
    for epoch in range(args.num_epoch):
        epoch_start_time = time.time()
        count, total, wait_time = 0, 0, 0
        for i in range(batch_num):
            wait_start = time.time()
            graph = comm.wait(query)
            wait_time += time.time() - wait_start
            graph.add_self_loop()
            query = comm.pull_graph()
            x = torch.Tensor(graph.f_feat).to(device)
            y = torch.Tensor(graph.i_feat).to(device, torch.long)
            if graph.type == graphmix.sampler.GraphSage:
                train_mask = torch.Tensor(graph.extra[:, 0]).to(device, torch.long)
            else:
                train_mask = y[ : , -1] == 1
            label = y[ : , : -1]
            out = DDPmodel(x, graph)
            loss = model.loss(out, label, train_mask)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            total += train_mask.sum()
            acc = model.metrics(out, label, train_mask)
            count += int(train_mask.sum() * acc)
        epoch_end_time = time.time()
        times.append(epoch_end_time-epoch_start_time)
        count, total, wait_time = torch_sync_data(count, total, wait_time)
        if args.local_rank == 0:
            print("epoch {} time {:.3f} acc={:.3f}".format(epoch, np.array(times).mean(), count/total))
            print("wait time total : {:.3f}sec".format(wait_time / dist.get_world_size()))
예제 #8
0
파일: comm.py 프로젝트: nox-410/GraphMix
def test(args):
    comm = graphmix.Client()
    rank = comm.rank()
    nrank = comm.num_worker()
    num_node = comm.meta["node"]
    num_batch = num_node // args.batch_size
    print("{} batch per epoch".format(num_batch))
    for epoch in range(20):
        for i in range(num_batch):
            query = comm.pull_graph()
            graph = comm.wait(query)
        if epoch == 10:
            time.sleep(1)
            comm.barrier_all()
    time.sleep(1)
    comm.barrier_all()
예제 #9
0
def test(args):
    comm = graphmix.Client()
    rank = comm.rank()
    nrank = comm.num_worker()
    query = comm.pull_graph()
    graph = comm.wait(query)
    graph.convert2coo()
    cora_dataset = graphmix.dataset.load_dataset("Cora")
    index = graph.i_feat[:, -2]
    for f, i in zip(graph.f_feat, graph.i_feat):
        idx = i[-2]
        assert np.all(f == cora_dataset.x[idx])
        assert i[0] == cora_dataset.y[idx]
    all_edge = np.array(cora_dataset.graph.edge_index).T
    for u, v in zip(graph.edge_index[0], graph.edge_index[1]):
        assert (index[u], index[v]) in all_edge
    print("CHECK OK")
예제 #10
0
def test(args):
    comm = graphmix.Client()
    rank = comm.rank()
    nrank = comm.num_worker()
    dataset = graphmix.dataset.load_dataset("Cora")
    num_nodes = dataset.graph.num_nodes
    query = comm.pull_node(np.arange(num_nodes))
    pack = comm.wait(query)
    assert len(pack) == num_nodes
    reindex = {}
    for i, node in pack.items():
        idx = node.i[-2]
        reindex[idx] = i
        assert np.all(dataset.x[idx] == node.f)
        assert dataset.y[idx] == node.i[0]
    for u, v in zip(dataset.graph.edge_index[0], dataset.graph.edge_index[1]):
        assert (reindex[v] in pack[reindex[u]].e)

    print("Check OK")
예제 #11
0
def test(args):
    comm = graphmix.Client()
    rank = comm.rank()
    nrank = comm.num_worker()
    item_count = 0
    def pull_graph():
        while True:
            query = comm.pull_graph()
            graph = comm.wait(query)
            nonlocal item_count
            item_count += graph.num_nodes

    def watch():
        nonlocal item_count
        start = time.time()
        while True:
            time.sleep(1)
            speed = item_count / (time.time() - start)
            print("speed : {} item/s".format(speed))
    task_list = [None for i in range(max_thread)]
    threading.Thread(target=watch).start()
    for i in range(max_thread):
        threading.Thread(target=pull_graph).start()
    time.sleep(1000)
예제 #12
0
def arrive_and_leave():
    for i in range(4):
        comm = graphmix.Client(graphmix.default_server_port + i)
예제 #13
0
def test(args):
    arrive_and_leave()
    comm = graphmix.Client(graphmix.default_server_port)
    query = comm.pull_graph()
    graph = comm.wait(query)
    print("CHECK OK")