def test_training(max_steps: int = 100, threshold: float = 0.9): dataset = load_digits(n_class=2) X_train, y_train = torch.tensor( dataset['data'], dtype=torch.float), torch.tensor(dataset['target']) SGD = partial(torch.optim.SGD, lr=0.05) with background_server(num_experts=2, device='cpu', optim_cls=SGD, hidden_dim=64, num_handlers=1, no_dht=True) as (server_endpoint, dht_endpoint): expert1 = RemoteExpert('expert.0', server_endpoint) expert2 = RemoteExpert('expert.1', server_endpoint) model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2)) opt = torch.optim.SGD(model.parameters(), lr=0.05) for step in range(max_steps): opt.zero_grad() outputs = model(X_train) loss = F.cross_entropy(outputs, y_train) loss.backward() opt.step() accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item() if accuracy >= threshold: break assert accuracy >= threshold, f"too small accuracy: {accuracy}"
def test_remote_module_call(hidden_dim=16): with background_server(num_experts=1, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=hidden_dim, optim_cls=None, no_dht=True) as (server_endpoint, dht_endpoint): real_expert = hivemind.RemoteExpert('expert.0', server_endpoint) fake_expert = hivemind.RemoteExpert('oiasfjiasjf', server_endpoint) out1 = real_expert(torch.randn(1, hidden_dim)) assert out1.shape == (1, hidden_dim) dummy_x = torch.randn(3, hidden_dim, requires_grad=True) out3 = real_expert(dummy_x) assert out3.shape == (3, hidden_dim) out3_again = real_expert(dummy_x[1:]) assert torch.allclose(out3_again, out3[1:], atol=1e-5, rtol=0) out3_again.norm().backward() assert dummy_x.grad is not None and dummy_x.grad.norm() > 0 with pytest.raises(grpc.RpcError): real_expert(torch.randn(3, 11)) with pytest.raises(grpc.RpcError): fake_expert(dummy_x)
def test_call_many(hidden_dim=16): k_min = 1 timeout_after_k_min = None backward_k_min = 1 forward_timeout = None backward_timeout = None detect_anomalies = False allow_zero_outputs = False atol = 1e-5 with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=hidden_dim, optim_cls=None, no_dht=True) as (server_endpoint, dht_endpoint): inputs = torch.randn(4, hidden_dim, requires_grad=True) inputs_clone = inputs.clone().detach().requires_grad_(True) e0, e1, e2, e3, e4 = [ hivemind.RemoteExpert(f'expert.{i}', server_endpoint) for i in range(5) ] e5 = hivemind.RemoteExpert(f'thisshouldnotexist', '127.0.0.1:80') mask, expert_outputs = hivemind.client.moe._RemoteCallMany.apply( DUMMY, [[e0, e1, e2], [e2, e4], [e1, e5, e3], []], k_min, backward_k_min, timeout_after_k_min, forward_timeout, backward_timeout, detect_anomalies, allow_zero_outputs, e1.info, inputs) assert mask.shape == (4, 3) assert expert_outputs.shape == (4, 3, hidden_dim) assert np.all(mask.data.numpy() == np.array( [[True, True, True], [True, True, False], [True, False, True], [False, False, False]])), f"Incorrect mask, {mask}" reference_outputs = torch.zeros_like(expert_outputs) reference_outputs[0, 0] = e0(inputs_clone[0:1]) reference_outputs[0, 1] = e1(inputs_clone[0:1]) reference_outputs[0, 2] = e2(inputs_clone[0:1]) reference_outputs[1, 0] = e2(inputs_clone[1:2]) reference_outputs[1, 1] = e4(inputs_clone[1:2]) reference_outputs[2, 0] = e1(inputs_clone[2:3]) reference_outputs[2, 2] = e3(inputs_clone[2:3]) assert torch.allclose(expert_outputs, reference_outputs, atol=atol, rtol=0) proj = torch.randn(4, hidden_dim) loss = (expert_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum() loss.backward() our_grad = inputs.grad.data.cpu().clone() reference_loss = (reference_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum() reference_loss.backward() reference_grad = inputs_clone.grad.data.cpu().clone() assert torch.allclose(our_grad, reference_grad, atol=atol, rtol=0)
def test_determinism(hidden_dim=16): atol = 1e-5 xx = torch.randn(32, hidden_dim, requires_grad=True) mask = torch.randint(0, 1, (32, hidden_dim)) with background_server(num_experts=1, device='cpu', expert_cls='det_dropout', num_handlers=1, hidden_dim=hidden_dim, optim_cls=None, no_dht=True) as (server_endpoint, dht_endpoint): expert = hivemind.RemoteExpert(uid=f'expert.0', endpoint=server_endpoint) out = expert(xx, mask) out_rerun = expert(xx, mask) grad, = torch.autograd.grad(out.sum(), xx, retain_graph=True) grad_rerun, = torch.autograd.grad(out_rerun.sum(), xx, retain_graph=True) assert torch.allclose( out, out_rerun, atol=atol, rtol=0), "Dropout layer outputs are non-deterministic." assert torch.allclose(grad, grad_rerun, atol=atol, rtol=0), "Gradients are non-deterministic."
def test_moe_training(max_steps: int = 100, threshold: float = 0.9, num_experts=2): dataset = load_digits(n_class=2) X_train, y_train = torch.tensor( dataset['data'], dtype=torch.float), torch.tensor(dataset['target']) SGD = partial(torch.optim.SGD, lr=0.05) all_expert_uids = [f'expert.{i}' for i in range(num_experts)] with background_server(expert_uids=all_expert_uids, device='cpu', optim_cls=SGD, hidden_dim=64, num_handlers=1) \ as (server_endpoint, dht_endpoint): dht = DHT(start=True, initial_peers=[dht_endpoint]) moe = RemoteMixtureOfExperts(in_features=64, grid_size=(num_experts, ), dht=dht, uid_prefix='expert.', k_best=2) model = nn.Sequential(moe, nn.Linear(64, 2)) opt = SGD(model.parameters(), lr=0.05) for step in range(max_steps): outputs = model(X_train) loss = F.cross_entropy(outputs, y_train) loss.backward() opt.step() opt.zero_grad() accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item() if accuracy >= threshold: break assert accuracy >= threshold, f"too small accuracy: {accuracy}"
def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_experts=5): dataset = load_digits(n_class=2) X_train, y_train = torch.tensor( dataset['data'], dtype=torch.float), torch.tensor(dataset['target']) SGD = partial(torch.optim.SGD, lr=0.05) all_expert_uids = [f'expert.{i}' for i in range(num_experts)] with background_server(expert_uids=all_expert_uids, device='cpu', optim_cls=SGD, hidden_dim=64, num_handlers=1) as (server_endpoint, dht_endpoint): dht = DHT(start=True, initial_peers=[dht_endpoint]) model = SwitchNetwork(dht, 64, 2, num_experts) opt = SGD(model.parameters(), lr=0.05) for step in range(max_steps): outputs, balancing_loss = model(X_train) loss = F.cross_entropy(outputs, y_train) + 0.01 * balancing_loss loss.backward() opt.step() opt.zero_grad() accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item() if accuracy >= threshold: break assert model.moe.grid_utilization.min().item() > (1 / num_experts) / 2 assert accuracy >= threshold, f"too small accuracy: {accuracy}"
def test_moe(): all_expert_uids = [ f'ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}' for _ in range(20) ] with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=16) as (server_endpoint, dht_endpoint): dht = hivemind.DHT(start=True, expiration=999, initial_peers=[dht_endpoint]) # declare expert uids. Server *should* declare them by itself, but it takes time. assert all( dht.declare_experts(all_expert_uids, endpoint=server_endpoint)) dmoe = hivemind.RemoteMixtureOfExperts(in_features=16, grid_size=(32, 32, 32), dht=dht, k_best=3, uid_prefix='ffn') for i in range(10): out = dmoe(torch.randn(10, 16)) out.sum().backward()
def test_multihead_expert(hid_dim=16): with background_server( expert_cls='multihead', num_experts=2, device='cpu', hidden_dim=hid_dim, num_handlers=2, no_dht=True, custom_module_path=CUSTOM_EXPERTS_PATH) as (server_endpoint, _): expert0 = RemoteExpert('expert.0', server_endpoint) expert1 = RemoteExpert('expert.1', server_endpoint) for batch_size in (1, 4): batch = (torch.randn(batch_size, hid_dim), torch.randn(batch_size, 2 * hid_dim), torch.randn(batch_size, 3 * hid_dim)) output0 = expert0(*batch) output1 = expert1(*batch) loss = output0.sum() loss.backward() loss = output1.sum() loss.backward()
def test_moe(): all_expert_uids = [ f'ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}' for _ in range(10) ] with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=16) as (server_endpoint, dht_endpoint): dht = hivemind.DHT(start=True, initial_peers=[dht_endpoint]) dmoe = hivemind.RemoteMixtureOfExperts(in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix='ffn.') for i in range(3): out = dmoe(torch.randn(10, 16)) out.sum().backward()
def test_no_experts(): all_expert_uids = [ f'expert.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}' for _ in range(10) ] with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='nop_delay', num_handlers=1, hidden_dim=16) as (server_endpoint, dht_endpoint): dht = hivemind.DHT(start=True, initial_peers=[dht_endpoint]) dmoe = hivemind.RemoteSwitchMixtureOfExperts(in_features=16, grid_size=(4, 4, 4), dht=dht, uid_prefix='expert.', forward_timeout=0.1, backward_timeout=0.1, allow_zero_outputs=True) for i in range(3): out, balancing_loss = dmoe(torch.randn(10, 16)) out.sum().backward()