Beispiel #1
0
    def summary(self):
        Hexnet_print('Generator')
        self.generator_for_summary.summary()

        print_newline()
        Hexnet_print('Discriminator')
        self.discriminator_for_summary.summary()
Beispiel #2
0
def transform_dataset(dataset,
                      output_dir,
                      mode='s2h',
                      rad_o=1.0,
                      width=64,
                      height=None,
                      method=0,
                      verbosity_level=2):

    if os.path.exists(output_dir):
        Hexnet_print(
            f'Dataset {output_dir} exists already (skipping transformation)')
        return

    if os.path.isfile(f'{output_dir}.h5'):
        Hexnet_print(
            f'Dataset {output_dir}.h5 exists already (skipping transformation)'
        )
        return

    Hexnet_print(
        f'Transforming dataset {dataset} with mode {mode} to {output_dir}')

    os.makedirs(output_dir, exist_ok=True)

    for dataset_set in natsorted(glob(os.path.join(dataset, '*'))):
        current_set = os.path.basename(dataset_set)

        if verbosity_level >= 1:
            print(f'\t> current_set={current_set}')

        output_dir_current_set = os.path.join(output_dir, current_set)
        os.makedirs(output_dir_current_set, exist_ok=True)

        for set_class in natsorted(glob(os.path.join(dataset_set, '*'))):
            current_class = os.path.basename(set_class)

            if verbosity_level >= 2:
                print(f'\t\t> current_class={current_class}')

            output_dir_current_class = os.path.join(output_dir_current_set,
                                                    current_class)
            os.makedirs(output_dir_current_class, exist_ok=True)

            if mode == 's2h':
                Hexsamp_s2h(
                    filename_s=os.path.join(set_class, '*'),
                    output_dir=output_dir_current_class,
                    rad_o=rad_o,
                    method=method,
                    increase_verbosity=True if verbosity_level >= 3 else False)
            else:
                Sqsamp_s2s(
                    filename_s=os.path.join(set_class, '*'),
                    output_dir=output_dir_current_class,
                    width=width,
                    height=height,
                    method=method,
                    increase_verbosity=True if verbosity_level >= 3 else False)
Beispiel #3
0
def visualize_feature_maps(
	model,
	test_classes,
	test_data,
	test_labels,
	visualize_hexagonal,
	output_dir,
	max_images_per_class = 10,
	verbosity_level      =  2):

	feature_map_to_visualize = 0

	os.makedirs(output_dir, exist_ok=True)

	test_data_for_prediction   = []
	test_labels_for_prediction = []

	class_counter_dict = dict.fromkeys(test_classes, 0)

	for image, label in zip(test_data, test_labels):
		if class_counter_dict[label] < max_images_per_class:
			test_data_for_prediction.append(image)
			test_labels_for_prediction.append(label)
			class_counter_dict[label] += 1

	test_data_for_prediction = np.asarray(test_data_for_prediction)

	layers_outputs = [layer.output for layer in model.layers]

	if verbosity_level >= 3:
		Hexnet_print(f'(visualize_feature_maps) layers_outputs={layers_outputs}')

	model_outputs = Model(inputs=model.input, outputs=layers_outputs)

	if verbosity_level >= 3:
		Hexnet_print(f'(visualize_feature_maps) model_outputs={model_outputs}')

	predictions = model_outputs.predict(test_data_for_prediction)

	if verbosity_level >= 3:
		Hexnet_print(f'(visualize_feature_maps) predictions={predictions}')

	for layer_counter, (layer, feature_maps) in enumerate(zip(model.layers, predictions)):
		if feature_maps.ndim != 4:
			continue

		if verbosity_level >= 2:
			Hexnet_print(f'(visualize_feature_maps) layer={layer} (layer_counter={layer_counter}, layer.name={layer.name}): feature_maps.shape={feature_maps.shape}')

		for feature_map_counter, (feature_map, label) in enumerate(zip(tqdm(feature_maps), test_labels_for_prediction)):
			title = f'layer{str(layer_counter).zfill(3)}_{layer.name}_label{str(label).zfill(3)}_image{str(feature_map_counter).zfill(3)}_featuremap{feature_map_to_visualize}'
			feature_map_filename = os.path.join(output_dir, title)
			feature_map          = normalize_array(feature_map[:, :, feature_map_to_visualize])

			imsave(f'{feature_map_filename}.png', feature_map, cmap='viridis')

			if visualize_hexagonal:
				visualize_hexarray(feature_map, feature_map_filename, colormap='viridis')
Beispiel #4
0
def visualize_feature_maps(model,
                           test_classes,
                           test_data,
                           test_labels,
                           output_dir,
                           max_images_per_class=10,
                           verbosity_level=2):

    feature_map_to_visualize = 0

    os.makedirs(output_dir, exist_ok=True)

    layers_outputs = [layer.output for layer in model.layers]

    if verbosity_level >= 3:
        Hexnet_print(
            f'(visualize_feature_maps) layers_outputs={layers_outputs}')

    model_outputs = Model(inputs=model.input, outputs=layers_outputs)

    if verbosity_level >= 3:
        Hexnet_print(f'(visualize_feature_maps) model_outputs={model_outputs}')

    predictions = model_outputs.predict(test_data)

    if verbosity_level >= 3:
        Hexnet_print(f'(visualize_feature_maps) predictions={predictions}')

    for layer_counter, (layer, feature_maps) in enumerate(
            zip(model.layers, predictions)):
        if feature_maps.ndim != 4:
            continue

        if verbosity_level >= 2:
            Hexnet_print(
                f'(visualize_feature_maps) layer={layer} (layer_counter={layer_counter}, layer.name={layer.name}): feature_maps.shape={feature_maps.shape}'
            )

        class_counter_dict = dict.fromkeys(test_classes, 0)

        for feature_map_counter, (feature_map, label) in enumerate(
                zip(feature_maps, test_labels)):
            if class_counter_dict[label] < max_images_per_class:
                feature_map_filename = os.path.join(
                    output_dir,
                    f'layer{str(layer_counter).zfill(3)}_{layer.name}_fm{feature_map_to_visualize}_label{label}_image{feature_map_counter}.png'
                )
                feature_map = feature_map[:, :, feature_map_to_visualize]
                imsave(feature_map_filename, feature_map, cmap='viridis')
                class_counter_dict[label] += 1
Beispiel #5
0
    def sample_images(self, epoch, tests_dir, run_title,
                      images_to_sample_per_class):
        Hexnet_print('Sampling images')

        r, c = images_to_sample_per_class, self.num_classes
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        sampled_labels = np.array([num for _ in range(r) for num in range(c)])
        gen_imgs = self.generator.predict([noise, sampled_labels])
        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        tests_dir_samples = os.path.join(tests_dir, run_title)
        os.makedirs(tests_dir_samples, exist_ok=True)

        for image_counter, (image,
                            label) in enumerate(zip(gen_imgs, sampled_labels)):
            image_filename = f'epoch{epoch}_label{label}_image{image_counter}.png'
            imsave(os.path.join(tests_dir_samples, image_filename), image)
Beispiel #6
0
def visualize_filters(model, visualize_hexagonal, output_dir, verbosity_level=2):
	filter_to_visualize = 0

	os.makedirs(output_dir, exist_ok=True)

	for layer_counter, layer in enumerate(model.layers):
		if 'conv' not in layer.name:
			continue

		filters = layer.get_weights()[0]

		if verbosity_level >= 2:
			Hexnet_print(f'(visualize_filters) layer={layer} (layer_counter={layer_counter}, layer.name={layer.name}): filters.shape={filters.shape}')

		for channel in tqdm(range(filters.shape[2])):
			filter_filename = os.path.join(output_dir, f'layer{str(layer_counter).zfill(3)}_{layer.name}_channel{str(channel).zfill(3)}_filter{filter_to_visualize}')
			filter          = normalize_array(filters[:, :, channel, filter_to_visualize])

			imsave(f'{filter_filename}.png', filter, cmap='viridis')

			if visualize_hexagonal:
				visualize_hexarray(filter, filter_filename, colormap='viridis')
Beispiel #7
0
def load_dataset(dataset, create_h5=True, verbosity_level=2):
    Hexnet_print(f'Loading dataset {dataset}')

    train_classes = []
    train_data = []
    train_filenames = []
    train_labels = []
    test_classes = []
    test_data = []
    test_filenames = []
    test_labels = []

    if os.path.isfile(dataset) and dataset.endswith('.h5'):
        start_time = time()

        with h5py.File(dataset, 'r') as h5py_file:
            train_classes = np.array(h5py_file['train_classes'])
            train_data = np.array(h5py_file['train_data'])
            train_filenames = np.array(h5py_file['train_filenames'])
            train_labels = np.array(h5py_file['train_labels'])
            test_classes = np.array(h5py_file['test_classes'])
            test_data = np.array(h5py_file['test_data'])
            test_filenames = np.array(h5py_file['test_filenames'])
            test_labels = np.array(h5py_file['test_labels'])

        time_diff = time() - start_time

        Hexnet_print(f'Loaded dataset {dataset} in {time_diff:.3f} seconds')
    else:
        start_time = time()

        for dataset_set in natsorted(glob(os.path.join(dataset, '*'))):
            current_set = os.path.basename(dataset_set)

            if verbosity_level >= 1:
                Hexnet_print(f'\t> current_set={current_set}')

            for set_class in natsorted(glob(os.path.join(dataset_set, '*'))):
                current_class = os.path.basename(set_class)

                if verbosity_level >= 2:
                    Hexnet_print(f'\t\t> current_class={current_class}')

                if 'train' in current_set:
                    train_classes.append(current_class)
                elif 'test' in current_set:
                    test_classes.append(current_class)

                for class_image in natsorted(glob(os.path.join(set_class,
                                                               '*'))):
                    current_image = os.path.basename(class_image)

                    if verbosity_level >= 3:
                        Hexnet_print(f'\t\t\t> current_image={current_image}')

                    if 'train' in current_set:
                        train_data.append(
                            cv2.imread(class_image, cv2.IMREAD_COLOR))
                        train_filenames.append(current_image)
                        train_labels.append(current_class)
                    elif 'test' in current_set:
                        test_data.append(
                            cv2.imread(class_image, cv2.IMREAD_COLOR))
                        test_filenames.append(current_image)
                        test_labels.append(current_class)

        time_diff = time() - start_time

        Hexnet_print(f'Loaded dataset {dataset} in {time_diff:.3f} seconds')

        if create_h5:
            dataset = f'{dataset}.h5'
            train_classes = np.array(train_classes, dtype='string_')
            train_data = np.array(train_data)
            train_filenames = np.array(train_filenames, dtype='string_')
            train_labels = np.array(train_labels, dtype='string_')
            test_classes = np.array(test_classes, dtype='string_')
            test_data = np.array(test_data)
            test_filenames = np.array(test_filenames, dtype='string_')
            test_labels = np.array(test_labels, dtype='string_')

            create_dataset_h5(dataset, train_classes, train_data,
                              train_filenames, train_labels, test_classes,
                              test_data, test_filenames, test_labels)

    return ((train_classes, train_data, train_filenames, train_labels),
            (test_classes, test_data, test_filenames, test_labels))
Beispiel #8
0
def run(args):

    model_string = args.model
    load_model = args.load_model
    load_weights = args.load_weights
    save_model = args.save_model
    save_weights = args.save_weights

    dataset = args.dataset
    resize_dataset = args.resize_dataset
    crop_dataset = args.crop_dataset
    augment_dataset = args.augment_dataset
    augmenter_string = args.augmenter
    augmentation_level = args.augmentation_level

    tests_dir = args.tests_dir
    show_dataset = args.show_dataset
    visualize_model = args.visualize_model
    show_results = args.show_results

    batch_size = args.batch_size
    epochs = args.epochs
    loss_string = args.loss
    runs = args.runs
    validation_split = args.validation_split

    cnn_kernel_size = args.cnn_kernel_size
    cnn_pool_size = args.cnn_pool_size

    verbosity_level = args.verbosity_level

    transform_s2h = args.transform_s2h
    transform_s2s = args.transform_s2s
    transform_rad_o = args.transform_rad_o
    transform_width = args.transform_width
    transform_height = args.transform_height

    if model_string is not None:
        model_is_provided = True

        model_is_custom = True if 'custom' in model_string else False
        model_is_standalone = True if 'standalone' in model_string else False

        model_is_autoencoder = True if 'autoencoder' in model_string else False
        model_is_CNN = True if 'CNN' in model_string else False
        model_is_GAN = True if 'GAN' in model_string else False
    else:
        model_is_provided = False

    if augmenter_string is not None:
        augmenter_is_custom = True if 'custom' in augmenter_string else False

    if loss_string is not None:
        loss_is_provided = True

        loss_is_subpixel_loss = True if ('s2s' in loss_string
                                         or 's2h' in loss_string) else False
    else:
        loss_is_provided = False

    train_classes = []
    train_data = []
    train_filenames = []
    train_labels = []
    train_labels_orig = []
    test_classes = []
    test_data = []
    test_filenames = []
    test_labels = []
    test_labels_orig = []

    ############################################################################
    # Transform the dataset
    ############################################################################

    if transform_s2h or transform_s2h is None or transform_s2s or transform_s2s is None:
        Hexnet_init()

        Hexnet_print('Dataset transformation')

        if transform_s2h or transform_s2h is None:
            if transform_s2h is None:
                transform_s2h = f'{dataset}_s2h'

            datasets.transform_dataset(dataset=dataset,
                                       output_dir=transform_s2h,
                                       mode='s2h',
                                       rad_o=transform_rad_o,
                                       method=0,
                                       verbosity_level=verbosity_level)

        if transform_s2s or transform_s2s is None:
            if transform_s2s is None:
                transform_s2s = f'{dataset}_s2s'

            datasets.transform_dataset(dataset=dataset,
                                       output_dir=transform_s2s,
                                       mode='s2s',
                                       width=transform_width,
                                       height=transform_height,
                                       method=0,
                                       verbosity_level=verbosity_level)

    ############################################################################
    # Load the dataset
    ############################################################################

    ((train_classes, train_data, train_filenames, train_labels_orig),
     (test_classes, test_data, test_filenames,
      test_labels_orig)) = datasets.load_dataset(
          dataset=dataset, create_h5=True, verbosity_level=verbosity_level)

    ############################################################################
    # Resize and crop the dataset
    ############################################################################

    if resize_dataset is not None:
        (train_data,
         test_data) = datasets.resize_dataset(dataset_s=(train_data,
                                                         test_data),
                                              resize_string=resize_dataset)

    if crop_dataset is not None:
        (train_data,
         test_data) = datasets.crop_dataset(dataset_s=(train_data, test_data),
                                            crop_string=crop_dataset)

    # TODO
    if model_is_provided and (model_is_autoencoder or model_is_GAN):
        if model_is_autoencoder:
            min_size_factor = 2**5
        else:
            min_size_factor = 2**4

        if train_data.shape[1] % min_size_factor:
            padding_h = min_size_factor - train_data.shape[1] % min_size_factor
            padding_h = (int(padding_h / 2) + padding_h % 2,
                         int(padding_h / 2))
        else:
            padding_h = (0, 0)

        if train_data.shape[2] % min_size_factor:
            padding_w = min_size_factor - train_data.shape[2] % min_size_factor
            padding_w = (int(padding_w / 2) + padding_w % 2,
                         int(padding_w / 2))
        else:
            padding_w = (0, 0)

        pad_width = ((0, 0), padding_h, padding_w, (0, 0))

        train_data = np.pad(train_data,
                            pad_width,
                            mode='constant',
                            constant_values=0)
        test_data = np.pad(test_data,
                           pad_width,
                           mode='constant',
                           constant_values=0)

    ############################################################################
    # Prepare the dataset
    ############################################################################

    class_labels_are_digits = True

    for class_label in train_classes:
        if not class_label.decode().isdigit():
            class_labels_are_digits = False
            break

    if class_labels_are_digits:
        train_labels = np.asarray(
            [int(label.decode()) for label in train_labels_orig])
        test_labels = np.asarray(
            [int(label.decode()) for label in test_labels_orig])
    else:
        train_labels = np.asarray([
            int(np.where(train_classes == label)[0])
            for label in train_labels_orig
        ])
        test_labels = np.asarray([
            int(np.where(test_classes == label)[0])
            for label in test_labels_orig
        ])

    train_classes = list(set(train_labels))
    test_classes = list(set(test_labels))

    if class_labels_are_digits:
        train_labels -= min(train_classes)
        test_labels -= min(test_classes)
        train_classes -= min(train_classes)
        test_classes -= min(test_classes)

    train_test_data_n = 255
    train_test_data_n /= 2
    train_data = (train_data - train_test_data_n) / train_test_data_n
    test_data = (test_data - train_test_data_n) / train_test_data_n

    ############################################################################
    # Augment the dataset
    ############################################################################

    print_newline()

    if augment_dataset is not None:
        Hexnet_print('Dataset augmentation')

        augmenter = vars(augmenters)[f'augmenter_{augmenter_string}']

        if augmenter_is_custom:
            augmenter = augmenter()
        else:
            augmenter = augmenter(augmentation_level)

        if 'train' in augment_dataset:
            train_data = augmenter(images=train_data)

        if 'test' in augment_dataset:
            test_data = augmenter(images=test_data)

        print_newline()

    print_newline()

    ############################################################################
    # Show the dataset
    ############################################################################

    if show_dataset:
        train_data_for_visualization = np.clip(train_data + 0.5, 0, 1)
        test_data_for_visualization = np.clip(test_data + 0.5, 0, 1)

        datasets.show_dataset_classes(train_classes,
                                      train_data_for_visualization,
                                      train_labels,
                                      test_classes,
                                      test_data_for_visualization,
                                      test_labels,
                                      max_images_per_class=1,
                                      max_classes_to_display=10)

    ############################################################################
    # No model was provided - returning
    ############################################################################

    if not model_is_provided:
        Hexnet_print('No model provided.')
        return 0

    ############################################################################
    # Shuffle the dataset
    ############################################################################

    train_data, train_labels = sklearn.utils.shuffle(train_data, train_labels)

    ############################################################################
    # Start a new run
    ############################################################################

    for run in range(1, runs + 1):
        run_string = f'run={run}/{runs}'

        dataset = os.path.basename(dataset)
        timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')

        run_title = f'{model_string}_{dataset}_{timestamp}_epochs{epochs}-bs{batch_size}'

        if runs > 1:
            run_title = f'{run_title}_run{run}'

        ########################################################################
        # Initialize / load the model
        ########################################################################

        Hexnet_print(f'({run_string}) Model initialization')

        input_shape = train_data.shape[1:4]
        output_shape = test_data.shape[1:4]
        classes = len(train_classes)

        if load_model is None:
            model = vars(models)[f'model_{model_string}']

            if model_is_custom or model_is_standalone:
                model = model(input_shape, classes)
            elif model_is_autoencoder:
                model = model(input_shape)
            elif model_is_CNN:
                model = model(input_shape, classes, cnn_kernel_size,
                              cnn_pool_size)
            else:
                model = model(input_shape, classes)
        else:
            model = tf.keras.models.load_model(load_model)

        if load_weights is not None:
            model.load_weights(load_weights)

        print_newline()

        ########################################################################
        # Fit the model
        ########################################################################

        if not loss_is_provided:
            if not model_is_autoencoder:
                loss = 'sparse_categorical_crossentropy'
            else:
                loss = 'mse'
        else:
            if not loss_is_subpixel_loss:
                loss = vars(losses)[f'loss_{loss_string}']()
            else:
                loss = vars(losses)[f'loss_{loss_string}'](input_shape,
                                                           output_shape)

        if not model_is_autoencoder:
            metrics = ['accuracy']
        else:
            metrics = None

        if not model_is_standalone:
            model.compile(optimizer='adam', loss=loss, metrics=metrics)
        else:
            model.compile()

        Hexnet_print(f'({run_string}) Model summary')
        model.summary()
        print_newline()

        Hexnet_print(f'({run_string}) Training')

        if model_is_standalone:
            model.fit(train_data, train_labels, batch_size, epochs, tests_dir,
                      run_title)
        elif model_is_autoencoder:
            history = model.fit(train_data,
                                train_data,
                                batch_size,
                                epochs,
                                validation_split=validation_split)
        else:
            history = model.fit(train_data,
                                train_labels,
                                batch_size,
                                epochs,
                                validation_split=validation_split)

        print_newline()

        ########################################################################
        # Visualize filters, feature maps, and training results
        ########################################################################

        if not model_is_standalone:
            if visualize_model is not None:
                Hexnet_print(f'({run_string}) Visualization')

                visualization.visualize_model(model,
                                              test_classes,
                                              test_data,
                                              test_labels,
                                              output_dir=visualize_model,
                                              max_images_per_class=10,
                                              verbosity_level=verbosity_level)

                print_newline()

            Hexnet_print(f'({run_string}) History')
            Hexnet_print(
                f'({run_string}) history.history.keys()={history.history.keys()}'
            )

            if tests_dir is not None or show_results:
                visualization.visualize_results(history, run_title, tests_dir,
                                                show_results)

            print_newline()

        ########################################################################
        # Evaluate the model and save test results
        ########################################################################

        Hexnet_print(f'({run_string}) Test')

        if model_is_standalone:
            model.evaluate(test_data,
                           test_labels,
                           batch_size,
                           epochs=10,
                           tests_dir=tests_dir,
                           run_title=run_title)
        elif model_is_autoencoder:
            test_loss = model.evaluate(test_data, test_data)
        else:
            test_loss, test_acc = model.evaluate(test_data, test_labels)

        if not model_is_standalone:
            predictions = model.predict(test_data)

            if not model_is_autoencoder:
                predictions_classes = predictions.argmax(axis=-1)
                Hexnet_print(
                    f'({run_string}) test_acc={test_acc:.8f}, test_loss={test_loss:.8f}'
                )
            else:
                Hexnet_print(f'({run_string}) test_loss={test_loss:.8f}')

            if tests_dir is not None:
                run_title_predictions = f'{run_title}_predictions'
                tests_dir_predictions = os.path.join(tests_dir,
                                                     run_title_predictions)

                if not model_is_autoencoder:
                    with open(f'{tests_dir_predictions}.csv',
                              'w') as predictions_file:
                        print(
                            'label_orig,filename,label,prediction_class,prediction',
                            file=predictions_file)

                        for label_orig, filename, label, prediction_class, prediction in zip(
                                test_labels_orig, test_filenames, test_labels,
                                predictions_classes, predictions):
                            prediction = [
                                float(format(class_confidence, '.8f'))
                                for class_confidence in prediction
                            ]
                            print(
                                f'{label_orig.decode()},{filename.decode()},{label},{prediction_class},{prediction}',
                                file=predictions_file)
                else:
                    os.makedirs(tests_dir_predictions, exist_ok=True)

                    for image_counter, (image, label) in enumerate(
                            zip(predictions, test_labels)):
                        image_filename = f'label{label}_image{image_counter}.png'
                        imsave(
                            os.path.join(tests_dir_predictions,
                                         image_filename), image)

        ########################################################################
        # Save the model
        ########################################################################

        if not model_is_standalone and tests_dir is not None:
            if save_model:
                model.save(os.path.join(tests_dir, f'{run_title}_model.h5'))

            if save_weights:
                model.save_weights(
                    os.path.join(tests_dir, f'{run_title}_weights.h5'))

        if run < runs:
            print_newline()
            print_newline()

    return 0
Beispiel #9
0
        help=
        'square to hexagonal image transformation hexagonal pixels outer radius'
    )
    parser.add_argument(
        '--transform-width',
        type=int,
        default=transform_width,
        help='square to square image transformation output width')
    parser.add_argument(
        '--transform-height',
        type=int,
        default=transform_height,
        help='square to square image transformation output height')

    return parser.parse_args(args, namespace)


################################################################################
# main
################################################################################

if __name__ == '__main__':
    args = parse_args()

    Hexnet_print(f'args={args}')
    print_newline()

    status = run(args)

    sys.exit(status)
Beispiel #10
0
def load_dataset(dataset, create_h5=False, verbosity_level=2):
    Hexnet_print(f'Loading dataset {dataset}')

    start_time = time()

    loaded_dataset = False

    train_classes = []
    train_data = []
    train_filenames = []
    train_labels = []
    test_classes = []
    test_data = []
    test_filenames = []
    test_labels = []

    # Determine the type of dataset

    dataset_is_file = False
    dataset_is_dir_of_files = False
    dataset_is_dir_of_dirs = False

    dataset_has_allowed_filetype = False

    allowed_dataset_filetypes = ('.h5')
    allowed_dataset_set_filetypes = ('.csv', '.h5', '.npy')
    dataset_set_filetypes_to_ignore = ('.log', '.md', '.txt')

    if os.path.isfile(dataset):
        dataset_is_file = True

        if dataset.lower().endswith(allowed_dataset_filetypes):
            dataset_has_allowed_filetype = True
    else:
        for dataset_set in glob(os.path.join(dataset, '*')):
            if os.path.isfile(dataset_set):
                dataset_set_lower = dataset_set.lower()

                if dataset_set_lower.endswith(
                        allowed_dataset_set_filetypes
                ) and not dataset_set_lower.endswith(
                        dataset_set_filetypes_to_ignore):
                    dataset_is_dir_of_files = True
                    dataset_is_dir_of_dirs = False

                    dataset_has_allowed_filetype = True

                    break
            else:
                dataset_is_dir_of_dirs = True

    # Load the dataset

    # Dataset is file
    if dataset_is_file and dataset_has_allowed_filetype:
        with h5py.File(dataset, 'r') as h5py_file:
            train_classes = np.asarray(h5py_file['train_classes']).astype('U')
            train_data = np.asarray(h5py_file['train_data'])
            train_filenames = np.asarray(
                h5py_file['train_filenames']).astype('U')
            train_labels = np.asarray(h5py_file['train_labels']).astype('U')
            test_classes = np.asarray(h5py_file['test_classes']).astype('U')
            test_data = np.asarray(h5py_file['test_data'])
            test_filenames = np.asarray(
                h5py_file['test_filenames']).astype('U')
            test_labels = np.asarray(h5py_file['test_labels']).astype('U')

        loaded_dataset = True

    # Dataset is directory of files
    elif dataset_is_dir_of_files and dataset_has_allowed_filetype:
        for dataset_set in glob(os.path.join(dataset, '*')):
            current_set = os.path.basename(dataset_set)

            if verbosity_level >= 1:
                Hexnet_print(f'\t> current_set={current_set}')

            current_set_lower = current_set.lower()

            if 'train' in current_set:
                if current_set_lower.endswith('.csv'):
                    file_data = pd.read_csv(dataset_set)

                    set_is_multilabel_set = type(file_data['label'][0]) is str

                    if not set_is_multilabel_set:
                        train_labels = np.asarray(
                            file_data['label']).astype('U')
                    else:
                        train_labels = np.asarray([
                            row.split(',') for row in file_data['label']
                        ]).astype('U')

                    train_classes = np.unique(train_labels)
                    train_data = np.asarray([
                        np.fromstring(row, sep=',')
                        for row in file_data['data']
                    ])
                    train_filenames = np.asarray(
                        file_data['filename']).astype('U')
                elif current_set_lower.endswith('.h5'):
                    with h5py.File(dataset_set, 'r') as h5py_file:
                        train_classes = np.asarray(
                            h5py_file['train_classes']).astype('U')
                        train_data = np.asarray(h5py_file['train_data'])
                        train_filenames = np.asarray(
                            h5py_file['train_filenames']).astype('U')
                        train_labels = np.asarray(
                            h5py_file['train_labels']).astype('U')
                else:  # npy
                    file_data = np.load(dataset_set, allow_pickle=True)

                    train_labels = np.asarray(file_data[0]).astype('U')
                    train_classes = np.unique(train_labels)
                    train_data = np.stack(file_data[2])
                    train_filenames = np.asarray(file_data[1]).astype('U')
            elif 'test' in current_set:
                if current_set_lower.endswith('.csv'):
                    file_data = pd.read_csv(dataset_set)

                    set_is_multilabel_set = type(file_data['label'][0]) is str

                    if not set_is_multilabel_set:
                        test_labels = np.asarray(
                            file_data['label']).astype('U')
                    else:
                        test_labels = np.asarray([
                            row.split(',') for row in file_data['label']
                        ]).astype('U')

                    test_classes = np.unique(train_labels)
                    test_data = np.asarray([
                        np.fromstring(row, sep=',')
                        for row in file_data['data']
                    ])
                    test_filenames = np.asarray(
                        file_data['filename']).astype('U')
                elif current_set_lower.endswith('.h5'):
                    with h5py.File(dataset_set, 'r') as h5py_file:
                        test_classes = np.asarray(
                            h5py_file['test_classes']).astype('U')
                        test_data = np.asarray(h5py_file['test_data'])
                        test_filenames = np.asarray(
                            h5py_file['test_filenames']).astype('U')
                        test_labels = np.asarray(
                            h5py_file['test_labels']).astype('U')
                else:  # npy
                    file_data = np.load(dataset_set, allow_pickle=True)

                    test_labels = np.asarray(file_data[0]).astype('U')
                    test_classes = np.unique(test_labels)
                    test_data = np.stack(file_data[2])
                    test_filenames = np.asarray(file_data[1]).astype('U')

        loaded_dataset = True

    # Dataset is directory of directories
    elif dataset_is_dir_of_dirs:
        for dataset_set in natsorted(glob(os.path.join(dataset, '*'))):
            current_set = os.path.basename(dataset_set)

            if verbosity_level >= 1:
                Hexnet_print(f'\t> current_set={current_set}')

            for set_class in natsorted(glob(os.path.join(dataset_set, '*'))):
                current_class = os.path.basename(set_class)

                if verbosity_level >= 2:
                    Hexnet_print(f'\t\t> current_class={current_class}')

                if 'train' in current_set:
                    train_classes.append(current_class)
                elif 'test' in current_set:
                    test_classes.append(current_class)

                for class_file in tqdm(
                        natsorted(glob(os.path.join(set_class, '*')))):
                    current_file = os.path.basename(class_file)

                    if verbosity_level >= 3:
                        Hexnet_print(f'\t\t\t> current_file={current_file}')

                    current_file_lower = current_file.lower()

                    if current_file_lower.endswith('.csv'):
                        file_data = np.loadtxt(class_file, delimiter=',')
                    elif current_file_lower.endswith('.npy'):
                        file_data = np.load(class_file)
                    else:
                        file_data = cv2.imread(class_file, cv2.IMREAD_COLOR)

                    if 'train' in current_set:
                        train_data.append(file_data)
                        train_filenames.append(current_file)
                        train_labels.append(current_class)
                    elif 'test' in current_set:
                        test_data.append(file_data)
                        test_filenames.append(current_file)
                        test_labels.append(current_class)

        # Zero-fill ragged nested sequences

        data_max_size = max([data.size for data in train_data + test_data])

        for data_index, data in enumerate(train_data):
            if data.size < data_max_size:
                train_data[data_index] = np.pad(train_data[data_index],
                                                (0, data_max_size - data.size),
                                                'constant',
                                                constant_values=0)

        for data_index, data in enumerate(test_data):
            if data.size < data_max_size:
                test_data[data_index] = np.pad(test_data[data_index],
                                               (0, data_max_size - data.size),
                                               'constant',
                                               constant_values=0)

        train_classes = np.asarray(train_classes)
        train_data = np.asarray(train_data)
        train_filenames = np.asarray(train_filenames)
        train_labels = np.asarray(train_labels)
        test_classes = np.asarray(test_classes)
        test_data = np.asarray(test_data)
        test_filenames = np.asarray(test_filenames)
        test_labels = np.asarray(test_labels)

        loaded_dataset = True

    # No dataset provided
    elif all(not dataset_type
             for dataset_type in (dataset_is_file, dataset_is_dir_of_files,
                                  dataset_is_dir_of_dirs)):
        Hexnet_print_warning('No dataset provided')

    # No identifiable dataset provided
    else:
        Hexnet_print_warning('No identifiable dataset provided')

    if loaded_dataset:
        time_diff = time() - start_time

        Hexnet_print(f'Loaded dataset {dataset} in {time_diff:.3f} seconds')

    if create_h5:
        dataset = f'{dataset}.h5'

        create_dataset_h5(dataset, train_classes, train_data, train_filenames,
                          train_labels, test_classes, test_data,
                          test_filenames, test_labels)

    return ((train_classes, train_data, train_filenames, train_labels),
            (test_classes, test_data, test_filenames, test_labels))
Beispiel #11
0
    def fit(self,
            train_data,
            train_labels,
            batch_size=128,
            epochs=10000,
            tests_dir=None,
            run_title=None,
            images_to_sample_per_class=100,
            sample_rate=100,
            disable_training=False):
        # Load the dataset
        X_train, y_train = train_data, train_labels

        # Configure inputs
        y_train = y_train.reshape(-1, 1)

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(1, epochs + 1):
            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            # Sample noise as generator input
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # The labels of the digits that the generator tries to create an
            # image representation of
            sampled_labels = np.random.randint(0, self.num_classes,
                                               (batch_size, 1))

            # Generate a half batch of new images
            gen_imgs = self.generator.predict([noise, sampled_labels])

            # Image labels. 0-9
            img_labels = y_train[idx]

            # Train the discriminator
            if not disable_training:
                d_loss_real = self.discriminator.train_on_batch(
                    imgs, [valid, img_labels])
                d_loss_fake = self.discriminator.train_on_batch(
                    gen_imgs, [fake, sampled_labels])
            else:
                d_loss_real = self.discriminator.test_on_batch(
                    imgs, [valid, img_labels])
                d_loss_fake = self.discriminator.test_on_batch(
                    gen_imgs, [fake, sampled_labels])
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            # Train the generator
            if not disable_training:
                g_loss = self.combined.train_on_batch([noise, sampled_labels],
                                                      [valid, sampled_labels])
            else:
                g_loss = self.combined.test_on_batch([noise, sampled_labels],
                                                     [valid, sampled_labels])

            # Plot the progress
            Hexnet_print(
                f'(epoch={epoch:{len(str(epochs))}}/{epochs}) [D loss={d_loss[0]:11.8f}, acc={100*d_loss[3]:6.2f}%, op_acc={100*d_loss[4]:6.2f}%] [G loss={g_loss[0]:11.8f}]'
            )

            # If at save interval => save generated image samples
            if not epoch % sample_rate and tests_dir is not None:
                self.sample_images(epoch, tests_dir, run_title,
                                   images_to_sample_per_class)
Beispiel #12
0
def run(args):

    ############################################################################
    # Parameters
    ############################################################################

    disable_training = args.disable_training
    disable_testing = args.disable_testing
    disable_output = args.disable_output
    enable_tensorboard = args.enable_tensorboard

    model_string = args.model
    load_model = args.load_model
    load_weights = args.load_weights
    save_model = args.save_model
    save_weights = args.save_weights

    dataset = args.dataset
    create_dataset = args.create_dataset
    disable_rand = args.disable_rand
    create_h5 = args.create_h5
    resize_dataset = args.resize_dataset
    crop_dataset = args.crop_dataset
    pad_dataset = args.pad_dataset
    augment_dataset = args.augment_dataset
    augmenter_string = args.augmenter
    augmentation_level = args.augmentation_level
    augmentation_size = args.augmentation_size

    chunk_size = args.chunk_size

    output_dir = args.output_dir
    show_dataset = args.show_dataset
    visualize_dataset = args.visualize_dataset
    visualize_model = args.visualize_model
    visualize_hexagonal = args.visualize_hexagonal
    show_results = args.show_results

    optimizer = args.optimizer
    metrics = args.metrics
    batch_size = args.batch_size
    epochs = args.epochs
    loss_string = args.loss
    runs = args.runs
    validation_split = args.validation_split

    cnn_kernel_size = args.cnn_kernel_size
    cnn_pool_size = args.cnn_pool_size
    resnet_stacks = args.resnet_stacks
    resnet_n = args.resnet_n
    resnet_filter_size = args.resnet_filter_size

    verbosity_level = args.verbosity_level

    transform_s2h = args.transform_s2h
    transform_h2s = args.transform_h2s
    transform_h2h = args.transform_h2h
    transform_s2s = args.transform_s2s
    transform_s2h_rad_o = args.transform_s2h_rad_o
    transform_h2s_len = args.transform_h2s_len
    transform_h2h_rad_o = args.transform_h2h_rad_o
    transform_s2s_res = args.transform_s2s_res

    ############################################################################
    # Initialization
    ############################################################################

    if model_string:
        model_is_provided = True

        model_is_custom = True if 'custom' in model_string else False
        model_is_standalone = True if 'standalone' in model_string else False

        model_is_autoencoder = True if 'autoencoder' in model_string else False
        model_is_cnn = True if 'CNN' in model_string else False
        model_is_gan = True if 'GAN' in model_string else False
        model_is_resnet = True if 'ResNet' in model_string else False

        model_is_from_keras = True if 'keras' in model_string else False
        model_is_from_sklearn = True if 'sklearn' in model_string else False
    else:
        model_is_provided = False

    if augmenter_string:
        augmenter_is_provided = True

        augmenter_is_custom = True if 'custom' in augmenter_string else False
    else:
        augmenter_is_provided = False

    if metrics != 'auto':
        metrics_are_provided = True
    else:
        metrics_are_provided = False

    if loss_string != 'auto':
        loss_is_provided = True

        subpixel_loss_identifiers = ('s2s', 's2h')
        loss_is_subpixel_loss = True if any(
            identifier in loss_string
            for identifier in subpixel_loss_identifiers) else False

        loss_is_from_keras = loss_string.startswith('keras_')
    else:
        loss_is_provided = False

    if dataset:
        dataset = dataset.rstrip('/')

        dataset_is_provided = True

        if create_dataset: create_dataset = ast.literal_eval(create_dataset)
    else:
        dataset_is_provided = False

    if disable_output:
        output_dir = None
    elif not output_dir:
        disable_output = True

    disable_training |= epochs < 1
    enable_training = not disable_training
    enable_testing = not disable_testing
    enable_output = not disable_output
    disable_tensorboard = not enable_tensorboard
    enable_rand = not disable_rand

    train_classes = []
    train_classes_orig = []
    train_data = []
    train_filenames = []
    train_labels = []
    train_labels_orig = []
    test_classes = []
    test_classes_orig = []
    test_data = []
    test_filenames = []
    test_labels = []
    test_labels_orig = []

    classification_reports = []

    if enable_tensorboard and enable_output:
        fit_callbacks = [
            tf.keras.callbacks.TensorBoard(
                log_dir=os.path.normpath(output_dir), histogram_freq=1)
        ]
    else:
        fit_callbacks = None

    ############################################################################
    # No dataset provided - returning
    ############################################################################

    if not dataset_is_provided:
        print_newline()
        Hexnet_print('No dataset provided - returning')

        return 0

    ############################################################################
    # Create classification dataset
    ############################################################################

    if create_dataset:
        print_newline()

        datasets.create_dataset(dataset,
                                split_ratios=create_dataset,
                                randomized_assignment=enable_rand,
                                verbosity_level=verbosity_level)

        dataset = f'{dataset}_classification_dataset'

    ############################################################################
    # Transform the dataset
    ############################################################################

    if any(operation != False for operation in (transform_s2h, transform_h2s,
                                                transform_h2h, transform_s2s)):
        print_newline()
        Hexnet_init()
        print_newline()

        if transform_s2h != False:
            if transform_s2h is None:
                transform_s2h = f'{dataset}_s2h'

            datasets.transform_dataset(dataset=dataset,
                                       output_dir=transform_s2h,
                                       mode='s2h',
                                       rad_o=transform_s2h_rad_o,
                                       method=0,
                                       verbosity_level=verbosity_level)

        if transform_h2s != False:
            if transform_h2s is None:
                transform_h2s = f'{dataset}_h2s'

            datasets.transform_dataset(dataset=dataset,
                                       output_dir=transform_h2s,
                                       mode='h2s',
                                       len=transform_h2s_len,
                                       method=0,
                                       verbosity_level=verbosity_level)

        if transform_h2h != False:
            if transform_h2h is None:
                transform_h2h = f'{dataset}_h2h'

            datasets.transform_dataset(dataset=dataset,
                                       output_dir=transform_h2h,
                                       mode='h2h',
                                       rad_o=transform_h2h_rad_o,
                                       method=0,
                                       verbosity_level=verbosity_level)

        if transform_s2s != False:
            if transform_s2s is None:
                transform_s2s = f'{dataset}_s2s'

            datasets.transform_dataset(dataset=dataset,
                                       output_dir=transform_s2s,
                                       mode='s2s',
                                       res=transform_s2s_res,
                                       method=0,
                                       verbosity_level=verbosity_level)

    ############################################################################
    # Visualize the dataset
    ############################################################################

    if visualize_dataset and os.path.isfile(
            dataset) and dataset.lower().endswith('.csv'):
        print_newline()

        datasets.visualize_dataset(dataset)

    ############################################################################
    # Load the dataset
    ############################################################################

    print_newline()

    ((train_classes_orig, train_data, train_filenames, train_labels_orig),
     (test_classes_orig,  test_data,  test_filenames,  test_labels_orig)) = \
     datasets.load_dataset(dataset, create_h5, verbosity_level)

    print_newline()

    datasets.create_dataset_overview(train_classes_orig, train_labels_orig,
                                     test_labels_orig, dataset, output_dir)

    if type(train_data) is not np.ndarray:
        disable_training = True
        enable_training = not disable_training

    if type(test_data) is not np.ndarray:
        disable_testing = True
        enable_testing = not disable_testing

    ############################################################################
    # Prepare the dataset
    ############################################################################

    print_newline()
    Hexnet_print('Dataset preparation')

    class_labels_are_digits = True

    if enable_training:
        for class_label in train_classes_orig:
            if not class_label.isdigit():
                class_labels_are_digits = False
                break
    elif enable_testing:
        for class_label in test_classes_orig:
            if not class_label.isdigit():
                class_labels_are_digits = False
                break

    if enable_training:
        if class_labels_are_digits:
            train_labels = np.asarray(
                [int(label) for label in train_labels_orig.flatten()])
        else:
            train_labels = np.asarray([
                np.where(label == train_classes_orig)[0][0]
                for label in train_labels_orig.flatten()
            ])

        train_classes = np.unique(train_labels)
        train_classes_len = len(train_classes)
        train_labels = np.reshape(train_labels,
                                  newshape=train_labels_orig.shape)

        if class_labels_are_digits:
            train_classes_min = min(train_classes)

            train_classes -= train_classes_min
            train_labels -= train_classes_min

        if train_labels.ndim > 1:
            train_labels = array_to_one_hot_array(train_labels,
                                                  train_classes_len)

    if enable_testing:
        if class_labels_are_digits:
            test_labels = np.asarray(
                [int(label) for label in test_labels_orig.flatten()])
        else:
            test_labels = np.asarray([
                np.where(label == test_classes_orig)[0][0]
                for label in test_labels_orig.flatten()
            ])

        test_classes = np.unique(test_labels)
        test_classes_len = len(test_classes)
        test_labels = np.reshape(test_labels, newshape=test_labels_orig.shape)

        if class_labels_are_digits:
            test_classes_min = min(test_classes)

            test_classes -= test_classes_min
            test_labels -= test_classes_min

        if test_labels.ndim > 1:
            test_labels = array_to_one_hot_array(test_labels, test_classes_len)

    ############################################################################
    # Preprocess the dataset
    ############################################################################

    if any(operation
           for operation in (resize_dataset, crop_dataset, pad_dataset)):
        Hexnet_print('Dataset preprocessing')

        if resize_dataset:
            (train_data,
             test_data) = datasets.resize_dataset(dataset_s=(train_data,
                                                             test_data),
                                                  resize_string=resize_dataset)

        if crop_dataset:
            (train_data,
             test_data) = datasets.crop_dataset(dataset_s=(train_data,
                                                           test_data),
                                                crop_string=crop_dataset)

        if pad_dataset:
            (train_data,
             test_data) = datasets.pad_dataset(dataset_s=(train_data,
                                                          test_data),
                                               pad_string=pad_dataset)

    ############################################################################
    # Augment the dataset
    ############################################################################

    if augment_dataset:
        Hexnet_print('Dataset augmentation')

        if augmentation_size != 1:
            if 'train' in augment_dataset:
                train_data, train_filenames, train_labels, train_labels_orig = \
                 augmenters.augment_size(train_data, train_filenames, train_labels, train_labels_orig, augmentation_size)

            if 'test' in augment_dataset:
                test_data, test_filenames, test_labels, test_labels_orig = \
                 augmenters.augment_size(test_data, test_filenames, test_labels, test_labels_orig, augmentation_size)

        if augmenter_is_provided:
            augmenter = vars(augmenters)[f'augmenter_{augmenter_string}']

            if not augmenter_is_custom:
                augmenter = augmenter(augmentation_level)
            else:
                augmenter = augmenter()

            if 'train' in augment_dataset:
                train_data = augmenter(images=train_data)

            if 'test' in augment_dataset:
                test_data = augmenter(images=test_data)

    ############################################################################
    # Show the dataset
    ############################################################################

    if show_dataset:
        datasets.show_dataset(train_classes_orig,
                              train_data,
                              train_labels_orig,
                              test_classes_orig,
                              test_data,
                              test_labels_orig,
                              max_images_per_class=1,
                              max_classes_to_display=10)

    ############################################################################
    # Visualize the dataset
    ############################################################################

    if visualize_dataset:
        print_newline()

        datasets.visualize_dataset(dataset, train_classes_orig, train_data,
                                   train_filenames, train_labels_orig,
                                   test_classes_orig, test_data,
                                   test_filenames, test_labels_orig,
                                   visualize_hexagonal, create_h5,
                                   verbosity_level)

    ############################################################################
    # No model provided - returning
    ############################################################################

    if not model_is_provided:
        print_newline()
        Hexnet_print('No model provided - returning')

        return 0

    ############################################################################
    # Standardize / normalize the dataset
    ############################################################################

    if not model_is_gan:
        if visualize_dataset: print_newline()
        Hexnet_print('Dataset standardization')

        std_eps = np.finfo(np.float32).eps

        if enable_training:
            if train_data.ndim > 3:
                mean_axis = (1, 2)
            else:
                mean_axis = 1

            for chunk_start in tqdm(range(0, train_data.shape[0], chunk_size)):
                chunk_end = min(chunk_start + chunk_size, train_data.shape[0])

                train_data_mean = np.mean(train_data[chunk_start:chunk_end],
                                          axis=mean_axis,
                                          keepdims=True)

                train_data_std = np.sqrt(
                    ((train_data[chunk_start:chunk_end] -
                      train_data_mean)**2).mean(axis=mean_axis, keepdims=True))
                train_data_std[train_data_std == 0] = std_eps

                train_data[chunk_start:chunk_end] = (
                    train_data[chunk_start:chunk_end] -
                    train_data_mean) / train_data_std

        if enable_testing:
            if test_data.ndim > 3:
                mean_axis = (1, 2)
            else:
                mean_axis = 1

            for chunk_start in tqdm(range(0, test_data.shape[0], chunk_size)):
                chunk_end = min(chunk_start + chunk_size, test_data.shape[0])

                test_data_mean = np.mean(test_data[chunk_start:chunk_end],
                                         axis=mean_axis,
                                         keepdims=True)

                test_data_std = np.sqrt(
                    ((test_data[chunk_start:chunk_end] -
                      test_data_mean)**2).mean(axis=mean_axis, keepdims=True))
                test_data_std[test_data_std == 0] = std_eps

                test_data[chunk_start:chunk_end] = (
                    test_data[chunk_start:chunk_end] -
                    test_data_mean) / test_data_std
    else:
        if visualize_dataset: print_newline()
        Hexnet_print('Dataset normalization')

        data_min = min(train_data.min(), test_data.min())
        data_max = max(train_data.max(), test_data.max())
        normalization_factor = data_max - data_min

        if enable_training:
            for chunk_start in tqdm(range(0, train_data.shape[0], chunk_size)):
                train_data[chunk_start:chunk_end] = (
                    train_data[chunk_start:chunk_end] -
                    data_min) / normalization_factor

        if enable_testing:
            for chunk_start in tqdm(range(0, test_data.shape[0], chunk_size)):
                test_data[chunk_start:chunk_end] = (
                    test_data[chunk_start:chunk_end] -
                    data_min) / normalization_factor

    ############################################################################
    # Shuffle the dataset
    ############################################################################

    if enable_training:
        Hexnet_print('Dataset shuffling')

        (train_data,
         train_labels) = sklearn.utils.shuffle(train_data, train_labels)

    ############################################################################
    # Start a new training and test run
    ############################################################################

    dataset = os.path.basename(dataset)
    tests_title = f'{model_string}__{dataset}'

    if not model_is_from_sklearn:
        tests_title = f'{tests_title}__epochs{epochs}-bs{batch_size}'

    print_newline()
    print_newline()

    for run in range(1, runs + 1):

        ########################################################################
        # Current run information
        ########################################################################

        run_string = f'run={run}/{runs}'
        timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
        run_title = f'{tests_title}__{timestamp}_run{run}'

        ########################################################################
        # Initialize the model
        ########################################################################

        Hexnet_print(f'({run_string}) Model initialization')

        if enable_training:
            input_shape = train_data.shape[1:4]
            classes = train_classes_len
        elif enable_testing:
            input_shape = test_data.shape[1:4]
            classes = test_classes_len

        if not load_model:
            model = vars(models)[f'model_{model_string}']

            if model_is_custom or model_is_standalone:
                model = model(input_shape, classes)
            elif model_is_autoencoder:
                model = model(input_shape)
            elif model_is_cnn:
                model = model(input_shape, classes, cnn_kernel_size,
                              cnn_pool_size)
            elif model_is_resnet and not model_is_from_keras:
                model = model(input_shape, classes, resnet_stacks, resnet_n,
                              resnet_filter_size)
            elif model_is_from_sklearn:
                model = model()
            else:
                model = model(input_shape, classes)
        elif not (model_is_standalone or model_is_from_sklearn):
            model = tf.keras.models.load_model(load_model)

        if load_weights and not (model_is_standalone or model_is_from_sklearn):
            model.load_weights(load_weights)

        ########################################################################
        # Initialize loss and metrics
        ########################################################################

        if not (model_is_standalone or model_is_from_sklearn):
            Hexnet_print(f'({run_string}) Loss and metrics initialization')

            if not metrics_are_provided:
                if not model_is_autoencoder:
                    metrics = ['accuracy']
                else:
                    metrics = []

            if not loss_is_provided:
                if not model_is_autoencoder:
                    loss = 'SparseCategoricalCrossentropy'
                else:
                    loss = 'MeanSquaredError'

                loss = tf.losses.get(loss)
            else:
                if not loss_is_subpixel_loss:
                    if not loss_is_from_keras:
                        loss = vars(losses)[f'loss_{loss_string}']()
                    else:
                        loss = vars(
                            tf.keras.losses)[loss_string[len('keras_'):]]()
                else:
                    output_shape = test_data.shape[1:4]
                    loss = vars(losses)[f'loss_{loss_string}'](input_shape,
                                                               output_shape)

        ########################################################################
        # Compile the model
        ########################################################################

        if not model_is_from_sklearn:
            Hexnet_print(f'({run_string}) Model compilation')

            if not model_is_standalone:
                model.compile(optimizer, loss, metrics)
            else:
                model.compile()

        ########################################################################
        # Model summary
        ########################################################################

        print_newline()
        Hexnet_print(f'({run_string}) Model summary')

        if not model_is_from_sklearn:
            model.summary()
        else:
            Hexnet_print(model.get_params())

        ########################################################################
        # Train the model
        ########################################################################

        if enable_training:
            print_newline()
            Hexnet_print(f'({run_string}) Model training')

            if model_is_standalone:
                model.fit(train_data, train_labels, batch_size, epochs,
                          visualize_hexagonal, output_dir, run_title)
            elif model_is_autoencoder:
                history = model.fit(train_data,
                                    train_data,
                                    batch_size,
                                    epochs,
                                    validation_split=validation_split)
            elif model_is_from_sklearn:
                model.fit(
                    np.reshape(train_data, newshape=(train_data.shape[0], -1)),
                    train_labels)
            else:
                history = model.fit(train_data,
                                    train_labels,
                                    batch_size,
                                    epochs,
                                    validation_split=validation_split,
                                    callbacks=fit_callbacks)

        ########################################################################
        # Visualize filters, feature maps, activations, and training results
        ########################################################################

        if not (model_is_standalone or model_is_from_sklearn):
            if visualize_model and enable_output:
                print_newline()
                Hexnet_print(f'({run_string}) Visualization')

                output_dir_visualizations = os.path.join(
                    output_dir, f'{run_title}_visualizations')

                visualization.visualize_model(
                    model,
                    test_classes,
                    test_data,
                    test_labels,
                    visualize_hexagonal,
                    output_dir=output_dir_visualizations,
                    max_images_per_class=10,
                    verbosity_level=verbosity_level)

            if enable_training:
                print_newline()
                Hexnet_print(f'({run_string}) History')
                Hexnet_print(
                    f'({run_string}) history.history.keys()={history.history.keys()}'
                )

                if enable_output or show_results:
                    visualization.visualize_training_results(
                        history, run_title, output_dir, show_results)

        ########################################################################
        # Evaluate the model
        ########################################################################

        if enable_testing and not model_is_from_sklearn:
            print_newline()
            Hexnet_print(f'({run_string}) Model evaluation')

            if model_is_standalone:
                model.evaluate(test_data,
                               test_labels,
                               batch_size,
                               epochs=10,
                               visualize_hexagonal=visualize_hexagonal,
                               output_dir=output_dir,
                               run_title=run_title)
            elif model_is_autoencoder:
                test_loss_metrics = model.evaluate(test_data, test_data)
            else:
                test_loss_metrics = model.evaluate(test_data, test_labels)

            if not model_is_standalone:
                Hexnet_print(
                    f'({run_string}) test_loss_metrics={test_loss_metrics}')

        ########################################################################
        # Save test results
        ########################################################################

        if enable_testing and enable_output and not model_is_standalone:
            print_newline()
            Hexnet_print(f'({run_string}) Saving test results')

            if not model_is_from_sklearn:
                predictions = model.predict(test_data)
            else:
                predictions = model.predict_proba(
                    np.reshape(test_data, newshape=(test_data.shape[0], -1)))

            if not model_is_autoencoder:
                classification_report = visualization.visualize_test_results(
                    predictions, test_classes, test_classes_orig,
                    test_filenames, test_labels, test_labels_orig, run_title,
                    output_dir)

                print_newline()
                Hexnet_print(f'({run_string}) Classification report')
                pprint(classification_report)

                classification_reports.append(classification_report)
            else:
                loss_newshape = (test_data.shape[0], -1)
                test_losses = loss(
                    np.reshape(test_data, newshape=loss_newshape),
                    np.reshape(predictions, newshape=loss_newshape))

                output_dir_predictions = os.path.join(
                    output_dir, f'{run_title}_predictions')
                os.makedirs(output_dir_predictions, exist_ok=True)

                with open(f'{output_dir_predictions}.csv',
                          'w') as predictions_file:
                    print('label_orig,filename,label,loss',
                          file=predictions_file)

                    for label_orig, filename, label, loss in zip(
                            test_labels_orig, test_filenames, test_labels,
                            test_losses):
                        loss = float(format(loss, '.8f'))
                        print(f'{label_orig},{filename},{label},{loss}',
                              file=predictions_file)

                for image_counter, (image, label) in enumerate(
                        zip(tqdm(predictions), test_labels)):
                    image_filename = f'label{label}_image{image_counter}.png'

                    if not visualize_hexagonal:
                        imsave(
                            os.path.join(output_dir_predictions,
                                         image_filename), image)
                    else:
                        visualization.visualize_hexarray(
                            image,
                            os.path.join(output_dir_predictions,
                                         image_filename))

        ########################################################################
        # Save the model
        ########################################################################

        if (save_model or save_weights) and enable_output and not (
                model_is_standalone or model_is_from_sklearn):
            print_newline()
            Hexnet_print(f'({run_string}) Saving the model')

            if save_model:
                model.save(os.path.join(output_dir, f'{run_title}_model.h5'))

            if save_weights:
                model.save_weights(
                    os.path.join(output_dir, f'{run_title}_weights.h5'))

        if run < runs:
            print_newline()
            print_newline()

    ############################################################################
    # Save global test results
    ############################################################################

    if enable_testing and enable_output and runs > 1 and not (
            model_is_standalone or model_is_autoencoder):
        timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
        tests_title = f'{tests_title}__{timestamp}'

        visualization.visualize_global_test_results(classification_reports,
                                                    tests_title, output_dir)

    return 0
Beispiel #13
0
def visualize_dataset(dataset,
                      train_classes=None,
                      train_data=None,
                      train_filenames=None,
                      train_labels=None,
                      test_classes=None,
                      test_data=None,
                      test_filenames=None,
                      test_labels=None,
                      visualize_hexagonal=None,
                      create_h5=False,
                      verbosity_level=2):

    Hexnet_print(f'Visualizing dataset {dataset}')

    start_time = time()

    if create_h5:
        dataset = f'{dataset}_visualized.h5'

        create_dataset_h5(dataset, train_classes, train_data, train_filenames,
                          train_labels, test_classes, test_data,
                          test_filenames, test_labels)
    elif os.path.isfile(dataset) and dataset.lower().endswith('.csv'):
        dataset_visualized = f'{dataset}_visualized'

        with open(dataset) as dataset_file:
            dataset_reader = csv.reader(dataset_file)
            dataset_data = list(dataset_reader)[1:]

        for label, filename, data in tqdm(dataset_data):
            current_output_dir = os.path.join(dataset_visualized, label)
            os.makedirs(current_output_dir, exist_ok=True)

            with open(os.path.join(current_output_dir, filename),
                      'w') as current_data_file:
                print(data.replace('"', ''), file=current_data_file)
    else:
        dataset_visualized = f'{dataset}_visualized'

        if os.path.isfile(dataset) and dataset.lower().endswith('.h5'):
            for current_class in train_classes:
                os.makedirs(os.path.join(dataset_visualized, 'train',
                                         current_class),
                            exist_ok=True)

            for current_class in test_classes:
                os.makedirs(os.path.join(dataset_visualized, 'test',
                                         current_class),
                            exist_ok=True)
        else:
            shutil.copytree(dataset,
                            dataset_visualized,
                            ignore=copytree_ignore_files)

        for current_set, current_data, current_filenames, current_labels in \
         zip(('train', 'test'), (train_data, test_data), (train_filenames, test_filenames), (train_labels, test_labels)):

            if verbosity_level >= 1:
                Hexnet_print(f'\t> current_set={current_set}')

            if not current_data.size:
                continue

            for file, filename, label in zip(tqdm(current_data),
                                             current_filenames,
                                             current_labels):
                filename = os.path.join(dataset_visualized, current_set, label,
                                        filename)

                if verbosity_level >= 3:
                    Hexnet_print(f'\t\t\t> filename={filename}')

                filename_lower = filename.lower()

                if filename_lower.endswith('.csv'):
                    np.savetxt(filename,
                               np.reshape(file, newshape=(1, file.shape[0])),
                               delimiter=',')
                elif filename_lower.endswith('.npy'):
                    np.save(filename, file)
                else:
                    if not visualize_hexagonal:
                        imsave(filename, file)
                    else:
                        filename = '.'.join(filename.split('.')[:-1])
                        visualize_hexarray(normalize_array(file), filename)

    time_diff = time() - start_time

    Hexnet_print(f'Visualized dataset {dataset} in {time_diff:.3f} seconds')
Beispiel #14
0
def transform_dataset(dataset,
                      output_dir,
                      mode='s2h',
                      rad_o=1.0,
                      len=1.0,
                      res=(64, 64),
                      method=0,
                      verbosity_level=2):

    if os.path.exists(output_dir):
        Hexnet_print(
            f'Dataset {output_dir} exists already (skipping transformation)')
        return

    if os.path.isfile(f'{output_dir}.h5'):
        Hexnet_print(
            f'Dataset {output_dir}.h5 exists already (skipping transformation)'
        )
        return

    Hexnet_print(f'Transforming dataset {dataset}')

    start_time = time()

    increase_verbosity = True if verbosity_level >= 3 else False

    shutil.copytree(dataset, output_dir, ignore=copytree_ignore_files)

    for directory in natsorted(
            glob(os.path.join(dataset, '**/'), recursive=True)):
        if verbosity_level >= 1:
            Hexnet_print(f'\t> directory={directory}')

        directory_filename_s = os.path.join(directory, '*')

        found_file = False

        for file in glob(directory_filename_s):
            if os.path.isfile(file):
                found_file = True
                break

        if not found_file:
            continue

        directory_output_dir = os.path.join(
            output_dir, os.path.relpath(directory, dataset))

        if mode == 's2h':
            Hexsamp_s2h(filename_s=directory_filename_s,
                        output_dir=directory_output_dir,
                        rad_o=rad_o,
                        method=method,
                        increase_verbosity=increase_verbosity)
        elif mode == 'h2s':
            Hexsamp_h2s(filename_s=directory_filename_s,
                        output_dir=directory_output_dir,
                        len=len,
                        method=method,
                        increase_verbosity=increase_verbosity)
        elif mode == 'h2h':
            Hexsamp_h2h(filename_s=directory_filename_s,
                        output_dir=directory_output_dir,
                        rad_o=rad_o,
                        method=method,
                        increase_verbosity=increase_verbosity)
        elif mode == 's2s':
            Sqsamp_s2s(filename_s=directory_filename_s,
                       output_dir=directory_output_dir,
                       res=res,
                       method=method,
                       increase_verbosity=increase_verbosity)

    time_diff = time() - start_time

    Hexnet_print(f'Transformed dataset {dataset} in {time_diff:.3f} seconds')
Beispiel #15
0
def create_dataset(dataset,
                   split_ratios,
                   randomized_assignment=True,
                   verbosity_level=2):
    Hexnet_print(f'Creating classification dataset from dataset {dataset}')

    start_time = time()

    split_ratios_len = len(split_ratios)
    split_ratios_sets = list(split_ratios.keys())
    split_ratios_fractions = list(split_ratios.values())

    classification_dataset = f'{dataset}_classification_dataset'
    os.makedirs(classification_dataset, exist_ok=True)

    max_files_to_copy = max([
        len(glob(os.path.join(set_class, '*')))
        for set_class in glob(os.path.join(dataset, '*'))
    ])
    max_files_to_copy_per_set = [
        math.ceil(fraction * max_files_to_copy)
        for fraction in split_ratios_fractions
    ]

    for set_class in natsorted(glob(os.path.join(dataset, '*'))):

        # Step 1: (randomized) file dataset set assignment

        current_class = os.path.basename(set_class)

        if verbosity_level >= 1:
            Hexnet_print(f'\t> current_class={current_class}')

        files_to_copy = glob(os.path.join(set_class, '*'))

        if not files_to_copy:
            continue

        files_to_copy_len = len(files_to_copy)
        copied_files = []
        copied_files_per_set = split_ratios_len * [0]

        if verbosity_level >= 2:
            Hexnet_print(
                f'\t\t> max_files_to_copy_per_set={max_files_to_copy_per_set} (files_to_copy_len={files_to_copy_len})'
            )

        for current_set in split_ratios_sets:
            os.makedirs(os.path.join(classification_dataset, current_set,
                                     current_class),
                        exist_ok=True)

        if randomized_assignment:
            files_to_copy = random.sample(files_to_copy, files_to_copy_len)

        for file_to_copy_index, file_to_copy in enumerate(tqdm(files_to_copy)):
            if not randomized_assignment:
                cumulative_split_ratio_fraction = 0

                for split_ratio_fraction_index, split_ratio_fraction in enumerate(
                        split_ratios_fractions):
                    cumulative_split_ratio_fraction += split_ratio_fraction

                    if file_to_copy_index < cumulative_split_ratio_fraction * files_to_copy_len:
                        set_selector = split_ratio_fraction_index
                        break

            while True:
                if randomized_assignment:
                    set_selector = random.randint(0, split_ratios_len - 1)

                if copied_files_per_set[
                        set_selector] < max_files_to_copy_per_set[set_selector]:
                    copy_file_to = os.path.join(
                        classification_dataset,
                        split_ratios_sets[set_selector], current_class,
                        os.path.basename(file_to_copy))

                    shutil.copyfile(file_to_copy, copy_file_to)

                    copied_files.append(copy_file_to)
                    copied_files_per_set[set_selector] += 1

                    break

        copied_files_len = len(copied_files)

        if verbosity_level >= 2:
            Hexnet_print(
                f'\t\t> copied_files_per_set={copied_files_per_set} (copied_files_len={copied_files_len})'
            )

        # Step 2: randomized file dataset set balancing: duplicate and hash assigned files

        for current_set_index, current_set in enumerate(split_ratios_sets):
            while copied_files_per_set[
                    current_set_index] < max_files_to_copy_per_set[
                        current_set_index]:
                file_selector = round(random.randint(0, copied_files_len - 1))

                file_to_copy = copied_files[file_selector]
                copy_file_to = os.path.basename(file_to_copy).split('.')
                copy_file_to = '.'.join(copy_file_to[:-1]) + '_' + str(
                    uuid.uuid4()) + '.' + copy_file_to[-1]
                copy_file_to = os.path.join(classification_dataset,
                                            current_set, current_class,
                                            copy_file_to)

                shutil.copyfile(file_to_copy, copy_file_to)

                copied_files_per_set[current_set_index] += 1

        if verbosity_level >= 2:
            copied_files_len = sum(copied_files_per_set)
            Hexnet_print(
                f'\t\t> copied_files_per_set={copied_files_per_set} (copied_files_len={copied_files_len}) after balancing'
            )

    time_diff = time() - start_time

    Hexnet_print(
        f'Created classification dataset from dataset {dataset} in {time_diff:.3f} seconds'
    )
Beispiel #16
0
    def fit(self,
            train_data,
            train_labels,
            batch_size=100,
            epochs=100,
            visualize_hexagonal=False,
            output_dir=None,
            run_title=None,
            images_to_sample_per_class=10,
            disable_training=False):

        # Configure inputs
        train_labels = np.reshape(train_labels, newshape=(-1, 1))

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        batches = int(train_data.shape[0] / batch_size)

        for epoch in range(1, epochs + 1):
            for batch in tqdm(range(1, batches + 1)):

                ################################################################
                # Train the discriminator
                ################################################################

                # Select a random batch of images
                idx = np.random.randint(0, train_data.shape[0], batch_size)
                imgs = train_data[idx]

                # Sample noise as generator input
                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

                # The labels of the digits that the generator tries to create an image representation of
                sampled_labels = np.random.randint(0, self.classes,
                                                   (batch_size, 1))

                # Generate a half batch of new images
                gen_imgs = self.generator.predict([noise, sampled_labels])

                # Image labels
                img_labels = train_labels[idx]

                if not disable_training:
                    d_loss_real = self.discriminator.train_on_batch(
                        imgs, [valid, img_labels])
                    d_loss_fake = self.discriminator.train_on_batch(
                        gen_imgs, [fake, sampled_labels])
                else:
                    d_loss_real = self.discriminator.test_on_batch(
                        imgs, [valid, img_labels])
                    d_loss_fake = self.discriminator.test_on_batch(
                        gen_imgs, [fake, sampled_labels])

                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

                ################################################################
                # Train the generator
                ################################################################

                if not disable_training:
                    g_loss = self.combined.train_on_batch(
                        [noise, sampled_labels], [valid, sampled_labels])
                else:
                    g_loss = self.combined.test_on_batch(
                        [noise, sampled_labels], [valid, sampled_labels])

            Hexnet_print(
                f'(epoch={epoch:{len(str(epochs))}}/{epochs}) [D loss={d_loss[0]:11.8f}, acc={100*d_loss[3]:6.2f}%, op_acc={100*d_loss[4]:6.2f}%] [G loss={g_loss[0]:11.8f}]'
            )

            if output_dir is not None:
                self.sample_images(epoch, visualize_hexagonal, output_dir,
                                   run_title, images_to_sample_per_class)
Beispiel #17
0
def visualize_activations(
	model,
	test_classes,
	test_data,
	test_labels,
	visualize_hexagonal,
	output_dir,
	max_images_per_class = 10,
	verbosity_level      =  2):

	heatmap_intensity_factor = 0.66


	os.makedirs(output_dir, exist_ok=True)


	test_data_for_prediction    = []
	test_data_for_visualization = []
	test_labels_for_prediction  = []

	class_counter_dict = dict.fromkeys(test_classes, 0)

	for image, label in zip(test_data, test_labels):
		if class_counter_dict[label] < max_images_per_class:
			test_data_for_prediction.append(image)
			test_labels_for_prediction.append(label)
			class_counter_dict[label] += 1

	test_data_for_prediction    = np.asarray(test_data_for_prediction, dtype=np.float32)
	test_data_for_visualization = 255 * normalize_array(test_data_for_prediction)


	for layer_counter, layer in enumerate(model.layers):
		if 'conv' not in layer.name:
			continue

		if verbosity_level >= 2:
			Hexnet_print(f'(visualize_activations) layer={layer} (layer_counter={layer_counter}, layer.name={layer.name}): layer.output.shape={layer.output.shape}')


		model_outputs = Model(inputs = model.inputs, outputs = (model.output, layer.output))

		with tf.GradientTape() as gradient_tape:
			predictions, layer_output = model_outputs(test_data_for_prediction)
			# predictions = tf.reduce_max(predictions, axis=1)
			predictions_indices = np.stack((np.arange(0, len(test_labels_for_prediction)), test_labels_for_prediction), axis=1)
			predictions = tf.gather_nd(predictions, indices=predictions_indices)
			gradients   = gradient_tape.gradient(predictions, layer_output)
			gradients   = tf.reduce_mean(gradients, axis = (1, 2))

		activations = tf.einsum('ijkl,il->ijkl', layer_output, gradients)
		# activations = tf.reduce_mean(activations, axis=3)
		activations = tf.reduce_sum(activations, axis=3)
		activations = np.maximum(activations, 0)


		for image_counter, (image, label, activation) in enumerate(zip(tqdm(test_data_for_visualization), test_labels_for_prediction, activations)):
			activation_max = activation.max()

			if activation_max:
				activation /= activation_max

			activation = (255 * activation).astype(np.uint8)
			activation = cv2.resize(activation, (image.shape[1], image.shape[0]))
			activation = cv2.equalizeHist(activation)
			heatmap    = cv2.applyColorMap(activation, cv2.COLORMAP_VIRIDIS)

			image_heatmapped = image + heatmap_intensity_factor * heatmap
			image_heatmapped = normalize_array(np.clip(image_heatmapped, 0, 255))


			title = f'layer{str(layer_counter).zfill(3)}_{layer.name}_label{str(label).zfill(3)}_image{str(image_counter).zfill(3)}'
			image_heatmapped_filename = os.path.join(output_dir, f'{title}_image_heatmapped')
			heatmap_filename          = os.path.join(output_dir, f'{title}_activations_heatmap')

			imsave(f'{image_heatmapped_filename}.png', image_heatmapped)
			imsave(f'{heatmap_filename}.png',          heatmap)

			if visualize_hexagonal:
				heatmap = heatmap / 255

				visualize_hexarray(image_heatmapped, image_heatmapped_filename)
				visualize_hexarray(heatmap,          heatmap_filename)
Beispiel #18
0
def create_dataset_overview(classes, train_labels, test_labels, dataset,
                            output_dir):
    # Prepare dataset overview table: entries

    total_string = 'Total'

    unique, counts = np.unique(train_labels, return_counts=True)
    train_labels_unique_counts = dict(zip(unique, counts))
    unique, counts = np.unique(test_labels, return_counts=True)
    test_labels_unique_counts = dict(zip(unique, counts))

    labels_unique_counts = {key: train_value + test_value for (key, train_value), (_, test_value) in \
     zip(train_labels_unique_counts.items(), test_labels_unique_counts.items())}

    train_labels_unique_counts_total = sum(train_labels_unique_counts.values())
    test_labels_unique_counts_total = sum(test_labels_unique_counts.values())
    labels_unique_counts_total = sum(labels_unique_counts.values())

    entries_max_len = max(
        np.vectorize(len)(classes).max(), len(total_string),
        len(str(labels_unique_counts_total)))

    # Create dataset overview table: rows and columns

    total_string = total_string.rjust(entries_max_len)

    header_entries = '|'.join(
        [f' {c.rjust(entries_max_len)} ' for c in classes])
    train_entries = '|'.join([
        f' {str(v).rjust(entries_max_len)} '
        for v in train_labels_unique_counts.values()
    ])
    test_entries = '|'.join([
        f' {str(v).rjust(entries_max_len)} '
        for v in test_labels_unique_counts.values()
    ])
    total_entries = '|'.join([
        f' {str(v).rjust(entries_max_len)} '
        for v in labels_unique_counts.values()
    ])

    train_entries_total = str(train_labels_unique_counts_total).rjust(
        entries_max_len, ' ')
    test_entries_total = str(test_labels_unique_counts_total).rjust(
        entries_max_len, ' ')
    total_entries_total = str(labels_unique_counts_total).rjust(
        entries_max_len, ' ')

    header = '| Set \ Class |' + header_entries + '| ' + total_string + ' |'
    train = '| Train       |' + train_entries + '| ' + train_entries_total + ' |'
    test = '| Test        |' + test_entries + '| ' + test_entries_total + ' |'
    total = '| Total       |' + total_entries + '| ' + total_entries_total + ' |'
    hline = len(header) * '-'

    dataset_overview = \
     f'{hline}\n'  \
     f'{header}\n' \
     f'{hline}\n'  \
     f'{train}\n'  \
     f'{test}\n'   \
     f'{hline}\n'  \
     f'{total}\n'  \
     f'{hline}'

    # Output dataset overview table

    Hexnet_print(f'Dataset overview\n{dataset_overview}')

    if output_dir:
        filename = os.path.join(
            output_dir, f'{os.path.basename(dataset)}_dataset_overview.dat')

        os.makedirs(output_dir, exist_ok=True)

        with open(filename, 'w') as file:
            print(dataset_overview, file=file)