def validate_one_epoch(self, data_loader): self.eval() self.model_state = enums.ModelState.VALID losses = AverageMeter() if self.using_tpu: tk0 = data_loader else: tk0 = tqdm(data_loader, total=len(data_loader)) for b_idx, data in enumerate(tk0): self.train_state = enums.TrainingState.VALID_STEP_START with torch.no_grad(): loss, metrics = self.validate_one_step(data) self.train_state = enums.TrainingState.VALID_STEP_END losses.update(loss.item(), data_loader.batch_size) if b_idx == 0: metrics_meter = {k: AverageMeter() for k in metrics} monitor = {} for m_m in metrics_meter: metrics_meter[m_m].update(metrics[m_m], data_loader.batch_size) monitor[m_m] = metrics_meter[m_m].avg if not self.using_tpu: tk0.set_postfix(loss=losses.avg, stage="valid", **monitor) self.current_valid_step += 1 if not self.using_tpu: tk0.close() self.update_metrics(losses=losses, monitor=monitor) return losses.avg
def train_one_epoch(self, data_loader): self.train() self.model_state = enums.ModelState.TRAIN losses = AverageMeter() if self.accumulation_steps > 1: self.optimizer.zero_grad() if self.using_tpu: tk0 = data_loader else: tk0 = tqdm(data_loader, total=len(data_loader)) for b_idx, data in enumerate(tk0): self.batch_index = b_idx self.train_state = enums.TrainingState.TRAIN_STEP_START loss, metrics = self.train_one_step(data) self.train_state = enums.TrainingState.TRAIN_STEP_END losses.update(loss.item() * self.accumulation_steps, data_loader.batch_size) if b_idx == 0: metrics_meter = {k: AverageMeter() for k in metrics} monitor = {} for m_m in metrics_meter: metrics_meter[m_m].update(metrics[m_m], data_loader.batch_size) monitor[m_m] = metrics_meter[m_m].avg self.current_train_step += 1 if not self.using_tpu: tk0.set_postfix(loss=losses.avg, stage="train", **monitor) if self.using_tpu: print( f"train step: {self.current_train_step} loss: {losses.avg}" ) if not self.using_tpu: tk0.close() self.update_metrics(losses=losses, monitor=monitor) return losses.avg
def train_one_epoch(self, data_loader): self.train() self.model_state = enums.ModelState.TRAIN losses = AverageMeter() tk0 = tqdm(data_loader, total=len(data_loader)) for b_idx, data in enumerate(tk0): self.train_state = enums.TrainingState.TRAIN_STEP_START loss, metrics = self.train_one_step(data) self.train_state = enums.TrainingState.TRAIN_STEP_END losses.update(loss.item(), data_loader.batch_size) if b_idx == 0: metrics_meter = {k: AverageMeter() for k in metrics} monitor = {} for m_m in metrics_meter: metrics_meter[m_m].update(metrics[m_m], data_loader.batch_size) monitor[m_m] = metrics_meter[m_m].avg self.current_train_step += 1 tk0.set_postfix(loss=losses.avg, stage="train", **monitor) tk0.close() self.update_metrics(losses=losses, monitor=monitor) return losses.avg