def get_preds_dist(self, dataset='valid', with_target=False): self.model.eval() if dataset == 'train': train_sampler = OrderedDistributedSampler(self.train_dl.dataset, get_world_size(), rank=get_rank()) ordered_dist_train_dl = DataLoader( self.train_dl.dataset, self.train_dl.batch_size, shuffle=False, sampler=train_sampler, num_workers=self.train_dl.num_workers, collate_fn=self.train_dl.collate_fn, pin_memory=self.train_dl.pin_memory, timeout=self.train_dl.timeout, worker_init_fn=self.train_dl.worker_init_fn) bar = tqdm(ordered_dist_train_dl) if is_main_process( ) else ordered_dist_train_dl else: valid_sampler = OrderedDistributedSampler(self.valid_dl.dataset, get_world_size(), rank=get_rank()) ordered_dist_valid_dl = DataLoader( self.valid_dl.dataset, self.valid_dl.batch_size, shuffle=False, sampler=valid_sampler, num_workers=self.valid_dl.num_workers, collate_fn=self.valid_dl.collate_fn, pin_memory=self.valid_dl.pin_memory, timeout=self.valid_dl.timeout, worker_init_fn=self.valid_dl.worker_init_fn) bar = tqdm(ordered_dist_valid_dl) if is_main_process( ) else ordered_dist_valid_dl outputs = [] targets = [] for batch in bar: x, y = batch_gpu(batch) output = self.model(x) output = to_cpu(output) outputs.append(output) if with_target: targets.append(to_cpu(y)) outputs = torch.cat(outputs) all_outputs = all_gather(outputs) if with_target: targets = torch.cat(targets) all_targets = all_gather(targets) if not is_main_process(): return all_outputs = torch.cat(all_outputs, dim=0).cpu()[:len(self.valid_dl.dataset)] if with_target: all_targets = torch.cat(all_targets, dim=0).cpu()[:len(self.valid_dl.dataset)] return all_outputs, all_targets else: return all_outputs
def on_epoch_end(self, epoch: int, **kwargs) -> None: "Compare the value monitored to its best score and maybe save the model." if self.every == "epoch": self.learn.save(f'{self.name}_{epoch}') else: # every="improvement" c = self.get_monitor_value() world_size = get_world_size() if world_size == 1: current = c if current is not None and self.operator(current, self.best): print( f'Better model found at epoch {epoch} with {self.monitor} value: {current}.' ) self.best = current self.learn.save(f'{self.name}') else: with torch.no_grad(): c = torch.tensor(c).cuda() dist.reduce(c, dst=0) if get_rank() == 0: current = c / world_size current = current.data if current is not None and current < self.best: print( f'Better model found at epoch {epoch} with {self.monitor} value: {current}.' ) self.best = current self.learn.save(f'{self.name}')
def to_parallel(self): assert self.state == TrainerState.BASE devices = os.environ['CUDA_VISIBLE_DEVICES'] print('visible devices', devices) self.model = DataParallel(self.model) if isinstance(self.scheduler, OneCycleScheduler): world_size = get_world_size() self.scheduler.total_steps //= world_size self.scheduler.step_size_up //= world_size self.scheduler.step_size_down //= world_size
def to_base(self): if self.state == TrainerState.BASE: return elif self.state == TrainerState.PARALLEL: self.model = self.model.module if isinstance(self.scheduler, OneCycleScheduler): world_size = get_world_size() self.scheduler.total_steps *= world_size self.scheduler.step_size_up *= world_size self.scheduler.step_size_down *= world_size else: self.model = self.model.module self.train_dl = self.old_train_dl self.valid_dl = self.old_valid_dl if isinstance(self.scheduler, OneCycleScheduler): world_size = get_world_size() self.scheduler.total_steps *= world_size self.scheduler.step_size_up *= world_size self.scheduler.step_size_down *= world_size
def reduce_loss(loss): """ Reduce the loss from all processes so that process with rank 0 has the averaged results. """ world_size = get_world_size() if world_size < 2: return loss with torch.no_grad(): dist.reduce(loss, dst=0) if dist.get_rank() == 0: # only main process gets accumulated, so only divide by # world_size in this case loss /= world_size return loss
def get_preds(self, dataset='valid', with_target=False): if get_world_size() > 1: return self.get_preds_dist(dataset, with_target) self.model.eval() assert dataset in ['train', 'valid'] if dataset == 'train': ordered_train_dl = DataLoader( self.train_dl.dataset, self.train_dl.batch_size, shuffle=False, sampler=None, num_workers=self.train_dl.num_workers, collate_fn=self.train_dl.collate_fn, pin_memory=self.train_dl.pin_memory, timeout=self.train_dl.timeout, worker_init_fn=self.train_dl.worker_init_fn) bar = tqdm(ordered_train_dl) else: ordered_valid_dl = DataLoader( self.valid_dl.dataset, self.valid_dl.batch_size, shuffle=False, sampler=None, num_workers=self.valid_dl.num_workers, collate_fn=self.valid_dl.collate_fn, pin_memory=self.valid_dl.pin_memory, timeout=self.valid_dl.timeout, worker_init_fn=self.valid_dl.worker_init_fn) bar = tqdm(ordered_valid_dl) outputs = [] targets = [] for batch in bar: x, y = batch_gpu(batch) output = self.model(x) output = to_cpu(output) outputs.append(output) if with_target: targets.append(to_cpu(y)) outputs = torch.cat(outputs) if with_target: targets = torch.cat(targets) return outputs, targets else: return outputs
def reduce_loss_dict(loss_dict): """ Reduce the loss dictionary from all processes so that process with rank 0 has the averaged results. Returns a dict with the same fields as loss_dict, after reduction. """ world_size = get_world_size() if world_size < 2: return loss_dict with torch.no_grad(): loss_names = [] all_losses = [] for k in sorted(loss_dict.keys()): loss_names.append(k) all_losses.append(loss_dict[k]) all_losses = torch.stack(all_losses, dim=0) dist.reduce(all_losses, dst=0) if dist.get_rank() == 0: # only main process gets accumulated, so only divide by # world_size in this case all_losses /= world_size reduced_losses = {k: v for k, v in zip(loss_names, all_losses)} return reduced_losses
def to_distributed(self): assert dist.is_available() and dist.is_initialized() local_rank = dist.get_rank() self.model = DistributedDataParallel(self.model, [local_rank], output_device=local_rank, broadcast_buffers=False) self.old_train_dl = self.train_dl train_sampler = DistributedSampler(self.train_dl.dataset, shuffle=True) new_train_dl = DataLoader(self.train_dl.dataset, self.train_dl.batch_size, shuffle=False, sampler=train_sampler, num_workers=self.train_dl.num_workers, collate_fn=self.train_dl.collate_fn, pin_memory=self.train_dl.pin_memory, timeout=self.train_dl.timeout, worker_init_fn=self.train_dl.worker_init_fn) self.train_dl = new_train_dl self.old_valid_dl = self.valid_dl valid_sampler = DistributedSampler(self.valid_dl.dataset, shuffle=False) new_valid_dl = DataLoader(self.valid_dl.dataset, self.valid_dl.batch_size, shuffle=False, sampler=valid_sampler, num_workers=self.valid_dl.num_workers, collate_fn=self.valid_dl.collate_fn, pin_memory=self.valid_dl.pin_memory, timeout=self.valid_dl.timeout, worker_init_fn=self.valid_dl.worker_init_fn) self.valid_dl = new_valid_dl if isinstance(self.scheduler, OneCycleScheduler): world_size = get_world_size() self.scheduler.total_steps /= world_size self.scheduler.step_size_up /= world_size self.scheduler.step_size_down /= world_size