示例#1
0
 def get_preds_dist(self, dataset='valid', with_target=False):
     self.model.eval()
     if dataset == 'train':
         train_sampler = OrderedDistributedSampler(self.train_dl.dataset,
                                                   get_world_size(),
                                                   rank=get_rank())
         ordered_dist_train_dl = DataLoader(
             self.train_dl.dataset,
             self.train_dl.batch_size,
             shuffle=False,
             sampler=train_sampler,
             num_workers=self.train_dl.num_workers,
             collate_fn=self.train_dl.collate_fn,
             pin_memory=self.train_dl.pin_memory,
             timeout=self.train_dl.timeout,
             worker_init_fn=self.train_dl.worker_init_fn)
         bar = tqdm(ordered_dist_train_dl) if is_main_process(
         ) else ordered_dist_train_dl
     else:
         valid_sampler = OrderedDistributedSampler(self.valid_dl.dataset,
                                                   get_world_size(),
                                                   rank=get_rank())
         ordered_dist_valid_dl = DataLoader(
             self.valid_dl.dataset,
             self.valid_dl.batch_size,
             shuffle=False,
             sampler=valid_sampler,
             num_workers=self.valid_dl.num_workers,
             collate_fn=self.valid_dl.collate_fn,
             pin_memory=self.valid_dl.pin_memory,
             timeout=self.valid_dl.timeout,
             worker_init_fn=self.valid_dl.worker_init_fn)
         bar = tqdm(ordered_dist_valid_dl) if is_main_process(
         ) else ordered_dist_valid_dl
     outputs = []
     targets = []
     for batch in bar:
         x, y = batch_gpu(batch)
         output = self.model(x)
         output = to_cpu(output)
         outputs.append(output)
         if with_target:
             targets.append(to_cpu(y))
     outputs = torch.cat(outputs)
     all_outputs = all_gather(outputs)
     if with_target:
         targets = torch.cat(targets)
         all_targets = all_gather(targets)
     if not is_main_process():
         return
     all_outputs = torch.cat(all_outputs,
                             dim=0).cpu()[:len(self.valid_dl.dataset)]
     if with_target:
         all_targets = torch.cat(all_targets,
                                 dim=0).cpu()[:len(self.valid_dl.dataset)]
         return all_outputs, all_targets
     else:
         return all_outputs
示例#2
0
 def on_epoch_end(self, epoch: int, **kwargs) -> None:
     "Compare the value monitored to its best score and maybe save the model."
     if self.every == "epoch":
         self.learn.save(f'{self.name}_{epoch}')
     else:  # every="improvement"
         c = self.get_monitor_value()
         world_size = get_world_size()
         if world_size == 1:
             current = c
             if current is not None and self.operator(current, self.best):
                 print(
                     f'Better model found at epoch {epoch} with {self.monitor} value: {current}.'
                 )
                 self.best = current
                 self.learn.save(f'{self.name}')
         else:
             with torch.no_grad():
                 c = torch.tensor(c).cuda()
                 dist.reduce(c, dst=0)
                 if get_rank() == 0:
                     current = c / world_size
                     current = current.data
                     if current is not None and current < self.best:
                         print(
                             f'Better model found at epoch {epoch} with {self.monitor} value: {current}.'
                         )
                         self.best = current
                         self.learn.save(f'{self.name}')
示例#3
0
 def to_parallel(self):
     assert self.state == TrainerState.BASE
     devices = os.environ['CUDA_VISIBLE_DEVICES']
     print('visible devices', devices)
     self.model = DataParallel(self.model)
     if isinstance(self.scheduler, OneCycleScheduler):
         world_size = get_world_size()
         self.scheduler.total_steps //= world_size
         self.scheduler.step_size_up //= world_size
         self.scheduler.step_size_down //= world_size
示例#4
0
 def to_base(self):
     if self.state == TrainerState.BASE:
         return
     elif self.state == TrainerState.PARALLEL:
         self.model = self.model.module
         if isinstance(self.scheduler, OneCycleScheduler):
             world_size = get_world_size()
             self.scheduler.total_steps *= world_size
             self.scheduler.step_size_up *= world_size
             self.scheduler.step_size_down *= world_size
     else:
         self.model = self.model.module
         self.train_dl = self.old_train_dl
         self.valid_dl = self.old_valid_dl
         if isinstance(self.scheduler, OneCycleScheduler):
             world_size = get_world_size()
             self.scheduler.total_steps *= world_size
             self.scheduler.step_size_up *= world_size
             self.scheduler.step_size_down *= world_size
示例#5
0
def reduce_loss(loss):
    """
    Reduce the loss from all processes so that process with rank
    0 has the averaged results.
    """
    world_size = get_world_size()
    if world_size < 2:
        return loss
    with torch.no_grad():
        dist.reduce(loss, dst=0)
        if dist.get_rank() == 0:
            # only main process gets accumulated, so only divide by
            # world_size in this case
            loss /= world_size
    return loss
示例#6
0
 def get_preds(self, dataset='valid', with_target=False):
     if get_world_size() > 1:
         return self.get_preds_dist(dataset, with_target)
     self.model.eval()
     assert dataset in ['train', 'valid']
     if dataset == 'train':
         ordered_train_dl = DataLoader(
             self.train_dl.dataset,
             self.train_dl.batch_size,
             shuffle=False,
             sampler=None,
             num_workers=self.train_dl.num_workers,
             collate_fn=self.train_dl.collate_fn,
             pin_memory=self.train_dl.pin_memory,
             timeout=self.train_dl.timeout,
             worker_init_fn=self.train_dl.worker_init_fn)
         bar = tqdm(ordered_train_dl)
     else:
         ordered_valid_dl = DataLoader(
             self.valid_dl.dataset,
             self.valid_dl.batch_size,
             shuffle=False,
             sampler=None,
             num_workers=self.valid_dl.num_workers,
             collate_fn=self.valid_dl.collate_fn,
             pin_memory=self.valid_dl.pin_memory,
             timeout=self.valid_dl.timeout,
             worker_init_fn=self.valid_dl.worker_init_fn)
         bar = tqdm(ordered_valid_dl)
     outputs = []
     targets = []
     for batch in bar:
         x, y = batch_gpu(batch)
         output = self.model(x)
         output = to_cpu(output)
         outputs.append(output)
         if with_target:
             targets.append(to_cpu(y))
     outputs = torch.cat(outputs)
     if with_target:
         targets = torch.cat(targets)
         return outputs, targets
     else:
         return outputs
示例#7
0
def reduce_loss_dict(loss_dict):
    """
    Reduce the loss dictionary from all processes so that process with rank
    0 has the averaged results. Returns a dict with the same fields as
    loss_dict, after reduction.
    """
    world_size = get_world_size()
    if world_size < 2:
        return loss_dict
    with torch.no_grad():
        loss_names = []
        all_losses = []
        for k in sorted(loss_dict.keys()):
            loss_names.append(k)
            all_losses.append(loss_dict[k])
        all_losses = torch.stack(all_losses, dim=0)
        dist.reduce(all_losses, dst=0)
        if dist.get_rank() == 0:
            # only main process gets accumulated, so only divide by
            # world_size in this case
            all_losses /= world_size
        reduced_losses = {k: v for k, v in zip(loss_names, all_losses)}
    return reduced_losses
示例#8
0
 def to_distributed(self):
     assert dist.is_available() and dist.is_initialized()
     local_rank = dist.get_rank()
     self.model = DistributedDataParallel(self.model, [local_rank],
                                          output_device=local_rank,
                                          broadcast_buffers=False)
     self.old_train_dl = self.train_dl
     train_sampler = DistributedSampler(self.train_dl.dataset, shuffle=True)
     new_train_dl = DataLoader(self.train_dl.dataset,
                               self.train_dl.batch_size,
                               shuffle=False,
                               sampler=train_sampler,
                               num_workers=self.train_dl.num_workers,
                               collate_fn=self.train_dl.collate_fn,
                               pin_memory=self.train_dl.pin_memory,
                               timeout=self.train_dl.timeout,
                               worker_init_fn=self.train_dl.worker_init_fn)
     self.train_dl = new_train_dl
     self.old_valid_dl = self.valid_dl
     valid_sampler = DistributedSampler(self.valid_dl.dataset,
                                        shuffle=False)
     new_valid_dl = DataLoader(self.valid_dl.dataset,
                               self.valid_dl.batch_size,
                               shuffle=False,
                               sampler=valid_sampler,
                               num_workers=self.valid_dl.num_workers,
                               collate_fn=self.valid_dl.collate_fn,
                               pin_memory=self.valid_dl.pin_memory,
                               timeout=self.valid_dl.timeout,
                               worker_init_fn=self.valid_dl.worker_init_fn)
     self.valid_dl = new_valid_dl
     if isinstance(self.scheduler, OneCycleScheduler):
         world_size = get_world_size()
         self.scheduler.total_steps /= world_size
         self.scheduler.step_size_up /= world_size
         self.scheduler.step_size_down /= world_size