示例#1
0
    def __init__(
        self,
        model: nn.Module,
        logger: Logger,
        prefix: str = "",
        checkpoint_dir: Union[str, None] = None,
        summary_dir: Union[str, None] = None,
        n_summaries: int = 4,  #
        input_shape: tuple = None,
        start_scratch: bool = False,
        #model_name: str="model",
    ):
        """
        Class which implements network training, validation and testing as well as writing checkpoints, logs, summaries, and saving the final model.

        :param Union[str, None] checkpoint_dir: the type is either str or None (default: None)
        :param int n_summaries: number of images as samples at different phases to visualize on tensorboard
        """
        #self.model_name=model_name
        self.model = model
        self.logger = logger
        self.prefix = prefix

        self.logger.info("Init summary writer")

        if summary_dir is not None:
            run_name = prefix + "_" if prefix != "" else ""
            run_name += "{time}-{host}".format(
                time=time.strftime("%y-%m-%d-%H-%M", time.localtime()),
                host=os.uname()[1],
            )
            self.summary_dir = os.path.join(summary_dir, run_name)

        self.n_summaries = n_summaries
        self.writer = SummaryWriter(summary_dir)

        if input_shape is not None:
            dummy_input = torch.rand(input_shape)
            self.logger.info("Writing graph to summary")
            self.writer.add_graph(self.model, dummy_input)

        if checkpoint_dir is not None:
            self.cp = CheckpointHandler(checkpoint_dir,
                                        prefix=prefix,
                                        logger=self.logger)
        else:
            self.cp = None

        self.start_scratch = start_scratch
示例#2
0
    def __init__(
        self,
        model: nn.Module,
        logger: Logger,
        prefix: str = "",
        checkpoint_dir: Union[str, None] = None,
        summary_dir: Union[str, None] = None,
        n_summaries: int = 4,
        input_shape: tuple = None,
        start_scratch: bool = False,
    ):
        self.model = model
        self.logger = logger
        self.prefix = prefix

        self.logger.info("Init summary writer")

        if summary_dir is not None:
            run_name = prefix + "_" if prefix != "" else ""
            run_name += "{time}-{host}".format(
                time=time.strftime("%y-%m-%d-%H-%M", time.localtime()),
                host=platform.uname()[1],
            )
            summary_dir = os.path.join(summary_dir, run_name)

        self.n_summaries = n_summaries
        self.writer = SummaryWriter(summary_dir)

        if input_shape is not None:
            dummy_input = torch.rand(input_shape)
            self.logger.info("Writing graph to summary")
            self.writer.add_graph(self.model, dummy_input)

        if checkpoint_dir is not None:
            self.cp = CheckpointHandler(checkpoint_dir,
                                        prefix=prefix,
                                        logger=self.logger)
        else:
            self.cp = None

        self.start_scratch = start_scratch
示例#3
0
class Trainer:
    """
    Initializing summary writer and checkpoint handler as well as setting required variables for training.
    """
    def __init__(
        self,
        model: nn.Module,
        logger: Logger,
        prefix: str = "",
        checkpoint_dir: Union[str, None] = None,
        summary_dir: Union[str, None] = None,
        n_summaries: int = 4,
        input_shape: tuple = None,
        start_scratch: bool = False,
    ):
        self.model = model
        self.logger = logger
        self.prefix = prefix

        self.logger.info("Init summary writer")

        if summary_dir is not None:
            run_name = prefix + "_" if prefix != "" else ""
            run_name += "{time}-{host}".format(
                time=time.strftime("%y-%m-%d-%H-%M", time.localtime()),
                host=platform.uname()[1],
            )
            summary_dir = os.path.join(summary_dir, run_name)

        self.n_summaries = n_summaries
        self.writer = SummaryWriter(summary_dir)

        if input_shape is not None:
            dummy_input = torch.rand(input_shape)
            self.logger.info("Writing graph to summary")
            self.writer.add_graph(self.model, dummy_input)

        if checkpoint_dir is not None:
            self.cp = CheckpointHandler(checkpoint_dir,
                                        prefix=prefix,
                                        logger=self.logger)
        else:
            self.cp = None

        self.start_scratch = start_scratch

    """
    Starting network training from scratch or loading existing checkpoints. The model training and validation is processed for a given
    number of epochs while storing all relevant information (metrics, summaries, logs, checkpoints) after each epoch. After the training 
    is stopped (either no improvement of the chosen validation metric for a given number of epochs, or maximum training epoch is reached)
    the model will be tested on the independent test set and saved to the selected model target directory.
    """

    def fit(
        self,
        train_loader,
        val_loader,
        test_loader,
        loss_fn,
        optimizer,
        scheduler,
        n_epochs,
        val_interval,
        patience_early_stopping,
        device,
        metrics: Union[list, dict] = [],
        val_metric: Union[int, str] = "loss",
        val_metric_mode: str = "min",
        start_epoch=0,
    ):
        self.logger.info("Init model on device '{}'".format(device))
        self.model = self.model.to(device)

        best_model = copy.deepcopy(self.model.state_dict())
        best_metric = 0.0 if val_metric_mode == "max" else float("inf")

        patience_stopping = math.ceil(patience_early_stopping / val_interval)
        patience_stopping = int(max(1, patience_stopping))
        early_stopping = EarlyStoppingCriterion(mode=val_metric_mode,
                                                patience=patience_stopping)

        if not self.start_scratch and self.cp is not None:
            checkpoint = self.cp.read_latest()
            if checkpoint is not None:
                try:
                    try:
                        self.model.load_state_dict(checkpoint["modelState"])
                    except RuntimeError as e:
                        self.logger.error(
                            "Failed to restore checkpoint: "
                            "Checkpoint has different parameters")
                        self.logger.error(e)
                        raise SystemExit

                    optimizer.load_state_dict(
                        checkpoint["trainState"]["optState"])
                    start_epoch = checkpoint["trainState"]["epoch"] + 1
                    best_metric = checkpoint["trainState"]["best_metric"]
                    best_model = checkpoint["trainState"]["best_model"]
                    early_stopping.load_state_dict(
                        checkpoint["trainState"]["earlyStopping"])
                    scheduler.load_state_dict(
                        checkpoint["trainState"]["scheduler"])
                    self.logger.info(
                        "Resuming with epoch {}".format(start_epoch))
                except KeyError:
                    self.logger.error("Failed to restore checkpoint")
                    raise

        since = time.time()

        self.logger.info("Start training model " + self.prefix)

        try:
            if val_metric_mode == "min":
                val_comp = operator.lt
            else:
                val_comp = operator.gt
            for epoch in range(start_epoch, n_epochs):
                self.train_epoch(epoch, train_loader, loss_fn, optimizer,
                                 metrics, device)
                if epoch % val_interval == 0 or epoch == n_epochs - 1:
                    val_loss = self.test_epoch(epoch,
                                               val_loader,
                                               loss_fn,
                                               metrics,
                                               device,
                                               phase="val")
                    if val_metric == "loss":
                        val_result = val_loss
                    else:
                        val_result = metrics[val_metric].get()
                    if val_comp(val_result, best_metric):
                        best_metric = val_result
                        best_model = copy.deepcopy(self.model.state_dict())
                    self.cp.write({
                        "modelState": self.model.state_dict(),
                        "trainState": {
                            "epoch": epoch,
                            "best_metric": best_metric,
                            "best_model": best_model,
                            "optState": optimizer.state_dict(),
                            "earlyStopping": early_stopping.state_dict(),
                            "scheduler": scheduler.state_dict(),
                        },
                    })
                    scheduler.step(val_result)
                    if early_stopping.step(val_result):
                        self.logger.info(
                            "No improvment over the last {} epochs. Stopping.".
                            format(patience_early_stopping))
                        break
        except Exception:
            import traceback
            self.logger.warning(traceback.format_exc())
            self.logger.warning("Aborting...")
            self.logger.close()
            raise SystemExit

        self.model.load_state_dict(best_model)
        final_loss = self.test_epoch(0,
                                     test_loader,
                                     loss_fn,
                                     metrics,
                                     device,
                                     phase="test")
        if val_metric == "loss":
            final_metric = final_loss
        else:
            final_metric = metrics[val_metric].get()

        time_elapsed = time.time() - since
        self.logger.info("Training complete in {:.0f}m {:.0f}s".format(
            time_elapsed // 60, time_elapsed % 60))
        self.logger.info("Best val metric: {:4f}".format(best_metric))
        self.logger.info("Final test metric: {:4f}".format(final_metric))

        return self.model

    """
    Training of one epoch using pre-extracted training data, loss function, optimizer, and respective metrics
    """

    def train_epoch(self, epoch, train_loader, loss_fn, optimizer, metrics,
                    device):
        self.logger.debug("train|{}|start".format(epoch))
        if isinstance(metrics, list):
            for metric in metrics:
                metric.reset(device)
        else:
            for metric in metrics.values():
                metric.reset(device)

        self.model.train()

        epoch_start = time.time()
        start_data_loading = epoch_start
        data_loading_time = m.Sum(torch.device("cpu"))
        epoch_loss = m.Mean(device)

        for i, (features, label) in enumerate(train_loader):
            features = features.to(device)
            call_label = None

            if "call" in label:
                call_label = label["call"].to(device,
                                              non_blocking=True,
                                              dtype=torch.int64)

            data_loading_time.update(
                torch.Tensor([(time.time() - start_data_loading)]))
            optimizer.zero_grad()

            output = self.model(features)

            loss = loss_fn(output, call_label)
            loss.backward()

            optimizer.step()

            epoch_loss.update(loss)

            prediction = None

            if call_label is not None:
                prediction = torch.argmax(output.data, dim=1)
                if isinstance(metrics, list):
                    for metric in metrics:
                        metric.update(call_label, prediction)
                else:
                    for metric in metrics.values():
                        metric.update(call_label, prediction)

            if i == 0:
                self.write_summaries(
                    features=features,
                    labels=call_label,
                    prediction=prediction,
                    file_names=label["file_name"],
                    epoch=epoch,
                    phase="train",
                )
            start_data_loading = time.time()

        self.write_scalar_summaries_logs(
            loss=epoch_loss.get(),
            metrics=metrics,
            lr=optimizer.param_groups[0]["lr"],
            epoch_time=time.time() - epoch_start,
            data_loading_time=data_loading_time.get(),
            epoch=epoch,
            phase="train",
        )

        self.writer.flush()

        return epoch_loss.get()

    """ 
    Validation/Testing using pre-extracted validation/test data, given loss function and respective metrics.
    The parameter 'phase' is used to switch between validation and test
    """

    def test_epoch(self,
                   epoch,
                   test_loader,
                   loss_fn,
                   metrics,
                   device,
                   phase="val"):
        self.logger.debug("{}|{}|start".format(phase, epoch))
        self.model.eval()

        with torch.no_grad():
            if isinstance(metrics, list):
                for metric in metrics:
                    metric.reset(device)
            else:
                for metric in metrics.values():
                    metric.reset(device)
            epoch_start = time.time()
            start_data_loading = epoch_start
            data_loading_time = m.Sum(torch.device("cpu"))

            epoch_loss = m.Mean(device)

            auc = AUCMeter()

            for i, (features, label) in enumerate(test_loader):
                features = features.to(device)
                call_label = None

                if "call" in label:
                    call_label = label["call"].to(device,
                                                  non_blocking=True,
                                                  dtype=torch.int64)

                data_loading_time.update(
                    torch.Tensor([(time.time() - start_data_loading)]))

                output = self.model(features)

                loss = loss_fn(output, call_label)
                epoch_loss.update(loss)

                prediction = None

                if call_label is not None:
                    prediction = torch.argmax(output.data, dim=1)
                    if isinstance(metrics, list):
                        for metric in metrics:
                            metric.update(call_label, prediction)
                    else:
                        for metric in metrics.values():
                            metric.update(call_label, prediction)

                    score = nn.functional.softmax(output, dim=1)[:, 1]

                    if auc is not None:
                        auc.add(score, call_label)

                if i == 0:
                    self.write_summaries(
                        features=features,
                        labels=call_label,
                        prediction=prediction,
                        file_names=label["file_name"],
                        epoch=epoch,
                        phase=phase,
                    )
                start_data_loading = time.time()

        self.write_scalar_summaries_logs(
            loss=epoch_loss.get(),
            metrics=metrics,
            epoch_time=time.time() - epoch_start,
            data_loading_time=data_loading_time.get(),
            epoch=epoch,
            phase=phase,
        )

        if call_label is not None and auc is not None:
            self.write_roc_curve_summary(*auc.value(), epoch, phase=phase)

        self.writer.flush()

        return epoch_loss.get()

    """
    Writes image summary per partition (spectrograms and the corresponding predictions)
    """

    def write_summaries(
        self,
        features,
        labels=None,
        prediction=None,
        file_names=None,
        epoch=None,
        phase="train",
    ):
        with torch.no_grad():
            self.write_img_summaries(
                features,
                labels=labels,
                prediction=prediction,
                file_names=file_names,
                epoch=epoch,
                phase=phase,
            )

    """
    Writes image summary per partition with respect to the prediction output (true predictions - true positive/negative, false 
    predictions - false positive/negative)
    """

    def write_img_summaries(
        self,
        features,
        labels=None,
        prediction=None,
        file_names=None,
        epoch=None,
        phase="train",
    ):
        with torch.no_grad():
            if file_names is not None:
                if isinstance(file_names, torch.Tensor):
                    file_names = file_names.cpu().numpy()
                elif isinstance(file_names, list):
                    file_names = np.asarray(file_names)
            if labels is not None and prediction is not None:
                features = features.cpu()
                labels = labels.cpu()
                prediction = prediction.cpu()
                for label in torch.unique(labels):
                    label = label.item()
                    l_i = torch.eq(labels, label)

                    t_i = torch.eq(prediction, label) * l_i
                    name_t = "true_{}".format(
                        "positive" if label else "negative")
                    try:
                        self.writer.add_image(
                            tag=phase + "/" + name_t,
                            img_tensor=prepare_img(
                                features[t_i],
                                num_images=self.n_summaries,
                                file_names=file_names[t_i.numpy() == 1],
                            ),
                            global_step=epoch,
                        )
                    except ValueError:
                        pass

                    f_i = torch.ne(prediction, label) * l_i
                    name_f = "false_{}".format(
                        "negative" if label else "positive")
                    try:
                        self.writer.add_image(
                            tag=phase + "/" + name_f,
                            img_tensor=prepare_img(
                                features[f_i],
                                num_images=self.n_summaries,
                                file_names=file_names[f_i.numpy() == 1],
                            ),
                            global_step=epoch,
                        )
                    except ValueError:
                        pass
            else:
                self.writer.add_image(
                    tag=phase + "/input",
                    img_tensor=prepare_img(features,
                                           num_images=self.n_summaries,
                                           file_names=file_names),
                    global_step=epoch,
                )

    """
    Writes scalar summary per partition including loss, confusion matrix, accuracy, recall, f1-score, true positive rate,
    false positive rate, precision, data_loading_time, epoch time
    """

    def write_scalar_summaries_logs(
        self,
        loss: float,
        metrics: Union[list, dict] = [],
        lr: float = None,
        epoch_time: float = None,
        data_loading_time: float = None,
        epoch=None,
        phase="train",
    ):
        with torch.no_grad():
            log_str = phase
            if epoch is not None:
                log_str += "|{}".format(epoch)
            self.writer.add_scalar(phase + "/epoch_loss", loss, epoch)
            log_str += "|loss:{:0.3f}".format(loss)
            if isinstance(metrics, dict):
                for name, metric in metrics.items():
                    self.writer.add_scalar(phase + "/" + name, metric.get(),
                                           epoch)
                    log_str += "|{}:{:0.3f}".format(name, metric.get())
            else:
                for i, metric in enumerate(metrics):
                    self.writer.add_scalar(phase + "/metric_" + str(i),
                                           metric.get(), epoch)
                    log_str += "|m_{}:{:0.3f}".format(i, metric.get())
            if lr is not None:
                self.writer.add_scalar("lr", lr, epoch)
                log_str += "|lr:{:0.2e}".format(lr)
            if epoch_time is not None:
                self.writer.add_scalar(phase + "/time", epoch_time, epoch)
                log_str += "|t:{:0.1f}".format(epoch_time)
            if data_loading_time is not None:
                self.writer.add_scalar(phase + "/data_loading_time",
                                       data_loading_time, epoch)
            self.logger.info(log_str)

    """
    Writes roc curve summary for validation and test set 
    """

    def write_roc_curve_summary(self, auc, tpr, fpr, epoch=None, phase=""):
        with torch.no_grad():
            if phase != "":
                phase += "_"
            fig = roc_fig(tpr, fpr, auc)
            self.writer.add_figure(phase + "roc/roc", fig, epoch)
示例#4
0
class Trainer:
    def __init__(
        self,
        model: nn.Module,
        logger: Logger,
        prefix: str = "",
        checkpoint_dir: Union[str, None] = None,
        summary_dir: Union[str, None] = None,
        n_summaries: int = 4,  #
        input_shape: tuple = None,
        start_scratch: bool = False,
        #model_name: str="model",
    ):
        """
        Class which implements network training, validation and testing as well as writing checkpoints, logs, summaries, and saving the final model.

        :param Union[str, None] checkpoint_dir: the type is either str or None (default: None)
        :param int n_summaries: number of images as samples at different phases to visualize on tensorboard
        """
        #self.model_name=model_name
        self.model = model
        self.logger = logger
        self.prefix = prefix

        self.logger.info("Init summary writer")

        if summary_dir is not None:
            run_name = prefix + "_" if prefix != "" else ""
            run_name += "{time}-{host}".format(
                time=time.strftime("%y-%m-%d-%H-%M", time.localtime()),
                host=os.uname()[1],
            )
            self.summary_dir = os.path.join(summary_dir, run_name)

        self.n_summaries = n_summaries
        self.writer = SummaryWriter(summary_dir)

        if input_shape is not None:
            dummy_input = torch.rand(input_shape)
            self.logger.info("Writing graph to summary")
            self.writer.add_graph(self.model, dummy_input)

        if checkpoint_dir is not None:
            self.cp = CheckpointHandler(checkpoint_dir,
                                        prefix=prefix,
                                        logger=self.logger)
        else:
            self.cp = None

        self.start_scratch = start_scratch

    def fit(
        self,
        train_dataloader,
        val_dataloader,
        train_ds,
        val_ds,
        loss_fn,
        optimizer,
        n_epochs,
        val_interval,
        patience_early_stopping,
        device,
        metrics: Union[list, dict] = [],
        val_metric: Union[int, str] = "loss",
        val_metric_mode: str = "min",
        start_epoch=0,
    ):
        """
        train and validate the networks

        :param int n_epochs: max_train_epochs (default=500)
        :param int val_interval: run validation every val_interval number of epoch (ARGS.patience_early_stopping)
        :param int patience_early_stopping: after (patience_early_stopping/val_interval) number of epochs without improvement, terminate training
        """

        self.logger.info("Init model on device '{}'".format(device))
        self.model = self.model.to(device)

        # initalize delve
        self.tracker = CheckLayerSat(self.summary_dir,
                                     save_to="plotcsv",
                                     modules=self.model,
                                     device=device)

        best_model = copy.deepcopy(self.model.state_dict())
        best_metric = 0.0 if val_metric_mode == "max" else float("inf")

        # as we don't validate after each epoch but at val_interval,
        # we update the patience_stopping accordingly to how many times of validation
        patience_stopping = math.ceil(patience_early_stopping / val_interval)
        patience_stopping = int(max(1, patience_stopping))
        early_stopping = EarlyStoppingCriterion(mode=val_metric_mode,
                                                patience=patience_stopping)

        if not self.start_scratch and self.cp is not None:
            checkpoint = self.cp.read_latest()
            if checkpoint is not None:
                try:
                    try:
                        self.model.load_state_dict(checkpoint["modelState"])
                    except RuntimeError as e:
                        self.logger.error(
                            "Failed to restore checkpoint: "
                            "Checkpoint has different parameters")
                        self.logger.error(e)
                        raise SystemExit

                    optimizer.load_state_dict(
                        checkpoint["trainState"]["optState"])
                    start_epoch = checkpoint["trainState"]["epoch"] + 1
                    best_metric = checkpoint["trainState"]["best_metric"]
                    best_model = checkpoint["trainState"]["best_model"]
                    early_stopping.load_state_dict(
                        checkpoint["trainState"]["earlyStopping"])
                    #scheduler.load_state_dict(checkpoint["trainState"]["scheduler"])
                    self.logger.info(
                        "Resuming with epoch {}".format(start_epoch))
                except KeyError:
                    self.logger.error("Failed to restore checkpoint")
                    raise

        since = time.time()

        self.logger.info("Start training model " + self.prefix)

        try:
            if val_metric_mode == "min":
                val_comp = operator.lt  # to run standard operator as function
            else:
                val_comp = operator.gt
            for epoch in range(start_epoch, n_epochs):
                self.train(epoch, train_dataloader, train_ds, loss_fn,
                           optimizer, device)

                if epoch % val_interval == 0 or epoch == n_epochs - 1:
                    # first, get val_loss for further comparison
                    val_loss = self.validate(epoch,
                                             val_dataloader,
                                             val_ds,
                                             loss_fn,
                                             device,
                                             phase="val")
                    if val_metric == "loss":
                        val_result = val_loss
                        # add metrics for delve to keep track of
                        self.tracker.add_scalar("loss", val_loss)
                        # add saturation to the mix
                        self.tracker.add_saturations()
                    else:
                        val_result = metrics[val_metric].get()

                    # compare to see if improvement occurs
                    if val_comp(val_result, best_metric):
                        best_metric = val_result  # update best_metric with the loss (smaller than previous)
                        best_model = copy.deepcopy(self.model.state_dict())
                        """previously, deadlock occurred, which seemed to be related to cp. comment self.cp.write() to see if freezing goes away."""
                        # write checkpoint
                        self.cp.write({
                            "modelState": self.model.state_dict(),
                            "trainState": {
                                "epoch": epoch,
                                "best_metric": best_metric,
                                "best_model": best_model,
                                "optState": optimizer.state_dict(),
                                "earlyStopping": early_stopping.state_dict(),
                            },
                        })

                    # test if the number of accumulated no-improvement epochs is bigger than patience
                    if early_stopping.step(val_result):
                        self.logger.info(
                            "No improvement over the last {} epochs. Training is stopped."
                            .format(patience_early_stopping))
                        break
        except Exception:
            import traceback
            self.logger.warning(traceback.format_exc())
            self.logger.warning("Aborting...")
            self.logger.close()
            raise SystemExit

        # option here: load the best model to run test on test_dataset and log the final metric (along side best metric)
        # for ae, only split: train and validate dataset, without test_dataset

        time_elapsed = time.time() - since
        self.logger.info("Training complete in {:.0f}m {:.0f}s".format(
            time_elapsed // 60, time_elapsed % 60))

        self.logger.info("Best val metric: {:4f}".format(best_metric))

        # close delve tracker
        self.tracker.close()

        return self.model

    def train(self, epoch, train_dataloader, train_ds, loss_fn, optimizer,
              device):
        """
        Training of one epoch on training data, loss function, optimizer, and respective metrics
        """
        self.logger.debug("train|{}|start".format(epoch))

        self.model.train()

        epoch_start = time.time()
        start_data_loading = epoch_start
        data_loading_time = m.Sum(torch.device("cpu"))

        train_running_loss = 0.0
        for i, (train_specs, label) in enumerate(train_dataloader):
            train_specs = train_specs.to(device)
            call_label = None

            if "call" in label:
                call_label = label["call"].to(
                    device, non_blocking=True, dtype=torch.int64
                )  #  e.g. tensor([True, True, True, True, True, True])

            if "ground_truth" in label:
                ground_truth = label["ground_truth"].to(device,
                                                        non_blocking=True)

            data_loading_time.update(
                torch.Tensor([(time.time() - start_data_loading)]))
            optimizer.zero_grad()

            # compute reconstructions
            outputs = self.model(train_specs)

            # compute training reconstruction loss, when augmentation is used
            # loss = loss_fn(outputs, ground_truth)

            # compute training reconstruction loss, when no augmentation is used
            loss = loss_fn(outputs, train_specs)

            # compute accumulated gradients
            loss.backward()

            # perform parameter update based on current gradients
            optimizer.step()

            # add the mini-batch training loss to epoch loss
            # the value of total cost averaged across all training examples of the current batch
            # loss.item()*data.size(0): total loss of the current batch (not averaged).
            train_running_loss += loss.item() * train_specs.size(0)

            prediction = None
            #print("label is ", label, "call_label is ", call_label)

            if i % 2 == 0:
                self.write_summaries(
                    features=train_specs,
                    #labels=call_label,
                    #prediction=prediction,
                    reconstructed=outputs,
                    file_names=label["file_name"],
                    epoch=epoch,
                    phase="train",
                )
            start_data_loading = time.time()

            # compute the epoch training loss
        train_epoch_loss = train_running_loss / len(train_ds)

        self.write_scalar_summaries_logs(
            loss=train_epoch_loss,
            #metrics=metrics,
            lr=optimizer.param_groups[0]["lr"],
            epoch_time=time.time() - epoch_start,
            data_loading_time=data_loading_time.get(),
            epoch=epoch,
            phase="train",
        )

        self.writer.flush()

        return train_epoch_loss

    def validate(self,
                 epoch,
                 val_dataloader,
                 val_ds,
                 loss_fn,
                 device,
                 phase="val"):
        self.logger.debug("{}|{}|start".format(phase, epoch))
        self.model.eval()

        val_running_loss = 0.0
        with torch.no_grad():
            epoch_start = time.time()
            start_data_loading = epoch_start
            data_loading_time = m.Sum(torch.device("cpu"))

            for i, (val_specs, label) in enumerate(val_dataloader):
                val_specs = val_specs.to(device)
                if "call" in label:
                    call_label = label["call"].to(device,
                                                  non_blocking=True,
                                                  dtype=torch.int64)  # bool

                data_loading_time.update(
                    torch.Tensor([(time.time() - start_data_loading)]))

                # instead of converting spec. to color img, we save the 1-chn outputs directly produced by the network
                if i % 2 == 0:
                    #grid = make_grid(val_specs)
                    self.writer.add_images("Original", val_specs,
                                           epoch)  #val_specs

                outputs = self.model(val_specs)

                if i % 2 == 0:
                    # tb = SummaryWriter()
                    #grid = make_grid(outputs)
                    self.writer.add_images("Reconstructed", outputs,
                                           epoch)  #outputs

                loss = loss_fn(outputs, val_specs)

                val_running_loss += loss.item() * val_specs.size(0)

                prediction = None

                if i % 2 == 0:
                    self.write_summaries(
                        features=val_specs,  # original
                        #labels=call_label,
                        #prediction=prediction,
                        reconstructed=outputs,
                        file_names=label["file_name"],
                        epoch=epoch,
                        phase=phase,
                    )
                start_data_loading = time.time()

            val_epoch_loss = val_running_loss / len(val_ds)

            self.write_scalar_summaries_logs(
                loss=val_epoch_loss,
                #metrics=metrics,
                epoch_time=time.time() - epoch_start,
                data_loading_time=data_loading_time.get(),
                epoch=epoch,
                phase=phase,
            )

            self.writer.flush()

            return val_epoch_loss

    def write_summaries(
        self,
        features,
        #labels=None, #  tensor([True, True, True, True, True, True])
        #prediction=None,
        reconstructed=None,
        file_names=None,
        epoch=None,
        phase="train",
    ):
        #"""Writes image summary per partition (spectrograms and the corresponding predictions)"""
        """Writes image summary per partition (spectrograms and reconstructed)"""

        with torch.no_grad():
            self.write_img_summaries(
                features,
                #labels=labels,
                #prediction=prediction,
                reconstructed=reconstructed,
                file_names=file_names,
                epoch=epoch + 1,
                phase=phase,
            )

    def write_img_summaries(
        self,
        features,
        #labels=None,
        #prediction=None,
        reconstructed=None,
        file_names=None,
        epoch=None,
        phase="train",
    ):
        """
        Writes image summary per partition with respect to the prediction output (true predictions - true positive/negative, false
        predictions - false positive/negative)
        """

        with torch.no_grad():
            if file_names is not None:
                if isinstance(file_names, torch.Tensor):
                    file_names = file_names.cpu().numpy()
                elif isinstance(file_names, list):
                    file_names = np.asarray(file_names)
            #if labels is not None and prediction is not None:
            if reconstructed is not None:
                features = features.cpu()
                #labels = labels.cpu()
                #prediction = prediction.cpu()
                reconstructed = reconstructed.cpu()

                self.writer.add_images(
                    tag=phase + "/input",
                    img_tensor=features[:self.n_summaries],
                    #img_tensor=prepare_img(
                    #    features, num_images=self.n_summaries, file_names=file_names
                    #),
                    global_step=epoch,
                )

                self.writer.add_images(
                    tag=phase + "/reconstructed",
                    img_tensor=reconstructed[:self.n_summaries],
                    # img_tensor=prepare_img(
                    #    features, num_images=self.n_summaries, file_names=file_names
                    # ),
                    global_step=epoch,
                )
                """ below are needed to visualize true positive/negative examples"""
                """for label in torch.unique(labels): #  tensor(1, device='cuda:0')
                    label = label.item() # Returns the value of this tensor as a standard Python number: 1
                    l_i = torch.eq(labels, label)

                    t_i = torch.eq(prediction, label) * l_i
                    name_t = "true_{}".format("positive" if label else "negative")
                    try:
                        self.writer.add_image(
                            tag=phase + "/" + name_t,
                            img_tensor=prepare_img(
                                features[t_i],
                                num_images=self.n_summaries,
                                file_names=file_names[t_i.numpy() == 1],
                            ),
                            global_step=epoch,
                        )
                    except ValueError:
                        pass

                    f_i = torch.ne(prediction, label) * l_i
                    name_f = "false_{}".format("negative" if label else "positive")
                    try:
                        self.writer.add_image(
                            tag=phase + "/" + name_f,
                            img_tensor=prepare_img(
                                features[f_i],
                                num_images=self.n_summaries,
                                file_names=file_names[f_i.numpy() == 1],
                            ),
                            global_step=epoch,
                        )
                    except ValueError:
                        pass
            else:
                self.writer.add_image(
                    tag=phase + "/input",
                    img_tensor=prepare_img(
                        features, num_images=self.n_summaries, file_names=file_names
                    ),
                    global_step=epoch,
                )"""

    """
    Writes scalar summary per partition including loss, confusion matrix, accuracy, recall, f1-score, true positive rate,
    false positive rate, precision, data_loading_time, epoch time
    """

    def write_scalar_summaries_logs(
        self,
        loss: float,
        metrics: Union[list, dict] = [],
        lr: float = None,
        epoch_time: float = None,
        data_loading_time: float = None,
        epoch=None,
        phase="train",
    ):
        with torch.no_grad():
            log_str = phase
            if epoch is not None:
                log_str += "|{}".format(epoch)
            self.writer.add_scalar(phase + "/epoch_loss", loss, epoch)
            log_str += "|loss:{:0.3f}".format(loss)
            if isinstance(metrics, dict):
                for name, metric in metrics.items():
                    self.writer.add_scalar(phase + "/" + name, metric.get(),
                                           epoch)
                    log_str += "|{}:{:0.3f}".format(name, metric.get())
            else:
                for i, metric in enumerate(metrics):
                    self.writer.add_scalar(phase + "/metric_" + str(i),
                                           metric.get(), epoch)
                    log_str += "|m_{}:{:0.3f}".format(i, metric.get())
            if lr is not None:
                self.writer.add_scalar("lr", lr, epoch)
                log_str += "|lr:{:0.2e}".format(lr)
            if epoch_time is not None:
                self.writer.add_scalar(phase + "/time", epoch_time, epoch)
                log_str += "|t:{:0.1f}".format(epoch_time)
            if data_loading_time is not None:
                self.writer.add_scalar(phase + "/data_loading_time",
                                       data_loading_time, epoch)
            self.logger.info(log_str)
示例#5
0
class Trainer:

    """
    Initializing summary writer and checkpoint handler as well as setting required variables for training.
    """
    def __init__(
        self,
        model: nn.Module,
        logger: Logger,
        prefix: str = "",
        checkpoint_dir: Union[str, None] = None,
        summary_dir: Union[str, None] = None,
        n_summaries: int = 4,
        input_shape: tuple = None,
        start_scratch: bool = False,
    ):
        self.model = model
        self.logger = logger
        self.prefix = prefix

        self.logger.info("Init summary writer")

        if summary_dir is not None:
            run_name = prefix + "_" if prefix != "" else ""
            run_name += "{time}-{host}".format(
                time=time.strftime("%y-%m-%d-%H-%M", time.localtime()),
                host=platform.uname()[1],
            )
            summary_dir = os.path.join(summary_dir, run_name)

        self.n_summaries = n_summaries
        self.writer = SummaryWriter(summary_dir)

        if input_shape is not None:
            dummy_input = torch.rand(input_shape)
            self.logger.info("Writing graph to summary")
            self.writer.add_graph(self.model, dummy_input)

        if checkpoint_dir is not None:
            self.cp = CheckpointHandler(
                checkpoint_dir, prefix=prefix, logger=self.logger
            )
        else:
            self.cp = None

        self.start_scratch = start_scratch

    """
    Starting network training from scratch or loading existing checkpoints. The model training and validation is processed for a given
    number of epochs while storing all relevant information (metrics, summaries, logs, checkpoints) after each epoch. After the training 
    is stopped (either no improvement of the chosen validation metric for a given number of epochs, or maximum training epoch is reached)
    the model will be tested on the independent test set and saved to the selected model target directory.
    """
    def fit(
        self,
        train_loader,
        val_loader,
        test_loader,
        loss_fn,
        optimizer,
        scheduler,
        n_epochs,
        val_interval,
        patience_early_stopping,
        device,
        val_metric: Union[int, str] = "loss",
        val_metric_mode: str = "min",
        start_epoch=0,
    ):
        self.logger.info("Init model on device '{}'".format(device))
        self.model = self.model.to(device)

        best_model = copy.deepcopy(self.model.state_dict())
        best_metric = 0.0 if val_metric_mode == "max" else float("inf")

        patience_stopping = math.ceil(patience_early_stopping / val_interval)
        patience_stopping = int(max(1, patience_stopping))
        early_stopping = EarlyStoppingCriterion(
            mode=val_metric_mode, patience=patience_stopping
        )

        if not self.start_scratch and self.cp is not None:
            checkpoint = self.cp.read_latest()
            if checkpoint is not None:
                try:
                    try:
                        self.model.load_state_dict(checkpoint["modelState"])
                    except RuntimeError as e:
                        self.logger.error(
                            "Failed to restore checkpoint: "
                            "Checkpoint has different parameters"
                        )
                        self.logger.error(e)
                        raise SystemExit

                    optimizer.load_state_dict(checkpoint["trainState"]["optState"])
                    start_epoch = checkpoint["trainState"]["epoch"] + 1
                    best_metric = checkpoint["trainState"]["best_metric"]
                    best_model = checkpoint["trainState"]["best_model"]
                    early_stopping.load_state_dict(
                        checkpoint["trainState"]["earlyStopping"]
                    )
                    scheduler.load_state_dict(checkpoint["trainState"]["scheduler"])
                    self.logger.info("Resuming with epoch {}".format(start_epoch))
                except KeyError:
                    self.logger.error("Failed to restore checkpoint")
                    raise

        since = time.time()

        self.logger.info("Start training model " + self.prefix)

        try:
            if val_metric_mode == "min":
                val_comp = operator.lt
            else:
                raise Exception("validation metric mode has to be set to \"min\"")
            for epoch in range(start_epoch, n_epochs):
                self.train_epoch(
                    epoch, train_loader, loss_fn, optimizer, device
                )
                if epoch % val_interval == 0 or epoch == n_epochs - 1:
                    val_loss = self.test_epoch(
                        epoch, val_loader, loss_fn, device, phase="val"
                    )
                    if val_metric == "loss":
                        val_result = val_loss
                    else:
                        raise Exception("validation metric has to be set to \"loss\"")
                    if val_comp(val_result, best_metric):
                        best_metric = val_result
                        best_model = copy.deepcopy(self.model.state_dict())
                    self.cp.write(
                        {
                            "modelState": self.model.state_dict(),
                            "trainState": {
                            "epoch": epoch,
                            "best_metric": best_metric,
                            "best_model": best_model,
                            "optState": optimizer.state_dict(),
                            "earlyStopping": early_stopping.state_dict(),
                            "scheduler": scheduler.state_dict(),
                            },
                        }
                    )
                    scheduler.step(val_result)
                    if early_stopping.step(val_result):
                        self.logger.info(
                            "No improvment over the last {} epochs. Stopping.".format(
                                patience_early_stopping
                            )
                        )
                        break
        except Exception:
            import traceback
            self.logger.warning(traceback.format_exc())
            self.logger.warning("Aborting...")
            self.logger.close()
            raise SystemExit

        self.model.load_state_dict(best_model)
        final_loss = self.test_epoch(0, test_loader, loss_fn, device, phase="test")
        if val_metric == "loss":
            final_metric = final_loss
        else:
            raise Exception("validation metric has to be set to \"loss\"")

        time_elapsed = time.time() - since
        self.logger.info(
            "Training complete in {:.0f}m {:.0f}s".format(
                time_elapsed // 60, time_elapsed % 60
            )
        )
        self.logger.info("Best val metric: {:4f}".format(best_metric))
        self.logger.info("Final test metric: {:4f}".format(final_metric))

        return self.model

    """
    Training of one epoch using pre-extracted training data, loss function, optimizer, and respective metrics
    """
    def train_epoch(self, epoch, train_loader, loss_fn, optimizer, device):
        self.logger.debug("train|{}|start".format(epoch))

        self.model.train()

        epoch_start = time.time()
        start_data_loading = epoch_start
        data_loading_time = m.Sum(torch.device("cpu"))
        epoch_loss = m.Mean(device)

        for i, (features, label) in enumerate(train_loader):

            features = features.to(device)

            ground_truth = label["ground_truth"].to(device, non_blocking=True)

            data_loading_time.update(torch.Tensor([(time.time() - start_data_loading)]))

            optimizer.zero_grad()

            denoised_output = self.model(features)

            loss = loss_fn(denoised_output, ground_truth)

            loss.backward()

            optimizer.step()

            epoch_loss.update(loss)

            start_data_loading = time.time()

            if i % 5 == 0:
                self.writer.add_image(
                    tag="train" + "/ground_truth",
                    img_tensor=prepare_img(
                        ground_truth.transpose(0, 1).squeeze(dim=0), num_images=self.n_summaries, file_names=label["file_name"],
                    ),
                    global_step=epoch,
                )
                self.writer.add_image(
                    tag="train" + "/input",
                    img_tensor=prepare_img(
                        features.transpose(0, 1).squeeze(dim=0), num_images=self.n_summaries,
                        file_names=label["file_name"],
                    ),
                    global_step=epoch,
                )
                self.writer.add_image(
                    tag="train" + "/masks_pred",
                    img_tensor=prepare_img(
                        denoised_output.transpose(0, 1).squeeze(dim=0), num_images=self.n_summaries,
                        file_names=label["file_name"],
                    ),
                    global_step=epoch,
                )

        self.write_scalar_summaries_logs(
            loss=epoch_loss.get(),
            lr=optimizer.param_groups[0]["lr"],
            epoch_time=time.time() - epoch_start,
            data_loading_time=data_loading_time.get(),
            epoch=epoch,
            phase="train",
        )

        self.writer.flush()

        return epoch_loss.get()

    """ 
    Validation/Testing using pre-extracted validation/test data, given loss function and respective metrics.
    The parameter 'phase' is used to switch between validation and test
    """
    def test_epoch(self, epoch, test_loader, loss_fn, device, phase="val"):
        self.logger.debug("{}|{}|start".format(phase, epoch))

        self.model.eval()

        with torch.no_grad():
            epoch_start = time.time()
            start_data_loading = epoch_start
            data_loading_time = m.Sum(torch.device("cpu"))
            epoch_loss = m.Mean(device)

            for i, (features, label) in enumerate(test_loader):

                features = features.to(device)

                ground_truth = label["ground_truth"].to(device, non_blocking=True)

                data_loading_time.update(torch.Tensor([(time.time() - start_data_loading)]))

                denoised_output = self.model(features)

                loss = loss_fn(denoised_output, ground_truth)

                epoch_loss.update(loss)

                if i % 5 == 0:
                    self.writer.add_image(
                        tag=phase + "/ground_truth",
                        img_tensor=prepare_img(
                            ground_truth.transpose(0, 1).squeeze(dim=0), num_images=self.n_summaries,
                            file_names=label["file_name"],
                        ),
                        global_step=epoch,
                    )
                    self.writer.add_image(
                        tag=phase + "/input",
                        img_tensor=prepare_img(
                            features.transpose(0, 1).squeeze(dim=0), num_images=self.n_summaries,
                            file_names=label["file_name"],
                        ),
                        global_step=epoch,
                    )
                    self.writer.add_image(
                        tag=phase + "/masks_pred",
                        img_tensor=prepare_img(
                            denoised_output.transpose(0, 1).squeeze(dim=0), num_images=self.n_summaries,
                            file_names=label["file_name"],
                        ),
                        global_step=epoch,
                    )

                start_data_loading = time.time()

        self.write_scalar_summaries_logs(
            loss=epoch_loss.get(),
            epoch_time=time.time() - epoch_start,
            data_loading_time=data_loading_time.get(),
            epoch=epoch,
            phase=phase,
        )

        self.writer.flush()

        return epoch_loss.get()

    """
    Writes scalar summary per partition including loss, data_loading_time, epoch time
    """
    def write_scalar_summaries_logs(
        self,
        loss: float,
        lr: float = None,
        epoch_time: float = None,
        data_loading_time: float = None,
        epoch=None,
        phase="train",
    ):
        with torch.no_grad():
            log_str = phase
            if epoch is not None:
                log_str += "|{}".format(epoch)
            self.writer.add_scalar(phase + "/epoch_loss", loss, epoch)
            log_str += "|loss:{:0.3f}".format(loss)
            if lr is not None:
                self.writer.add_scalar("lr", lr, epoch)
                log_str += "|lr:{:0.2e}".format(lr)
            if epoch_time is not None:
                self.writer.add_scalar(phase + "/time", epoch_time, epoch)
                log_str += "|t:{:0.1f}".format(epoch_time)
            if data_loading_time is not None:
                self.writer.add_scalar(
                    phase + "/data_loading_time", data_loading_time, epoch
                )
            self.logger.info(log_str)