def train_gan(dataf): # Создаем модель gen, disc, gan = build_networks() logger = CSVLogger('loss.csv') logger.on_train_begin() # Запускаем обучение на 500 эпох with h5py.File(dataf, 'r') as f: faces = f.get('faces') run_batches(gen, disc, gan, faces, logger, range(5000)) logger.on_train_end()
def train_gan(dataf): gen, disc, gan = build_networks() # Uncomment these, if you want to continue training from some snapshot. # (or load pretrained generator weights) #load_weights(gen, Args.genw) #load_weights(disc, Args.discw) logger = CSVLogger('loss.csv') # yeah, you can use callbacks independently logger.on_train_begin() # initialize csv file with h5py.File( dataf, 'r' ) as f : faces = f.get( 'faces' ) run_batches(gen, disc, gan, faces, logger, range(1000000)) logger.on_train_end()
def train_gan(dataf, iters=1000000, disc_start=20, cont=False): gen, disc, gan = build_networks() # Uncomment these, if you want to continue training from some snapshot. # (or load pretrained generator weights) if cont == True: #load_weights(gen, Args.genw) #load_weights(disc, Args.discw) load_weights(gen, "snapshots/{}.gen.hdf5".format(Args.batch_len - 1)) load_weights(disc, "snapshots/{}.disc.hdf5".format(Args.batch_len - 1)) logger = CSVLogger('loss.csv') # yeah, you can use callbacks independently logger.on_train_begin() # initialize csv file with h5py.File(dataf, 'r') as f: faces = f.get('faces') run_batches(gen, disc, gan, faces, logger, range(iters), disc_start) logger.on_train_end()
def train(self): log.info('Training Model') self.init_train_data() self.init_image_callback() sl = SaveLoss(self.conf.folder) cl = CSVLogger(self.conf.folder + '/training.csv') cl.on_train_begin() es = EarlyStopping('val_loss_mod2_fused', min_delta=0.01, patience=60) es.model = self.model.Segmentor es.on_train_begin() loss_names = self.get_loss_names() total_loss = {n: [] for n in loss_names} progress_bar = Progbar(target=self.batches * self.conf.batch_size) for self.epoch in range(self.conf.epochs): log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs)) epoch_loss = {n: [] for n in loss_names} epoch_loss_list = [] for self.batch in range(self.batches): self.train_batch(epoch_loss) progress_bar.update((self.batch + 1) * self.conf.batch_size) self.set_swa_model_weights() for swa_m in self.get_swa_models(): swa_m.on_epoch_end(self.epoch) self.validate(epoch_loss) for n in loss_names: epoch_loss_list.append((n, np.mean(epoch_loss[n]))) total_loss[n].append(np.mean(epoch_loss[n])) log.info( str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.5f' for l in loss_names])) % ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names))) logs = {l: total_loss[l][-1] for l in loss_names} cl.model = self.model.D_Mask cl.model.stop_training = False cl.on_epoch_end(self.epoch, logs) sl.on_epoch_end(self.epoch, logs) # print images self.img_callback.on_epoch_end(self.epoch) self.save_models() if self.stop_criterion(es, logs): log.info('Finished training from early stopping criterion') es.on_train_end(logs) cl.on_train_end(logs) for swa_m in self.get_swa_models(): swa_m.on_train_end() # Set final model parameters based on SWA self.model.D_Mask = self.swa_D_Mask.model self.model.D_Image1 = self.swa_D_Image1.model self.model.D_Image2 = self.swa_D_Image2.model self.model.Encoders_Anatomy[0] = self.swa_Enc_Anatomy1.model self.model.Encoders_Anatomy[1] = self.swa_Enc_Anatomy2.model self.model.Enc_Modality = self.swa_Enc_Modality.model self.model.Anatomy_Fuser = self.swa_Anatomy_Fuser.model self.model.Segmentor = self.swa_Segmentor.model self.model.Decoder = self.swa_Decoder.model self.model.Balancer = self.swa_Balancer.model self.save_models() break
class ExtendedLogger(Callback): val_data_metrics = {} def __init__(self, prediction_layer, output_dir='./tmp', stateful=False, stateful_reset_interval=None, starting_indicies=None): if stateful and stateful_reset_interval is None: raise ValueError( 'If model is stateful, then seq-len has to be defined!') super(ExtendedLogger, self).__init__() self.csv_dir = os.path.join(output_dir, 'csv') self.tb_dir = os.path.join(output_dir, 'tensorboard') self.pred_dir = os.path.join(output_dir, 'predictions') self.plot_dir = os.path.join(output_dir, 'plots') make_dir(self.csv_dir) make_dir(self.tb_dir) make_dir(self.plot_dir) make_dir(self.pred_dir) self.stateful = stateful self.stateful_reset_interval = stateful_reset_interval self.starting_indicies = starting_indicies self.csv_logger = CSVLogger(os.path.join(self.csv_dir, 'run.csv')) self.tensorboard = TensorBoard(log_dir=self.tb_dir, write_graph=True) self.prediction_layer = prediction_layer def set_params(self, params): super(ExtendedLogger, self).set_params(params) self.tensorboard.set_params(params) self.tensorboard.batch_size = params['batch_size'] self.csv_logger.set_params(params) def set_model(self, model): super(ExtendedLogger, self).set_model(model) self.tensorboard.set_model(model) self.csv_logger.set_model(model) def on_batch_begin(self, batch, logs=None): self.csv_logger.on_batch_begin(batch, logs=logs) self.tensorboard.on_batch_begin(batch, logs=logs) def on_batch_end(self, batch, logs=None): self.csv_logger.on_batch_end(batch, logs=logs) self.tensorboard.on_batch_end(batch, logs=logs) def on_train_begin(self, logs=None): self.csv_logger.on_train_begin(logs=logs) self.tensorboard.on_train_begin(logs=logs) def on_train_end(self, logs=None): self.csv_logger.on_train_end(logs=logs) self.tensorboard.on_train_end(logs) def on_epoch_begin(self, epoch, logs=None): self.csv_logger.on_epoch_begin(epoch, logs=logs) self.tensorboard.on_epoch_begin(epoch, logs=logs) def on_epoch_end(self, epoch, logs=None): with timeit('metrics'): outputs = self.model.get_layer(self.prediction_layer).output self.prediction_model = Model(inputs=self.model.input, outputs=outputs) batch_size = self.params['batch_size'] if isinstance(self.validation_data[-1], float): val_data = self.validation_data[:-2] else: val_data = self.validation_data[:-1] y_true = val_data[1] callback = None if self.stateful: callback = ResetStatesCallback( interval=self.stateful_reset_interval) callback.model = self.prediction_model y_pred = self.prediction_model.predict(val_data[:-1], batch_size=batch_size, verbose=1, callback=callback) print(y_true.shape, y_pred.shape) self.write_prediction(epoch, y_true, y_pred) y_true = y_true.reshape((-1, 7)) y_pred = y_pred.reshape((-1, 7)) self.save_error_histograms(epoch, y_true, y_pred) self.save_topview_trajectories(epoch, y_true, y_pred) new_logs = { name: np.array(metric(y_true, y_pred)) for name, metric in self.val_data_metrics.items() } logs.update(new_logs) homo_logs = self.try_add_homoscedastic_params() logs.update(homo_logs) self.tensorboard.validation_data = self.validation_data self.csv_logger.validation_data = self.validation_data self.tensorboard.on_epoch_end(epoch, logs=logs) self.csv_logger.on_epoch_end(epoch, logs=logs) def add_validation_metrics(self, metrics_dict): self.val_data_metrics.update(metrics_dict) def add_validation_metric(self, name, metric): self.val_data_metrics[name] = metric def try_add_homoscedastic_params(self): homo_pos_loss_layer = search_layer(self.model, 'homo_pos_loss') homo_quat_loss_layer = search_layer(self.model, 'homo_quat_loss') if homo_pos_loss_layer: homo_pos_log_vars = np.array(homo_pos_loss_layer.get_weights()[0]) homo_quat_log_vars = np.array( homo_quat_loss_layer.get_weights()[0]) return { 'pos_log_var': np.array(homo_pos_log_vars), 'quat_log_var': np.array(homo_quat_log_vars), } else: return {} def write_prediction(self, epoch, y_true, y_pred): filename = '{:04d}_predictions.npy'.format(epoch) filename = os.path.join(self.pred_dir, filename) arr = {'y_pred': y_pred, 'y_true': y_true} np.save(filename, arr) def save_topview_trajectories(self, epoch, y_true, y_pred, max_segment=1000): if self.starting_indicies is None: self.starting_indicies = {'valid': range(0, 4000, 1000) + [4000]} for begin, end in pairwise(self.starting_indicies['valid']): diff = end - begin if diff > max_segment: subindicies = range(begin, end, max_segment) + [end] for b, e in pairwise(subindicies): self.save_trajectory(epoch, y_true, y_pred, b, e) self.save_trajectory(epoch, y_true, y_pred, begin, end) def save_trajectory(self, epoch, y_true, y_pred, begin, end): true_xy, pred_xy = y_true[begin:end, :2], y_pred[begin:end, :2] true_q = quaternion.as_quat_array(y_true[begin:end, [6, 3, 4, 5]]) true_q = quaternion.as_euler_angles(true_q)[1] pred_q = quaternion.as_quat_array(y_pred[begin:end, [6, 3, 4, 5]]) pred_q = quaternion.as_euler_angles(pred_q)[1] plt.clf() plt.plot(true_xy[:, 0], true_xy[:, 1], 'g-') plt.plot(pred_xy[:, 0], pred_xy[:, 1], 'r-') for ((x1, y1), (x2, y2)) in zip(true_xy, pred_xy): plt.plot([x1, x2], [y1, y2], color='k', linestyle='-', linewidth=0.3, alpha=0.2) plt.grid(True) plt.xlabel('x [m]') plt.ylabel('y [m]') plt.title('Top-down view of trajectory') plt.axis('equal') x_range = (np.min(true_xy[:, 0]) - .2, np.max(true_xy[:, 0]) + .2) y_range = (np.min(true_xy[:, 1]) - .2, np.max(true_xy[:, 1]) + .2) plt.xlim(x_range) plt.ylim(y_range) filename = 'epoch={epoch:04d}_begin={begin:04d}_end={end:04d}_trajectory.pdf' \ .format(epoch=epoch, begin=begin, end=end) filename = os.path.join(self.plot_dir, filename) plt.savefig(filename) def save_error_histograms(self, epoch, y_true, y_pred): pos_errors = PoseMetrics.abs_errors_position(y_true, y_pred) pos_errors = np.sort(pos_errors) angle_errors = PoseMetrics.abs_errors_orienation(y_true, y_pred) angle_errors = np.sort(angle_errors) size = len(y_true) ys = np.arange(size) / float(size) plt.clf() plt.subplot(2, 1, 1) plt.title('Empirical CDF of absolute errors') plt.grid(True) plt.plot(pos_errors, ys, 'k-') plt.xlabel('Absolute Position Error (m)') plt.xlim(0, 1.2) plt.subplot(2, 1, 2) plt.grid(True) plt.plot(angle_errors, ys, 'r-') plt.xlabel('Absolute Angle Error (deg)') plt.xlim(0, 70) filename = '{:04d}_cdf.pdf'.format(epoch) filename = os.path.join(self.plot_dir, filename) plt.savefig(filename)
def train_model(model, data, config, include_tensorboard): model_history = History() model_history.on_train_begin() saver = ModelCheckpoint(full_path(config.model_file()), verbose=1, save_best_only=True, period=1) saver.set_model(model) early_stopping = EarlyStopping(min_delta=config.min_delta, patience=config.patience, verbose=1) early_stopping.set_model(model) early_stopping.on_train_begin() csv_logger = CSVLogger(full_path(config.csv_log_file())) csv_logger.on_train_begin() if include_tensorboard: tensorborad = TensorBoard(histogram_freq=10, write_images=True) tensorborad.set_model(model) else: tensorborad = Callback() epoch = 0 stop = False while(epoch <= config.max_epochs and stop == False): epoch_history = History() epoch_history.on_train_begin() valid_sizes = [] train_sizes = [] print("Epoch:", epoch) for dataset in data.datasets: print("dataset:", dataset.name) model.reset_states() dataset.reset_generators() valid_sizes.append(dataset.valid_generators[0].size()) train_sizes.append(dataset.train_generators[0].size()) fit_history = model.fit_generator(dataset.train_generators[0], dataset.train_generators[0].size(), nb_epoch=1, verbose=0, validation_data=dataset.valid_generators[0], nb_val_samples=dataset.valid_generators[0].size()) epoch_history.on_epoch_end(epoch, last_logs(fit_history)) train_sizes.append(dataset.train_generators[1].size()) fit_history = model.fit_generator(dataset.train_generators[1], dataset.train_generators[1].size(), nb_epoch=1, verbose=0) epoch_history.on_epoch_end(epoch, last_logs(fit_history)) epoch_logs = average_logs(epoch_history, train_sizes, valid_sizes) model_history.on_epoch_end(epoch, logs=epoch_logs) saver.on_epoch_end(epoch, logs=epoch_logs) early_stopping.on_epoch_end(epoch, epoch_logs) csv_logger.on_epoch_end(epoch, epoch_logs) tensorborad.on_epoch_end(epoch, epoch_logs) epoch+= 1 if early_stopping.stopped_epoch > 0: stop = True early_stopping.on_train_end() csv_logger.on_train_end() tensorborad.on_train_end({})