def create_sequence_pipeline(layers: list[RemoteModuleParams], balance: list[int], devices: list[str], **kwargs) -> DistributedPipeline: """A simple helper function to create a pipeline from list of pipeline-modules that run sequentially. Args: layers: list of modules. They should not be already assigned a remote-device. balance: a list of integers how layers should be paritioned. Sum of numbers in 'balance' should be equal to the number of layers. devices: specification of remote device for each partition. Should be of the same length as 'balance'. """ remote_modules: list[RemoteModule] = [] index = 0 for num_layers, remote_device in zip(balance, devices): next_index = index + num_layers for li in range(index, next_index): remote_modules.append( RemoteModule(remote_device, **layers[li]._asdict())) index = next_index graph = PipelineModulesGraph() graph.add_sequence(remote_modules, [0]) return DistributedPipeline(graph, **kwargs)
def multi_input_multi_output_layers(devices): device = devices[0].split("/")[1] torch.random.manual_seed(3) criterion = DistributedLoss(torch.nn.MSELoss) x = torch.randn(8, 4).to(device) # / ->linear_layer_2_1 # input -> linear_layer1 -> split ->concatenate # \ ->linear_layer_2_2 linear_layer_1 = RemoteModule(devices[0], nn.Linear, (4, 4), {}) split = RemoteModule(devices[0], SplitTensors, (), {}) linear_layers_2 = [ RemoteModule(devices[0], nn.Linear, (2, 2), {}), RemoteModule(devices[1], nn.Linear, (2, 2), {}), ] concatenate = RemoteModule(devices[1], ConcatenateTensors, ()) graph = PipelineModulesGraph() graph.add_sequence([linear_layer_1, split]) graph.set_model_input(linear_layer_1) graph.fan_out(split, linear_layers_2) graph.add_multi_input_layer(concatenate, linear_layers_2) pipe = DistributedPipeline(graph, chunks=4) assert [[0, 1], [2], [3], [4]] == extract_partitions(graph, pipe) params = pipe.parameter_rrefs() opt = DistributedOptimizer(torch.optim.SGD, pipe.parameter_rrefs(), lr=0.05,) losses = [] for i in range(2): with dist_autograd.context() as context_id: y = pipe(x) loss = criterion(y, rpc.RRef(x)) losses.append(loss) loss.backward(context_id) opt.step(context_id) losses = [l.to_here() for l in losses] assert losses[0] > losses[1], f"{losses[0]} !> {losses[1]}"