Beispiel #1
0
def partitions(pipeline_style):
    a = nn.Linear(1, 1)
    b = nn.Linear(1, 1)

    model = nn.Sequential(a, b)
    model = Pipe(model, [1, 1],
                 style=pipeline_style,
                 worker_map=get_worker_map())

    assert isinstance(model.mp_partitions, list)
    assert len(model) == 1
    assert isinstance(model.mp_partitions[0].module, nn.Sequential)

    if model.group.rank() == 0:
        assert "0.0.weight" in model.state_dict()
    else:
        assert "0.1.weight" in model.state_dict()
Beispiel #2
0
def partitions():
    a = nn.Linear(1, 1)
    b = nn.Linear(1, 1)

    model = nn.Sequential(a, b)
    model = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map())

    assert isinstance(model.partitions, nn.ModuleList)
    assert len(model) == 1
    assert isinstance(model.partitions[0], nn.Sequential)

    assert "partitions.0.0.weight" in model.state_dict()
Beispiel #3
0
def test_partitions():
    a = nn.Linear(1, 1)
    b = nn.Linear(1, 1)

    model = nn.Sequential(a, b)
    model = Pipe(model, [1, 1], devices=["cpu", "cpu"])

    assert isinstance(model.partitions, nn.ModuleList)
    assert isinstance(model.partitions[0], nn.Sequential)
    assert isinstance(model.partitions[1], nn.Sequential)

    assert "partitions.0.0.weight" in model.state_dict()