Exemple #1
0
 def validate_one_epoch(self, data_loader):
     self.eval()
     self.model_state = enums.ModelState.VALID
     losses = AverageMeter()
     if self.using_tpu:
         tk0 = data_loader
     else:
         tk0 = tqdm(data_loader, total=len(data_loader))
     for b_idx, data in enumerate(tk0):
         self.train_state = enums.TrainingState.VALID_STEP_START
         with torch.no_grad():
             loss, metrics = self.validate_one_step(data)
         self.train_state = enums.TrainingState.VALID_STEP_END
         losses.update(loss.item(), data_loader.batch_size)
         if b_idx == 0:
             metrics_meter = {k: AverageMeter() for k in metrics}
         monitor = {}
         for m_m in metrics_meter:
             metrics_meter[m_m].update(metrics[m_m], data_loader.batch_size)
             monitor[m_m] = metrics_meter[m_m].avg
         if not self.using_tpu:
             tk0.set_postfix(loss=losses.avg, stage="valid", **monitor)
         self.current_valid_step += 1
     if not self.using_tpu:
         tk0.close()
     self.update_metrics(losses=losses, monitor=monitor)
     return losses.avg
Exemple #2
0
    def train_one_epoch(self, data_loader):
        self.train()
        self.model_state = enums.ModelState.TRAIN
        losses = AverageMeter()
        if self.accumulation_steps > 1:
            self.optimizer.zero_grad()
        if self.using_tpu:
            tk0 = data_loader
        else:
            tk0 = tqdm(data_loader, total=len(data_loader))
        for b_idx, data in enumerate(tk0):
            self.batch_index = b_idx
            self.train_state = enums.TrainingState.TRAIN_STEP_START
            loss, metrics = self.train_one_step(data)
            self.train_state = enums.TrainingState.TRAIN_STEP_END
            losses.update(loss.item() * self.accumulation_steps,
                          data_loader.batch_size)
            if b_idx == 0:
                metrics_meter = {k: AverageMeter() for k in metrics}
            monitor = {}
            for m_m in metrics_meter:
                metrics_meter[m_m].update(metrics[m_m], data_loader.batch_size)
                monitor[m_m] = metrics_meter[m_m].avg
            self.current_train_step += 1
            if not self.using_tpu:
                tk0.set_postfix(loss=losses.avg, stage="train", **monitor)
            if self.using_tpu:
                print(
                    f"train step: {self.current_train_step} loss: {losses.avg}"
                )
        if not self.using_tpu:
            tk0.close()
        self.update_metrics(losses=losses, monitor=monitor)

        return losses.avg
Exemple #3
0
 def train_one_epoch(self, data_loader):
     self.train()
     self.model_state = enums.ModelState.TRAIN
     losses = AverageMeter()
     tk0 = tqdm(data_loader, total=len(data_loader))
     for b_idx, data in enumerate(tk0):
         self.train_state = enums.TrainingState.TRAIN_STEP_START
         loss, metrics = self.train_one_step(data)
         self.train_state = enums.TrainingState.TRAIN_STEP_END
         losses.update(loss.item(), data_loader.batch_size)
         if b_idx == 0:
             metrics_meter = {k: AverageMeter() for k in metrics}
         monitor = {}
         for m_m in metrics_meter:
             metrics_meter[m_m].update(metrics[m_m], data_loader.batch_size)
             monitor[m_m] = metrics_meter[m_m].avg
         self.current_train_step += 1
         tk0.set_postfix(loss=losses.avg, stage="train", **monitor)
     tk0.close()
     self.update_metrics(losses=losses, monitor=monitor)
     return losses.avg