def main():

    # Check command line arguments.
    #if len(sys.argv) != 2 or sys.argv[1] not in model_names:
    #    print("Must provide name of model.")
    #    print("Options: " + " ".join(model_names))
    #    exit(0)
    #model_name = sys.argv[1]

    # Data preparation.
    nb_classes = 40
    train_file = './ModelNet40/ply_data_train.h5'
    test_file = './ModelNet40/ply_data_test.h5'

    # Hyperparameters.
    number_of_points = 1024
    epochs = 100
    batch_size = 32

    # Data generators for training and validation.
    train = DataGenerator(train_file,
                          batch_size,
                          number_of_points,
                          nb_classes,
                          train=True)
    val = DataGenerator(test_file,
                        batch_size,
                        number_of_points,
                        nb_classes,
                        train=False)

    # Create the model.
    if model_name == "pointnet":
        model = create_pointnet(number_of_points, nb_classes)
    elif model_name == "gapnet":
        model = GAPNet()
    model.summary()

    # Ensure output paths.
    output_path = "logs"
    if not os.path.exists(output_path):
        os.mkdir(output_path)
    output_path = os.path.join(output_path, model_name)
    if not os.path.exists(output_path):
        os.mkdir(output_path)
    output_path = os.path.join(output_path, training_name)
    if os.path.exists(output_path):
        shutil.rmtree(output_path)
    os.mkdir(output_path)

    # Compile the model.
    lr = 0.0001
    adam = Adam(lr=lr)
    model.compile(optimizer=adam,
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

    # Checkpoint callback.
    checkpoint = ModelCheckpoint(os.path.join(output_path, "model.h5"),
                                 monitor="val_acc",
                                 save_weights_only=True,
                                 save_best_only=True,
                                 verbose=1)

    # Logging training progress with tensorboard.
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=output_path,
        histogram_freq=0,
        batch_size=32,
        write_graph=True,
        write_grads=False,
        write_images=True,
        embeddings_freq=0,
        embeddings_layer_names=None,
        embeddings_metadata=None,
        embeddings_data=None,
        update_freq="epoch")

    callbacks = []
    #callbacks.append(checkpoint)
    callbacks.append(onetenth_50_75(lr))
    callbacks.append(tensorboard_callback)

    # Train the model.
    history = model.fit_generator(train.generator(),
                                  steps_per_epoch=9840 // batch_size,
                                  epochs=epochs,
                                  validation_data=val.generator(),
                                  validation_steps=2468 // batch_size,
                                  callbacks=callbacks,
                                  verbose=1)

    # Save history and model.
    plot_history(history, output_path)
    save_history(history, output_path)
    model.save_weights(os.path.join(output_path, "model_weights.h5"))
Example #2
0
    log_dir="logs",
    histogram_freq=0,
    write_graph=True,
    write_grads=False,
    write_images=True,
    embeddings_freq=0,
    embeddings_layer_names=None,
    embeddings_metadata=None,
    embeddings_data=None,
    update_freq="epoch")
training_callbacks.append(tensorboard_callback)

# Compile the model.
lr = 0.0001
adam = optimizers.Adam(lr=lr)
model.compile(optimizer=adam, loss="mse", metrics=["mae"])

batch_size = 128
epochs = 500
model.fit(dataset_training.batch(batch_size),
          validation_data=dataset_validate.batch(batch_size),
          epochs=epochs,
          callbacks=training_callbacks)

# Save the model.
logger.info('Saving and uploading weights...')
path = "gapnet_weights.h5"
model.save_weights(path)
run.upload_file(name="gapnet_weights.h5", path_or_stream=path)

# Save the model.