示例#1
0
def save_model_and_history(model, history, name):

    print("Saving model and history...")

    datetime_string = utils.get_datetime_string()

    # Try to save model. Could fail.
    try:
        model_name = datetime_string + "-" + name + "-model.h5"
        model_path = os.path.join(output_path, model_name)
        model.save(model_path)
        print("Saved model to" + model_name)
    except Exception as e:
        print("WARNING! Failed to save model. Use model-weights instead.")

    # Save the model weights.
    model_weights_name = datetime_string + "-" + name + "-model-weights.h5"
    model_weights_path = os.path.join(output_path, model_weights_name)
    model.save_weights(model_weights_path)
    print("Saved model weights to" + model_name)

    # Save the history.
    history_name = datetime_string + "-" + name + "-history.p"
    history_path = os.path.join(output_path, history_name)
    pickle.dump(history.history, open(history_path, "wb"))
    print("Saved history to" + history_name)
示例#2
0
def save_dataset(dataset_train, dataset_test, dataset_parameters):
    print("Saving dataset...")
    data = (dataset_train, dataset_test, dataset_parameters)
    datetime_string = utils.get_datetime_string()
    dataset_name = datetime_string + "-" + dataset_parameters[
        "input_type"] + "-dataset.p"
    dataset_path = os.path.join(output_path, dataset_name)
    pickle.dump(data, open(dataset_path, "wb"))
    print("Saved dataset to " + dataset_path + ".")
示例#3
0
def plot_histories(histories, names):
    for index, (history, name) in enumerate(histories.items()):
        for key, data in history.history.items():
            plt.plot(data, label=name + "-" + key)

    fig_name = utils.get_datetime_string() + "-histories.png"
    fig_path = os.path.join(output_path, fig_name)
    plt.savefig(fig_path)
    plt.show()
    plt.close()
示例#4
0
    # Training details.
    training_details = {
        "dataset_path": dataset_path,
        "qrcodes_train": qrcodes_train,
        "qrcodes_validate": qrcodes_validate,
        "steps_per_epoch": steps_per_epoch,
        "validation_steps": validation_steps,
        "epochs": epochs,
        "batch_size": batch_size,
        "random_seed": random_seed,
        "dataset_parameters": dataset_parameters
    }

    # Date time string.
    datetime_string = utils.get_datetime_string() + "_{}-{}".format(
        len(qrcodes_train), len(qrcodes_validate)) + "_".join(
            dataset_parameters["output_targets"])

    # Output path. Ensure its existence.
    output_path = os.path.join(output_root_path, datetime_string)
    if os.path.exists(output_path) == False:
        os.makedirs(output_path)
    print("Using output path:", output_path)

    # Important things.
    pp = pprint.PrettyPrinter(indent=4)
    log_dir = os.path.join("/whhdata/models", "logs", datetime_string)
    tensorboard_callback = callbacks.TensorBoard(log_dir=log_dir)
    histories = {}
示例#5
0
def main():

    # Parse command line arguments.
    parser = argparse.ArgumentParser(description="Training on GPU")
    parser.add_argument("-config_file",
                        action="store",
                        dest="config_file",
                        type=str,
                        required=True,
                        help="config file path")
    parser.add_argument("-use_multi_gpu",
                        action="store_true",
                        dest="use_multi_gpu",
                        help="set the training on multiple gpus")
    parser.add_argument("-resume_training",
                        action="store_true",
                        dest="resume_training",
                        help="resumes a previous training")
    arguments = parser.parse_args()

    # Loading the config file.
    config = json.load(open(arguments.config_file, "r"))
    config = Bunch({key: Bunch(value) for key, value in config.items()})

    # Create logger.
    logger = logging.getLogger("train.py")
    logger.setLevel(logging.DEBUG)
    file_handler = logging.FileHandler(
        os.path.join(config.global_parameters.output_path, "train.log"))
    file_handler.setLevel(logging.DEBUG)
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    file_handler.setFormatter(formatter)
    console_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)
    logger.info("Starting training job...")

    # Prepare results.
    results = Bunch()

    # Check if there is a GPU.
    if len(utils.get_available_gpus()) == 0:
        logger.warning("WARNING! No GPU available!")

    # Create datagenerator.
    datagenerator_instance = create_datagenerator_from_parameters(
        config.datagenerator_parameters.dataset_path,
        config.datagenerator_parameters)

    # Do a test-validation split.
    qrcodes = datagenerator_instance.qrcodes[:]
    randomizer = random.Random(config.datagenerator_parameters.random_seed)
    randomizer.shuffle(qrcodes)
    split_index = int(0.8 * len(qrcodes))
    qrcodes_train = sorted(qrcodes[:split_index])
    qrcodes_validate = sorted(qrcodes[split_index:])
    del qrcodes
    results.qrcodes_train = qrcodes_train
    results.qrcodes_validate = qrcodes_validate

    # Create python generators.
    workers = 4
    generator_train = datagenerator_instance.generate(
        size=config.training_parameters.batch_size,
        qrcodes_to_use=qrcodes_train,
        workers=workers)
    generator_validate = datagenerator_instance.generate(
        size=config.training_parameters.batch_size,
        qrcodes_to_use=qrcodes_validate,
        workers=workers)

    # Output path. Ensure its existence.
    if os.path.exists(config.global_parameters.output_path) == False:
        os.makedirs(config.global_parameters.output_path)
    logger.info("Using output path:", config.global_parameters.output_path)

    # Copy config file.
    shutil.copy2(arguments.config_file, config.global_parameters.output_path)

    # Create the model path.
    model_path = os.path.join(config.global_parameters.output_path, "model.h5")

    # TODO
    assert config.model_parameters.type == "pointnet"

    # Resume training.
    if arguments.resume_training == True:
        if os.path.exists(model_path) == False:
            logger.error("Model does not exist. Cannot resume!")
            exit(0)
        model = tf.keras.models.load_model(model_path)
        logger.info("Loaded model from {}.".format(config.model_path))

    # Start from scratch.
    else:
        model = modelutils.create_point_net(
            config.model_parameters.input_shape,
            config.model_parameters.output_size,
            config.model_parameters.hidden_sizes)
        logger.info("Created new model.")
    model.summary()

    # Compile model.
    if config.model_parameters.optimizer == "rmsprop":
        optimizer = optimizers.RMSprop(
            learning_rate=config.model_parameters.learning_rate)
    elif config.model_parameters.optimizer == "adam":
        optimizer = optimizers.Adam(
            learning_rate=config.model_parameters.learning_rate,
            beta_1=config.model_parameters.beta_1,
            beta_2=config.model_parameters.beta_2,
            amsgrad=config.model_parameters.amsgrad)
    else:
        raise Exception("Unexpected optimizer {}".format(
            config.model_parameters.optimizer))

    model.compile(optimizer=optimizer, loss="mse", metrics=["mae"])

    # Do training on multiple GPUs.
    original_model = model
    if arguments.use_multi_gpu == True:
        model = tf.keras.utils.multi_gpu_model(model, gpus=2)

    # Create the callbacks.
    callbacks = []

    # Logging training progress with tensorboard.
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=config.global_parameters.output_path,
        histogram_freq=0,
        batch_size=32,
        write_graph=True,
        write_grads=False,
        write_images=True,
        embeddings_freq=0,
        embeddings_layer_names=None,
        embeddings_metadata=None,
        embeddings_data=None,
        update_freq="epoch")
    callbacks.append(tensorboard_callback)

    # Early stopping.
    if config.training_parameters.use_early_stopping == True:
        early_stopping_callback = tf.keras.callbacks.EarlyStopping(
            monitor="val_loss",
            min_delta=config.training_parameters.early_stopping_threshold,
            patience=5,
            verbose=1)
        callbacks.append(early_stopping_callback)

    # Model checkpoint.
    val_loss_callback = tf.keras.callbacks.ModelCheckpoint(
        os.path.join(config.global_parameters.output_path,
                     "val_loss_{val_loss:.2f}_at_epoche_{epoch:2d}.hdf5"),
        monitor="val_loss",
        verbose=0,
        save_best_only=True,
        save_weights_only=False,
        mode="auto",
        save_freq="epoch")
    callbacks.append(val_loss_callback)

    # Start training.
    results.training_begin = utils.get_datetime_string()
    try:
        # Train the model.
        model.fit_generator(
            generator_train,
            steps_per_epoch=config.training_parameters.steps_per_epoch,
            epochs=config.training_parameters.epochs,
            validation_data=generator_validate,
            validation_steps=config.training_parameters.validation_steps,
            use_multiprocessing=False,
            workers=0,
            callbacks=callbacks)
    except KeyboardInterrupt:
        logger.info("Gracefully stopping training...")
        datagenerator_instance.finish()
        results.interrupted_by_user = True

    # Training ended.
    results.training_end = utils.get_datetime_string()

    # Save the model. Make sure that it is the original model.
    original_model.save(model_path)

    # Store the history.
    results.model_history = model.history.history

    # Write the results.
    results_name = "results.json"
    results_path = os.path.join(config.global_parameters.output_path,
                                results_name)
    json.dump(results, open(results_path, "w"), indent=4, sort_keys=True)