def test_train_neural_network(self):
        dataset = make_data_sets(x_train_val=self.x_pristine,
                                 y_train_val=self.y_pristine,
                                 split_train_val=True,
                                 test_size=0.1,
                                 x_test=self.x_vac25,
                                 y_test=self.y_vac25)

        # load the data
        x_train = dataset.train.images
        y_train = dataset.train.labels

        partial_model_architecture = partial(
            cnn_nature_comm_ziletti2018,
            conv2d_filters=[32, 32, 16, 16, 8, 8],
            kernel_sizes=[3, 3, 3, 3, 3, 3],
            max_pool_strides=[2, 2],
            hidden_layer_size=128)

        # use x_train also for validation - this is only to run the test
        results = train_neural_network(
            x_train=x_train,
            y_train=y_train,
            x_val=x_train,
            y_val=y_train,
            configs=self.configs,
            partial_model_architecture=partial_model_architecture,
            nb_epoch=1)
    def test_predict(self):
        dataset = make_data_sets(x_train_val=self.x_pristine,
                                 y_train_val=self.y_pristine,
                                 split_train_val=False,
                                 x_test=self.x_vac25,
                                 y_test=self.y_vac25)

        label_encoder = preprocessing.LabelEncoder()
        label_encoder.fit(self.text_labels)
        numerical_labels = label_encoder.transform(self.text_labels)

        # load the data
        x_test = dataset.test.images
        y_test = dataset.test.labels

        results = predict(x_test,
                          y_test,
                          configs=self.configs,
                          numerical_labels=numerical_labels,
                          text_labels=self.text_labels)

        self.assertIsInstance(results, dict)

        # probabilities should be between 0. and 1.
        self.assertLessEqual(np.amax(results['prob_predictions']), 1.0)
        self.assertGreaterEqual(np.amin(results['prob_predictions']), 0.0)

        # target_pred_class unique values should be at most 7
        self.assertLessEqual(len(set(results['target_pred_class'])), 7)
        self.assertLessEqual(np.amax(results['target_pred_class']), 7)
        self.assertGreaterEqual(np.amax(results['prob_predictions']), 0)

        # confusion matrix should be a numpy array
        self.assertIsInstance(results['confusion_matrix'], np.ndarray)

        # string_probs are a list of strings - one element for each prediction
        self.assertIsInstance(results['string_probs'], list)
        "batch_size": 32,
        "img_channels": 3
    }

    text_labels = np.asarray(dataset_info_train["data"][0]["text_labels"])
    numerical_labels = np.asarray(
        dataset_info_train["data"][0]["numerical_labels"])
    classes = dataset_info_train["data"][0]["classes"]

    # text_labels = np.asarray(dataset_info_test["data"][0]["text_labels"])
    # numerical_labels = np.asarray(dataset_info_test["data"][0]["numerical_labels"])

    data_set = make_data_sets(x_train_val=x_train,
                              y_train_val=y_train,
                              x_test=x_test,
                              y_test=y_test,
                              split_train_val=True,
                              test_size=0.1,
                              stratified_splits=True)

    # =============================================================================
    # Neural network training and prediction
    # =============================================================================

    partial_model_architecture = partial(model_deep_cnn_struct_recognition,
                                         conv2d_filters=[32, 32, 16, 16, 8, 8],
                                         kernel_sizes=[7, 7, 7, 7, 7, 7],
                                         max_pool_strides=[2, 2],
                                         hidden_layer_size=128)

    # generate image of architecture
    x_test, y_test, dataset_info_test = load_dataset_from_file(path_to_x=path_to_x_test, path_to_y=path_to_y_test,
                                                               path_to_summary=path_to_summary_test)

    params_cnn = {"nb_classes": dataset_info_train["data"][0]["nb_classes"],
                  "classes": dataset_info_train["data"][0]["classes"],
                  # "checkpoint_filename": 'try_' + str(now.isoformat()),
                  "checkpoint_filename": 'enc_dec_no_batch_norm',
                  # "checkpoint_filename": 'fully_conv_acc100',
                  # "checkpoint_filename": 'rot_inv_kernel_15',
                  "batch_size": 32, "img_channels": 1}

    text_labels = np.asarray(dataset_info_test["data"][0]["text_labels"])
    numerical_labels = np.asarray(dataset_info_test["data"][0]["numerical_labels"])

    data_set_train = make_data_sets(x_train_val=x_train, y_train_val=y_train, split_train_val=True, test_size=0.1,
                                    x_test=x_test, y_test=y_test, flatten_images=False)

    # beautiful maps
    #        conv2d_filters=[32, 16, 12, 8, 4, 4],
    #        kernel_sizes=[3, 3, 3, 3, 3, 3],
    # hidden_layer_size = 32)

    # partial_model_architecture = partial(model_cnn_rot_inv, conv2d_filters=[32, 32, 16, 16, 16, 16],
    #                                      kernel_sizes=[3, 3, 3, 3, 3, 3], hidden_layer_size=64)

    # partial_model_architecture = partial(model_cnn_rot_inv, conv2d_filters=[32, 16, 8, 8, 16, 32],
    #                                  kernel_sizes=[3, 3, 3, 3, 3, 3], hidden_layer_size=64, dropout=0.25)

    # partial_model_architecture = partial(model_cnn_rot_inv, conv2d_filters=[8, 8, 8, 8, 8, 8],
    #                                  kernel_sizes=[3, 3, 3, 3, 3, 3], hidden_layer_size=64)
def get_classification_map(polycrystal_file,
                           descriptor,
                           desc_metadata,
                           configs,
                           checkpoint_dir,
                           checkpoint_filename,
                           operations_on_structure=None,
                           stride_size=(4., 4., 4.),
                           box_size=12.0,
                           init_sliding_volume=(14., 14., 14.),
                           desc_file=None,
                           desc_only=False,
                           show_plot_lengths=True,
                           calc_uncertainty=True,
                           mc_samples=10,
                           desc_file_suffix_name='_pristine',
                           nb_jobs=-1,
                           interpolation='none',
                           results_file=None,
                           conf_matrix_file=None,
                           train_set_name='hcp-bcc-sc-diam-fcc-pristine',
                           padding_ratio=None,
                           cmap_uncertainty='hot',
                           interpolation_uncertainty='none'):
    if desc_file is None:
        logger.info("Calculating system's representation.")
        desc_file = calc_polycrystal_desc(
            polycrystal_file,
            stride_size,
            box_size,
            descriptor,
            configs,
            desc_file_suffix_name,
            operations_on_structure,
            nb_jobs,
            show_plot_lengths,
            padding_ratio=padding_ratio,
            init_sliding_volume=init_sliding_volume)
    else:
        logger.info("Using the precomputed user-specified descriptor file.")

    if not desc_only:
        target_list, structure_list = load_descriptor(desc_files=desc_file,
                                                      configs=configs)

        # create dataset
        dataset_name = '{0}_stride_{1}_{2}_{3}_box_size_{4}_{5}.tar.gz'.format(
            polycrystal_file, stride_size[0], stride_size[1], stride_size[2],
            box_size, desc_file_suffix_name)

        path_to_x_test, path_to_y_test, path_to_summary_test = prepare_dataset(
            structure_list=structure_list,
            target_list=target_list,
            desc_metadata=desc_metadata,
            dataset_name=dataset_name,
            target_name='target',
            target_categorical=True,
            input_dims=(52, 32),
            configs=configs,
            dataset_folder=configs['io']['dataset_folder'],
            main_folder=configs['io']['main_folder'],
            desc_folder=configs['io']['desc_folder'],
            tmp_folder=configs['io']['tmp_folder'])

        path_to_x_train = os.path.join(configs['io']['dataset_folder'],
                                       train_set_name + '_x.pkl')
        path_to_y_train = os.path.join(configs['io']['dataset_folder'],
                                       train_set_name + '_y.pkl')
        path_to_summary_train = os.path.join(configs['io']['dataset_folder'],
                                             train_set_name + '_summary.json')

        x_train, y_train, dataset_info_train = load_dataset_from_file(
            path_to_x=path_to_x_train,
            path_to_y=path_to_y_train,
            path_to_summary=path_to_summary_train)

        x_test, y_test, dataset_info_test = load_dataset_from_file(
            path_to_x=path_to_x_test,
            path_to_y=path_to_y_test,
            path_to_summary=path_to_summary_test)

        params_cnn = {
            "nb_classes": dataset_info_train["data"][0]["nb_classes"],
            "classes": dataset_info_train["data"][0]["classes"],
            "batch_size": 32,
            "img_channels": 1
        }

        text_labels = np.asarray(dataset_info_test["data"][0]["text_labels"])
        numerical_labels = np.asarray(
            dataset_info_test["data"][0]["numerical_labels"])

        data_set_predict = make_data_sets(x_train_val=x_test,
                                          y_train_val=y_test,
                                          split_train_val=False,
                                          test_size=0.1,
                                          x_test=x_test,
                                          y_test=y_test)

        target_pred_class, target_pred_probs, prob_predictions, conf_matrix, uncertainty = predict(
            data_set_predict,
            params_cnn["nb_classes"],
            configs=configs,
            batch_size=params_cnn["batch_size"],
            checkpoint_dir=checkpoint_dir,
            checkpoint_filename=checkpoint_filename,
            show_model_acc=False,
            mc_samples=mc_samples,
            predict_probabilities=True,
            plot_conf_matrix=True,
            conf_matrix_file=conf_matrix_file,
            numerical_labels=numerical_labels,
            text_labels=text_labels,
            results_file=results_file,
            normalize=True)

        predictive_mean = prob_predictions

        # get the number of strides in each directions in order to reshape properly
        strided_pattern_positions = []
        for structure in structure_list:
            strided_pattern_positions.append(
                structure.info['strided_pattern_positions'])

        class_plot_pos = np.asarray(strided_pattern_positions)
        (z_max, y_max, x_max) = np.amax(class_plot_pos, axis=0) + 1

        # make a dataframe to order the prob_predictions
        # this is needed when we read from file - the structures are ordered in a different way after they are saved
        # this comes into play only if more than 10 values for each directions are used
        df_positions = pd.DataFrame(data=class_plot_pos,
                                    columns=[
                                        'strided_pattern_positions_z',
                                        'strided_pattern_positions_y',
                                        'strided_pattern_positions_x'
                                    ])

        # sort predictive mean
        df_predictive_mean = pd.DataFrame(data=predictive_mean)
        df = pd.concat([df_positions, df_predictive_mean],
                       axis=1,
                       join_axes=[df_positions.index])
        df_predictive_mean_sorted = df.sort_values([
            'strided_pattern_positions_z', 'strided_pattern_positions_y',
            'strided_pattern_positions_x'
        ],
                                                   ascending=True)

        predictive_mean_sorted = df_predictive_mean_sorted.drop(columns=[
            'strided_pattern_positions_z', 'strided_pattern_positions_y',
            'strided_pattern_positions_x'
        ]).values

        for idx_class in range(predictive_mean_sorted.shape[1]):

            prob_prediction_class = predictive_mean_sorted[:,
                                                           idx_class].reshape(
                                                               z_max, y_max,
                                                               x_max)

            plot_prediction_heatmaps(prob_prediction_class,
                                     title='Probability',
                                     class_name=str(idx_class),
                                     prefix='prob',
                                     main_folder=configs['io']['main_folder'],
                                     cmap='viridis',
                                     interpolation=interpolation)
            # mlab.close(all=True)
            # plt.contour3(prob_prediction_class)
            #(prob_prediction_class)
            # make three-dimensional plot

        if calc_uncertainty:
            df_uncertainty = pd.DataFrame()
            for key in uncertainty.keys():
                df_uncertainty[key] = uncertainty[key]

            df = pd.concat([df_positions, df_uncertainty],
                           axis=1,
                           join_axes=[df_positions.index])
            df_uncertainty_sorted = df.sort_values([
                'strided_pattern_positions_z', 'strided_pattern_positions_y',
                'strided_pattern_positions_x'
            ],
                                                   ascending=True)

            uncertainty_sorted = df_uncertainty_sorted.drop(columns=[
                'strided_pattern_positions_z', 'strided_pattern_positions_y',
                'strided_pattern_positions_x'
            ])

            for key in uncertainty.keys():
                uncertainty_prediction = uncertainty_sorted[
                    key].values.reshape(z_max, y_max, x_max)

                # for idx_uncertainty in range(predictive_mean_sorted.shape[1]):
                plot_prediction_heatmaps(
                    uncertainty_prediction,
                    title='Uncertainty ({})'.format(str(key)),
                    main_folder=configs['io']['main_folder'],
                    cmap=cmap_uncertainty,
                    prefix='uncertainty',
                    suffix=str(key),
                    interpolation=interpolation_uncertainty)