def main(args):

    # get parsed arguments from user
    input_dir = args.input_dir
    architecture = args.architecture
    color_mode = args.color
    loss = args.loss
    batch_size = args.batch
    epochs = args.epochs
    lr_estimate = args.lr_estimate
    policy = args.policy

    # check arguments
    check_arguments(architecture, color_mode, loss)

    # get autoencoder
    autoencoder = AutoEncoder(input_dir, architecture, color_mode, loss,
                              batch_size)

    # load data as generators that yield batches of preprocessed images
    preprocessor = Preprocessor(
        input_directory=input_dir,
        rescale=autoencoder.rescale,
        shape=autoencoder.shape,
        color_mode=autoencoder.color_mode,
    )
    train_generator = preprocessor.get_train_generator(
        batch_size=autoencoder.batch_size, shuffle=True)
    validation_generator = preprocessor.get_val_generator(
        batch_size=autoencoder.batch_size,
        shuffle=False,
        purpose="val",
    )

    # find best learning rates for training
    lr_opt = autoencoder.find_lr_opt(train_generator, validation_generator,
                                     lr_estimate)

    # train with optimal learning rate
    autoencoder.fit(lr_opt=lr_opt, epochs=epochs, policy=policy)

    # save model and configuration
    autoencoder.save()

    # inspect validation and test images for visual assessement
    if args.inspect:
        inspection.inspect_images(model_path=autoencoder.save_path)
    logger.info("done.")
    return
def main(args):

    # get parsed arguments from user
    input_dir = args.input_dir
    architecture = args.architecture
    color_mode = args.color
    loss = args.loss
    batch_size = args.batch

    # check arguments
    check_arguments(architecture, color_mode, loss)

    # get autoencoder
    autoencoder = AutoEncoder(input_dir, architecture, color_mode, loss,
                              batch_size)

    # load data as generators that yield batches of preprocessed images
    preprocessor = Preprocessor(
        input_directory=input_dir,
        rescale=autoencoder.rescale,
        shape=autoencoder.shape,
        color_mode=autoencoder.color_mode,
        preprocessing_function=autoencoder.preprocessing_function,
    )
    train_generator = preprocessor.get_train_generator(
        batch_size=autoencoder.batch_size, shuffle=True)
    validation_generator = preprocessor.get_val_generator(
        batch_size=autoencoder.batch_size, shuffle=True)

    # find best learning rates for training
    autoencoder.find_opt_lr(train_generator, validation_generator)

    # train
    autoencoder.fit()

    # save model
    autoencoder.save()

    if args.inspect:
        # -------------- INSPECTING VALIDATION IMAGES --------------
        logger.info("generating inspection plots of validation images...")

        # create a directory to save inspection plots
        inspection_val_dir = os.path.join(autoencoder.save_dir,
                                          "inspection_val")
        if not os.path.isdir(inspection_val_dir):
            os.makedirs(inspection_val_dir)

        inspection_val_generator = preprocessor.get_val_generator(
            batch_size=autoencoder.learner.val_data.samples, shuffle=False)

        imgs_val_input = inspection_val_generator.next()[0]
        filenames_val = inspection_val_generator.filenames

        # get reconstructed images (i.e predictions) on validation dataset
        logger.info("reconstructing validation images...")
        imgs_val_pred = autoencoder.model.predict(imgs_val_input)

        # convert to grayscale if RGB
        if color_mode == "rgb":
            imgs_val_input = tf.image.rgb_to_grayscale(imgs_val_input).numpy()
            imgs_val_pred = tf.image.rgb_to_grayscale(imgs_val_pred).numpy()

        # remove last channel since images are grayscale
        imgs_val_input = imgs_val_input[:, :, :, 0]
        imgs_val_pred = imgs_val_pred[:, :, :, 0]

        # instantiate TensorImages object to compute validation resmaps
        tensor_val = postprocessing.TensorImages(
            imgs_input=imgs_val_input,
            imgs_pred=imgs_val_pred,
            vmin=autoencoder.vmin,
            vmax=autoencoder.vmax,
            method=autoencoder.loss,
            dtype="float64",
            filenames=filenames_val,
        )

        # generate and save inspection validation plots
        tensor_val.generate_inspection_plots(group="validation",
                                             save_dir=inspection_val_dir)

        # -------------- INSPECTING TEST IMAGES --------------
        logger.info("generating inspection plots of test images...")

        # create a directory to save inspection plots
        inspection_test_dir = os.path.join(autoencoder.save_dir,
                                           "inspection_test")
        if not os.path.isdir(inspection_test_dir):
            os.makedirs(inspection_test_dir)

        nb_test_images = preprocessor.get_total_number_test_images()

        inspection_test_generator = preprocessor.get_test_generator(
            batch_size=nb_test_images, shuffle=False)

        imgs_test_input = inspection_test_generator.next()[0]
        filenames_test = inspection_test_generator.filenames

        # get reconstructed images (i.e predictions) on validation dataset
        logger.info("reconstructing test images...")
        imgs_test_pred = autoencoder.model.predict(imgs_test_input)

        # convert to grayscale if RGB
        if color_mode == "rgb":
            imgs_test_input = tf.image.rgb_to_grayscale(
                imgs_test_input).numpy()
            imgs_test_pred = tf.image.rgb_to_grayscale(imgs_test_pred).numpy()

        # remove last channel since images are grayscale
        imgs_test_input = imgs_test_input[:, :, :, 0]
        imgs_test_pred = imgs_test_pred[:, :, :, 0]

        # instantiate TensorImages object to compute test resmaps
        tensor_test = postprocessing.TensorImages(
            imgs_input=imgs_test_input,
            imgs_pred=imgs_test_pred,
            vmin=autoencoder.vmin,
            vmax=autoencoder.vmax,
            method=autoencoder.loss,
            dtype="float64",
            filenames=filenames_test,
        )

        # generate and save inspection test plots
        tensor_test.generate_inspection_plots(group="test",
                                              save_dir=inspection_test_dir)

    logger.info("done.")
    return