def grad_server_helper(model_creators: List[Callable], optimizer: Any = Adam, learning_rate: float = 1e-3): """ Helper function for creating a tuple of grad servers, used by A3C, IMPALE, etc. This function requires all processes in the world to enter. Warning: You should never run this function twice! Args: model_creators: A list of model creator functions, each one corresponds to one gradient reduction server. optimizer: Optimizer type, default is Adam. learning_rate: Learning rate of the optimizer. Returns: A tuple of accessors to gradient servers, the tuple has the same size as ``model_creators`` """ # Note: # passing a list of creator functions instead of passing a list of models # directly is designed to remove the unnecessary model creation cost on # not-the-primary-reducer processes. DEFAULT_GROUP_NAME = "server_group" # create groups first world = get_world() server_group = world.create_rpc_group(DEFAULT_GROUP_NAME, world.get_members()) # create servers primary_reducer = world.get_members()[0] servers = [ PushPullGradServerImpl("grad_server_" + str(i), server_group, primary_reducer=primary_reducer) for i in range(len(model_creators)) ] if get_cur_name() == primary_reducer: for model_creator, server in zip(model_creators, servers): model = model_creator() server.manage_model(model, optimizer(model.parameters(), lr=learning_rate)) server.start() server_group.barrier() servers = tuple( server_group.get_paired("grad_server_" + str(i)).to_here() for i in range(len(model_creators)) ) # accessors instead of actual implementation instance # will be returned because of __reduce__ return servers
def init_from_config(cls, config: Union[Dict[str, Any], Config]): world = get_world() f_config = deepcopy(config["frame_config"]) apex_group = world.create_rpc_group( group_name=f_config["apex_group_name"], members=( world.get_members() if f_config["apex_members"] == "all" else f_config["apex_members"] ), ) models = assert_and_get_valid_models(f_config["models"]) model_args = f_config["model_args"] model_kwargs = f_config["model_kwargs"] models = [ m(*arg, **kwarg) for m, arg, kwarg in zip(models, model_args, model_kwargs) ] # wrap models in DistributedDataParallel when running in learner mode max_learner_id = f_config["learner_process_number"] learner_group = world.create_collective_group(ranks=list(range(max_learner_id))) if world.rank < max_learner_id: models = [ DistributedDataParallel(module=m, process_group=learner_group.group) for m in models ] optimizer = assert_and_get_valid_optimizer(f_config["optimizer"]) criterion = assert_and_get_valid_criterion(f_config["criterion"])( *f_config["criterion_args"], **f_config["criterion_kwargs"] ) criterion.reduction = "none" lr_scheduler = f_config["lr_scheduler"] and assert_and_get_valid_lr_scheduler( f_config["lr_scheduler"] ) servers = model_server_helper( model_num=1, group_name=f_config["model_server_group_name"], members=f_config["model_server_members"], ) del f_config["optimizer"] del f_config["criterion"] del f_config["lr_scheduler"] frame = cls( *models, optimizer, criterion, apex_group, servers, lr_scheduler=lr_scheduler, **f_config ) if world.rank >= max_learner_id: frame.update = lambda *_, **__: (None, None) return frame
def training_step(self, batch, _batch_idx): world_inited = get_world() is not None model_inited = isinstance(self.nn_model, NNModule) if world_inited and get_cur_rank() == 0: with open(os.environ["TEST_SAVE_PATH"], "wb") as f: pickle.dump([model_inited], f) if not world_inited: raise RuntimeError("World not initialized.") return None
def init_frame(self): # Called by overloaded pytorch lightning DDP plugins if self.frame is None: # initialize framework self.frame = init_algorithm_from_config(self.config, model_device=self.device) self._batch_num = self.config["batch_num"].get(self.frame.role, 1) # forward models to the launcher module, so that parameters are handled. for name, model in zip(self.frame.get_top_model_names(), self.frame.top_models): self.add_module(name, model) # create group for custom synchronization with large timeout # otherwise the default barrier in pytorch_lightning after training step # will throw an exception. world = get_world() self.group = world.create_collective_group(world.get_ranks(), timeout=86400)
def model_server_helper( model_num: int, group_name: str = "model_server", members: Union[str, List[str]] = "all", ): """ Helper function for creating a tuple of model servers, used by APEX, etc. This function requires all processes in the world to enter. Args: model_num: The number of models, corresponds to the number of model servers, since each server manages 1 model. group_name: Name of the RPC group where gradient servers should be registered on, the group name should be unique. members: Name of the involved RPC processes, ``"all"`` for all processes, only the first process will serve as the server in the current implementation. Returns: A tuple of accessors to model servers, the size of tuple is ``model_num`` """ # create groups first world = get_world() members = world.get_members() if members == "all" else members server_group = world.create_rpc_group(group_name, members) # create servers # In current implementation, only one process will initialize the server if get_cur_name() == members[0]: for i in range(model_num): _server = PushPullModelServerImpl("model_server_" + str(i), server_group) server_group.barrier() servers = tuple( server_group.get_paired("model_server_" + str(i)).to_here() for i in range(model_num)) # accessors instead of actual implementation instance # will be returned because of __reduce__ return servers
def init_from_config(cls, config: Union[Dict[str, Any], Config]): world = get_world() f_config = copy.deepcopy(config["frame_config"]) ars_group = world.create_rpc_group( group_name=f_config["ars_group_name"], members=( world.get_members() if f_config["ars_members"] == "all" else f_config["ars_members"] ), ) models = assert_and_get_valid_models(f_config["models"]) model_args = f_config["model_args"] model_kwargs = f_config["model_kwargs"] models = [ m(*arg, **kwarg) for m, arg, kwarg in zip(models, model_args, model_kwargs) ] optimizer = assert_and_get_valid_optimizer(f_config["optimizer"]) lr_scheduler = f_config["lr_scheduler"] and assert_and_get_valid_lr_scheduler( f_config["lr_scheduler"] ) servers = model_server_helper( model_num=1, group_name=f_config["model_server_group_name"], members=f_config["model_server_members"], ) del f_config["optimizer"] del f_config["lr_scheduler"] frame = cls( *models, optimizer, ars_group, servers, lr_scheduler=lr_scheduler, **f_config, ) return frame
def model_server_helper(model_num): """ Helper function for creating a tuple of model servers, used by APEX, etc. This function requires all processes in the world to enter. Warning: You should never run this function twice! Returns: A tuple of accessors to model servers, the size of tuple is ``model_num`` """ DEFAULT_GROUP_NAME = "server_group" # create groups first world = get_world() server_group = world.create_rpc_group(DEFAULT_GROUP_NAME, world.get_members()) # create servers # In current implementation, only one process will initialize the server if get_cur_name() == world.get_members()[0]: for i in range(model_num): _server = PushPullModelServerImpl("model_server_" + str(i), server_group) server_group.barrier() servers = tuple( server_group.get_paired("model_server_" + str(i)).to_here() for i in range(model_num) ) # accessors instead of actual implementation instance # will be returned because of __reduce__ return servers
def init_from_config( cls, config: Union[Dict[str, Any], Config], model_device: Union[str, t.device] = "cpu", ): world = get_world() f_config = deepcopy(config["frame_config"]) impala_group = world.create_rpc_group( group_name=f_config["impala_group_name"], members=( world.get_members() if f_config["impala_members"] == "all" else f_config["impala_members"] ), ) models = assert_and_get_valid_models(f_config["models"]) model_args = f_config["model_args"] model_kwargs = f_config["model_kwargs"] models = [ m(*arg, **kwarg).to(model_device) for m, arg, kwarg in zip(models, model_args, model_kwargs) ] # wrap models in DistributedDataParallel when running in learner mode max_learner_id = f_config["learner_process_number"] learner_group = world.create_collective_group(ranks=list(range(max_learner_id))) if world.rank < max_learner_id: models = [ DistributedDataParallel(module=m, process_group=learner_group.group) for m in models ] optimizer = assert_and_get_valid_optimizer(f_config["optimizer"]) criterion = assert_and_get_valid_criterion(f_config["criterion"])( *f_config["criterion_args"], **f_config["criterion_kwargs"] ) lr_scheduler = f_config["lr_scheduler"] and assert_and_get_valid_lr_scheduler( f_config["lr_scheduler"] ) servers = model_server_helper( model_num=1, group_name=f_config["model_server_group_name"], members=f_config["model_server_members"], ) del f_config["optimizer"] del f_config["criterion"] del f_config["lr_scheduler"] frame = cls( *models, optimizer, criterion, impala_group, servers, lr_scheduler=lr_scheduler, **f_config, ) if world.rank >= max_learner_id: frame.role = "sampler" frame.update = _disable_update else: frame.role = "learner" return frame
def grad_server_helper( model_creators: List[Callable], group_name: str = "grad_server", members: Union[str, List[str]] = "all", optimizer: Any = Adam, learning_rate: Union[float, List[float]] = 1e-3, optimizer_kwargs: List[Dict[str, Any]] = None, lr_scheduler: Any = None, lr_scheduler_args: List[Tuple] = None, lr_scheduler_kwargs: List[Dict[str, Any]] = None, ): """ Helper function for creating a tuple of grad servers, used by A3C, IMPALE, etc. This function requires all processes in the world to enter. Args: model_creators: A list of model creator functions, each one corresponds to one gradient reduction server. group_name: Name of the RPC group where gradient servers should be registered on, the group name should be unique. members: Name of the involved RPC processes, ``"all"`` for all processes, they will be used as secondary reducers, the first process will be the primary reducer. optimizer: Optimizer class, default is Adam. learning_rate: Learning rate of each optimizer. Or a single float value for every one. optimizer_kwargs: Optimizer keyword arguments for each optimizer of each model. lr_scheduler: Learning rate scheduler class. lr_scheduler_args: Learning rate scheduler arguments for each lr_scheduler corresponding to each optimizer. lr_scheduler_kwargs: Learning rate scheduler keyword arguments for each lr_scheduler corresponding to each optimizer. Returns: A tuple of accessors to gradient servers, the tuple has the same size as ``model_creators`` """ # Note: # passing a list of creator functions instead of passing a list of models # directly is designed to remove the unnecessary model creation cost on # not-the-primary-reducer processes. # create groups first world = get_world() members = world.get_members() if members == "all" else members server_group = world.create_rpc_group(group_name, members) if isinstance(learning_rate, float): learning_rate = [learning_rate] * len(model_creators) optimizer_kwargs = optimizer_kwargs or [{}] * len(model_creators) lr_scheduler_args = lr_scheduler_args or [()] * len(model_creators) lr_scheduler_kwargs = lr_scheduler_kwargs or [{}] * len(model_creators) # create servers primary_reducer = members[0] servers = [ PushPullGradServerImpl("grad_server_" + str(i), server_group, primary_reducer=primary_reducer) for i in range(len(model_creators)) ] if get_cur_name() == primary_reducer: for ( model_creator, server, optim_kwargs, lr, lr_sch_args, lr_sch_kwargs, ) in zip( model_creators, servers, optimizer_kwargs, learning_rate, lr_scheduler_args, lr_scheduler_kwargs, ): model = model_creator() if lr_scheduler is None: server.manage_model( model, optimizer(model.parameters(), lr=lr, **optim_kwargs)) else: optimizer = optimizer(model.parameters(), lr=lr, **optim_kwargs) server.manage_model( model, optimizer, lr_scheduler(optimizer, *lr_sch_args, **lr_sch_kwargs), ) server.start() server_group.barrier() servers = tuple( server_group.get_paired("grad_server_" + str(i)).to_here() for i in range(len(model_creators))) # accessors instead of actual implementation instance # will be returned because of __reduce__ return servers