Exemplo n.º 1
0
    def train(self, HP):

        if HP.USE_VISLOGGER:
            try:
                from trixi.logger.visdom import PytorchVisdomLogger
            except ImportError:
                pass
            trixi = PytorchVisdomLogger(port=8080, auto_start=True)

        ExpUtils.print_and_save(HP, socket.gethostname())

        epoch_times = []
        nr_of_updates = 0

        metrics = {}
        for type in ["train", "test", "validate"]:
            metrics_new = {
                "loss_" + type: [0],
                "f1_macro_" + type: [0],
            }
            metrics = dict(list(metrics.items()) + list(metrics_new.items()))

        for epoch_nr in range(HP.NUM_EPOCHS):
            start_time = time.time()
            # current_lr = HP.LEARNING_RATE * (HP.LR_DECAY ** epoch_nr)
            # current_lr = HP.LEARNING_RATE

            batch_gen_time = 0
            data_preparation_time = 0
            network_time = 0
            metrics_time = 0
            saving_time = 0
            plotting_time = 0

            batch_nr = {
                "train": 0,
                "test": 0,
                "validate": 0
            }

            if HP.LOSS_WEIGHT_LEN == -1:
                weight_factor = float(HP.LOSS_WEIGHT)
            else:
                if epoch_nr < HP.LOSS_WEIGHT_LEN:
                    # weight_factor = -(9./100.) * epoch_nr + 10.   #ep0: 10 -> linear decrease -> ep100: 1
                    weight_factor = -((HP.LOSS_WEIGHT-1)/float(HP.LOSS_WEIGHT_LEN)) * epoch_nr + float(HP.LOSS_WEIGHT)
                    # weight_factor = -((HP.LOSS_WEIGHT-5)/float(HP.LOSS_WEIGHT_LEN)) * epoch_nr + float(HP.LOSS_WEIGHT)
                else:
                    weight_factor = 1.
                    # weight_factor = 5.

            for type in ["train", "test", "validate"]:
                print_loss = []
                start_time_batch_gen = time.time()

                batch_generator = self.dataManager.get_batches(batch_size=HP.BATCH_SIZE,
                                                               type=type, subjects=getattr(HP, type.upper() + "_SUBJECTS"))
                batch_gen_time = time.time() - start_time_batch_gen
                # print("batch_gen_time: {}s".format(batch_gen_time))

                print("Start looping batches...")
                start_time_batch_part = time.time()
                for batch in batch_generator:                   #getting next batch takes around 0.14s -> second largest Time part after mode!

                    start_time_data_preparation = time.time()
                    batch_nr[type] += 1

                    x = batch["data"] # (bs, nr_of_channels, x, y)
                    y = batch["seg"]  # (bs, nr_of_classes, x, y)
                    # since using new BatchGenerator y is not int anymore but float -> would be good for Pytorch but not Lasagne
                    # y = y.astype(HP.LABELS_TYPE)  #for bundle_peaks regression: is already float -> saves 0.2s/batch if left out

                    data_preparation_time += time.time() - start_time_data_preparation
                    # self.model.learning_rate.set_value(np.float32(current_lr))
                    start_time_network = time.time()
                    if type == "train":
                        nr_of_updates += 1
                        loss, probs, f1 = self.model.train(x, y, weight_factor=weight_factor)    # probs: # (bs, x, y, nrClasses)
                        # loss, probs, f1, intermediate = self.model.train(x, y)
                    elif type == "validate":
                        loss, probs, f1 = self.model.predict(x, y, weight_factor=weight_factor)
                    elif type == "test":
                        loss, probs, f1 = self.model.predict(x, y, weight_factor=weight_factor)
                    network_time += time.time() - start_time_network

                    start_time_metrics = time.time()

                    if HP.CALC_F1:
                        if HP.EXPERIMENT_TYPE == "peak_regression":
                            #Following two lines increase metrics_time by 30s (without < 1s); time per batch increases by 1.5s by these lines
                            # y_flat = y.transpose(0, 2, 3, 1)  # (bs, x, y, nr_of_classes)
                            # y_flat = np.reshape(y_flat, (-1, y_flat.shape[-1]))  # (bs*x*y, nr_of_classes)
                            # metrics = MetricUtils.calculate_metrics(metrics, y_flat, probs, loss, f1=np.mean(f1), type=type, threshold=HP.THRESHOLD,
                            #                                         f1_per_bundle={"CA": f1[5], "FX_left": f1[23], "FX_right": f1[24]})

                            #Numpy
                            # y_right_order = y.transpose(0, 2, 3, 1)  # (bs, x, y, nr_of_classes)
                            # peak_f1 = MetricUtils.calc_peak_dice(HP, probs, y_right_order)
                            # peak_f1_mean = np.array([s for s in peak_f1.values()]).mean()

                            #Pytorch
                            peak_f1_mean = np.array([s for s in list(f1.values())]).mean()  #if f1 for multiple bundles
                            metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=peak_f1_mean, type=type, threshold=HP.THRESHOLD)

                            #Pytorch 2 F1
                            # peak_f1_mean_a = np.array([s for s in f1[0].values()]).mean()
                            # peak_f1_mean_b = np.array([s for s in f1[1].values()]).mean()
                            # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=peak_f1_mean_a, type=type, threshold=HP.THRESHOLD,
                            #                                         f1_per_bundle={"LenF1": peak_f1_mean_b})

                            #Single Bundle
                            # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"][0], type=type, threshold=HP.THRESHOLD,
                            #                                         f1_per_bundle={"Thr1": f1["CST_right"][1], "Thr2": f1["CST_right"][2]})
                            # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"], type=type, threshold=HP.THRESHOLD)
                        else:
                            metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=np.mean(f1), type=type, threshold=HP.THRESHOLD)

                    else:
                        metrics = MetricUtils.calculate_metrics_onlyLoss(metrics, loss, type=type)

                    metrics_time += time.time() - start_time_metrics

                    print_loss.append(loss)
                    if batch_nr[type] % HP.PRINT_FREQ == 0:
                        time_batch_part = time.time() - start_time_batch_part
                        start_time_batch_part = time.time()
                        ExpUtils.print_and_save(HP, "{} Ep {}, Sp {}, loss {}, t print {}s, t batch {}s".format(type, epoch_nr,
                                                                batch_nr[type] * HP.BATCH_SIZE,
                                                                round(np.array(print_loss).mean(), 6), round(time_batch_part, 3),
                                                                round(time_batch_part / HP.PRINT_FREQ, 3)))
                        print_loss = []

                    if HP.USE_VISLOGGER:
                        ExpUtils.plot_result_trixi(trixi, x, y, probs, loss, f1, epoch_nr)


            ###################################
            # Post Training tasks (each epoch)
            ###################################

            #Adapt LR
            if HP.LR_SCHEDULE:
                self.model.scheduler.step()
                # self.model.scheduler.step(np.mean(f1))
                self.model.print_current_lr()

            # Average loss per batch over entire epoch
            metrics = MetricUtils.normalize_last_element(metrics, batch_nr["train"], type="train")
            metrics = MetricUtils.normalize_last_element(metrics, batch_nr["validate"], type="validate")
            metrics = MetricUtils.normalize_last_element(metrics, batch_nr["test"], type="test")

            print("  Epoch {}, Average Epoch loss = {}".format(epoch_nr, metrics["loss_train"][-1]))
            print("  Epoch {}, nr_of_updates {}".format(epoch_nr, nr_of_updates))

            # Save Weights
            start_time_saving = time.time()
            if HP.SAVE_WEIGHTS:
                self.model.save_model(metrics, epoch_nr)
            saving_time += time.time() - start_time_saving

            # Create Plots
            start_time_plotting = time.time()
            pickle.dump(metrics, open(join(HP.EXP_PATH, "metrics.pkl"), "wb")) # wb -> write (override) and binary (binary only needed on windows, on unix also works without) # for loading: pickle.load(open("metrics.pkl", "rb"))
            ExpUtils.create_exp_plot(metrics, HP.EXP_PATH, HP.EXP_NAME)
            ExpUtils.create_exp_plot(metrics, HP.EXP_PATH, HP.EXP_NAME, without_first_epochs=True)
            plotting_time += time.time() - start_time_plotting

            epoch_time = time.time() - start_time
            epoch_times.append(epoch_time)

            ExpUtils.print_and_save(HP, "  Epoch {}, time total {}s".format(epoch_nr, epoch_time))
            ExpUtils.print_and_save(HP, "  Epoch {}, time UNet: {}s".format(epoch_nr, network_time))
            ExpUtils.print_and_save(HP, "  Epoch {}, time metrics: {}s".format(epoch_nr, metrics_time))
            ExpUtils.print_and_save(HP, "  Epoch {}, time saving files: {}s".format(epoch_nr, saving_time))
            ExpUtils.print_and_save(HP, str(datetime.datetime.now()))

            # Adding next Epoch
            if epoch_nr < HP.NUM_EPOCHS-1:
                metrics = MetricUtils.add_empty_element(metrics)


        ####################################
        # After all epochs
        ###################################
        with open(join(HP.EXP_PATH, "Hyperparameters.txt"), "a") as f:  # a for append
            f.write("\n\n")
            f.write("Average Epoch time: {}s".format(sum(epoch_times) / float(len(epoch_times))))

        return metrics
Exemplo n.º 2
0
def train_model(Config, model, data_loader):

    if Config.USE_VISLOGGER:
        try:
            from trixi.logger.visdom import PytorchVisdomLogger
        except ImportError:
            pass
        trixi = PytorchVisdomLogger(port=8080, auto_start=True)

    exp_utils.print_and_save(Config, socket.gethostname())

    epoch_times = []
    nr_of_updates = 0

    metrics = {}
    for type in ["train", "test", "validate"]:
        for metric in Config.METRIC_TYPES:
            metrics[metric + "_" + type] = [0]

    batch_gen_train = data_loader.get_batch_generator(batch_size=Config.BATCH_SIZE, type="train",
                                                      subjects=getattr(Config, "TRAIN_SUBJECTS"))
    batch_gen_val = data_loader.get_batch_generator(batch_size=Config.BATCH_SIZE, type="validate",
                                                    subjects=getattr(Config, "VALIDATE_SUBJECTS"))

    for epoch_nr in range(Config.NUM_EPOCHS):
        start_time = time.time()

        timings = defaultdict(lambda: 0)
        batch_nr = defaultdict(lambda: 0)
        weight_factor = _get_weights_for_this_epoch(Config, epoch_nr)
        types = ["validate"] if Config.ONLY_VAL else ["train", "validate"]

        for type in types:
            print_loss = []

            if Config.DIM == "2D":
                nr_of_samples = len(getattr(Config, type.upper() + "_SUBJECTS")) * Config.INPUT_DIM[0]
            else:
                nr_of_samples = len(getattr(Config, type.upper() + "_SUBJECTS"))

            # *Config.EPOCH_MULTIPLIER needed to have roughly same number of updates/batches as with 2D U-Net
            nr_batches = int(int(nr_of_samples / Config.BATCH_SIZE) * Config.EPOCH_MULTIPLIER)

            print("Start looping batches...")
            start_time_batch_part = time.time()
            for i in range(nr_batches):

                batch = next(batch_gen_train) if type == "train" else next(batch_gen_val)

                start_time_data_preparation = time.time()
                batch_nr[type] += 1

                x = batch["data"]  # (bs, nr_of_channels, x, y)
                y = batch["seg"]  # (bs, nr_of_classes, x, y)

                timings["data_preparation_time"] += time.time() - start_time_data_preparation
                start_time_network = time.time()
                if type == "train":
                    nr_of_updates += 1
                    probs, metr_batch = model.train(x, y, weight_factor=weight_factor)
                elif type == "validate":
                    probs, metr_batch = model.test(x, y, weight_factor=weight_factor)
                elif type == "test":
                    probs, metr_batch = model.test(x, y, weight_factor=weight_factor)
                timings["network_time"] += time.time() - start_time_network

                start_time_metrics = time.time()
                metrics = _update_metrics(Config, metrics, metr_batch, type)
                timings["metrics_time"] += time.time() - start_time_metrics

                print_loss.append(metr_batch["loss"])
                if batch_nr[type] % Config.PRINT_FREQ == 0:
                    time_batch_part = time.time() - start_time_batch_part
                    start_time_batch_part = time.time()
                    exp_utils.print_and_save(Config, "{} Ep {}, Sp {}, loss {}, t print {}s, t batch {}s".format(
                        type, epoch_nr, batch_nr[type] * Config.BATCH_SIZE, round(np.array(print_loss).mean(), 6),
                        round(time_batch_part, 3), round( time_batch_part / Config.PRINT_FREQ, 3)))
                    print_loss = []

                if Config.USE_VISLOGGER:
                    plot_utils.plot_result_trixi(trixi, x, y, probs, metr_batch["loss"], metr_batch["f1_macro"], epoch_nr)


        ################################### Post Training tasks (each epoch) ###################################

        if Config.ONLY_VAL:
            metrics = metric_utils.normalize_last_element(metrics, batch_nr["validate"], type="validate")
            print("f1 macro validate: {}".format(round(metrics["f1_macro_validate"][0], 4)))
            return model

        # Average loss per batch over entire epoch
        metrics = metric_utils.normalize_last_element(metrics, batch_nr["train"], type="train")
        metrics = metric_utils.normalize_last_element(metrics, batch_nr["validate"], type="validate")

        print("  Epoch {}, Average Epoch loss = {}".format(epoch_nr, metrics["loss_train"][-1]))
        print("  Epoch {}, nr_of_updates {}".format(epoch_nr, nr_of_updates))

        # Adapt LR
        if Config.LR_SCHEDULE:
            if Config.LR_SCHEDULE_MODE == "min":
                model.scheduler.step(metrics["loss_validate"][-1])
            else:
                model.scheduler.step(metrics["f1_macro_validate"][-1])
            model.print_current_lr()

        # Save Weights
        start_time_saving = time.time()
        if Config.SAVE_WEIGHTS:
            model.save_model(metrics, epoch_nr, mode=Config.BEST_EPOCH_SELECTION)
        timings["saving_time"] += time.time() - start_time_saving

        # Create Plots
        start_time_plotting = time.time()
        pickle.dump(metrics, open(join(Config.EXP_PATH, "metrics.pkl"), "wb"))
        plot_utils.create_exp_plot(metrics, Config.EXP_PATH, Config.EXP_NAME,
                                   keys=["loss", "f1_macro"],
                                   types=["train", "validate"],
                                   selected_ax=["loss", "f1"],
                                   fig_name="metrics_all.png")
        plot_utils.create_exp_plot(metrics, Config.EXP_PATH, Config.EXP_NAME, without_first_epochs=True,
                                   keys=["loss", "f1_macro"],
                                   types=["train", "validate"],
                                   selected_ax=["loss", "f1"],
                                   fig_name="metrics.png")
        if "angle_err" in Config.METRIC_TYPES:
            plot_utils.create_exp_plot(metrics, Config.EXP_PATH, Config.EXP_NAME, without_first_epochs=True,
                                       keys=["loss", "angle_err"],
                                       types=["train", "validate"],
                                       selected_ax=["loss", "f1"],
                                       fig_name="metrics_angle.png")

        timings["plotting_time"] += time.time() - start_time_plotting

        epoch_time = time.time() - start_time
        epoch_times.append(epoch_time)

        exp_utils.print_and_save(Config, "  Epoch {}, time total {}s".format(epoch_nr, epoch_time))
        exp_utils.print_and_save(Config, "  Epoch {}, time UNet: {}s".format(epoch_nr, timings["network_time"]))
        exp_utils.print_and_save(Config, "  Epoch {}, time metrics: {}s".format(epoch_nr, timings["metrics_time"]))
        exp_utils.print_and_save(Config, "  Epoch {}, time saving files: {}s".format(epoch_nr, timings["saving_time"]))
        exp_utils.print_and_save(Config, str(datetime.datetime.now()))

        # Adding next Epoch
        if epoch_nr < Config.NUM_EPOCHS-1:
            metrics = metric_utils.add_empty_element(metrics)

    with open(join(Config.EXP_PATH, "Hyperparameters.txt"), "a") as f:
        f.write("\n\nAverage Epoch time: {}s".format(sum(epoch_times) / float(len(epoch_times))))
Exemplo n.º 3
0
def train_model(Config, model, data_loader):

    if Config.USE_VISLOGGER:
        try:
            from trixi.logger.visdom import PytorchVisdomLogger
        except ImportError:
            pass
        trixi = PytorchVisdomLogger(port=8080, auto_start=True)

    exp_utils.print_and_save(Config, socket.gethostname())

    epoch_times = []
    nr_of_updates = 0

    metrics = {}
    for type in ["train", "test", "validate"]:
        metrics_new = {
            "loss_" + type: [0],
            "f1_macro_" + type: [0],
        }
        metrics = dict(list(metrics.items()) + list(metrics_new.items()))

    for epoch_nr in range(Config.NUM_EPOCHS):
        start_time = time.time()
        # current_lr = Config.LEARNING_RATE * (Config.LR_DECAY ** epoch_nr)
        # current_lr = Config.LEARNING_RATE

        batch_gen_time = 0
        data_preparation_time = 0
        network_time = 0
        metrics_time = 0
        saving_time = 0
        plotting_time = 0

        batch_nr = {"train": 0, "test": 0, "validate": 0}

        if Config.LOSS_WEIGHT_LEN == -1:
            weight_factor = float(Config.LOSS_WEIGHT)
        else:
            if epoch_nr < Config.LOSS_WEIGHT_LEN:
                weight_factor = -(
                    (Config.LOSS_WEIGHT - 1) /
                    float(Config.LOSS_WEIGHT_LEN)) * epoch_nr + float(
                        Config.LOSS_WEIGHT)
            else:
                weight_factor = 1.

        for type in ["train", "test", "validate"]:
            print_loss = []
            start_time_batch_gen = time.time()

            batch_gen = data_loader.get_batch_generator(
                batch_size=Config.BATCH_SIZE,
                type=type,
                subjects=getattr(Config,
                                 type.upper() + "_SUBJECTS"))
            batch_gen_time = time.time() - start_time_batch_gen
            # print("batch_gen_time: {}s".format(batch_gen_time))

            if Config.DIM == "2D":
                nr_of_samples = len(getattr(
                    Config,
                    type.upper() + "_SUBJECTS")) * Config.INPUT_DIM[0]
            else:
                nr_of_samples = len(getattr(Config,
                                            type.upper() + "_SUBJECTS"))

            # *Config.EPOCH_MULTIPLIER needed to have roughly same number of updates/batches as with 2D U-Net
            nr_batches = int(
                int(nr_of_samples / Config.BATCH_SIZE) *
                Config.EPOCH_MULTIPLIER)

            print("Start looping batches...")
            start_time_batch_part = time.time()
            for i in range(nr_batches):
                batch = next(batch_gen)

                start_time_data_preparation = time.time()
                batch_nr[type] += 1

                x = batch["data"]  # (bs, nr_of_channels, x, y)
                y = batch["seg"]  # (bs, nr_of_classes, x, y)

                data_preparation_time += time.time(
                ) - start_time_data_preparation
                start_time_network = time.time()
                if type == "train":
                    nr_of_updates += 1
                    loss, probs, f1 = model.train(x,
                                                  y,
                                                  weight_factor=weight_factor)
                    # loss, probs, f1, intermediate = model.train(x, y)
                elif type == "validate":
                    loss, probs, f1 = model.test(x,
                                                 y,
                                                 weight_factor=weight_factor)
                elif type == "test":
                    loss, probs, f1 = model.test(x,
                                                 y,
                                                 weight_factor=weight_factor)
                network_time += time.time() - start_time_network

                start_time_metrics = time.time()

                if Config.CALC_F1:
                    if Config.EXPERIMENT_TYPE == "peak_regression":
                        #Following two lines increase metrics_time by 30s (without < 1s);
                        #  time per batch increases by 1.5s by these lines
                        # y_flat = y.transpose(0, 2, 3, 1)  # (bs, x, y, nr_of_classes)
                        # y_flat = np.reshape(y_flat, (-1, y_flat.shape[-1]))  # (bs*x*y, nr_of_classes)
                        # metrics = metric_utils.calculate_metrics(metrics, y_flat, probs, loss, f1=np.mean(f1),
                        #                                          type=type, threshold=Config.THRESHOLD,
                        #                                          f1_per_bundle={"CA": f1[5], "FX_left": f1[23],
                        #                                                         "FX_right": f1[24]})

                        #Numpy
                        # y_right_order = y.transpose(0, 2, 3, 1)  # (bs, x, y, nr_of_classes)
                        # peak_f1 = metric_utils.calc_peak_dice(Config, probs, y_right_order)
                        # peak_f1_mean = np.array([s for s in peak_f1.values()]).mean()

                        # import IPython
                        # IPython.embed()

                        #Pytorch
                        peak_f1_mean = np.array([
                            s.to('cpu') for s in list(f1.values())
                        ]).mean()  #if f1 for multiple bundles
                        metrics = metric_utils.calculate_metrics(
                            metrics,
                            None,
                            None,
                            loss,
                            f1=peak_f1_mean,
                            type=type,
                            threshold=Config.THRESHOLD)

                        #Pytorch 2 F1
                        # peak_f1_mean_a = np.array([s for s in f1[0].values()]).mean()
                        # peak_f1_mean_b = np.array([s for s in f1[1].values()]).mean()
                        # metrics = metric_utils.calculate_metrics(metrics, None, None, loss, f1=peak_f1_mean_a,
                        #                                         type=type, threshold=Config.THRESHOLD,
                        #                                         f1_per_bundle={"LenF1": peak_f1_mean_b})

                        #Single Bundle
                        # metrics = metric_utils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"][0],
                        #                                          type=type, threshold=Config.THRESHOLD,
                        #                                          f1_per_bundle={"Thr1": f1["CST_right"][1],
                        #                                                         "Thr2": f1["CST_right"][2]})
                        # metrics = metric_utils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"],
                        #                                          type=type, threshold=Config.THRESHOLD)
                    else:
                        metrics = metric_utils.calculate_metrics(
                            metrics,
                            None,
                            None,
                            loss,
                            f1=np.mean(f1),
                            type=type,
                            threshold=Config.THRESHOLD)

                else:
                    metrics = metric_utils.calculate_metrics_onlyLoss(
                        metrics, loss, type=type)

                metrics_time += time.time() - start_time_metrics

                print_loss.append(loss)
                if batch_nr[type] % Config.PRINT_FREQ == 0:
                    time_batch_part = time.time() - start_time_batch_part
                    start_time_batch_part = time.time()
                    exp_utils.print_and_save(
                        Config, "{} Ep {}, Sp {}, loss {}, t print {}s, "
                        "t batch {}s".format(
                            type, epoch_nr, batch_nr[type] * Config.BATCH_SIZE,
                            round(np.array(print_loss).mean(), 6),
                            round(time_batch_part, 3),
                            round(time_batch_part / Config.PRINT_FREQ, 3)))
                    print_loss = []

                if Config.USE_VISLOGGER:
                    plot_utils.plot_result_trixi(trixi, x, y, probs, loss, f1,
                                                 epoch_nr)

        ###################################
        # Post Training tasks (each epoch)
        ###################################

        # Average loss per batch over entire epoch
        metrics = metric_utils.normalize_last_element(metrics,
                                                      batch_nr["train"],
                                                      type="train")
        metrics = metric_utils.normalize_last_element(metrics,
                                                      batch_nr["validate"],
                                                      type="validate")
        metrics = metric_utils.normalize_last_element(metrics,
                                                      batch_nr["test"],
                                                      type="test")

        print("  Epoch {}, Average Epoch loss = {}".format(
            epoch_nr, metrics["loss_train"][-1]))
        print("  Epoch {}, nr_of_updates {}".format(epoch_nr, nr_of_updates))

        # Adapt LR
        if Config.LR_SCHEDULE:
            if Config.LR_SCHEDULE_MODE == "min":
                model.scheduler.step(metrics["loss_validate"][-1])
            else:
                model.scheduler.step(metrics["f1_macro_validate"][-1])
            model.print_current_lr()

        # Save Weights
        start_time_saving = time.time()
        if Config.SAVE_WEIGHTS:
            model.save_model(metrics, epoch_nr)
        saving_time += time.time() - start_time_saving

        # Create Plots
        start_time_plotting = time.time()
        pickle.dump(metrics, open(join(Config.EXP_PATH, "metrics.pkl"), "wb"))
        plot_utils.create_exp_plot(metrics, Config.EXP_PATH, Config.EXP_NAME)
        plot_utils.create_exp_plot(metrics,
                                   Config.EXP_PATH,
                                   Config.EXP_NAME,
                                   without_first_epochs=True)
        plotting_time += time.time() - start_time_plotting

        epoch_time = time.time() - start_time
        epoch_times.append(epoch_time)

        exp_utils.print_and_save(
            Config, "  Epoch {}, time total {}s".format(epoch_nr, epoch_time))
        exp_utils.print_and_save(
            Config,
            "  Epoch {}, time UNet: {}s".format(epoch_nr, network_time))
        exp_utils.print_and_save(
            Config,
            "  Epoch {}, time metrics: {}s".format(epoch_nr, metrics_time))
        exp_utils.print_and_save(
            Config,
            "  Epoch {}, time saving files: {}s".format(epoch_nr, saving_time))
        exp_utils.print_and_save(Config, str(datetime.datetime.now()))

        # Adding next Epoch
        if epoch_nr < Config.NUM_EPOCHS - 1:
            metrics = metric_utils.add_empty_element(metrics)

    ####################################
    # After all epochs
    ###################################
    with open(join(Config.EXP_PATH, "Hyperparameters.txt"),
              "a") as f:  # a for append
        f.write("\n\n")
        f.write("Average Epoch time: {}s".format(
            sum(epoch_times) / float(len(epoch_times))))

    return model
 def setUp(self):
     self.visdomLogger = PytorchVisdomLogger()
class TestPytorchVisdomLogger(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        super(TestPytorchVisdomLogger, cls).setUpClass()
        try:
            start_visdom()
        except:
            print("Could not start visdom, it might be already running.")

    def setUp(self):
        self.visdomLogger = PytorchVisdomLogger()

    def test_show_image(self):
        image = np.random.random_sample((3, 128, 128))
        tensor = torch.from_numpy(image)
        self.visdomLogger.show_image(tensor.numpy(), title='image')

    def test_show_images(self):
        images = np.random.random_sample((4, 3, 128, 128))
        tensors = torch.from_numpy(images)
        self.visdomLogger.show_images(tensors.numpy(), title='images')

    def test_show_image_grid(self):
        images = np.random.random_sample((4, 3, 128, 128))
        tensor = torch.from_numpy(images)
        self.visdomLogger.show_image_grid(tensor, title="image_grid")

    def test_show_image_grid_heatmap(self):
        images = np.random.random_sample((4, 3, 128, 128))
        self.visdomLogger.show_image_grid_heatmap(images,
                                                  title="image_grid_heatmap")

    def test_show_barplot(self):
        tensor = torch.from_numpy(np.random.random_sample(5))
        self.visdomLogger.show_barplot(tensor, title="barplot")

    def test_show_lineplot(self):
        x = [0, 1, 2, 3, 4, 5]
        y = np.random.random_sample(6)
        self.visdomLogger.show_lineplot(y, x, title="lineplot1")

    def test_show_piechart(self):
        array = torch.from_numpy(np.random.random_sample(5))
        self.visdomLogger.show_piechart(array, title="piechart")

    def test_show_scatterplot(self):
        array = torch.from_numpy(np.random.random_sample((5, 2)))
        self.visdomLogger.show_scatterplot(array, title="scatterplot")

    def test_show_value(self):
        val = torch.from_numpy(np.random.random_sample(1))
        self.visdomLogger.show_value(val, title="value")

        val = torch.from_numpy(np.random.random_sample(1))
        self.visdomLogger.show_value(val, title="value")

        val = torch.from_numpy(np.random.random_sample(1))
        self.visdomLogger.show_value(val, title="value", counter=4)

    def test_show_text(self):
        text = "\nTest 4 fun: zD ;-D 0o"
        self.visdomLogger.show_text(text, title='text')

    def test_get_roc_curve(self):
        array = np.random.random_sample(100)
        labels = np.random.choice((0, 1), 100)

        self.visdomLogger.show_roc_curve(array, labels, name="roc")

    def test_get_pr_curve(self):
        array = np.random.random_sample(100)
        labels = np.random.choice((0, 1), 100)

        self.visdomLogger.show_roc_curve(array, labels, name="pr")

    def test_get_classification_metric(self):
        array = np.random.random_sample(100)
        labels = np.random.choice((0, 1), 100)

        self.visdomLogger.show_classification_metrics(
            array,
            labels,
            metric=("roc-auc", "pr-score"),
            name="classification-metrics")

    def test_show_image_gradient(self):
        net = Net()
        random_input = torch.from_numpy(
            np.random.randn(28 * 28).reshape((1, 1, 28, 28))).float()
        fake_labels = torch.from_numpy(np.array([2])).long()
        criterion = torch.nn.CrossEntropyLoss()

        def err_fn(x):
            x = net(x)
            return criterion(x, fake_labels)

        self.visdomLogger.show_image_gradient(name="grads-vanilla",
                                              model=net,
                                              inpt=random_input,
                                              err_fn=err_fn,
                                              grad_type="vanilla")
        time.sleep(1)

        self.visdomLogger.show_image_gradient(name="grads-svanilla",
                                              model=net,
                                              inpt=random_input,
                                              err_fn=err_fn,
                                              grad_type="smooth-vanilla")
        time.sleep(1)

        self.visdomLogger.show_image_gradient(name="grads-guided",
                                              model=net,
                                              inpt=random_input,
                                              err_fn=err_fn,
                                              grad_type="guided")
        time.sleep(1)

        self.visdomLogger.show_image_gradient(name="grads-sguided",
                                              model=net,
                                              inpt=random_input,
                                              err_fn=err_fn,
                                              grad_type="smooth-guided")
        time.sleep(1)

    def test_plot_model_structure(self):
        net = Net()
        self.visdomLogger.plot_model_structure(net, [(1, 1, 28, 28)])

    def test_plot_model_statistics(self):
        net = Net()
        self.visdomLogger.plot_model_statistics(net, plot_grad=False)
        self.visdomLogger.plot_model_statistics(net, plot_grad=True)

    def test_show_embedding(self):
        array = torch.from_numpy(np.random.random_sample((100, 100)))
        self.visdomLogger.show_embedding(array, method="tsne")
        self.visdomLogger.show_embedding(array, method="umap")
Exemplo n.º 6
0
def train_model(Config, model, data_loader):

    if Config.USE_VISLOGGER:
        try:
            from trixi.logger.visdom import PytorchVisdomLogger
        except ImportError:
            pass
        trixi = PytorchVisdomLogger(port=8080, auto_start=True)

    exp_utils.print_and_save(Config, socket.gethostname())

    epoch_times = []
    nr_of_updates = 0

    metrics = {}
    for type in ["train", "test", "validate"]:
        metrics_new = {}
        for metric in Config.METRIC_TYPES:
            metrics_new[metric + "_" + type] = [0]

        metrics = dict(list(metrics.items()) + list(metrics_new.items()))

    batch_gen_train = data_loader.get_batch_generator(
        batch_size=Config.BATCH_SIZE,
        type="train",
        subjects=getattr(Config, "TRAIN_SUBJECTS"))
    batch_gen_val = data_loader.get_batch_generator(
        batch_size=Config.BATCH_SIZE,
        type="validate",
        subjects=getattr(Config, "VALIDATE_SUBJECTS"))

    for epoch_nr in range(Config.NUM_EPOCHS):
        start_time = time.time()
        # current_lr = Config.LEARNING_RATE * (Config.LR_DECAY ** epoch_nr)
        # current_lr = Config.LEARNING_RATE

        data_preparation_time = 0
        network_time = 0
        metrics_time = 0
        saving_time = 0
        plotting_time = 0

        batch_nr = {"train": 0, "test": 0, "validate": 0}

        if Config.LOSS_WEIGHT is None:
            weight_factor = None
        elif Config.LOSS_WEIGHT_LEN == -1:
            weight_factor = float(Config.LOSS_WEIGHT)
        else:
            # Linearly decrease from LOSS_WEIGHT to 1 over LOSS_WEIGHT_LEN epochs
            if epoch_nr < Config.LOSS_WEIGHT_LEN:
                weight_factor = -(
                    (Config.LOSS_WEIGHT - 1) /
                    float(Config.LOSS_WEIGHT_LEN)) * epoch_nr + float(
                        Config.LOSS_WEIGHT)
            else:
                weight_factor = 1.
            exp_utils.print_and_save(
                Config, "Current weight_factor: {}".format(weight_factor))

        if Config.ONLY_VAL:
            types = ["validate"]
        else:
            types = ["train", "validate"]

        for type in types:
            print_loss = []

            if Config.DIM == "2D":
                nr_of_samples = len(getattr(
                    Config,
                    type.upper() + "_SUBJECTS")) * Config.INPUT_DIM[0]
            else:
                nr_of_samples = len(getattr(Config,
                                            type.upper() + "_SUBJECTS"))

            # *Config.EPOCH_MULTIPLIER needed to have roughly same number of updates/batches as with 2D U-Net
            nr_batches = int(
                int(nr_of_samples / Config.BATCH_SIZE) *
                Config.EPOCH_MULTIPLIER)

            print("Start looping batches...")
            start_time_batch_part = time.time()
            for i in range(nr_batches):

                if type == "train":
                    batch = next(batch_gen_train)
                else:
                    batch = next(batch_gen_val)

                start_time_data_preparation = time.time()
                batch_nr[type] += 1

                x = batch["data"]  # (bs, nr_of_channels, x, y)
                y = batch["seg"]  # (bs, nr_of_classes, x, y)

                # print("x.shape: {}".format(x.shape))
                # print("y.shape: {}".format(y.shape))

                data_preparation_time += time.time(
                ) - start_time_data_preparation
                start_time_network = time.time()
                if type == "train":
                    nr_of_updates += 1
                    probs, metr_batch = model.train(
                        x, y, weight_factor=weight_factor)
                elif type == "validate":
                    probs, metr_batch = model.test(x,
                                                   y,
                                                   weight_factor=weight_factor)
                elif type == "test":
                    probs, metr_batch = model.test(x,
                                                   y,
                                                   weight_factor=weight_factor)
                network_time += time.time() - start_time_network

                start_time_metrics = time.time()

                if Config.CALC_F1:
                    if Config.EXPERIMENT_TYPE == "peak_regression":
                        peak_f1_mean = np.array([
                            s.to('cpu')
                            for s in list(metr_batch["f1_macro"].values())
                        ]).mean()
                        metr_batch["f1_macro"] = peak_f1_mean

                        metrics = metric_utils.add_to_metrics(
                            metrics, metr_batch, type, Config.METRIC_TYPES)

                    else:
                        metr_batch["f1_macro"] = np.mean(
                            metr_batch["f1_macro"])
                        metrics = metric_utils.add_to_metrics(
                            metrics, metr_batch, type, Config.METRIC_TYPES)

                else:
                    metrics = metric_utils.calculate_metrics_onlyLoss(
                        metrics, metr_batch["loss"], type=type)

                metrics_time += time.time() - start_time_metrics

                print_loss.append(metr_batch["loss"])
                if batch_nr[type] % Config.PRINT_FREQ == 0:
                    time_batch_part = time.time() - start_time_batch_part
                    start_time_batch_part = time.time()
                    exp_utils.print_and_save(
                        Config, "{} Ep {}, Sp {}, loss {}, t print {}s, "
                        "t batch {}s".format(
                            type, epoch_nr, batch_nr[type] * Config.BATCH_SIZE,
                            round(np.array(print_loss).mean(), 6),
                            round(time_batch_part, 3),
                            round(time_batch_part / Config.PRINT_FREQ, 3)))
                    print_loss = []

                if Config.USE_VISLOGGER:
                    plot_utils.plot_result_trixi(trixi, x, y, probs,
                                                 metr_batch["loss"],
                                                 metr_batch["f1_macro"],
                                                 epoch_nr)

        ###################################
        # Post Training tasks (each epoch)
        ###################################

        if Config.ONLY_VAL:
            metrics = metric_utils.normalize_last_element(metrics,
                                                          batch_nr["validate"],
                                                          type="validate")
            print("f1 macro validate: {}".format(
                round(metrics["f1_macro_validate"][0], 4)))
            return model

        # Average loss per batch over entire epoch
        metrics = metric_utils.normalize_last_element(metrics,
                                                      batch_nr["train"],
                                                      type="train")
        metrics = metric_utils.normalize_last_element(metrics,
                                                      batch_nr["validate"],
                                                      type="validate")
        # metrics = metric_utils.normalize_last_element(metrics, batch_nr["test"], type="test")

        print("  Epoch {}, Average Epoch loss = {}".format(
            epoch_nr, metrics["loss_train"][-1]))
        print("  Epoch {}, nr_of_updates {}".format(epoch_nr, nr_of_updates))

        # Adapt LR
        if Config.LR_SCHEDULE:
            if Config.LR_SCHEDULE_MODE == "min":
                model.scheduler.step(metrics["loss_validate"][-1])
            else:
                model.scheduler.step(metrics["f1_macro_validate"][-1])
            model.print_current_lr()

        # Save Weights
        start_time_saving = time.time()
        if Config.SAVE_WEIGHTS:
            model.save_model(metrics,
                             epoch_nr,
                             mode=Config.BEST_EPOCH_SELECTION)
        saving_time += time.time() - start_time_saving

        # Create Plots
        start_time_plotting = time.time()
        pickle.dump(metrics, open(join(Config.EXP_PATH, "metrics.pkl"), "wb"))
        plot_utils.create_exp_plot(metrics,
                                   Config.EXP_PATH,
                                   Config.EXP_NAME,
                                   keys=["loss", "f1_macro"],
                                   types=["train", "validate"],
                                   selected_ax=["loss", "f1"],
                                   fig_name="metrics_all.png")
        plot_utils.create_exp_plot(metrics,
                                   Config.EXP_PATH,
                                   Config.EXP_NAME,
                                   without_first_epochs=True,
                                   keys=["loss", "f1_macro"],
                                   types=["train", "validate"],
                                   selected_ax=["loss", "f1"],
                                   fig_name="metrics.png")
        if "angle_err" in Config.METRIC_TYPES:
            plot_utils.create_exp_plot(metrics,
                                       Config.EXP_PATH,
                                       Config.EXP_NAME,
                                       without_first_epochs=True,
                                       keys=["loss", "angle_err"],
                                       types=["train", "validate"],
                                       selected_ax=["loss", "f1"],
                                       fig_name="metrics_angle.png")

        plotting_time += time.time() - start_time_plotting

        epoch_time = time.time() - start_time
        epoch_times.append(epoch_time)

        exp_utils.print_and_save(
            Config, "  Epoch {}, time total {}s".format(epoch_nr, epoch_time))
        exp_utils.print_and_save(
            Config,
            "  Epoch {}, time UNet: {}s".format(epoch_nr, network_time))
        exp_utils.print_and_save(
            Config,
            "  Epoch {}, time metrics: {}s".format(epoch_nr, metrics_time))
        exp_utils.print_and_save(
            Config,
            "  Epoch {}, time saving files: {}s".format(epoch_nr, saving_time))
        exp_utils.print_and_save(Config, str(datetime.datetime.now()))

        # Adding next Epoch
        if epoch_nr < Config.NUM_EPOCHS - 1:
            metrics = metric_utils.add_empty_element(metrics)

    ####################################
    # After all epochs
    ###################################
    with open(join(Config.EXP_PATH, "Hyperparameters.txt"),
              "a") as f:  # a for append
        f.write("\n\n")
        f.write("Average Epoch time: {}s".format(
            sum(epoch_times) / float(len(epoch_times))))

    return model