def train(self, HP): if HP.USE_VISLOGGER: try: from trixi.logger.visdom import PytorchVisdomLogger except ImportError: pass trixi = PytorchVisdomLogger(port=8080, auto_start=True) ExpUtils.print_and_save(HP, socket.gethostname()) epoch_times = [] nr_of_updates = 0 metrics = {} for type in ["train", "test", "validate"]: metrics_new = { "loss_" + type: [0], "f1_macro_" + type: [0], } metrics = dict(list(metrics.items()) + list(metrics_new.items())) for epoch_nr in range(HP.NUM_EPOCHS): start_time = time.time() # current_lr = HP.LEARNING_RATE * (HP.LR_DECAY ** epoch_nr) # current_lr = HP.LEARNING_RATE batch_gen_time = 0 data_preparation_time = 0 network_time = 0 metrics_time = 0 saving_time = 0 plotting_time = 0 batch_nr = { "train": 0, "test": 0, "validate": 0 } if HP.LOSS_WEIGHT_LEN == -1: weight_factor = float(HP.LOSS_WEIGHT) else: if epoch_nr < HP.LOSS_WEIGHT_LEN: # weight_factor = -(9./100.) * epoch_nr + 10. #ep0: 10 -> linear decrease -> ep100: 1 weight_factor = -((HP.LOSS_WEIGHT-1)/float(HP.LOSS_WEIGHT_LEN)) * epoch_nr + float(HP.LOSS_WEIGHT) # weight_factor = -((HP.LOSS_WEIGHT-5)/float(HP.LOSS_WEIGHT_LEN)) * epoch_nr + float(HP.LOSS_WEIGHT) else: weight_factor = 1. # weight_factor = 5. for type in ["train", "test", "validate"]: print_loss = [] start_time_batch_gen = time.time() batch_generator = self.dataManager.get_batches(batch_size=HP.BATCH_SIZE, type=type, subjects=getattr(HP, type.upper() + "_SUBJECTS")) batch_gen_time = time.time() - start_time_batch_gen # print("batch_gen_time: {}s".format(batch_gen_time)) print("Start looping batches...") start_time_batch_part = time.time() for batch in batch_generator: #getting next batch takes around 0.14s -> second largest Time part after mode! start_time_data_preparation = time.time() batch_nr[type] += 1 x = batch["data"] # (bs, nr_of_channels, x, y) y = batch["seg"] # (bs, nr_of_classes, x, y) # since using new BatchGenerator y is not int anymore but float -> would be good for Pytorch but not Lasagne # y = y.astype(HP.LABELS_TYPE) #for bundle_peaks regression: is already float -> saves 0.2s/batch if left out data_preparation_time += time.time() - start_time_data_preparation # self.model.learning_rate.set_value(np.float32(current_lr)) start_time_network = time.time() if type == "train": nr_of_updates += 1 loss, probs, f1 = self.model.train(x, y, weight_factor=weight_factor) # probs: # (bs, x, y, nrClasses) # loss, probs, f1, intermediate = self.model.train(x, y) elif type == "validate": loss, probs, f1 = self.model.predict(x, y, weight_factor=weight_factor) elif type == "test": loss, probs, f1 = self.model.predict(x, y, weight_factor=weight_factor) network_time += time.time() - start_time_network start_time_metrics = time.time() if HP.CALC_F1: if HP.EXPERIMENT_TYPE == "peak_regression": #Following two lines increase metrics_time by 30s (without < 1s); time per batch increases by 1.5s by these lines # y_flat = y.transpose(0, 2, 3, 1) # (bs, x, y, nr_of_classes) # y_flat = np.reshape(y_flat, (-1, y_flat.shape[-1])) # (bs*x*y, nr_of_classes) # metrics = MetricUtils.calculate_metrics(metrics, y_flat, probs, loss, f1=np.mean(f1), type=type, threshold=HP.THRESHOLD, # f1_per_bundle={"CA": f1[5], "FX_left": f1[23], "FX_right": f1[24]}) #Numpy # y_right_order = y.transpose(0, 2, 3, 1) # (bs, x, y, nr_of_classes) # peak_f1 = MetricUtils.calc_peak_dice(HP, probs, y_right_order) # peak_f1_mean = np.array([s for s in peak_f1.values()]).mean() #Pytorch peak_f1_mean = np.array([s for s in list(f1.values())]).mean() #if f1 for multiple bundles metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=peak_f1_mean, type=type, threshold=HP.THRESHOLD) #Pytorch 2 F1 # peak_f1_mean_a = np.array([s for s in f1[0].values()]).mean() # peak_f1_mean_b = np.array([s for s in f1[1].values()]).mean() # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=peak_f1_mean_a, type=type, threshold=HP.THRESHOLD, # f1_per_bundle={"LenF1": peak_f1_mean_b}) #Single Bundle # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"][0], type=type, threshold=HP.THRESHOLD, # f1_per_bundle={"Thr1": f1["CST_right"][1], "Thr2": f1["CST_right"][2]}) # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"], type=type, threshold=HP.THRESHOLD) else: metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=np.mean(f1), type=type, threshold=HP.THRESHOLD) else: metrics = MetricUtils.calculate_metrics_onlyLoss(metrics, loss, type=type) metrics_time += time.time() - start_time_metrics print_loss.append(loss) if batch_nr[type] % HP.PRINT_FREQ == 0: time_batch_part = time.time() - start_time_batch_part start_time_batch_part = time.time() ExpUtils.print_and_save(HP, "{} Ep {}, Sp {}, loss {}, t print {}s, t batch {}s".format(type, epoch_nr, batch_nr[type] * HP.BATCH_SIZE, round(np.array(print_loss).mean(), 6), round(time_batch_part, 3), round(time_batch_part / HP.PRINT_FREQ, 3))) print_loss = [] if HP.USE_VISLOGGER: ExpUtils.plot_result_trixi(trixi, x, y, probs, loss, f1, epoch_nr) ################################### # Post Training tasks (each epoch) ################################### #Adapt LR if HP.LR_SCHEDULE: self.model.scheduler.step() # self.model.scheduler.step(np.mean(f1)) self.model.print_current_lr() # Average loss per batch over entire epoch metrics = MetricUtils.normalize_last_element(metrics, batch_nr["train"], type="train") metrics = MetricUtils.normalize_last_element(metrics, batch_nr["validate"], type="validate") metrics = MetricUtils.normalize_last_element(metrics, batch_nr["test"], type="test") print(" Epoch {}, Average Epoch loss = {}".format(epoch_nr, metrics["loss_train"][-1])) print(" Epoch {}, nr_of_updates {}".format(epoch_nr, nr_of_updates)) # Save Weights start_time_saving = time.time() if HP.SAVE_WEIGHTS: self.model.save_model(metrics, epoch_nr) saving_time += time.time() - start_time_saving # Create Plots start_time_plotting = time.time() pickle.dump(metrics, open(join(HP.EXP_PATH, "metrics.pkl"), "wb")) # wb -> write (override) and binary (binary only needed on windows, on unix also works without) # for loading: pickle.load(open("metrics.pkl", "rb")) ExpUtils.create_exp_plot(metrics, HP.EXP_PATH, HP.EXP_NAME) ExpUtils.create_exp_plot(metrics, HP.EXP_PATH, HP.EXP_NAME, without_first_epochs=True) plotting_time += time.time() - start_time_plotting epoch_time = time.time() - start_time epoch_times.append(epoch_time) ExpUtils.print_and_save(HP, " Epoch {}, time total {}s".format(epoch_nr, epoch_time)) ExpUtils.print_and_save(HP, " Epoch {}, time UNet: {}s".format(epoch_nr, network_time)) ExpUtils.print_and_save(HP, " Epoch {}, time metrics: {}s".format(epoch_nr, metrics_time)) ExpUtils.print_and_save(HP, " Epoch {}, time saving files: {}s".format(epoch_nr, saving_time)) ExpUtils.print_and_save(HP, str(datetime.datetime.now())) # Adding next Epoch if epoch_nr < HP.NUM_EPOCHS-1: metrics = MetricUtils.add_empty_element(metrics) #################################### # After all epochs ################################### with open(join(HP.EXP_PATH, "Hyperparameters.txt"), "a") as f: # a for append f.write("\n\n") f.write("Average Epoch time: {}s".format(sum(epoch_times) / float(len(epoch_times)))) return metrics
def train_model(Config, model, data_loader): if Config.USE_VISLOGGER: try: from trixi.logger.visdom import PytorchVisdomLogger except ImportError: pass trixi = PytorchVisdomLogger(port=8080, auto_start=True) exp_utils.print_and_save(Config, socket.gethostname()) epoch_times = [] nr_of_updates = 0 metrics = {} for type in ["train", "test", "validate"]: for metric in Config.METRIC_TYPES: metrics[metric + "_" + type] = [0] batch_gen_train = data_loader.get_batch_generator(batch_size=Config.BATCH_SIZE, type="train", subjects=getattr(Config, "TRAIN_SUBJECTS")) batch_gen_val = data_loader.get_batch_generator(batch_size=Config.BATCH_SIZE, type="validate", subjects=getattr(Config, "VALIDATE_SUBJECTS")) for epoch_nr in range(Config.NUM_EPOCHS): start_time = time.time() timings = defaultdict(lambda: 0) batch_nr = defaultdict(lambda: 0) weight_factor = _get_weights_for_this_epoch(Config, epoch_nr) types = ["validate"] if Config.ONLY_VAL else ["train", "validate"] for type in types: print_loss = [] if Config.DIM == "2D": nr_of_samples = len(getattr(Config, type.upper() + "_SUBJECTS")) * Config.INPUT_DIM[0] else: nr_of_samples = len(getattr(Config, type.upper() + "_SUBJECTS")) # *Config.EPOCH_MULTIPLIER needed to have roughly same number of updates/batches as with 2D U-Net nr_batches = int(int(nr_of_samples / Config.BATCH_SIZE) * Config.EPOCH_MULTIPLIER) print("Start looping batches...") start_time_batch_part = time.time() for i in range(nr_batches): batch = next(batch_gen_train) if type == "train" else next(batch_gen_val) start_time_data_preparation = time.time() batch_nr[type] += 1 x = batch["data"] # (bs, nr_of_channels, x, y) y = batch["seg"] # (bs, nr_of_classes, x, y) timings["data_preparation_time"] += time.time() - start_time_data_preparation start_time_network = time.time() if type == "train": nr_of_updates += 1 probs, metr_batch = model.train(x, y, weight_factor=weight_factor) elif type == "validate": probs, metr_batch = model.test(x, y, weight_factor=weight_factor) elif type == "test": probs, metr_batch = model.test(x, y, weight_factor=weight_factor) timings["network_time"] += time.time() - start_time_network start_time_metrics = time.time() metrics = _update_metrics(Config, metrics, metr_batch, type) timings["metrics_time"] += time.time() - start_time_metrics print_loss.append(metr_batch["loss"]) if batch_nr[type] % Config.PRINT_FREQ == 0: time_batch_part = time.time() - start_time_batch_part start_time_batch_part = time.time() exp_utils.print_and_save(Config, "{} Ep {}, Sp {}, loss {}, t print {}s, t batch {}s".format( type, epoch_nr, batch_nr[type] * Config.BATCH_SIZE, round(np.array(print_loss).mean(), 6), round(time_batch_part, 3), round( time_batch_part / Config.PRINT_FREQ, 3))) print_loss = [] if Config.USE_VISLOGGER: plot_utils.plot_result_trixi(trixi, x, y, probs, metr_batch["loss"], metr_batch["f1_macro"], epoch_nr) ################################### Post Training tasks (each epoch) ################################### if Config.ONLY_VAL: metrics = metric_utils.normalize_last_element(metrics, batch_nr["validate"], type="validate") print("f1 macro validate: {}".format(round(metrics["f1_macro_validate"][0], 4))) return model # Average loss per batch over entire epoch metrics = metric_utils.normalize_last_element(metrics, batch_nr["train"], type="train") metrics = metric_utils.normalize_last_element(metrics, batch_nr["validate"], type="validate") print(" Epoch {}, Average Epoch loss = {}".format(epoch_nr, metrics["loss_train"][-1])) print(" Epoch {}, nr_of_updates {}".format(epoch_nr, nr_of_updates)) # Adapt LR if Config.LR_SCHEDULE: if Config.LR_SCHEDULE_MODE == "min": model.scheduler.step(metrics["loss_validate"][-1]) else: model.scheduler.step(metrics["f1_macro_validate"][-1]) model.print_current_lr() # Save Weights start_time_saving = time.time() if Config.SAVE_WEIGHTS: model.save_model(metrics, epoch_nr, mode=Config.BEST_EPOCH_SELECTION) timings["saving_time"] += time.time() - start_time_saving # Create Plots start_time_plotting = time.time() pickle.dump(metrics, open(join(Config.EXP_PATH, "metrics.pkl"), "wb")) plot_utils.create_exp_plot(metrics, Config.EXP_PATH, Config.EXP_NAME, keys=["loss", "f1_macro"], types=["train", "validate"], selected_ax=["loss", "f1"], fig_name="metrics_all.png") plot_utils.create_exp_plot(metrics, Config.EXP_PATH, Config.EXP_NAME, without_first_epochs=True, keys=["loss", "f1_macro"], types=["train", "validate"], selected_ax=["loss", "f1"], fig_name="metrics.png") if "angle_err" in Config.METRIC_TYPES: plot_utils.create_exp_plot(metrics, Config.EXP_PATH, Config.EXP_NAME, without_first_epochs=True, keys=["loss", "angle_err"], types=["train", "validate"], selected_ax=["loss", "f1"], fig_name="metrics_angle.png") timings["plotting_time"] += time.time() - start_time_plotting epoch_time = time.time() - start_time epoch_times.append(epoch_time) exp_utils.print_and_save(Config, " Epoch {}, time total {}s".format(epoch_nr, epoch_time)) exp_utils.print_and_save(Config, " Epoch {}, time UNet: {}s".format(epoch_nr, timings["network_time"])) exp_utils.print_and_save(Config, " Epoch {}, time metrics: {}s".format(epoch_nr, timings["metrics_time"])) exp_utils.print_and_save(Config, " Epoch {}, time saving files: {}s".format(epoch_nr, timings["saving_time"])) exp_utils.print_and_save(Config, str(datetime.datetime.now())) # Adding next Epoch if epoch_nr < Config.NUM_EPOCHS-1: metrics = metric_utils.add_empty_element(metrics) with open(join(Config.EXP_PATH, "Hyperparameters.txt"), "a") as f: f.write("\n\nAverage Epoch time: {}s".format(sum(epoch_times) / float(len(epoch_times))))
def train_model(Config, model, data_loader): if Config.USE_VISLOGGER: try: from trixi.logger.visdom import PytorchVisdomLogger except ImportError: pass trixi = PytorchVisdomLogger(port=8080, auto_start=True) exp_utils.print_and_save(Config, socket.gethostname()) epoch_times = [] nr_of_updates = 0 metrics = {} for type in ["train", "test", "validate"]: metrics_new = { "loss_" + type: [0], "f1_macro_" + type: [0], } metrics = dict(list(metrics.items()) + list(metrics_new.items())) for epoch_nr in range(Config.NUM_EPOCHS): start_time = time.time() # current_lr = Config.LEARNING_RATE * (Config.LR_DECAY ** epoch_nr) # current_lr = Config.LEARNING_RATE batch_gen_time = 0 data_preparation_time = 0 network_time = 0 metrics_time = 0 saving_time = 0 plotting_time = 0 batch_nr = {"train": 0, "test": 0, "validate": 0} if Config.LOSS_WEIGHT_LEN == -1: weight_factor = float(Config.LOSS_WEIGHT) else: if epoch_nr < Config.LOSS_WEIGHT_LEN: weight_factor = -( (Config.LOSS_WEIGHT - 1) / float(Config.LOSS_WEIGHT_LEN)) * epoch_nr + float( Config.LOSS_WEIGHT) else: weight_factor = 1. for type in ["train", "test", "validate"]: print_loss = [] start_time_batch_gen = time.time() batch_gen = data_loader.get_batch_generator( batch_size=Config.BATCH_SIZE, type=type, subjects=getattr(Config, type.upper() + "_SUBJECTS")) batch_gen_time = time.time() - start_time_batch_gen # print("batch_gen_time: {}s".format(batch_gen_time)) if Config.DIM == "2D": nr_of_samples = len(getattr( Config, type.upper() + "_SUBJECTS")) * Config.INPUT_DIM[0] else: nr_of_samples = len(getattr(Config, type.upper() + "_SUBJECTS")) # *Config.EPOCH_MULTIPLIER needed to have roughly same number of updates/batches as with 2D U-Net nr_batches = int( int(nr_of_samples / Config.BATCH_SIZE) * Config.EPOCH_MULTIPLIER) print("Start looping batches...") start_time_batch_part = time.time() for i in range(nr_batches): batch = next(batch_gen) start_time_data_preparation = time.time() batch_nr[type] += 1 x = batch["data"] # (bs, nr_of_channels, x, y) y = batch["seg"] # (bs, nr_of_classes, x, y) data_preparation_time += time.time( ) - start_time_data_preparation start_time_network = time.time() if type == "train": nr_of_updates += 1 loss, probs, f1 = model.train(x, y, weight_factor=weight_factor) # loss, probs, f1, intermediate = model.train(x, y) elif type == "validate": loss, probs, f1 = model.test(x, y, weight_factor=weight_factor) elif type == "test": loss, probs, f1 = model.test(x, y, weight_factor=weight_factor) network_time += time.time() - start_time_network start_time_metrics = time.time() if Config.CALC_F1: if Config.EXPERIMENT_TYPE == "peak_regression": #Following two lines increase metrics_time by 30s (without < 1s); # time per batch increases by 1.5s by these lines # y_flat = y.transpose(0, 2, 3, 1) # (bs, x, y, nr_of_classes) # y_flat = np.reshape(y_flat, (-1, y_flat.shape[-1])) # (bs*x*y, nr_of_classes) # metrics = metric_utils.calculate_metrics(metrics, y_flat, probs, loss, f1=np.mean(f1), # type=type, threshold=Config.THRESHOLD, # f1_per_bundle={"CA": f1[5], "FX_left": f1[23], # "FX_right": f1[24]}) #Numpy # y_right_order = y.transpose(0, 2, 3, 1) # (bs, x, y, nr_of_classes) # peak_f1 = metric_utils.calc_peak_dice(Config, probs, y_right_order) # peak_f1_mean = np.array([s for s in peak_f1.values()]).mean() # import IPython # IPython.embed() #Pytorch peak_f1_mean = np.array([ s.to('cpu') for s in list(f1.values()) ]).mean() #if f1 for multiple bundles metrics = metric_utils.calculate_metrics( metrics, None, None, loss, f1=peak_f1_mean, type=type, threshold=Config.THRESHOLD) #Pytorch 2 F1 # peak_f1_mean_a = np.array([s for s in f1[0].values()]).mean() # peak_f1_mean_b = np.array([s for s in f1[1].values()]).mean() # metrics = metric_utils.calculate_metrics(metrics, None, None, loss, f1=peak_f1_mean_a, # type=type, threshold=Config.THRESHOLD, # f1_per_bundle={"LenF1": peak_f1_mean_b}) #Single Bundle # metrics = metric_utils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"][0], # type=type, threshold=Config.THRESHOLD, # f1_per_bundle={"Thr1": f1["CST_right"][1], # "Thr2": f1["CST_right"][2]}) # metrics = metric_utils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"], # type=type, threshold=Config.THRESHOLD) else: metrics = metric_utils.calculate_metrics( metrics, None, None, loss, f1=np.mean(f1), type=type, threshold=Config.THRESHOLD) else: metrics = metric_utils.calculate_metrics_onlyLoss( metrics, loss, type=type) metrics_time += time.time() - start_time_metrics print_loss.append(loss) if batch_nr[type] % Config.PRINT_FREQ == 0: time_batch_part = time.time() - start_time_batch_part start_time_batch_part = time.time() exp_utils.print_and_save( Config, "{} Ep {}, Sp {}, loss {}, t print {}s, " "t batch {}s".format( type, epoch_nr, batch_nr[type] * Config.BATCH_SIZE, round(np.array(print_loss).mean(), 6), round(time_batch_part, 3), round(time_batch_part / Config.PRINT_FREQ, 3))) print_loss = [] if Config.USE_VISLOGGER: plot_utils.plot_result_trixi(trixi, x, y, probs, loss, f1, epoch_nr) ################################### # Post Training tasks (each epoch) ################################### # Average loss per batch over entire epoch metrics = metric_utils.normalize_last_element(metrics, batch_nr["train"], type="train") metrics = metric_utils.normalize_last_element(metrics, batch_nr["validate"], type="validate") metrics = metric_utils.normalize_last_element(metrics, batch_nr["test"], type="test") print(" Epoch {}, Average Epoch loss = {}".format( epoch_nr, metrics["loss_train"][-1])) print(" Epoch {}, nr_of_updates {}".format(epoch_nr, nr_of_updates)) # Adapt LR if Config.LR_SCHEDULE: if Config.LR_SCHEDULE_MODE == "min": model.scheduler.step(metrics["loss_validate"][-1]) else: model.scheduler.step(metrics["f1_macro_validate"][-1]) model.print_current_lr() # Save Weights start_time_saving = time.time() if Config.SAVE_WEIGHTS: model.save_model(metrics, epoch_nr) saving_time += time.time() - start_time_saving # Create Plots start_time_plotting = time.time() pickle.dump(metrics, open(join(Config.EXP_PATH, "metrics.pkl"), "wb")) plot_utils.create_exp_plot(metrics, Config.EXP_PATH, Config.EXP_NAME) plot_utils.create_exp_plot(metrics, Config.EXP_PATH, Config.EXP_NAME, without_first_epochs=True) plotting_time += time.time() - start_time_plotting epoch_time = time.time() - start_time epoch_times.append(epoch_time) exp_utils.print_and_save( Config, " Epoch {}, time total {}s".format(epoch_nr, epoch_time)) exp_utils.print_and_save( Config, " Epoch {}, time UNet: {}s".format(epoch_nr, network_time)) exp_utils.print_and_save( Config, " Epoch {}, time metrics: {}s".format(epoch_nr, metrics_time)) exp_utils.print_and_save( Config, " Epoch {}, time saving files: {}s".format(epoch_nr, saving_time)) exp_utils.print_and_save(Config, str(datetime.datetime.now())) # Adding next Epoch if epoch_nr < Config.NUM_EPOCHS - 1: metrics = metric_utils.add_empty_element(metrics) #################################### # After all epochs ################################### with open(join(Config.EXP_PATH, "Hyperparameters.txt"), "a") as f: # a for append f.write("\n\n") f.write("Average Epoch time: {}s".format( sum(epoch_times) / float(len(epoch_times)))) return model
def setUp(self): self.visdomLogger = PytorchVisdomLogger()
class TestPytorchVisdomLogger(unittest.TestCase): @classmethod def setUpClass(cls): super(TestPytorchVisdomLogger, cls).setUpClass() try: start_visdom() except: print("Could not start visdom, it might be already running.") def setUp(self): self.visdomLogger = PytorchVisdomLogger() def test_show_image(self): image = np.random.random_sample((3, 128, 128)) tensor = torch.from_numpy(image) self.visdomLogger.show_image(tensor.numpy(), title='image') def test_show_images(self): images = np.random.random_sample((4, 3, 128, 128)) tensors = torch.from_numpy(images) self.visdomLogger.show_images(tensors.numpy(), title='images') def test_show_image_grid(self): images = np.random.random_sample((4, 3, 128, 128)) tensor = torch.from_numpy(images) self.visdomLogger.show_image_grid(tensor, title="image_grid") def test_show_image_grid_heatmap(self): images = np.random.random_sample((4, 3, 128, 128)) self.visdomLogger.show_image_grid_heatmap(images, title="image_grid_heatmap") def test_show_barplot(self): tensor = torch.from_numpy(np.random.random_sample(5)) self.visdomLogger.show_barplot(tensor, title="barplot") def test_show_lineplot(self): x = [0, 1, 2, 3, 4, 5] y = np.random.random_sample(6) self.visdomLogger.show_lineplot(y, x, title="lineplot1") def test_show_piechart(self): array = torch.from_numpy(np.random.random_sample(5)) self.visdomLogger.show_piechart(array, title="piechart") def test_show_scatterplot(self): array = torch.from_numpy(np.random.random_sample((5, 2))) self.visdomLogger.show_scatterplot(array, title="scatterplot") def test_show_value(self): val = torch.from_numpy(np.random.random_sample(1)) self.visdomLogger.show_value(val, title="value") val = torch.from_numpy(np.random.random_sample(1)) self.visdomLogger.show_value(val, title="value") val = torch.from_numpy(np.random.random_sample(1)) self.visdomLogger.show_value(val, title="value", counter=4) def test_show_text(self): text = "\nTest 4 fun: zD ;-D 0o" self.visdomLogger.show_text(text, title='text') def test_get_roc_curve(self): array = np.random.random_sample(100) labels = np.random.choice((0, 1), 100) self.visdomLogger.show_roc_curve(array, labels, name="roc") def test_get_pr_curve(self): array = np.random.random_sample(100) labels = np.random.choice((0, 1), 100) self.visdomLogger.show_roc_curve(array, labels, name="pr") def test_get_classification_metric(self): array = np.random.random_sample(100) labels = np.random.choice((0, 1), 100) self.visdomLogger.show_classification_metrics( array, labels, metric=("roc-auc", "pr-score"), name="classification-metrics") def test_show_image_gradient(self): net = Net() random_input = torch.from_numpy( np.random.randn(28 * 28).reshape((1, 1, 28, 28))).float() fake_labels = torch.from_numpy(np.array([2])).long() criterion = torch.nn.CrossEntropyLoss() def err_fn(x): x = net(x) return criterion(x, fake_labels) self.visdomLogger.show_image_gradient(name="grads-vanilla", model=net, inpt=random_input, err_fn=err_fn, grad_type="vanilla") time.sleep(1) self.visdomLogger.show_image_gradient(name="grads-svanilla", model=net, inpt=random_input, err_fn=err_fn, grad_type="smooth-vanilla") time.sleep(1) self.visdomLogger.show_image_gradient(name="grads-guided", model=net, inpt=random_input, err_fn=err_fn, grad_type="guided") time.sleep(1) self.visdomLogger.show_image_gradient(name="grads-sguided", model=net, inpt=random_input, err_fn=err_fn, grad_type="smooth-guided") time.sleep(1) def test_plot_model_structure(self): net = Net() self.visdomLogger.plot_model_structure(net, [(1, 1, 28, 28)]) def test_plot_model_statistics(self): net = Net() self.visdomLogger.plot_model_statistics(net, plot_grad=False) self.visdomLogger.plot_model_statistics(net, plot_grad=True) def test_show_embedding(self): array = torch.from_numpy(np.random.random_sample((100, 100))) self.visdomLogger.show_embedding(array, method="tsne") self.visdomLogger.show_embedding(array, method="umap")
def train_model(Config, model, data_loader): if Config.USE_VISLOGGER: try: from trixi.logger.visdom import PytorchVisdomLogger except ImportError: pass trixi = PytorchVisdomLogger(port=8080, auto_start=True) exp_utils.print_and_save(Config, socket.gethostname()) epoch_times = [] nr_of_updates = 0 metrics = {} for type in ["train", "test", "validate"]: metrics_new = {} for metric in Config.METRIC_TYPES: metrics_new[metric + "_" + type] = [0] metrics = dict(list(metrics.items()) + list(metrics_new.items())) batch_gen_train = data_loader.get_batch_generator( batch_size=Config.BATCH_SIZE, type="train", subjects=getattr(Config, "TRAIN_SUBJECTS")) batch_gen_val = data_loader.get_batch_generator( batch_size=Config.BATCH_SIZE, type="validate", subjects=getattr(Config, "VALIDATE_SUBJECTS")) for epoch_nr in range(Config.NUM_EPOCHS): start_time = time.time() # current_lr = Config.LEARNING_RATE * (Config.LR_DECAY ** epoch_nr) # current_lr = Config.LEARNING_RATE data_preparation_time = 0 network_time = 0 metrics_time = 0 saving_time = 0 plotting_time = 0 batch_nr = {"train": 0, "test": 0, "validate": 0} if Config.LOSS_WEIGHT is None: weight_factor = None elif Config.LOSS_WEIGHT_LEN == -1: weight_factor = float(Config.LOSS_WEIGHT) else: # Linearly decrease from LOSS_WEIGHT to 1 over LOSS_WEIGHT_LEN epochs if epoch_nr < Config.LOSS_WEIGHT_LEN: weight_factor = -( (Config.LOSS_WEIGHT - 1) / float(Config.LOSS_WEIGHT_LEN)) * epoch_nr + float( Config.LOSS_WEIGHT) else: weight_factor = 1. exp_utils.print_and_save( Config, "Current weight_factor: {}".format(weight_factor)) if Config.ONLY_VAL: types = ["validate"] else: types = ["train", "validate"] for type in types: print_loss = [] if Config.DIM == "2D": nr_of_samples = len(getattr( Config, type.upper() + "_SUBJECTS")) * Config.INPUT_DIM[0] else: nr_of_samples = len(getattr(Config, type.upper() + "_SUBJECTS")) # *Config.EPOCH_MULTIPLIER needed to have roughly same number of updates/batches as with 2D U-Net nr_batches = int( int(nr_of_samples / Config.BATCH_SIZE) * Config.EPOCH_MULTIPLIER) print("Start looping batches...") start_time_batch_part = time.time() for i in range(nr_batches): if type == "train": batch = next(batch_gen_train) else: batch = next(batch_gen_val) start_time_data_preparation = time.time() batch_nr[type] += 1 x = batch["data"] # (bs, nr_of_channels, x, y) y = batch["seg"] # (bs, nr_of_classes, x, y) # print("x.shape: {}".format(x.shape)) # print("y.shape: {}".format(y.shape)) data_preparation_time += time.time( ) - start_time_data_preparation start_time_network = time.time() if type == "train": nr_of_updates += 1 probs, metr_batch = model.train( x, y, weight_factor=weight_factor) elif type == "validate": probs, metr_batch = model.test(x, y, weight_factor=weight_factor) elif type == "test": probs, metr_batch = model.test(x, y, weight_factor=weight_factor) network_time += time.time() - start_time_network start_time_metrics = time.time() if Config.CALC_F1: if Config.EXPERIMENT_TYPE == "peak_regression": peak_f1_mean = np.array([ s.to('cpu') for s in list(metr_batch["f1_macro"].values()) ]).mean() metr_batch["f1_macro"] = peak_f1_mean metrics = metric_utils.add_to_metrics( metrics, metr_batch, type, Config.METRIC_TYPES) else: metr_batch["f1_macro"] = np.mean( metr_batch["f1_macro"]) metrics = metric_utils.add_to_metrics( metrics, metr_batch, type, Config.METRIC_TYPES) else: metrics = metric_utils.calculate_metrics_onlyLoss( metrics, metr_batch["loss"], type=type) metrics_time += time.time() - start_time_metrics print_loss.append(metr_batch["loss"]) if batch_nr[type] % Config.PRINT_FREQ == 0: time_batch_part = time.time() - start_time_batch_part start_time_batch_part = time.time() exp_utils.print_and_save( Config, "{} Ep {}, Sp {}, loss {}, t print {}s, " "t batch {}s".format( type, epoch_nr, batch_nr[type] * Config.BATCH_SIZE, round(np.array(print_loss).mean(), 6), round(time_batch_part, 3), round(time_batch_part / Config.PRINT_FREQ, 3))) print_loss = [] if Config.USE_VISLOGGER: plot_utils.plot_result_trixi(trixi, x, y, probs, metr_batch["loss"], metr_batch["f1_macro"], epoch_nr) ################################### # Post Training tasks (each epoch) ################################### if Config.ONLY_VAL: metrics = metric_utils.normalize_last_element(metrics, batch_nr["validate"], type="validate") print("f1 macro validate: {}".format( round(metrics["f1_macro_validate"][0], 4))) return model # Average loss per batch over entire epoch metrics = metric_utils.normalize_last_element(metrics, batch_nr["train"], type="train") metrics = metric_utils.normalize_last_element(metrics, batch_nr["validate"], type="validate") # metrics = metric_utils.normalize_last_element(metrics, batch_nr["test"], type="test") print(" Epoch {}, Average Epoch loss = {}".format( epoch_nr, metrics["loss_train"][-1])) print(" Epoch {}, nr_of_updates {}".format(epoch_nr, nr_of_updates)) # Adapt LR if Config.LR_SCHEDULE: if Config.LR_SCHEDULE_MODE == "min": model.scheduler.step(metrics["loss_validate"][-1]) else: model.scheduler.step(metrics["f1_macro_validate"][-1]) model.print_current_lr() # Save Weights start_time_saving = time.time() if Config.SAVE_WEIGHTS: model.save_model(metrics, epoch_nr, mode=Config.BEST_EPOCH_SELECTION) saving_time += time.time() - start_time_saving # Create Plots start_time_plotting = time.time() pickle.dump(metrics, open(join(Config.EXP_PATH, "metrics.pkl"), "wb")) plot_utils.create_exp_plot(metrics, Config.EXP_PATH, Config.EXP_NAME, keys=["loss", "f1_macro"], types=["train", "validate"], selected_ax=["loss", "f1"], fig_name="metrics_all.png") plot_utils.create_exp_plot(metrics, Config.EXP_PATH, Config.EXP_NAME, without_first_epochs=True, keys=["loss", "f1_macro"], types=["train", "validate"], selected_ax=["loss", "f1"], fig_name="metrics.png") if "angle_err" in Config.METRIC_TYPES: plot_utils.create_exp_plot(metrics, Config.EXP_PATH, Config.EXP_NAME, without_first_epochs=True, keys=["loss", "angle_err"], types=["train", "validate"], selected_ax=["loss", "f1"], fig_name="metrics_angle.png") plotting_time += time.time() - start_time_plotting epoch_time = time.time() - start_time epoch_times.append(epoch_time) exp_utils.print_and_save( Config, " Epoch {}, time total {}s".format(epoch_nr, epoch_time)) exp_utils.print_and_save( Config, " Epoch {}, time UNet: {}s".format(epoch_nr, network_time)) exp_utils.print_and_save( Config, " Epoch {}, time metrics: {}s".format(epoch_nr, metrics_time)) exp_utils.print_and_save( Config, " Epoch {}, time saving files: {}s".format(epoch_nr, saving_time)) exp_utils.print_and_save(Config, str(datetime.datetime.now())) # Adding next Epoch if epoch_nr < Config.NUM_EPOCHS - 1: metrics = metric_utils.add_empty_element(metrics) #################################### # After all epochs ################################### with open(join(Config.EXP_PATH, "Hyperparameters.txt"), "a") as f: # a for append f.write("\n\n") f.write("Average Epoch time: {}s".format( sum(epoch_times) / float(len(epoch_times)))) return model