def initialize_gradients(self): Future.gen_tuple_list([ self.call_async(rank, '_async_initialize_gradients') for rank in range(self.num_replicas) ]) self._grads_initialized = True
def update_parameters(self, grad_denom=1): """ When we update parameters, all replicas update at the same time""" self.check_global_overflow() Future.gen_tuple_list([ self.call_async(rank, '_async_update', grad_denom=grad_denom, is_global_overflow=False) for rank in range(self.num_replicas) ])
def _scatter_samples(self, batches, replace_empty_samples=False): """Split and distribute a sample across GPUs.""" if not replace_empty_samples: # pad with None until its size is equal to the number of replicas batches = batches + [None] * (self.num_replicas - len(batches)) else: # pad by cycling through the given samples batches = list(islice(cycle(batches), self.num_replicas)) assert len(batches) == self.num_replicas Future.gen_list([ self.call_async(rank, '_async_prepare_batch', batch=batches[rank]) for rank in range(self.num_replicas) ])
def load_checkpoint(self, filename): """Load a checkpoint into the model replicas in each process.""" results = Future.gen_list([ self.call_async(rank, '_async_load_checkpoint', filename=filename) for rank in range(self.num_replicas) ]) return results[0]
def check_global_overflow(self): local_over_flows = Future.gen_tuple_list([ self.call_async(rank, '_async_local_overflow') for rank in range(self.num_replicas) ]) # global_flows = sum(local_over_flows) return False
def load_optim_state_dict(self, optim_state_dict): """Load a checkpoint into the model replicas in each process.""" results = Future.gen_list([ self.call_async(rank, '_async_load_optim_state_dict', optim_state_dict=optim_state_dict) for rank in range(self.num_replicas) ]) return results[0]
def __init__(self, opt, model, loss_function, device_ids=None, multiprocessing_method='spawn'): if device_ids is None: device_ids = tuple(range(torch.cuda.device_count())) super().__init__(device_ids, multiprocessing_method) if not torch.cuda.is_available(): raise NotImplementedError('Training on CPU is not supported') print("Initializing multi-gpu training with %d devices" % self.num_replicas) model = model.share_memory() nccl_uid = nccl.get_unique_id() self.loss_function = loss_function Future.gen_list([ self.call_async(rank, '_async_init', args=opt, model=model, loss_function=loss_function, nccl_uid=nccl_uid) for rank in range(self.num_replicas) ]) self._grads_initialized = False self.initialize_gradients() self.set_seed(opt.seed)
def step(self, samples, eval=False): self._scatter_samples(samples, replace_empty_samples=False) # call the async forward function losses, logging_outputs, ooms = Future.gen_tuple_list([ self.call_async(rank, '_async_forward', eval=eval) for rank in range(self.num_replicas) ]) logging_output = aggregate_logging_outputs(logging_outputs) loss = aggregate_loss(losses) logging_output['oom'] = sum(ooms) logging_output['loss'] = loss return logging_output
def zero_grad(self): Future.gen_tuple_list([ self.call_async(rank, '_async_zero_grad') for rank in range(self.num_replicas) ])
def update_parameters(self, grad_denom=1): Future.gen_tuple_list([ self.call_async(rank, '_async_update', grad_denom=grad_denom) for rank in range(self.num_replicas) ])
def set_seed(self, seed): Future.gen_list([ self.call_async(rank, '_async_set_seed', seed=seed) for rank in range(self.num_replicas) ])