def test_expert_params(device): model_dim = 8 num_experts = 4 gate = Top2Gate(model_dim, num_experts) expert = torch.nn.Linear(model_dim, model_dim) moe = MOELayer(gate, expert).to(device) for p in expert.parameters(): assert p.expert is True
def test_forward(device): model_dim = 8 num_experts = dist.get_world_size(dist.group.WORLD) input = torch.randn(1, 4, 16, model_dim).to(device) gate = Top2Gate(model_dim, num_experts) expert = torch.nn.Linear(model_dim, model_dim, bias=False) # Use identity matrix expert.weight = torch.nn.Parameter(torch.eye(model_dim)) moe = MOELayer(gate, expert).to(device) output = moe(input) assert output.shape == input.shape # Re-assembled output should match input due to identity expert. assert torch.allclose(input, output)
def test_backward(device): loss = torch.nn.MSELoss() model_dim = 8 num_experts = dist.get_world_size(dist.group.WORLD) input = torch.randn(1, 4, 16, model_dim).to(device) gate = Top2Gate(model_dim, num_experts) expert = torch.nn.Linear(model_dim, model_dim, bias=False) # Use identity matrix expert.weight = torch.nn.Parameter(torch.eye(model_dim)) moe = MOELayer(gate, expert).to(device) output = moe(input) assert output.shape == input.shape output = loss(output, input) output.backward() assert torch.allclose(expert.weight.grad, torch.zeros_like(expert.weight))
def test_forward_routing(device): model_dim = 8 num_experts = dist.get_world_size() input = torch.randn(1, 4, 16, model_dim).to(device) gate = RoundRobinGate(model_dim, num_experts) expert = torch.nn.Linear(model_dim, model_dim, bias=False) # Use scaling matrix (each rank has a different scale) scale = dist.get_rank() + 1 expert.weight = torch.nn.Parameter(torch.eye(model_dim) * scale) moe = MOELayer(gate, expert).to(device) output = moe(input) assert output.shape == input.shape # Verify that each token was sent to the correct expert by checking its scale. t = input.shape[2] for i in range(t): expert = i % num_experts assert torch.allclose(input[:, :, i] * (expert + 1), output[:, :, i])
def test_forward_multi(device): torch.set_printoptions(threshold=5000) num_local_experts = 4 model_dim = 4 num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts input = torch.randn(num_local_experts, 4, 16, model_dim).to(device) gate = Top2Gate(model_dim, num_experts) experts = [] for i in range(num_local_experts): expert = torch.nn.Linear(model_dim, model_dim, bias=False) # Use identity matrix expert.weight = torch.nn.Parameter(torch.eye(model_dim)) experts += [expert] moe = MOELayer(gate, torch.nn.ModuleList(experts)).to(device) output = moe(input) assert output.shape == input.shape # 90% of the input should have gone to an expert assert len(output.nonzero(as_tuple=False)) / output.numel() > 0.90 # Except for zeros, re-assembled output should match input due to identity expert. assert torch.allclose(input, torch.where(output > 0, output, input))
def test_forward_routing_multi(device): model_dim = 8 num_local_experts = 4 num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts input = torch.randn(4 * num_local_experts, 16, model_dim).to(device) gate = RoundRobinGate(model_dim, num_experts) experts = [] for i in range(num_local_experts): expert = torch.nn.Linear(model_dim, model_dim, bias=False) # Use scaling matrix (each rank has a different scale) scale = dist.get_rank() * num_local_experts + i + 1 expert.weight = torch.nn.Parameter(torch.eye(model_dim) * scale) experts += [expert] moe = MOELayer(gate, torch.nn.ModuleList(experts)).to(device) output = moe(input) assert output.shape == input.shape # Verify that each token was sent to the correct expert by checking its scale. t = input.shape[1] for i in range(t): expert = i % num_experts assert torch.allclose(input[:, i] * (expert + 1), output[:, i])
def test_create(device): model_dim = 8 num_experts = 4 gate = Top2Gate(model_dim, num_experts) expert = torch.nn.Linear(model_dim, model_dim) moe = MOELayer(gate, expert).to(device)