コード例 #1
0
def run_frozensony(mod: dict):
    logger.info("STARTED running frozen sony model updates.")

    # Specify Image Data Pipeline

    idp = ImageDataPipeline(preprocessing_function='sony',
                            stride=32,
                            batch_size=32,
                            patch_size=(64,64),
                            random_seed=42,
                            meanm_fpath='simulation_mean.pkl',
                            covm_fpath='simulation_cov.pkl',
                            num_images=10
    )
    
    # Specify train/val generators

    train_file = 'Sony_RGB/Sony_train_list.txt'
    train_dataflow = SonyDataGenerator(train_file, idp)

    val_file = 'Sony_RGB/Sony_val_list.txt'
    val_dataflow = SonyDataGenerator(val_file, idp)

    # Fit model

    mod = functional_sony()
    model_type = 'bl_cd_pn_ag'
    frozen_model, history = train_frozen_model(train_dataflow, val_dataflow,
                                               epochs=250, mod=mod,
                                               model_type=model_type,
                                               lr=1e-4)

    # Save history

    try:
        review_dir = os.path.join(os.getcwd(), 'review')
        if not os.path.isdir(review_dir):
            os.makedirs(review_dir)

        model_id = 'freeze_sony'
        model_name = '{}_{}'.format(model_id, model_type)

        datetime_now = datetime.now().strftime("%Y%m%d-%H%M%S")
        model_history_name = '{}_{}.json'.format(model_name, datetime_now)
        mh_filepath = os.path.join(review_dir, model_history_name)

        with open(mh_filepath, 'w') as outfile:
            json.dump(str(history.history), outfile)

        logger.info('Saved model history: {}'.format(mh_filepath))

    except Exception as exc:
        logger.exception(exc)
コード例 #2
0
def main():
    """ Main function to run training and prediction. """

    mod = functional_sony()
    run_simulation(mod)
コード例 #3
0
        model = restore_model(mod, model_name)

        y_pred_ij = []
        for X_patch in X_patches:

            # Predict against augmented X_patch

            y_pred = model.predict(np.expand_dims(X_patch, axis=0))
            y_pred_ij.append(y_pred[0])

        # Reconstruct Y_pred

        Y_pred_patches = np.array(y_pred_ij)
        Y_pred = idp.reconstruct_patches(Y_pred_patches, Y_test.shape)

        # Write out image comparison

        review_dir = os.path.join(os.getcwd(), 'review')
        if not os.path.isdir(review_dir):
            os.makedirs(review_dir)

        model_chkpt_name = '{}_{}.png'.format(mod.get('model_id', ''),
                                              checkpoint_name)
        mi_filepath = os.path.join(review_dir, model_chkpt_name)

        plot_images(mi_filepath, X_test, Y_pred, Y_test)


if __name__ == '__main__':
    mod = functional_sony()
    run_raise_giph(mod, 'bl_cd_pn_ag')