def main(not_parsed_args): logging.info('Build dataset') train_set = get_training_set(FLAGS.dataset_h, FLAGS.dataset_l, FLAGS.frames, FLAGS.scale, True, 'filelist.txt', True, FLAGS.patch_size, FLAGS.future_frame) if FLAGS.dataset_val: val_set = get_eval_set(FLAGS.dataset_val_h, FLAGS.dataset_val_l, FLAGS.frames, FLAGS.scale, True, 'filelist.txt', True, FLAGS.patch_size, FLAGS.future_frame) logging.info('Build model') model = RBPN() model.summary() last_epoch, last_step = load_weights(model) model.compile(optimizer=optimizers.Adam(FLAGS.lr), loss=losses.mae, metrics=[psnr]) # checkpoint = ModelCheckpoint('models/model.hdf5', verbose=1) tensorboard = TensorBoard(log_dir='./tf_logs', batch_size=FLAGS.batch_size, write_graph=False, write_grads=True, write_images=True, update_freq='batch') tensorboard.set_model(model) logging.info('Training start') for e in range(last_epoch, FLAGS.epochs): tensorboard.on_epoch_begin(e) for s in range(last_step + 1, len(train_set) // FLAGS.batch_size): tensorboard.on_batch_begin(s) x, y = train_set.batch(FLAGS.batch_size) loss = model.train_on_batch(x, y) print('Epoch %d step %d, loss %f psnr %f' % (e, s, loss[0], loss[1])) tensorboard.on_batch_end(s, named_logs(model, loss, s)) if FLAGS.dataset_val and s > 0 and s % FLAGS.val_interval == 0 or s == len( train_set) // FLAGS.batch_size - 1: logging.info('Validation start') val_loss = 0 val_psnr = 0 for j in range(len(val_set)): x_val, y_val = val_set.batch(1) score = model.test_on_batch(x_val, y_val) val_loss += score[0] val_psnr += score[1] val_loss /= len(val_set) val_psnr /= len(val_set) logging.info('Validation average loss %f psnr %f' % (val_loss, val_psnr)) if s > 0 and s % FLAGS.save_interval == 0 or s == len( train_set) // FLAGS.batch_size - 1: logging.info('Saving model') filename = 'model_%d_%d.h5' % (e, s) path = os.path.join(FLAGS.model_dir, filename) path_info = os.path.join(FLAGS.model_dir, 'info') model.save_weights(path) f = open(path_info, 'w') f.write(filename) f.close() tensorboard.on_epoch_end(e) last_step = -1
class ExtendedLogger(Callback): val_data_metrics = {} def __init__(self, prediction_layer, output_dir='./tmp', stateful=False, stateful_reset_interval=None, starting_indicies=None): if stateful and stateful_reset_interval is None: raise ValueError( 'If model is stateful, then seq-len has to be defined!') super(ExtendedLogger, self).__init__() self.csv_dir = os.path.join(output_dir, 'csv') self.tb_dir = os.path.join(output_dir, 'tensorboard') self.pred_dir = os.path.join(output_dir, 'predictions') self.plot_dir = os.path.join(output_dir, 'plots') make_dir(self.csv_dir) make_dir(self.tb_dir) make_dir(self.plot_dir) make_dir(self.pred_dir) self.stateful = stateful self.stateful_reset_interval = stateful_reset_interval self.starting_indicies = starting_indicies self.csv_logger = CSVLogger(os.path.join(self.csv_dir, 'run.csv')) self.tensorboard = TensorBoard(log_dir=self.tb_dir, write_graph=True) self.prediction_layer = prediction_layer def set_params(self, params): super(ExtendedLogger, self).set_params(params) self.tensorboard.set_params(params) self.tensorboard.batch_size = params['batch_size'] self.csv_logger.set_params(params) def set_model(self, model): super(ExtendedLogger, self).set_model(model) self.tensorboard.set_model(model) self.csv_logger.set_model(model) def on_batch_begin(self, batch, logs=None): self.csv_logger.on_batch_begin(batch, logs=logs) self.tensorboard.on_batch_begin(batch, logs=logs) def on_batch_end(self, batch, logs=None): self.csv_logger.on_batch_end(batch, logs=logs) self.tensorboard.on_batch_end(batch, logs=logs) def on_train_begin(self, logs=None): self.csv_logger.on_train_begin(logs=logs) self.tensorboard.on_train_begin(logs=logs) def on_train_end(self, logs=None): self.csv_logger.on_train_end(logs=logs) self.tensorboard.on_train_end(logs) def on_epoch_begin(self, epoch, logs=None): self.csv_logger.on_epoch_begin(epoch, logs=logs) self.tensorboard.on_epoch_begin(epoch, logs=logs) def on_epoch_end(self, epoch, logs=None): with timeit('metrics'): outputs = self.model.get_layer(self.prediction_layer).output self.prediction_model = Model(inputs=self.model.input, outputs=outputs) batch_size = self.params['batch_size'] if isinstance(self.validation_data[-1], float): val_data = self.validation_data[:-2] else: val_data = self.validation_data[:-1] y_true = val_data[1] callback = None if self.stateful: callback = ResetStatesCallback( interval=self.stateful_reset_interval) callback.model = self.prediction_model y_pred = self.prediction_model.predict(val_data[:-1], batch_size=batch_size, verbose=1, callback=callback) print(y_true.shape, y_pred.shape) self.write_prediction(epoch, y_true, y_pred) y_true = y_true.reshape((-1, 7)) y_pred = y_pred.reshape((-1, 7)) self.save_error_histograms(epoch, y_true, y_pred) self.save_topview_trajectories(epoch, y_true, y_pred) new_logs = { name: np.array(metric(y_true, y_pred)) for name, metric in self.val_data_metrics.items() } logs.update(new_logs) homo_logs = self.try_add_homoscedastic_params() logs.update(homo_logs) self.tensorboard.validation_data = self.validation_data self.csv_logger.validation_data = self.validation_data self.tensorboard.on_epoch_end(epoch, logs=logs) self.csv_logger.on_epoch_end(epoch, logs=logs) def add_validation_metrics(self, metrics_dict): self.val_data_metrics.update(metrics_dict) def add_validation_metric(self, name, metric): self.val_data_metrics[name] = metric def try_add_homoscedastic_params(self): homo_pos_loss_layer = search_layer(self.model, 'homo_pos_loss') homo_quat_loss_layer = search_layer(self.model, 'homo_quat_loss') if homo_pos_loss_layer: homo_pos_log_vars = np.array(homo_pos_loss_layer.get_weights()[0]) homo_quat_log_vars = np.array( homo_quat_loss_layer.get_weights()[0]) return { 'pos_log_var': np.array(homo_pos_log_vars), 'quat_log_var': np.array(homo_quat_log_vars), } else: return {} def write_prediction(self, epoch, y_true, y_pred): filename = '{:04d}_predictions.npy'.format(epoch) filename = os.path.join(self.pred_dir, filename) arr = {'y_pred': y_pred, 'y_true': y_true} np.save(filename, arr) def save_topview_trajectories(self, epoch, y_true, y_pred, max_segment=1000): if self.starting_indicies is None: self.starting_indicies = {'valid': range(0, 4000, 1000) + [4000]} for begin, end in pairwise(self.starting_indicies['valid']): diff = end - begin if diff > max_segment: subindicies = range(begin, end, max_segment) + [end] for b, e in pairwise(subindicies): self.save_trajectory(epoch, y_true, y_pred, b, e) self.save_trajectory(epoch, y_true, y_pred, begin, end) def save_trajectory(self, epoch, y_true, y_pred, begin, end): true_xy, pred_xy = y_true[begin:end, :2], y_pred[begin:end, :2] true_q = quaternion.as_quat_array(y_true[begin:end, [6, 3, 4, 5]]) true_q = quaternion.as_euler_angles(true_q)[1] pred_q = quaternion.as_quat_array(y_pred[begin:end, [6, 3, 4, 5]]) pred_q = quaternion.as_euler_angles(pred_q)[1] plt.clf() plt.plot(true_xy[:, 0], true_xy[:, 1], 'g-') plt.plot(pred_xy[:, 0], pred_xy[:, 1], 'r-') for ((x1, y1), (x2, y2)) in zip(true_xy, pred_xy): plt.plot([x1, x2], [y1, y2], color='k', linestyle='-', linewidth=0.3, alpha=0.2) plt.grid(True) plt.xlabel('x [m]') plt.ylabel('y [m]') plt.title('Top-down view of trajectory') plt.axis('equal') x_range = (np.min(true_xy[:, 0]) - .2, np.max(true_xy[:, 0]) + .2) y_range = (np.min(true_xy[:, 1]) - .2, np.max(true_xy[:, 1]) + .2) plt.xlim(x_range) plt.ylim(y_range) filename = 'epoch={epoch:04d}_begin={begin:04d}_end={end:04d}_trajectory.pdf' \ .format(epoch=epoch, begin=begin, end=end) filename = os.path.join(self.plot_dir, filename) plt.savefig(filename) def save_error_histograms(self, epoch, y_true, y_pred): pos_errors = PoseMetrics.abs_errors_position(y_true, y_pred) pos_errors = np.sort(pos_errors) angle_errors = PoseMetrics.abs_errors_orienation(y_true, y_pred) angle_errors = np.sort(angle_errors) size = len(y_true) ys = np.arange(size) / float(size) plt.clf() plt.subplot(2, 1, 1) plt.title('Empirical CDF of absolute errors') plt.grid(True) plt.plot(pos_errors, ys, 'k-') plt.xlabel('Absolute Position Error (m)') plt.xlim(0, 1.2) plt.subplot(2, 1, 2) plt.grid(True) plt.plot(angle_errors, ys, 'r-') plt.xlabel('Absolute Angle Error (deg)') plt.xlim(0, 70) filename = '{:04d}_cdf.pdf'.format(epoch) filename = os.path.join(self.plot_dir, filename) plt.savefig(filename)
def train(self, batch_size=4, epochs=25): cf = self.cf self.compile() model = self.keras_model word_vectors, char_vectors, train_ques_ids, X_train, y_train, val_ques_ids, X_valid, y_valid = self.data_train qanet_cb = QANetCallback(decay=cf.EMA_DECAY) tb = TensorBoard(log_dir=cf.TENSORBOARD_PATH, histogram_freq=0, write_graph=False, write_images=False, update_freq=cf.TENSORBOARD_UPDATE_FREQ) # Call set_model for all callbacks qanet_cb.set_model(model) tb.set_model(model) ep_list = [] avg_train_loss_list = [] em_score_list = [] f1_score_list = [] global_steps = 0 gt_start_list, gt_end_list = y_valid[2:] for ep in range(1, epochs + 1): # Epoch num start from 1 print('----------- Training for epoch {}...'.format(ep)) # Train batch = 0 sum_loss = 0 num_batches = (len(X_train[0]) - 1) // batch_size + 1 for X_batch, y_batch in get_batch(X_train, y_train, batch_size=batch_size, shuffle=True): batch_logs = {'batch': batch, 'size': len(X_batch[0])} tb.on_batch_begin(batch, batch_logs) loss, loss_p1, loss_p2, loss_start, loss_end = model.train_on_batch( X_batch, y_batch) sum_loss += loss avg_loss = sum_loss / (batch + 1) print( 'Epoch: {}/{}, Batch: {}/{}, Accumulative average loss: {:.4f}, Loss: {:.4f}, Loss_P1: {:.4f}, Loss_P2: {:.4f}, Loss_start: {:.4f}, Loss_end: {:.4f}' .format(ep, epochs, batch, num_batches, avg_loss, loss, loss_p1, loss_p2, loss_start, loss_end)) batch_logs.update({ 'loss': loss, 'loss_p1': loss_p1, 'loss_p2': loss_p2 }) qanet_cb.on_batch_end(batch, batch_logs) tb.on_batch_end(batch, batch_logs) global_steps += 1 batch += 1 ep_list.append(ep) avg_train_loss_list.append(avg_loss) print('Backing up temp weights...') model.save_weights(cf.TEMP_MODEL_PATH) qanet_cb.on_epoch_end(ep) # Apply EMA weights model.save_weights(cf.MODEL_PATH % str(ep)) print('----------- Validating for epoch {}...'.format(ep)) valid_scores = self.validate(X_valid, y_valid, gt_start_list, gt_end_list, batch_size=cf.BATCH_SIZE) em_score_list.append(valid_scores['exact_match']) f1_score_list.append(valid_scores['f1']) print( '------- Result of epoch: {}/{}, Average_train_loss: {:.6f}, EM: {:.4f}, F1: {:.4f}\n' .format(ep, epochs, avg_loss, valid_scores['exact_match'], valid_scores['f1'])) tb.on_epoch_end(ep, { 'f1': valid_scores['f1'], 'em': valid_scores['exact_match'] }) # Write result to CSV file result = pd.DataFrame({ 'epoch': ep_list, 'avg_train_loss': avg_train_loss_list, 'em': em_score_list, 'f1': f1_score_list }) result.to_csv(cf.RESULT_LOG, index=None) # Restore the original weights to continue training print('Restoring temp weights...') model.load_weights(cf.TEMP_MODEL_PATH) tb.on_train_end(None)