def do_training(tr: Training, callback: tf.keras.callbacks.Callback, verbose=0): tr.model_name = dataset_name + "_" + str(tr.hyperparameters) tr.checkpoint_path = os.path.join(models_dir, tr.model_name, "checkpoints") tr.custom_objects = { "direction_metric": metrics.direction_metric, "angle_metric": metrics.angle_metric, } append_logs = False model: tf.keras.Model if tr.hyperparameters.USE_LAST: append_logs = True dirs = utils.list_dirs(tr.checkpoint_path) last_checkpoint = sorted(dirs)[-1] model = tf.keras.models.load_model( os.path.join(tr.checkpoint_path, last_checkpoint), custom_objects=tr.custom_objects, compile=False, ) else: model = getattr(models, tr.hyperparameters.MODEL)( tr.NETWORK_IMG_WIDTH, tr.NETWORK_IMG_HEIGHT, tr.hyperparameters.BATCH_NORM, ) tr.loss_fn = losses.sq_weighted_mse_angle tr.metric_list = [ "mean_absolute_error", tr.custom_objects["direction_metric"], tr.custom_objects["angle_metric"], ] optimizer = tf.keras.optimizers.Adam( learning_rate=tr.hyperparameters.LEARNING_RATE) model.compile(optimizer=optimizer, loss=tr.loss_fn, metrics=tr.metric_list) if verbose: print(model.summary()) tr.log_path = os.path.join(models_dir, tr.model_name, "logs") if verbose: print(tr.model_name) STEPS_PER_EPOCH = np.ceil(tr.image_count_train / tr.hyperparameters.TRAIN_BATCH_SIZE) callback.broadcast("message", "Fit model...") tr.history = model.fit( tr.train_ds, epochs=tr.hyperparameters.NUM_EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, validation_data=tr.test_ds, verbose=verbose, callbacks=[ callbacks.checkpoint_cb(tr.checkpoint_path), callbacks.tensorboard_cb(tr.log_path), callbacks.logger_cb(tr.log_path, append_logs), callback, ], )
def do_evaluation(tr: Training, callback: tf.keras.callbacks.Callback, verbose=0): callback.broadcast("message", "Generate plots...") history = tr.history log_path = tr.log_path plt.plot(history.history["mean_absolute_error"], label="mean_absolute_error") plt.plot(history.history["val_mean_absolute_error"], label="val_mean_absolute_error") plt.xlabel("Epoch") plt.ylabel("Mean Absolute Error") plt.legend(loc="lower right") savefig(os.path.join(log_path, "error.png")) plt.plot(history.history["direction_metric"], label="direction_metric") plt.plot(history.history["val_direction_metric"], label="val_direction_metric") plt.xlabel("Epoch") plt.ylabel("Direction Metric") plt.legend(loc="lower right") savefig(os.path.join(log_path, "direction.png")) plt.plot(history.history["angle_metric"], label="angle_metric") plt.plot(history.history["val_angle_metric"], label="val_angle_metric") plt.xlabel("Epoch") plt.ylabel("Angle Metric") plt.legend(loc="lower right") savefig(os.path.join(log_path, "angle.png")) plt.plot(history.history["loss"], label="loss") plt.plot(history.history["val_loss"], label="val_loss") plt.xlabel("Epoch") plt.ylabel("Loss") plt.legend(loc="lower right") savefig(os.path.join(log_path, "loss.png")) callback.broadcast("message", "Generate tflite models...") checkpoint_path = tr.checkpoint_path print("checkpoint_path", checkpoint_path) best_index = np.argmax( np.array(history.history["val_angle_metric"]) + np.array(history.history["val_direction_metric"])) best_checkpoint = str("cp-%04d.ckpt" % (best_index + 1)) best_tflite = utils.generate_tflite(checkpoint_path, best_checkpoint) utils.save_tflite(best_tflite, checkpoint_path, "best") print("Best Checkpoint (val_angle: %s, val_direction: %s): %s" % ( history.history["val_angle_metric"][best_index], history.history["val_direction_metric"][best_index], best_checkpoint, )) last_checkpoint = sorted(utils.list_dirs(checkpoint_path))[-1] last_tflite = utils.generate_tflite(checkpoint_path, last_checkpoint) utils.save_tflite(last_tflite, checkpoint_path, "last") print("Last Checkpoint (val_angle: %s, val_direction: %s): %s" % ( history.history["val_angle_metric"][-1], history.history["val_direction_metric"][-1], last_checkpoint, )) callback.broadcast("message", "Evaluate model...") best_model = utils.load_model( os.path.join(checkpoint_path, best_checkpoint), tr.loss_fn, tr.metric_list, tr.custom_objects, ) # test_loss, test_acc, test_dir, test_ang = best_model.evaluate(tr.test_ds, res = best_model.evaluate( tr.test_ds, steps=tr.image_count_test / tr.hyperparameters.TEST_BATCH_SIZE, verbose=2, ) print(res) NUM_SAMPLES = 15 (image_batch, cmd_batch), label_batch = next(iter(tr.test_ds)) pred_batch = best_model.predict(( tf.slice(image_batch, [0, 0, 0, 0], [NUM_SAMPLES, -1, -1, -1]), tf.slice(cmd_batch, [0], [NUM_SAMPLES]), )) utils.show_test_batch(image_batch.numpy(), cmd_batch.numpy(), label_batch.numpy(), pred_batch) savefig(os.path.join(log_path, "test_preview.png")) utils.compare_tf_tflite(best_model, best_tflite)
def do_training(tr: Training, callback: tf.keras.callbacks.Callback, verbose=0): tr.model_name = dataset_name + "_" + str(tr.hyperparameters) tr.checkpoint_path = os.path.join(models_dir, tr.model_name, "checkpoints") tr.custom_objects = { "direction_metric": metrics.direction_metric, "angle_metric": metrics.angle_metric, } model_path = os.path.join(models_dir, tr.model_name, "model") if tr.hyperparameters.WANDB: import wandb from wandb.keras import WandbCallback wandb.init(project="openbot") config = wandb.config config.epochs = tr.hyperparameters.NUM_EPOCHS config.learning_rate = tr.hyperparameters.LEARNING_RATE config.batch_size = tr.hyperparameters.TRAIN_BATCH_SIZE config["model_name"] = tr.model_name append_logs = False model: tf.keras.Model if tr.hyperparameters.USE_LAST: append_logs = True model = tf.keras.models.load_model( model_path, custom_objects=tr.custom_objects, compile=False, ) else: model = getattr(models, tr.hyperparameters.MODEL)( tr.NETWORK_IMG_WIDTH, tr.NETWORK_IMG_HEIGHT, tr.hyperparameters.BATCH_NORM, ) dot_img_file = os.path.join(models_dir, tr.model_name, "model.png") tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True) callback.broadcast("model", tr.model_name) tr.loss_fn = losses.sq_weighted_mse_angle tr.metric_list = [ "mean_absolute_error", tr.custom_objects["direction_metric"], tr.custom_objects["angle_metric"], ] optimizer = tf.keras.optimizers.Adam( learning_rate=tr.hyperparameters.LEARNING_RATE) model.compile(optimizer=optimizer, loss=tr.loss_fn, metrics=tr.metric_list) if verbose: print(model.summary()) tr.log_path = os.path.join(models_dir, tr.model_name, "logs") if verbose: print(tr.model_name) STEPS_PER_EPOCH = np.ceil(tr.image_count_train / tr.hyperparameters.TRAIN_BATCH_SIZE) callback.broadcast("message", "Fit model...") callback_list = [ callbacks.checkpoint_cb(tr.checkpoint_path), callbacks.tensorboard_cb(tr.log_path), callbacks.logger_cb(tr.log_path, append_logs), callback, ] if tr.hyperparameters.WANDB: callback_list += [WandbCallback()] tr.history = model.fit( tr.train_ds, epochs=tr.hyperparameters.NUM_EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, validation_data=tr.test_ds, verbose=verbose, callbacks=callback_list, ) model.save(model_path) if tr.hyperparameters.WANDB: wandb.save(model_path) wandb.finish()