Ejemplo n.º 1
0
def start_sparse_adam_worker(rank,
                             device,
                             world_size,
                             weight,
                             tensor_dev='cpu',
                             has_zero_grad=False,
                             backend='gloo',
                             num_embs=128,
                             emb_dim=10):
    print('start sparse worker for adam {}'.format(rank))
    dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
        master_ip='127.0.0.1', master_port='12345')

    if device.type == 'cuda':
        th.cuda.set_device(device)

    th.distributed.init_process_group(backend=backend,
                                      init_method=dist_init_method,
                                      world_size=world_size,
                                      rank=rank)

    init_weight = th.empty((num_embs, emb_dim))
    th.manual_seed(0)
    th.nn.init.uniform_(init_weight, -1.0, 1.0)
    dgl_emb = NodeEmbedding(num_embs,
                            emb_dim,
                            'test',
                            init_func=initializer,
                            device=tensor_dev)
    dgl_emb.all_set_embedding(init_weight)

    if has_zero_grad:
        dgl_emb_zero = NodeEmbedding(num_embs,
                                     emb_dim,
                                     'zero',
                                     init_func=initializer,
                                     device=tensor_dev)
        dgl_adam = SparseAdam(params=[dgl_emb, dgl_emb_zero], lr=0.01)
    else:
        dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01)

    start = (num_embs // world_size) * rank
    end = (num_embs // world_size) * (rank + 1)
    th.manual_seed(rank)
    idx = th.randint(start, end, size=(4, )).to(tensor_dev)
    dgl_value = dgl_emb(idx, device)
    labels = th.ones((4, )).long().to(device)
    dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
    dgl_adam.zero_grad()
    dgl_loss.backward()
    dgl_adam.step()
    th.distributed.barrier()
    dgl_weight = dgl_emb.all_get_embedding().detach()
    after_step = dgl_emb(idx, device).cpu()

    if rank == 0:
        dgl_value = dgl_value.detach().cpu()
        assert F.allclose(dgl_value, after_step) is False
        weight[:] = dgl_weight[:]
    th.distributed.barrier()
Ejemplo n.º 2
0
def check_all_set_all_get_func(device, init_emb):
    num_embs = init_emb.shape[0]
    emb_dim = init_emb.shape[1]
    dgl_emb = NodeEmbedding(num_embs, emb_dim, 'test', device=device)
    dgl_emb.all_set_embedding(init_emb)

    out_emb = dgl_emb.all_get_embedding()
    assert F.allclose(init_emb, out_emb)