def validate_one_epoch(self, data_loader): self.eval() self.model_state = enums.ModelState.VALID losses = AverageMeter() if self.using_tpu: tk0 = data_loader else: tk0 = tqdm(data_loader, total=len(data_loader)) for b_idx, data in enumerate(tk0): self.train_state = enums.TrainingState.VALID_STEP_START with torch.no_grad(): loss, metrics = self.validate_one_step(data) self.train_state = enums.TrainingState.VALID_STEP_END losses.update(loss.item(), data_loader.batch_size) if b_idx == 0: metrics_meter = {k: AverageMeter() for k in metrics} monitor = {} for m_m in metrics_meter: metrics_meter[m_m].update(metrics[m_m], data_loader.batch_size) monitor[m_m] = metrics_meter[m_m].avg if not self.using_tpu: tk0.set_postfix(loss=losses.avg, stage="valid", **monitor) self.current_valid_step += 1 if not self.using_tpu: tk0.close() self.update_metrics(losses=losses, monitor=monitor) return losses.avg
def train_one_epoch(self, data_loader): self.train() self.model_state = enums.ModelState.TRAIN losses = AverageMeter() if self.accumulation_steps > 1: self.optimizer.zero_grad() if self.using_tpu: tk0 = data_loader else: tk0 = tqdm(data_loader, total=len(data_loader)) for b_idx, data in enumerate(tk0): self.batch_index = b_idx self.train_state = enums.TrainingState.TRAIN_STEP_START loss, metrics = self.train_one_step(data) self.train_state = enums.TrainingState.TRAIN_STEP_END losses.update(loss.item() * self.accumulation_steps, data_loader.batch_size) if b_idx == 0: metrics_meter = {k: AverageMeter() for k in metrics} monitor = {} for m_m in metrics_meter: metrics_meter[m_m].update(metrics[m_m], data_loader.batch_size) monitor[m_m] = metrics_meter[m_m].avg self.current_train_step += 1 if not self.using_tpu: tk0.set_postfix(loss=losses.avg, stage="train", **monitor) if self.using_tpu: print( f"train step: {self.current_train_step} loss: {losses.avg}" ) if not self.using_tpu: tk0.close() self.update_metrics(losses=losses, monitor=monitor) return losses.avg
def train_one_epoch(self, data_loader): self.train() self.model_state = enums.ModelState.TRAIN losses = AverageMeter() tk0 = tqdm(data_loader, total=len(data_loader)) for b_idx, data in enumerate(tk0): self.train_state = enums.TrainingState.TRAIN_STEP_START loss, metrics = self.train_one_step(data) self.train_state = enums.TrainingState.TRAIN_STEP_END losses.update(loss.item(), data_loader.batch_size) if b_idx == 0: metrics_meter = {k: AverageMeter() for k in metrics} monitor = {} for m_m in metrics_meter: metrics_meter[m_m].update(metrics[m_m], data_loader.batch_size) monitor[m_m] = metrics_meter[m_m].avg self.current_train_step += 1 tk0.set_postfix(loss=losses.avg, stage="train", **monitor) tk0.close() self.update_metrics(losses=losses, monitor=monitor) return losses.avg
def validate(self, data_loader): self._set_validation_epoch_start(data_loader) losses = AverageMeter() for batch_index, data in enumerate(data_loader): self.valid_batch_index = batch_index self.train_state = enums.TrainingState.VALID_STEP_START with torch.no_grad(): loss, metrics = self.predict_step(data) losses, monitor = self._update_loss_metrics(losses, loss, metrics) self.train_state = enums.TrainingState.VALID_STEP_END self._set_validation_epoch_end(losses, monitor)
def _update_loss_metrics(self, losses, loss, metrics): if self._model_state == enums.ModelState.TRAIN: if self.train_batch_index == 0: self.train_meter = {k: AverageMeter() for k in metrics} losses.update( loss.item() * self.config.gradient_accumulation_steps, self.train_loader_bs) elif self._model_state == enums.ModelState.VALID: if self.valid_batch_index == 0: self.valid_meter = {k: AverageMeter() for k in metrics} loss = self._driver.gather(loss).mean() losses.update(loss.item(), self.valid_loader_bs) else: raise ValueError("Invalid model state") monitor = self._update_monitor(losses, metrics) if self._model_state == enums.ModelState.TRAIN: self._train_step += 1 elif self._model_state == enums.ModelState.VALID: self._valid_step += 1 else: raise ValueError("Invalid model state") return losses, monitor
def train(self, data_loader): self._set_training_epoch_start(data_loader) losses = AverageMeter() for batch_index, data in enumerate(data_loader): self.train_batch_index = batch_index self.train_state = enums.TrainingState.TRAIN_STEP_START loss, metrics = self.train_step(data) losses, monitor = self._update_loss_metrics(losses, loss, metrics) self.train_state = enums.TrainingState.TRAIN_STEP_END if self.valid_loader and self.config.val_strategy == "batch": if self._train_step % self.config.val_steps == 0 or self._train_step == self.num_train_steps: self.validate(self.valid_loader) if self._model_state == enums.ModelState.END: break self._set_training_epoch_end(losses, monitor)