コード例 #1
0
ファイル: example.py プロジェクト: rivol/delve
    N, D_in, H, D_out = 64, 1000, h, 10

    # Create random Tensors to hold inputs and outputs
    x = torch.randn(N, D_in)
    y = torch.randn(N, D_out)

    model = TwoLayerNet(D_in, H, D_out)

    x, y, model = x.to(device), y.to(device), model.to(device)

    layers = [model.linear1, model.linear2]
    stats = CheckLayerSat('regression/h{}'.format(h), layers)

    loss_fn = torch.nn.MSELoss(size_average=False)
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
    steps_iter = trange(2000, desc='steps', leave=True, position=0)
    steps_iter.write("{:^80}".format(
        "Regression - TwoLayerNet - Hidden layer size {}".format(h)))
    for _ in steps_iter:
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        steps_iter.set_description('loss=%g' % loss.data)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        stats.saturation()
    steps_iter.write('\n')
    stats.close()
    steps_iter.close()
コード例 #2
0
def train(network, dataset, test_set, logging_dir, batch_size):

    network.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(network.parameters())
    #stats = CheckLayerSat(logging_dir, network, log_interval=len(dataset)//batch_size)
    stats = CheckLayerSat(logging_dir,
                          network,
                          log_interval=60,
                          sat_method='cumvar99',
                          conv_method='mean')

    epoch_acc = 0
    thresh = 0.95
    epoch = 0
    total = 0
    correct = 0
    value_dict = None
    while epoch <= 20:
        print('Start Training Epoch', epoch, '\n')
        start = t.time()
        epoch_acc = 0
        train_loss = 0
        total = 0
        correct = 0
        network.train()
        for i, data in enumerate(dataset):
            step = epoch * len(dataset) + i
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = network(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            #if i % 2000 == 1999:  # print every 2000 mini-batches
            print(i, 'of', len(dataset), 'acc:', correct / total)
            # display layer saturation levels
        end = t.time()
        stats.saturation()
        test_loss, test_acc = test(network, test_set, criterion, stats, epoch)
        epoch_acc = correct / total
        print('Epoch', epoch, 'finished', 'Acc:', epoch_acc, 'Loss:',
              train_loss / total, '\n')
        stats.add_scalar('train_loss', train_loss / total, epoch)  # optional
        stats.add_scalar('train_acc', epoch_acc, epoch)  # optional
        value_dict = record_metrics(value_dict, stats.logs, epoch_acc,
                                    train_loss / total, test_acc, test_loss,
                                    epoch, (end - start) / total)
        log_to_csv(value_dict, logging_dir)
        epoch += 1
    stats.close()
    #    test_stats.close()

    return criterion
コード例 #3
0
class Trainer:
    def __init__(
        self,
        model: nn.Module,
        logger: Logger,
        prefix: str = "",
        checkpoint_dir: Union[str, None] = None,
        summary_dir: Union[str, None] = None,
        n_summaries: int = 4,  #
        input_shape: tuple = None,
        start_scratch: bool = False,
        #model_name: str="model",
    ):
        """
        Class which implements network training, validation and testing as well as writing checkpoints, logs, summaries, and saving the final model.

        :param Union[str, None] checkpoint_dir: the type is either str or None (default: None)
        :param int n_summaries: number of images as samples at different phases to visualize on tensorboard
        """
        #self.model_name=model_name
        self.model = model
        self.logger = logger
        self.prefix = prefix

        self.logger.info("Init summary writer")

        if summary_dir is not None:
            run_name = prefix + "_" if prefix != "" else ""
            run_name += "{time}-{host}".format(
                time=time.strftime("%y-%m-%d-%H-%M", time.localtime()),
                host=os.uname()[1],
            )
            self.summary_dir = os.path.join(summary_dir, run_name)

        self.n_summaries = n_summaries
        self.writer = SummaryWriter(summary_dir)

        if input_shape is not None:
            dummy_input = torch.rand(input_shape)
            self.logger.info("Writing graph to summary")
            self.writer.add_graph(self.model, dummy_input)

        if checkpoint_dir is not None:
            self.cp = CheckpointHandler(checkpoint_dir,
                                        prefix=prefix,
                                        logger=self.logger)
        else:
            self.cp = None

        self.start_scratch = start_scratch

    def fit(
        self,
        train_dataloader,
        val_dataloader,
        train_ds,
        val_ds,
        loss_fn,
        optimizer,
        n_epochs,
        val_interval,
        patience_early_stopping,
        device,
        metrics: Union[list, dict] = [],
        val_metric: Union[int, str] = "loss",
        val_metric_mode: str = "min",
        start_epoch=0,
    ):
        """
        train and validate the networks

        :param int n_epochs: max_train_epochs (default=500)
        :param int val_interval: run validation every val_interval number of epoch (ARGS.patience_early_stopping)
        :param int patience_early_stopping: after (patience_early_stopping/val_interval) number of epochs without improvement, terminate training
        """

        self.logger.info("Init model on device '{}'".format(device))
        self.model = self.model.to(device)

        # initalize delve
        self.tracker = CheckLayerSat(self.summary_dir,
                                     save_to="plotcsv",
                                     modules=self.model,
                                     device=device)

        best_model = copy.deepcopy(self.model.state_dict())
        best_metric = 0.0 if val_metric_mode == "max" else float("inf")

        # as we don't validate after each epoch but at val_interval,
        # we update the patience_stopping accordingly to how many times of validation
        patience_stopping = math.ceil(patience_early_stopping / val_interval)
        patience_stopping = int(max(1, patience_stopping))
        early_stopping = EarlyStoppingCriterion(mode=val_metric_mode,
                                                patience=patience_stopping)

        if not self.start_scratch and self.cp is not None:
            checkpoint = self.cp.read_latest()
            if checkpoint is not None:
                try:
                    try:
                        self.model.load_state_dict(checkpoint["modelState"])
                    except RuntimeError as e:
                        self.logger.error(
                            "Failed to restore checkpoint: "
                            "Checkpoint has different parameters")
                        self.logger.error(e)
                        raise SystemExit

                    optimizer.load_state_dict(
                        checkpoint["trainState"]["optState"])
                    start_epoch = checkpoint["trainState"]["epoch"] + 1
                    best_metric = checkpoint["trainState"]["best_metric"]
                    best_model = checkpoint["trainState"]["best_model"]
                    early_stopping.load_state_dict(
                        checkpoint["trainState"]["earlyStopping"])
                    #scheduler.load_state_dict(checkpoint["trainState"]["scheduler"])
                    self.logger.info(
                        "Resuming with epoch {}".format(start_epoch))
                except KeyError:
                    self.logger.error("Failed to restore checkpoint")
                    raise

        since = time.time()

        self.logger.info("Start training model " + self.prefix)

        try:
            if val_metric_mode == "min":
                val_comp = operator.lt  # to run standard operator as function
            else:
                val_comp = operator.gt
            for epoch in range(start_epoch, n_epochs):
                self.train(epoch, train_dataloader, train_ds, loss_fn,
                           optimizer, device)

                if epoch % val_interval == 0 or epoch == n_epochs - 1:
                    # first, get val_loss for further comparison
                    val_loss = self.validate(epoch,
                                             val_dataloader,
                                             val_ds,
                                             loss_fn,
                                             device,
                                             phase="val")
                    if val_metric == "loss":
                        val_result = val_loss
                        # add metrics for delve to keep track of
                        self.tracker.add_scalar("loss", val_loss)
                        # add saturation to the mix
                        self.tracker.add_saturations()
                    else:
                        val_result = metrics[val_metric].get()

                    # compare to see if improvement occurs
                    if val_comp(val_result, best_metric):
                        best_metric = val_result  # update best_metric with the loss (smaller than previous)
                        best_model = copy.deepcopy(self.model.state_dict())
                        """previously, deadlock occurred, which seemed to be related to cp. comment self.cp.write() to see if freezing goes away."""
                        # write checkpoint
                        self.cp.write({
                            "modelState": self.model.state_dict(),
                            "trainState": {
                                "epoch": epoch,
                                "best_metric": best_metric,
                                "best_model": best_model,
                                "optState": optimizer.state_dict(),
                                "earlyStopping": early_stopping.state_dict(),
                            },
                        })

                    # test if the number of accumulated no-improvement epochs is bigger than patience
                    if early_stopping.step(val_result):
                        self.logger.info(
                            "No improvement over the last {} epochs. Training is stopped."
                            .format(patience_early_stopping))
                        break
        except Exception:
            import traceback
            self.logger.warning(traceback.format_exc())
            self.logger.warning("Aborting...")
            self.logger.close()
            raise SystemExit

        # option here: load the best model to run test on test_dataset and log the final metric (along side best metric)
        # for ae, only split: train and validate dataset, without test_dataset

        time_elapsed = time.time() - since
        self.logger.info("Training complete in {:.0f}m {:.0f}s".format(
            time_elapsed // 60, time_elapsed % 60))

        self.logger.info("Best val metric: {:4f}".format(best_metric))

        # close delve tracker
        self.tracker.close()

        return self.model

    def train(self, epoch, train_dataloader, train_ds, loss_fn, optimizer,
              device):
        """
        Training of one epoch on training data, loss function, optimizer, and respective metrics
        """
        self.logger.debug("train|{}|start".format(epoch))

        self.model.train()

        epoch_start = time.time()
        start_data_loading = epoch_start
        data_loading_time = m.Sum(torch.device("cpu"))

        train_running_loss = 0.0
        for i, (train_specs, label) in enumerate(train_dataloader):
            train_specs = train_specs.to(device)
            call_label = None

            if "call" in label:
                call_label = label["call"].to(
                    device, non_blocking=True, dtype=torch.int64
                )  #  e.g. tensor([True, True, True, True, True, True])

            if "ground_truth" in label:
                ground_truth = label["ground_truth"].to(device,
                                                        non_blocking=True)

            data_loading_time.update(
                torch.Tensor([(time.time() - start_data_loading)]))
            optimizer.zero_grad()

            # compute reconstructions
            outputs = self.model(train_specs)

            # compute training reconstruction loss, when augmentation is used
            # loss = loss_fn(outputs, ground_truth)

            # compute training reconstruction loss, when no augmentation is used
            loss = loss_fn(outputs, train_specs)

            # compute accumulated gradients
            loss.backward()

            # perform parameter update based on current gradients
            optimizer.step()

            # add the mini-batch training loss to epoch loss
            # the value of total cost averaged across all training examples of the current batch
            # loss.item()*data.size(0): total loss of the current batch (not averaged).
            train_running_loss += loss.item() * train_specs.size(0)

            prediction = None
            #print("label is ", label, "call_label is ", call_label)

            if i % 2 == 0:
                self.write_summaries(
                    features=train_specs,
                    #labels=call_label,
                    #prediction=prediction,
                    reconstructed=outputs,
                    file_names=label["file_name"],
                    epoch=epoch,
                    phase="train",
                )
            start_data_loading = time.time()

            # compute the epoch training loss
        train_epoch_loss = train_running_loss / len(train_ds)

        self.write_scalar_summaries_logs(
            loss=train_epoch_loss,
            #metrics=metrics,
            lr=optimizer.param_groups[0]["lr"],
            epoch_time=time.time() - epoch_start,
            data_loading_time=data_loading_time.get(),
            epoch=epoch,
            phase="train",
        )

        self.writer.flush()

        return train_epoch_loss

    def validate(self,
                 epoch,
                 val_dataloader,
                 val_ds,
                 loss_fn,
                 device,
                 phase="val"):
        self.logger.debug("{}|{}|start".format(phase, epoch))
        self.model.eval()

        val_running_loss = 0.0
        with torch.no_grad():
            epoch_start = time.time()
            start_data_loading = epoch_start
            data_loading_time = m.Sum(torch.device("cpu"))

            for i, (val_specs, label) in enumerate(val_dataloader):
                val_specs = val_specs.to(device)
                if "call" in label:
                    call_label = label["call"].to(device,
                                                  non_blocking=True,
                                                  dtype=torch.int64)  # bool

                data_loading_time.update(
                    torch.Tensor([(time.time() - start_data_loading)]))

                # instead of converting spec. to color img, we save the 1-chn outputs directly produced by the network
                if i % 2 == 0:
                    #grid = make_grid(val_specs)
                    self.writer.add_images("Original", val_specs,
                                           epoch)  #val_specs

                outputs = self.model(val_specs)

                if i % 2 == 0:
                    # tb = SummaryWriter()
                    #grid = make_grid(outputs)
                    self.writer.add_images("Reconstructed", outputs,
                                           epoch)  #outputs

                loss = loss_fn(outputs, val_specs)

                val_running_loss += loss.item() * val_specs.size(0)

                prediction = None

                if i % 2 == 0:
                    self.write_summaries(
                        features=val_specs,  # original
                        #labels=call_label,
                        #prediction=prediction,
                        reconstructed=outputs,
                        file_names=label["file_name"],
                        epoch=epoch,
                        phase=phase,
                    )
                start_data_loading = time.time()

            val_epoch_loss = val_running_loss / len(val_ds)

            self.write_scalar_summaries_logs(
                loss=val_epoch_loss,
                #metrics=metrics,
                epoch_time=time.time() - epoch_start,
                data_loading_time=data_loading_time.get(),
                epoch=epoch,
                phase=phase,
            )

            self.writer.flush()

            return val_epoch_loss

    def write_summaries(
        self,
        features,
        #labels=None, #  tensor([True, True, True, True, True, True])
        #prediction=None,
        reconstructed=None,
        file_names=None,
        epoch=None,
        phase="train",
    ):
        #"""Writes image summary per partition (spectrograms and the corresponding predictions)"""
        """Writes image summary per partition (spectrograms and reconstructed)"""

        with torch.no_grad():
            self.write_img_summaries(
                features,
                #labels=labels,
                #prediction=prediction,
                reconstructed=reconstructed,
                file_names=file_names,
                epoch=epoch + 1,
                phase=phase,
            )

    def write_img_summaries(
        self,
        features,
        #labels=None,
        #prediction=None,
        reconstructed=None,
        file_names=None,
        epoch=None,
        phase="train",
    ):
        """
        Writes image summary per partition with respect to the prediction output (true predictions - true positive/negative, false
        predictions - false positive/negative)
        """

        with torch.no_grad():
            if file_names is not None:
                if isinstance(file_names, torch.Tensor):
                    file_names = file_names.cpu().numpy()
                elif isinstance(file_names, list):
                    file_names = np.asarray(file_names)
            #if labels is not None and prediction is not None:
            if reconstructed is not None:
                features = features.cpu()
                #labels = labels.cpu()
                #prediction = prediction.cpu()
                reconstructed = reconstructed.cpu()

                self.writer.add_images(
                    tag=phase + "/input",
                    img_tensor=features[:self.n_summaries],
                    #img_tensor=prepare_img(
                    #    features, num_images=self.n_summaries, file_names=file_names
                    #),
                    global_step=epoch,
                )

                self.writer.add_images(
                    tag=phase + "/reconstructed",
                    img_tensor=reconstructed[:self.n_summaries],
                    # img_tensor=prepare_img(
                    #    features, num_images=self.n_summaries, file_names=file_names
                    # ),
                    global_step=epoch,
                )
                """ below are needed to visualize true positive/negative examples"""
                """for label in torch.unique(labels): #  tensor(1, device='cuda:0')
                    label = label.item() # Returns the value of this tensor as a standard Python number: 1
                    l_i = torch.eq(labels, label)

                    t_i = torch.eq(prediction, label) * l_i
                    name_t = "true_{}".format("positive" if label else "negative")
                    try:
                        self.writer.add_image(
                            tag=phase + "/" + name_t,
                            img_tensor=prepare_img(
                                features[t_i],
                                num_images=self.n_summaries,
                                file_names=file_names[t_i.numpy() == 1],
                            ),
                            global_step=epoch,
                        )
                    except ValueError:
                        pass

                    f_i = torch.ne(prediction, label) * l_i
                    name_f = "false_{}".format("negative" if label else "positive")
                    try:
                        self.writer.add_image(
                            tag=phase + "/" + name_f,
                            img_tensor=prepare_img(
                                features[f_i],
                                num_images=self.n_summaries,
                                file_names=file_names[f_i.numpy() == 1],
                            ),
                            global_step=epoch,
                        )
                    except ValueError:
                        pass
            else:
                self.writer.add_image(
                    tag=phase + "/input",
                    img_tensor=prepare_img(
                        features, num_images=self.n_summaries, file_names=file_names
                    ),
                    global_step=epoch,
                )"""

    """
    Writes scalar summary per partition including loss, confusion matrix, accuracy, recall, f1-score, true positive rate,
    false positive rate, precision, data_loading_time, epoch time
    """

    def write_scalar_summaries_logs(
        self,
        loss: float,
        metrics: Union[list, dict] = [],
        lr: float = None,
        epoch_time: float = None,
        data_loading_time: float = None,
        epoch=None,
        phase="train",
    ):
        with torch.no_grad():
            log_str = phase
            if epoch is not None:
                log_str += "|{}".format(epoch)
            self.writer.add_scalar(phase + "/epoch_loss", loss, epoch)
            log_str += "|loss:{:0.3f}".format(loss)
            if isinstance(metrics, dict):
                for name, metric in metrics.items():
                    self.writer.add_scalar(phase + "/" + name, metric.get(),
                                           epoch)
                    log_str += "|{}:{:0.3f}".format(name, metric.get())
            else:
                for i, metric in enumerate(metrics):
                    self.writer.add_scalar(phase + "/metric_" + str(i),
                                           metric.get(), epoch)
                    log_str += "|m_{}:{:0.3f}".format(i, metric.get())
            if lr is not None:
                self.writer.add_scalar("lr", lr, epoch)
                log_str += "|lr:{:0.2e}".format(lr)
            if epoch_time is not None:
                self.writer.add_scalar(phase + "/time", epoch_time, epoch)
                log_str += "|t:{:0.1f}".format(epoch_time)
            if data_loading_time is not None:
                self.writer.add_scalar(phase + "/data_loading_time",
                                       data_loading_time, epoch)
            self.logger.info(log_str)
コード例 #4
0
class Trainer:
    """The trainer object handles the actual training and testing of the model.

    Args:
        model:              The PyTorch-Model
        data_bundle:        The training and test data as a DataBundle
        optimizer_bundle:   Contains the optimizer and the learning rate scheduler
        run_id:             A string identifying the specific run.
        batch_size:         The batch_size to train the model on.
        epochs:             The (total) number of epochs to train the model.
        criterion:          The optimization criterionn (loss), default is cross-entropy
        metrics:            A list of metric-object
        device:             The compute device or list of compute devices to place the model(s) on
        logs_dir:           The directory to store the results
        conv_method:        The strategy for handling convolutional layers for saturation computation
        device_sat:         The device to compute the saturation on. If None, the same device is used as for
                            the model.
        delta:              The delta threshold for computing saturation
        data_parallel:      Enable or Disable multi-GPU
        downsampling:       If None, downsampling is disabled, else the feature maps will be downsampled
                            to (downsampling x downsampling) resolution
    """

    # private internal variables
    _tracker: CheckLayerSat = attrib(init=False)
    _save_path: str = attrib(init=False)
    _initial_epoch: int = attrib(init=False)
    _trained_epochs: int = attrib(init=False)
    _experiment_done: bool = attrib(init=False)

    # General Training setup
    model: Module
    data_bundle: DataBundle
    optimizer_bundle: OptimizerSchedulerBundle
    run_id: str
    batch_size: int = 32
    epochs: int = 30
    criterion: nn.modules.loss._Loss = nn.modules.CrossEntropyLoss()
    metrics: List[Metric] = attrib(factory=list)

    # Technical Setup
    device: str = 'cpu'
    logs_dir: str = './logs'

    # delve setup
    conv_method = 'channelwise'
    device_sat: Optional[str] = None
    delta: float = 0.99
    data_parallel: bool = False
    downsampling: Optional[int] = None

    def _initialize_tracker(self):
        writer = CSVandPlottingWriter(self._save_path.replace('.csv', ''),
                                      primary_metric='test_accuracy')

        self._tracker = CheckLayerSat(
            self._save_path.replace('.csv', ''), [writer],
            self.model,
            ignore_layer_names='convolution',
            stats=['lsat', 'idim'],
            sat_threshold=self.delta,
            verbose=False,
            conv_method=self.conv_method,
            log_interval=1,
            device=self.device_sat,
            reset_covariance=True,
            max_samples=None,
            initial_epoch=self._initial_epoch,
            interpolation_strategy='nearest'
            if self.downsampling is not None else None,
            interpolation_downsampling=self.downsampling)

    def _initialize_saving_structure(self):
        save_dir: str = build_saving_structure(
            logs_dir=self.logs_dir,
            model_name=self.model.name,
            dataset_name=self.data_bundle.dataset_name,
            output_resolution=self.data_bundle.output_resolution,
            run_id=self.run_id)
        self._save_path = os.path.join(
            save_dir,
            f"{self.model.name}-{self.data_bundle.dataset_name}-r{self.data_bundle.output_resolution}-bs{self.batch_size}-e{self.epochs}.csv"
        )

    def _load_model(self):
        self.model.load_state_dict(
            torch.load(self._save_path.replace('.csv', '.pt'),
                       map_location="cpu")['model_state_dict'])
        self.model = self.model.to(self.device)

    def _load_optimizer_and_scheduler(self):
        self.optimizer_bundle.optimizer.load_state_dict(
            torch.load(self._save_path.replace('.csv', '.pt'))['optimizer'])
        if self.optimizer_bundle.scheduler is not None:
            self.optimizer_bundle.scheduler.load_state_dict(
                torch.load(self._save_path.replace('.csv',
                                                   '.pt'))['scheduler'])

    def _load_initial_and_trained_epoch(self):
        self._trained_epochs = torch.load(
            self._save_path.replace('.csv', '.pt'))['epoch']
        self._initial_epoch = self._trained_epochs + 1

    def _check_training_done(self):
        if self._initial_epoch >= self.epochs:
            self._experiment_done = True
            print(
                f'Experiment Logs for the exact same experiment with identical run_id was detected, '
                f'training will be skipped, consider using another run_id')

    def _checkpointing(self):
        self._initial_epoch = 0
        self._trained_epochs = 0
        self._experiment_done = False
        if self.data_parallel:
            print("Enabling multi gpu")
            self.model = nn.DataParallel(self.model,
                                         device_ids=["cuda:0", "cuda:1"],
                                         output_device=self.device)
        self.model = self.model.to(self.device)
        if os.path.exists(self._save_path):
            self._load_initial_and_trained_epoch()
            self._check_training_done()
            self._load_model()
            self._load_optimizer_and_scheduler()
            print('Resuming existing run, starting at epoch',
                  self._initial_epoch + 1, 'from',
                  self._save_path.replace('.csv', '.pt'))

    def _enable_benchmark_mode_if_cuda(self):
        if "cuda" in self.device:
            from torch.backends import cudnn
            cudnn.benchmark = True

    def __attrs_post_init__(self):
        self.device_sat = self.device if self.device_sat is None else self.device_sat
        self._enable_benchmark_mode_if_cuda()
        self._initialize_saving_structure()
        self._checkpointing()
        self._initialize_tracker()

    def _reset_metrics(self):
        for metric in self.metrics:
            metric.reset()

    def _eval_metrics(self, y_true: torch.Tensor, y_pred: torch.Tensor):
        for metric in self.metrics:
            metric.update(y_true, y_pred)

    def _update_pbar_postfix(self, pbar: tqdm):
        metrics = {
            metric.name: round(metric.value, 3)
            for metric in self.metrics
        }
        pbar.set_postfix(metrics)

    def _print_status(self, batch: int, old_time: int, dataset: DataLoader):
        metrics = [
            f"{metric.name}:  {round(metric.value, 3)}"
            for metric in self.metrics
        ]
        print(batch, 'of', len(dataset), 'processing time',
              round(time() - old_time, 3), *metrics)

    def _print_epoch_status(self, epoch: int, old_time: int,
                            metric_dict: Dict[str, float]):
        metrics = [f"{k}:  {round(v, 3)}" for (k, v) in metric_dict.items()]
        print(epoch + 1, 'of', self.epochs, 'processing time',
              round(time() - old_time, 3), *metrics)

    def _track_results(self, prefix: str, metric_name: str,
                       metric_value: float) -> Tuple[str, float]:
        self._tracker.add_scalar(f"{prefix}_{metric_name}", metric_value)
        return f"{prefix}_{metric_name}", metric_value

    def _track_metrics(self, prefix: str, loss: float,
                       total: int) -> Dict[str, float]:
        result: Dict[str, float] = dict()
        for metric in self.metrics:
            name, val = self._track_results(prefix, metric.name, metric.value)
            result[name] = val
        name, val = self._track_results(prefix, "loss", loss / total)
        result[name] = val
        return result

    def _save_checkpoint(self, train_metric: Dict[str, float],
                         test_metric: Dict[str, float], epoch: int):
        state_dict = {
            'model_state_dict':
            self.model.state_dict(),
            'optimizer':
            self.optimizer_bundle.optimizer.state_dict(),
            'scheduler':
            None if self.optimizer_bundle.scheduler is None else
            self.optimizer_bundle.scheduler.state_dict(),
            'epoch':
            epoch
        }
        state_dict.update(train_metric)
        state_dict.update(test_metric)
        torch.save(state_dict, self._save_path.replace('.csv', '.pt'))

    def train(self):
        """Train the model.

        The model is trained for a total number of epochs given the number of epochs provided in the constructor.
        This includes epochs this model was trained previously.

        Returns:
            The path to the saturation ans metric logs.
        """
        if self._experiment_done:
            return
        old_time = time()
        for epoch in range(self._initial_epoch, self.epochs):
            print('Start training epoch', epoch + 1)
            train_metric = self.train_epoch()
            test_metric = self.test()
            train_metric.update(test_metric)
            self._print_epoch_status(epoch=epoch,
                                     old_time=old_time,
                                     metric_dict=train_metric)
            old_time = time()

            if self.optimizer_bundle.scheduler is not None:
                self.optimizer_bundle.scheduler.step()
            self._tracker.add_saturations()
            self._save_checkpoint(train_metric=train_metric,
                                  test_metric=test_metric,
                                  epoch=epoch)
        self._tracker.close()
        return self._save_path + '.csv'

    def train_epoch(self) -> Dict[str, float]:
        """Train a single epoch.

        Returns:
            A dictionary containing all metrics computed incrementally during training.
        """
        self.model.train()
        self._reset_metrics()
        running_loss = 0
        total = 0
        old_time = time()
        pbar = tqdm(self.data_bundle.train_dataset)

        for batch, data in enumerate(pbar):
            if batch % 10 == 0 and batch != 0:
                self._update_pbar_postfix(pbar)

            inputs, labels = data
            inputs, labels = inputs.to(self.device), labels.to(self.device)

            self.optimizer_bundle.optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast():
                outputs = self.model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                self._eval_metrics(labels, outputs)

                loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer_bundle.optimizer.step()

            running_loss += loss.item()
            total += self.batch_size
        return self._track_metrics('training', running_loss, total)

    def test(self):
        """Evaluate the model on the test set.

        Returns:
            The metric computed on the test set.
        """
        self._reset_metrics()
        self.model.eval()
        total = 0
        test_loss = 0
        with torch.no_grad():
            old_time = time()
            pbar = tqdm(self.data_bundle.test_dataset)

            for batch, data in enumerate(pbar):
                inputs, labels = data
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                _, predicted = torch.max(outputs.data, 1)

                total += labels.size(0)
                test_loss += loss.item()

                self._eval_metrics(labels, outputs)

                if batch % 10 == 0 or batch == (
                        len(self.data_bundle.test_dataset) - 1):
                    #self._print_status(batch, old_time, self.data_bundle.test_dataset)
                    self._update_pbar_postfix(pbar)
                    old_time = time()

            test_metrics = self._track_metrics('test', test_loss, total)
        return test_metrics