def train(self): log.info('Training Model') self.init_train_data() self.init_image_callback() sl = SaveLoss(self.conf.folder) cl = CSVLogger(self.conf.folder + '/training.csv') cl.on_train_begin() es = EarlyStopping('val_loss_mod2_fused', min_delta=0.01, patience=60) es.model = self.model.Segmentor es.on_train_begin() loss_names = self.get_loss_names() total_loss = {n: [] for n in loss_names} progress_bar = Progbar(target=self.batches * self.conf.batch_size) for self.epoch in range(self.conf.epochs): log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs)) epoch_loss = {n: [] for n in loss_names} epoch_loss_list = [] for self.batch in range(self.batches): self.train_batch(epoch_loss) progress_bar.update((self.batch + 1) * self.conf.batch_size) self.validate(epoch_loss) for n in loss_names: epoch_loss_list.append((n, np.mean(epoch_loss[n]))) total_loss[n].append(np.mean(epoch_loss[n])) log.info(str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.3f' for l in loss_names])) % ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names))) logs = {l: total_loss[l][-1] for l in loss_names} cl.model = self.model.D_Mask cl.model.stop_training = False cl.on_epoch_end(self.epoch, logs) sl.on_epoch_end(self.epoch, logs) # Plot some example images self.img_callback.on_epoch_end(self.epoch) self.model.save_models() if self.stop_criterion(es, logs): log.info('Finished training from early stopping criterion') break
def train(self): log.info('Training Model') self.init_train_data() self.init_image_callback() sl = SaveLoss(self.conf.folder) cl = CSVLogger(self.conf.folder + '/training.csv') cl.on_train_begin() es = EarlyStopping('val_loss_mod2_fused', min_delta=0.01, patience=60) es.model = self.model.Segmentor es.on_train_begin() loss_names = self.get_loss_names() total_loss = {n: [] for n in loss_names} progress_bar = Progbar(target=self.batches * self.conf.batch_size) for self.epoch in range(self.conf.epochs): log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs)) epoch_loss = {n: [] for n in loss_names} epoch_loss_list = [] for self.batch in range(self.batches): self.train_batch(epoch_loss) progress_bar.update((self.batch + 1) * self.conf.batch_size) self.set_swa_model_weights() for swa_m in self.get_swa_models(): swa_m.on_epoch_end(self.epoch) self.validate(epoch_loss) for n in loss_names: epoch_loss_list.append((n, np.mean(epoch_loss[n]))) total_loss[n].append(np.mean(epoch_loss[n])) log.info( str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.5f' for l in loss_names])) % ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names))) logs = {l: total_loss[l][-1] for l in loss_names} cl.model = self.model.D_Mask cl.model.stop_training = False cl.on_epoch_end(self.epoch, logs) sl.on_epoch_end(self.epoch, logs) # print images self.img_callback.on_epoch_end(self.epoch) self.save_models() if self.stop_criterion(es, logs): log.info('Finished training from early stopping criterion') es.on_train_end(logs) cl.on_train_end(logs) for swa_m in self.get_swa_models(): swa_m.on_train_end() # Set final model parameters based on SWA self.model.D_Mask = self.swa_D_Mask.model self.model.D_Image1 = self.swa_D_Image1.model self.model.D_Image2 = self.swa_D_Image2.model self.model.Encoders_Anatomy[0] = self.swa_Enc_Anatomy1.model self.model.Encoders_Anatomy[1] = self.swa_Enc_Anatomy2.model self.model.Enc_Modality = self.swa_Enc_Modality.model self.model.Anatomy_Fuser = self.swa_Anatomy_Fuser.model self.model.Segmentor = self.swa_Segmentor.model self.model.Decoder = self.swa_Decoder.model self.model.Balancer = self.swa_Balancer.model self.save_models() break
def train(self): self.init_train_data() # make genetrated data gen_dict = self.get_datagen_params() p_gen = ImageDataGenerator(**gen_dict).flow(x=self.p_images, y=self.p_masks, batch_size=self.conf.batch_size) h_gen = ImageDataGenerator(**gen_dict).flow(x=self.h_images, y=self.h_masks, batch_size=self.conf.batch_size) random_p_masks = ImageDataGenerator(**gen_dict).flow(x= self.p_masks, batch_size=self.conf.batch_size) # initialize training batches = int(np.ceil(self.conf.data_len/self.conf.batch_size)) progress_bar = Progbar(target=batches * self.conf.batch_size) sl = SaveLoss(self.conf.folder) cl = CSVLogger(self.conf.folder+'/training.csv') cl.on_train_begin() img_clb = ImageCallback(self.conf, self.model, self.comet_exp) loss_names = self.get_loss_names() total_loss = {n: [] for n in loss_names} # start training for epoch in range(self.conf.epochs): log.info("Train epoch %d/%d"%(epoch, self.conf.epochs)) epoch_loss = {n: [] for n in loss_names} epoch_loss_list = [] pool_to_print_p_img, pool_to_print_p_msk, pool_to_print_h_img, pool_to_print_h_msk = [], [], [], [] for batch in range(batches): p_img, p_msk = next(p_gen) h_img, h_msk = next(h_gen) r_p_msk = next(random_p_masks) if len(pool_to_print_p_img)<30: pool_to_print_p_img.append(p_img[0]) pool_to_print_p_msk.append(p_msk[0]) if len(pool_to_print_h_img)<30: pool_to_print_h_img.append(h_img[0]) pool_to_print_h_msk.append(h_msk[0]) # Adversarial ground truths real_pred = -np.ones((h_img.shape[0],1)) fake_pred = np.ones((h_img.shape[0],1)) dummy = np.zeros((h_img.shape[0],1)) dummy_Img = np.ones(h_img.shape) if self.conf.self_rec: h_test_sr = self.model.train_self_rec.fit([h_img, h_msk], [h_img, h_img], epochs=1, verbose=0) epoch_loss["test_self_rec_loss"].append(np.mean(h_test_sr.history["loss"])) else: epoch_loss["test_self_rec_loss"].append(0) # --------------------- # Train Discriminator # --------------------- # Get a group of synthetic msks and imgs cy1_pse_h_img = self.model.G_d_to_h.predict(p_img) cy1_seg_d_msk = self.model.S_d_to_msk.predict(p_img) cy2_fake_h_img = self.model.G_h_to_d.predict([h_img, h_msk]) if epoch<25: for _ in range(self.conf.ncritic[0]): cy1_epsilon = np.random.uniform(0,1, size=(h_img.shape[0],1,1,1)) cy1_average = cy1_epsilon * h_img +(1-cy1_epsilon) * cy1_pse_h_img cy1_epsilon_msk = np.random.uniform(0, 1, size=(h_img.shape[0], 1, 1, 1)) cy1_average_msk = cy1_epsilon_msk * r_p_msk + (1 - cy1_epsilon) * cy1_seg_d_msk cy2_epsilon = np.random.uniform(0,1, size=(h_img.shape[0],1,1,1)) cy2_average = cy2_epsilon * h_img +(1-cy2_epsilon) * cy2_fake_h_img h_d = self.model.critic_model.fit([h_img, cy1_pse_h_img, cy1_average, r_p_msk, cy1_seg_d_msk, cy1_average_msk, h_img, cy2_fake_h_img, cy2_average], [real_pred, fake_pred, dummy, real_pred, fake_pred, dummy, real_pred, fake_pred, dummy], epochs=1, verbose=0) else: for _ in range(self.conf.ncritic[1]): cy1_epsilon = np.random.uniform(0,1, size=(h_img.shape[0],1,1,1)) cy1_average = cy1_epsilon * h_img +(1-cy1_epsilon) * cy1_pse_h_img cy1_epsilon_msk = np.random.uniform(0, 1, size=(h_img.shape[0], 1, 1, 1)) cy1_average_msk = cy1_epsilon_msk * r_p_msk + (1 - cy1_epsilon) * cy1_seg_d_msk cy2_epsilon = np.random.uniform(0,1, size=(h_img.shape[0],1,1,1)) cy2_average = cy2_epsilon * h_img +(1-cy2_epsilon) * cy2_fake_h_img h_d = self.model.critic_model.fit([h_img, cy1_pse_h_img, cy1_average, r_p_msk, cy1_seg_d_msk, cy1_average_msk, h_img, cy2_fake_h_img, cy2_average], [real_pred, fake_pred, dummy, real_pred, fake_pred, dummy, real_pred, fake_pred, dummy], epochs=1, verbose=0) # print(h_d.history) d_dis_pse_image_loss = np.mean([h_d.history['dis_cy1_I_pse_h_loss'], h_d.history['dis_cy2_I_pse_h_loss']]) d_dis_r_image_loss = np.mean([h_d.history['dis_cy1_I_h_loss'], h_d.history['dis_cy2_I_h_loss']]) d_dis_d_mask_loss = np.mean([h_d.history['dis_cy1_M_d_loss'], h_d.history['dis_cy1_M_seg_d_loss']]) d_gp_loss = np.mean([h_d.history['gp_cy1_I_h_loss'], h_d.history['gp_cy2_I_h_loss'], h_d.history['gp_cy1_M_d_loss']]) epoch_loss['d_dis_pse_image_loss'].append(d_dis_pse_image_loss) epoch_loss['d_dis_r_image_loss'].append(d_dis_r_image_loss) epoch_loss['d_dis_d_mask_loss'].append(d_dis_d_mask_loss) epoch_loss['d_gp_loss'].append(d_gp_loss) # -------------------- # Train Generator # -------------------- h_g = self.model.gan.fit([p_img, h_img, h_msk],[real_pred, real_pred, p_img, real_pred, h_img, h_msk], epochs=1, verbose=0) g_dis_pse_image_loss = np.mean([h_g.history['cy1_dis_I_pse_h_loss'], h_g.history['cy2_dis_I_pse_d_loss']]) g_rec_image_loss = np.mean([h_g.history['cy2_I_rec_h_loss'], h_g.history['cy1_I_rec_d_loss']]) g_dis_d_mask_loss = np.mean(h_g.history['cy1_dis_M_seg_d_loss']) epoch_loss['g_dis_pse_image_loss'].append(g_dis_pse_image_loss) epoch_loss['g_rec_image_loss'].append(g_rec_image_loss) epoch_loss['g_dis_d_mask_loss'].append(g_dis_d_mask_loss) # print(h_g.history) # Plot the progress progress_bar.update((batch + 1) * self.conf.batch_size) for n in loss_names: epoch_loss_list.append((n, np.mean(epoch_loss[n]))) total_loss[n].append(np.mean(epoch_loss[n])) for n in loss_names: epoch_loss_list.append((n, np.mean(epoch_loss[n]))) total_loss[n].append(np.mean(epoch_loss[n])) log.info(str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.3f' for l in loss_names])) % ((epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names))) logs = {l: total_loss[l][-1] for l in loss_names} cl.model = self.model.D_pse_h cl.model.stop_training = False cl.on_epoch_end(epoch, logs) sl.on_epoch_end(epoch, logs) pool_to_print_p_img = np.asarray(pool_to_print_p_img) pool_to_print_p_msk = np.asarray(pool_to_print_p_msk) pool_to_print_h_img = np.asarray(pool_to_print_h_img) pool_to_print_h_msk = np.asarray(pool_to_print_h_msk) print("pool_to_print_p_img: ", np.shape(pool_to_print_p_img)) img_clb.on_epoch_end(epoch, pool_to_print_p_img, pool_to_print_p_msk, pool_to_print_h_img, pool_to_print_h_msk)
def fit(self): """ Train SDNet """ log.info('Training SDNet') # Load data self.init_train() # Initialise callbacks sl = SaveLoss(self.conf.folder) cl = CSVLogger(self.conf.folder + '/training.csv') cl.on_train_begin() si = SDNetCallback(self.conf.folder, self.conf.batch_size, self.sdnet) es = EarlyStopping('val_loss', min_delta=0.001, patience=20) es.on_train_begin() loss_names = [ 'adv_M', 'adv_X', 'rec_X', 'rec_M', 'rec_Z', 'dis_M', 'dis_X', 'mask', 'image', 'val_loss' ] total_loss = {n: [] for n in loss_names} progress_bar = Progbar(target=self.conf.batches * self.conf.batch_size) for self.epoch in range(self.conf.epochs): log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs)) real_lb_pool, real_ul_pool = [], [ ] # these are used only for printing images epoch_loss = {n: [] for n in loss_names} D_initial_weights = np.mean( [np.mean(w) for w in self.sdnet.D_model.get_weights()]) G_initial_weights = np.mean( [np.mean(w) for w in self.sdnet.G_model.get_weights()]) for self.batch in range(self.conf.batches): real_lb = next(self.gen_X_L) real_ul = next(self.gen_X_U) # Add image/mask batch to the data pool x, m = real_lb real_lb_pool.extend([(x[i:i + 1], m[i:i + 1]) for i in range(x.shape[0])]) real_ul_pool.extend(real_ul) D_weights1 = np.mean( [np.mean(w) for w in self.sdnet.D_model.get_weights()]) self.train_batch_generator(real_lb, real_ul, epoch_loss) D_weights2 = np.mean( [np.mean(w) for w in self.sdnet.D_model.get_weights()]) assert D_weights1 == D_weights2 self.train_batch_discriminator(real_lb, real_ul, epoch_loss) progress_bar.update((self.batch + 1) * self.conf.batch_size) G_final_weights = np.mean( [np.mean(w) for w in self.sdnet.G_model.get_weights()]) D_final_weights = np.mean( [np.mean(w) for w in self.sdnet.D_model.get_weights()]) # Check training is altering weights assert D_initial_weights != D_final_weights assert G_initial_weights != G_final_weights # Plot some example images si.on_epoch_end(self.epoch, np.array(real_lb_pool), np.array(real_ul_pool)) self.validate(epoch_loss) # Calculate epoch losses for n in loss_names: total_loss[n].append(np.mean(epoch_loss[n])) log.info(str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.3f' for l in loss_names])) % \ ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names))) logs = {l: total_loss[l][-1] for l in loss_names} sl.on_epoch_end(self.epoch, logs) # log losses to csv cl.model = self.sdnet.D_model cl.model.stop_training = False cl.on_epoch_end(self.epoch, logs) # save models self.sdnet.save_models() # early stopping if self.stop_criterion(es, self.epoch, logs): log.info('Finished training from early stopping criterion') break
def train(self): self.init_train_data() # make genetrated data gen_dict = self.get_datagen_params() # Here we need to concatenate age and AD labels, in order to use Function ImageDataGenerator yng_labels = np.concatenate([self.train_age_yng, self.train_AD_yng], axis=1) old_labels = np.concatenate([self.train_age_old, self.train_AD_old], axis=1) old_gen = ImageDataGenerator(**gen_dict).flow( x=self.train_img_old, y=old_labels, batch_size=self.conf.batch_size) yng_gen = ImageDataGenerator(**gen_dict).flow( x=self.train_img_yng, y=yng_labels, batch_size=self.conf.batch_size) # initialize training batches = int(np.ceil(self.conf.data_len / self.conf.batch_size)) progress_bar = Progbar(target=batches * self.conf.batch_size) sl = SaveLoss(self.conf.folder) cl = CSVLogger(self.conf.folder + '/training.csv') cl.on_train_begin() img_clb = ImageCallback(self.conf, self.model, self.comet_exp) # clr = CyclicLR(base_lr=self.conf.lr/5, max_lr=self.conf.lr, # step_size=batches*4, mode='triangular') loss_names = self.get_loss_names() total_loss = {n: [] for n in loss_names} # start training for epoch in range(self.conf.epochs): log.info("Train epoch %d/%d" % (epoch, self.conf.epochs)) epoch_loss = {n: [] for n in loss_names} epoch_loss_list = [] pool_to_print_old, pool_to_print_yng = [], [] for batch in range(batches): old_img, old_labels = next(old_gen) yng_img, yng_labels = next(yng_gen) # Return labels to age and AD vectors old_age = old_labels[:, :self.conf.age_dim, :] old_AD = old_labels[:, self.conf.age_dim:, :] yng_age = yng_labels[:, :self.conf.age_dim, :] yng_AD = yng_labels[:, self.conf.age_dim:, :] if len(pool_to_print_old) < 30: pool_to_print_old.append(old_img) if len(pool_to_print_yng) < 30: pool_to_print_yng.append(yng_img) # Adversarial ground truths real_pred = -np.ones((old_img.shape[0], 1)) fake_pred = np.ones((old_img.shape[0], 1)) dummy = np.zeros((old_img.shape[0], 1)) dummy_Img = np.ones(old_img.shape) # --------------------- # Train Discriminator # --------------------- age_gap = calculate_age_diff(yng_age, old_age) diff_age = get_age_ord_vector(age_gap, expand_dim=1, con=self.conf.age_con, ord=self.conf.age_ord, age_dim=self.conf.age_dim) # Get a group of synthetic msks and imgs gen_masks = self.model.generator.predict( [yng_img, diff_age, old_AD]) gen_old_img = np.tanh( gen_masks + yng_img) if self.conf.use_tanh else gen_masks + yng_img # Need to train discriminators more iterations: if epoch < 25: for _ in range(self.conf.ncritic[0]): epsilon = np.random.uniform(0, 1, size=(old_img.shape[0], 1, 1, 1)) interpolation = epsilon * old_img + ( 1 - epsilon) * gen_old_img h_d = self.model.critic_model.fit([ old_img, old_age, old_AD, gen_old_img, old_age, old_AD, interpolation, old_age, old_AD ], [real_pred, fake_pred, dummy], epochs=1, verbose=0) # , callbacks=[clr]) # d_loss_bce = np.mean([h_real.history['binary_crossentropy'], h_fake.history['binary_crossentropy']]) else: for _ in range(self.conf.ncritic[1]): epsilon = np.random.uniform(0, 1, size=(old_img.shape[0], 1, 1, 1)) interpolation = epsilon * old_img + ( 1 - epsilon) * gen_old_img h_d = self.model.critic_model.fit([ old_img, old_age, old_AD, gen_old_img, old_age, old_AD, interpolation, old_age, old_AD ], [real_pred, fake_pred, dummy], epochs=1, verbose=0) # , callbacks=[clr]) # d_loss_bce = np.mean(h_real.history['d_loss']) print('d_real_loss', np.mean(h_d.history['d_real_loss']), 'd_fake_loss', np.mean(h_d.history['d_fake_loss'])) d_loss_bce = np.mean( [h_d.history['d_real_loss'], h_d.history['d_fake_loss']]) d_loss_real = np.mean(h_d.history['d_real_loss']) d_loss_fake = np.mean(h_d.history['d_fake_loss']) d_loss_gp = np.mean(h_d.history['gp_loss']) epoch_loss['Discriminator_loss'].append(d_loss_bce) epoch_loss['Discriminator_real_loss'].append(d_loss_real) epoch_loss['Discriminator_fake_loss'].append(d_loss_fake) epoch_loss['Discriminator_gp_loss'].append(d_loss_gp) # -------------------- # Train Generator # -------------------- # Train the generator, want discriminator to mistake images as real h = self.model.gan.fit( [yng_img, old_age, diff_age, age_gap, old_AD], [real_pred, dummy_Img], epochs=1, verbose=0) # , callbacks=[clr]) # print(h.history) g_loss_bce = h.history['discriminator_loss'] g_loss_l1 = h.history['map_l1_reg_loss'] # Deal with epoch loss epoch_loss['Generator_fake_loss'].append(g_loss_bce) epoch_loss['Generator_l1_reg_loss'].append(g_loss_l1) #----------------------------------------- # Train Generator by self-regularization #----------------------------------------- diff_age_zero = yng_age - yng_age h = self.model.GAN_zero_reg([yng_img, diff_age_zero, yng_AD], yng_img, epochs=1, verbose=0) g_zero_reg = np.mean(h.history['self_reg']) epoch_loss['Generator_zero_gre_loss'].append(g_zero_reg) # Plot the progress progress_bar.update((batch + 1) * self.conf.batch_size) for n in loss_names: epoch_loss_list.append((n, np.mean(epoch_loss[n]))) total_loss[n].append(np.mean(epoch_loss[n])) log.info( str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.3f' for l in loss_names])) % ((epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names))) logs = {l: total_loss[l][-1] for l in loss_names} cl.model = self.model.discriminator cl.model.stop_training = False cl.on_epoch_end(epoch, logs) sl.on_epoch_end(epoch, logs) img_clb.on_epoch_end(epoch, yng_img, yng_age, old_img, old_age)
def train(self): def _learning_rate_schedule(epoch): return self.conf.lr * math.exp(self.lr_schedule_coef * (-epoch - 1)) if os.path.exists(os.path.join(self.conf.folder, 'test-performance.csv')): os.remove(os.path.join(self.conf.folder, 'test-performance.csv')) if os.path.exists(os.path.join(self.conf.folder, 'validation-performance.csv')): os.remove(os.path.join(self.conf.folder, 'validation-performance.csv')) log.info('Training Model') dice_record = 0 self.eval_train_interval = int(max(1, self.conf.epochs/50)) self.init_train_data() lr_callback = LearningRateScheduler(_learning_rate_schedule) self.init_image_callback() sl = SaveLoss(self.conf.folder) cl = CSVLogger(self.conf.folder + '/training.csv') cl.on_train_begin() es = EarlyStopping('Validate_Dice', self.conf.min_delta, self.conf.patience) es.model = self.model.Segmentor es.on_train_begin() loss_names = self.get_loss_names() loss_names.sort() total_loss = {n: [] for n in loss_names} progress_bar = Progbar(target=self.conf.batches) # self.img_clb.on_epoch_end(self.epoch) best_performance = 0. test_performance = 0. total_iters = 0 for self.epoch in range(self.conf.epochs): total_iters+=1 log.info('Epoch %d/%d' % (self.epoch+1, self.conf.epochs)) epoch_loss = {n: [] for n in loss_names} epoch_loss_list = [] for self.batch in range(self.conf.batches): total_iters += 1 self.train_batch(epoch_loss, lr_callback) progress_bar.update(self.batch + 1) val_dice = self.validate(epoch_loss) if val_dice > dice_record: dice_record = val_dice cl.model = self.model.D_Reconstruction cl.model.stop_training = False self.model.save_models() # Plot some example images if self.epoch % self.eval_train_interval == 0 or self.epoch == self.conf.epochs - 1: self.img_clb.on_epoch_end(self.epoch) folder = os.path.join(os.path.join(self.conf.folder, 'test_during_train'), 'test_results_%s_epoch%d' % (self.conf.test_dataset, self.epoch)) if not os.path.exists(folder): os.makedirs(folder) test_performance = self.test_modality(folder, self.conf.modality, 'test', False) if test_performance > best_performance: best_performance = test_performance self.model.save_models('BestModel') log.info("BestModel@Epoch%d" % self.epoch) folder = os.path.join(os.path.join(self.conf.folder, 'test_during_train'), 'validation_results_%s_epoch%d' % (self.conf.test_dataset, self.epoch)) if not os.path.exists(folder): os.makedirs(folder) validation_performance = self.test_modality(folder, self.conf.modality, 'validation', False) if self.conf.batches>check_batch_iters: self.write_csv(os.path.join(self.conf.folder, 'test-performance.csv'), self.epoch, self.batch, test_performance) self.write_csv(os.path.join(self.conf.folder, 'validation-performance.csv'), self.epoch, self.batch, validation_performance) epoch_loss['Test_Performance_Dice'].append(test_performance) for n in loss_names: epoch_loss_list.append((n, np.mean(epoch_loss[n]))) total_loss[n].append(np.mean(epoch_loss[n])) if self.epoch<5: log.info(str('Epoch %d/%d:\n' + ''.join([l + ' Loss = %.3f\n' for l in loss_names])) % ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names))) else: info_str = str('Epoch %d/%d:\n' % (self.epoch, self.conf.epochs)) loss_info = '' for l in loss_names: loss_info = loss_info + l + ' Loss = %.3f->%.3f->%.3f->%.3f->%.3f\n' % \ (total_loss[l][-5], total_loss[l][-4], total_loss[l][-3], total_loss[l][-2], total_loss[l][-1]) log.info(info_str + loss_info) log.info("BestTest:%f" % best_performance) log.info('Epoch %d/%d' % (self.epoch + 1, self.conf.epochs)) logs = {l: total_loss[l][-1] for l in loss_names} cl.on_epoch_end(self.epoch, logs) sl.on_epoch_end(self.epoch, logs) if self.stop_criterion(es, logs) and self.epoch > self.conf.epochs / 2: log.info('Finished training from early stopping criterion') self.img_clb.on_epoch_end(self.epoch) break
def train(self): log.info('Training Model') self.init_train_data() self.init_image_callback() sl = SaveLoss(self.conf.folder) cl = CSVLogger(self.conf.folder + '/training.csv') cl.on_train_begin() es = EarlyStopping('val_loss', min_delta=0.01, patience=100) es.model = self.model.Segmentor es.on_train_begin() loss_names = self.get_loss_names() total_loss = {n: [] for n in loss_names} progress_bar = Progbar(target=self.conf.batches * self.conf.batch_size) for self.epoch in range(self.conf.epochs): log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs)) epoch_loss = {n: [] for n in loss_names} epoch_loss_list = [] D_initial_weights = np.mean( [np.mean(w) for w in self.model.D_trainer.get_weights()]) G_initial_weights = np.mean( [np.mean(w) for w in self.model.G_trainer.get_weights()]) for self.batch in range(self.conf.batches): # real_pools = self.add_to_pool(data, real_pools) self.train_batch(epoch_loss) progress_bar.update((self.batch + 1) * self.conf.batch_size) G_final_weights = np.mean( [np.mean(w) for w in self.model.G_trainer.get_weights()]) D_final_weights = np.mean( [np.mean(w) for w in self.model.D_trainer.get_weights()]) assert self.gen_unlabelled is None or not self.model.D_trainer.trainable \ or D_initial_weights != D_final_weights assert G_initial_weights != G_final_weights self.validate(epoch_loss) for n in loss_names: epoch_loss_list.append((n, np.mean(epoch_loss[n]))) total_loss[n].append(np.mean(epoch_loss[n])) log.info( str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.3f' for l in loss_names])) % ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names))) logs = {l: total_loss[l][-1] for l in loss_names} cl.model = self.model.D_Mask cl.model.stop_training = False cl.on_epoch_end(self.epoch, logs) sl.on_epoch_end(self.epoch, logs) # Plot some example images self.img_clb.on_epoch_end(self.epoch) self.model.save_models() if self.stop_criterion(es, logs): log.info('Finished training from early stopping criterion') break