Ejemplo n.º 1
0
    def test_get_design_matrix(self):
        # load data with pre-calculated prdfs
        target_list, structure_list = load_descriptor(
            desc_files=self.prdf_binaries_desc_file, configs=self.configs)

        # calculate design matrix
        design_matrix = get_design_matrix(structure_list,
                                          total_bins=50,
                                          max_dist=25)

        self.assertIsInstance(design_matrix, scipy.sparse.csr.csr_matrix)
        self.assertTrue(type(design_matrix) is not np.array)
        self.assertEqual(design_matrix.shape[0], len(structure_list))
    ase_atoms_list = read_ase_db(db_path=ase_db_file)

    desc_file_path = calc_descriptor(
        descriptor=descriptor,
        configs=configs,
        ase_atoms_list=ase_atoms_list,
        tmp_folder=tmp_folder,
        desc_folder=desc_folder,
        desc_info_file=desc_info_file,
        desc_file=str(desc_file_name) + '.tar.gz',
        format_geometry='aims',
        operations_on_structure=operations_on_structure_list[1],
        nb_jobs=-1)

    desc_file_path = '/home/ziletti/Documents/nomadml_docs/desc_folder/try1.tar.gz'
    target_list, structure_list = load_descriptor(desc_files=desc_file_path,
                                                  configs=configs)

    ase_db_file = write_ase_db(ase_atoms_list=structure_list,
                               db_name='elemental_solids_ncomms_7_classes_new',
                               main_folder=main_folder,
                               folder_name='db_ase')

    desc_file_path = '/home/ziletti/Documents/nomadml_docs/desc_folder/try1.tar.gz'
    target_list, structure_list = load_descriptor(desc_files=desc_file_path,
                                                  configs=configs)

    sys.exit()

    df, sprite_atlas = generate_facets_input(
        structure_list=structure_list,
        desc_metadata='diffraction_2d_intensity',
    prefix_file = "/home/ziletti/Documents/calc_xray/2d_nature_comm/bcc_to_sc/"
    suffix_file = "_bcc_to_sc.json.tar.gz"
    target_b_contributions = np.linspace(0.0, 1.0, num=21).tolist()
    bcc_to_sc_list = []
    for target_b_contrib in target_b_contributions:
        bcc_to_sc_list.append(prefix_file + str(target_b_contrib) +
                              suffix_file)

    desc_file_path = []
    desc_file_7_classes = [
        '/home/ziletti/Documents/calc_xray/2d_nature_comm/desc_folder/7_classes.tar.gz'
    ]
    desc_file_path.extend(desc_file_7_classes)
    desc_file_path.extend(bcc_to_sc_list)

    target_list, ase_atoms_list = load_descriptor(desc_files=desc_file_path,
                                                  configs=configs)

    new_labels = {
        "bct_139": ["139"],
        "bct_141": ["141"],
        "hex/rh": ["166", "194"],
        "sc": ["221"],
        "fcc": ["225"],
        "diam": ["227"],
        "bcc": ["229"]
    }

    path_to_x_train, path_to_y_train, path_to_summary_train = prepare_dataset(
        structure_list=ase_atoms_list,
        target_list=target_list,
        desc_metadata='diffraction_2d_intensity',
Ejemplo n.º 4
0
def make_strided_pattern_matching_dataset(polycrystal_file,
                                          descriptor,
                                          desc_metadata,
                                          configs,
                                          operations_on_structure=None,
                                          stride_size=(4., 4., 4.),
                                          box_size=12.0,
                                          init_sliding_volume=(14., 14., 14.),
                                          desc_file=None,
                                          desc_only=False,
                                          show_plot_lengths=True,
                                          desc_file_suffix_name='_pristine',
                                          nb_jobs=-1,
                                          padding_ratio=None,
                                          min_nb_atoms=20):
    if desc_file is None:
        logger.info("Calculating system's representation.")
        if nb_jobs == 1:

            ase_atoms_list = get_structures_by_boxes(
                polycrystal_file,
                stride_size=stride_size,
                box_size=box_size,
                show_plot_lengths=show_plot_lengths,
                padding_ratio=padding_ratio,
                init_sliding_volume=init_sliding_volume)
            from ai4materials.wrappers import _calc_descriptor
            desc_file = _calc_descriptor(
                ase_atoms_list=ase_atoms_list,
                descriptor=descriptor,
                configs=configs,
                logger=logger,
                desc_folder=configs['io']['desc_folder'],
                tmp_folder=configs['io']['tmp_folder'])
        else:

            desc_file = calc_polycrystal_desc(
                polycrystal_file,
                stride_size,
                box_size,
                descriptor,
                configs,
                desc_file_suffix_name,
                operations_on_structure,
                nb_jobs,
                show_plot_lengths,
                padding_ratio=padding_ratio,
                init_sliding_volume=init_sliding_volume)
    else:
        logger.info("Using the precomputed user-specified descriptor file.")

    if not desc_only:
        target_list, structure_list = load_descriptor(desc_files=desc_file,
                                                      configs=configs)

        # create dataset
        polycrystal_name = os.path.basename(polycrystal_file)
        dataset_name = '{0}_stride_{1}_{2}_{3}_box_size_{4}_{5}.tar.gz'.format(
            polycrystal_name, stride_size[0], stride_size[1], stride_size[2],
            box_size, desc_file_suffix_name)

        # if total number of atoms less than cutoff, set descriptor to NaN
        for structure in structure_list:
            if structure.get_number_of_atoms(
            ) <= min_nb_atoms:  # TODO: < or <=??
                structure.info['descriptor'][desc_metadata][:] = np.nan

        path_to_x_test, path_to_y_test, path_to_summary_test = prepare_dataset(
            structure_list=structure_list,
            target_list=target_list,
            desc_metadata=desc_metadata,
            dataset_name=dataset_name,
            target_name='target',
            target_categorical=True,
            input_dims=(52, 32),
            configs=configs,
            dataset_folder=configs['io']['dataset_folder'],
            main_folder=configs['io']['main_folder'],
            desc_folder=configs['io']['desc_folder'],
            tmp_folder=configs['io']['tmp_folder'])

        # get the number of strides in each directions in order to reshape properly
        strided_pattern_pos = []
        for structure in structure_list:
            strided_pattern_pos.append(
                structure.info['strided_pattern_positions'])

        strided_pattern_pos = np.asarray(strided_pattern_pos)

        path_to_strided_pattern_pos = os.path.abspath(
            os.path.normpath(
                os.path.join(
                    configs['io']['dataset_folder'],
                    '{0}_strided_pattern_pos.pkl'.format(dataset_name))))

        # write to file
        with open(path_to_strided_pattern_pos, 'wb') as output:
            pickle.dump(strided_pattern_pos, output, pickle.HIGHEST_PROTOCOL)
            logger.info("Writing strided pattern positions to {0}".format(
                path_to_strided_pattern_pos))

        logger.info("Dataset created at {}".format(
            configs['io']['dataset_folder']))
        logger.info("Strided pattern positions saved at {}".format(
            configs['io']['dataset_folder']))

    return path_to_x_test, path_to_y_test, path_to_summary_test, path_to_strided_pattern_pos
def get_classification_map(polycrystal_file,
                           descriptor,
                           desc_metadata,
                           configs,
                           checkpoint_dir,
                           checkpoint_filename,
                           operations_on_structure=None,
                           stride_size=(4., 4., 4.),
                           box_size=12.0,
                           init_sliding_volume=(14., 14., 14.),
                           desc_file=None,
                           desc_only=False,
                           show_plot_lengths=True,
                           calc_uncertainty=True,
                           mc_samples=10,
                           desc_file_suffix_name='_pristine',
                           nb_jobs=-1,
                           interpolation='none',
                           results_file=None,
                           conf_matrix_file=None,
                           train_set_name='hcp-bcc-sc-diam-fcc-pristine',
                           padding_ratio=None,
                           cmap_uncertainty='hot',
                           interpolation_uncertainty='none'):
    if desc_file is None:
        logger.info("Calculating system's representation.")
        desc_file = calc_polycrystal_desc(
            polycrystal_file,
            stride_size,
            box_size,
            descriptor,
            configs,
            desc_file_suffix_name,
            operations_on_structure,
            nb_jobs,
            show_plot_lengths,
            padding_ratio=padding_ratio,
            init_sliding_volume=init_sliding_volume)
    else:
        logger.info("Using the precomputed user-specified descriptor file.")

    if not desc_only:
        target_list, structure_list = load_descriptor(desc_files=desc_file,
                                                      configs=configs)

        # create dataset
        dataset_name = '{0}_stride_{1}_{2}_{3}_box_size_{4}_{5}.tar.gz'.format(
            polycrystal_file, stride_size[0], stride_size[1], stride_size[2],
            box_size, desc_file_suffix_name)

        path_to_x_test, path_to_y_test, path_to_summary_test = prepare_dataset(
            structure_list=structure_list,
            target_list=target_list,
            desc_metadata=desc_metadata,
            dataset_name=dataset_name,
            target_name='target',
            target_categorical=True,
            input_dims=(52, 32),
            configs=configs,
            dataset_folder=configs['io']['dataset_folder'],
            main_folder=configs['io']['main_folder'],
            desc_folder=configs['io']['desc_folder'],
            tmp_folder=configs['io']['tmp_folder'])

        path_to_x_train = os.path.join(configs['io']['dataset_folder'],
                                       train_set_name + '_x.pkl')
        path_to_y_train = os.path.join(configs['io']['dataset_folder'],
                                       train_set_name + '_y.pkl')
        path_to_summary_train = os.path.join(configs['io']['dataset_folder'],
                                             train_set_name + '_summary.json')

        x_train, y_train, dataset_info_train = load_dataset_from_file(
            path_to_x=path_to_x_train,
            path_to_y=path_to_y_train,
            path_to_summary=path_to_summary_train)

        x_test, y_test, dataset_info_test = load_dataset_from_file(
            path_to_x=path_to_x_test,
            path_to_y=path_to_y_test,
            path_to_summary=path_to_summary_test)

        params_cnn = {
            "nb_classes": dataset_info_train["data"][0]["nb_classes"],
            "classes": dataset_info_train["data"][0]["classes"],
            "batch_size": 32,
            "img_channels": 1
        }

        text_labels = np.asarray(dataset_info_test["data"][0]["text_labels"])
        numerical_labels = np.asarray(
            dataset_info_test["data"][0]["numerical_labels"])

        data_set_predict = make_data_sets(x_train_val=x_test,
                                          y_train_val=y_test,
                                          split_train_val=False,
                                          test_size=0.1,
                                          x_test=x_test,
                                          y_test=y_test)

        target_pred_class, target_pred_probs, prob_predictions, conf_matrix, uncertainty = predict(
            data_set_predict,
            params_cnn["nb_classes"],
            configs=configs,
            batch_size=params_cnn["batch_size"],
            checkpoint_dir=checkpoint_dir,
            checkpoint_filename=checkpoint_filename,
            show_model_acc=False,
            mc_samples=mc_samples,
            predict_probabilities=True,
            plot_conf_matrix=True,
            conf_matrix_file=conf_matrix_file,
            numerical_labels=numerical_labels,
            text_labels=text_labels,
            results_file=results_file,
            normalize=True)

        predictive_mean = prob_predictions

        # get the number of strides in each directions in order to reshape properly
        strided_pattern_positions = []
        for structure in structure_list:
            strided_pattern_positions.append(
                structure.info['strided_pattern_positions'])

        class_plot_pos = np.asarray(strided_pattern_positions)
        (z_max, y_max, x_max) = np.amax(class_plot_pos, axis=0) + 1

        # make a dataframe to order the prob_predictions
        # this is needed when we read from file - the structures are ordered in a different way after they are saved
        # this comes into play only if more than 10 values for each directions are used
        df_positions = pd.DataFrame(data=class_plot_pos,
                                    columns=[
                                        'strided_pattern_positions_z',
                                        'strided_pattern_positions_y',
                                        'strided_pattern_positions_x'
                                    ])

        # sort predictive mean
        df_predictive_mean = pd.DataFrame(data=predictive_mean)
        df = pd.concat([df_positions, df_predictive_mean],
                       axis=1,
                       join_axes=[df_positions.index])
        df_predictive_mean_sorted = df.sort_values([
            'strided_pattern_positions_z', 'strided_pattern_positions_y',
            'strided_pattern_positions_x'
        ],
                                                   ascending=True)

        predictive_mean_sorted = df_predictive_mean_sorted.drop(columns=[
            'strided_pattern_positions_z', 'strided_pattern_positions_y',
            'strided_pattern_positions_x'
        ]).values

        for idx_class in range(predictive_mean_sorted.shape[1]):

            prob_prediction_class = predictive_mean_sorted[:,
                                                           idx_class].reshape(
                                                               z_max, y_max,
                                                               x_max)

            plot_prediction_heatmaps(prob_prediction_class,
                                     title='Probability',
                                     class_name=str(idx_class),
                                     prefix='prob',
                                     main_folder=configs['io']['main_folder'],
                                     cmap='viridis',
                                     interpolation=interpolation)
            # mlab.close(all=True)
            # plt.contour3(prob_prediction_class)
            #(prob_prediction_class)
            # make three-dimensional plot

        if calc_uncertainty:
            df_uncertainty = pd.DataFrame()
            for key in uncertainty.keys():
                df_uncertainty[key] = uncertainty[key]

            df = pd.concat([df_positions, df_uncertainty],
                           axis=1,
                           join_axes=[df_positions.index])
            df_uncertainty_sorted = df.sort_values([
                'strided_pattern_positions_z', 'strided_pattern_positions_y',
                'strided_pattern_positions_x'
            ],
                                                   ascending=True)

            uncertainty_sorted = df_uncertainty_sorted.drop(columns=[
                'strided_pattern_positions_z', 'strided_pattern_positions_y',
                'strided_pattern_positions_x'
            ])

            for key in uncertainty.keys():
                uncertainty_prediction = uncertainty_sorted[
                    key].values.reshape(z_max, y_max, x_max)

                # for idx_uncertainty in range(predictive_mean_sorted.shape[1]):
                plot_prediction_heatmaps(
                    uncertainty_prediction,
                    title='Uncertainty ({})'.format(str(key)),
                    main_folder=configs['io']['main_folder'],
                    cmap=cmap_uncertainty,
                    prefix='uncertainty',
                    suffix=str(key),
                    interpolation=interpolation_uncertainty)