Пример #1
0
def main(not_parsed_args):
    logging.info('Build dataset')
    train_set = get_training_set(FLAGS.dataset_h, FLAGS.dataset_l,
                                 FLAGS.frames, FLAGS.scale, True,
                                 'filelist.txt', True, FLAGS.patch_size,
                                 FLAGS.future_frame)
    if FLAGS.dataset_val:
        val_set = get_eval_set(FLAGS.dataset_val_h, FLAGS.dataset_val_l,
                               FLAGS.frames, FLAGS.scale, True, 'filelist.txt',
                               True, FLAGS.patch_size, FLAGS.future_frame)

    logging.info('Build model')
    model = RBPN()
    model.summary()
    last_epoch, last_step = load_weights(model)
    model.compile(optimizer=optimizers.Adam(FLAGS.lr),
                  loss=losses.mae,
                  metrics=[psnr])

    # checkpoint = ModelCheckpoint('models/model.hdf5', verbose=1)
    tensorboard = TensorBoard(log_dir='./tf_logs',
                              batch_size=FLAGS.batch_size,
                              write_graph=False,
                              write_grads=True,
                              write_images=True,
                              update_freq='batch')
    tensorboard.set_model(model)

    logging.info('Training start')
    for e in range(last_epoch, FLAGS.epochs):
        tensorboard.on_epoch_begin(e)
        for s in range(last_step + 1, len(train_set) // FLAGS.batch_size):
            tensorboard.on_batch_begin(s)
            x, y = train_set.batch(FLAGS.batch_size)
            loss = model.train_on_batch(x, y)
            print('Epoch %d step %d, loss %f psnr %f' %
                  (e, s, loss[0], loss[1]))
            tensorboard.on_batch_end(s, named_logs(model, loss, s))

            if FLAGS.dataset_val and s > 0 and s % FLAGS.val_interval == 0 or s == len(
                    train_set) // FLAGS.batch_size - 1:
                logging.info('Validation start')
                val_loss = 0
                val_psnr = 0
                for j in range(len(val_set)):
                    x_val, y_val = val_set.batch(1)
                    score = model.test_on_batch(x_val, y_val)
                    val_loss += score[0]
                    val_psnr += score[1]
                val_loss /= len(val_set)
                val_psnr /= len(val_set)
                logging.info('Validation average loss %f psnr %f' %
                             (val_loss, val_psnr))

            if s > 0 and s % FLAGS.save_interval == 0 or s == len(
                    train_set) // FLAGS.batch_size - 1:
                logging.info('Saving model')
                filename = 'model_%d_%d.h5' % (e, s)
                path = os.path.join(FLAGS.model_dir, filename)
                path_info = os.path.join(FLAGS.model_dir, 'info')
                model.save_weights(path)
                f = open(path_info, 'w')
                f.write(filename)
                f.close()
        tensorboard.on_epoch_end(e)
        last_step = -1
Пример #2
0
class ExtendedLogger(Callback):

    val_data_metrics = {}

    def __init__(self,
                 prediction_layer,
                 output_dir='./tmp',
                 stateful=False,
                 stateful_reset_interval=None,
                 starting_indicies=None):

        if stateful and stateful_reset_interval is None:
            raise ValueError(
                'If model is stateful, then seq-len has to be defined!')

        super(ExtendedLogger, self).__init__()

        self.csv_dir = os.path.join(output_dir, 'csv')
        self.tb_dir = os.path.join(output_dir, 'tensorboard')
        self.pred_dir = os.path.join(output_dir, 'predictions')
        self.plot_dir = os.path.join(output_dir, 'plots')

        make_dir(self.csv_dir)
        make_dir(self.tb_dir)
        make_dir(self.plot_dir)
        make_dir(self.pred_dir)

        self.stateful = stateful
        self.stateful_reset_interval = stateful_reset_interval
        self.starting_indicies = starting_indicies
        self.csv_logger = CSVLogger(os.path.join(self.csv_dir, 'run.csv'))
        self.tensorboard = TensorBoard(log_dir=self.tb_dir, write_graph=True)
        self.prediction_layer = prediction_layer

    def set_params(self, params):
        super(ExtendedLogger, self).set_params(params)
        self.tensorboard.set_params(params)
        self.tensorboard.batch_size = params['batch_size']
        self.csv_logger.set_params(params)

    def set_model(self, model):
        super(ExtendedLogger, self).set_model(model)
        self.tensorboard.set_model(model)
        self.csv_logger.set_model(model)

    def on_batch_begin(self, batch, logs=None):
        self.csv_logger.on_batch_begin(batch, logs=logs)
        self.tensorboard.on_batch_begin(batch, logs=logs)

    def on_batch_end(self, batch, logs=None):
        self.csv_logger.on_batch_end(batch, logs=logs)
        self.tensorboard.on_batch_end(batch, logs=logs)

    def on_train_begin(self, logs=None):
        self.csv_logger.on_train_begin(logs=logs)
        self.tensorboard.on_train_begin(logs=logs)

    def on_train_end(self, logs=None):
        self.csv_logger.on_train_end(logs=logs)
        self.tensorboard.on_train_end(logs)

    def on_epoch_begin(self, epoch, logs=None):
        self.csv_logger.on_epoch_begin(epoch, logs=logs)
        self.tensorboard.on_epoch_begin(epoch, logs=logs)

    def on_epoch_end(self, epoch, logs=None):

        with timeit('metrics'):

            outputs = self.model.get_layer(self.prediction_layer).output
            self.prediction_model = Model(inputs=self.model.input,
                                          outputs=outputs)

            batch_size = self.params['batch_size']

            if isinstance(self.validation_data[-1], float):
                val_data = self.validation_data[:-2]
            else:
                val_data = self.validation_data[:-1]

            y_true = val_data[1]

            callback = None
            if self.stateful:
                callback = ResetStatesCallback(
                    interval=self.stateful_reset_interval)
                callback.model = self.prediction_model

            y_pred = self.prediction_model.predict(val_data[:-1],
                                                   batch_size=batch_size,
                                                   verbose=1,
                                                   callback=callback)

            print(y_true.shape, y_pred.shape)

            self.write_prediction(epoch, y_true, y_pred)

            y_true = y_true.reshape((-1, 7))
            y_pred = y_pred.reshape((-1, 7))

            self.save_error_histograms(epoch, y_true, y_pred)
            self.save_topview_trajectories(epoch, y_true, y_pred)

            new_logs = {
                name: np.array(metric(y_true, y_pred))
                for name, metric in self.val_data_metrics.items()
            }
            logs.update(new_logs)

            homo_logs = self.try_add_homoscedastic_params()
            logs.update(homo_logs)

            self.tensorboard.validation_data = self.validation_data
            self.csv_logger.validation_data = self.validation_data

            self.tensorboard.on_epoch_end(epoch, logs=logs)
            self.csv_logger.on_epoch_end(epoch, logs=logs)

    def add_validation_metrics(self, metrics_dict):
        self.val_data_metrics.update(metrics_dict)

    def add_validation_metric(self, name, metric):
        self.val_data_metrics[name] = metric

    def try_add_homoscedastic_params(self):
        homo_pos_loss_layer = search_layer(self.model, 'homo_pos_loss')
        homo_quat_loss_layer = search_layer(self.model, 'homo_quat_loss')

        if homo_pos_loss_layer:
            homo_pos_log_vars = np.array(homo_pos_loss_layer.get_weights()[0])
            homo_quat_log_vars = np.array(
                homo_quat_loss_layer.get_weights()[0])
            return {
                'pos_log_var': np.array(homo_pos_log_vars),
                'quat_log_var': np.array(homo_quat_log_vars),
            }
        else:
            return {}

    def write_prediction(self, epoch, y_true, y_pred):
        filename = '{:04d}_predictions.npy'.format(epoch)
        filename = os.path.join(self.pred_dir, filename)
        arr = {'y_pred': y_pred, 'y_true': y_true}
        np.save(filename, arr)

    def save_topview_trajectories(self,
                                  epoch,
                                  y_true,
                                  y_pred,
                                  max_segment=1000):

        if self.starting_indicies is None:
            self.starting_indicies = {'valid': range(0, 4000, 1000) + [4000]}

        for begin, end in pairwise(self.starting_indicies['valid']):

            diff = end - begin
            if diff > max_segment:
                subindicies = range(begin, end, max_segment) + [end]
                for b, e in pairwise(subindicies):
                    self.save_trajectory(epoch, y_true, y_pred, b, e)

            self.save_trajectory(epoch, y_true, y_pred, begin, end)

    def save_trajectory(self, epoch, y_true, y_pred, begin, end):
        true_xy, pred_xy = y_true[begin:end, :2], y_pred[begin:end, :2]

        true_q = quaternion.as_quat_array(y_true[begin:end, [6, 3, 4, 5]])
        true_q = quaternion.as_euler_angles(true_q)[1]

        pred_q = quaternion.as_quat_array(y_pred[begin:end, [6, 3, 4, 5]])
        pred_q = quaternion.as_euler_angles(pred_q)[1]

        plt.clf()

        plt.plot(true_xy[:, 0], true_xy[:, 1], 'g-')
        plt.plot(pred_xy[:, 0], pred_xy[:, 1], 'r-')

        for ((x1, y1), (x2, y2)) in zip(true_xy, pred_xy):
            plt.plot([x1, x2], [y1, y2],
                     color='k',
                     linestyle='-',
                     linewidth=0.3,
                     alpha=0.2)

        plt.grid(True)
        plt.xlabel('x [m]')
        plt.ylabel('y [m]')
        plt.title('Top-down view of trajectory')
        plt.axis('equal')

        x_range = (np.min(true_xy[:, 0]) - .2, np.max(true_xy[:, 0]) + .2)
        y_range = (np.min(true_xy[:, 1]) - .2, np.max(true_xy[:, 1]) + .2)

        plt.xlim(x_range)
        plt.ylim(y_range)

        filename = 'epoch={epoch:04d}_begin={begin:04d}_end={end:04d}_trajectory.pdf' \
          .format(epoch=epoch, begin=begin, end=end)
        filename = os.path.join(self.plot_dir, filename)
        plt.savefig(filename)

    def save_error_histograms(self, epoch, y_true, y_pred):
        pos_errors = PoseMetrics.abs_errors_position(y_true, y_pred)
        pos_errors = np.sort(pos_errors)

        angle_errors = PoseMetrics.abs_errors_orienation(y_true, y_pred)
        angle_errors = np.sort(angle_errors)

        size = len(y_true)
        ys = np.arange(size) / float(size)

        plt.clf()

        plt.subplot(2, 1, 1)
        plt.title('Empirical CDF of absolute errors')
        plt.grid(True)
        plt.plot(pos_errors, ys, 'k-')
        plt.xlabel('Absolute Position Error (m)')
        plt.xlim(0, 1.2)

        plt.subplot(2, 1, 2)
        plt.grid(True)
        plt.plot(angle_errors, ys, 'r-')
        plt.xlabel('Absolute Angle Error (deg)')
        plt.xlim(0, 70)

        filename = '{:04d}_cdf.pdf'.format(epoch)
        filename = os.path.join(self.plot_dir, filename)
        plt.savefig(filename)
Пример #3
0
    def train(self, batch_size=4, epochs=25):
        cf = self.cf
        self.compile()
        model = self.keras_model
        word_vectors, char_vectors, train_ques_ids, X_train, y_train, val_ques_ids, X_valid, y_valid = self.data_train

        qanet_cb = QANetCallback(decay=cf.EMA_DECAY)
        tb = TensorBoard(log_dir=cf.TENSORBOARD_PATH,
                         histogram_freq=0,
                         write_graph=False,
                         write_images=False,
                         update_freq=cf.TENSORBOARD_UPDATE_FREQ)

        # Call set_model for all callbacks
        qanet_cb.set_model(model)
        tb.set_model(model)

        ep_list = []
        avg_train_loss_list = []
        em_score_list = []
        f1_score_list = []

        global_steps = 0
        gt_start_list, gt_end_list = y_valid[2:]
        for ep in range(1, epochs + 1):  # Epoch num start from 1
            print('----------- Training for epoch {}...'.format(ep))
            # Train
            batch = 0
            sum_loss = 0
            num_batches = (len(X_train[0]) - 1) // batch_size + 1
            for X_batch, y_batch in get_batch(X_train,
                                              y_train,
                                              batch_size=batch_size,
                                              shuffle=True):
                batch_logs = {'batch': batch, 'size': len(X_batch[0])}
                tb.on_batch_begin(batch, batch_logs)

                loss, loss_p1, loss_p2, loss_start, loss_end = model.train_on_batch(
                    X_batch, y_batch)
                sum_loss += loss
                avg_loss = sum_loss / (batch + 1)
                print(
                    'Epoch: {}/{}, Batch: {}/{}, Accumulative average loss: {:.4f}, Loss: {:.4f}, Loss_P1: {:.4f}, Loss_P2: {:.4f}, Loss_start: {:.4f}, Loss_end: {:.4f}'
                    .format(ep, epochs, batch, num_batches, avg_loss, loss,
                            loss_p1, loss_p2, loss_start, loss_end))
                batch_logs.update({
                    'loss': loss,
                    'loss_p1': loss_p1,
                    'loss_p2': loss_p2
                })
                qanet_cb.on_batch_end(batch, batch_logs)
                tb.on_batch_end(batch, batch_logs)

                global_steps += 1
                batch += 1

            ep_list.append(ep)
            avg_train_loss_list.append(avg_loss)

            print('Backing up temp weights...')
            model.save_weights(cf.TEMP_MODEL_PATH)
            qanet_cb.on_epoch_end(ep)  # Apply EMA weights
            model.save_weights(cf.MODEL_PATH % str(ep))

            print('----------- Validating for epoch {}...'.format(ep))
            valid_scores = self.validate(X_valid,
                                         y_valid,
                                         gt_start_list,
                                         gt_end_list,
                                         batch_size=cf.BATCH_SIZE)
            em_score_list.append(valid_scores['exact_match'])
            f1_score_list.append(valid_scores['f1'])
            print(
                '------- Result of epoch: {}/{}, Average_train_loss: {:.6f}, EM: {:.4f}, F1: {:.4f}\n'
                .format(ep, epochs, avg_loss, valid_scores['exact_match'],
                        valid_scores['f1']))

            tb.on_epoch_end(ep, {
                'f1': valid_scores['f1'],
                'em': valid_scores['exact_match']
            })

            # Write result to CSV file
            result = pd.DataFrame({
                'epoch': ep_list,
                'avg_train_loss': avg_train_loss_list,
                'em': em_score_list,
                'f1': f1_score_list
            })
            result.to_csv(cf.RESULT_LOG, index=None)

            # Restore the original weights to continue training
            print('Restoring temp weights...')
            model.load_weights(cf.TEMP_MODEL_PATH)

        tb.on_train_end(None)