예제 #1
0
파일: train.py 프로젝트: Russ76/OpenBot
def do_training(tr: Training,
                callback: tf.keras.callbacks.Callback,
                verbose=0):
    tr.model_name = dataset_name + "_" + str(tr.hyperparameters)
    tr.checkpoint_path = os.path.join(models_dir, tr.model_name, "checkpoints")
    tr.custom_objects = {
        "direction_metric": metrics.direction_metric,
        "angle_metric": metrics.angle_metric,
    }
    append_logs = False
    model: tf.keras.Model
    if tr.hyperparameters.USE_LAST:
        append_logs = True
        dirs = utils.list_dirs(tr.checkpoint_path)
        last_checkpoint = sorted(dirs)[-1]
        model = tf.keras.models.load_model(
            os.path.join(tr.checkpoint_path, last_checkpoint),
            custom_objects=tr.custom_objects,
            compile=False,
        )
    else:
        model = getattr(models, tr.hyperparameters.MODEL)(
            tr.NETWORK_IMG_WIDTH,
            tr.NETWORK_IMG_HEIGHT,
            tr.hyperparameters.BATCH_NORM,
        )

    tr.loss_fn = losses.sq_weighted_mse_angle
    tr.metric_list = [
        "mean_absolute_error",
        tr.custom_objects["direction_metric"],
        tr.custom_objects["angle_metric"],
    ]
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=tr.hyperparameters.LEARNING_RATE)

    model.compile(optimizer=optimizer, loss=tr.loss_fn, metrics=tr.metric_list)
    if verbose:
        print(model.summary())

    tr.log_path = os.path.join(models_dir, tr.model_name, "logs")
    if verbose:
        print(tr.model_name)

    STEPS_PER_EPOCH = np.ceil(tr.image_count_train /
                              tr.hyperparameters.TRAIN_BATCH_SIZE)
    callback.broadcast("message", "Fit model...")
    tr.history = model.fit(
        tr.train_ds,
        epochs=tr.hyperparameters.NUM_EPOCHS,
        steps_per_epoch=STEPS_PER_EPOCH,
        validation_data=tr.test_ds,
        verbose=verbose,
        callbacks=[
            callbacks.checkpoint_cb(tr.checkpoint_path),
            callbacks.tensorboard_cb(tr.log_path),
            callbacks.logger_cb(tr.log_path, append_logs),
            callback,
        ],
    )
예제 #2
0
파일: train.py 프로젝트: Russ76/OpenBot
def do_evaluation(tr: Training,
                  callback: tf.keras.callbacks.Callback,
                  verbose=0):
    callback.broadcast("message", "Generate plots...")
    history = tr.history
    log_path = tr.log_path
    plt.plot(history.history["mean_absolute_error"],
             label="mean_absolute_error")
    plt.plot(history.history["val_mean_absolute_error"],
             label="val_mean_absolute_error")
    plt.xlabel("Epoch")
    plt.ylabel("Mean Absolute Error")
    plt.legend(loc="lower right")
    savefig(os.path.join(log_path, "error.png"))

    plt.plot(history.history["direction_metric"], label="direction_metric")
    plt.plot(history.history["val_direction_metric"],
             label="val_direction_metric")
    plt.xlabel("Epoch")
    plt.ylabel("Direction Metric")
    plt.legend(loc="lower right")
    savefig(os.path.join(log_path, "direction.png"))

    plt.plot(history.history["angle_metric"], label="angle_metric")
    plt.plot(history.history["val_angle_metric"], label="val_angle_metric")
    plt.xlabel("Epoch")
    plt.ylabel("Angle Metric")
    plt.legend(loc="lower right")
    savefig(os.path.join(log_path, "angle.png"))

    plt.plot(history.history["loss"], label="loss")
    plt.plot(history.history["val_loss"], label="val_loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend(loc="lower right")
    savefig(os.path.join(log_path, "loss.png"))

    callback.broadcast("message", "Generate tflite models...")
    checkpoint_path = tr.checkpoint_path
    print("checkpoint_path", checkpoint_path)
    best_index = np.argmax(
        np.array(history.history["val_angle_metric"]) +
        np.array(history.history["val_direction_metric"]))
    best_checkpoint = str("cp-%04d.ckpt" % (best_index + 1))
    best_tflite = utils.generate_tflite(checkpoint_path, best_checkpoint)
    utils.save_tflite(best_tflite, checkpoint_path, "best")
    print("Best Checkpoint (val_angle: %s, val_direction: %s): %s" % (
        history.history["val_angle_metric"][best_index],
        history.history["val_direction_metric"][best_index],
        best_checkpoint,
    ))

    last_checkpoint = sorted(utils.list_dirs(checkpoint_path))[-1]
    last_tflite = utils.generate_tflite(checkpoint_path, last_checkpoint)
    utils.save_tflite(last_tflite, checkpoint_path, "last")
    print("Last Checkpoint (val_angle: %s, val_direction: %s): %s" % (
        history.history["val_angle_metric"][-1],
        history.history["val_direction_metric"][-1],
        last_checkpoint,
    ))

    callback.broadcast("message", "Evaluate model...")
    best_model = utils.load_model(
        os.path.join(checkpoint_path, best_checkpoint),
        tr.loss_fn,
        tr.metric_list,
        tr.custom_objects,
    )
    # test_loss, test_acc, test_dir, test_ang = best_model.evaluate(tr.test_ds,
    res = best_model.evaluate(
        tr.test_ds,
        steps=tr.image_count_test / tr.hyperparameters.TEST_BATCH_SIZE,
        verbose=2,
    )
    print(res)

    NUM_SAMPLES = 15
    (image_batch, cmd_batch), label_batch = next(iter(tr.test_ds))
    pred_batch = best_model.predict((
        tf.slice(image_batch, [0, 0, 0, 0], [NUM_SAMPLES, -1, -1, -1]),
        tf.slice(cmd_batch, [0], [NUM_SAMPLES]),
    ))
    utils.show_test_batch(image_batch.numpy(), cmd_batch.numpy(),
                          label_batch.numpy(), pred_batch)
    savefig(os.path.join(log_path, "test_preview.png"))
    utils.compare_tf_tflite(best_model, best_tflite)
예제 #3
0
def do_training(tr: Training,
                callback: tf.keras.callbacks.Callback,
                verbose=0):
    tr.model_name = dataset_name + "_" + str(tr.hyperparameters)
    tr.checkpoint_path = os.path.join(models_dir, tr.model_name, "checkpoints")
    tr.custom_objects = {
        "direction_metric": metrics.direction_metric,
        "angle_metric": metrics.angle_metric,
    }
    model_path = os.path.join(models_dir, tr.model_name, "model")

    if tr.hyperparameters.WANDB:
        import wandb
        from wandb.keras import WandbCallback

        wandb.init(project="openbot")

        config = wandb.config
        config.epochs = tr.hyperparameters.NUM_EPOCHS
        config.learning_rate = tr.hyperparameters.LEARNING_RATE
        config.batch_size = tr.hyperparameters.TRAIN_BATCH_SIZE
        config["model_name"] = tr.model_name

    append_logs = False
    model: tf.keras.Model
    if tr.hyperparameters.USE_LAST:
        append_logs = True
        model = tf.keras.models.load_model(
            model_path,
            custom_objects=tr.custom_objects,
            compile=False,
        )
    else:
        model = getattr(models, tr.hyperparameters.MODEL)(
            tr.NETWORK_IMG_WIDTH,
            tr.NETWORK_IMG_HEIGHT,
            tr.hyperparameters.BATCH_NORM,
        )
        dot_img_file = os.path.join(models_dir, tr.model_name, "model.png")
        tf.keras.utils.plot_model(model,
                                  to_file=dot_img_file,
                                  show_shapes=True)

    callback.broadcast("model", tr.model_name)

    tr.loss_fn = losses.sq_weighted_mse_angle
    tr.metric_list = [
        "mean_absolute_error",
        tr.custom_objects["direction_metric"],
        tr.custom_objects["angle_metric"],
    ]
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=tr.hyperparameters.LEARNING_RATE)

    model.compile(optimizer=optimizer, loss=tr.loss_fn, metrics=tr.metric_list)
    if verbose:
        print(model.summary())

    tr.log_path = os.path.join(models_dir, tr.model_name, "logs")
    if verbose:
        print(tr.model_name)

    STEPS_PER_EPOCH = np.ceil(tr.image_count_train /
                              tr.hyperparameters.TRAIN_BATCH_SIZE)
    callback.broadcast("message", "Fit model...")
    callback_list = [
        callbacks.checkpoint_cb(tr.checkpoint_path),
        callbacks.tensorboard_cb(tr.log_path),
        callbacks.logger_cb(tr.log_path, append_logs),
        callback,
    ]

    if tr.hyperparameters.WANDB:
        callback_list += [WandbCallback()]

    tr.history = model.fit(
        tr.train_ds,
        epochs=tr.hyperparameters.NUM_EPOCHS,
        steps_per_epoch=STEPS_PER_EPOCH,
        validation_data=tr.test_ds,
        verbose=verbose,
        callbacks=callback_list,
    )
    model.save(model_path)

    if tr.hyperparameters.WANDB:
        wandb.save(model_path)
        wandb.finish()