Ejemplo n.º 1
0
def test_call_many(hidden_dim=16):
    k_min = 1
    timeout_after_k_min = None
    backward_k_min = 1
    forward_timeout = None
    backward_timeout = None
    detect_anomalies = False
    allow_zero_outputs = False
    atol = 1e-5

    with background_server(num_experts=5,
                           device='cpu',
                           expert_cls='ffn',
                           num_handlers=1,
                           hidden_dim=hidden_dim,
                           optim_cls=None,
                           no_dht=True) as (server_endpoint, dht_endpoint):
        inputs = torch.randn(4, hidden_dim, requires_grad=True)
        inputs_clone = inputs.clone().detach().requires_grad_(True)
        e0, e1, e2, e3, e4 = [
            hivemind.RemoteExpert(f'expert.{i}', server_endpoint)
            for i in range(5)
        ]
        e5 = hivemind.RemoteExpert(f'thisshouldnotexist', '127.0.0.1:80')

        mask, expert_outputs = hivemind.client.moe._RemoteCallMany.apply(
            DUMMY, [[e0, e1, e2], [e2, e4], [e1, e5, e3], []], k_min,
            backward_k_min, timeout_after_k_min, forward_timeout,
            backward_timeout, detect_anomalies, allow_zero_outputs, e1.info,
            inputs)
        assert mask.shape == (4, 3)
        assert expert_outputs.shape == (4, 3, hidden_dim)

        assert np.all(mask.data.numpy() == np.array(
            [[True, True, True], [True, True, False], [True, False, True],
             [False, False, False]])), f"Incorrect mask, {mask}"

        reference_outputs = torch.zeros_like(expert_outputs)
        reference_outputs[0, 0] = e0(inputs_clone[0:1])
        reference_outputs[0, 1] = e1(inputs_clone[0:1])
        reference_outputs[0, 2] = e2(inputs_clone[0:1])
        reference_outputs[1, 0] = e2(inputs_clone[1:2])
        reference_outputs[1, 1] = e4(inputs_clone[1:2])
        reference_outputs[2, 0] = e1(inputs_clone[2:3])
        reference_outputs[2, 2] = e3(inputs_clone[2:3])

        assert torch.allclose(expert_outputs,
                              reference_outputs,
                              atol=atol,
                              rtol=0)
        proj = torch.randn(4, hidden_dim)
        loss = (expert_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
        loss.backward()
        our_grad = inputs.grad.data.cpu().clone()

        reference_loss = (reference_outputs[(0, 1, 1, 2),
                                            (0, 2, 1, 0)] * proj).sum()
        reference_loss.backward()
        reference_grad = inputs_clone.grad.data.cpu().clone()
        assert torch.allclose(our_grad, reference_grad, atol=atol, rtol=0)
Ejemplo n.º 2
0
def test_remote_module_call(hidden_dim=16):
    with background_server(num_experts=1,
                           device='cpu',
                           expert_cls='ffn',
                           num_handlers=1,
                           hidden_dim=hidden_dim,
                           optim_cls=None,
                           no_dht=True) as (server_endpoint, dht_endpoint):
        real_expert = hivemind.RemoteExpert('expert.0', server_endpoint)
        fake_expert = hivemind.RemoteExpert('oiasfjiasjf', server_endpoint)

        out1 = real_expert(torch.randn(1, hidden_dim))
        assert out1.shape == (1, hidden_dim)
        dummy_x = torch.randn(3, hidden_dim, requires_grad=True)
        out3 = real_expert(dummy_x)
        assert out3.shape == (3, hidden_dim)
        out3_again = real_expert(dummy_x[1:])
        assert torch.allclose(out3_again, out3[1:], atol=1e-5, rtol=0)
        out3_again.norm().backward()
        assert dummy_x.grad is not None and dummy_x.grad.norm() > 0

        with pytest.raises(grpc.RpcError):
            real_expert(torch.randn(3, 11))
        with pytest.raises(grpc.RpcError):
            fake_expert(dummy_x)
Ejemplo n.º 3
0
def test_compute_expert_scores():
    try:
        dht = hivemind.DHT(start=True)
        moe = hivemind.client.moe.RemoteMixtureOfExperts(dht=dht,
                                                         in_features=16,
                                                         grid_size=(40, ),
                                                         k_best=4,
                                                         k_min=1,
                                                         timeout_after_k_min=1,
                                                         uid_prefix='expert.')
        gx, gy = torch.randn(4, 5, requires_grad=True), torch.randn(
            4, 3, requires_grad=True)
        ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
        jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
        batch_experts = [[
            hivemind.RemoteExpert(
                uid=f'expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}',
                endpoint="[::]:1337") for expert_i in range(len(ii[batch_i]))
        ] for batch_i in range(
            len(ii)
        )]  # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
        logits = moe.compute_expert_scores([gx, gy], batch_experts)
        torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
        assert gx.grad.norm().item() > 0 and gy.grad.norm().item(
        ), "compute_expert_scores didn't backprop"

        for batch_i in range(len(ii)):
            for expert_i in range(len(ii[batch_i])):
                assert torch.allclose(logits[batch_i, expert_i],
                                      gx[batch_i, ii[batch_i][expert_i]] + gy[batch_i, jj[batch_i][expert_i]]), \
                    "compute_expert_scores returned incorrect score"
    finally:
        dht.shutdown()
Ejemplo n.º 4
0
def test_determinism(hidden_dim=16):
    atol = 1e-5

    xx = torch.randn(32, hidden_dim, requires_grad=True)
    mask = torch.randint(0, 1, (32, hidden_dim))

    with background_server(num_experts=1,
                           device='cpu',
                           expert_cls='det_dropout',
                           num_handlers=1,
                           hidden_dim=hidden_dim,
                           optim_cls=None,
                           no_dht=True) as (server_endpoint, dht_endpoint):
        expert = hivemind.RemoteExpert(uid=f'expert.0',
                                       endpoint=server_endpoint)

        out = expert(xx, mask)
        out_rerun = expert(xx, mask)

        grad, = torch.autograd.grad(out.sum(), xx, retain_graph=True)
        grad_rerun, = torch.autograd.grad(out_rerun.sum(),
                                          xx,
                                          retain_graph=True)

    assert torch.allclose(
        out, out_rerun, atol=atol,
        rtol=0), "Dropout layer outputs are non-deterministic."
    assert torch.allclose(grad, grad_rerun, atol=atol,
                          rtol=0), "Gradients are non-deterministic."
Ejemplo n.º 5
0
def client_process(can_start,
                   benchmarking_failed,
                   port,
                   num_experts,
                   batch_size,
                   hid_dim,
                   num_batches,
                   backprop=True):
    torch.set_num_threads(1)
    can_start.wait()
    experts = [
        hivemind.RemoteExpert(f"expert{i}",
                              endpoint=f"{hivemind.LOCALHOST}:{port}")
        for i in range(num_experts)
    ]

    try:
        dummy_batch = torch.randn(batch_size, hid_dim)
        for batch_i in range(num_batches):
            expert = random.choice(experts)
            out = expert(dummy_batch)
            if backprop:
                out.sum().backward()
    except BaseException as e:
        benchmarking_failed.set()
        raise e
Ejemplo n.º 6
0
def test_moe_beam_search():
    all_expert_uids = [
        f'ffn.{5 + i}.{10 + j}.{15 + k}' for i in range(10) for j in range(10)
        for k in range(10)
    ]
    dht = hivemind.DHT(start=True, expiration=999)
    assert all(dht.declare_experts(all_expert_uids, endpoint='fake-endpoint'))

    dmoe = hivemind.RemoteMixtureOfExperts(in_features=32,
                                           grid_size=(32, 32, 32),
                                           dht=dht,
                                           k_best=4,
                                           uid_prefix='ffn')

    for i in range(25):
        input = torch.randn(32)
        grid_scores = dmoe.proj(input).split_with_sizes(dmoe.grid_size, dim=-1)

        chosen_experts = dmoe.loop.run_until_complete(
            dmoe.beam_search(grid_scores, k_best=dmoe.k_best))

        chosen_scores = dmoe.compute_expert_scores(
            [dim_scores[None] for dim_scores in grid_scores],
            [chosen_experts])[0]

        all_scores = dmoe.compute_expert_scores(
            [dim_scores[None] for dim_scores in grid_scores],
            [[hivemind.RemoteExpert(uid, '') for uid in all_expert_uids]])[0]
        true_best_scores = sorted(all_scores.cpu().detach().numpy(),
                                  reverse=True)[:len(chosen_experts)]
        our_best_scores = list(chosen_scores.cpu().detach().numpy())
        assert np.allclose(true_best_scores, our_best_scores)
Ejemplo n.º 7
0
def test_beam_search_correctness():
    all_expert_uids = [
        f'ffn.{5 + i}.{10 + j}.{15 + k}' for i in range(10) for j in range(10)
        for k in range(10)
    ]
    dht = hivemind.DHT(start=True, expiration=999)
    assert all(dht.declare_experts(all_expert_uids, endpoint='fake-endpoint'))

    dmoe = hivemind.RemoteMixtureOfExperts(in_features=32,
                                           grid_size=(32, 32, 32),
                                           dht=dht,
                                           k_best=4,
                                           uid_prefix='ffn.')

    for i in range(25):
        input = torch.randn(32)
        grid_scores = dmoe.proj(input).split_with_sizes(dmoe.grid_size, dim=-1)

        chosen_experts = dht.find_best_experts(
            dmoe.uid_prefix,
            [tensor.detach().numpy() for tensor in grid_scores],
            beam_size=dmoe.k_best)
        chosen_scores = dmoe.compute_expert_scores(
            [dim_scores[None] for dim_scores in grid_scores],
            [chosen_experts])[0]
        our_best_scores = list(chosen_scores.cpu().detach().numpy())

        # reference: independently find :beam_size: best experts with exhaustive search
        all_scores = dmoe.compute_expert_scores(
            [dim_scores.unsqueeze(0) for dim_scores in grid_scores],
            [[hivemind.RemoteExpert(uid, '') for uid in all_expert_uids]])[0]
        true_best_scores = sorted(all_scores.cpu().detach().numpy(),
                                  reverse=True)[:len(chosen_experts)]

        assert np.allclose(true_best_scores, our_best_scores)
Ejemplo n.º 8
0
def test_remote_module_call():
    """ Check that remote_module_call returns correct outputs and gradients if called directly """
    num_experts = 8
    k_min = 1
    timeout_after_k_min = None
    backward_k_min = 1
    timeout_total = None
    backward_timeout = None
    rtol = 1e-3
    atol = 1e-6

    xx = torch.randn(32, 1024, requires_grad=True)
    logits = torch.randn(3, requires_grad=True)
    random_proj = torch.randn_like(xx)

    with background_server(num_experts=num_experts,
                           device='cpu',
                           no_optimizer=True,
                           no_dht=True) as (localhost, server_port, dht_port):
        experts = [
            hivemind.RemoteExpert(uid=f'expert.{i}', port=server_port)
            for i in range(num_experts)
        ]
        moe_output, = hivemind.client.moe._RemoteMoECall.apply(
            logits, experts[:len(logits)], k_min, timeout_after_k_min,
            backward_k_min, timeout_total, backward_timeout, [(None, ), {}],
            xx)

        grad_xx_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output),
                                           xx,
                                           retain_graph=True)
        grad_logits_moe, = torch.autograd.grad(torch.sum(random_proj *
                                                         moe_output),
                                               logits,
                                               retain_graph=True)

        # reference outputs: call all experts manually and average their outputs with softmax probabilities
        probs = torch.softmax(logits, 0)
        outs = [expert(xx) for expert in experts[:3]]
        manual_output = sum(p * x for p, x in zip(probs, outs))
        grad_xx_manual, = torch.autograd.grad(torch.sum(random_proj *
                                                        manual_output),
                                              xx,
                                              retain_graph=True)
        grad_xx_manual_rerun, = torch.autograd.grad(torch.sum(random_proj *
                                                              manual_output),
                                                    xx,
                                                    retain_graph=True)
        grad_logits_manual, = torch.autograd.grad(torch.sum(random_proj *
                                                            manual_output),
                                                  logits,
                                                  retain_graph=True)

    assert torch.allclose(grad_xx_manual, grad_xx_manual_rerun, rtol, atol), "Experts are non-deterministic. The test" \
                                                                             " is only valid for deterministic experts"
    assert torch.allclose(moe_output, manual_output, rtol,
                          atol), "_RemoteMoECall returned incorrect output"
    assert torch.allclose(grad_xx_moe, grad_xx_manual, rtol,
                          atol), "incorrect gradient w.r.t. input"
    assert torch.allclose(grad_logits_moe, grad_logits_manual, rtol,
                          atol), "incorrect gradient w.r.t. logits"