Beispiel #1
0
def prepare_data(ngraph):
    rank = ad.get_worker_communicate().rank()
    nrank = ad.get_worker_communicate().nrank()
    graphs = []
    graphsage_sample_depth = 2
    graphsage_sample_width = 2
    node_upper_bound = args.batch_size * (
        (graphsage_sample_width**(graphsage_sample_depth + 1)) - 1)
    print("Start Sampling {} graphs".format(ngraph))

    def transform(result):
        [graph, sample_mask] = result
        train_mask = np.zeros(node_upper_bound)
        train_mask[0:graph.num_nodes] = sample_mask * graph.x[:, -1]
        test_mask = np.zeros(node_upper_bound)
        test_mask[0:graph.num_nodes] = (sample_mask -
                                        graph.x[:, -1]) * sample_mask
        graph = padding(graph, node_upper_bound)
        mp_val = mp_matrix(graph, 0, system="tensorflow")
        return graph, mp_val, train_mask, test_mask

    with DistributedGraphSageSampler(args.path,
                                     args.batch_size,
                                     graphsage_sample_depth,
                                     graphsage_sample_width,
                                     rank=rank,
                                     nrank=nrank,
                                     transformer=transform,
                                     cache_size_factor=1,
                                     reduce_nonlocal_factor=0,
                                     num_sample_thread=4) as sampler:
        for i in tqdm(range(ngraph)):
            g_sample, mp_val, train_mask, test_mask = sampler.sample()
            graphs.append([g_sample, mp_val, train_mask, test_mask])
    return graphs
Beispiel #2
0
def train_main(args):
    with open(os.path.join(args.path, "meta.yml"), 'rb') as f:
        meta = yaml.load(f.read(), Loader=yaml.FullLoader)
    hidden_layer_size = args.hidden_size
    num_epoch = args.num_epoch
    rank = ad.get_worker_communicate().rank()
    device_id = rank % args.num_local_worker
    nrank = ad.get_worker_communicate().nrank()
    distributed.ps_init(rank, nrank)
    ngraph = meta["partition"]["nodes"][rank] // args.batch_size
    graphs = prepare_data(ngraph)
    idx, epoch, nnodes = 0, 0, 0
    worker_device = "gpu:0"
    graph_len = graphs[0][0].y.shape[0]
    with tf.device(worker_device):
        norm_adj = tf.compat.v1.sparse.placeholder(tf.float32, name="norm_adj")
        sparse_feature = tf.placeholder(tf.int32, [graph_len, meta["feature"] - 1])
        y_ = tf.placeholder(tf.int32, [graph_len], name="y_")
        train_mask = tf.placeholder(tf.float32, [graph_len], name="train_mask")
    loss, y, train_op = model(norm_adj, sparse_feature, y_, train_mask)
    init=tf.global_variables_initializer()
    gpu_options = tf.GPUOptions(allow_growth=True)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    sess.run(init)
    acc_cnt, total_cnt = 0, 0
    train_acc, train_cnt = 0, 0
    start = time.time()
    while True:
        g_sample, mp_val, mask, mask_eval = graphs[idx]
        idx = (idx + 1) % ngraph
        feed_dict = {
            norm_adj : mp_val,
            sparse_feature : g_sample.x[:, 0:-1],
            y_ : g_sample.y,
            train_mask : mask
        }
        loss_val = sess.run([loss, y, train_op], feed_dict=feed_dict)
        pred_val = loss_val[1]
        acc_val = np.equal(np.argmax(pred_val, 1), g_sample.y).astype(np.float)
        acc_cnt += (acc_val * mask_eval).sum()
        total_cnt +=  mask_eval.sum()
        nnodes += mask.sum() + mask_eval.sum()
        train_acc += (acc_val * mask).sum()
        train_cnt += mask.sum()
        if nnodes > meta["partition"]["nodes"][rank] // 10:
            nnodes = 0
            epoch += 1
            print("Acc : ", acc_cnt / total_cnt, train_acc / train_cnt ,"Time : ", time.time() - start)
            print(pred_val)
            start = time.time()
            acc_cnt, total_cnt = 0, 0
            train_acc, train_cnt = 0, 0
            if epoch >= num_epoch:
                break
Beispiel #3
0
def train_main(args):
    autodist = AutoDist(resource_spec_file, Parallaxx())
    with open(os.path.join(args.path, "meta.yml"), 'rb') as f:
        meta = yaml.load(f.read(), Loader=yaml.FullLoader)
    hidden_layer_size = args.hidden_size
    num_epoch = args.num_epoch
    rank = ad.get_worker_communicate().rank()
    device_id = rank % args.num_local_worker
    nrank = ad.get_worker_communicate().nrank()
    distributed.ps_init(rank, nrank)
    ngraph = meta["partition"]["nodes"][rank] // args.batch_size
    graphs = prepare_data(ngraph)
    idx, epoch, nnodes = 0, 0, 0
    graph_len = graphs[0][0].y.shape[0]
    with tf.Graph().as_default() as g, autodist.scope():
        norm_adj = tf.compat.v1.sparse.placeholder(tf.float32, name="norm_adj")
        sparse_feature = tf.placeholder(tf.int32,
                                        [graph_len, meta["feature"] - 1])
        y_ = tf.placeholder(tf.int32, [graph_len], name="y_")
        train_mask = tf.placeholder(tf.float32, [graph_len], name="train_mask")
        loss, y, train_op = model(norm_adj, sparse_feature, y_, train_mask)
        sess = autodist.create_distributed_session()

        acc_stat = []
        start = time.time()
        while True:
            g_sample, mp_val, mask, mask_eval = graphs[idx]
            idx = (idx + 1) % ngraph
            feed_dict = {
                norm_adj: mp_val,
                sparse_feature: g_sample.x[:, 0:-1],
                y_: g_sample.y,
                train_mask: mask
            }
            print("Before training")
            loss_val = sess.run([loss, y, y_, train_op], feed_dict=feed_dict)
            print(loss_val)
            pred_val = loss_val[1]
            true_val = loss_val[2]
            acc_val = np.equal(np.argmax(pred_val, 1),
                               true_val).astype(np.float)
            acc_stat.append(acc_val)
            nnodes += mask.sum() + mask_eval.sum()
            if nnodes > meta["partition"]["nodes"][rank]:
                nnodes = 0
                epoch += 1
                print("Acc : ", np.mean(acc_stat), "Time : ",
                      time.time() - start)
                start = time.time()
                acc_stat = []
                if epoch >= num_epoch:
                    break
Beispiel #4
0
def test():
    ctx = ndarray.cpu(0)
    rank = int(os.environ["WORKER_ID"])
    nrank = int(os.environ["DMLC_NUM_WORKER"])
    nitem = 2000
    item_len = 1000
    arr = ndarray.array(np.random.rand(nitem, item_len),
                        ctx=ctx)  # generate a long buffer

    push_indices = np.arange(nitem) * nrank + rank
    print(push_indices)
    push_length = np.repeat(item_len, repeats=nitem)
    worker_communicate = ad.get_worker_communicate()
    worker_communicate.PushData(pointer(push_indices), nitem, arr.handle,
                                pointer(push_length))
    print("Waiting")
    worker_communicate.WaitPushData(pointer(push_indices), nitem)
    worker_communicate.BarrierWorker()
    print("OK")
    arr2 = ndarray.array(np.random.rand(nitem, item_len), ctx=ctx)
    worker_communicate.PullData(pointer(push_indices), nitem, arr2.handle,
                                pointer(push_length))
    worker_communicate.WaitPullData(pointer(push_indices), nitem)
    assert np.all(arr.asnumpy() == arr2.asnumpy())
    print("Check Complete")
Beispiel #5
0
 def __init__(self, limit, length, width, node_id, policy="LRU", bound=100):
     # make sure we open libps.so first
     comm = ad.get_worker_communicate()
     sys.path.append(os.path.dirname(__file__) + "/../../build/lib")
     import hetu_cache
     policy = policy.lower()
     if policy == "lru":
         self.cache = hetu_cache.LRUCache(limit, length, width, node_id)
     elif policy == "lfu":
         self.cache = hetu_cache.LFUCache(limit, length, width, node_id)
     elif policy == "lfuopt":
         self.cache = hetu_cache.LFUOptCache(limit, length, width, node_id)
     else:
         raise NotImplementedError(policy)
     self.cache.pull_bound = bound
     self.cache.push_bound = bound
     comm.BarrierWorker()
Beispiel #6
0
def test(args):
    comm = ad.get_worker_communicate()
    node_id = 0
    limit = 10000
    length = 10000
    width = 128
    comm.InitTensor(ctypes.c_int(node_id), ctypes.c_int(2), ctypes.c_int(length), ctypes.c_int(width), ctypes.c_int(2), ctypes.c_double(0), ctypes.c_double(0.1), ctypes.c_ulonglong(123),\
        ctypes.c_int(0), (ctypes.c_float * 1)(0.1), ctypes.c_int(1))
    cache = CacheSparseTable(limit, length, width, node_id, "LFUOpt")
    for i in tqdm(range(10000)):
        key = np.random.randint(10000, size=1000).astype(np.uint64)
        value = np.empty((key.size, width), np.float32)
        ts = cache.embedding_lookup(key, value)
        ts.wait()
        grad = np.random.rand(key.size, width).astype(np.float32)
        ts = cache.embedding_update(key, grad)
        ts.wait()
Beispiel #7
0
    def debug_keys(self):
        comm = ad.get_worker_communicate()
        nrank = comm.nrank()
        form = "w" if comm.rank() == 0 else "a"
        for i in range(nrank):
            if i == comm.rank():
                with open("_keys.log".format(comm.rank()), form) as f:
                    print(*self.keys(), file=f, flush=True)
            comm.BarrierWorker()

        if comm.rank() != 0:
            return
        keys = []
        with open("_keys.log".format(comm.rank()), "r") as f:
            for i in range(nrank):
                keys.append(set(map(int, f.readline().split())))
        rt = np.zeros([nrank, nrank])
        for i in range(nrank):
            for j in range(nrank):
                if not keys[i]:
                    continue
                rt[i][j] = len(keys[i].intersection(keys[j])) / len(keys[i])
        return rt
Beispiel #8
0
def test():
    ctx = ndarray.cpu(0)
    rank = int(os.environ["WORKER_ID"])
    nrank = int(os.environ["DMLC_NUM_WORKER"])
    if rank > 0:
        return
    arr = ndarray.array(np.random.rand(nitem, item_len),ctx = ctx) # generate a long buffer

    push_indices = np.arange(nitem)
    print(push_indices)
    push_length = np.repeat(item_len, repeats=nitem)
    worker_communicate = ad.get_worker_communicate()
    query = worker_communicate.PushData(pointer(push_indices), nitem, arr.handle, pointer(push_length))
    worker_communicate.WaitData(query)
    print("data_pushed")
    t = ThreadPoolExecutor(max_workers=max_thread)
    byte_count = 0
    arr2 = ndarray.array(np.random.rand(nitem, item_len),ctx = ctx)
    def pull_data():
        query = worker_communicate.PullData(pointer(push_indices), nitem, arr2.handle, pointer(push_length))
        worker_communicate.WaitData(query)
        # print( np.all(arr.asnumpy() == arr2.asnumpy()) )
        nonlocal byte_count
        byte_count += nitem * item_len * 4
    def watch():
        nonlocal byte_count
        start = time.time()
        while True:
            time.sleep(1)
            speed = byte_count / (time.time() - start)
            print("speed : {} MB/s".format(speed / 2**20))
    task_list = [None for i in range(max_thread)]
    threading.Thread(target=watch).start()
    while True:
        for i in range(max_thread):
            if task_list[i] is None or task_list[i].done():
                task_list[i] = t.submit(pull_data)
Beispiel #9
0
def test():
    ctx = ndarray.cpu(0)
    rank = int(os.environ["WORKER_ID"])
    nrank = int(os.environ["DMLC_NUM_WORKER"])
    arr = ndarray.array(np.random.rand(2,rank+100),ctx = ctx)
    print(arr.asnumpy())

    push_indices = np.array([2*rank+1,2*rank+2])

    if rank == 0:
        pull_indices = np.array([3])
    elif rank == 1:
        pull_indices = np.array([1])

    push_length = np.array([rank+100,rank+100])


    if rank == 0:
        pull_length = np.array([101])
        out_arr = ndarray.array(np.zeros(101),ctx = ctx)
    elif rank == 1:
        pull_length = np.array([100])
        out_arr = ndarray.array(np.zeros(100),ctx = ctx)

    print(out_arr.asnumpy())

    worker_communicate = ad.get_worker_communicate()
    query = worker_communicate.PushData(pointer(push_indices), 2, arr.handle, pointer(push_length))

    worker_communicate.WaitData(query);

    worker_communicate.BarrierWorker()
    worker_communicate.PullData(pointer(pull_indices), 1, out_arr.handle, pointer(pull_length))
    worker_communicate.WaitData(query);

    print(out_arr.asnumpy())
Beispiel #10
0
def train_main(args):
    with open(os.path.join(args.path, "meta.yml"), 'rb') as f:
        meta = yaml.load(f.read(), Loader=yaml.FullLoader)
    hidden_layer_size = args.hidden_size
    num_epoch = args.num_epoch
    rank = ad.get_worker_communicate().rank()
    nrank = int(os.environ["DMLC_NUM_WORKER"])
    ctx = ndarray.gpu(rank % args.num_local_worker)
    embedding_width = args.hidden_size
    extract_width = embedding_width * (meta["feature"] - 1)

    y_ = dl.GNNDataLoaderOp(lambda g: ndarray.array(
        convert_to_one_hot(g.y, max_val=g.num_classes), ctx=ndarray.cpu()))
    mask_ = ad.Variable(name="mask_")
    gcn1 = GCN(extract_width, hidden_layer_size, activation="relu")
    gcn2 = GCN(hidden_layer_size, meta["class"])
    index = dl.GNNDataLoaderOp(
        lambda g: ndarray.array(g.x[:, 0:-1], ctx=ndarray.cpu()),
        ctx=ndarray.cpu())
    embedding = initializers.random_normal([meta["idx_max"], embedding_width],
                                           stddev=0.1)
    embed = ad.embedding_lookup_op(embedding, index)
    embed = ad.array_reshape_op(embed, (-1, extract_width))
    # embed = ad.reduce_mean_op(embed, axes=1)
    # x = ad.concat_op(x_, embed, axis=1)
    x = gcn1(embed)
    y = gcn2(x)
    loss = ad.softmaxcrossentropy_op(y, y_)
    train_loss = loss * mask_
    train_loss = ad.reduce_mean_op(train_loss, [0])
    opt = optimizer.SGDOptimizer(args.learning_rate)
    train_op = opt.minimize(train_loss)
    ad.worker_init()
    distributed.ps_init(rank, nrank)

    ngraph = meta["partition"]["nodes"][rank] // args.batch_size
    graphs = prepare_data(ngraph)
    idx = 0
    g_sample, mp_val, mask, mask_eval = graphs[idx]
    idx = (idx + 1) % ngraph
    dl.GNNDataLoaderOp.step(g_sample)
    dl.GNNDataLoaderOp.step(g_sample)
    epoch = 0
    nnodes = 0
    executor = ad.Executor([loss, y, train_op],
                           ctx=ctx,
                           comm_mode='PS',
                           use_sparse_pull=False,
                           cstable_policy=args.cache)
    while True:
        g_sample_nxt, mp_val_nxt, mask_nxt, mask_eval_nxt = graphs[idx]
        idx = (idx + 1) % ngraph
        dl.GNNDataLoaderOp.step(g_sample_nxt)
        feed_dict = {gcn1.mp: mp_val, gcn2.mp: mp_val, mask_: mask}
        loss_val, y_predicted, _ = executor.run(feed_dict=feed_dict)
        y_predicted = y_predicted.asnumpy().argmax(axis=1)

        acc = np.sum((y_predicted == g_sample.y) * mask_eval)
        train_acc = np.sum((y_predicted == g_sample.y) * mask)
        stat.update(acc, mask_eval.sum(),
                    np.sum(loss_val.asnumpy() * mask_eval) / mask_eval.sum())
        stat.update_train(train_acc, mask.sum(),
                          np.sum(loss_val.asnumpy() * mask) / mask.sum())

        # distributed.ps_get_worker_communicator().BarrierWorker()
        nnodes += mask.sum() + mask_eval.sum()
        if nnodes > meta["partition"]["nodes"][rank]:
            nnodes = 0
            epoch += 1
            if rank == 0:
                stat.print(epoch)
            if epoch >= num_epoch:
                break
        g_sample, mp_val, mask, mask_eval = g_sample_nxt, mp_val_nxt, mask_nxt, mask_eval_nxt
Beispiel #11
0
def test(func_name,
         nitem=2000,
         item_len=10000,
         ind_len=500,
         max_thread=10,
         ret_ans=False):
    func_name = func_name.lower()
    ctx = ndarray.cpu(0)
    rank = int(os.environ["WORKER_ID"])
    nrank = int(os.environ["DMLC_NUM_WORKER"])

    comm = ad.get_worker_communicate()
    byte_count = 0
    if func_name == 'pushnpull':
        inarr = ndarray.array(np.random.rand(nitem, item_len), ctx=ctx)
        outarr = ndarray.array(np.random.rand(nitem, item_len), ctx=ctx)

        def func(name):
            comm.Push(name, inarr.handle, None)
            comm.Pull(name, outarr.handle)
            comm.Wait(name)
            nonlocal byte_count
            byte_count += nitem * item_len * 4 * 2
    elif func_name == 'pushpull':
        inarr = ndarray.array(np.random.rand(nitem, item_len), ctx=ctx)
        outarr = ndarray.array(np.random.rand(nitem, item_len), ctx=ctx)

        def func(name):
            comm.DDPushPull(name, inarr.handle, outarr.handle, None)
            comm.Wait(name)
            nonlocal byte_count
            byte_count += nitem * item_len * 4 * 2
    elif func_name == 'sparsepushnpull':
        inarr = ndarray.array(np.random.rand(ind_len, item_len), ctx=ctx)
        outarr = ndarray.array(np.random.rand(nitem, item_len), ctx=ctx)

        def func(name):
            np_ind = np.random.randint(low=0, high=nitem, size=(ind_len, ))
            inind = ndarray.array(np_ind.astype(np.float32), ctx=ctx)
            uni_ind_len = np.unique(np_ind).size
            comm.SparsePush(name, inind.handle, inarr.handle, None)
            comm.Pull(name, outarr.handle)
            comm.Wait(name)
            nonlocal byte_count
            byte_count += (nitem + uni_ind_len) * item_len * 4
    elif func_name == 'sparsepushnsparsepull':
        inarr = ndarray.array(np.random.rand(ind_len, item_len), ctx=ctx)
        outarr = ndarray.array(np.random.rand(ind_len, item_len), ctx=ctx)

        def func(name):
            np_inind = np.random.randint(low=0, high=nitem, size=(ind_len, ))
            np_outind = np.random.randint(low=0, high=nitem, size=(ind_len, ))
            inind = ndarray.array(np_inind.astype(np.float32), ctx=ctx)
            outind = ndarray.array(np_outind.astype(np.float32), ctx=ctx)
            uni_inind_len = np.unique(np_inind).size
            uni_outind_len = np.unique(np_outind).size
            comm.SparsePush(name, inind.handle, inarr.handle, None)
            comm.SparsePull(name, outind.handle, outarr.handle)
            comm.Wait(name)
            nonlocal byte_count
            byte_count += (uni_inind_len + uni_outind_len) * item_len * 4
    elif func_name == 'push':
        inarr = ndarray.array(np.random.rand(nitem, item_len), ctx=ctx)

        def func(name):
            comm.Push(name, inarr.handle, None)
            comm.Wait(name)
            nonlocal byte_count
            byte_count += nitem * item_len * 4
    elif func_name == 'pull':
        outarr = ndarray.array(np.random.rand(nitem, item_len), ctx=ctx)

        def func(name):
            comm.Pull(name, outarr.handle)
            comm.Wait(name)
            nonlocal byte_count
            byte_count += nitem * item_len * 4
    elif func_name == 'sparsepush':
        inarr = ndarray.array(np.random.rand(ind_len, item_len), ctx=ctx)

        def func(name):
            np_inind = np.random.randint(low=0, high=nitem, size=(ind_len, ))
            inind = ndarray.array(np_inind.astype(np.float32), ctx=ctx)
            uni_inind_len = np.unique(np_inind).size
            comm.SparsePush(name, inind.handle, inarr.handle, None)
            comm.Wait(name)
            nonlocal byte_count
            byte_count += uni_inind_len * item_len * 4
    elif func_name == 'sparsepull':
        outarr = ndarray.array(np.random.rand(ind_len, item_len), ctx=ctx)

        def func(name):
            np_outind = np.random.randint(low=0, high=nitem, size=(ind_len, ))
            outind = ndarray.array(np_outind.astype(np.float32), ctx=ctx)
            uni_outind_len = np.unique(np_outind).size
            comm.SparsePull(name, outind.handle, outarr.handle)
            comm.Wait(name)
            nonlocal byte_count
            byte_count += uni_outind_len * item_len * 4
    elif func_name == 'sdpushpull':
        inarr = ndarray.array(np.random.rand(ind_len, item_len), ctx=ctx)
        outarr = ndarray.array(np.random.rand(nitem, item_len), ctx=ctx)

        def func(name):
            np_inind = np.random.randint(low=0, high=nitem, size=(ind_len, ))
            inind = ndarray.array(np_inind.astype(np.float32), ctx=ctx)
            uni_inind_len = np.unique(np_inind).size
            comm.SDPushPull(name, inind.handle, inarr.handle, outarr.handle,
                            None)
            comm.Wait(name)
            nonlocal byte_count
            byte_count += (uni_inind_len + nitem) * item_len * 4
    elif func_name == 'sspushpull':
        inarr = ndarray.array(np.random.rand(ind_len, item_len), ctx=ctx)
        outarr = ndarray.array(np.random.rand(ind_len, item_len), ctx=ctx)

        def func(name):
            np_inind = np.random.randint(low=0, high=nitem, size=(ind_len, ))
            np_outind = np.random.randint(low=0, high=nitem, size=(ind_len, ))
            inind = ndarray.array(np_inind.astype(np.float32), ctx=ctx)
            uni_inind_len = np.unique(np_inind).size
            outind = ndarray.array(np_outind.astype(np.float32), ctx=ctx)
            uni_outind_len = np.unique(np_outind).size
            comm.SSPushPull(name, inind.handle, inarr.handle, outind.handle,
                            outarr.handle, None)
            comm.Wait(name)
            nonlocal byte_count
            byte_count += (uni_inind_len + uni_outind_len) * item_len * 4
    else:
        assert False
    if 'sparse' in func_name or func_name in ('sdpushpull', 'sspushpull'):
        arr_len = ctypes.c_int(nitem)
        arr_wid = ctypes.c_int(item_len)
        sparse_init = ctypes.c_int(1)
    else:
        arr_len = ctypes.c_int(nitem * item_len)
        arr_wid = ctypes.c_int(1)
        sparse_init = ctypes.c_int(0)
    for i in range(max_thread):
        comm.InitTensor(i, sparse_init, arr_len, arr_wid, ctypes.c_int(0), ctypes.c_double(0), ctypes.c_double(1), ctypes.c_ulonglong(123),\
            ctypes.c_int(0), (ctypes.c_float * 1)(0.1), ctypes.c_int(1))
    # print("data init")
    t = ThreadPoolExecutor(max_workers=max_thread)
    if ret_ans:
        task_list = [None for i in range(max_thread)]
        for i in range(max_thread):
            task_list[i] = t.submit(func, i)
        curByte = byte_count
        start = time.time()
        cnt = 0
        while cnt < 30:
            for i in range(max_thread):
                if task_list[i].done():
                    cnt += 1
                    task_list[i] = t.submit(func, i)
        speed = (byte_count - curByte) / (time.time() - start) / 2**20
        t.shutdown()
        for i in range(max_thread):
            comm.ClearOnServer(i)
            comm.Clear(i)
        return speed
    else:

        def watch():
            start = time.time()
            while True:
                time.sleep(1)
                speed = byte_count / (time.time() - start)
                print("speed : {} MB/s".format(speed / 2**20))

        task_list = [None for i in range(max_thread)]
        threading.Thread(target=watch).start()
        while True:
            for i in range(max_thread):
                if task_list[i] is None or task_list[i].done():
                    task_list[i] = t.submit(func, i)
Beispiel #12
0
def worker(args):
    model = args.model
    rank = ad.get_worker_communicate().rank()
    def train(iterations, auc_enabled=True):
        train_loss, train_acc, train_auc = [], [], []
        for it in range(iterations):
            loss_val, predict_y, y_val, _ = executor.run(convert_to_numpy_ret_vals=True)
            if y_val.shape[1] == 1: # for criteo case
                acc_val = np.equal(
                    y_val,
                    predict_y > 0.5).astype(np.float)
            else:
                acc_val = np.equal(
                    np.argmax(y_val, 1),
                    np.argmax(predict_y, 1)).astype(np.float)
            train_loss.append(loss_val[0])
            train_acc.append(acc_val)
            if auc_enabled:
                train_auc.append(metrics.roc_auc_score(y_val, predict_y))
            executor.ps_comm.BarrierWorker()
        if auc_enabled:
            return np.mean(train_loss), np.mean(train_acc), np.mean(train_auc)
        else:
            return np.mean(train_loss), np.mean(train_acc)
    def validate(iterations):
        test_loss, test_acc, test_auc = [], [], []
        for it in range(iterations):
            loss_val, test_y_predicted, y_test_val = val_executor.run(convert_to_numpy_ret_vals=True)
            if y_test_val.shape[1] == 1: # for criteo case
                correct_prediction = np.equal(
                    y_test_val,
                    test_y_predicted > 0.5).astype(np.float)
            else:
                correct_prediction = np.equal(
                    np.argmax(y_test_val, 1),
                    np.argmax(test_y_predicted, 1)).astype(np.float)
            test_loss.append(loss_val[0])
            test_acc.append(correct_prediction)
            test_auc.append(metrics.roc_auc_score(y_test_val, test_y_predicted))
        return np.mean(test_loss), np.mean(test_acc), np.mean(test_auc)

    if args.all:
        from models.load_data import process_all_criteo_data
        dense, sparse, labels = process_all_criteo_data(return_val=args.val)
    elif args.val:
        from models.load_data import process_head_criteo_data
        dense, sparse, labels = process_head_criteo_data(return_val=True)
    else:
        from models.load_data import process_sampled_criteo_data
        dense, sparse, labels = process_sampled_criteo_data()
    loss, prediction, y_, train_op = model(dense, sparse, labels)

    executor = ad.Executor([loss, prediction, y_, train_op], ctx=ndarray.gpu(rank),\
        dataloader_name='train', stream_mode='AllStreams', comm_mode='PS', use_sparse_pull=True, cstable_policy=args.cache, bsp=args.bsp, cache_bound=args.bound)
    if args.val:
        print('Validation enabled...')
        val_executor = ad.Executor([loss, prediction, y_], ctx=ndarray.gpu(rank),\
            dataloader_name='validate', stream_mode='AllStreams', comm_mode='PS', use_sparse_pull=True, inference=True)

    if args.all:
        raw_log_file = './logs/localps_%s' % (args.model)
        if args.bsp:
            raw_log_file += '_bsp'
        else:
            raw_log_file += '_asp'
        if args.cache:
            raw_log_file += '_%s' % (args.cache)
        raw_log_file += '_%d.log' % (rank)
        print('Processing all data, log to', raw_log_file)
        log_file = open(raw_log_file, 'w')
        total_loop = 20 * (executor.batch_num // 1000)
        for lp in range(total_loop):
            print("iters: %d" % (lp * 1000))
            train_loss, train_acc, train_auc = train(1000)
            if args.val:
                val_loss, val_acc, val_auc = validate(100)
                printstr = "train_loss: %.4f, train_acc: %.4f, train_auc: %.4f, test_loss: %.4f, test_acc: %.4f, test_auc: %.4f"\
                        % (train_loss, train_acc, train_auc, val_loss, val_acc, val_auc)
            else:
                printstr = "train_loss: %.4f, train_acc: %.4f, train_auc: %.4f"\
                        % (train_loss, train_acc, train_auc)
            print(printstr)
            log_file.write(printstr + '\n')
            if lp % 5 == 0:
                log_file.flush()
    else:
        total_epoch = 50
        for ep in range(total_epoch):
            if ep == 5:
                start = time.time()
            print("epoch %d" % ep)
            ep_st = time.time()
            train_loss, train_acc = train(executor.batch_num, auc_enabled=False)
            ep_en = time.time()
            if args.val:
                val_loss, val_acc, val_auc = validate(val_executor.batch_num)
                print("train_loss: %.4f, train_acc: %.4f, train_time: %.4f, test_loss: %.4f, test_acc: %.4f, test_auc: %.4f"
                        % (train_loss, train_acc, ep_en - ep_st, val_loss, val_acc, val_auc))
            else:
                print("train_loss: %.4f, train_acc: %.4f, train_time: %.4f"
                        % (train_loss, train_acc, ep_en - ep_st))
        print('all time:', time.time() - start)
Beispiel #13
0
def train_main(args):
    with open(os.path.join(args.path, "meta.yml"), 'rb') as f:
        meta = yaml.load(f.read(), Loader=yaml.FullLoader)
    hidden_layer_size = args.hidden_size
    num_epoch = args.num_epoch
    rank = ad.get_worker_communicate().rank()
    device_id = rank % args.num_local_worker
    os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
    nrank = ad.get_worker_communicate().nrank()
    distributed.ps_init(rank, nrank)
    ngraph = meta["partition"]["nodes"][rank] // args.batch_size
    graphs = prepare_data(ngraph)
    idx, epoch, nnodes = 0, 0, 0
    worker_device = "/job:worker/task:{}/gpu:0".format(rank)
    graph_len = graphs[0][0].y.shape[0]
    with tf.device(worker_device):
        norm_adj = tf.compat.v1.sparse.placeholder(tf.float32, name="norm_adj")
        sparse_feature = tf.placeholder(tf.int32,
                                        [graph_len, meta["feature"] - 1])
        y_ = tf.placeholder(tf.int32, [graph_len], name="y_")
        train_mask = tf.placeholder(tf.float32, [graph_len], name="train_mask")
    loss, y, train_op, global_step = model(norm_adj, sparse_feature, y_,
                                           train_mask, cluster, rank)

    with tf.device(
            tf.train.replica_device_setter(worker_device=worker_device,
                                           cluster=cluster)):
        server = tf.train.Server(cluster, job_name="worker", task_index=rank)
        init = tf.global_variables_initializer()
        sv = tf.train.Supervisor(is_chief=(rank == 0),
                                 init_op=init,
                                 recovery_wait_secs=1,
                                 global_step=global_step)
        sess_config = tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=False,
            device_filters=["/job:ps", "/job:worker/task:%d" % rank])
        sess = sv.prepare_or_wait_for_session(server.target,
                                              config=sess_config)
        sess.run(init)

    acc_stat = []
    start = time.time()
    while True:
        g_sample, mp_val, mask, mask_eval = graphs[idx]
        idx = (idx + 1) % ngraph
        feed_dict = {
            norm_adj: mp_val,
            sparse_feature: g_sample.x[:, 0:-1],
            y_: g_sample.y,
            train_mask: mask
        }
        loss_val = sess.run([loss, y, y_, train_op], feed_dict=feed_dict)
        pred_val = loss_val[1]
        true_val = loss_val[2]
        acc_val = np.equal(np.argmax(pred_val, 1), true_val).astype(np.float)
        acc_stat.append(acc_val)
        nnodes += mask.sum() + mask_eval.sum()
        if nnodes > meta["partition"]["nodes"][rank]:
            nnodes = 0
            epoch += 1
            print("Acc : ", np.mean(acc_stat), "Time : ", time.time() - start)
            start = time.time()
            acc_stat = []
            if epoch >= num_epoch:
                break
Beispiel #14
0
def worker(args):
    def validate():
        hits, ndcgs = [], []
        for idx in range(testData.shape[0]):
            start_index = idx * 100
            predictions = val_executor.run(convert_to_numpy_ret_vals=True)
            map_item_score = {testItemInput[start_index + i]: predictions[0][i] for i in range(100)}
            gtItem = testItemInput[start_index]
            # Evaluate top rank list
            ranklist = heapq.nlargest(topK, map_item_score, key=map_item_score.get)
            hr = getHitRatio(ranklist, gtItem)
            ndcg = getNDCG(ranklist, gtItem)
            hits.append(hr)
            ndcgs.append(ndcg)
        hr, ndcg = np.array(hits).mean(), np.array(ndcgs).mean()
        return hr, ndcg
    def get_current_shard(data):
        if args.comm is not None:
            part_size = data.shape[0] // nrank
            start = part_size * rank
            end = start + part_size if rank != nrank - 1 else data.shape[0]
            return data[start:end]
        else:
            return data

    device_id = 0
    if args.comm == 'PS':
        rank = ad.get_worker_communicate().rank()
        nrank = int(os.environ['DMLC_NUM_WORKER'])
        device_id = rank % 8
    elif args.comm == 'Hybrid':
        comm, rank = ad.mpi_nccl_init()
        nrank = int(os.environ['DMLC_NUM_WORKER'])
        device_id = rank % 8

    from movielens import getdata
    if args.all:
        trainData, testData = getdata('ml-25m', 'datasets')
        trainUsers = get_current_shard(trainData['user_input'])
        trainItems = get_current_shard(trainData['item_input'])
        trainLabels = get_current_shard(trainData['labels'])
        testData = get_current_shard(testData)
        testUserInput = np.repeat(np.arange(testData.shape[0], dtype=np.int32), 100)
        testItemInput = testData.reshape((-1,))
    else:
        trainData, testData = getdata('ml-25m', 'datasets')
        trainUsers = get_current_shard(trainData['user_input'][:1024000])
        trainItems = get_current_shard(trainData['item_input'][:1024000])
        trainLabels = get_current_shard(trainData['labels'][:1024000])
        testData = get_current_shard(testData[:1470])
        testUserInput = np.repeat(np.arange(testData.shape[0], dtype=np.int32), 100)
        testItemInput = testData.reshape((-1,))

    num_users, num_items = {
        'ml-1m': (6040, 3706),
        'ml-20m': (138493, 26744),
        'ml-25m': (162541, 59047),
    }['ml-25m']
    # assert not args.all or num_users == testData.shape[0]
    batch_size = 1024
    num_negatives = 4
    topK = 10
    user_input = dl.dataloader_op([
        dl.Dataloader(trainUsers, batch_size, 'train'),
        dl.Dataloader(testUserInput, 100, 'validate'),
    ])
    item_input = dl.dataloader_op([
        dl.Dataloader(trainItems, batch_size, 'train'),
        dl.Dataloader(testItemInput, 100, 'validate'),
    ])
    y_ = dl.dataloader_op([
        dl.Dataloader(trainLabels.reshape((-1, 1)), batch_size, 'train'),
    ])

    loss, y, train_op = neural_mf(user_input, item_input, y_, num_users, num_items)

    executor = ad.Executor([loss, train_op], ctx=ndarray.gpu(device_id), dataloader_name='train', \
        comm_mode=args.comm, cstable_policy=args.cache, bsp=args.bsp, cache_bound=args.bound, seed=123)
    val_executor = ad.Executor([y], ctx=ndarray.gpu(device_id), inference=True, dataloader_name='validate', comm_mode=args.comm, bsp=args.bsp)

    path = 'logs/hetulog_%s' % ({None: 'local', 'PS': 'ps', 'Hybrid': 'hybrid'}[args.comm])
    path += '_%d.txt' % rank if args.comm else '.txt'
    log = Logging(path=path)
    epoch = 7
    start = time.time()
    for ep in range(epoch):
        ep_st = time.time()
        log.write('epoch %d' % ep)
        train_loss = []
        for idx in tqdm(range(executor.batch_num)):
            loss_val = executor.run(convert_to_numpy_ret_vals=True)
            train_loss.append(loss_val[0])

            # if idx % 10000 == 0:
            #     hr, ndcg = validate()
            #     printstr = "HR: %.4f, NDCF: %.4f" % (hr, ndcg)
            #     log.write(printstr)

        tra_loss = np.mean(train_loss)
        ep_en = time.time()

        # validate phase
        if args.val:
            hr, ndcg = validate()
            printstr = "train_loss: %.4f, HR: %.4f, NDCF: %.4f, train_time: %.4f" % (tra_loss, hr, ndcg, ep_en - ep_st)
        else:
            printstr = "train_loss: %.4f, train_time: %.4f" % (tra_loss, ep_en - ep_st)
        log.write(printstr)
    log.write('all time: %f' % (time.time() - start))
Beispiel #15
0
def test_init_ps(rarr, init_type, init_a, init_b=1.0, sparse=False):
    assert init_type in ('constant', 'uniform', 'normal', 'truncated_normal')
    init_type_map = {
        'constant': 0,
        'uniform': 1,
        'normal': 2,
        'truncated_normal': 3
    }
    ctx = ndarray.cpu(0)
    rank = int(os.environ["WORKER_ID"])
    nrank = int(os.environ["DMLC_NUM_WORKER"])
    local_arr = np.frombuffer(rarr, dtype=np.float32).reshape(nitem, item_len)
    if rank == 0:
        arr = ndarray.array(local_arr, ctx=ctx)
    else:
        arr = ndarray.empty((nitem, item_len), ctx=ctx)
    comm = ad.get_worker_communicate()
    if sparse:
        arr_len = ctypes.c_int(nitem)
        arr_wid = ctypes.c_int(item_len)
    else:
        arr_len = ctypes.c_int(nitem * item_len)
        arr_wid = ctypes.c_int(1)
    itype = ctypes.c_int(init_type_map[init_type])
    comm.InitTensor(ctypes.c_int(0), ctypes.c_int(sparse), arr_len,
                    arr_wid, itype, ctypes.c_double(init_a),
                    ctypes.c_double(init_b), ctypes.c_ulonglong(123),
                    ctypes.c_int(0), (ctypes.c_float * 1)(0.1),
                    ctypes.c_int(1))

    comm.Pull(ctypes.c_int(0), arr.handle)
    comm.Wait(ctypes.c_int(0))
    if rank == 0:
        local_arr[:] = arr.asnumpy()
    comm.BarrierWorker()
    if rank != 0:
        np.testing.assert_allclose(local_arr, arr.asnumpy(), rtol=5e-7)
    else:
        if init_type == 'constant':
            np.testing.assert_allclose(np.full((nitem, item_len), init_a),
                                       arr.asnumpy(),
                                       rtol=5e-7)
        else:
            if init_type == 'uniform':
                numpy_samples = np.random.uniform(
                    low=init_a, high=init_b,
                    size=(nitem, item_len)).astype(np.float32)
            elif init_type == 'normal':
                numpy_samples = np.random.normal(
                    loc=init_a, scale=init_b,
                    size=(nitem, item_len)).astype(np.float32)
            else:
                numpy_samples = truncnorm.rvs(-2.0,
                                              2.0,
                                              loc=init_a,
                                              scale=init_b,
                                              size=(nitem, item_len)).astype(
                                                  np.float32)
            fig, ax = plt.subplots(1, 1)
            ax.hist(numpy_samples.flatten(),
                    histtype='stepfilled',
                    alpha=0.2,
                    bins=50,
                    label='numpy')
            ax.hist(local_arr.flatten(),
                    histtype='step',
                    alpha=0.2,
                    bins=50,
                    label='ps')
            ax.legend(loc='best', frameon=False)
            # ax2.legend(loc='best', frameon=False)
            file_name = '%s_%.1f_%.1f_%d.png' % (init_type, init_a, init_b,
                                                 int(sparse))
            plt.savefig(file_name)
            print('Check file %s.' % file_name)
    print('Init parameters %d/%d passed.' % (rank, nrank))
    if rank == 0:
        comm.ClearOnServer(0)
    comm.Clear(0)
    comm.BarrierWorker()
Beispiel #16
0
def test_api(rarr, rpush, rpull, sparse=False, lr=0.5):
    ctx = ndarray.cpu(0)
    rank = int(os.environ["WORKER_ID"])
    nrank = int(os.environ["DMLC_NUM_WORKER"])
    local_arr = np.frombuffer(rarr, dtype=np.float32).reshape(nitem,
                                                              item_len).copy()
    local_push = np.frombuffer(rpush, dtype=np.float32).copy()
    local_pull = np.frombuffer(rpull, dtype=np.float32).copy()
    if rank == 0:
        arr = ndarray.array(local_arr, ctx=ctx)
    else:
        arr = ndarray.empty((nitem, item_len), ctx=ctx)
    comm = ad.get_worker_communicate()
    if sparse:
        arr_len = ctypes.c_int(nitem)
        arr_wid = ctypes.c_int(item_len)
    else:
        arr_len = ctypes.c_int(nitem * item_len)
        arr_wid = ctypes.c_int(1)
    comm.InitTensor(ctypes.c_int(0), ctypes.c_int(sparse), arr_len, arr_wid, ctypes.c_int(0), ctypes.c_double(0.0), ctypes.c_double(1.0), ctypes.c_ulonglong(123),\
        ctypes.c_int(0), (ctypes.c_float * 1)(lr), ctypes.c_int(1))
    if sparse:
        local_arr[:] = 0
        for j in local_push:
            local_arr[int(j)] += 1
        if rank == 0:
            push_ind = ndarray.array(local_push.reshape(indx1, indx2), ctx=ctx)
            push_val = ndarray.array(np.ones(
                (indx1, indx2, item_len)).astype(np.float32),
                                     ctx=ctx)
            comm.SparsePush(0, push_ind.handle, push_val.handle, None)
            comm.Wait(0)
        comm.BarrierWorker()
        comm.Pull(0, arr.handle)
        comm.Wait(0)
        np.testing.assert_allclose(local_arr, arr.asnumpy(), rtol=5e-7)
        print('SparsePush DensePull %d/%d passed.' % (rank, nrank))
        comm.BarrierWorker()

        for j in local_push:
            local_arr[int(j)] += 1
        if rank == 0:
            push_ind = ndarray.array(local_push.reshape(indx1, indx2), ctx=ctx)
            push_val = ndarray.array(np.ones(
                (indx1, indx2, item_len)).astype(np.float32),
                                     ctx=ctx)
            comm.SDPushPull(0, push_ind.handle, push_val.handle, arr.handle,
                            None)
            comm.Wait(0)
        comm.BarrierWorker()
        if rank != 0:
            comm.Pull(0, arr.handle)
            comm.Wait(0)
        np.testing.assert_allclose(local_arr, arr.asnumpy(), rtol=5e-7)
        print('SDPushPull %d/%d passed.' % (rank, nrank))
        comm.BarrierWorker()

        for j in local_push:
            local_arr[int(j)] += 1
        pull_ind = ndarray.array(local_pull.reshape(indx1, indx2), ctx=ctx)
        pull_val = ndarray.empty((indx1, indx2, item_len), ctx=ctx)
        if rank == 0:
            push_ind = ndarray.array(local_push.reshape(indx1, indx2), ctx=ctx)
            push_val = ndarray.array(np.ones(
                (indx1, indx2, item_len)).astype(np.float32),
                                     ctx=ctx)
            comm.SSPushPull(0, push_ind.handle, push_val.handle, \
                        pull_ind.handle, pull_val.handle, None)
            comm.Wait(0)
        comm.BarrierWorker()
        if rank != 0:
            comm.SparsePull(0, pull_ind.handle, pull_val.handle)
            comm.Wait(0)
        np.testing.assert_allclose(local_arr[local_pull.astype(int)].reshape(
            indx1, indx2, item_len),
                                   pull_val.asnumpy(),
                                   rtol=5e-7)
        print('SSPushPull and SparsePull %d/%d passed.' % (rank, nrank))
        comm.BarrierWorker()

    else:
        if rank == 0:
            comm.Push(0, arr.handle, None)
            comm.Wait(0)
        comm.BarrierWorker()
        comm.Pull(0, arr.handle)
        comm.Wait(0)
        np.testing.assert_allclose(local_arr, arr.asnumpy(), rtol=5e-7)
        print('DensePush DensePull %d/%d passed.' % (rank, nrank))
        comm.BarrierWorker()
        if rank == 0:
            temp_push_val = ndarray.array(np.ones(
                (nitem, item_len)).astype(np.float32),
                                          ctx=ctx)
            comm.DDPushPull(0, temp_push_val.handle, arr.handle, None)
            comm.Wait(0)
        comm.BarrierWorker()
        if rank != 0:
            comm.Pull(0, arr.handle)
            comm.Wait(0)
        np.testing.assert_allclose(local_arr + 1, arr.asnumpy())
        print('DenseDensePushPull %d/%d passed.' % (rank, nrank))
        comm.BarrierWorker()
    if rank == 0:
        comm.ClearOnServer(0)
    comm.Clear(0)
    comm.BarrierWorker()
Beispiel #17
0
def worker(args):
    def train(iterations, auc_enabled=True, tqdm_enabled=False):
        localiter = tqdm(
            range(iterations)) if tqdm_enabled else range(iterations)
        train_loss = []
        train_acc = []
        if auc_enabled:
            train_auc = []
        for it in localiter:
            loss_val, predict_y, y_val, _ = executor.run(
                convert_to_numpy_ret_vals=True)
            if y_val.shape[1] == 1:  # for criteo case
                acc_val = np.equal(y_val, predict_y > 0.5).astype(np.float)
            else:
                acc_val = np.equal(np.argmax(y_val, 1),
                                   np.argmax(predict_y, 1)).astype(np.float)
            train_loss.append(loss_val[0])
            train_acc.append(acc_val)
            if auc_enabled:
                train_auc.append(metrics.roc_auc_score(y_val, predict_y))
        if auc_enabled:
            return np.mean(train_loss), np.mean(train_acc), np.mean(train_auc)
        else:
            return np.mean(train_loss), np.mean(train_acc)

    def validate(iterations, tqdm_enabled=False):
        localiter = tqdm(
            range(iterations)) if tqdm_enabled else range(iterations)
        test_loss = []
        test_acc = []
        test_auc = []
        for it in localiter:
            loss_val, test_y_predicted, y_test_val = val_executor.run(
                convert_to_numpy_ret_vals=True)
            if y_test_val.shape[1] == 1:  # for criteo case
                correct_prediction = np.equal(
                    y_test_val, test_y_predicted > 0.5).astype(np.float)
            else:
                correct_prediction = np.equal(np.argmax(y_test_val, 1),
                                              np.argmax(test_y_predicted,
                                                        1)).astype(np.float)
            test_loss.append(loss_val[0])
            test_acc.append(correct_prediction)
            test_auc.append(metrics.roc_auc_score(y_test_val,
                                                  test_y_predicted))
        return np.mean(test_loss), np.mean(test_acc), np.mean(test_auc)

    def get_current_shard(data):
        if args.comm is not None:
            part_size = data.shape[0] // nrank
            start = part_size * rank
            end = start + part_size if rank != nrank - 1 else data.shape[0]
            return data[start:end]
        else:
            return data

    batch_size = 128
    dataset = args.dataset
    model = args.model
    device_id = 0

    if args.comm == 'PS':
        rank = ad.get_worker_communicate().rank()
        nrank = int(os.environ['DMLC_NUM_WORKER'])
        device_id = rank % 8
    elif args.comm == 'Hybrid':
        comm, rank = ad.mpi_nccl_init()
        nrank = int(os.environ['DMLC_NUM_WORKER'])
        device_id = rank % 8

    if dataset == 'criteo':
        # define models for criteo
        if args.all:
            from models.load_data import process_all_criteo_data
            dense, sparse, labels = process_all_criteo_data(
                return_val=args.val)
        elif args.val:
            from models.load_data import process_head_criteo_data
            dense, sparse, labels = process_head_criteo_data(return_val=True)
        else:
            from models.load_data import process_sampled_criteo_data
            dense, sparse, labels = process_sampled_criteo_data()
        if isinstance(dense, tuple):
            dense_input = dl.dataloader_op(
                [[get_current_shard(dense[0]), batch_size, 'train'],
                 [get_current_shard(dense[1]), batch_size, 'validate']])
            sparse_input = dl.dataloader_op(
                [[get_current_shard(sparse[0]), batch_size, 'train'],
                 [get_current_shard(sparse[1]), batch_size, 'validate']])
            y_ = dl.dataloader_op(
                [[get_current_shard(labels[0]), batch_size, 'train'],
                 [get_current_shard(labels[1]), batch_size, 'validate']])
        else:
            dense_input = dl.dataloader_op(
                [[get_current_shard(dense), batch_size, 'train']])
            sparse_input = dl.dataloader_op(
                [[get_current_shard(sparse), batch_size, 'train']])
            y_ = dl.dataloader_op(
                [[get_current_shard(labels), batch_size, 'train']])
    elif dataset == 'adult':
        from models.load_data import load_adult_data
        x_train_deep, x_train_wide, y_train, x_test_deep, x_test_wide, y_test = load_adult_data(
        )
        dense_input = [
            dl.dataloader_op([
                [get_current_shard(x_train_deep[:, i]), batch_size, 'train'],
                [get_current_shard(x_test_deep[:, i]), batch_size, 'validate'],
            ]) for i in range(12)
        ]
        sparse_input = dl.dataloader_op([
            [get_current_shard(x_train_wide), batch_size, 'train'],
            [get_current_shard(x_test_wide), batch_size, 'validate'],
        ])
        y_ = dl.dataloader_op([
            [get_current_shard(y_train), batch_size, 'train'],
            [get_current_shard(y_test), batch_size, 'validate'],
        ])
    else:
        raise NotImplementedError
    print("Data loaded.")

    loss, prediction, y_, train_op = model(dense_input, sparse_input, y_)

    executor = ad.Executor([loss, prediction, y_, train_op], ctx=ndarray.gpu(device_id),\
        dataloader_name='train', stream_mode='AllStreams', comm_mode=args.comm, cstable_policy=args.cache, bsp=args.bsp, cache_bound=args.bound, seed=123, log_path='./logs/')
    if args.val:
        print('Validation enabled...')
        val_executor = ad.Executor([loss, prediction, y_], ctx=ndarray.gpu(device_id),\
            dataloader_name='validate', stream_mode='AllStreams', inference=True, comm_mode=args.comm)

    if args.all and dataset == 'criteo':
        print('Processing all data...')
        file_path = '%s_%s' % ({
            None: 'local',
            'PS': 'ps',
            'Hybrid': 'hybrid'
        }[args.comm], args.raw_model)
        file_path += '%d.log' % rank if args.comm else '.log'
        file_path = os.path.join(
            os.path.split(os.path.abspath(__file__))[0], 'logs', file_path)
        log_file = open(file_path, 'w')
        total_epoch = 11
        for ep in range(total_epoch):
            print("ep: %d" % ep)
            ep_st = time.time()
            train_loss, train_acc, train_auc = train(executor.batch_num // 10 +
                                                     (ep % 10 == 9) *
                                                     (executor.batch_num % 10),
                                                     tqdm_enabled=True)
            ep_en = time.time()
            if args.val:
                val_loss, val_acc, val_auc = validate(val_executor.batch_num)
                printstr = "train_loss: %.4f, train_acc: %.4f, train_auc: %.4f, test_loss: %.4f, test_acc: %.4f, test_auc: %.4f, train_time: %.4f"\
                        % (train_loss, train_acc, train_auc, val_loss, val_acc, val_auc, ep_en - ep_st)
            else:
                printstr = "train_loss: %.4f, train_acc: %.4f, train_auc: %.4f, train_time: %.4f"\
                        % (train_loss, train_acc, train_auc, ep_en - ep_st)
            print(printstr)
            log_file.write(printstr + '\n')
            log_file.flush()
    else:
        total_epoch = 50
        for ep in range(total_epoch):
            if ep == 5:
                start = time.time()
            print("epoch %d" % ep)
            ep_st = time.time()
            train_loss, train_acc = train(executor.batch_num,
                                          auc_enabled=False)
            ep_en = time.time()
            if args.val:
                val_loss, val_acc, val_auc = validate(val_executor.batch_num)
                print(
                    "train_loss: %.4f, train_acc: %.4f, train_time: %.4f, test_loss: %.4f, test_acc: %.4f, test_auc: %.4f"
                    % (train_loss, train_acc, ep_en - ep_st, val_loss, val_acc,
                       val_auc))
            else:
                print("train_loss: %.4f, train_acc: %.4f, train_time: %.4f" %
                      (train_loss, train_acc, ep_en - ep_st))
        print('all time:', time.time() - start)
    if args.comm == 'Hybrid':
        ad.mpi_nccl_finish(comm)