def test_expert_params(device): model_dim = 8 num_experts = 4 gate = Top2Gate(model_dim, num_experts) expert = torch.nn.Linear(model_dim, model_dim) moe = MOELayer(gate, expert).to(device) for p in expert.parameters(): assert p.expert is True
def test_forward(device): model_dim = 8 num_experts = dist.get_world_size(dist.group.WORLD) input = torch.randn(1, 4, 16, model_dim).to(device) gate = Top2Gate(model_dim, num_experts) expert = torch.nn.Linear(model_dim, model_dim, bias=False) # Use identity matrix expert.weight = torch.nn.Parameter(torch.eye(model_dim)) moe = MOELayer(gate, expert).to(device) output = moe(input) assert output.shape == input.shape # Re-assembled output should match input due to identity expert. assert torch.allclose(input, output)
def test_backward(device): loss = torch.nn.MSELoss() model_dim = 8 num_experts = dist.get_world_size(dist.group.WORLD) input = torch.randn(1, 4, 16, model_dim).to(device) gate = Top2Gate(model_dim, num_experts) expert = torch.nn.Linear(model_dim, model_dim, bias=False) # Use identity matrix expert.weight = torch.nn.Parameter(torch.eye(model_dim)) moe = MOELayer(gate, expert).to(device) output = moe(input) assert output.shape == input.shape output = loss(output, input) output.backward() assert torch.allclose(expert.weight.grad, torch.zeros_like(expert.weight))
def do_test_forward(device): torch.manual_seed(3) input = torch.randn(12, 4).to(device) gate = Top2Gate(4, 6).to(device) capacity = 2 * 12 // 6 l_aux, combine_weights, dispatch_mask = gate(input) assert pytest.approx(l_aux.item(), 0.0283) assert combine_weights.shape == (12, 6, 4) assert dispatch_mask.shape == (12, 6, 4) assert torch.equal(combine_weights.bool(), dispatch_mask) assert torch.all(torch.sum(dispatch_mask, axis=(0, 2)) <= capacity) assert torch.all(combine_weights >= 0.0) assert torch.all(combine_weights <= 1.0) weights_sum = torch.sum(combine_weights).item() assert round(weights_sum) == pytest.approx(weights_sum) # For this random seed, we get 12 slots filled. assert weights_sum == pytest.approx(12.0)
def test_forward_multi(device): torch.set_printoptions(threshold=5000) num_local_experts = 4 model_dim = 4 num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts input = torch.randn(num_local_experts, 4, 16, model_dim).to(device) gate = Top2Gate(model_dim, num_experts) experts = [] for i in range(num_local_experts): expert = torch.nn.Linear(model_dim, model_dim, bias=False) # Use identity matrix expert.weight = torch.nn.Parameter(torch.eye(model_dim)) experts += [expert] moe = MOELayer(gate, torch.nn.ModuleList(experts)).to(device) output = moe(input) assert output.shape == input.shape # 90% of the input should have gone to an expert assert len(output.nonzero(as_tuple=False)) / output.numel() > 0.90 # Except for zeros, re-assembled output should match input due to identity expert. assert torch.allclose(input, torch.where(output > 0, output, input))
def test_create_cuda(): gate = Top2Gate(4, 8).cuda()
def test_create(): gate = Top2Gate(4, 8)
def test_create(device): model_dim = 8 num_experts = 4 gate = Top2Gate(model_dim, num_experts) expert = torch.nn.Linear(model_dim, model_dim) moe = MOELayer(gate, expert).to(device)