Ejemplo n.º 1
0
def test_expert_params(device):
    model_dim = 8
    num_experts = 4
    gate = Top2Gate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim)
    moe = MOELayer(gate, expert).to(device)
    for p in expert.parameters():
        assert p.expert is True
Ejemplo n.º 2
0
def test_forward(device):
    model_dim = 8
    num_experts = dist.get_world_size(dist.group.WORLD)
    input = torch.randn(1, 4, 16, model_dim).to(device)
    gate = Top2Gate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim, bias=False)
    # Use identity matrix
    expert.weight = torch.nn.Parameter(torch.eye(model_dim))
    moe = MOELayer(gate, expert).to(device)
    output = moe(input)
    assert output.shape == input.shape
    # Re-assembled output should match input due to identity expert.
    assert torch.allclose(input, output)
Ejemplo n.º 3
0
def test_backward(device):
    loss = torch.nn.MSELoss()
    model_dim = 8
    num_experts = dist.get_world_size(dist.group.WORLD)
    input = torch.randn(1, 4, 16, model_dim).to(device)
    gate = Top2Gate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim, bias=False)
    # Use identity matrix
    expert.weight = torch.nn.Parameter(torch.eye(model_dim))
    moe = MOELayer(gate, expert).to(device)
    output = moe(input)
    assert output.shape == input.shape
    output = loss(output, input)
    output.backward()
    assert torch.allclose(expert.weight.grad, torch.zeros_like(expert.weight))
Ejemplo n.º 4
0
def test_forward_routing(device):
    model_dim = 8
    num_experts = dist.get_world_size()
    input = torch.randn(1, 4, 16, model_dim).to(device)
    gate = RoundRobinGate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim, bias=False)
    # Use scaling matrix (each rank has a different scale)
    scale = dist.get_rank() + 1
    expert.weight = torch.nn.Parameter(torch.eye(model_dim) * scale)
    moe = MOELayer(gate, expert).to(device)
    output = moe(input)
    assert output.shape == input.shape
    # Verify that each token was sent to the correct expert by checking its scale.
    t = input.shape[2]
    for i in range(t):
        expert = i % num_experts
        assert torch.allclose(input[:, :, i] * (expert + 1), output[:, :, i])
Ejemplo n.º 5
0
def test_forward_multi(device):
    torch.set_printoptions(threshold=5000)
    num_local_experts = 4
    model_dim = 4
    num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts
    input = torch.randn(num_local_experts, 4, 16, model_dim).to(device)
    gate = Top2Gate(model_dim, num_experts)
    experts = []
    for i in range(num_local_experts):
        expert = torch.nn.Linear(model_dim, model_dim, bias=False)
        # Use identity matrix
        expert.weight = torch.nn.Parameter(torch.eye(model_dim))
        experts += [expert]
    moe = MOELayer(gate, torch.nn.ModuleList(experts)).to(device)
    output = moe(input)
    assert output.shape == input.shape
    # 90% of the input should have gone to an expert
    assert len(output.nonzero(as_tuple=False)) / output.numel() > 0.90
    # Except for zeros, re-assembled output should match input due to identity expert.
    assert torch.allclose(input, torch.where(output > 0, output, input))
Ejemplo n.º 6
0
def test_forward_routing_multi(device):
    model_dim = 8
    num_local_experts = 4
    num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts
    input = torch.randn(4 * num_local_experts, 16, model_dim).to(device)
    gate = RoundRobinGate(model_dim, num_experts)
    experts = []
    for i in range(num_local_experts):
        expert = torch.nn.Linear(model_dim, model_dim, bias=False)
        # Use scaling matrix (each rank has a different scale)
        scale = dist.get_rank() * num_local_experts + i + 1
        expert.weight = torch.nn.Parameter(torch.eye(model_dim) * scale)
        experts += [expert]
    moe = MOELayer(gate, torch.nn.ModuleList(experts)).to(device)
    output = moe(input)
    assert output.shape == input.shape
    # Verify that each token was sent to the correct expert by checking its scale.
    t = input.shape[1]
    for i in range(t):
        expert = i % num_experts
        assert torch.allclose(input[:, i] * (expert + 1), output[:, i])
Ejemplo n.º 7
0
def test_create(device):
    model_dim = 8
    num_experts = 4
    gate = Top2Gate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim)
    moe = MOELayer(gate, expert).to(device)