Beispiel #1
0
    def summary(self):
        Hexnet_print('Generator')
        self.generator_for_summary.summary()

        print_newline()
        Hexnet_print('Discriminator')
        self.discriminator_for_summary.summary()
Beispiel #2
0
def Hexnet_init():
	Hexnet_load()
	Hexnet.print_info()
	print_newline()
Beispiel #3
0
def run(args):

    model_string = args.model
    load_model = args.load_model
    load_weights = args.load_weights
    save_model = args.save_model
    save_weights = args.save_weights

    dataset = args.dataset
    resize_dataset = args.resize_dataset
    crop_dataset = args.crop_dataset
    augment_dataset = args.augment_dataset
    augmenter_string = args.augmenter
    augmentation_level = args.augmentation_level

    tests_dir = args.tests_dir
    show_dataset = args.show_dataset
    visualize_model = args.visualize_model
    show_results = args.show_results

    batch_size = args.batch_size
    epochs = args.epochs
    loss_string = args.loss
    runs = args.runs
    validation_split = args.validation_split

    cnn_kernel_size = args.cnn_kernel_size
    cnn_pool_size = args.cnn_pool_size

    verbosity_level = args.verbosity_level

    transform_s2h = args.transform_s2h
    transform_s2s = args.transform_s2s
    transform_rad_o = args.transform_rad_o
    transform_width = args.transform_width
    transform_height = args.transform_height

    if model_string is not None:
        model_is_provided = True

        model_is_custom = True if 'custom' in model_string else False
        model_is_standalone = True if 'standalone' in model_string else False

        model_is_autoencoder = True if 'autoencoder' in model_string else False
        model_is_CNN = True if 'CNN' in model_string else False
        model_is_GAN = True if 'GAN' in model_string else False
    else:
        model_is_provided = False

    if augmenter_string is not None:
        augmenter_is_custom = True if 'custom' in augmenter_string else False

    if loss_string is not None:
        loss_is_provided = True

        loss_is_subpixel_loss = True if ('s2s' in loss_string
                                         or 's2h' in loss_string) else False
    else:
        loss_is_provided = False

    train_classes = []
    train_data = []
    train_filenames = []
    train_labels = []
    train_labels_orig = []
    test_classes = []
    test_data = []
    test_filenames = []
    test_labels = []
    test_labels_orig = []

    ############################################################################
    # Transform the dataset
    ############################################################################

    if transform_s2h or transform_s2h is None or transform_s2s or transform_s2s is None:
        Hexnet_init()

        Hexnet_print('Dataset transformation')

        if transform_s2h or transform_s2h is None:
            if transform_s2h is None:
                transform_s2h = f'{dataset}_s2h'

            datasets.transform_dataset(dataset=dataset,
                                       output_dir=transform_s2h,
                                       mode='s2h',
                                       rad_o=transform_rad_o,
                                       method=0,
                                       verbosity_level=verbosity_level)

        if transform_s2s or transform_s2s is None:
            if transform_s2s is None:
                transform_s2s = f'{dataset}_s2s'

            datasets.transform_dataset(dataset=dataset,
                                       output_dir=transform_s2s,
                                       mode='s2s',
                                       width=transform_width,
                                       height=transform_height,
                                       method=0,
                                       verbosity_level=verbosity_level)

    ############################################################################
    # Load the dataset
    ############################################################################

    ((train_classes, train_data, train_filenames, train_labels_orig),
     (test_classes, test_data, test_filenames,
      test_labels_orig)) = datasets.load_dataset(
          dataset=dataset, create_h5=True, verbosity_level=verbosity_level)

    ############################################################################
    # Resize and crop the dataset
    ############################################################################

    if resize_dataset is not None:
        (train_data,
         test_data) = datasets.resize_dataset(dataset_s=(train_data,
                                                         test_data),
                                              resize_string=resize_dataset)

    if crop_dataset is not None:
        (train_data,
         test_data) = datasets.crop_dataset(dataset_s=(train_data, test_data),
                                            crop_string=crop_dataset)

    # TODO
    if model_is_provided and (model_is_autoencoder or model_is_GAN):
        if model_is_autoencoder:
            min_size_factor = 2**5
        else:
            min_size_factor = 2**4

        if train_data.shape[1] % min_size_factor:
            padding_h = min_size_factor - train_data.shape[1] % min_size_factor
            padding_h = (int(padding_h / 2) + padding_h % 2,
                         int(padding_h / 2))
        else:
            padding_h = (0, 0)

        if train_data.shape[2] % min_size_factor:
            padding_w = min_size_factor - train_data.shape[2] % min_size_factor
            padding_w = (int(padding_w / 2) + padding_w % 2,
                         int(padding_w / 2))
        else:
            padding_w = (0, 0)

        pad_width = ((0, 0), padding_h, padding_w, (0, 0))

        train_data = np.pad(train_data,
                            pad_width,
                            mode='constant',
                            constant_values=0)
        test_data = np.pad(test_data,
                           pad_width,
                           mode='constant',
                           constant_values=0)

    ############################################################################
    # Prepare the dataset
    ############################################################################

    class_labels_are_digits = True

    for class_label in train_classes:
        if not class_label.decode().isdigit():
            class_labels_are_digits = False
            break

    if class_labels_are_digits:
        train_labels = np.asarray(
            [int(label.decode()) for label in train_labels_orig])
        test_labels = np.asarray(
            [int(label.decode()) for label in test_labels_orig])
    else:
        train_labels = np.asarray([
            int(np.where(train_classes == label)[0])
            for label in train_labels_orig
        ])
        test_labels = np.asarray([
            int(np.where(test_classes == label)[0])
            for label in test_labels_orig
        ])

    train_classes = list(set(train_labels))
    test_classes = list(set(test_labels))

    if class_labels_are_digits:
        train_labels -= min(train_classes)
        test_labels -= min(test_classes)
        train_classes -= min(train_classes)
        test_classes -= min(test_classes)

    train_test_data_n = 255
    train_test_data_n /= 2
    train_data = (train_data - train_test_data_n) / train_test_data_n
    test_data = (test_data - train_test_data_n) / train_test_data_n

    ############################################################################
    # Augment the dataset
    ############################################################################

    print_newline()

    if augment_dataset is not None:
        Hexnet_print('Dataset augmentation')

        augmenter = vars(augmenters)[f'augmenter_{augmenter_string}']

        if augmenter_is_custom:
            augmenter = augmenter()
        else:
            augmenter = augmenter(augmentation_level)

        if 'train' in augment_dataset:
            train_data = augmenter(images=train_data)

        if 'test' in augment_dataset:
            test_data = augmenter(images=test_data)

        print_newline()

    print_newline()

    ############################################################################
    # Show the dataset
    ############################################################################

    if show_dataset:
        train_data_for_visualization = np.clip(train_data + 0.5, 0, 1)
        test_data_for_visualization = np.clip(test_data + 0.5, 0, 1)

        datasets.show_dataset_classes(train_classes,
                                      train_data_for_visualization,
                                      train_labels,
                                      test_classes,
                                      test_data_for_visualization,
                                      test_labels,
                                      max_images_per_class=1,
                                      max_classes_to_display=10)

    ############################################################################
    # No model was provided - returning
    ############################################################################

    if not model_is_provided:
        Hexnet_print('No model provided.')
        return 0

    ############################################################################
    # Shuffle the dataset
    ############################################################################

    train_data, train_labels = sklearn.utils.shuffle(train_data, train_labels)

    ############################################################################
    # Start a new run
    ############################################################################

    for run in range(1, runs + 1):
        run_string = f'run={run}/{runs}'

        dataset = os.path.basename(dataset)
        timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')

        run_title = f'{model_string}_{dataset}_{timestamp}_epochs{epochs}-bs{batch_size}'

        if runs > 1:
            run_title = f'{run_title}_run{run}'

        ########################################################################
        # Initialize / load the model
        ########################################################################

        Hexnet_print(f'({run_string}) Model initialization')

        input_shape = train_data.shape[1:4]
        output_shape = test_data.shape[1:4]
        classes = len(train_classes)

        if load_model is None:
            model = vars(models)[f'model_{model_string}']

            if model_is_custom or model_is_standalone:
                model = model(input_shape, classes)
            elif model_is_autoencoder:
                model = model(input_shape)
            elif model_is_CNN:
                model = model(input_shape, classes, cnn_kernel_size,
                              cnn_pool_size)
            else:
                model = model(input_shape, classes)
        else:
            model = tf.keras.models.load_model(load_model)

        if load_weights is not None:
            model.load_weights(load_weights)

        print_newline()

        ########################################################################
        # Fit the model
        ########################################################################

        if not loss_is_provided:
            if not model_is_autoencoder:
                loss = 'sparse_categorical_crossentropy'
            else:
                loss = 'mse'
        else:
            if not loss_is_subpixel_loss:
                loss = vars(losses)[f'loss_{loss_string}']()
            else:
                loss = vars(losses)[f'loss_{loss_string}'](input_shape,
                                                           output_shape)

        if not model_is_autoencoder:
            metrics = ['accuracy']
        else:
            metrics = None

        if not model_is_standalone:
            model.compile(optimizer='adam', loss=loss, metrics=metrics)
        else:
            model.compile()

        Hexnet_print(f'({run_string}) Model summary')
        model.summary()
        print_newline()

        Hexnet_print(f'({run_string}) Training')

        if model_is_standalone:
            model.fit(train_data, train_labels, batch_size, epochs, tests_dir,
                      run_title)
        elif model_is_autoencoder:
            history = model.fit(train_data,
                                train_data,
                                batch_size,
                                epochs,
                                validation_split=validation_split)
        else:
            history = model.fit(train_data,
                                train_labels,
                                batch_size,
                                epochs,
                                validation_split=validation_split)

        print_newline()

        ########################################################################
        # Visualize filters, feature maps, and training results
        ########################################################################

        if not model_is_standalone:
            if visualize_model is not None:
                Hexnet_print(f'({run_string}) Visualization')

                visualization.visualize_model(model,
                                              test_classes,
                                              test_data,
                                              test_labels,
                                              output_dir=visualize_model,
                                              max_images_per_class=10,
                                              verbosity_level=verbosity_level)

                print_newline()

            Hexnet_print(f'({run_string}) History')
            Hexnet_print(
                f'({run_string}) history.history.keys()={history.history.keys()}'
            )

            if tests_dir is not None or show_results:
                visualization.visualize_results(history, run_title, tests_dir,
                                                show_results)

            print_newline()

        ########################################################################
        # Evaluate the model and save test results
        ########################################################################

        Hexnet_print(f'({run_string}) Test')

        if model_is_standalone:
            model.evaluate(test_data,
                           test_labels,
                           batch_size,
                           epochs=10,
                           tests_dir=tests_dir,
                           run_title=run_title)
        elif model_is_autoencoder:
            test_loss = model.evaluate(test_data, test_data)
        else:
            test_loss, test_acc = model.evaluate(test_data, test_labels)

        if not model_is_standalone:
            predictions = model.predict(test_data)

            if not model_is_autoencoder:
                predictions_classes = predictions.argmax(axis=-1)
                Hexnet_print(
                    f'({run_string}) test_acc={test_acc:.8f}, test_loss={test_loss:.8f}'
                )
            else:
                Hexnet_print(f'({run_string}) test_loss={test_loss:.8f}')

            if tests_dir is not None:
                run_title_predictions = f'{run_title}_predictions'
                tests_dir_predictions = os.path.join(tests_dir,
                                                     run_title_predictions)

                if not model_is_autoencoder:
                    with open(f'{tests_dir_predictions}.csv',
                              'w') as predictions_file:
                        print(
                            'label_orig,filename,label,prediction_class,prediction',
                            file=predictions_file)

                        for label_orig, filename, label, prediction_class, prediction in zip(
                                test_labels_orig, test_filenames, test_labels,
                                predictions_classes, predictions):
                            prediction = [
                                float(format(class_confidence, '.8f'))
                                for class_confidence in prediction
                            ]
                            print(
                                f'{label_orig.decode()},{filename.decode()},{label},{prediction_class},{prediction}',
                                file=predictions_file)
                else:
                    os.makedirs(tests_dir_predictions, exist_ok=True)

                    for image_counter, (image, label) in enumerate(
                            zip(predictions, test_labels)):
                        image_filename = f'label{label}_image{image_counter}.png'
                        imsave(
                            os.path.join(tests_dir_predictions,
                                         image_filename), image)

        ########################################################################
        # Save the model
        ########################################################################

        if not model_is_standalone and tests_dir is not None:
            if save_model:
                model.save(os.path.join(tests_dir, f'{run_title}_model.h5'))

            if save_weights:
                model.save_weights(
                    os.path.join(tests_dir, f'{run_title}_weights.h5'))

        if run < runs:
            print_newline()
            print_newline()

    return 0
Beispiel #4
0
        help=
        'square to hexagonal image transformation hexagonal pixels outer radius'
    )
    parser.add_argument(
        '--transform-width',
        type=int,
        default=transform_width,
        help='square to square image transformation output width')
    parser.add_argument(
        '--transform-height',
        type=int,
        default=transform_height,
        help='square to square image transformation output height')

    return parser.parse_args(args, namespace)


################################################################################
# main
################################################################################

if __name__ == '__main__':
    args = parse_args()

    Hexnet_print(f'args={args}')
    print_newline()

    status = run(args)

    sys.exit(status)
Beispiel #5
0
def run(args):

    ############################################################################
    # Parameters
    ############################################################################

    disable_training = args.disable_training
    disable_testing = args.disable_testing
    disable_output = args.disable_output
    enable_tensorboard = args.enable_tensorboard

    model_string = args.model
    load_model = args.load_model
    load_weights = args.load_weights
    save_model = args.save_model
    save_weights = args.save_weights

    dataset = args.dataset
    create_dataset = args.create_dataset
    disable_rand = args.disable_rand
    create_h5 = args.create_h5
    resize_dataset = args.resize_dataset
    crop_dataset = args.crop_dataset
    pad_dataset = args.pad_dataset
    augment_dataset = args.augment_dataset
    augmenter_string = args.augmenter
    augmentation_level = args.augmentation_level
    augmentation_size = args.augmentation_size

    chunk_size = args.chunk_size

    output_dir = args.output_dir
    show_dataset = args.show_dataset
    visualize_dataset = args.visualize_dataset
    visualize_model = args.visualize_model
    visualize_hexagonal = args.visualize_hexagonal
    show_results = args.show_results

    optimizer = args.optimizer
    metrics = args.metrics
    batch_size = args.batch_size
    epochs = args.epochs
    loss_string = args.loss
    runs = args.runs
    validation_split = args.validation_split

    cnn_kernel_size = args.cnn_kernel_size
    cnn_pool_size = args.cnn_pool_size
    resnet_stacks = args.resnet_stacks
    resnet_n = args.resnet_n
    resnet_filter_size = args.resnet_filter_size

    verbosity_level = args.verbosity_level

    transform_s2h = args.transform_s2h
    transform_h2s = args.transform_h2s
    transform_h2h = args.transform_h2h
    transform_s2s = args.transform_s2s
    transform_s2h_rad_o = args.transform_s2h_rad_o
    transform_h2s_len = args.transform_h2s_len
    transform_h2h_rad_o = args.transform_h2h_rad_o
    transform_s2s_res = args.transform_s2s_res

    ############################################################################
    # Initialization
    ############################################################################

    if model_string:
        model_is_provided = True

        model_is_custom = True if 'custom' in model_string else False
        model_is_standalone = True if 'standalone' in model_string else False

        model_is_autoencoder = True if 'autoencoder' in model_string else False
        model_is_cnn = True if 'CNN' in model_string else False
        model_is_gan = True if 'GAN' in model_string else False
        model_is_resnet = True if 'ResNet' in model_string else False

        model_is_from_keras = True if 'keras' in model_string else False
        model_is_from_sklearn = True if 'sklearn' in model_string else False
    else:
        model_is_provided = False

    if augmenter_string:
        augmenter_is_provided = True

        augmenter_is_custom = True if 'custom' in augmenter_string else False
    else:
        augmenter_is_provided = False

    if metrics != 'auto':
        metrics_are_provided = True
    else:
        metrics_are_provided = False

    if loss_string != 'auto':
        loss_is_provided = True

        subpixel_loss_identifiers = ('s2s', 's2h')
        loss_is_subpixel_loss = True if any(
            identifier in loss_string
            for identifier in subpixel_loss_identifiers) else False

        loss_is_from_keras = loss_string.startswith('keras_')
    else:
        loss_is_provided = False

    if dataset:
        dataset = dataset.rstrip('/')

        dataset_is_provided = True

        if create_dataset: create_dataset = ast.literal_eval(create_dataset)
    else:
        dataset_is_provided = False

    if disable_output:
        output_dir = None
    elif not output_dir:
        disable_output = True

    disable_training |= epochs < 1
    enable_training = not disable_training
    enable_testing = not disable_testing
    enable_output = not disable_output
    disable_tensorboard = not enable_tensorboard
    enable_rand = not disable_rand

    train_classes = []
    train_classes_orig = []
    train_data = []
    train_filenames = []
    train_labels = []
    train_labels_orig = []
    test_classes = []
    test_classes_orig = []
    test_data = []
    test_filenames = []
    test_labels = []
    test_labels_orig = []

    classification_reports = []

    if enable_tensorboard and enable_output:
        fit_callbacks = [
            tf.keras.callbacks.TensorBoard(
                log_dir=os.path.normpath(output_dir), histogram_freq=1)
        ]
    else:
        fit_callbacks = None

    ############################################################################
    # No dataset provided - returning
    ############################################################################

    if not dataset_is_provided:
        print_newline()
        Hexnet_print('No dataset provided - returning')

        return 0

    ############################################################################
    # Create classification dataset
    ############################################################################

    if create_dataset:
        print_newline()

        datasets.create_dataset(dataset,
                                split_ratios=create_dataset,
                                randomized_assignment=enable_rand,
                                verbosity_level=verbosity_level)

        dataset = f'{dataset}_classification_dataset'

    ############################################################################
    # Transform the dataset
    ############################################################################

    if any(operation != False for operation in (transform_s2h, transform_h2s,
                                                transform_h2h, transform_s2s)):
        print_newline()
        Hexnet_init()
        print_newline()

        if transform_s2h != False:
            if transform_s2h is None:
                transform_s2h = f'{dataset}_s2h'

            datasets.transform_dataset(dataset=dataset,
                                       output_dir=transform_s2h,
                                       mode='s2h',
                                       rad_o=transform_s2h_rad_o,
                                       method=0,
                                       verbosity_level=verbosity_level)

        if transform_h2s != False:
            if transform_h2s is None:
                transform_h2s = f'{dataset}_h2s'

            datasets.transform_dataset(dataset=dataset,
                                       output_dir=transform_h2s,
                                       mode='h2s',
                                       len=transform_h2s_len,
                                       method=0,
                                       verbosity_level=verbosity_level)

        if transform_h2h != False:
            if transform_h2h is None:
                transform_h2h = f'{dataset}_h2h'

            datasets.transform_dataset(dataset=dataset,
                                       output_dir=transform_h2h,
                                       mode='h2h',
                                       rad_o=transform_h2h_rad_o,
                                       method=0,
                                       verbosity_level=verbosity_level)

        if transform_s2s != False:
            if transform_s2s is None:
                transform_s2s = f'{dataset}_s2s'

            datasets.transform_dataset(dataset=dataset,
                                       output_dir=transform_s2s,
                                       mode='s2s',
                                       res=transform_s2s_res,
                                       method=0,
                                       verbosity_level=verbosity_level)

    ############################################################################
    # Visualize the dataset
    ############################################################################

    if visualize_dataset and os.path.isfile(
            dataset) and dataset.lower().endswith('.csv'):
        print_newline()

        datasets.visualize_dataset(dataset)

    ############################################################################
    # Load the dataset
    ############################################################################

    print_newline()

    ((train_classes_orig, train_data, train_filenames, train_labels_orig),
     (test_classes_orig,  test_data,  test_filenames,  test_labels_orig)) = \
     datasets.load_dataset(dataset, create_h5, verbosity_level)

    print_newline()

    datasets.create_dataset_overview(train_classes_orig, train_labels_orig,
                                     test_labels_orig, dataset, output_dir)

    if type(train_data) is not np.ndarray:
        disable_training = True
        enable_training = not disable_training

    if type(test_data) is not np.ndarray:
        disable_testing = True
        enable_testing = not disable_testing

    ############################################################################
    # Prepare the dataset
    ############################################################################

    print_newline()
    Hexnet_print('Dataset preparation')

    class_labels_are_digits = True

    if enable_training:
        for class_label in train_classes_orig:
            if not class_label.isdigit():
                class_labels_are_digits = False
                break
    elif enable_testing:
        for class_label in test_classes_orig:
            if not class_label.isdigit():
                class_labels_are_digits = False
                break

    if enable_training:
        if class_labels_are_digits:
            train_labels = np.asarray(
                [int(label) for label in train_labels_orig.flatten()])
        else:
            train_labels = np.asarray([
                np.where(label == train_classes_orig)[0][0]
                for label in train_labels_orig.flatten()
            ])

        train_classes = np.unique(train_labels)
        train_classes_len = len(train_classes)
        train_labels = np.reshape(train_labels,
                                  newshape=train_labels_orig.shape)

        if class_labels_are_digits:
            train_classes_min = min(train_classes)

            train_classes -= train_classes_min
            train_labels -= train_classes_min

        if train_labels.ndim > 1:
            train_labels = array_to_one_hot_array(train_labels,
                                                  train_classes_len)

    if enable_testing:
        if class_labels_are_digits:
            test_labels = np.asarray(
                [int(label) for label in test_labels_orig.flatten()])
        else:
            test_labels = np.asarray([
                np.where(label == test_classes_orig)[0][0]
                for label in test_labels_orig.flatten()
            ])

        test_classes = np.unique(test_labels)
        test_classes_len = len(test_classes)
        test_labels = np.reshape(test_labels, newshape=test_labels_orig.shape)

        if class_labels_are_digits:
            test_classes_min = min(test_classes)

            test_classes -= test_classes_min
            test_labels -= test_classes_min

        if test_labels.ndim > 1:
            test_labels = array_to_one_hot_array(test_labels, test_classes_len)

    ############################################################################
    # Preprocess the dataset
    ############################################################################

    if any(operation
           for operation in (resize_dataset, crop_dataset, pad_dataset)):
        Hexnet_print('Dataset preprocessing')

        if resize_dataset:
            (train_data,
             test_data) = datasets.resize_dataset(dataset_s=(train_data,
                                                             test_data),
                                                  resize_string=resize_dataset)

        if crop_dataset:
            (train_data,
             test_data) = datasets.crop_dataset(dataset_s=(train_data,
                                                           test_data),
                                                crop_string=crop_dataset)

        if pad_dataset:
            (train_data,
             test_data) = datasets.pad_dataset(dataset_s=(train_data,
                                                          test_data),
                                               pad_string=pad_dataset)

    ############################################################################
    # Augment the dataset
    ############################################################################

    if augment_dataset:
        Hexnet_print('Dataset augmentation')

        if augmentation_size != 1:
            if 'train' in augment_dataset:
                train_data, train_filenames, train_labels, train_labels_orig = \
                 augmenters.augment_size(train_data, train_filenames, train_labels, train_labels_orig, augmentation_size)

            if 'test' in augment_dataset:
                test_data, test_filenames, test_labels, test_labels_orig = \
                 augmenters.augment_size(test_data, test_filenames, test_labels, test_labels_orig, augmentation_size)

        if augmenter_is_provided:
            augmenter = vars(augmenters)[f'augmenter_{augmenter_string}']

            if not augmenter_is_custom:
                augmenter = augmenter(augmentation_level)
            else:
                augmenter = augmenter()

            if 'train' in augment_dataset:
                train_data = augmenter(images=train_data)

            if 'test' in augment_dataset:
                test_data = augmenter(images=test_data)

    ############################################################################
    # Show the dataset
    ############################################################################

    if show_dataset:
        datasets.show_dataset(train_classes_orig,
                              train_data,
                              train_labels_orig,
                              test_classes_orig,
                              test_data,
                              test_labels_orig,
                              max_images_per_class=1,
                              max_classes_to_display=10)

    ############################################################################
    # Visualize the dataset
    ############################################################################

    if visualize_dataset:
        print_newline()

        datasets.visualize_dataset(dataset, train_classes_orig, train_data,
                                   train_filenames, train_labels_orig,
                                   test_classes_orig, test_data,
                                   test_filenames, test_labels_orig,
                                   visualize_hexagonal, create_h5,
                                   verbosity_level)

    ############################################################################
    # No model provided - returning
    ############################################################################

    if not model_is_provided:
        print_newline()
        Hexnet_print('No model provided - returning')

        return 0

    ############################################################################
    # Standardize / normalize the dataset
    ############################################################################

    if not model_is_gan:
        if visualize_dataset: print_newline()
        Hexnet_print('Dataset standardization')

        std_eps = np.finfo(np.float32).eps

        if enable_training:
            if train_data.ndim > 3:
                mean_axis = (1, 2)
            else:
                mean_axis = 1

            for chunk_start in tqdm(range(0, train_data.shape[0], chunk_size)):
                chunk_end = min(chunk_start + chunk_size, train_data.shape[0])

                train_data_mean = np.mean(train_data[chunk_start:chunk_end],
                                          axis=mean_axis,
                                          keepdims=True)

                train_data_std = np.sqrt(
                    ((train_data[chunk_start:chunk_end] -
                      train_data_mean)**2).mean(axis=mean_axis, keepdims=True))
                train_data_std[train_data_std == 0] = std_eps

                train_data[chunk_start:chunk_end] = (
                    train_data[chunk_start:chunk_end] -
                    train_data_mean) / train_data_std

        if enable_testing:
            if test_data.ndim > 3:
                mean_axis = (1, 2)
            else:
                mean_axis = 1

            for chunk_start in tqdm(range(0, test_data.shape[0], chunk_size)):
                chunk_end = min(chunk_start + chunk_size, test_data.shape[0])

                test_data_mean = np.mean(test_data[chunk_start:chunk_end],
                                         axis=mean_axis,
                                         keepdims=True)

                test_data_std = np.sqrt(
                    ((test_data[chunk_start:chunk_end] -
                      test_data_mean)**2).mean(axis=mean_axis, keepdims=True))
                test_data_std[test_data_std == 0] = std_eps

                test_data[chunk_start:chunk_end] = (
                    test_data[chunk_start:chunk_end] -
                    test_data_mean) / test_data_std
    else:
        if visualize_dataset: print_newline()
        Hexnet_print('Dataset normalization')

        data_min = min(train_data.min(), test_data.min())
        data_max = max(train_data.max(), test_data.max())
        normalization_factor = data_max - data_min

        if enable_training:
            for chunk_start in tqdm(range(0, train_data.shape[0], chunk_size)):
                train_data[chunk_start:chunk_end] = (
                    train_data[chunk_start:chunk_end] -
                    data_min) / normalization_factor

        if enable_testing:
            for chunk_start in tqdm(range(0, test_data.shape[0], chunk_size)):
                test_data[chunk_start:chunk_end] = (
                    test_data[chunk_start:chunk_end] -
                    data_min) / normalization_factor

    ############################################################################
    # Shuffle the dataset
    ############################################################################

    if enable_training:
        Hexnet_print('Dataset shuffling')

        (train_data,
         train_labels) = sklearn.utils.shuffle(train_data, train_labels)

    ############################################################################
    # Start a new training and test run
    ############################################################################

    dataset = os.path.basename(dataset)
    tests_title = f'{model_string}__{dataset}'

    if not model_is_from_sklearn:
        tests_title = f'{tests_title}__epochs{epochs}-bs{batch_size}'

    print_newline()
    print_newline()

    for run in range(1, runs + 1):

        ########################################################################
        # Current run information
        ########################################################################

        run_string = f'run={run}/{runs}'
        timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
        run_title = f'{tests_title}__{timestamp}_run{run}'

        ########################################################################
        # Initialize the model
        ########################################################################

        Hexnet_print(f'({run_string}) Model initialization')

        if enable_training:
            input_shape = train_data.shape[1:4]
            classes = train_classes_len
        elif enable_testing:
            input_shape = test_data.shape[1:4]
            classes = test_classes_len

        if not load_model:
            model = vars(models)[f'model_{model_string}']

            if model_is_custom or model_is_standalone:
                model = model(input_shape, classes)
            elif model_is_autoencoder:
                model = model(input_shape)
            elif model_is_cnn:
                model = model(input_shape, classes, cnn_kernel_size,
                              cnn_pool_size)
            elif model_is_resnet and not model_is_from_keras:
                model = model(input_shape, classes, resnet_stacks, resnet_n,
                              resnet_filter_size)
            elif model_is_from_sklearn:
                model = model()
            else:
                model = model(input_shape, classes)
        elif not (model_is_standalone or model_is_from_sklearn):
            model = tf.keras.models.load_model(load_model)

        if load_weights and not (model_is_standalone or model_is_from_sklearn):
            model.load_weights(load_weights)

        ########################################################################
        # Initialize loss and metrics
        ########################################################################

        if not (model_is_standalone or model_is_from_sklearn):
            Hexnet_print(f'({run_string}) Loss and metrics initialization')

            if not metrics_are_provided:
                if not model_is_autoencoder:
                    metrics = ['accuracy']
                else:
                    metrics = []

            if not loss_is_provided:
                if not model_is_autoencoder:
                    loss = 'SparseCategoricalCrossentropy'
                else:
                    loss = 'MeanSquaredError'

                loss = tf.losses.get(loss)
            else:
                if not loss_is_subpixel_loss:
                    if not loss_is_from_keras:
                        loss = vars(losses)[f'loss_{loss_string}']()
                    else:
                        loss = vars(
                            tf.keras.losses)[loss_string[len('keras_'):]]()
                else:
                    output_shape = test_data.shape[1:4]
                    loss = vars(losses)[f'loss_{loss_string}'](input_shape,
                                                               output_shape)

        ########################################################################
        # Compile the model
        ########################################################################

        if not model_is_from_sklearn:
            Hexnet_print(f'({run_string}) Model compilation')

            if not model_is_standalone:
                model.compile(optimizer, loss, metrics)
            else:
                model.compile()

        ########################################################################
        # Model summary
        ########################################################################

        print_newline()
        Hexnet_print(f'({run_string}) Model summary')

        if not model_is_from_sklearn:
            model.summary()
        else:
            Hexnet_print(model.get_params())

        ########################################################################
        # Train the model
        ########################################################################

        if enable_training:
            print_newline()
            Hexnet_print(f'({run_string}) Model training')

            if model_is_standalone:
                model.fit(train_data, train_labels, batch_size, epochs,
                          visualize_hexagonal, output_dir, run_title)
            elif model_is_autoencoder:
                history = model.fit(train_data,
                                    train_data,
                                    batch_size,
                                    epochs,
                                    validation_split=validation_split)
            elif model_is_from_sklearn:
                model.fit(
                    np.reshape(train_data, newshape=(train_data.shape[0], -1)),
                    train_labels)
            else:
                history = model.fit(train_data,
                                    train_labels,
                                    batch_size,
                                    epochs,
                                    validation_split=validation_split,
                                    callbacks=fit_callbacks)

        ########################################################################
        # Visualize filters, feature maps, activations, and training results
        ########################################################################

        if not (model_is_standalone or model_is_from_sklearn):
            if visualize_model and enable_output:
                print_newline()
                Hexnet_print(f'({run_string}) Visualization')

                output_dir_visualizations = os.path.join(
                    output_dir, f'{run_title}_visualizations')

                visualization.visualize_model(
                    model,
                    test_classes,
                    test_data,
                    test_labels,
                    visualize_hexagonal,
                    output_dir=output_dir_visualizations,
                    max_images_per_class=10,
                    verbosity_level=verbosity_level)

            if enable_training:
                print_newline()
                Hexnet_print(f'({run_string}) History')
                Hexnet_print(
                    f'({run_string}) history.history.keys()={history.history.keys()}'
                )

                if enable_output or show_results:
                    visualization.visualize_training_results(
                        history, run_title, output_dir, show_results)

        ########################################################################
        # Evaluate the model
        ########################################################################

        if enable_testing and not model_is_from_sklearn:
            print_newline()
            Hexnet_print(f'({run_string}) Model evaluation')

            if model_is_standalone:
                model.evaluate(test_data,
                               test_labels,
                               batch_size,
                               epochs=10,
                               visualize_hexagonal=visualize_hexagonal,
                               output_dir=output_dir,
                               run_title=run_title)
            elif model_is_autoencoder:
                test_loss_metrics = model.evaluate(test_data, test_data)
            else:
                test_loss_metrics = model.evaluate(test_data, test_labels)

            if not model_is_standalone:
                Hexnet_print(
                    f'({run_string}) test_loss_metrics={test_loss_metrics}')

        ########################################################################
        # Save test results
        ########################################################################

        if enable_testing and enable_output and not model_is_standalone:
            print_newline()
            Hexnet_print(f'({run_string}) Saving test results')

            if not model_is_from_sklearn:
                predictions = model.predict(test_data)
            else:
                predictions = model.predict_proba(
                    np.reshape(test_data, newshape=(test_data.shape[0], -1)))

            if not model_is_autoencoder:
                classification_report = visualization.visualize_test_results(
                    predictions, test_classes, test_classes_orig,
                    test_filenames, test_labels, test_labels_orig, run_title,
                    output_dir)

                print_newline()
                Hexnet_print(f'({run_string}) Classification report')
                pprint(classification_report)

                classification_reports.append(classification_report)
            else:
                loss_newshape = (test_data.shape[0], -1)
                test_losses = loss(
                    np.reshape(test_data, newshape=loss_newshape),
                    np.reshape(predictions, newshape=loss_newshape))

                output_dir_predictions = os.path.join(
                    output_dir, f'{run_title}_predictions')
                os.makedirs(output_dir_predictions, exist_ok=True)

                with open(f'{output_dir_predictions}.csv',
                          'w') as predictions_file:
                    print('label_orig,filename,label,loss',
                          file=predictions_file)

                    for label_orig, filename, label, loss in zip(
                            test_labels_orig, test_filenames, test_labels,
                            test_losses):
                        loss = float(format(loss, '.8f'))
                        print(f'{label_orig},{filename},{label},{loss}',
                              file=predictions_file)

                for image_counter, (image, label) in enumerate(
                        zip(tqdm(predictions), test_labels)):
                    image_filename = f'label{label}_image{image_counter}.png'

                    if not visualize_hexagonal:
                        imsave(
                            os.path.join(output_dir_predictions,
                                         image_filename), image)
                    else:
                        visualization.visualize_hexarray(
                            image,
                            os.path.join(output_dir_predictions,
                                         image_filename))

        ########################################################################
        # Save the model
        ########################################################################

        if (save_model or save_weights) and enable_output and not (
                model_is_standalone or model_is_from_sklearn):
            print_newline()
            Hexnet_print(f'({run_string}) Saving the model')

            if save_model:
                model.save(os.path.join(output_dir, f'{run_title}_model.h5'))

            if save_weights:
                model.save_weights(
                    os.path.join(output_dir, f'{run_title}_weights.h5'))

        if run < runs:
            print_newline()
            print_newline()

    ############################################################################
    # Save global test results
    ############################################################################

    if enable_testing and enable_output and runs > 1 and not (
            model_is_standalone or model_is_autoencoder):
        timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
        tests_title = f'{tests_title}__{timestamp}'

        visualization.visualize_global_test_results(classification_reports,
                                                    tests_title, output_dir)

    return 0