예제 #1
0
class ExtendedLogger(Callback):

    val_data_metrics = {}

    def __init__(self,
                 prediction_layer,
                 output_dir='./tmp',
                 stateful=False,
                 stateful_reset_interval=None,
                 starting_indicies=None):

        if stateful and stateful_reset_interval is None:
            raise ValueError(
                'If model is stateful, then seq-len has to be defined!')

        super(ExtendedLogger, self).__init__()

        self.csv_dir = os.path.join(output_dir, 'csv')
        self.tb_dir = os.path.join(output_dir, 'tensorboard')
        self.pred_dir = os.path.join(output_dir, 'predictions')
        self.plot_dir = os.path.join(output_dir, 'plots')

        make_dir(self.csv_dir)
        make_dir(self.tb_dir)
        make_dir(self.plot_dir)
        make_dir(self.pred_dir)

        self.stateful = stateful
        self.stateful_reset_interval = stateful_reset_interval
        self.starting_indicies = starting_indicies
        self.csv_logger = CSVLogger(os.path.join(self.csv_dir, 'run.csv'))
        self.tensorboard = TensorBoard(log_dir=self.tb_dir, write_graph=True)
        self.prediction_layer = prediction_layer

    def set_params(self, params):
        super(ExtendedLogger, self).set_params(params)
        self.tensorboard.set_params(params)
        self.tensorboard.batch_size = params['batch_size']
        self.csv_logger.set_params(params)

    def set_model(self, model):
        super(ExtendedLogger, self).set_model(model)
        self.tensorboard.set_model(model)
        self.csv_logger.set_model(model)

    def on_batch_begin(self, batch, logs=None):
        self.csv_logger.on_batch_begin(batch, logs=logs)
        self.tensorboard.on_batch_begin(batch, logs=logs)

    def on_batch_end(self, batch, logs=None):
        self.csv_logger.on_batch_end(batch, logs=logs)
        self.tensorboard.on_batch_end(batch, logs=logs)

    def on_train_begin(self, logs=None):
        self.csv_logger.on_train_begin(logs=logs)
        self.tensorboard.on_train_begin(logs=logs)

    def on_train_end(self, logs=None):
        self.csv_logger.on_train_end(logs=logs)
        self.tensorboard.on_train_end(logs)

    def on_epoch_begin(self, epoch, logs=None):
        self.csv_logger.on_epoch_begin(epoch, logs=logs)
        self.tensorboard.on_epoch_begin(epoch, logs=logs)

    def on_epoch_end(self, epoch, logs=None):

        with timeit('metrics'):

            outputs = self.model.get_layer(self.prediction_layer).output
            self.prediction_model = Model(inputs=self.model.input,
                                          outputs=outputs)

            batch_size = self.params['batch_size']

            if isinstance(self.validation_data[-1], float):
                val_data = self.validation_data[:-2]
            else:
                val_data = self.validation_data[:-1]

            y_true = val_data[1]

            callback = None
            if self.stateful:
                callback = ResetStatesCallback(
                    interval=self.stateful_reset_interval)
                callback.model = self.prediction_model

            y_pred = self.prediction_model.predict(val_data[:-1],
                                                   batch_size=batch_size,
                                                   verbose=1,
                                                   callback=callback)

            print(y_true.shape, y_pred.shape)

            self.write_prediction(epoch, y_true, y_pred)

            y_true = y_true.reshape((-1, 7))
            y_pred = y_pred.reshape((-1, 7))

            self.save_error_histograms(epoch, y_true, y_pred)
            self.save_topview_trajectories(epoch, y_true, y_pred)

            new_logs = {
                name: np.array(metric(y_true, y_pred))
                for name, metric in self.val_data_metrics.items()
            }
            logs.update(new_logs)

            homo_logs = self.try_add_homoscedastic_params()
            logs.update(homo_logs)

            self.tensorboard.validation_data = self.validation_data
            self.csv_logger.validation_data = self.validation_data

            self.tensorboard.on_epoch_end(epoch, logs=logs)
            self.csv_logger.on_epoch_end(epoch, logs=logs)

    def add_validation_metrics(self, metrics_dict):
        self.val_data_metrics.update(metrics_dict)

    def add_validation_metric(self, name, metric):
        self.val_data_metrics[name] = metric

    def try_add_homoscedastic_params(self):
        homo_pos_loss_layer = search_layer(self.model, 'homo_pos_loss')
        homo_quat_loss_layer = search_layer(self.model, 'homo_quat_loss')

        if homo_pos_loss_layer:
            homo_pos_log_vars = np.array(homo_pos_loss_layer.get_weights()[0])
            homo_quat_log_vars = np.array(
                homo_quat_loss_layer.get_weights()[0])
            return {
                'pos_log_var': np.array(homo_pos_log_vars),
                'quat_log_var': np.array(homo_quat_log_vars),
            }
        else:
            return {}

    def write_prediction(self, epoch, y_true, y_pred):
        filename = '{:04d}_predictions.npy'.format(epoch)
        filename = os.path.join(self.pred_dir, filename)
        arr = {'y_pred': y_pred, 'y_true': y_true}
        np.save(filename, arr)

    def save_topview_trajectories(self,
                                  epoch,
                                  y_true,
                                  y_pred,
                                  max_segment=1000):

        if self.starting_indicies is None:
            self.starting_indicies = {'valid': range(0, 4000, 1000) + [4000]}

        for begin, end in pairwise(self.starting_indicies['valid']):

            diff = end - begin
            if diff > max_segment:
                subindicies = range(begin, end, max_segment) + [end]
                for b, e in pairwise(subindicies):
                    self.save_trajectory(epoch, y_true, y_pred, b, e)

            self.save_trajectory(epoch, y_true, y_pred, begin, end)

    def save_trajectory(self, epoch, y_true, y_pred, begin, end):
        true_xy, pred_xy = y_true[begin:end, :2], y_pred[begin:end, :2]

        true_q = quaternion.as_quat_array(y_true[begin:end, [6, 3, 4, 5]])
        true_q = quaternion.as_euler_angles(true_q)[1]

        pred_q = quaternion.as_quat_array(y_pred[begin:end, [6, 3, 4, 5]])
        pred_q = quaternion.as_euler_angles(pred_q)[1]

        plt.clf()

        plt.plot(true_xy[:, 0], true_xy[:, 1], 'g-')
        plt.plot(pred_xy[:, 0], pred_xy[:, 1], 'r-')

        for ((x1, y1), (x2, y2)) in zip(true_xy, pred_xy):
            plt.plot([x1, x2], [y1, y2],
                     color='k',
                     linestyle='-',
                     linewidth=0.3,
                     alpha=0.2)

        plt.grid(True)
        plt.xlabel('x [m]')
        plt.ylabel('y [m]')
        plt.title('Top-down view of trajectory')
        plt.axis('equal')

        x_range = (np.min(true_xy[:, 0]) - .2, np.max(true_xy[:, 0]) + .2)
        y_range = (np.min(true_xy[:, 1]) - .2, np.max(true_xy[:, 1]) + .2)

        plt.xlim(x_range)
        plt.ylim(y_range)

        filename = 'epoch={epoch:04d}_begin={begin:04d}_end={end:04d}_trajectory.pdf' \
          .format(epoch=epoch, begin=begin, end=end)
        filename = os.path.join(self.plot_dir, filename)
        plt.savefig(filename)

    def save_error_histograms(self, epoch, y_true, y_pred):
        pos_errors = PoseMetrics.abs_errors_position(y_true, y_pred)
        pos_errors = np.sort(pos_errors)

        angle_errors = PoseMetrics.abs_errors_orienation(y_true, y_pred)
        angle_errors = np.sort(angle_errors)

        size = len(y_true)
        ys = np.arange(size) / float(size)

        plt.clf()

        plt.subplot(2, 1, 1)
        plt.title('Empirical CDF of absolute errors')
        plt.grid(True)
        plt.plot(pos_errors, ys, 'k-')
        plt.xlabel('Absolute Position Error (m)')
        plt.xlim(0, 1.2)

        plt.subplot(2, 1, 2)
        plt.grid(True)
        plt.plot(angle_errors, ys, 'r-')
        plt.xlabel('Absolute Angle Error (deg)')
        plt.xlim(0, 70)

        filename = '{:04d}_cdf.pdf'.format(epoch)
        filename = os.path.join(self.plot_dir, filename)
        plt.savefig(filename)
예제 #2
0
def main(not_parsed_args):
    logging.info('Build dataset')
    train_set = get_training_set(FLAGS.dataset_h, FLAGS.dataset_l,
                                 FLAGS.frames, FLAGS.scale, True,
                                 'filelist.txt', True, FLAGS.patch_size,
                                 FLAGS.future_frame)
    if FLAGS.dataset_val:
        val_set = get_eval_set(FLAGS.dataset_val_h, FLAGS.dataset_val_l,
                               FLAGS.frames, FLAGS.scale, True, 'filelist.txt',
                               True, FLAGS.patch_size, FLAGS.future_frame)

    logging.info('Build model')
    model = RBPN()
    model.summary()
    last_epoch, last_step = load_weights(model)
    model.compile(optimizer=optimizers.Adam(FLAGS.lr),
                  loss=losses.mae,
                  metrics=[psnr])

    # checkpoint = ModelCheckpoint('models/model.hdf5', verbose=1)
    tensorboard = TensorBoard(log_dir='./tf_logs',
                              batch_size=FLAGS.batch_size,
                              write_graph=False,
                              write_grads=True,
                              write_images=True,
                              update_freq='batch')
    tensorboard.set_model(model)

    logging.info('Training start')
    for e in range(last_epoch, FLAGS.epochs):
        tensorboard.on_epoch_begin(e)
        for s in range(last_step + 1, len(train_set) // FLAGS.batch_size):
            tensorboard.on_batch_begin(s)
            x, y = train_set.batch(FLAGS.batch_size)
            loss = model.train_on_batch(x, y)
            print('Epoch %d step %d, loss %f psnr %f' %
                  (e, s, loss[0], loss[1]))
            tensorboard.on_batch_end(s, named_logs(model, loss, s))

            if FLAGS.dataset_val and s > 0 and s % FLAGS.val_interval == 0 or s == len(
                    train_set) // FLAGS.batch_size - 1:
                logging.info('Validation start')
                val_loss = 0
                val_psnr = 0
                for j in range(len(val_set)):
                    x_val, y_val = val_set.batch(1)
                    score = model.test_on_batch(x_val, y_val)
                    val_loss += score[0]
                    val_psnr += score[1]
                val_loss /= len(val_set)
                val_psnr /= len(val_set)
                logging.info('Validation average loss %f psnr %f' %
                             (val_loss, val_psnr))

            if s > 0 and s % FLAGS.save_interval == 0 or s == len(
                    train_set) // FLAGS.batch_size - 1:
                logging.info('Saving model')
                filename = 'model_%d_%d.h5' % (e, s)
                path = os.path.join(FLAGS.model_dir, filename)
                path_info = os.path.join(FLAGS.model_dir, 'info')
                model.save_weights(path)
                f = open(path_info, 'w')
                f.write(filename)
                f.close()
        tensorboard.on_epoch_end(e)
        last_step = -1