def __init__(self, args, model, criterion, device_ids=None, multiprocessing_method='spawn'): if device_ids is None: device_ids = tuple(range(torch.cuda.device_count())) super().__init__(device_ids, multiprocessing_method) if not torch.cuda.is_available(): raise NotImplementedError('Training on CPU is not supported') model = model.share_memory() nccl_uid = nccl.get_unique_id() self.criterion = criterion Future.gen_list([ self.call_async(rank, '_async_init', args=args, model=model, criterion=criterion, nccl_uid=nccl_uid) for rank in range(self.num_replicas) ]) self._grads_initialized = False
def __init__(self, args, model, device_ids=None, multiprocessing_method='spawn', src_dict=None, dst_dict=None): if device_ids is None: device_ids = tuple(range(torch.cuda.device_count())) super().__init__(device_ids, multiprocessing_method) if not torch.cuda.is_available(): raise NotImplementedError('Training on CPU is not supported') model = model.share_memory() nccl_uid = nccl.get_unique_id() Future.gen_list([ self.call_async(rank, '_async_init', args=args, model=model, src_dict=src_dict, dst_dict=dst_dict, nccl_uid=nccl_uid) for rank in range(self.num_replicas) ]) self.enable_rl = args.enable_rl self.args = args
def _scatter_samples(self, samples, volatile=False): """Split and distribute a sample across GPUs.""" # Pad with None until its size is equal to the number of replicas. samples = samples + [None]*(self.num_replicas - len(samples)) Future.gen_list([ self.call_async(rank, '_async_prepare_sample', sample=samples[rank], volatile=volatile) for rank in range(self.num_replicas) ])
def _scatter_samples(self, samples, volatile=False, replace_empty_samples=False): """Split and distribute a sample across GPUs.""" if not replace_empty_samples: # pad with None until its size is equal to the number of replicas samples = samples + [None]*(self.num_replicas - len(samples)) else: # pad by cycling through the given samples samples = list(islice(cycle(samples), self.num_replicas)) Future.gen_list([ self.call_async(rank, '_async_prepare_sample', sample=samples[rank], volatile=volatile) for rank in range(self.num_replicas) ])
def lr_step(self, val_loss=None, epoch=None): """Adjust the learning rate depending on the validation loss.""" lr = Future.gen_list([ self.call_async(rank, '_async_lr_step', val_loss=val_loss, epoch=epoch) for rank in range(self.num_replicas) ]) return lr[0]
def train_step(self, samples): """Do forward, backward and gradient step in parallel.""" # PyTorch initializes gradient buffers lazily, so the first # train step needs to send non-empty samples to all replicas replace_empty_samples = False if not self._grads_initialized: replace_empty_samples = True self._grads_initialized = True # scatter sample across GPUs self._scatter_samples(samples, replace_empty_samples=replace_empty_samples) # forward pass sample_sizes, logging_outputs = Future.gen_tuple_list([ self.call_async(rank, '_async_forward') for rank in range(self.num_replicas) ]) # backward pass, all-reduce gradients and take an optimization step grad_denom = self.criterion.__class__.grad_denom(sample_sizes) grad_norms = Future.gen_list([ self.call_async(rank, '_async_backward_and_opt', grad_denom=grad_denom) for rank in range(self.num_replicas) ]) # aggregate logging output logging_output = self.criterion.__class__.aggregate_logging_outputs( logging_outputs) logging_output['gnorm'] = grad_norms[0] # log the gradient norm return logging_output
def load_checkpoint(self, filename): """Load a checkpoint into the model replicas in each process.""" results = Future.gen_list([ self.call_async(rank, '_async_load_checkpoint', filename=filename) for rank in range(self.num_replicas) ]) extra_state = results[0] return extra_state
def valid_step(self, samples, criterion): """Do forward pass in parallel.""" # scatter sample across GPUs self._scatter_samples(samples, volatile=True) criterion.prepare(samples) # forward pass losses = [ self.call_async(rank, '_async_valid_step', criterion=criterion) for rank in range(self.num_replicas) ] # aggregate losses loss = criterion.aggregate(Future.gen_list(losses)) return loss
def valid_step(self, samples, criterion): """Do forward pass in parallel.""" # scatter sample across GPUs samples, data_events = self._scatter_samples(samples, volatile=True) criterion.prepare(samples) # forward pass losses = [ self.call_async(rank, '_async_valid_step', sample=samples[rank], criterion=criterion, data_event=event) for rank, event in enumerate(data_events) ] # aggregate losses loss = criterion.aggregate(Future.gen_list(losses)) return loss
def set_seed(self, seed): Future.gen_list([ self.call_async(rank, '_async_set_seed', seed=seed) for rank in range(self.num_replicas) ])