def _communicate(self): """Perform all-reduce on flattened parameters. To be rewritten by all subclasses. """ log.debug('Communicate') if self.premultiplier is None: for flat_param in self.flat_parameters: flat_param.mul_(1 / self.world_size) else: for flat_param in self.flat_parameters: flat_param.mul_(self.premultiplier / self.world_size) reqs = self.all_reduce_tensors(self.flat_parameters) for req in reqs: req.wait() if self.premultiplier is not None: for flat_param in self.flat_parameters: flat_param.mul_(1 / self.premultiplier) self.assign_unflattened_tensors(self.parameters(), self.flat_parameters)
def _communicate(self): log.debug('communicate') send = _check_reqs(self._send_reqs) if send: self._send_reqs = [] flat_params_to_send = [] for flat_param in self.flat_parameters: if self.async_op: flat_param = flat_param.clone().detach() flat_param.div_(self.n_peers + 1) flat_params_to_send.append(flat_param) else: log.debug('Fail to send') # If we can't send now, try to send again at the next step self._iter_counter -= 1 return for dst in range(self.world_size): if send and dst != self.rank: log.debug('Send') if self.G.has_predecessor(dst, self.rank): self._send_reqs += self.reduce_tensors( flat_params_to_send, dst, self.G.process_group[dst]) if dst == self.rank: # Recv if _check_reqs(self._recv_reqs): log.debug('Recv') self._recv_reqs = self.reduce_tensors( flat_params_to_send, self.rank, self.G.process_group[self.rank], self.flat_bufs) if not self.async_op: for req in self._recv_reqs + self._send_reqs: req.wait() self._recv_reqs = [] self._send_reqs = [] # Switch flat_bufs and flat_parameters to keep names consistent self.flat_bufs, self.flat_parameters = self.flat_parameters, self.flat_bufs # Re-assign parameters self.assign_unflattened_tensors(self.parameters(), self.flat_parameters) log.debug('communicate done')
def _forward_pre_hook(*args, **kwargs): if self.training: # Update iteration counter self._iter_counter += 1 self._iter_counter %= self.sync_freq log.debug('_forward_pre_hook called on %s, _iter_counter %d', self.device, self._iter_counter) if self._iter_counter == 0: self.require_backward_grad_sync = True else: self.require_backward_grad_sync = False
def _process_async_recv(self): r"""Mix received model parameters with current model. The model is updated according to: new_model = current * ratio + received * (1 - ratio), where ratio is defined by ratio = 0.0001 / (#steps between update and receive + 2). No justification for this. """ if len(self._recv_reqs) > 0 and _check_reqs(self._recv_reqs): log.debug('Update local parameters') for tensor, buf in zip(self.flat_parameters, self.flat_bufs): ratio = 0.0001 / (self._iter_counter + 2) tensor.mul_(ratio) tensor.add_(buf, alpha=1 - ratio) self._recv_reqs = []
def forward(self, *args, **kwargs): """Forward function. First, update the internal iteration counter. If communication is needed, call self._communicate(). Finally, call self.module instance. The update procedures are explicitly provided here to make the logic easier to understand. """ if self.training: # Update iteration counter self._iter_counter %= self.sync_freq self._iter_counter += 1 log.debug('forward called on %s, rank %d, _iter_counter %d', self.device, self.rank, self._iter_counter) if self._iter_counter == 1: self._communicate() return self.module(*args, **kwargs)
def eval(self): # Create validation model with torch.no_grad(): if self._val_model is None: if self.rank == 0: self._val_model = self.module else: self._val_model = deepcopy(self.module) self._val_model.eval() for p in self._val_model.parameters(): p.detach_() log.debug('Created _val_model') # Receive weigths from node 0 for p in self._val_model.parameters(): dist.broadcast(p, 0) log.debug('Updated _val_model') # Skip DistributedDataParallel's eval() function, because we don't need to communcate here. return super(DistributedDataParallel, self).eval()
def validate(model, val_loader, criterion, classes=None, device=None): log.info('Validating model') losses = [] if classes is not None: confusion_matrix = np.zeros((len(classes), len(classes))) for data, target in val_loader: target = target.to(device=device, non_blocking=True) data = data.to(device=device, non_blocking=True) output = model(data) loss = criterion(output, target) losses.append(loss.cpu().item()) if classes is not None: _, predicted = torch.max(output, 1) for i in range(len(target)): l = target[i] p = predicted[i] confusion_matrix[l][p] += 1 loss = np.mean(losses) / dist.get_world_size() loss = torch.Tensor([loss]).to(device) dist.all_reduce(loss) loss = loss.cpu().item() if classes is not None: confusion_matrix = torch.from_numpy(confusion_matrix).to(device) dist.all_reduce(confusion_matrix) confusion_matrix = confusion_matrix.cpu().numpy() log.debug('Synchronized from other wokers') if classes is not None: acc = np.diag(confusion_matrix).sum() / confusion_matrix.sum() confusion_matrix /= confusion_matrix.sum(axis=1) # log.debug(confusion_matrix) max_len = str(max([len(str(c)) for c in classes])) if len(classes) > 10: log.info('Accuracy of first 5 classes') for i in range(5): log.info('%-' + max_len + 's: %8.5f%%', classes[i], 100 * confusion_matrix[i, i]) log.info('Accuracy of last 5 classes') for i in range(len(classes) - 5, len(classes)): log.info('%-' + max_len + 's: %8.5f%%', classes[i], 100 * confusion_matrix[i, i]) else: log.info('Accuracy of each class') for i in range(len(classes)): log.info('%-' + max_len + 's: %8.5f%%', classes[i], 100 * confusion_matrix[i, i]) log.info('Validation loss %.5f, accuracy %.5f%%', loss, acc * 100) return loss, acc else: log.info('Validation loss %.5f', loss) return [loss]
def _save(): if args.rank == 0: fname = get_fname(args, exp_name=None) save_data(train_res, val_res, fname, output_dir=args.output_dir) log.debug('Data saved to %s', fname)
def train(model, criterion, optimizer, train_loader, args, val_loader=None, exp_name=None, classes=None, scheduler=None): if args.apex: from apex import amp def _val(): if args.val_interval is not None: val_start = time() model.eval() val_res.append([ i, train_time, run_time, *validate(model, val_loader, criterion, classes=classes, device=args.device) ]) model.train() val_end = time() return val_end - val_start else: return 0 def _save(): if args.rank == 0: fname = get_fname(args, exp_name=None) save_data(train_res, val_res, fname, output_dir=args.output_dir) log.debug('Data saved to %s', fname) def _eta(): _time = train_time / i * (total_batches - i) if args.val_interval is not None: _time += val_time / (i // args.val_interval + 1) * ( (total_batches - i) // args.val_interval + 1) h = _time / 3600 if h > 1: return "%.2fh" % h m = _time / 60 if m > 1: return "%.2fm" % m return "%.2fs" % _time total_batches = len(train_loader) * args.epochs train_res = [] val_res = [] running_loss = [] running_acc = [] i = 0 val_time = run_time = train_time = 0 train_start = time() printed = False val_time += _val() log.info('Training started') model.train() optimizer.zero_grad() if args.gradient_accumulation and args.ddp == 'pytorch': model.require_backward_grad_sync = False for epoch in range(1, args.epochs + 1): for _, (data, target) in enumerate(train_loader): i += 1 target = target.to(device=args.device, non_blocking=True) data = data.to(device=args.device, non_blocking=True) if args.ddp == 'pytorch': if args.gradient_accumulation and i % args.sync_freq != 0: model.require_backward_grad_sync = False else: model.require_backward_grad_sync = True # ==== Step begin ==== output = model(data) loss = criterion(output, target) if args.gradient_accumulation: loss /= args.sync_freq if args.apex: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if args.ddp == 'DistributedGradientParallel' and printed == False: for n, p in model.named_parameters(): log.warn( '%s.grad.dtype = %s, max difference between original grad and half precision grad is %f', n, p.grad.dtype, (p.grad - p.grad.clone().half()).abs().max()) printed = True if not args.gradient_accumulation or i % args.sync_freq == 0: log.debug('[%d/%d, %5d/%d] optimizer step', epoch, args.epochs, i, total_batches) optimizer.step() optimizer.zero_grad() loss = loss.item() running_loss.append(loss) if classes is not None: acc = accuracy(output, target).item() running_acc.append(acc) # ==== Step done ==== current_time = time() run_time = current_time - train_start train_time = run_time - val_time if args.gradient_accumulation: tmp_res = [i, train_time, run_time, loss * args.sync_freq] else: tmp_res = [i, train_time, run_time, loss] if classes is not None: tmp_res += [acc] train_res.append(tmp_res) if i % args.disp_interval == 0: log.info( '[%d/%d, %5d/%d] local running loss %.5f, local running acc %.5f%%, average train time %.4f seconds per batch, eta %s', epoch, args.epochs, i, total_batches, np.mean(running_loss), np.mean(running_acc) * 100, train_time / i, _eta()) running_loss = [] running_acc = [] if args.val_interval is not None and i % args.val_interval == 0: val_time += _val() # Update saved data after every validation _save() # end for current_time = time() run_time = current_time - train_start train_time = run_time - val_time log.info( 'Training epoch %d ends, total run time %.4f seconds, average train time %.4f seconds per batch', epoch, run_time, train_time / i) if scheduler is not None: log.debug('schedule.step() called') scheduler.step() if args.val_interval is not None and i % args.val_interval != 0: val_time += _val() current_time = time() run_time = current_time - train_start train_time = run_time - val_time _save() if classes is not None: best_acc = max([x[-1] for x in val_res]) log.info( 'Training finished, %d epochs, final val loss %.5f, final val acc %.5f%%, best val acc %.5f%%', epoch, val_res[-1][-2], val_res[-1][-1] * 100, best_acc * 100) else: log.info('Training finished, %d epochs, final val loss %.5f', epoch, val_res[-1][-1]) return train_res, val_res
def __init__(self, module, world_local_size=None, node_rank=None, local_rank=None, sync_freq=1, num_streams=1, premultiplier=None, **kwargs): r"""Init function. Args: module: The module to be wrapped. sync_freq: Number of steps between communications. num_streams: Number of CUDA streams to use for communication. premultiplier: The multiplier to be applied before communication. If not none, parameters will be multiplied by pre-multiplier before communication, then divided by the pre-multiplier after communication. """ super().__init__() log.info('Using %s', self.__class__.__name__) self.module = module self.device = next(self.module.parameters()).device # Assume torch.dist is initialized self.rank = dist.get_rank() self.world_size = dist.get_world_size() self.local_rank = local_rank if local_rank is not None else self.rank self.node_rank = node_rank if node_rank is not None else 0 self.world_local_size = world_local_size if world_local_size is not None else 1 # When the counter equals to sync_freq, perform communication and reset self.premultiplier = premultiplier self.sync_freq = sync_freq self._iter_counter = 0 self.param_info = [{ 'numel': param.numel(), 'shape': param.shape } for param in self.parameters()] self.flat_parameters, self.flat_indexes = self.flatten_tensors( list(self.parameters())) self.assign_unflattened_tensors(self.parameters(), self.flat_parameters) log.debug('Broadcasting init params') for param in self.flat_parameters: dist.broadcast(param, 0) log.debug('Broadcasting init params done') self.num_streams = num_streams if self.device.type == 'cuda': self.streams = [ torch.cuda.Stream() for _ in range(self.num_streams) ]