def train_gan(dataf): # Создаем модель gen, disc, gan = build_networks() logger = CSVLogger('loss.csv') logger.on_train_begin() # Запускаем обучение на 500 эпох with h5py.File(dataf, 'r') as f: faces = f.get('faces') run_batches(gen, disc, gan, faces, logger, range(5000)) logger.on_train_end()
def train_gan(dataf): gen, disc, gan = build_networks() # Uncomment these, if you want to continue training from some snapshot. # (or load pretrained generator weights) #load_weights(gen, Args.genw) #load_weights(disc, Args.discw) logger = CSVLogger('loss.csv') # yeah, you can use callbacks independently logger.on_train_begin() # initialize csv file with h5py.File( dataf, 'r' ) as f : faces = f.get( 'faces' ) run_batches(gen, disc, gan, faces, logger, range(1000000)) logger.on_train_end()
def train(self): log.info('Training Model') self.init_train_data() self.init_image_callback() sl = SaveLoss(self.conf.folder) cl = CSVLogger(self.conf.folder + '/training.csv') cl.on_train_begin() es = EarlyStopping('val_loss_mod2_fused', min_delta=0.01, patience=60) es.model = self.model.Segmentor es.on_train_begin() loss_names = self.get_loss_names() total_loss = {n: [] for n in loss_names} progress_bar = Progbar(target=self.batches * self.conf.batch_size) for self.epoch in range(self.conf.epochs): log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs)) epoch_loss = {n: [] for n in loss_names} epoch_loss_list = [] for self.batch in range(self.batches): self.train_batch(epoch_loss) progress_bar.update((self.batch + 1) * self.conf.batch_size) self.validate(epoch_loss) for n in loss_names: epoch_loss_list.append((n, np.mean(epoch_loss[n]))) total_loss[n].append(np.mean(epoch_loss[n])) log.info(str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.3f' for l in loss_names])) % ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names))) logs = {l: total_loss[l][-1] for l in loss_names} cl.model = self.model.D_Mask cl.model.stop_training = False cl.on_epoch_end(self.epoch, logs) sl.on_epoch_end(self.epoch, logs) # Plot some example images self.img_callback.on_epoch_end(self.epoch) self.model.save_models() if self.stop_criterion(es, logs): log.info('Finished training from early stopping criterion') break
def train_gan(dataf, iters=1000000, disc_start=20, cont=False): gen, disc, gan = build_networks() # Uncomment these, if you want to continue training from some snapshot. # (or load pretrained generator weights) if cont == True: #load_weights(gen, Args.genw) #load_weights(disc, Args.discw) load_weights(gen, "snapshots/{}.gen.hdf5".format(Args.batch_len - 1)) load_weights(disc, "snapshots/{}.disc.hdf5".format(Args.batch_len - 1)) logger = CSVLogger('loss.csv') # yeah, you can use callbacks independently logger.on_train_begin() # initialize csv file with h5py.File(dataf, 'r') as f: faces = f.get('faces') run_batches(gen, disc, gan, faces, logger, range(iters), disc_start) logger.on_train_end()
def train(self): log.info('Training Model') self.init_train_data() self.init_image_callback() sl = SaveLoss(self.conf.folder) cl = CSVLogger(self.conf.folder + '/training.csv') cl.on_train_begin() es = EarlyStopping('val_loss_mod2_fused', min_delta=0.01, patience=60) es.model = self.model.Segmentor es.on_train_begin() loss_names = self.get_loss_names() total_loss = {n: [] for n in loss_names} progress_bar = Progbar(target=self.batches * self.conf.batch_size) for self.epoch in range(self.conf.epochs): log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs)) epoch_loss = {n: [] for n in loss_names} epoch_loss_list = [] for self.batch in range(self.batches): self.train_batch(epoch_loss) progress_bar.update((self.batch + 1) * self.conf.batch_size) self.set_swa_model_weights() for swa_m in self.get_swa_models(): swa_m.on_epoch_end(self.epoch) self.validate(epoch_loss) for n in loss_names: epoch_loss_list.append((n, np.mean(epoch_loss[n]))) total_loss[n].append(np.mean(epoch_loss[n])) log.info( str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.5f' for l in loss_names])) % ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names))) logs = {l: total_loss[l][-1] for l in loss_names} cl.model = self.model.D_Mask cl.model.stop_training = False cl.on_epoch_end(self.epoch, logs) sl.on_epoch_end(self.epoch, logs) # print images self.img_callback.on_epoch_end(self.epoch) self.save_models() if self.stop_criterion(es, logs): log.info('Finished training from early stopping criterion') es.on_train_end(logs) cl.on_train_end(logs) for swa_m in self.get_swa_models(): swa_m.on_train_end() # Set final model parameters based on SWA self.model.D_Mask = self.swa_D_Mask.model self.model.D_Image1 = self.swa_D_Image1.model self.model.D_Image2 = self.swa_D_Image2.model self.model.Encoders_Anatomy[0] = self.swa_Enc_Anatomy1.model self.model.Encoders_Anatomy[1] = self.swa_Enc_Anatomy2.model self.model.Enc_Modality = self.swa_Enc_Modality.model self.model.Anatomy_Fuser = self.swa_Anatomy_Fuser.model self.model.Segmentor = self.swa_Segmentor.model self.model.Decoder = self.swa_Decoder.model self.model.Balancer = self.swa_Balancer.model self.save_models() break
def train(self): self.init_train_data() # make genetrated data gen_dict = self.get_datagen_params() p_gen = ImageDataGenerator(**gen_dict).flow(x=self.p_images, y=self.p_masks, batch_size=self.conf.batch_size) h_gen = ImageDataGenerator(**gen_dict).flow(x=self.h_images, y=self.h_masks, batch_size=self.conf.batch_size) random_p_masks = ImageDataGenerator(**gen_dict).flow(x= self.p_masks, batch_size=self.conf.batch_size) # initialize training batches = int(np.ceil(self.conf.data_len/self.conf.batch_size)) progress_bar = Progbar(target=batches * self.conf.batch_size) sl = SaveLoss(self.conf.folder) cl = CSVLogger(self.conf.folder+'/training.csv') cl.on_train_begin() img_clb = ImageCallback(self.conf, self.model, self.comet_exp) loss_names = self.get_loss_names() total_loss = {n: [] for n in loss_names} # start training for epoch in range(self.conf.epochs): log.info("Train epoch %d/%d"%(epoch, self.conf.epochs)) epoch_loss = {n: [] for n in loss_names} epoch_loss_list = [] pool_to_print_p_img, pool_to_print_p_msk, pool_to_print_h_img, pool_to_print_h_msk = [], [], [], [] for batch in range(batches): p_img, p_msk = next(p_gen) h_img, h_msk = next(h_gen) r_p_msk = next(random_p_masks) if len(pool_to_print_p_img)<30: pool_to_print_p_img.append(p_img[0]) pool_to_print_p_msk.append(p_msk[0]) if len(pool_to_print_h_img)<30: pool_to_print_h_img.append(h_img[0]) pool_to_print_h_msk.append(h_msk[0]) # Adversarial ground truths real_pred = -np.ones((h_img.shape[0],1)) fake_pred = np.ones((h_img.shape[0],1)) dummy = np.zeros((h_img.shape[0],1)) dummy_Img = np.ones(h_img.shape) if self.conf.self_rec: h_test_sr = self.model.train_self_rec.fit([h_img, h_msk], [h_img, h_img], epochs=1, verbose=0) epoch_loss["test_self_rec_loss"].append(np.mean(h_test_sr.history["loss"])) else: epoch_loss["test_self_rec_loss"].append(0) # --------------------- # Train Discriminator # --------------------- # Get a group of synthetic msks and imgs cy1_pse_h_img = self.model.G_d_to_h.predict(p_img) cy1_seg_d_msk = self.model.S_d_to_msk.predict(p_img) cy2_fake_h_img = self.model.G_h_to_d.predict([h_img, h_msk]) if epoch<25: for _ in range(self.conf.ncritic[0]): cy1_epsilon = np.random.uniform(0,1, size=(h_img.shape[0],1,1,1)) cy1_average = cy1_epsilon * h_img +(1-cy1_epsilon) * cy1_pse_h_img cy1_epsilon_msk = np.random.uniform(0, 1, size=(h_img.shape[0], 1, 1, 1)) cy1_average_msk = cy1_epsilon_msk * r_p_msk + (1 - cy1_epsilon) * cy1_seg_d_msk cy2_epsilon = np.random.uniform(0,1, size=(h_img.shape[0],1,1,1)) cy2_average = cy2_epsilon * h_img +(1-cy2_epsilon) * cy2_fake_h_img h_d = self.model.critic_model.fit([h_img, cy1_pse_h_img, cy1_average, r_p_msk, cy1_seg_d_msk, cy1_average_msk, h_img, cy2_fake_h_img, cy2_average], [real_pred, fake_pred, dummy, real_pred, fake_pred, dummy, real_pred, fake_pred, dummy], epochs=1, verbose=0) else: for _ in range(self.conf.ncritic[1]): cy1_epsilon = np.random.uniform(0,1, size=(h_img.shape[0],1,1,1)) cy1_average = cy1_epsilon * h_img +(1-cy1_epsilon) * cy1_pse_h_img cy1_epsilon_msk = np.random.uniform(0, 1, size=(h_img.shape[0], 1, 1, 1)) cy1_average_msk = cy1_epsilon_msk * r_p_msk + (1 - cy1_epsilon) * cy1_seg_d_msk cy2_epsilon = np.random.uniform(0,1, size=(h_img.shape[0],1,1,1)) cy2_average = cy2_epsilon * h_img +(1-cy2_epsilon) * cy2_fake_h_img h_d = self.model.critic_model.fit([h_img, cy1_pse_h_img, cy1_average, r_p_msk, cy1_seg_d_msk, cy1_average_msk, h_img, cy2_fake_h_img, cy2_average], [real_pred, fake_pred, dummy, real_pred, fake_pred, dummy, real_pred, fake_pred, dummy], epochs=1, verbose=0) # print(h_d.history) d_dis_pse_image_loss = np.mean([h_d.history['dis_cy1_I_pse_h_loss'], h_d.history['dis_cy2_I_pse_h_loss']]) d_dis_r_image_loss = np.mean([h_d.history['dis_cy1_I_h_loss'], h_d.history['dis_cy2_I_h_loss']]) d_dis_d_mask_loss = np.mean([h_d.history['dis_cy1_M_d_loss'], h_d.history['dis_cy1_M_seg_d_loss']]) d_gp_loss = np.mean([h_d.history['gp_cy1_I_h_loss'], h_d.history['gp_cy2_I_h_loss'], h_d.history['gp_cy1_M_d_loss']]) epoch_loss['d_dis_pse_image_loss'].append(d_dis_pse_image_loss) epoch_loss['d_dis_r_image_loss'].append(d_dis_r_image_loss) epoch_loss['d_dis_d_mask_loss'].append(d_dis_d_mask_loss) epoch_loss['d_gp_loss'].append(d_gp_loss) # -------------------- # Train Generator # -------------------- h_g = self.model.gan.fit([p_img, h_img, h_msk],[real_pred, real_pred, p_img, real_pred, h_img, h_msk], epochs=1, verbose=0) g_dis_pse_image_loss = np.mean([h_g.history['cy1_dis_I_pse_h_loss'], h_g.history['cy2_dis_I_pse_d_loss']]) g_rec_image_loss = np.mean([h_g.history['cy2_I_rec_h_loss'], h_g.history['cy1_I_rec_d_loss']]) g_dis_d_mask_loss = np.mean(h_g.history['cy1_dis_M_seg_d_loss']) epoch_loss['g_dis_pse_image_loss'].append(g_dis_pse_image_loss) epoch_loss['g_rec_image_loss'].append(g_rec_image_loss) epoch_loss['g_dis_d_mask_loss'].append(g_dis_d_mask_loss) # print(h_g.history) # Plot the progress progress_bar.update((batch + 1) * self.conf.batch_size) for n in loss_names: epoch_loss_list.append((n, np.mean(epoch_loss[n]))) total_loss[n].append(np.mean(epoch_loss[n])) for n in loss_names: epoch_loss_list.append((n, np.mean(epoch_loss[n]))) total_loss[n].append(np.mean(epoch_loss[n])) log.info(str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.3f' for l in loss_names])) % ((epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names))) logs = {l: total_loss[l][-1] for l in loss_names} cl.model = self.model.D_pse_h cl.model.stop_training = False cl.on_epoch_end(epoch, logs) sl.on_epoch_end(epoch, logs) pool_to_print_p_img = np.asarray(pool_to_print_p_img) pool_to_print_p_msk = np.asarray(pool_to_print_p_msk) pool_to_print_h_img = np.asarray(pool_to_print_h_img) pool_to_print_h_msk = np.asarray(pool_to_print_h_msk) print("pool_to_print_p_img: ", np.shape(pool_to_print_p_img)) img_clb.on_epoch_end(epoch, pool_to_print_p_img, pool_to_print_p_msk, pool_to_print_h_img, pool_to_print_h_msk)
class ExtendedLogger(Callback): val_data_metrics = {} def __init__(self, prediction_layer, output_dir='./tmp', stateful=False, stateful_reset_interval=None, starting_indicies=None): if stateful and stateful_reset_interval is None: raise ValueError( 'If model is stateful, then seq-len has to be defined!') super(ExtendedLogger, self).__init__() self.csv_dir = os.path.join(output_dir, 'csv') self.tb_dir = os.path.join(output_dir, 'tensorboard') self.pred_dir = os.path.join(output_dir, 'predictions') self.plot_dir = os.path.join(output_dir, 'plots') make_dir(self.csv_dir) make_dir(self.tb_dir) make_dir(self.plot_dir) make_dir(self.pred_dir) self.stateful = stateful self.stateful_reset_interval = stateful_reset_interval self.starting_indicies = starting_indicies self.csv_logger = CSVLogger(os.path.join(self.csv_dir, 'run.csv')) self.tensorboard = TensorBoard(log_dir=self.tb_dir, write_graph=True) self.prediction_layer = prediction_layer def set_params(self, params): super(ExtendedLogger, self).set_params(params) self.tensorboard.set_params(params) self.tensorboard.batch_size = params['batch_size'] self.csv_logger.set_params(params) def set_model(self, model): super(ExtendedLogger, self).set_model(model) self.tensorboard.set_model(model) self.csv_logger.set_model(model) def on_batch_begin(self, batch, logs=None): self.csv_logger.on_batch_begin(batch, logs=logs) self.tensorboard.on_batch_begin(batch, logs=logs) def on_batch_end(self, batch, logs=None): self.csv_logger.on_batch_end(batch, logs=logs) self.tensorboard.on_batch_end(batch, logs=logs) def on_train_begin(self, logs=None): self.csv_logger.on_train_begin(logs=logs) self.tensorboard.on_train_begin(logs=logs) def on_train_end(self, logs=None): self.csv_logger.on_train_end(logs=logs) self.tensorboard.on_train_end(logs) def on_epoch_begin(self, epoch, logs=None): self.csv_logger.on_epoch_begin(epoch, logs=logs) self.tensorboard.on_epoch_begin(epoch, logs=logs) def on_epoch_end(self, epoch, logs=None): with timeit('metrics'): outputs = self.model.get_layer(self.prediction_layer).output self.prediction_model = Model(inputs=self.model.input, outputs=outputs) batch_size = self.params['batch_size'] if isinstance(self.validation_data[-1], float): val_data = self.validation_data[:-2] else: val_data = self.validation_data[:-1] y_true = val_data[1] callback = None if self.stateful: callback = ResetStatesCallback( interval=self.stateful_reset_interval) callback.model = self.prediction_model y_pred = self.prediction_model.predict(val_data[:-1], batch_size=batch_size, verbose=1, callback=callback) print(y_true.shape, y_pred.shape) self.write_prediction(epoch, y_true, y_pred) y_true = y_true.reshape((-1, 7)) y_pred = y_pred.reshape((-1, 7)) self.save_error_histograms(epoch, y_true, y_pred) self.save_topview_trajectories(epoch, y_true, y_pred) new_logs = { name: np.array(metric(y_true, y_pred)) for name, metric in self.val_data_metrics.items() } logs.update(new_logs) homo_logs = self.try_add_homoscedastic_params() logs.update(homo_logs) self.tensorboard.validation_data = self.validation_data self.csv_logger.validation_data = self.validation_data self.tensorboard.on_epoch_end(epoch, logs=logs) self.csv_logger.on_epoch_end(epoch, logs=logs) def add_validation_metrics(self, metrics_dict): self.val_data_metrics.update(metrics_dict) def add_validation_metric(self, name, metric): self.val_data_metrics[name] = metric def try_add_homoscedastic_params(self): homo_pos_loss_layer = search_layer(self.model, 'homo_pos_loss') homo_quat_loss_layer = search_layer(self.model, 'homo_quat_loss') if homo_pos_loss_layer: homo_pos_log_vars = np.array(homo_pos_loss_layer.get_weights()[0]) homo_quat_log_vars = np.array( homo_quat_loss_layer.get_weights()[0]) return { 'pos_log_var': np.array(homo_pos_log_vars), 'quat_log_var': np.array(homo_quat_log_vars), } else: return {} def write_prediction(self, epoch, y_true, y_pred): filename = '{:04d}_predictions.npy'.format(epoch) filename = os.path.join(self.pred_dir, filename) arr = {'y_pred': y_pred, 'y_true': y_true} np.save(filename, arr) def save_topview_trajectories(self, epoch, y_true, y_pred, max_segment=1000): if self.starting_indicies is None: self.starting_indicies = {'valid': range(0, 4000, 1000) + [4000]} for begin, end in pairwise(self.starting_indicies['valid']): diff = end - begin if diff > max_segment: subindicies = range(begin, end, max_segment) + [end] for b, e in pairwise(subindicies): self.save_trajectory(epoch, y_true, y_pred, b, e) self.save_trajectory(epoch, y_true, y_pred, begin, end) def save_trajectory(self, epoch, y_true, y_pred, begin, end): true_xy, pred_xy = y_true[begin:end, :2], y_pred[begin:end, :2] true_q = quaternion.as_quat_array(y_true[begin:end, [6, 3, 4, 5]]) true_q = quaternion.as_euler_angles(true_q)[1] pred_q = quaternion.as_quat_array(y_pred[begin:end, [6, 3, 4, 5]]) pred_q = quaternion.as_euler_angles(pred_q)[1] plt.clf() plt.plot(true_xy[:, 0], true_xy[:, 1], 'g-') plt.plot(pred_xy[:, 0], pred_xy[:, 1], 'r-') for ((x1, y1), (x2, y2)) in zip(true_xy, pred_xy): plt.plot([x1, x2], [y1, y2], color='k', linestyle='-', linewidth=0.3, alpha=0.2) plt.grid(True) plt.xlabel('x [m]') plt.ylabel('y [m]') plt.title('Top-down view of trajectory') plt.axis('equal') x_range = (np.min(true_xy[:, 0]) - .2, np.max(true_xy[:, 0]) + .2) y_range = (np.min(true_xy[:, 1]) - .2, np.max(true_xy[:, 1]) + .2) plt.xlim(x_range) plt.ylim(y_range) filename = 'epoch={epoch:04d}_begin={begin:04d}_end={end:04d}_trajectory.pdf' \ .format(epoch=epoch, begin=begin, end=end) filename = os.path.join(self.plot_dir, filename) plt.savefig(filename) def save_error_histograms(self, epoch, y_true, y_pred): pos_errors = PoseMetrics.abs_errors_position(y_true, y_pred) pos_errors = np.sort(pos_errors) angle_errors = PoseMetrics.abs_errors_orienation(y_true, y_pred) angle_errors = np.sort(angle_errors) size = len(y_true) ys = np.arange(size) / float(size) plt.clf() plt.subplot(2, 1, 1) plt.title('Empirical CDF of absolute errors') plt.grid(True) plt.plot(pos_errors, ys, 'k-') plt.xlabel('Absolute Position Error (m)') plt.xlim(0, 1.2) plt.subplot(2, 1, 2) plt.grid(True) plt.plot(angle_errors, ys, 'r-') plt.xlabel('Absolute Angle Error (deg)') plt.xlim(0, 70) filename = '{:04d}_cdf.pdf'.format(epoch) filename = os.path.join(self.plot_dir, filename) plt.savefig(filename)
def fit(self): """ Train SDNet """ log.info('Training SDNet') # Load data self.init_train() # Initialise callbacks sl = SaveLoss(self.conf.folder) cl = CSVLogger(self.conf.folder + '/training.csv') cl.on_train_begin() si = SDNetCallback(self.conf.folder, self.conf.batch_size, self.sdnet) es = EarlyStopping('val_loss', min_delta=0.001, patience=20) es.on_train_begin() loss_names = [ 'adv_M', 'adv_X', 'rec_X', 'rec_M', 'rec_Z', 'dis_M', 'dis_X', 'mask', 'image', 'val_loss' ] total_loss = {n: [] for n in loss_names} progress_bar = Progbar(target=self.conf.batches * self.conf.batch_size) for self.epoch in range(self.conf.epochs): log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs)) real_lb_pool, real_ul_pool = [], [ ] # these are used only for printing images epoch_loss = {n: [] for n in loss_names} D_initial_weights = np.mean( [np.mean(w) for w in self.sdnet.D_model.get_weights()]) G_initial_weights = np.mean( [np.mean(w) for w in self.sdnet.G_model.get_weights()]) for self.batch in range(self.conf.batches): real_lb = next(self.gen_X_L) real_ul = next(self.gen_X_U) # Add image/mask batch to the data pool x, m = real_lb real_lb_pool.extend([(x[i:i + 1], m[i:i + 1]) for i in range(x.shape[0])]) real_ul_pool.extend(real_ul) D_weights1 = np.mean( [np.mean(w) for w in self.sdnet.D_model.get_weights()]) self.train_batch_generator(real_lb, real_ul, epoch_loss) D_weights2 = np.mean( [np.mean(w) for w in self.sdnet.D_model.get_weights()]) assert D_weights1 == D_weights2 self.train_batch_discriminator(real_lb, real_ul, epoch_loss) progress_bar.update((self.batch + 1) * self.conf.batch_size) G_final_weights = np.mean( [np.mean(w) for w in self.sdnet.G_model.get_weights()]) D_final_weights = np.mean( [np.mean(w) for w in self.sdnet.D_model.get_weights()]) # Check training is altering weights assert D_initial_weights != D_final_weights assert G_initial_weights != G_final_weights # Plot some example images si.on_epoch_end(self.epoch, np.array(real_lb_pool), np.array(real_ul_pool)) self.validate(epoch_loss) # Calculate epoch losses for n in loss_names: total_loss[n].append(np.mean(epoch_loss[n])) log.info(str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.3f' for l in loss_names])) % \ ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names))) logs = {l: total_loss[l][-1] for l in loss_names} sl.on_epoch_end(self.epoch, logs) # log losses to csv cl.model = self.sdnet.D_model cl.model.stop_training = False cl.on_epoch_end(self.epoch, logs) # save models self.sdnet.save_models() # early stopping if self.stop_criterion(es, self.epoch, logs): log.info('Finished training from early stopping criterion') break
def train(self): self.init_train_data() # make genetrated data gen_dict = self.get_datagen_params() # Here we need to concatenate age and AD labels, in order to use Function ImageDataGenerator yng_labels = np.concatenate([self.train_age_yng, self.train_AD_yng], axis=1) old_labels = np.concatenate([self.train_age_old, self.train_AD_old], axis=1) old_gen = ImageDataGenerator(**gen_dict).flow( x=self.train_img_old, y=old_labels, batch_size=self.conf.batch_size) yng_gen = ImageDataGenerator(**gen_dict).flow( x=self.train_img_yng, y=yng_labels, batch_size=self.conf.batch_size) # initialize training batches = int(np.ceil(self.conf.data_len / self.conf.batch_size)) progress_bar = Progbar(target=batches * self.conf.batch_size) sl = SaveLoss(self.conf.folder) cl = CSVLogger(self.conf.folder + '/training.csv') cl.on_train_begin() img_clb = ImageCallback(self.conf, self.model, self.comet_exp) # clr = CyclicLR(base_lr=self.conf.lr/5, max_lr=self.conf.lr, # step_size=batches*4, mode='triangular') loss_names = self.get_loss_names() total_loss = {n: [] for n in loss_names} # start training for epoch in range(self.conf.epochs): log.info("Train epoch %d/%d" % (epoch, self.conf.epochs)) epoch_loss = {n: [] for n in loss_names} epoch_loss_list = [] pool_to_print_old, pool_to_print_yng = [], [] for batch in range(batches): old_img, old_labels = next(old_gen) yng_img, yng_labels = next(yng_gen) # Return labels to age and AD vectors old_age = old_labels[:, :self.conf.age_dim, :] old_AD = old_labels[:, self.conf.age_dim:, :] yng_age = yng_labels[:, :self.conf.age_dim, :] yng_AD = yng_labels[:, self.conf.age_dim:, :] if len(pool_to_print_old) < 30: pool_to_print_old.append(old_img) if len(pool_to_print_yng) < 30: pool_to_print_yng.append(yng_img) # Adversarial ground truths real_pred = -np.ones((old_img.shape[0], 1)) fake_pred = np.ones((old_img.shape[0], 1)) dummy = np.zeros((old_img.shape[0], 1)) dummy_Img = np.ones(old_img.shape) # --------------------- # Train Discriminator # --------------------- age_gap = calculate_age_diff(yng_age, old_age) diff_age = get_age_ord_vector(age_gap, expand_dim=1, con=self.conf.age_con, ord=self.conf.age_ord, age_dim=self.conf.age_dim) # Get a group of synthetic msks and imgs gen_masks = self.model.generator.predict( [yng_img, diff_age, old_AD]) gen_old_img = np.tanh( gen_masks + yng_img) if self.conf.use_tanh else gen_masks + yng_img # Need to train discriminators more iterations: if epoch < 25: for _ in range(self.conf.ncritic[0]): epsilon = np.random.uniform(0, 1, size=(old_img.shape[0], 1, 1, 1)) interpolation = epsilon * old_img + ( 1 - epsilon) * gen_old_img h_d = self.model.critic_model.fit([ old_img, old_age, old_AD, gen_old_img, old_age, old_AD, interpolation, old_age, old_AD ], [real_pred, fake_pred, dummy], epochs=1, verbose=0) # , callbacks=[clr]) # d_loss_bce = np.mean([h_real.history['binary_crossentropy'], h_fake.history['binary_crossentropy']]) else: for _ in range(self.conf.ncritic[1]): epsilon = np.random.uniform(0, 1, size=(old_img.shape[0], 1, 1, 1)) interpolation = epsilon * old_img + ( 1 - epsilon) * gen_old_img h_d = self.model.critic_model.fit([ old_img, old_age, old_AD, gen_old_img, old_age, old_AD, interpolation, old_age, old_AD ], [real_pred, fake_pred, dummy], epochs=1, verbose=0) # , callbacks=[clr]) # d_loss_bce = np.mean(h_real.history['d_loss']) print('d_real_loss', np.mean(h_d.history['d_real_loss']), 'd_fake_loss', np.mean(h_d.history['d_fake_loss'])) d_loss_bce = np.mean( [h_d.history['d_real_loss'], h_d.history['d_fake_loss']]) d_loss_real = np.mean(h_d.history['d_real_loss']) d_loss_fake = np.mean(h_d.history['d_fake_loss']) d_loss_gp = np.mean(h_d.history['gp_loss']) epoch_loss['Discriminator_loss'].append(d_loss_bce) epoch_loss['Discriminator_real_loss'].append(d_loss_real) epoch_loss['Discriminator_fake_loss'].append(d_loss_fake) epoch_loss['Discriminator_gp_loss'].append(d_loss_gp) # -------------------- # Train Generator # -------------------- # Train the generator, want discriminator to mistake images as real h = self.model.gan.fit( [yng_img, old_age, diff_age, age_gap, old_AD], [real_pred, dummy_Img], epochs=1, verbose=0) # , callbacks=[clr]) # print(h.history) g_loss_bce = h.history['discriminator_loss'] g_loss_l1 = h.history['map_l1_reg_loss'] # Deal with epoch loss epoch_loss['Generator_fake_loss'].append(g_loss_bce) epoch_loss['Generator_l1_reg_loss'].append(g_loss_l1) #----------------------------------------- # Train Generator by self-regularization #----------------------------------------- diff_age_zero = yng_age - yng_age h = self.model.GAN_zero_reg([yng_img, diff_age_zero, yng_AD], yng_img, epochs=1, verbose=0) g_zero_reg = np.mean(h.history['self_reg']) epoch_loss['Generator_zero_gre_loss'].append(g_zero_reg) # Plot the progress progress_bar.update((batch + 1) * self.conf.batch_size) for n in loss_names: epoch_loss_list.append((n, np.mean(epoch_loss[n]))) total_loss[n].append(np.mean(epoch_loss[n])) log.info( str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.3f' for l in loss_names])) % ((epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names))) logs = {l: total_loss[l][-1] for l in loss_names} cl.model = self.model.discriminator cl.model.stop_training = False cl.on_epoch_end(epoch, logs) sl.on_epoch_end(epoch, logs) img_clb.on_epoch_end(epoch, yng_img, yng_age, old_img, old_age)
def train(self): def _learning_rate_schedule(epoch): return self.conf.lr * math.exp(self.lr_schedule_coef * (-epoch - 1)) if os.path.exists(os.path.join(self.conf.folder, 'test-performance.csv')): os.remove(os.path.join(self.conf.folder, 'test-performance.csv')) if os.path.exists(os.path.join(self.conf.folder, 'validation-performance.csv')): os.remove(os.path.join(self.conf.folder, 'validation-performance.csv')) log.info('Training Model') dice_record = 0 self.eval_train_interval = int(max(1, self.conf.epochs/50)) self.init_train_data() lr_callback = LearningRateScheduler(_learning_rate_schedule) self.init_image_callback() sl = SaveLoss(self.conf.folder) cl = CSVLogger(self.conf.folder + '/training.csv') cl.on_train_begin() es = EarlyStopping('Validate_Dice', self.conf.min_delta, self.conf.patience) es.model = self.model.Segmentor es.on_train_begin() loss_names = self.get_loss_names() loss_names.sort() total_loss = {n: [] for n in loss_names} progress_bar = Progbar(target=self.conf.batches) # self.img_clb.on_epoch_end(self.epoch) best_performance = 0. test_performance = 0. total_iters = 0 for self.epoch in range(self.conf.epochs): total_iters+=1 log.info('Epoch %d/%d' % (self.epoch+1, self.conf.epochs)) epoch_loss = {n: [] for n in loss_names} epoch_loss_list = [] for self.batch in range(self.conf.batches): total_iters += 1 self.train_batch(epoch_loss, lr_callback) progress_bar.update(self.batch + 1) val_dice = self.validate(epoch_loss) if val_dice > dice_record: dice_record = val_dice cl.model = self.model.D_Reconstruction cl.model.stop_training = False self.model.save_models() # Plot some example images if self.epoch % self.eval_train_interval == 0 or self.epoch == self.conf.epochs - 1: self.img_clb.on_epoch_end(self.epoch) folder = os.path.join(os.path.join(self.conf.folder, 'test_during_train'), 'test_results_%s_epoch%d' % (self.conf.test_dataset, self.epoch)) if not os.path.exists(folder): os.makedirs(folder) test_performance = self.test_modality(folder, self.conf.modality, 'test', False) if test_performance > best_performance: best_performance = test_performance self.model.save_models('BestModel') log.info("BestModel@Epoch%d" % self.epoch) folder = os.path.join(os.path.join(self.conf.folder, 'test_during_train'), 'validation_results_%s_epoch%d' % (self.conf.test_dataset, self.epoch)) if not os.path.exists(folder): os.makedirs(folder) validation_performance = self.test_modality(folder, self.conf.modality, 'validation', False) if self.conf.batches>check_batch_iters: self.write_csv(os.path.join(self.conf.folder, 'test-performance.csv'), self.epoch, self.batch, test_performance) self.write_csv(os.path.join(self.conf.folder, 'validation-performance.csv'), self.epoch, self.batch, validation_performance) epoch_loss['Test_Performance_Dice'].append(test_performance) for n in loss_names: epoch_loss_list.append((n, np.mean(epoch_loss[n]))) total_loss[n].append(np.mean(epoch_loss[n])) if self.epoch<5: log.info(str('Epoch %d/%d:\n' + ''.join([l + ' Loss = %.3f\n' for l in loss_names])) % ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names))) else: info_str = str('Epoch %d/%d:\n' % (self.epoch, self.conf.epochs)) loss_info = '' for l in loss_names: loss_info = loss_info + l + ' Loss = %.3f->%.3f->%.3f->%.3f->%.3f\n' % \ (total_loss[l][-5], total_loss[l][-4], total_loss[l][-3], total_loss[l][-2], total_loss[l][-1]) log.info(info_str + loss_info) log.info("BestTest:%f" % best_performance) log.info('Epoch %d/%d' % (self.epoch + 1, self.conf.epochs)) logs = {l: total_loss[l][-1] for l in loss_names} cl.on_epoch_end(self.epoch, logs) sl.on_epoch_end(self.epoch, logs) if self.stop_criterion(es, logs) and self.epoch > self.conf.epochs / 2: log.info('Finished training from early stopping criterion') self.img_clb.on_epoch_end(self.epoch) break
def train_model(model, data, config, include_tensorboard): model_history = History() model_history.on_train_begin() saver = ModelCheckpoint(full_path(config.model_file()), verbose=1, save_best_only=True, period=1) saver.set_model(model) early_stopping = EarlyStopping(min_delta=config.min_delta, patience=config.patience, verbose=1) early_stopping.set_model(model) early_stopping.on_train_begin() csv_logger = CSVLogger(full_path(config.csv_log_file())) csv_logger.on_train_begin() if include_tensorboard: tensorborad = TensorBoard(histogram_freq=10, write_images=True) tensorborad.set_model(model) else: tensorborad = Callback() epoch = 0 stop = False while(epoch <= config.max_epochs and stop == False): epoch_history = History() epoch_history.on_train_begin() valid_sizes = [] train_sizes = [] print("Epoch:", epoch) for dataset in data.datasets: print("dataset:", dataset.name) model.reset_states() dataset.reset_generators() valid_sizes.append(dataset.valid_generators[0].size()) train_sizes.append(dataset.train_generators[0].size()) fit_history = model.fit_generator(dataset.train_generators[0], dataset.train_generators[0].size(), nb_epoch=1, verbose=0, validation_data=dataset.valid_generators[0], nb_val_samples=dataset.valid_generators[0].size()) epoch_history.on_epoch_end(epoch, last_logs(fit_history)) train_sizes.append(dataset.train_generators[1].size()) fit_history = model.fit_generator(dataset.train_generators[1], dataset.train_generators[1].size(), nb_epoch=1, verbose=0) epoch_history.on_epoch_end(epoch, last_logs(fit_history)) epoch_logs = average_logs(epoch_history, train_sizes, valid_sizes) model_history.on_epoch_end(epoch, logs=epoch_logs) saver.on_epoch_end(epoch, logs=epoch_logs) early_stopping.on_epoch_end(epoch, epoch_logs) csv_logger.on_epoch_end(epoch, epoch_logs) tensorborad.on_epoch_end(epoch, epoch_logs) epoch+= 1 if early_stopping.stopped_epoch > 0: stop = True early_stopping.on_train_end() csv_logger.on_train_end() tensorborad.on_train_end({})
def train(self): log.info('Training Model') self.init_train_data() self.init_image_callback() sl = SaveLoss(self.conf.folder) cl = CSVLogger(self.conf.folder + '/training.csv') cl.on_train_begin() es = EarlyStopping('val_loss', min_delta=0.01, patience=100) es.model = self.model.Segmentor es.on_train_begin() loss_names = self.get_loss_names() total_loss = {n: [] for n in loss_names} progress_bar = Progbar(target=self.conf.batches * self.conf.batch_size) for self.epoch in range(self.conf.epochs): log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs)) epoch_loss = {n: [] for n in loss_names} epoch_loss_list = [] D_initial_weights = np.mean( [np.mean(w) for w in self.model.D_trainer.get_weights()]) G_initial_weights = np.mean( [np.mean(w) for w in self.model.G_trainer.get_weights()]) for self.batch in range(self.conf.batches): # real_pools = self.add_to_pool(data, real_pools) self.train_batch(epoch_loss) progress_bar.update((self.batch + 1) * self.conf.batch_size) G_final_weights = np.mean( [np.mean(w) for w in self.model.G_trainer.get_weights()]) D_final_weights = np.mean( [np.mean(w) for w in self.model.D_trainer.get_weights()]) assert self.gen_unlabelled is None or not self.model.D_trainer.trainable \ or D_initial_weights != D_final_weights assert G_initial_weights != G_final_weights self.validate(epoch_loss) for n in loss_names: epoch_loss_list.append((n, np.mean(epoch_loss[n]))) total_loss[n].append(np.mean(epoch_loss[n])) log.info( str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.3f' for l in loss_names])) % ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names))) logs = {l: total_loss[l][-1] for l in loss_names} cl.model = self.model.D_Mask cl.model.stop_training = False cl.on_epoch_end(self.epoch, logs) sl.on_epoch_end(self.epoch, logs) # Plot some example images self.img_clb.on_epoch_end(self.epoch) self.model.save_models() if self.stop_criterion(es, logs): log.info('Finished training from early stopping criterion') break