def test_upload_pif():
    client = CitrinationClient(environ['CITRINATION_API_KEY'], environ['CITRINATION_SITE'])
    dataset = loads(client.create_data_set(name="Tutorial dataset", description="Dataset for tutorial", share=0).content.decode('utf-8'))['id']
    pif = System()
    pif.id = 0

    with open("tmp.json", "w") as fp:
        dump(pif, fp)
    response = loads(client.upload_file("tmp.json", dataset))
    assert response["message"] == "Upload is complete."
def test_upload_pif():
    client = CitrinationClient(environ['CITRINATION_API_KEY'], 'https://stage.citrination.com')
    dataset = loads(client.create_data_set(name="Tutorial dataset", description="Dataset for tutorial", share=0).content.decode('utf-8'))['id']
    pif = System()
    pif.id = 0

    with TemporaryDirectory() as tmpdir:
        tempname = join(tmpdir, "pif.json")
        with open(tempname, "w") as fp:
            dump(pif, fp)
        response = loads(client.upload_file(tempname, dataset))
    assert response["message"] == "Upload is complete."
Exemplo n.º 3
0
def main():
    # Create an instance of the client
    client = CitrinationClient(os.environ["CITRINATION_API_KEY"],
                               'https://citrination.com')

    # Name of the file storing the data you would like to upload or predict:
    filename = "filename"
    # Number of the Dataset you would like to upload to (found in URL):
    dataset_id = 0000
    # Number of the DataView you would like to upload to (found in URL):
    dataview_id = 0000

    # Upload data
    upload(client, filename, dataset_id)

    # Retrieve predictions
    predict(client, filename, dataview_id)
Exemplo n.º 4
0
    def __init__(self, api_key=None):
        """
        Args:
            api_key: (str) Your Citrine API key, or None if
                you've set the CITRINE_KEY environment variable
        """
        if api_key:
            api_key = api_key
        elif "CITRINATION_API_KEY" in os.environ:
            api_key = os.environ["CITRINATION_API_KEY"]
        elif "CITRINE_KEY" in os.environ:
            api_key = os.environ["CITRINE_KEY"]
        else:
            raise AttributeError('''Citrine API key not found.

            You need to get an API key from Citrination, and either supply it as an argument to 
            this class or set it as the value of the CITRINATION_API_KEY enviornmental variable

            See https://citrineinformatics.github.io/api-documentation/quickstart/index.html
            for details on how to get an API key''')

        self.client = CitrinationClient(api_key, "https://citrination.com")
def main():
	start_time = time.time()
	# Create instance of the client
	client = CitrinationClient(environ["CITRINATION_API_KEY"], 'https://citrination.com')
	# dataset_id is the id of the dataset you want to upload new data to
	dataset_id = 161880
	# dataview_id is the DataView you want to update after file upload 
	dataview_id = 4815
	# url is the url of the dataview we'd like to click save on.
	url = 'https://citrination.com/data_views/' + str(dataview_id) + '/ml_config'
	# Read in new data. For test we are using just chemical formula, DFT energy, and only predicting Tg
	filename = "predict_data.csv"
	
	# Read in csv file using pandas to make df
	new_data = pd.read_csv(filename)
	# Pull out the desired properties
	form = new_data['formula']
	energy = new_data['PROPERTY: Nearest DFT Formation Energy (eV)']
	tg = new_data['PROPERTY: Tg (K)']
	tl = new_data['PROPERTY: Tl (K)']
	tx = new_data['PROPERTY: Tx (K)']
	
	# Use itertools to create combinations of the data set to leave out/use for training
	# Make a list of the indexes of the dataset
	indices = []
	for i in range(0, len(form)):
		indices.append(i)
	
	# Generate all combos of the indices of the extra data to use for training/testing
	combos = []
	# Set how many combinations to randomly pull for each size of subset in the power set
	num_rand_select = 4
	for i in range(0, len(indices)):
		combos = [list(x) for x in itertools.combinations(indices, i)]
		
		# Loop through each combo of the indices for training
		for c in range(0, num_rand_select):
			one_combo = randrange(0, len(combos))			
			# for each index in the combination of training indices
			input = []
			for i in combos[one_combo]:
				# Convert to pif and store pif in JSON
				input.append([form[i], energy[i], tg[i], tl[i], tx[i]])
			
			# Write to CSV and pass the csv file path to make_pif
			with open("training_data.csv", 'w', newline='') as training_csv:
				writer = csv.writer(training_csv)
				writer.writerow(['formula', 'PROPERTY: Nearest DFT Formation Energy (eV)', 'PROPERTY: Tg (K)', 'PROPERTY: Tl (K)', 'PROPERTY: Tx (K)'])
				for i in range(0, len(input)):
					writer.writerow(input[i])
			training_csv.close()
			
			pif_output = make_pif("training_data.csv")
		
			# Upload data. Params are (dataset id, file path)
			client.data.upload(dataset_id, pif_output)
			time.sleep(10)
			
			click_save(url)
			
			# Wait for model to retrain
			model_report_url = 'https://citrination.com/data_views/' + str(dataview_id) + '/data_summary'
			wait_for_train(model_report_url)
			
			# Make predictions and store them to a CSV
			# Make copies of the lists of all indices so we can remove the training indices
			pred_form = form[:].tolist()
			pred_energy = energy[:].tolist()
			pred_tg = tg[:].tolist()
			for i in reversed([c]):
				del pred_form[i]
				del pred_energy[i]
				del pred_tg[i]
			
			# Write the formula and energy of the points to predict to a CSV
			predict_data_file = "testing_data.csv"
			with open(predict_data_file, 'w', newline='') as testing_csv:
				writer = csv.writer(testing_csv)
				writer.writerow(['formula', 'PROPERTY: Nearest DFT Formation Energy (eV)'])
				for i in range(0, len(pred_form)):
					writer.writerow([pred_form[i], pred_energy[i]])
			testing_csv.close()
			
			# Predict the Tg, Tl, Tx of these data points
			# Add try, catch to prevent 'candidates' key error
			err_count = 0
			try:
				make_predictions(client, predict_data_file, str(dataview_id))
			except:
				err_count+=1
				while err_count < 10:
					try:
						make_predictions(client, predict_data_file, str(dataview_id))
					except:
						err_count+=1
				
			if err_count > 0:
				print("Errors encountered with set " + str(c))
			err_count = 0
	
	print("Run time: " + str(time.time() - start_time))
Exemplo n.º 6
0
from citrination_client import CitrinationClient
from os import environ

client = CitrinationClient(environ["CITRINATION_API_KEY"],
                           environ["CITRINATION_SITE"])
data_views_client = client.data_views
Exemplo n.º 7
0
class CitrineDataRetrieval:
    def __init__(self, api_key=None):
        """
        Args:
            api_key: (str) Your Citrine API key, or None if you've set the CITRINE_KEY environment variable

        Returns: None
        """
        api_key = api_key if api_key else os.environ['CITRINE_KEY']
        self.client = CitrinationClient(api_key, 'http://citrination.com')

    def get_dataframe(self,
                      term=None,
                      formula=None,
                      property=None,
                      contributor=None,
                      reference=None,
                      min_measurement=None,
                      max_measurement=None,
                      from_record=None,
                      data_set_id=None,
                      max_results=None,
                      show_columns=None):
        """
        Gets data from MP in a dataframe format.
        See client docs at http://citrineinformatics.github.io/api-documentation/ for more details on these parameters.

        Args:
            term: (str) general search string; this is searched against all fields
            formula: (str) filter for the chemical formula field; only those results that have chemical formulas that
                contain this string will be returned
            property: (str) name of the property to search for
            contributor: (str) filter for the contributor field; only those results that have contributors that
                contain this string will be returned
            reference: (str) filter for the reference field; only those results that have contributors that
                contain this string will be returned
            min_measurement: (str/num) minimum of the property value range
            max_measurement: (str/num) maximum of the property value range
            from_record: (int) index of the first record to return (indexed from 0)
            data_set_id: (int) id of the particular data set to search on
            max_results: (int) number of records to limit the results to

        Returns: (object) Pandas dataframe object containing the results
        """

        json_data = []
        start = from_record if from_record else 0
        per_page = 100
        refresh_time = 3  # seconds to wait between search calls

        while True:
            if max_results and max_results < per_page:  # use per_page=max_results, eg: in case of max_results=68 < 100
                data = self.client.search(term=term,
                                          formula=formula,
                                          property=property,
                                          contributor=contributor,
                                          reference=reference,
                                          min_measurement=min_measurement,
                                          max_measurement=max_measurement,
                                          from_record=start,
                                          per_page=max_results,
                                          data_set_id=data_set_id)
            else:
                data = self.client.search(term=term,
                                          formula=formula,
                                          property=property,
                                          contributor=contributor,
                                          reference=reference,
                                          min_measurement=min_measurement,
                                          max_measurement=max_measurement,
                                          from_record=start,
                                          per_page=per_page,
                                          data_set_id=data_set_id)
            size = len(data.json()['results'])
            start += size
            json_data.append(data.json()['results'])
            if max_results and len(
                    json_data
            ) * per_page > max_results:  # check if limit is reached
                json_data = json_data[:(
                    max_results /
                    per_page)]  # get first multiple of 100 records
                json_data.append(
                    data.json()['results'][:max_results %
                                           per_page])  # get remaining records
                break
            if size < per_page:  # break out of last loop of results
                break
            time.sleep(refresh_time)

        non_meas_df = pd.DataFrame()  # df w/o measurement column
        meas_df = pd.DataFrame()  # df containing only measurement column

        counter = 0  # variable to keep count of sample hit and set indexes

        for page in json_data:
            # df = pd.concat((json_normalize(hit) for hit in set))   # Useful tool for the future
            for hit in tqdm(page):
                counter += 1
                if 'sample' in hit.keys():
                    sample_value = hit['sample']
                    sample_normdf = json_normalize(sample_value)
                    # Make a DF of all non-'measurement' fields
                    non_meas_cols = [
                        cols for cols in sample_normdf.columns
                        if "measurement" not in cols
                    ]
                    non_meas_row = pd.DataFrame()
                    for col in non_meas_cols:
                        non_meas_row[col] = sample_normdf[col]
                    non_meas_row.index = [counter] * len(sample_normdf)
                    non_meas_df = non_meas_df.append(non_meas_row)
                    # Make a DF of the 'measurement' array
                    if 'measurement' in sample_value:
                        meas_normdf = json_normalize(
                            sample_value['measurement'])
                        # Extract numbers of properties
                        if 'property.scalar' in meas_normdf.columns:
                            for row, col in enumerate(
                                    meas_normdf['property.scalar']):
                                for item in col:
                                    if 'value' in item:
                                        meas_normdf.xs(row)[
                                            'property.scalar'] = item['value']
                                    # TODO: ask Anubhav how to deal with these and rest of formats
                                    elif 'minimum' in item and 'maximum' in item:
                                        meas_normdf.xs(
                                            row
                                        )['property.scalar'] = 'Minimum = ' + item[
                                            'minimum'] + ', ' + 'Maximum = ' + item[
                                                'maximum']
                        # Take all property rows and convert them into columns
                        prop_df = pd.DataFrame()
                        prop_cols = [
                            cols for cols in meas_normdf.columns
                            if "property" in cols
                        ]
                        for col in prop_cols:
                            prop_df[col] = meas_normdf[col]
                        prop_df.index = [counter] * len(meas_normdf)
                        prop_df = prop_df.drop_duplicates(['property.name'])
                        if 'property.scalar' in meas_normdf.columns:
                            prop_df = prop_df.pivot(columns='property.name',
                                                    values='property.scalar')
                        elif 'property.matrix' in meas_normdf.columns:
                            prop_df = prop_df.pivot(columns='property.name',
                                                    values='property.matrix')
                        prop_df = prop_df.convert_objects(
                            convert_numeric=True
                        )  # Convert columns from object to num
                        # Making a single row DF of non-'measurement.property' columns
                        non_prop_df = pd.DataFrame()
                        non_prop_cols = [
                            cols for cols in meas_normdf.columns
                            if "property" not in cols
                        ]
                        for col in non_prop_cols:
                            non_prop_df['measurement.' +
                                        col] = meas_normdf[col]
                        if len(
                                non_prop_df
                        ) > 0:  # Do not index empty DF (non-'measuremenet.property' columns absent)
                            non_prop_df.index = [counter] * len(meas_normdf)
                        non_prop_df = non_prop_df[:
                                                  1]  # Take only first row - does not collect non-unique rows
                        units_df = pd.DataFrame(
                        )  # Get property unit and insert it as a dict
                        if 'property.units' in meas_normdf.columns:
                            curr_units = dict(
                                zip(meas_normdf['property.name'],
                                    meas_normdf['property.units']))
                            units_df['property.units'] = [curr_units]
                            units_df.index = [counter] * len(meas_normdf)
                        meas_df = meas_df.append(
                            pd.concat([prop_df, non_prop_df, units_df],
                                      axis=1))

        df = pd.concat([non_meas_df, meas_df], axis=1)
        df.index.name = 'sample'
        if show_columns:
            for column in df.columns:
                if column not in show_columns:
                    df.drop(column, axis=1, inplace=True)
        return df
        df.append
Exemplo n.º 8
0
def test_start_client():
    client = CitrinationClient(environ['CITRINATION_API_KEY'],
                               environ['CITRINATION_SITE'])
Exemplo n.º 9
0
# -*- coding: utf-8 -*-
"""
Created on Thu Jul  8 09:04:52 2021

@author: tquah
"""

from citrination_client import CitrinationClient
import os
import glob
apikey_path = os.path.join('/home/tquah/.citrine_api_key')
op = open(apikey_path, 'r')
apikey = op.read().split('\n')
op.close()

client = CitrinationClient(apikey[0])

data_client = client.data

# file_path = "./test.json"
# dataset_id = 195077
# data_client.upload(dataset_id, file_path)

os.chdir('/home/tquah/Projects/TESTPHASES')

filelist = glob.glob('test*')

for i in range(len(filelist)):
    dataset_id = 51213104
    data_client.upload(dataset_id, filelist[i])
Exemplo n.º 10
0
class CitrineFeatureGeneration(object):
    """
    Class to generate new features using Citrine data and dataframe containing material compositions

    Attributes:
        configdict <dict> : MASTML configfile object as dict
        dataframe <pandas dataframe> : dataframe containing x and y data and feature names
        api_key <str> : your Citrination API key

    Methods:
        generate_citrine_features : generates Citrine feature set based on compositions in dataframe
            args:
                save_to_csv <bool> : whether to save the magpie feature set to a csv file
            returns:
                dataframe <pandas dataframe> : dataframe containing magpie feature set
    """
    def __init__(self, configdict, dataframe, api_key):
        self.configdict = configdict
        self.dataframe = dataframe
        self.api_key = api_key
        self.client = CitrinationClient(api_key, 'https://citrination.com')

    @timeit
    def generate_citrine_features(self, save_to_csv=True):
        logging.info(
            'WARNING: You have specified generation of features from Citrine. Based on which materials you are'
            'interested in, there may be many records to parse through, thus this routine may take a long time to complete!'
        )
        compositions = self.dataframe['Material compositions'].tolist()
        citrine_dict_property_min = dict()
        citrine_dict_property_max = dict()
        citrine_dict_property_avg = dict()
        for composition in compositions:
            pifquery = self._get_pifquery(composition=composition)
            property_name_list, property_value_list = self._get_pifquery_property_list(
                pifquery=pifquery)
            property_names_unique, parsed_property_min, parsed_property_max, parsed_property_avg = self._parse_pifquery_property_list(
                property_name_list=property_name_list,
                property_value_list=property_value_list)
            citrine_dict_property_min[composition] = parsed_property_min
            citrine_dict_property_max[composition] = parsed_property_max
            citrine_dict_property_avg[composition] = parsed_property_avg

        dataframe = self.dataframe
        citrine_dict_list = [
            citrine_dict_property_min, citrine_dict_property_max,
            citrine_dict_property_avg
        ]
        for citrine_dict in citrine_dict_list:
            dataframe_citrine = pd.DataFrame.from_dict(data=citrine_dict,
                                                       orient='index')
            # Need to reorder compositions in new dataframe to match input dataframe
            dataframe_citrine = dataframe_citrine.reindex(
                self.dataframe['Material compositions'].tolist())
            # Need to make compositions the first column, instead of the row names
            dataframe_citrine.index.name = 'Material compositions'
            dataframe_citrine.reset_index(inplace=True)
            # Need to delete duplicate column before merging dataframes
            del dataframe_citrine['Material compositions']
            # Merge magpie feature dataframe with originally supplied dataframe
            dataframe = DataframeUtilities().merge_dataframe_columns(
                dataframe1=dataframe, dataframe2=dataframe_citrine)

        if save_to_csv == bool(True):
            # Get y_feature in this dataframe, attach it to save path
            for column in dataframe.columns.values:
                if column in self.configdict['General Setup'][
                        'target_feature']:
                    filetag = column
            dataframe.to_csv(self.configdict['General Setup']['save_path'] +
                             "/" + 'input_with_citrine_features' + '_' +
                             str(filetag) + '.csv',
                             index=False)

        return dataframe

    def _get_pifquery(self, composition):
        pif_query = PifQuery(system=SystemQuery(
            chemical_formula=ChemicalFieldQuery(filter=ChemicalFilter(
                equal=composition))))
        # Check if any results found
        if 'hits' not in self.client.search(pif_query).as_dictionary():
            raise KeyError('No results found!')
        pifquery = self.client.search(pif_query).as_dictionary()['hits']
        return pifquery

    def _get_pifquery_property_list(self, pifquery):
        property_name_list = list()
        property_value_list = list()
        accepted_properties_list = [
            'mass', 'space group', 'band', 'Band', 'energy', 'volume',
            'density', 'dielectric', 'Dielectric', 'Enthalpy', 'Convex',
            'Magnetization', 'Elements', 'Modulus', 'Shear', "Poisson's",
            'Elastic', 'Energy'
        ]
        for result_number, results in enumerate(pifquery):
            for system_heading, system_value in results.items():
                if system_heading == 'system':
                    # print('FOUND SYSTEM')
                    for property_name, property_value in system_value.items():
                        if property_name == 'properties':
                            # print('FOUND PROPERTIES')
                            # pprint(property_value)
                            for list_index, list_element in enumerate(
                                    property_value):
                                for name, value in property_value[
                                        list_index].items():
                                    if name == 'name':
                                        # Check that the property name is in the acceptable property list
                                        if value != "CIF File":
                                            for entry in accepted_properties_list:
                                                if entry in value:
                                                    # print('found acceptable name', entry, 'for name', value, 'with value',property_value[list_index]['scalars'][0]['value'] )
                                                    property_name_list.append(
                                                        value)
                                                    try:
                                                        property_value_list.append(
                                                            float(
                                                                property_value[
                                                                    list_index]
                                                                ['scalars'][0]
                                                                ['value']))
                                                    except (ValueError,
                                                            KeyError):
                                                        # print('found something to remove', property_value[list_index]['scalars'][0]['value'])
                                                        property_name_list.pop(
                                                            -1)
                                                        continue
        return property_name_list, property_value_list

    def _parse_pifquery_property_list(self, property_name_list,
                                      property_value_list):
        parsed_property_max = dict()
        parsed_property_min = dict()
        parsed_property_avg = dict()
        property_names_unique = list()
        if len(property_name_list) != len(property_value_list):
            print(
                'Error! Length of property name and property value lists are not the same. There must be a bug in the _get_pifquerey_property_list method'
            )
            sys.exit()
        else:
            # Get unique property names
            for name in property_name_list:
                if name not in property_names_unique:
                    property_names_unique.append(name)
            for unique_name in property_names_unique:
                unique_property = list()
                unique_property_avg = 0
                count = 0
                for i, name in enumerate(property_name_list):
                    # Only include property values whose name are same as those in unique_name list
                    if name == unique_name:
                        count += 1  # count how many instances of the same property occur
                        unique_property_avg += property_value_list[i]
                        unique_property.append(property_value_list[i])
                unique_property_min = min(entry for entry in unique_property)
                unique_property_max = max(entry for entry in unique_property)
                unique_property_avg = unique_property_avg / count
                parsed_property_min[str(unique_name) +
                                    "_min"] = unique_property_min
                parsed_property_max[str(unique_name) +
                                    "_max"] = unique_property_max
                parsed_property_avg[str(unique_name) +
                                    "_avg"] = unique_property_avg

        return property_names_unique, parsed_property_min, parsed_property_max, parsed_property_avg
Exemplo n.º 11
0
from citrination_client import CitrinationClient
from os import environ

client = CitrinationClient("my_api_key")
def test_predict():
    client = CitrinationClient(environ['CITRINATION_API_KEY'], environ['CITRINATION_SITE'])
    inputs = [{"CHEMICAL_FORMULA": "AlCu"}, ]
    resp = client.predict("betterdensitydemo", inputs)
    prediction = resp['candidates'][0]['Density']
    assert abs(prediction[0] - 5.786) < 0.1
Exemplo n.º 13
0
#### standard packages ####
from os import environ
from time import sleep

#### third party libraries ####
from citrination_client import CitrinationClient
import pandas as pd

#### Set up a citrination client
client = CitrinationClient(
    environ.get("SLAC_CITRINATION_API_KEY"),
    "https://slac.citrination.com"
)

#### Initialize the client
dataset_id = 111
data_view_id = 97
model_client = client.models


view = model_client.get_data_view(data_view_id)

# get name
# get role
# get type
# get lower
# get upper
descriptor_list = []

for column in view.columns:
    name, role, group_by_key, column_type, categories, units, lower_bound, upper_bound, length, balance_element, basis = [None for i in range(0, 11)]
Exemplo n.º 14
0
 def start(self):
     self.address = self.inputs['address'] 
     f = open(self.inputs['api_key_file'],'r')
     self.api_key = str(f.readline()).strip()
     f.close()
     self.ctn_client = CitrinationClient(api_key = self.api_key, site = self.address)
Exemplo n.º 15
0
class CitrineDataRetrieval:
    """
    CitrineDataRetrieval is used to retrieve data from
    the Citrination database.  See API client docs at
    http://citrineinformatics.github.io/api-documentation/
    """
    def __init__(self, api_key=None):
        """
        Args:
            api_key: (str) Your Citrine API key, or None if
                you've set the CITRINE_KEY environment variable

        Returns: None
        """
        api_key = api_key if api_key else os.environ["CITRINE_KEY"]
        self.client = CitrinationClient(api_key, "https://citrination.com")

    def get_dataframe(self,
                      formula=None,
                      prop=None,
                      data_type=None,
                      reference=None,
                      min_measurement=None,
                      max_measurement=None,
                      from_record=None,
                      data_set_id=None,
                      max_results=None,
                      show_columns=None):
        """
        Gets a Pandas dataframe object from data retrieved from
        the Citrine API.  See client docs at
        http://citrineinformatics.github.io/api-documentation/
        for more details on input parameters.

        Args:
            formula: (str) filter for the chemical formula field; only those
                results that have chemical formulas that contain this string
                will be returned
            prop: (str) name of the property to search for
            data_type: (str) 'EXPERIMENTAL'/'COMPUTATIONAL'/'MACHINE_LEARNING';
                filter for properties obtained from experimental work,
                computational methods, or machine learning.
            reference: (str) filter for the reference field; only those
                results that have contributors that contain this string
                will be returned
            min_measurement: (str/num) minimum of the property value range
            max_measurement: (str/num) maximum of the property value range
            from_record: (int) index of first record to return (indexed from 0)
            data_set_id: (int) id of the particular data set to search on
            max_results: (int) number of records to limit the results to
            show_columns: (list) list of columns to show from the
                resulting dataframe

        Returns: (object) Pandas dataframe object containing the results

        """
        # Get all of the jsons from client
        jsons = self.get_api_data(formula=formula,
                                  prop=prop,
                                  data_type=data_type,
                                  reference=reference,
                                  min_measurement=min_measurement,
                                  max_measurement=max_measurement,
                                  from_record=from_record,
                                  data_set_id=data_set_id,
                                  max_results=max_results)

        non_prop_df = pd.DataFrame()  # df w/o measurement column
        prop_df = pd.DataFrame()  # df containing only measurement column

        counter = 0  # variable to keep count of sample hit and set indexes

        for hit in tqdm(jsons):

            counter += 1  # Keep a count to appropriately index the rows

            if "system" in hit.keys(
            ):  # Check if 'system' key exists, else skip
                system_value = hit["system"]
                system_normdf = json_normalize(system_value)

                # Make a DF of all non-'properties' fields
                non_prop_cols = [
                    cols for cols in system_normdf.columns
                    if "properties" not in cols
                ]
                non_prop_row = pd.DataFrame()
                for col in non_prop_cols:
                    non_prop_row[col] = system_normdf[col]
                non_prop_row.index = [counter] * len(system_normdf)
                non_prop_df = non_prop_df.append(non_prop_row)

                # Make a DF of the 'properties' array
                if "properties" in system_value:

                    p_df = pd.DataFrame()

                    # Rename duplicate property names in a record with progressive numbering
                    all_prop_names = [
                        x["name"] for x in system_value["properties"]
                    ]

                    counts = {
                        k: v
                        for k, v in Counter(all_prop_names).items() if v > 1
                    }

                    for i in reversed(range(len(all_prop_names))):
                        item = all_prop_names[i]
                        if item in counts and counts[item]:
                            all_prop_names[i] += "_" + str(counts[item])
                            counts[item] -= 1

                    # add each property, and its associated fields, as a new column
                    for p_idx, prop in enumerate(system_value["properties"]):

                        # Rename property name according to above duplicate numbering
                        prop["name"] = all_prop_names[p_idx]

                        if "scalars" in prop:
                            p_df.set_value(counter, prop["name"],
                                           parse_scalars(prop["scalars"]))
                        elif "vectors" in prop:
                            p_df[prop["name"]] = prop["vectors"]
                        elif "matrices" in prop:
                            p_df[prop["name"]] = prop["matrices"]

                        # parse all keys in the Property object except 'name', 'scalars', 'vectors', and 'matrices'
                        for prop_key in prop:

                            if prop_key not in [
                                    "name", "scalars", "vectors", "matrices"
                            ]:

                                # If value is a list of multiple items, set the cell to the entire list by first
                                # converting to object type, else results in a ValueError/IndexError
                                if type(prop[prop_key]) == list and len(
                                        prop[prop_key]) > 1:
                                    p_df[prop["name"] + "-" +
                                         prop_key] = np.nan
                                    p_df[prop["name"] + "-" + prop_key] = \
                                        p_df[prop["name"] + "-" + prop_key].astype(object)

                                p_df.set_value(counter,
                                               prop["name"] + "-" + prop_key,
                                               prop[prop_key])

                    p_df.index = [counter]
                    prop_df = prop_df.append(p_df)

        # Concatenate 'properties' and 'non-properties' dataframes
        df = pd.concat([non_prop_df, prop_df], axis=1)
        df.index.name = "system"

        # Remove uninformative columns, such as 'category' and 'uid'
        df.drop(["category", "uid"], axis=1, inplace=True)

        # Filter out columns not selected
        if show_columns:
            df = df[show_columns]

        return df

    def get_api_data(self,
                     formula=None,
                     prop=None,
                     data_type=None,
                     reference=None,
                     min_measurement=None,
                     max_measurement=None,
                     from_record=None,
                     data_set_id=None,
                     max_results=None):
        """
        Gets raw api data from Citrine in json format. See client docs
        at http://citrineinformatics.github.io/api-documentation/
        for more details on these parameters.

        Args:
            formula: (str) filter for the chemical formula field; only those
                results that have chemical formulas that contain this string
                will be returned
            prop: (str) name of the property to search for
            data_type: (str) 'EXPERIMENTAL'/'COMPUTATIONAL'/'MACHINE_LEARNING';
                filter for properties obtained from experimental work,
                computational methods, or machine learning.
            reference: (str) filter for the reference field; only those
                results that have contributors that contain this string
                will be returned
            min_measurement: (str/num) minimum of the property value range
            max_measurement: (str/num) maximum of the property value range
            from_record: (int) index of first record to return (indexed from 0)
            data_set_id: (int) id of the particular data set to search on
            max_results: (int) number of records to limit the results to

        Returns: (list) of jsons/pifs returned by Citrine's API
        """

        json_data = []
        start = from_record if from_record else 0
        per_page = 100
        refresh_time = 3  # seconds to wait between search calls

        # Construct all of the relevant queries from input args
        formula_query = ChemicalFieldQuery(filter=ChemicalFilter(
            equal=formula))
        prop_query = PropertyQuery(
            name=FieldQuery(filter=Filter(equal=prop)),
            value=FieldQuery(
                filter=Filter(min=min_measurement, max=max_measurement)),
            data_type=FieldQuery(filter=Filter(equal=data_type)))
        ref_query = ReferenceQuery(doi=FieldQuery(filter=Filter(
            equal=reference)))

        system_query = PifSystemQuery(chemical_formula=formula_query,
                                      properties=prop_query,
                                      references=ref_query)
        dataset_query = DatasetQuery(id=Filter(equal=data_set_id))
        data_query = DataQuery(system=system_query, dataset=dataset_query)

        while True:
            # use per_page=max_results, eg: in case of max_results=68 < 100
            if max_results and max_results < per_page:
                pif_query = PifSystemReturningQuery(query=data_query,
                                                    from_index=start,
                                                    size=max_results)
            else:
                pif_query = PifSystemReturningQuery(query=data_query,
                                                    from_index=start,
                                                    size=per_page)

            # Check if any results found
            if "hits" not in self.client.search(pif_query).as_dictionary():
                raise KeyError("No results found!")

            data = self.client.search(pif_query).as_dictionary()["hits"]
            size = len(data)
            start += size
            json_data.extend(data)

            # check if limit is reached
            if max_results and len(json_data) > max_results:
                # get first multiple of 100 records
                json_data = json_data[:max_results]
                break
            if size < per_page:  # break out of last loop of results
                break

            time.sleep(refresh_time)

        return json_data
Exemplo n.º 16
0
class CitrineFeatureGeneration(object):
    """
    Class to generate new features using Citrine data and dataframe containing material compositions
    Datarame must have a column named "Material compositions".

    Args:
        configdict (dict) : MASTML configfile object as dict
        dataframe (pandas dataframe) : dataframe containing x and y data and feature names
        api_key (str) : your Citrination API key

    Methods:
        generate_citrine_features : generates Citrine feature set based on compositions in dataframe

            Args:
                save_to_csv (bool) : whether to save the magpie feature set to a csv file

            Returns:
                pandas dataframe : dataframe containing magpie feature set
    """
    def __init__(self, dataframe, api_key, composition_feature):
        self.dataframe = dataframe
        self.api_key = api_key
        self.client = CitrinationClient(api_key, 'https://citrination.com')
        self.composition_feature = composition_feature

    def generate_citrine_features(self):
        log.warning(
            'WARNING: You have specified generation of features from Citrine. Based on which'
            ' materials you are interested in, there may be many records to parse through, thus'
            ' this routine may take a long time to complete!')
        try:
            compositions = self.dataframe[self.composition_feature].tolist()
        except KeyError as e:
            log.error(f'original python error: {str(e)}')
            raise utils.MissingColumnError(
                'Error! No column named {self.composition_feature} found in your input data file. '
                'To use this feature generation routine, you must supply a material composition for each data point'
            )
        citrine_dict_property_min = dict()
        citrine_dict_property_max = dict()
        citrine_dict_property_avg = dict()

        # before: ~11 seconds
        # made into a func so we can do requests in parallel

        # now like 1.8 secs!
        pool = multiprocessing.Pool(processes=20)
        #result_tuples = pool.map(self._load_composition, compositions)
        result_tuples = map(self._load_composition, compositions)

        for comp, (prop_min, prop_max,
                   prop_avg) in zip(compositions, result_tuples):
            citrine_dict_property_min[comp] = prop_min
            citrine_dict_property_max[comp] = prop_max
            citrine_dict_property_avg[comp] = prop_avg

        dataframe = self.dataframe
        citrine_dict_list = [
            citrine_dict_property_min, citrine_dict_property_max,
            citrine_dict_property_avg
        ]
        for citrine_dict in citrine_dict_list:
            dataframe_citrine = pd.DataFrame.from_dict(data=citrine_dict,
                                                       orient='index')
            # Need to reorder compositions in new dataframe to match input dataframe
            dataframe_citrine = dataframe_citrine.reindex(
                self.dataframe[self.composition_feature].tolist())
            # Need to make compositions the first column, instead of the row names
            dataframe_citrine.index.name = self.composition_feature
            dataframe_citrine.reset_index(inplace=True)
            # Need to delete duplicate column before merging dataframes
            del dataframe_citrine[self.composition_feature]
            # Merge magpie feature dataframe with originally supplied dataframe
            dataframe = DataframeUtilities().merge_dataframe_columns(
                dataframe1=dataframe, dataframe2=dataframe_citrine)

        return dataframe

    def _load_composition(self, composition):
        pifquery = self._get_pifquery(composition=composition)
        property_name_list, property_value_list = self._get_pifquery_property_list(
            pifquery=pifquery)
        #print("Citrine Feature Generation: ", composition, property_name_list, property_value_list)
        property_names_unique, parsed_property_min, parsed_property_max, parsed_property_avg = self._parse_pifquery_property_list(
            property_name_list=property_name_list,
            property_value_list=property_value_list)
        return parsed_property_min, parsed_property_max, parsed_property_avg

    def _get_pifquery(self, composition):
        # TODO: does this stop csv generation on first invalid composition?
        # TODO: Is there a way to send many compositions in one call to citrine?
        pif_query = PifQuery(system=SystemQuery(
            chemical_formula=ChemicalFieldQuery(filter=ChemicalFilter(
                equal=composition))))
        # Check if any results found
        if 'hits' not in self.client.search(pif_query).as_dictionary():
            raise KeyError('No results found!')
        pifquery = self.client.search(pif_query).as_dictionary()['hits']
        return pifquery

    def _get_pifquery_property_list(self, pifquery):
        property_name_list = list()
        property_value_list = list()
        accepted_properties_list = [
            'mass', 'space group', 'band', 'Band', 'energy', 'volume',
            'density', 'dielectric', 'Dielectric', 'Enthalpy', 'Convex',
            'Magnetization', 'Elements', 'Modulus', 'Shear', "Poisson's",
            'Elastic', 'Energy'
        ]

        for result_number, results in enumerate(pifquery):
            for i, dictionary in enumerate(results['system']['properties']):
                if 'name' not in dictionary or dictionary['name'] == "CIF File":
                    continue
                value = dictionary['name']
                for entry in accepted_properties_list:
                    if entry not in value: continue
                    property_name_list.append(value)
                    try:
                        property_value_list.append(
                            float(dictionary['scalars'][0]['value']))
                    except (ValueError, KeyError):
                        property_name_list.pop(-1)
                        continue

        #for result_number, results in enumerate(pifquery):
        #    property_value = results['system']['properties']
        #    for list_index, list_element in enumerate(property_value):
        #        for name, value in property_value[list_index].items():
        #            if name == 'name' and value != "CIF File":
        #                for entry in accepted_properties_list:
        #                    if entry in value:
        #                        property_name_list.append(value)
        #                        try:
        #                            property_value_list.append(
        #                                float(property_value[list_index]['scalars'][0]['value']))
        #                        except (ValueError, KeyError):
        #                            # print('found something to remove', property_value[list_index]['scalars'][0]['value'])
        #                            property_name_list.pop(-1)
        #                            continue

        return property_name_list, property_value_list

    def _parse_pifquery_property_list(self, property_name_list,
                                      property_value_list):
        parsed_property_max = dict()
        parsed_property_min = dict()
        parsed_property_avg = dict()
        property_names_unique = list()
        if len(property_name_list) != len(property_value_list):
            print(
                'Error! Length of property name and property value lists are not the same. There must be a bug in the _get_pifquerey_property_list method'
            )
            raise IndexError(
                "property_name_list and property_value_list are not the same size."
            )
        else:
            # Get unique property names
            for name in property_name_list:
                if name not in property_names_unique:
                    property_names_unique.append(name)
            for unique_name in property_names_unique:
                unique_property = list()
                unique_property_avg = 0
                count = 0
                for i, name in enumerate(property_name_list):
                    # Only include property values whose name are same as those in unique_name list
                    if name == unique_name:
                        count += 1  # count how many instances of the same property occur
                        unique_property_avg += property_value_list[i]
                        unique_property.append(property_value_list[i])
                unique_property_min = min(entry for entry in unique_property)
                unique_property_max = max(entry for entry in unique_property)
                unique_property_avg = unique_property_avg / count
                parsed_property_min[str(unique_name) +
                                    "_min"] = unique_property_min
                parsed_property_max[str(unique_name) +
                                    "_max"] = unique_property_max
                parsed_property_avg[str(unique_name) +
                                    "_avg"] = unique_property_avg

        return property_names_unique, parsed_property_min, parsed_property_max, parsed_property_avg
Exemplo n.º 17
0
 def __init__(self, dataframe, api_key, composition_feature):
     self.dataframe = dataframe
     self.api_key = api_key
     self.client = CitrinationClient(api_key, 'https://citrination.com')
     self.composition_feature = composition_feature
Exemplo n.º 18
0
class CitrinationSaxsClassifer(object):
    """A set of classifiers to be used on SAXS spectra"""

    def __init__(self, address, api_key_file):
        with open(api_key_file, "r") as g:
            api_key = g.readline()
        a_key = api_key.strip()

        self.client = CitrinationClient(site = address, api_key=a_key)


    def citrination_classify(self,sample_params):
        """
        Parameters
        ----------
        sample_params : ordered dictionary
            ordered dictionary of floats representing features of test sample

        Returns
        -------
        flags : dict
            dictionary of (boolean,float) tuples,
            where the first item is the flag
            and the second is the probability,
            for each of the potential scattering populations
        """

        inputs = self.append_str_property(sample_params)

        flags = OrderedDict()
        resp = self.client.predict("24", inputs) # "24" is ID of dataview on Citrination
        flags['unidentified'] = resp['candidates'][0]['Property unidentified']
        flags['guinier_porod'] = resp['candidates'][0]['Property guinier_porod']
        flags['spherical_normal'] = resp['candidates'][0]['Property spherical_normal']
        flags['diffraction_peaks'] = resp['candidates'][0]['Property diffraction_peaks']

        return flags


    # helper function
    def append_str_property(self, sample_params):
        inputs = {}
        for k,v in sample_params.items():
            k = "Property " + k
            inputs[k] = v
        return inputs


    def citrination_predict(self, populations, sample_params, q_I):
        """Apply self.models and self.scalers to sample_params.

        Parameters
        ----------
        sample_params : ordered dictionary
            ordered dictionary of floats representing features of test sample

        Returns
        -------
        flags : dict
            dictionary of (boolean,float) tuples,
            where the first item is the flag
            and the second is the probability,
            for each of the potential scattering populations
        """

        features = self.append_str_property(sample_params)

        params = OrderedDict.fromkeys(saxs_math.all_parameter_keys)

        if populations['unidentified'][0] == '1':
            # TODO: we could still use a fit to 'predict' I0_floor...
            return params # all params are "None"

        if populations['spherical_normal'][0] == '1' and populations['diffraction_peaks'][0] == '0':
            resp = self.client.predict("27", features) # "27" is ID of dataview on Citrination
            params['r0_sphere'] = resp['candidates'][0]['Property r0_sphere']

            additional_features = saxs_math.spherical_normal_profile(q_I)
            additional_features = self.append_str_property(additional_features)
            ss_features = dict(features)
            ss_features.update(additional_features)
            resp = self.client.predict("28", ss_features)
            params['sigma_sphere'] = resp['candidates'][0]['Property sigma_sphere']

        if populations['guinier_porod'][0] == '1':
            additional_features = saxs_math.guinier_porod_profile(q_I)
            additional_features = self.append_str_property(additional_features)
            rg_features = dict(features)
            rg_features.update(additional_features)
            resp =self.client.predict("29", rg_features)
            params['rg_gp'] = resp['candidates'][0]['Property rg_gp']

        return params
Exemplo n.º 19
0
    def __init__(self, address, api_key_file):
        with open(api_key_file, "r") as g:
            api_key = g.readline()
        a_key = api_key.strip()

        self.client = CitrinationClient(site = address, api_key=a_key)
def citrine_upload(citrine_data,
                   api_key,
                   mdf_dataset,
                   previous_id=None,
                   public=True):
    import os
    from citrination_client import CitrinationClient

    cit_client = CitrinationClient(api_key).data
    source_id = mdf_dataset.get("mdf", {}).get("source_id", "NO_ID")
    try:
        cit_title = mdf_dataset["dc"]["titles"][0]["title"]
    except (KeyError, IndexError, TypeError):
        cit_title = "Untitled"
    try:
        cit_desc = " ".join([
            desc["description"] for desc in mdf_dataset["dc"]["descriptions"]
        ])
        if not cit_desc:
            raise KeyError
    except (KeyError, IndexError, TypeError):
        cit_desc = None

    # Create new version if dataset previously created
    if previous_id:
        try:
            rev_res = cit_client.create_dataset_version(previous_id)
            assert rev_res.number > 1
        except Exception:
            previous_id = "INVALID"
        else:
            cit_ds_id = previous_id
            cit_client.update_dataset(cit_ds_id,
                                      name=cit_title,
                                      description=cit_desc,
                                      public=False)
    # Create new dataset if not created
    if not previous_id or previous_id == "INVALID":
        try:
            cit_ds_id = cit_client.create_dataset(name=cit_title,
                                                  description=cit_desc,
                                                  public=False).id
            assert cit_ds_id > 0
        except Exception as e:
            print("{}: Citrine dataset creation failed: {}".format(
                source_id, repr(e)))
            if previous_id == "INVALID":
                return {
                    "success": False,
                    "error":
                    "Unable to create revision or new dataset in Citrine"
                }
            else:
                return {
                    "success":
                    False,
                    "error":
                    "Unable to create Citrine dataset, possibly due to duplicate entry"
                }

    success = 0
    failed = 0
    for path, _, files in os.walk(os.path.abspath(citrine_data)):
        for pif in files:
            up_res = cit_client.upload(cit_ds_id, os.path.join(path, pif))
            if up_res.successful():
                success += 1
            else:
                print("{}: Citrine upload failure: {}".format(
                    source_id, str(up_res)))
                failed += 1

    cit_client.update_dataset(cit_ds_id, public=public)

    return {
        "success": bool(success),
        "cit_ds_id": cit_ds_id,
        "success_count": success,
        "failure_count": failed
    }
Exemplo n.º 21
0
 def __init__(self, configdict, dataframe, api_key):
     self.configdict = configdict
     self.dataframe = dataframe
     self.api_key = api_key
     self.client = CitrinationClient(api_key, 'https://citrination.com')
Exemplo n.º 22
0
#### standard packages ####
from os import environ

#### third party libraries ####
from citrination_client import CitrinationClient
from citrination_client.views.data_view_builder import DataViewBuilder

#### Set up a citrination client
client = CitrinationClient(environ.get("CITRINATION_API_KEY"),
                           "https://citrination.com")

#### Initialize the client
data_view_id = '117'


def clone_data_view(client, target):
    """
    Creates a dataview if one does not exist at the supplied ID

    :param: client: a citrination client object
    :type: CitrinationClient
    :param: target: metadata of the dataview to be copied
    :type: dict
    :return: view_id: the view id of the created id
    :type: str
    """

    # Create ML configuration
    dv_builder = DataViewBuilder()
    dv_builder.dataset_ids(target['configuration']['dataset_ids'])
    [
Exemplo n.º 23
0
import pandas as pd
import re
from matminer.utils.conversions import str_to_composition
from matminer.featurizers import composition
from os import environ
from citrination_client import CitrinationClient
from citrination_client import *
from pypif import pif


# # Fracture Toughness DataSet

# In[2]:


client = CitrinationClient(environ['CITRINATION_API_KEY'], 'https://citrination.com')
dataset_id = '151803'


# ## query DataSet contains Fracture-Toughness feature

# In[3]:


def parse_prop_and_temp(search_result, prop_name):
    rows = []
    pif_records = [x.system for x in search_result.hits]
    for system in pif_records:
        cryst_value = '0'
        for prop in system.properties:
            if prop.name == 'Crystallinity':
Exemplo n.º 24
0
def load_data_zT():
    results_dir = setResDir()

    ## Metadata
    keys_response = [
        'Seebeck coefficient; squared', 'Electrical resistivity',
        'Thermal conductivity'
    ]
    sign = np.array([
        +1,  # Seebeck
        -1,  # Electric resistivity
        -1  # Thermal conductivity
    ])

    ## Load data, if possible
    # --------------------------------------------------
    try:
        df_X_all = pd.read_csv(results_dir + file_features)
        X_all = df_X_all.drop(df_X_all.columns[0], axis=1).values

        df_Y_all = pd.read_csv(results_dir + file_responses)
        Y_all = df_Y_all.drop(df_Y_all.columns[0], axis=1).values
        print("Cached data loaded.")

    except FileNotFoundError:
        ## Data Import
        # --------------------------------------------------
        # Initialize client
        print("Accessing data from Citrination...")
        site = 'https://citrination.com'  # Citrination
        client = CitrinationClient(api_key=os.environ['CITRINATION_API_KEY'],
                                   site=site)
        search_client = client.search
        # Aluminum dataset
        dataset_id = 178480  # ucsb_te_roomtemp_seebeck
        system_query = PifSystemReturningQuery(
            size=1000,
            query=DataQuery(dataset=DatasetQuery(id=Filter(
                equal=str(dataset_id)))))

        query_result = search_client.pif_search(system_query)
        print("    Found {} PIFs in dataset {}.".format(
            query_result.total_num_hits, dataset_id))

        ## Wrangle
        # --------------------------------------------------
        pifs = [x.system for x in query_result.hits]
        # Utility function will tabularize PIFs
        df_response = pifs2df(pifs)
        # Down-select columns to play well with to_numeric
        df_response = df_response[[
            'Seebeck coefficient', 'Electrical resistivity',
            'Thermal conductivity'
        ]]
        df_response = df_response.apply(pd.to_numeric)

        # Parse chemical compositions
        formulas = [pif.chemical_formula for pif in pifs]

        df_comp = pd.DataFrame(columns=['chemical_formula'], data=formulas)

        # Join
        df_data = pd.concat([df_comp, df_response], axis=1)
        print("    Accessed data.")

        # Featurize
        print("Featurizing data...")
        df_data['composition'] = df_data['chemical_formula'].apply(
            get_compostion)

        f = MultipleFeaturizer([
            cf.Stoichiometry(),
            cf.ElementProperty.from_preset("magpie"),
            cf.ValenceOrbital(props=['avg']),
            cf.IonProperty(fast=True)
        ])

        X = np.array(f.featurize_many(df_data['composition']))

        # Find valid response values
        keys_original = [
            'Seebeck coefficient', 'Electrical resistivity',
            'Thermal conductivity'
        ]

        index_valid_response = {
            key: df_data[key].dropna().index.values
            for key in keys_original
        }

        index_valid_all = df_data[keys_original].dropna().index.values
        X_all = X[index_valid_all, :]
        Y_all = df_data[keys_original].iloc[index_valid_all].values

        # Manipulate columns for proper objective values
        Y_all[:, 0] = Y_all[:, 0]**2  # Squared seebeck
        print("    Data prepared; {0:} valid observations.".format(
            X_all.shape[0]))

        # Cache data
        pd.DataFrame(data=X_all).to_csv(results_dir + file_features)
        pd.DataFrame(data=Y_all, columns=keys_response).to_csv(results_dir +
                                                               file_responses)
        print("Data cached in results directory.")

    return X_all, Y_all, sign, keys_response, prefix
def run_sequential_learning(client:CitrinationClient, view_id:int, dataset_id:int,
                        num_candidates_per_iter:int,
                        design_effort:int, wait_time:int,
                        num_sl_iterations:int, input_properties:List[str],
                        target:List[str], print_output:bool,
                        true_function:Callable[[np.ndarray], float],
                        score_type:str,
                        ) -> Tuple[List[float], List[float]]:
    '''Runs SL design

    :param client: Client object
    :type client: CitrinationClient
    :param view_id: View ID
    :type view_id: int
    :param dataset_id: Dataset ID
    :type dataset_id: int
    :param num_candidates_per_iter: Candidates in a batch
    :type num_candidates_per_iter: int
    :param design_effort: Effort from 1-30
    :type design_effort: int
    :param wait_time: Wait time in seconds before polling API
    :type wait_time: int
    :param num_sl_iterations: SL iterations to run
    :type num_sl_iterations: int
    :param input_properties: Inputs
    :type input_properties: List[str]
    :param target: ("Output property", {"Min", "Max"})
    :type target: List[str]
    :param print_output: Whether or not to print outputs
    :type print_output: bool
    :param true_function: Actual function for evaluating measured/true values
    :type true_function: Callable[[np.ndarray], float]
    :param score_type: MLI or MEI
    :type score_type: str
    :return: 2-tuple: list of predicted scores/uncertainties; list of measured scores/uncertainties
    :rtype: Tuple[List[float], List[float]]
    '''



    best_sl_pred_vals = []
    best_sl_measured_vals = []

    _wait_on_ingest(client, dataset_id, wait_time, print_output)

    for i in range(num_sl_iterations):
        if print_output:
            print(f"\n---STARTING SL ITERATION #{i+1}---")

        _wait_on_ingest(client, dataset_id, wait_time, print_output)
        _wait_on_data_view(client, dataset_id, view_id, wait_time, print_output)

        # Submit a design run
        design_id = client.submit_design_run(
                data_view_id=view_id,
                num_candidates=num_candidates_per_iter,
                effort=design_effort,
                target=Target(*target),
                constraints=[],
                sampler="Default"
            ).uuid

        if print_output:
            print(f"Created design run with ID {design_id}")

        _wait_on_design_run(client, design_id, view_id, wait_time, print_output)

        # Compute the best values with uncertainties as a list of (value, uncertainty)
        if score_type == "MEI":
            candidates = client.get_design_run_results(view_id, design_id).best_materials
        else:
            candidates = client.get_design_run_results(view_id, design_id).next_experiments
        values_w_uncertainties = [
            (
                m["descriptor_values"][target[0]],
                m["descriptor_values"][f"Uncertainty in {target[0]}"]
            ) for m in candidates
        ]

        # Find and save the best predicted value
        if target[1] == "Min":
            best_value_w_uncertainty = min(values_w_uncertainties, key=lambda x: x[0])
        else:
            best_value_w_uncertainty = max(values_w_uncertainties, key=lambda x: x[0])

        best_sl_pred_vals.append(best_value_w_uncertainty)
        if print_output:
            print(f"SL iter #{i+1}, best predicted (value, uncertainty) = {best_value_w_uncertainty}")

        # Update dataset w/ new candidates
        new_x_vals = []
        for material in candidates:
            new_x_vals.append(np.array(
                [float(material["descriptor_values"][x]) for x in input_properties]
            ))

        temp_dataset_fpath = f"design-{design_id}.json"
        write_dataset_from_func(true_function, temp_dataset_fpath, new_x_vals)
        upload_data_and_get_id(
            client,
            "", # No name needed for updating a dataset
            temp_dataset_fpath,
            given_dataset_id=dataset_id
        )

        _wait_on_ingest(client, dataset_id, wait_time, print_output)

        if print_output:
            print(f"Dataset updated: {len(new_x_vals)} candidates added")

        query_dataset = PifSystemReturningQuery(size=9999,
                            query=DataQuery(
                            dataset=DatasetQuery(
                                id=Filter(equal=str(dataset_id))
                        )))
        query_result = client.search.pif_search(query_dataset)

        if print_output:
            print(f"New dataset contains {query_result.total_num_hits} PIFs")

        # Update measured values in new dataset
        dataset_y_values = []
        for hit in query_result.hits:
            # Assume last prop is output if following this script
            dataset_y_values.append(
                float(hit.system.properties[-1].scalars[0].value)
            )

        if target[1] == "Min":
            best_sl_measured_vals.append(min(dataset_y_values))
        else:
            best_sl_measured_vals.append(max(dataset_y_values))

        # Retrain model w/ wait times
        client.data_views.retrain(view_id)
        _wait_on_data_view(client, dataset_id, view_id, wait_time, print_output)

    if print_output:
        print("SL finished!\n")

    return (best_sl_pred_vals, best_sl_measured_vals)
Exemplo n.º 26
0
def begin_convert(mdf_dataset, status_id):
    """Pull, back up, and convert metadata."""
    # Setup
    creds = {
        "app_name": "MDF Open Connect",
        "client_id": app.config["API_CLIENT_ID"],
        "client_secret": app.config["API_CLIENT_SECRET"],
        "services": ["transfer", "publish"]
    }
    clients = toolbox.confidential_login(creds)
    mdf_transfer_client = clients["transfer"]
    globus_publish_client = clients["publish"]

    # Download data locally, back up on MDF resources
    dl_res = download_and_backup(mdf_transfer_client,
                                 mdf_dataset.pop("data", {}), status_id)
    if dl_res["success"]:
        local_path = dl_res["local_path"]
        backup_path = dl_res["backup_path"]
    else:
        raise IOError("No data downloaded")
    # TODO: Update status - data downloaded
    print("DEBUG: Data downloaded")

    print("DEBUG: Conversions started")
    # Pop indexing args
    parse_params = mdf_dataset.pop("index", {})
    add_services = mdf_dataset.pop("services", [])

    # TODO: Stream data into files instead of holding feedstock in memory
    feedstock = [mdf_dataset]

    # tags = [sub["subject"] for sub in mdf_dataset.get("dc", {}).get("subjects", [])]
    # key_info = get_key_matches(tags or None)

    # List of all files, for bag
    all_files = []

    # Citrination setup
    cit_manager = IngesterManager()
    cit_client = CitrinationClient(app.config["CITRINATION_API_KEY"])
    # Get title and description
    try:
        cit_title = mdf_dataset["dc"]["titles"][0]["title"]
    except (KeyError, IndexError):
        cit_title = "Untitled"
    try:
        cit_desc = " ".join([
            desc["description"] for desc in mdf_dataset["dc"]["descriptions"]
        ])
        if not cit_desc:
            raise KeyError
    except (KeyError, IndexError):
        cit_desc = None
    cit_ds = cit_client.create_data_set(name=cit_title,
                                        description=cit_desc,
                                        share=0).json()
    cit_ds_id = cit_ds["id"]
    print("DEBUG: Citrine dataset ID:", cit_ds_id)

    for path, dirs, files in os.walk(os.path.abspath(local_path)):
        # Separate files into groups, process group as unit
        for group in group_files(files):
            # Get all file metadata
            group_file_md = [
                get_file_metadata(file_path=os.path.join(path, filename),
                                  backup_path=os.path.join(
                                      backup_path,
                                      path.replace(os.path.abspath(local_path),
                                                   ""), filename))
                for filename in group
            ]
            all_files.extend(group_file_md)

            group_paths = [os.path.join(path, filename) for filename in group]

            # MDF parsing
            mdf_res = omniparser.omniparse(group_paths, parse_params)

            # Citrine parsing
            cit_pifs = cit_manager.run_extensions(
                group_paths,
                include=None,
                exclude=[],
                args={"quality_report": False})
            if not isinstance(cit_pifs, list):
                cit_pifs = [cit_pifs]
            cit_full = []
            if len(cit_pifs) > 0:
                cit_res = []
                # Add UIDs
                cit_pifs = cit_utils.set_uids(cit_pifs)
                for pif in cit_pifs:
                    # Get PIF URL
                    pif_land_page = {
                        "mdf": {
                            "landing_page": cit_utils.get_url(pif, cit_ds_id)
                        }
                    } if cit_ds_id else {}
                    # Make PIF into feedstock and save
                    cit_res.append(
                        toolbox.dict_merge(pif_to_feedstock(pif),
                                           pif_land_page))
                    # Add DataCite metadata
                    pif = add_dc(pif, mdf_dataset.get("dc", {}))

                    cit_full.append(pif)
            else:  # No PIFs parsed
                # TODO: Send failed datatype to Citrine for logging
                # Pad cit_res to the same length as mdf_res for "merging"
                cit_res = [{} for i in range(len(mdf_res))]

            # If MDF parser failed to parse group, pad mdf_res to match PIF count
            if len(mdf_res) == 0:
                mdf_res = [{} for i in range(len(cit_res))]

            # If only one mdf record was parsed, merge all PIFs into that record
            if len(mdf_res) == 1:
                merged_cit = {}
                [toolbox.dict_merge(merged_cit, cr) for cr in cit_res]
                mdf_records = [toolbox.dict_merge(mdf_res[0], merged_cit)]
            # If the same number of MDF records and Citrine PIFs were parsed, merge in order
            elif len(mdf_res) == len(cit_res):
                mdf_records = [
                    toolbox.dict_merge(r_mdf, r_cit)
                    for r_mdf, r_cit in zip(mdf_res, cit_res)
                ]
            # Otherwise, keep the MDF records only
            else:
                print("DEBUG: Record mismatch:\nMDF parsed", len(mdf_res),
                      "records", "\nCitrine parsed", len(cit_res), "records"
                      "\nPIFs discarded")
                # TODO: Update status/log - Citrine records discarded
                mdf_records = mdf_res

            # Filter null records, save rest
            if not mdf_records:
                print("DEBUG: No MDF records in group:", group)
            [
                feedstock.append(
                    toolbox.dict_merge(record, {"files": group_file_md}))
                for record in mdf_records if record
            ]

            # Upload PIFs to Citrine
            for full_pif in cit_full:
                with tempfile.NamedTemporaryFile(mode="w+") as pif_file:
                    pif_dump(full_pif, pif_file)
                    pif_file.seek(0)
                    up_res = json.loads(
                        cit_client.upload(cit_ds_id, pif_file.name))
                    if up_res["success"]:
                        print("DEBUG: Citrine upload success")
                    else:
                        print("DEBUG: Citrine upload failure, error",
                              up_res.get("status"))

    # TODO: Update status - indexing success
    print("DEBUG: Indexing success")

    # Pass feedstock to /ingest
    with tempfile.TemporaryFile(mode="w+") as stock:
        for entry in feedstock:
            json.dump(entry, stock)
            stock.write("\n")
        stock.seek(0)
        ingest_res = requests.post(app.config["INGEST_URL"],
                                   data={"status_id": status_id},
                                   files={'file': stock})
    if not ingest_res.json().get("success"):
        # TODO: Update status? Ingest failed
        # TODO: Fail everything, delete Citrine dataset, etc.
        raise ValueError("In convert - Ingest failed" + str(ingest_res.json()))

    # Additional service integrations

    # Finalize Citrine dataset
    # TODO: Turn on public dataset ingest (share=1)
    if "citrine" in add_services:
        try:
            cit_client.update_data_set(cit_ds_id, share=0)
        except Exception as e:
            # TODO: Update status, notify Citrine - Citrine ds failure
            print("DEBUG: Citrination dataset not updated")

    # Globus Publish
    # TODO: Test after Publish API is fixed
    if "globus_publish" in add_services:
        try:
            fin_res = globus_publish_data(globus_publish_client,
                                          mdf_transfer_client, mdf_dataset,
                                          local_path)
        except Exception as e:
            # TODO: Update status - Publish failed
            print("Publish ERROR:", repr(e))
        else:
            # TODO: Update status - Publish success
            print("DEBUG: Publish success:", fin_res)

    # Remove local data
    shutil.rmtree(local_path)
    # TODO: Update status - everything done
    return {"success": True, "status_id": status_id}