def partitions(pipeline_style): a = nn.Linear(1, 1) b = nn.Linear(1, 1) model = nn.Sequential(a, b) model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) assert isinstance(model.mp_partitions, list) assert len(model) == 1 assert isinstance(model.mp_partitions[0].module, nn.Sequential) if model.group.rank() == 0: assert "0.0.weight" in model.state_dict() else: assert "0.1.weight" in model.state_dict()
def partitions(): a = nn.Linear(1, 1) b = nn.Linear(1, 1) model = nn.Sequential(a, b) model = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) assert isinstance(model.partitions, nn.ModuleList) assert len(model) == 1 assert isinstance(model.partitions[0], nn.Sequential) assert "partitions.0.0.weight" in model.state_dict()
def test_partitions(): a = nn.Linear(1, 1) b = nn.Linear(1, 1) model = nn.Sequential(a, b) model = Pipe(model, [1, 1], devices=["cpu", "cpu"]) assert isinstance(model.partitions, nn.ModuleList) assert isinstance(model.partitions[0], nn.Sequential) assert isinstance(model.partitions[1], nn.Sequential) assert "partitions.0.0.weight" in model.state_dict()