Exemplo n.º 1
0
def test_training(max_steps: int = 100, threshold: float = 0.9):
    dataset = load_digits(n_class=2)
    X_train, y_train = torch.tensor(
        dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
    SGD = partial(torch.optim.SGD, lr=0.05)

    with background_server(num_experts=2,
                           device='cpu',
                           optim_cls=SGD,
                           hidden_dim=64,
                           num_handlers=1,
                           no_dht=True) as (server_endpoint, dht_endpoint):
        expert1 = RemoteExpert('expert.0', server_endpoint)
        expert2 = RemoteExpert('expert.1', server_endpoint)
        model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2))

        opt = torch.optim.SGD(model.parameters(), lr=0.05)

        for step in range(max_steps):
            opt.zero_grad()

            outputs = model(X_train)
            loss = F.cross_entropy(outputs, y_train)
            loss.backward()
            opt.step()

            accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
            if accuracy >= threshold:
                break

        assert accuracy >= threshold, f"too small accuracy: {accuracy}"
Exemplo n.º 2
0
def test_remote_module_call(hidden_dim=16):
    with background_server(num_experts=1,
                           device='cpu',
                           expert_cls='ffn',
                           num_handlers=1,
                           hidden_dim=hidden_dim,
                           optim_cls=None,
                           no_dht=True) as (server_endpoint, dht_endpoint):
        real_expert = hivemind.RemoteExpert('expert.0', server_endpoint)
        fake_expert = hivemind.RemoteExpert('oiasfjiasjf', server_endpoint)

        out1 = real_expert(torch.randn(1, hidden_dim))
        assert out1.shape == (1, hidden_dim)
        dummy_x = torch.randn(3, hidden_dim, requires_grad=True)
        out3 = real_expert(dummy_x)
        assert out3.shape == (3, hidden_dim)
        out3_again = real_expert(dummy_x[1:])
        assert torch.allclose(out3_again, out3[1:], atol=1e-5, rtol=0)
        out3_again.norm().backward()
        assert dummy_x.grad is not None and dummy_x.grad.norm() > 0

        with pytest.raises(grpc.RpcError):
            real_expert(torch.randn(3, 11))
        with pytest.raises(grpc.RpcError):
            fake_expert(dummy_x)
Exemplo n.º 3
0
def test_call_many(hidden_dim=16):
    k_min = 1
    timeout_after_k_min = None
    backward_k_min = 1
    forward_timeout = None
    backward_timeout = None
    detect_anomalies = False
    allow_zero_outputs = False
    atol = 1e-5

    with background_server(num_experts=5,
                           device='cpu',
                           expert_cls='ffn',
                           num_handlers=1,
                           hidden_dim=hidden_dim,
                           optim_cls=None,
                           no_dht=True) as (server_endpoint, dht_endpoint):
        inputs = torch.randn(4, hidden_dim, requires_grad=True)
        inputs_clone = inputs.clone().detach().requires_grad_(True)
        e0, e1, e2, e3, e4 = [
            hivemind.RemoteExpert(f'expert.{i}', server_endpoint)
            for i in range(5)
        ]
        e5 = hivemind.RemoteExpert(f'thisshouldnotexist', '127.0.0.1:80')

        mask, expert_outputs = hivemind.client.moe._RemoteCallMany.apply(
            DUMMY, [[e0, e1, e2], [e2, e4], [e1, e5, e3], []], k_min,
            backward_k_min, timeout_after_k_min, forward_timeout,
            backward_timeout, detect_anomalies, allow_zero_outputs, e1.info,
            inputs)
        assert mask.shape == (4, 3)
        assert expert_outputs.shape == (4, 3, hidden_dim)

        assert np.all(mask.data.numpy() == np.array(
            [[True, True, True], [True, True, False], [True, False, True],
             [False, False, False]])), f"Incorrect mask, {mask}"

        reference_outputs = torch.zeros_like(expert_outputs)
        reference_outputs[0, 0] = e0(inputs_clone[0:1])
        reference_outputs[0, 1] = e1(inputs_clone[0:1])
        reference_outputs[0, 2] = e2(inputs_clone[0:1])
        reference_outputs[1, 0] = e2(inputs_clone[1:2])
        reference_outputs[1, 1] = e4(inputs_clone[1:2])
        reference_outputs[2, 0] = e1(inputs_clone[2:3])
        reference_outputs[2, 2] = e3(inputs_clone[2:3])

        assert torch.allclose(expert_outputs,
                              reference_outputs,
                              atol=atol,
                              rtol=0)
        proj = torch.randn(4, hidden_dim)
        loss = (expert_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
        loss.backward()
        our_grad = inputs.grad.data.cpu().clone()

        reference_loss = (reference_outputs[(0, 1, 1, 2),
                                            (0, 2, 1, 0)] * proj).sum()
        reference_loss.backward()
        reference_grad = inputs_clone.grad.data.cpu().clone()
        assert torch.allclose(our_grad, reference_grad, atol=atol, rtol=0)
Exemplo n.º 4
0
def test_determinism(hidden_dim=16):
    atol = 1e-5

    xx = torch.randn(32, hidden_dim, requires_grad=True)
    mask = torch.randint(0, 1, (32, hidden_dim))

    with background_server(num_experts=1,
                           device='cpu',
                           expert_cls='det_dropout',
                           num_handlers=1,
                           hidden_dim=hidden_dim,
                           optim_cls=None,
                           no_dht=True) as (server_endpoint, dht_endpoint):
        expert = hivemind.RemoteExpert(uid=f'expert.0',
                                       endpoint=server_endpoint)

        out = expert(xx, mask)
        out_rerun = expert(xx, mask)

        grad, = torch.autograd.grad(out.sum(), xx, retain_graph=True)
        grad_rerun, = torch.autograd.grad(out_rerun.sum(),
                                          xx,
                                          retain_graph=True)

    assert torch.allclose(
        out, out_rerun, atol=atol,
        rtol=0), "Dropout layer outputs are non-deterministic."
    assert torch.allclose(grad, grad_rerun, atol=atol,
                          rtol=0), "Gradients are non-deterministic."
Exemplo n.º 5
0
def test_moe_training(max_steps: int = 100,
                      threshold: float = 0.9,
                      num_experts=2):
    dataset = load_digits(n_class=2)
    X_train, y_train = torch.tensor(
        dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
    SGD = partial(torch.optim.SGD, lr=0.05)

    all_expert_uids = [f'expert.{i}' for i in range(num_experts)]
    with background_server(expert_uids=all_expert_uids, device='cpu', optim_cls=SGD, hidden_dim=64, num_handlers=1) \
            as (server_endpoint, dht_endpoint):
        dht = DHT(start=True, initial_peers=[dht_endpoint])

        moe = RemoteMixtureOfExperts(in_features=64,
                                     grid_size=(num_experts, ),
                                     dht=dht,
                                     uid_prefix='expert.',
                                     k_best=2)
        model = nn.Sequential(moe, nn.Linear(64, 2))

        opt = SGD(model.parameters(), lr=0.05)

        for step in range(max_steps):
            outputs = model(X_train)
            loss = F.cross_entropy(outputs, y_train)
            loss.backward()
            opt.step()
            opt.zero_grad()

            accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
            if accuracy >= threshold:
                break

        assert accuracy >= threshold, f"too small accuracy: {accuracy}"
Exemplo n.º 6
0
def test_switch_training(max_steps: int = 10,
                         threshold: float = 0.9,
                         num_experts=5):
    dataset = load_digits(n_class=2)
    X_train, y_train = torch.tensor(
        dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
    SGD = partial(torch.optim.SGD, lr=0.05)

    all_expert_uids = [f'expert.{i}' for i in range(num_experts)]
    with background_server(expert_uids=all_expert_uids,
                           device='cpu',
                           optim_cls=SGD,
                           hidden_dim=64,
                           num_handlers=1) as (server_endpoint, dht_endpoint):
        dht = DHT(start=True, initial_peers=[dht_endpoint])

        model = SwitchNetwork(dht, 64, 2, num_experts)
        opt = SGD(model.parameters(), lr=0.05)

        for step in range(max_steps):
            outputs, balancing_loss = model(X_train)
            loss = F.cross_entropy(outputs, y_train) + 0.01 * balancing_loss
            loss.backward()
            opt.step()
            opt.zero_grad()

            accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
            if accuracy >= threshold:
                break

        assert model.moe.grid_utilization.min().item() > (1 / num_experts) / 2
        assert accuracy >= threshold, f"too small accuracy: {accuracy}"
Exemplo n.º 7
0
def test_moe():
    all_expert_uids = [
        f'ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}'
        for _ in range(20)
    ]
    with background_server(expert_uids=all_expert_uids,
                           device='cpu',
                           expert_cls='ffn',
                           num_handlers=1,
                           hidden_dim=16) as (server_endpoint, dht_endpoint):
        dht = hivemind.DHT(start=True,
                           expiration=999,
                           initial_peers=[dht_endpoint])
        # declare expert uids. Server *should* declare them by itself, but it takes time.
        assert all(
            dht.declare_experts(all_expert_uids, endpoint=server_endpoint))

        dmoe = hivemind.RemoteMixtureOfExperts(in_features=16,
                                               grid_size=(32, 32, 32),
                                               dht=dht,
                                               k_best=3,
                                               uid_prefix='ffn')

        for i in range(10):
            out = dmoe(torch.randn(10, 16))
            out.sum().backward()
Exemplo n.º 8
0
def test_multihead_expert(hid_dim=16):
    with background_server(
            expert_cls='multihead',
            num_experts=2,
            device='cpu',
            hidden_dim=hid_dim,
            num_handlers=2,
            no_dht=True,
            custom_module_path=CUSTOM_EXPERTS_PATH) as (server_endpoint, _):
        expert0 = RemoteExpert('expert.0', server_endpoint)
        expert1 = RemoteExpert('expert.1', server_endpoint)

        for batch_size in (1, 4):
            batch = (torch.randn(batch_size,
                                 hid_dim), torch.randn(batch_size,
                                                       2 * hid_dim),
                     torch.randn(batch_size, 3 * hid_dim))

            output0 = expert0(*batch)
            output1 = expert1(*batch)

            loss = output0.sum()
            loss.backward()
            loss = output1.sum()
            loss.backward()
Exemplo n.º 9
0
def test_moe():
    all_expert_uids = [
        f'ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}'
        for _ in range(10)
    ]
    with background_server(expert_uids=all_expert_uids,
                           device='cpu',
                           expert_cls='ffn',
                           num_handlers=1,
                           hidden_dim=16) as (server_endpoint, dht_endpoint):
        dht = hivemind.DHT(start=True, initial_peers=[dht_endpoint])

        dmoe = hivemind.RemoteMixtureOfExperts(in_features=16,
                                               grid_size=(4, 4, 4),
                                               dht=dht,
                                               k_best=3,
                                               uid_prefix='ffn.')

        for i in range(3):
            out = dmoe(torch.randn(10, 16))
            out.sum().backward()
Exemplo n.º 10
0
def test_no_experts():
    all_expert_uids = [
        f'expert.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}'
        for _ in range(10)
    ]
    with background_server(expert_uids=all_expert_uids,
                           device='cpu',
                           expert_cls='nop_delay',
                           num_handlers=1,
                           hidden_dim=16) as (server_endpoint, dht_endpoint):
        dht = hivemind.DHT(start=True, initial_peers=[dht_endpoint])

        dmoe = hivemind.RemoteSwitchMixtureOfExperts(in_features=16,
                                                     grid_size=(4, 4, 4),
                                                     dht=dht,
                                                     uid_prefix='expert.',
                                                     forward_timeout=0.1,
                                                     backward_timeout=0.1,
                                                     allow_zero_outputs=True)

        for i in range(3):
            out, balancing_loss = dmoe(torch.randn(10, 16))
            out.sum().backward()