Ejemplo n.º 1
0
def check_all_set_all_get_func(device, init_emb):
    num_embs = init_emb.shape[0]
    emb_dim = init_emb.shape[1]
    dgl_emb = NodeEmbedding(num_embs, emb_dim, 'test', device=device)
    dgl_emb.all_set_embedding(init_emb)

    out_emb = dgl_emb.all_get_embedding()
    assert F.allclose(init_emb, out_emb)
Ejemplo n.º 2
0
def test_sparse_adam():
    num_embs = 10
    emb_dim = 4
    device = F.ctx()
    dgl_emb = NodeEmbedding(num_embs, emb_dim, 'test')
    torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)
    th.manual_seed(0)
    th.nn.init.uniform_(torch_emb.weight, 0, 1.0)
    th.manual_seed(0)
    th.nn.init.uniform_(dgl_emb.weight, 0, 1.0)

    dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01)
    torch_adam = th.optim.SparseAdam(list(torch_emb.parameters()), lr=0.01)

    # first step
    idx = th.randint(0, num_embs, size=(4, ))
    dgl_value = dgl_emb(idx, device).to(th.device('cpu'))
    torch_value = torch_emb(idx)
    labels = th.ones((4, )).long()

    dgl_adam.zero_grad()
    torch_adam.zero_grad()
    dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
    torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
    dgl_loss.backward()
    torch_loss.backward()

    dgl_adam.step()
    torch_adam.step()
    assert F.allclose(dgl_emb.weight, torch_emb.weight)
Ejemplo n.º 3
0
def start_sparse_adam_worker(rank,
                             device,
                             world_size,
                             weight,
                             tensor_dev='cpu',
                             has_zero_grad=False,
                             backend='gloo',
                             num_embs=128,
                             emb_dim=10):
    print('start sparse worker for adam {}'.format(rank))
    dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
        master_ip='127.0.0.1', master_port='12345')

    if device.type == 'cuda':
        th.cuda.set_device(device)

    th.distributed.init_process_group(backend=backend,
                                      init_method=dist_init_method,
                                      world_size=world_size,
                                      rank=rank)

    init_weight = th.empty((num_embs, emb_dim))
    th.manual_seed(0)
    th.nn.init.uniform_(init_weight, -1.0, 1.0)
    dgl_emb = NodeEmbedding(num_embs,
                            emb_dim,
                            'test',
                            init_func=initializer,
                            device=tensor_dev)
    dgl_emb.all_set_embedding(init_weight)

    if has_zero_grad:
        dgl_emb_zero = NodeEmbedding(num_embs,
                                     emb_dim,
                                     'zero',
                                     init_func=initializer,
                                     device=tensor_dev)
        dgl_adam = SparseAdam(params=[dgl_emb, dgl_emb_zero], lr=0.01)
    else:
        dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01)

    start = (num_embs // world_size) * rank
    end = (num_embs // world_size) * (rank + 1)
    th.manual_seed(rank)
    idx = th.randint(start, end, size=(4, )).to(tensor_dev)
    dgl_value = dgl_emb(idx, device)
    labels = th.ones((4, )).long().to(device)
    dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
    dgl_adam.zero_grad()
    dgl_loss.backward()
    dgl_adam.step()
    th.distributed.barrier()
    dgl_weight = dgl_emb.all_get_embedding().detach()
    after_step = dgl_emb(idx, device).cpu()

    if rank == 0:
        dgl_value = dgl_value.detach().cpu()
        assert F.allclose(dgl_value, after_step) is False
        weight[:] = dgl_weight[:]
    th.distributed.barrier()
Ejemplo n.º 4
0
def start_sparse_adam_worker(rank,
                             world_size,
                             has_zero_grad=False,
                             num_embs=128,
                             emb_dim=10):
    print('start sparse worker for adam {}'.format(rank))
    dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
        master_ip='127.0.0.1', master_port='12345')
    backend = 'gloo'
    device = F.ctx()

    th.distributed.init_process_group(backend=backend,
                                      init_method=dist_init_method,
                                      world_size=world_size,
                                      rank=rank)

    dgl_emb = NodeEmbedding(num_embs, emb_dim, 'test', init_func=initializer)
    torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)
    th.manual_seed(0)
    th.nn.init.uniform_(torch_emb.weight, -1.0, 1.0)
    torch_emb = th.nn.parallel.DistributedDataParallel(torch_emb)

    if has_zero_grad:
        dgl_emb_zero = NodeEmbedding(num_embs,
                                     emb_dim,
                                     'zero',
                                     init_func=initializer)
        torch_emb_zero = th.nn.Embedding(num_embs, emb_dim, sparse=True)
        th.manual_seed(0)
        th.nn.init.uniform_(torch_emb_zero.weight, -1.0, 1.0)
        torch_emb_zero = th.nn.parallel.DistributedDataParallel(torch_emb_zero)

        dgl_adam = SparseAdam(params=[dgl_emb, dgl_emb_zero], lr=0.01)
        torch_adam = th.optim.SparseAdam(
            list(torch_emb.module.parameters()) +
            list(torch_emb_zero.module.parameters()),
            lr=0.01)
    else:
        dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01)
        torch_adam = th.optim.SparseAdam(list(torch_emb.module.parameters()),
                                         lr=0.01)

    start = (num_embs // world_size) * rank
    end = (num_embs // world_size) * (rank + 1)
    idx = th.randint(start, end, size=(4, ))
    dgl_value = dgl_emb(idx, device).to(th.device('cpu'))
    torch_value = torch_emb(idx)
    labels = th.ones((4, )).long()

    dgl_adam.zero_grad()
    dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
    dgl_loss.backward()
    dgl_adam.step()
    torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
    torch_adam.zero_grad()
    torch_loss.backward()
    torch_adam.step()
    if rank == 0:
        after_step = dgl_emb(idx, device)
        assert F.allclose(dgl_emb.weight, torch_emb.module.weight)
        assert F.allclose(dgl_value, after_step) is False
    th.distributed.barrier()