def test_local_rref_creation_with_ivalue(self): # create a local RRef that holds a IValue rref_local_script_class = rpc.RRef(MyScriptClass()) self.assertEqual(rref_local_script_class.to_here().a, 10) # create a local RRef that holds a ScriptModule rref_local_script_mod = rpc.RRef(MyScriptModule(3)._c) self.assertEqual(rref_local_script_mod.to_here().forward(), torch.ones(3))
def _new_script_local_optimizer(optim_cls, local_params_rref, *args, **kwargs): optim = _ScriptLocalOptimizer(optim_cls, local_params_rref, *args, **kwargs) with _ScriptLocalOptimizer.compile_lock: script_optim = jit.script(optim) return rpc.RRef(script_optim, _ScriptLocalOptimizerInterface)
def test_create_local_script_module_rref_in_py(self): if self.rank != 0: return # Create a local RRef<MyModuleInterface>. rref_script_module = rpc.RRef(MyScriptModule(self.rank), MyModuleInterface) ret = rref_script_module.to_here().forward() self.assertEqual(ret, torch.ones(self.rank)) # Create a local RRef<MyModuleInterface> without type hint. with self.assertRaisesRegex( RuntimeError, ("The RRef being created contains a ScriptModule, " "must provide its ModuleInterface type hint.")): rref_script_module = rpc.RRef(MyScriptModule(self.rank))
def auto_graph_extract(devices): from fairscale.experimental.nn.distributed_pipeline.trace import make_graph device = devices[0].split("/")[1] torch.random.manual_seed(3) criterion = DistributedLoss(torch.nn.MSELoss) x = torch.randn(8, 4).to(device) # create model model = nn.Sequential( RemoteModule(devices[0], nn.Linear, (4, 4), {}), ShardedLinearLayer(devices[0], devices, devices[1]), RemoteModule(devices[0], nn.Linear, (4, 4), {}), ) graph = make_graph(model) pipe = DistributedPipeline(graph, chunks=4) partitions = extract_partitions(graph, pipe) assert [[0, 1], [2], [3], [4], [5]] == partitions, f"partitions={partitions}" parameter_rrefs = pipe.parameter_rrefs() assert len(parameter_rrefs) == 8 opt = DistributedOptimizer( torch.optim.SGD, parameter_rrefs, lr=0.05, ) losses = [] for i in range(2): with dist_autograd.context() as context_id: y = pipe(x) loss = criterion(y, rpc.RRef(x)) losses.append(loss) loss.backward(context_id) opt.step(context_id) losses = [l.to_here() for l in losses] assert losses[0] > losses[1], f"{losses[0]} !> {losses[1]}"
def update(devices): torch.random.manual_seed(3) criterion = DistributedLoss(torch.nn.MSELoss) x = torch.randn(8, 4) model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})] pipe = MultiProcessPipe(model, balance=[1, 1], chunks=4, devices=devices[:2]) params = pipe.parameter_rrefs() opt = DistributedOptimizer( torch.optim.SGD, pipe.parameter_rrefs(), lr=0.05, ) losses = [] for i in range(2): with dist_autograd.context() as context_id: y = pipe(x) loss = criterion(y, rpc.RRef(x)) losses.append(loss) loss.backward(context_id) opt.step(context_id) losses = [l.to_here() for l in losses] assert losses[0] > losses[1], f"{losses[0]} !> {losses[1]}"
def _create_module_with_interface( module_cls, args, kwargs, device, module_interface_cls ): module = _create_module(module_cls, args, kwargs, device) if module_interface_cls is not None: module = torch.jit.script(module) return rpc.RRef(module, module_interface_cls)
def update(devices): device = devices[0].split("/")[1] torch.random.manual_seed(3) criterion = DistributedLoss(torch.nn.MSELoss) x = torch.randn(8, 4).to(device) model = [ RemoteModuleParams(nn.Linear, (4, 4), {}), RemoteModuleParams(nn.ReLU, (), {}) ] pipe = create_sequence_pipeline(model, balance=[1, 1], chunks=4, devices=devices[:2]) opt = DistributedOptimizer( torch.optim.SGD, pipe.parameter_rrefs(), lr=0.05, ) losses = [] for i in range(2): with dist_autograd.context() as context_id: y = pipe(x) loss = criterion(y, rpc.RRef(x)) losses.append(loss) loss.backward(context_id) opt.step(context_id) losses = [l.to_here() for l in losses] assert losses[0] > losses[1], f"{losses[0]} !> {losses[1]}"
def forward(self, x: Tensor) -> rpc.RRef: # type: ignore outputs = [] for chunk in x.chunk(self.chunks): output = rpc.RRef(chunk) for rlayer in self.rmodule: output = rlayer.remote().forward(output) outputs.append(output) return rpc.remote(outputs[0].owner(), _rcat, args=(outputs, ))
def test_create_local_script_class_rref_in_py(self): if self.rank != 0: return # Create a local RRef<MyScriptClass>. rref_script_class = rpc.RRef(MyScriptClass(self.rank)) ret = rref_script_class.to_here().get_value() self.assertEqual(ret, self.rank)
def _create_module(module_cls, args, kwargs, module_interface_cls=None): module = module_cls(*args, **kwargs) if not isinstance(module, nn.Module): raise ValueError( "Expect `module_cls(*args, **kwargs)` returns an instance of <class nn.Module>, " f"but it returns an instance of {type(module)}.") if module_interface_cls is not None: module = torch.jit.script(module) return rpc.RRef(module, module_interface_cls)
def run_ps(trainers): timed_log("Start training") ps_rref = rpc.RRef(BatchUpdateParameterServer()) futs = [] for trainer in trainers: futs.append(rpc.rpc_async(trainer, run_trainer, args=(ps_rref, ))) torch.futures.wait_all(futs) timed_log("Finish training")
def _init_rpc(self): self._rpc_initialized = True self._remote_shards = {} # Gather all the sharded tensor ids. world_size = dist.get_world_size(self._process_group) worker_infos = rpc._get_current_rpc_agent().get_worker_infos() rank_to_name = {} name_to_rank = {} for worker_info in worker_infos: rank_to_name[worker_info.id] = worker_info.name name_to_rank[worker_info.name] = worker_info.id rpc_workers = set() for rank in range(world_size): if self._process_group == distributed_c10d._get_default_group(): global_rank = rank else: global_rank = distributed_c10d._get_global_rank( self._process_group, rank) rpc_workers.add(rank_to_name[global_rank]) all_tensor_ids = rpc.api._all_gather(self._sharded_tensor_id, rpc_workers) # Share the local shards to the entire world. futs = [] rpc_rank = rpc.get_worker_info().id for rank in range(world_size): # Skip self. if rank == dist.get_rank(self._process_group): continue if self._process_group == distributed_c10d._get_default_group(): global_rank = rank else: global_rank = distributed_c10d._get_global_rank( self._process_group, rank) if len(self.local_shards()) != 0: rrefs: List[rpc.RRef[Shard]] = [ rpc.RRef(shard) for shard in self.local_shards() ] fut = rpc.rpc_async( global_rank, _register_remote_shards, args=(all_tensor_ids[rank_to_name[global_rank]], rrefs, rpc_rank)) futs.append(fut) torch.futures.wait_all(futs) # Barrier for all RPCs to finish on all ranks. rpc.api._barrier(rpc_workers)
def __init__(self, client_id_triple, num_epochs=3, config=None): log_rref = rpc.RRef(FLLogger()) self.log_rref = log_rref self.num_epoch = num_epochs self.config = config self.tb_path = config.output_location self.ensure_path_exists(self.tb_path) self.tb_writer = SummaryWriter( f'{self.tb_path}/{config.experiment_prefix}_federator') self.create_clients(client_id_triple) self.config.init_logger(logging)
def run_ps(trainers): timed_log("Start training") start = perf_counter() ps_rref = rpc.RRef(BatchUpdateParameterServer(len(trainers))) futs = [] for trainer in trainers: futs.append(rpc.rpc_async(trainer, run_trainer, args=(ps_rref, ))) torch.futures.wait_all(futs) stop = perf_counter() timed_log("Finish training") timed_log(f"Time spent training: {stop-start}s")
def forward(self, x: Tensor) -> rpc.RRef: # type: ignore outputs = [] for i, chunk in enumerate(x.chunk(self.chunks)): output = rpc.RRef(chunk) if i < self.checkpoint_stop: for rlayer in self.rmodule: output = rpc.remote(rlayer.owner(), _rcheckpoint, args=(rlayer, output)) else: for rlayer in self.rmodule: output = rlayer.remote().forward(output) outputs.append(output) return rpc.remote(outputs[0].owner(), _rcat, args=(outputs,))
def backward(devices): device = devices[0].split("/")[1] torch.random.manual_seed(3) criterion = DistributedLoss(torch.nn.MSELoss) x = torch.randn(8, 4).to(device) model = [RemoteModuleParams(nn.Linear, (4, 4), {}), RemoteModuleParams(nn.ReLU, (), {})] pipe = create_sequence_pipeline(model, balance=[1, 1], chunks=4, devices=devices[:2]) with dist_autograd.context() as context_id: y = pipe(x) loss = criterion(y, rpc.RRef(x)) loss.backward(context_id) grads = dist_autograd.get_gradients(context_id) assert len(grads) == 2
def run(rank, num_workers, data_dir, model, batch_size, test_batch_size, lr, num_epochs, job_name, target_loss): logging.basicConfig(level=logging.INFO) world_size = num_workers + 2 options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=16, rpc_timeout=0) if rank == 0: logging.info(f"PS{rank} initializing") rpc.init_rpc(f"PS{rank}", rank=rank, world_size=world_size, rpc_backend_options=options) logging.info(f"PS{rank} initialized") workers = [f"worker{r}" for r in range(1, world_size - 1)] ps_rref = rpc.RRef(ParameterServer(model, num_workers, lr, job_name)) futs = [] futs.append( rpc.rpc_async(to="tester", func=get_accuracy, args=(ps_rref, data_dir, test_batch_size, job_name, target_loss))) for worker in workers: futs.append( rpc.rpc_async(to=worker, func=run_worker, args=(ps_rref, data_dir, batch_size, num_epochs, worker, job_name))) torch.futures.wait_all(futs) logging.info(f"Finish training") elif rank == world_size - 1: logging.info(f"Tester initializing") rpc.init_rpc("tester", rank=rank, world_size=world_size, rpc_backend_options=options) logging.info(f"Tester initialized") else: logging.info(f"Worker{rank} initializing") rpc.init_rpc(f"worker{rank}", rank=rank, world_size=world_size, rpc_backend_options=options) logging.info(f"Worker{rank} initialized") rpc.shutdown()
def backward(devices): torch.random.manual_seed(3) criterion = DistributedLoss(torch.nn.MSELoss) x = torch.randn(8, 4) model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})] pipe = MultiProcessPipe(model, balance=[1, 1], chunks=4, devices=devices[:2]) with dist_autograd.context() as context_id: y = pipe(x) loss = criterion(y, rpc.RRef(x)) loss.backward(context_id) grads = dist_autograd.get_gradients(context_id) assert len(grads) == 2
def _init_rpc(self): # Validate PG and RPC ranks match. pg_rank = dist.get_rank() rpc_rank = rpc.get_worker_info().id if pg_rank != rpc_rank: raise ValueError( f'Default ProcessGroup and RPC ranks must be ' f'the same for ShardedTensor, found process group rank: ' f'{pg_rank} and RPC rank: {rpc_rank}') self._remote_shards = {} # Gather all the sharded tensor ids. world_size = dist.get_world_size(self._process_group) worker_infos = rpc._get_current_rpc_agent().get_worker_infos() rank_to_name = {} name_to_rank = {} for worker_info in worker_infos: rank_to_name[worker_info.id] = worker_info.name name_to_rank[worker_info.name] = worker_info.id all_tensor_ids = rpc.api._all_gather(self._sharded_tensor_id) # Share the local shards to the entire world. futs = [] rpc_rank = rpc.get_worker_info().id for rank in range(dist.get_world_size()): # Skip self. if rank == dist.get_rank(): continue if len(self.local_shards()) != 0: rrefs: List[rpc.RRef[Shard]] = [ rpc.RRef(shard) for shard in self.local_shards() ] fut = rpc.rpc_async(rank, _register_remote_shards, args=(all_tensor_ids[rank_to_name[rank]], rrefs, rpc_rank)) futs.append(fut) torch.futures.wait_all(futs) # Barrier for all RPCs to finish on all ranks. rpc.api._all_gather(None)
def train(self, num_epochs=10): for ep in range(num_epochs): itbar = self.dataloader itbar = tqdm(itbar, desc='iter') for X, Y in itbar: self.optim.zero_grad() # TODO chunk X and distribute to workers tile_len = X.shape[-1] // num_workers # this is a pretty inefficient way to do things chunks = [rpc.RRef( X[..., tile_len * i : tile_len * (i+1)], ) for i in range(world_size)] pred = self.model(chunks) loss = self.criterion(pred, Y) loss.backward() self.optim.step()
def multi_input_multi_output_layers(devices): device = devices[0].split("/")[1] torch.random.manual_seed(3) criterion = DistributedLoss(torch.nn.MSELoss) x = torch.randn(8, 4).to(device) # / ->linear_layer_2_1 # input -> linear_layer1 -> split ->concatenate # \ ->linear_layer_2_2 linear_layer_1 = RemoteModule(devices[0], nn.Linear, (4, 4), {}) split = RemoteModule(devices[0], SplitTensors, (), {}) linear_layers_2 = [ RemoteModule(devices[0], nn.Linear, (2, 2), {}), RemoteModule(devices[1], nn.Linear, (2, 2), {}), ] concatenate = RemoteModule(devices[1], ConcatenateTensors, ()) graph = PipelineModulesGraph() graph.add_sequence([linear_layer_1, split], [0], 2) for i, l in enumerate(linear_layers_2): graph.add_layer(l, [(split, i)]) graph.add_layer(concatenate, linear_layers_2) pipe = DistributedPipeline(graph, chunks=4) assert [[0, 1], [2], [3], [4]] == extract_partitions(graph, pipe) parameter_rrefs = pipe.parameter_rrefs() assert len(parameter_rrefs) == 6 opt = DistributedOptimizer( torch.optim.SGD, parameter_rrefs, lr=0.05, ) losses = [] for i in range(2): with dist_autograd.context() as context_id: y = pipe(x) loss = criterion(y, rpc.RRef(x)) losses.append(loss) loss.backward(context_id) opt.step(context_id) losses = [l.to_here() for l in losses] assert losses[0] > losses[1], f"{losses[0]} !> {losses[1]}"
def _parameter_rrefs(module): r""" Create one RRef for each parameter in the given local module, and return a list of RRefs. """ return [rpc.RRef(p) for p in module.parameters()]
def genotype(self): return rpc.RRef(self.model.genotype())
def owner_create_rref_my_script_module(a): return rpc.RRef(MyScriptModule(a), MyModuleInterface)
def named_weights(self): param_rrefs = [ rpc.RRef(param) for param in self.model.named_parameters() ] return param_rrefs
def owner_create_rref_my_script_class(a): return rpc.RRef(MyScriptClass(a))
def _param_rrefs(module_rref, recurse): ret = [] for param in module_rref.local_value().parameters(recurse): ret.append(rpc.RRef(param)) return ret
def alphas(self): param_rrefs = [rpc.RRef(p) for n, p in self.model._alphas] return param_rrefs
def get_param_rrefs(self): param_rrefs = [rpc.RRef(param) for param in self.model.parameters()] return param_rrefs
def _new_local_optimizer(optim_cls, local_params_rref, *args, **kwargs): return rpc.RRef( _LocalOptimizer(optim_cls, local_params_rref, *args, **kwargs))