def _train_batch(self, train_iter, accumulation_steps: int, non_blocking_transfer: bool = True) -> float:
        self.model.train()
        total_loss = 0

        self.optimizer.zero_grad()
        for i in range(accumulation_steps):
            inputs, labels = next(train_iter)
            inputs, labels = copy_to_device([inputs, labels], device=self.device, non_blocking=non_blocking_transfer)

            # Forward pass
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)

            # Loss should be averaged in each step
            loss /= accumulation_steps

            # Backward pass
            if self.amp and hasattr(self.optimizer, "_amp_stash"):
                # For minor performance optimization, see also:
                # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
                delay_unscale = ((i + 1) % accumulation_steps) != 0

                with torch.cuda.amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss:  # type: ignore
                    scaled_loss.backward()
            else:
                loss.backward()

            total_loss += loss.item()

        self.optimizer.step()

        return total_loss
    def _validate(self, val_iter: ValDataLoaderIter, non_blocking_transfer: bool = True) -> float:
        # Set model to evaluation mode and disable gradient computation
        running_loss = 0
        with eval_mode(self.model):
            for inputs, labels in val_iter:
                # Copy data to the correct device
                inputs, labels = copy_to_device(
                    [inputs, labels], device=self.device, non_blocking=non_blocking_transfer
                )

                # Forward pass and loss computation
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                running_loss += loss.item() * len(labels)

        return running_loss / len(val_iter.dataset)