def predict(self): """The testing process. """ self.net.eval() dataloader = self.test_dataloader pbar = tqdm(dataloader, desc='test', ascii=True) epoch_log = EpochLog() for i, batch in enumerate(pbar): with torch.no_grad(): test_dict = self._test_step(batch) loss = test_dict['loss'] losses = test_dict.get('losses') metrics = test_dict.get('metrics') if (i + 1) == len(dataloader) and not dataloader.drop_last: batch_size = len(dataloader.dataset) % dataloader.batch_size else: batch_size = dataloader.batch_size epoch_log.update(batch_size, loss, losses, metrics) pbar.set_postfix(**epoch_log.on_step_end_log) test_log = epoch_log.on_epoch_end_log LOGGER.info(f'Test log: {test_log}.') return test_log
def _run_epoch(self, mode): """Run an epoch for training. Args: mode (str): The mode of running an epoch ('train' or 'valid'). Returns: log (dict): The log information. batch (dict or sequence): The last batch of the data. outputs (torch.Tensor or sequence of torch.Tensor): The corresponding model outputs. """ if mode == 'train': self.net.train() dataloader = self.train_dataloader dataloader_iterator = iter(dataloader) outter_pbar = tqdm(total=( (len(dataloader) + dataloader.grad_accumulation_steps() - 1) // dataloader.grad_accumulation_steps()), desc='train', ascii=True) inner_pbar = tqdm(total=math.ceil( dataloader.grad_accumulation_steps(0)), desc='grad_accumulation', leave=False, ascii=True) epoch_log = EpochLog() for i in range(len(dataloader)): batch = next(dataloader_iterator) train_dict = self._train_step(batch) loss = train_dict.get('loss') if loss is None: raise KeyError( f"The train_dict must have the key named 'loss'. " 'Please check the returned keys as defined in MyTrainer._train_step().' ) losses = train_dict.get('losses') metrics = train_dict.get('metrics') outputs = train_dict.get('outputs') if self.use_amp: with amp.scale_loss(loss, self.optimizer) as scaled_loss: (scaled_loss / dataloader.grad_accumulation_steps(i)).backward() else: (loss / dataloader.grad_accumulation_steps(i)).backward() if (i + 1) % dataloader.grad_accumulation_steps() == 0 or ( i + 1) == len(dataloader): self.optimizer.step() self.optimizer.zero_grad() if isinstance(self.lr_scheduler, CyclicLR): self.lr_scheduler.step() elif isinstance(self.lr_scheduler, CosineAnnealingWarmRestarts): self.lr_scheduler.step((self.epoch - 1) + i / len(dataloader)) if (i + 1) == len(dataloader) and not dataloader.drop_last: batch_size = len( dataloader.dataset) % dataloader.batch_size else: batch_size = dataloader.batch_size epoch_log.update(batch_size, loss, losses, metrics) inner_pbar.update() if (i + 1) % dataloader.grad_accumulation_steps() == 0 or ( i + 1) == len(dataloader): outter_pbar.update() outter_pbar.set_postfix(**epoch_log.on_step_end_log) inner_pbar.close() inner_pbar = tqdm(total=math.ceil( dataloader.grad_accumulation_steps(i + 1)), desc='grad_accumulation', leave=False, ascii=True) outter_pbar.close() inner_pbar.close() else: self.net.eval() dataloader = self.valid_dataloader pbar = tqdm(dataloader, desc='valid', ascii=True) epoch_log = EpochLog() for i, batch in enumerate(pbar): with torch.no_grad(): valid_dict = self._valid_step(batch) loss = valid_dict.get('loss') if loss is None: raise KeyError( f"The valid_dict must have the key named 'loss'. " 'Please check the returned keys as defined in MyTrainer._valid_step().' ) losses = valid_dict.get('losses') metrics = valid_dict.get('metrics') outputs = valid_dict.get('outputs') if (i + 1) == len(dataloader) and not dataloader.drop_last: batch_size = len( dataloader.dataset) % dataloader.batch_size else: batch_size = dataloader.batch_size epoch_log.update(batch_size, loss, losses, metrics) pbar.set_postfix(**epoch_log.on_step_end_log) return epoch_log.on_epoch_end_log, batch, outputs
def _run_epoch(self, mode): """Run an epoch for training. Args: mode (str): The mode of running an epoch ('train' or 'valid'). Returns: log (dict): The log information. batch (dict or sequence): The last batch of the data. outputs (torch.Tensor or sequence of torch.Tensor): The corresponding model outputs. """ for fn in self.metric_fns: fn.reset_count() if mode == 'train': self.net.train() dataloader = self.train_dataloader dataloader_iterator = iter(dataloader) outter_pbar = tqdm(total=len(dataloader), desc=mode, ascii=True) epoch_log = EpochLog() for i in range(len(dataloader)): batch = next(dataloader_iterator) train_dict = self._train_step(batch) loss = train_dict['loss'] losses = train_dict.get('losses') metrics = train_dict.get('metrics') outputs = train_dict.get('outputs') lr = self.optimizer.param_groups[0]['lr'] loss.backward() if self.grad_norm: torch.nn.utils.clip_grad_norm_(self.net.parameters(), 1.0) self.optimizer.step() self.optimizer.zero_grad() if (i + 1) == len(dataloader) and not dataloader.drop_last: batch_size = len( dataloader.dataset) % dataloader.batch_size else: batch_size = dataloader.batch_size epoch_log.update(batch_size, loss, losses, metrics, lr) outter_pbar.update() outter_pbar.set_postfix(**epoch_log.on_step_end_log) outter_pbar.close() else: self.net.eval() dataloader = self.valid_dataloader pbar = tqdm(dataloader, desc=mode, ascii=True) epoch_log = EpochLog() for i, batch in enumerate(pbar): with torch.no_grad(): valid_dict = self._valid_step(batch) loss = valid_dict['loss'] losses = valid_dict.get('losses') metrics = valid_dict.get('metrics') outputs = valid_dict.get('outputs') if (i + 1) == len(dataloader) and not dataloader.drop_last: batch_size = len( dataloader.dataset) % dataloader.batch_size else: batch_size = dataloader.batch_size epoch_log.update(batch_size, loss, losses, metrics) pbar.set_postfix(**epoch_log.on_step_end_log) return epoch_log.on_epoch_end_log, batch, outputs
def _run_epoch(self, mode): """Run an epoch for training. Args: mode (str): The mode of running an epoch ('train' or 'valid'). Returns: log (dict): The log information. batch (dict or sequence): The last batch of the data. outputs (torch.Tensor or sequence of torch.Tensor): The corresponding model outputs. """ if mode == 'train': self.net.train() dataloader = self.train_dataloader dataloader_iterator = iter(dataloader) outter_pbar = tqdm(total=len(dataloader), desc=mode, ascii=True) epoch_log = EpochLog() for i in range(len(dataloader)): batch = next(dataloader_iterator) train_dict = self._train_step(batch) loss = train_dict['loss'] losses = train_dict.get('losses') metrics = train_dict.get('metrics') outputs = train_dict.get('outputs') lr = self.optimizer.param_groups[0]['lr'] loss.backward() self.optimizer.step() self.optimizer.zero_grad() # if isinstance(self.lr_scheduler, (CyclicLR, OneCycleLR)): # self.lr_scheduler.step() # elif isinstance(self.lr_scheduler, CosineAnnealingWarmRestarts): # self.lr_scheduler.step((self.epoch - 1) + i / len(dataloader)) if (i + 1) == len(dataloader) and not dataloader.drop_last: batch_size = len( dataloader.dataset) % dataloader.batch_size else: batch_size = dataloader.batch_size epoch_log.update(batch_size, loss, losses, metrics, lr) outter_pbar.update() outter_pbar.set_postfix(**epoch_log.on_step_end_log) outter_pbar.close() else: self.net.eval() dataloader = self.valid_dataloader pbar = tqdm(dataloader, desc=mode, ascii=True) epoch_log = EpochLog() for i, batch in enumerate(pbar): with torch.no_grad(): valid_dict = self._valid_step(batch) loss = valid_dict['loss'] losses = valid_dict.get('losses') metrics = valid_dict.get('metrics') outputs = valid_dict.get('outputs') if (i + 1) == len(dataloader) and not dataloader.drop_last: batch_size = len( dataloader.dataset) % dataloader.batch_size else: batch_size = dataloader.batch_size epoch_log.update(batch_size, loss, losses, metrics) pbar.set_postfix(**epoch_log.on_step_end_log) return epoch_log.on_epoch_end_log, batch, outputs