Пример #1
0
def valid_simple_ml_predictor_data():
    """Produce valid data used for tests."""
    from citrine.informatics.data_sources import GemTableDataSource
    from citrine.informatics.descriptors import RealDescriptor
    x = RealDescriptor("x", 0, 100, "")
    y = RealDescriptor("y", 0, 100, "")
    z = RealDescriptor("z", 0, 100, "")
    data_source = GemTableDataSource(
        table_id=uuid.UUID('e5c51369-8e71-4ec6-b027-1f92bdc14762'),
        table_version=2)
    return dict(
        module_type='PREDICTOR',
        status='VALID',
        status_info=[],
        archived=False,
        display_name='ML predictor',
        schema_id='08d20e5f-e329-4de0-a90a-4b5e36b91703',
        id=str(uuid.uuid4()),
        config=dict(
            type='Simple',
            name='ML predictor',
            description='Predicts z from input x and latent variable y',
            inputs=[x.dump()],
            outputs=[z.dump()],
            latent_variables=[y.dump()],
            training_data=[data_source.dump()]))
Пример #2
0
def valid_generalized_mean_property_predictor_data():
    """Produce valid data used for tests."""
    from citrine.informatics.descriptors import FormulationDescriptor
    from citrine.informatics.data_sources import GemTableDataSource
    formulation_descriptor = FormulationDescriptor('simple mixture')
    return dict(module_type='PREDICTOR',
                status='VALID',
                status_info=[],
                archived=False,
                display_name='Mean property predictor',
                schema_id='29e53222-3217-4f81-b3b8-4197a8211ade',
                id=str(uuid.uuid4()),
                config=dict(
                    type='GeneralizedMeanProperty',
                    name='Mean property predictor',
                    description='Computes mean ingredient properties',
                    input=formulation_descriptor.dump(),
                    properties=['density'],
                    p=2,
                    training_data=[
                        GemTableDataSource(uuid.uuid4(), 0,
                                           formulation_descriptor).dump()
                    ],
                    impute_properties=True,
                    default_properties={'density': 1.0},
                    label='solvent'))
Пример #3
0
def valid_graph_predictor_data():
    """Produce valid data used for tests."""
    from citrine.informatics.data_sources import GemTableDataSource
    from citrine.informatics.descriptors import RealDescriptor
    return dict(module_type='PREDICTOR',
                status='VALID',
                status_info=[],
                archived=False,
                display_name='Graph predictor',
                schema_id='43c61ad4-7e33-45d0-a3de-504acb4e0737',
                id=str(uuid.uuid4()),
                config=dict(
                    type='Graph',
                    name='Graph predictor',
                    description='description',
                    predictors=[
                        str(uuid.uuid4()),
                        dict(type='Expression',
                             name='Expression predictor',
                             description='mean of 2 outputs',
                             expression='(X + Y)/2',
                             output=RealDescriptor('Property~Some metric',
                                                   lower_bound=0,
                                                   upper_bound=1000,
                                                   units='W').dump(),
                             aliases={
                                 "Property~X": "X",
                                 "Property~Y": "Y"
                             })
                    ],
                    training_data=[GemTableDataSource(uuid.uuid4(),
                                                      0).dump()]))
Пример #4
0
def valid_mean_property_predictor_data():
    """Produce valid data used for tests."""
    from citrine.informatics.descriptors import FormulationDescriptor, RealDescriptor
    from citrine.informatics.data_sources import GemTableDataSource
    formulation_descriptor = FormulationDescriptor('simple mixture')
    density = RealDescriptor(key='density',
                             lower_bound=0,
                             upper_bound=100,
                             units='g/cm^3')
    return dict(module_type='PREDICTOR',
                status='READY',
                status_info=[],
                archived=False,
                display_name='Mean property predictor',
                id=str(uuid.uuid4()),
                config=dict(
                    type='MeanProperty',
                    name='Mean property predictor',
                    description='Computes mean ingredient properties',
                    input=formulation_descriptor.dump(),
                    properties=[density.dump()],
                    p=2,
                    training_data=[
                        GemTableDataSource(uuid.uuid4(), 0,
                                           formulation_descriptor).dump()
                    ],
                    impute_properties=True,
                    default_properties={'density': 1.0},
                    label='solvent'))
Пример #5
0
def test_unexpected_pattern():
    """Check that unexpected patterns result in a value error"""
    # Given
    session = FakeSession()
    pc = PredictorCollection(uuid.uuid4(), session)

    # Then
    with pytest.raises(ValueError):
        pc.auto_configure(GemTableDataSource(uuid.uuid4(), 0), "yogurt")
Пример #6
0
def test_graph_default_training_data():
    """Test that default training data list isn't shared."""
    # create two serialized graph predictors with no defined training data
    gp1raw = {'config': {'name': 'one', 'description': '', 'predictors': [], 'type': 'Graph'},
              'archived': False, 'module_type': 'PREDICTOR', 'display_name': 'one'}
    gp2raw = {'config': {'name': 'two', 'description': '', 'predictors': [], 'type': 'Graph'},
              'archived': False, 'module_type': 'PREDICTOR', 'display_name': 'two'}

    # build them, populating the default empty list of training data
    gp1: GraphPredictor = Predictor.build(gp1raw)
    gp2: GraphPredictor = Predictor.build(gp2raw)

    # check it is empty
    assert len(gp1.training_data) == 0
    assert len(gp2.training_data) == 0

    # add training data to one of them
    gp1.training_data.append(GemTableDataSource(uuid.uuid4(), 1))

    # check that the training data doesn't bleed into both
    assert len(gp1.training_data) == 1
    assert len(gp2.training_data) == 0
Пример #7
0
def test_returned_predictor(valid_graph_predictor_data):
    """Check that auto_configure works on the happy path."""
    # Given
    session = FakeSession()

    # Setup a response that includes instance instead of config
    response = deepcopy(valid_graph_predictor_data)
    response["instance"] = response["config"]
    del response["config"]

    session.set_response(response)
    pc = PredictorCollection(uuid.uuid4(), session)

    # When
    result = pc.auto_configure(GemTableDataSource(uuid.uuid4(), 0), "PLAIN")

    # Then the response is parsed in a predictor
    assert result.name == valid_graph_predictor_data["display_name"]
    assert isinstance(result, GraphPredictor)
    # including nested predictors
    assert len(result.predictors) == 2
    assert isinstance(result.predictors[0], uuid.UUID)
    assert isinstance(result.predictors[1], DeprecatedExpressionPredictor)
Пример #8
0
def valid_simple_mixture_predictor_data():
    """Produce valid data used for tests."""
    from citrine.informatics.data_sources import GemTableDataSource
    from citrine.informatics.descriptors import FormulationDescriptor
    input_formulation = FormulationDescriptor('input formulation')
    output_formulation = FormulationDescriptor('output formulation')
    return dict(
        module_type='PREDICTOR',
        status='READY',
        status_info=[],
        archived=False,
        display_name='Simple mixture predictor',
        id=str(uuid.uuid4()),
        config=dict(type='SimpleMixture',
                    name='Simple mixture predictor',
                    description='simple mixture description',
                    input=input_formulation.dump(),
                    output=output_formulation.dump(),
                    training_data=[
                        GemTableDataSource(uuid.uuid4(), 0,
                                           input_formulation).dump()
                    ]),
    )
def test_descriptors_from_data_source():
    session = FakeSession()
    col = 'smiles'
    response_json = {
        'descriptors': [  # shortened sample response
            {
                'category': 'Real',
                'descriptor_key': 'khs.sNH3 KierHallSmarts for {}'.format(col),
                'units': '',
                'lower_bound': 0,
                'upper_bound': 1000000000
            },
            {
                'category': 'Real',
                'descriptor_key': 'khs.dsN KierHallSmarts for {}'.format(col),
                'units': '',
                'lower_bound': 0,
                'upper_bound': 1000000000
            },
        ]
    }
    session.set_response(response_json)
    descriptors = DescriptorMethods(uuid4(), session)
    data_source = GemTableDataSource('43357a66-3644-4959-8115-77b2630aca45',
                                     123)

    results = descriptors.descriptors_from_data_source(data_source)
    assert results == [
        RealDescriptor(key=r['descriptor_key'],
                       lower_bound=r['lower_bound'],
                       upper_bound=r['upper_bound'],
                       units=r['units']) for r in response_json['descriptors']
    ]
    assert session.last_call.path == '/projects/{}/material-descriptors/from-data-source'\
        .format(descriptors.project_id)
    assert session.last_call.method == 'POST'
"""Tests for citrine.informatics.descriptors."""
import uuid

import pytest

from citrine.informatics.data_sources import DataSource, CSVDataSource, GemTableDataSource
from citrine.informatics.descriptors import RealDescriptor, FormulationDescriptor
from citrine.resources.file_link import FileLink


@pytest.fixture(params=[
    CSVDataSource(FileLink("foo.spam", "http://example.com"), {
        "spam":
        RealDescriptor("eggs", lower_bound=0, upper_bound=1.0, units="")
    }, ["identifier"]),
    GemTableDataSource(uuid.uuid4(), 1),
    GemTableDataSource(uuid.uuid4(), "2"),
    GemTableDataSource(uuid.uuid4(), "2",
                       FormulationDescriptor("formulation")),
])
def data_source(request):
    return request.param


def test_deser_from_parent(data_source):
    # Serialize and deserialize the descriptors, making sure they are round-trip serializable
    data = data_source.dump()
    data_source_deserialized = DataSource.build(data)
    assert data_source == data_source_deserialized

Пример #11
0
                               lower_bound=0,
                               upper_bound=100,
                               units='GPa')
youngs_modulus = RealDescriptor('Property~Young\'s modulus',
                                lower_bound=0,
                                upper_bound=100,
                                units='GPa')
poissons_ratio = RealDescriptor('Property~Poisson\'s ratio',
                                lower_bound=-1,
                                upper_bound=0.5,
                                units='')
formulation = FormulationDescriptor('formulation')
formulation_output = FormulationDescriptor('output formulation')
water_quantity = RealDescriptor('water quantity', 0, 1)
salt_quantity = RealDescriptor('salt quantity', 0, 1)
data_source = GemTableDataSource(
    uuid.UUID('e5c51369-8e71-4ec6-b027-1f92bdc14762'), 0)
formulation_data_source = GemTableDataSource(
    uuid.UUID('6894a181-81d2-4304-9dfa-a6c5b114d8bc'), 0, formulation)


@pytest.fixture
def simple_predictor() -> SimpleMLPredictor:
    """Build a SimpleMLPredictor for testing."""
    return SimpleMLPredictor(
        name='ML predictor',
        description='Predicts z from input x and latent variable y',
        inputs=[x],
        outputs=[z],
        latent_variables=[y],
        training_data=[data_source])