예제 #1
0
def valid_mean_property_predictor_data():
    """Produce valid data used for tests."""
    from citrine.informatics.descriptors import FormulationDescriptor, RealDescriptor
    from citrine.informatics.data_sources import GemTableDataSource
    formulation_descriptor = FormulationDescriptor('simple mixture')
    density = RealDescriptor(key='density',
                             lower_bound=0,
                             upper_bound=100,
                             units='g/cm^3')
    return dict(module_type='PREDICTOR',
                status='READY',
                status_info=[],
                archived=False,
                display_name='Mean property predictor',
                id=str(uuid.uuid4()),
                config=dict(
                    type='MeanProperty',
                    name='Mean property predictor',
                    description='Computes mean ingredient properties',
                    input=formulation_descriptor.dump(),
                    properties=[density.dump()],
                    p=2,
                    training_data=[
                        GemTableDataSource(uuid.uuid4(), 0,
                                           formulation_descriptor).dump()
                    ],
                    impute_properties=True,
                    default_properties={'density': 1.0},
                    label='solvent'))
예제 #2
0
def valid_deprecated_expression_predictor_data():
    """Produce valid data used for tests."""
    from citrine.informatics.descriptors import RealDescriptor
    shear_modulus = RealDescriptor('Property~Shear modulus',
                                   lower_bound=0,
                                   upper_bound=100,
                                   units='GPa')
    return dict(
        module_type='PREDICTOR',
        status='VALID',
        status_info=[],
        archived=False,
        display_name='Expression predictor',
        schema_id='866e72a6-0a01-4c5f-8c35-146eb2540166',
        id=str(uuid.uuid4()),
        config=dict(
            type='Expression',
            name='Expression predictor',
            description=
            'Computes shear modulus from Youngs modulus and Poissons ratio',
            expression='Y / (2 * (1 + v))',
            output=shear_modulus.dump(),
            aliases={
                'Y': "Property~Young's modulus",
                'v': "Property~Poisson's ratio",
            }))
예제 #3
0
def valid_simple_ml_predictor_data():
    """Produce valid data used for tests."""
    from citrine.informatics.data_sources import GemTableDataSource
    from citrine.informatics.descriptors import RealDescriptor
    x = RealDescriptor("x", 0, 100, "")
    y = RealDescriptor("y", 0, 100, "")
    z = RealDescriptor("z", 0, 100, "")
    data_source = GemTableDataSource(
        table_id=uuid.UUID('e5c51369-8e71-4ec6-b027-1f92bdc14762'),
        table_version=2)
    return dict(
        module_type='PREDICTOR',
        status='VALID',
        status_info=[],
        archived=False,
        display_name='ML predictor',
        schema_id='08d20e5f-e329-4de0-a90a-4b5e36b91703',
        id=str(uuid.uuid4()),
        config=dict(
            type='Simple',
            name='ML predictor',
            description='Predicts z from input x and latent variable y',
            inputs=[x.dump()],
            outputs=[z.dump()],
            latent_variables=[y.dump()],
            training_data=[data_source.dump()]))
예제 #4
0
def valid_expression_predictor_data():
    """Produce valid data used for tests."""
    from citrine.informatics.descriptors import RealDescriptor
    shear_modulus = RealDescriptor('Property~Shear modulus',
                                   lower_bound=0,
                                   upper_bound=100,
                                   units='GPa')
    youngs_modulus = RealDescriptor('Property~Young\'s modulus',
                                    lower_bound=0,
                                    upper_bound=100,
                                    units='GPa')
    poissons_ratio = RealDescriptor('Property~Poisson\'s ratio',
                                    lower_bound=-1,
                                    upper_bound=0.5,
                                    units='')
    return dict(
        module_type='PREDICTOR',
        status='VALID',
        status_info=[],
        archived=False,
        display_name='Expression predictor',
        schema_id='f1601161-bb98-4fa9-bdd2-a2a673547532',
        id=str(uuid.uuid4()),
        config=dict(
            type='AnalyticExpression',
            name='Expression predictor',
            description=
            'Computes shear modulus from Youngs modulus and Poissons ratio',
            expression='Y / (2 * (1 + v))',
            output=shear_modulus.dump(),
            aliases={
                'Y': youngs_modulus.dump(),
                'v': poissons_ratio.dump(),
            }))
예제 #5
0
def product_design_space() -> ProductDesignSpace:
    """Build a ProductDesignSpace for testing."""
    alpha = RealDescriptor('alpha', 0, 100, "")
    beta = RealDescriptor('beta', 0, 100, "")
    gamma = CategoricalDescriptor('gamma', ['a', 'b', 'c'])
    dimensions = [
        ContinuousDimension(alpha, 0, 10),
        ContinuousDimension(beta, 0, 10),
        EnumeratedDimension(gamma, ['a', 'c'])
    ]
    return ProductDesignSpace(name='my design space', description='does some things', dimensions=dimensions)
예제 #6
0
def valid_predictor_report_data():
    """Produce valid data used for tests."""
    from citrine.informatics.descriptors import RealDescriptor
    x = RealDescriptor("x", 0, 1, "")
    y = RealDescriptor("y", 0, 100, "")
    z = RealDescriptor("z", 0, 101, "")
    return dict(
        id='7c2dda5d-675a-41b6-829c-e485163f0e43',
        module_id='31c7f311-6f3d-4a93-9387-94cc877f170c',
        status='OK',
        create_time='2020-04-23T15:46:26Z',
        update_time='2020-04-23T15:46:26Z',
        report=dict(models=[
            dict(name='GeneralLoloModel_1',
                 type='ML Model',
                 inputs=[x.key],
                 outputs=[y.key],
                 display_name='ML Model',
                 model_settings=[
                     dict(name='Algorithm',
                          value='Ensemble of non-linear estimators',
                          children=[
                              dict(name='Number of estimators',
                                   value=64,
                                   children=[]),
                              dict(name='Leaf model',
                                   value='Mean',
                                   children=[]),
                              dict(name='Use jackknife',
                                   value=True,
                                   children=[])
                          ])
                 ],
                 feature_importances=[
                     dict(response_key='y',
                          importances=dict(x=1.00),
                          top_features=5)
                 ],
                 predictor_configuration_name="Predict y from x with ML"),
            dict(name='GeneralLosslessModel_2',
                 type='Analytic Model',
                 inputs=[x.key, y.key],
                 outputs=[z.key],
                 display_name='GeneralLosslessModel_2',
                 model_settings=[
                     dict(name="Expression",
                          value="(z) <- (x + y)",
                          children=[])
                 ],
                 feature_importances=[],
                 predictor_configuration_name="Expression for z",
                 predictor_configuration_uid=
                 "249bf32c-6f3d-4a93-9387-94cc877f170c")
        ],
                    descriptors=[x.dump(), y.dump(),
                                 z.dump()]))
예제 #7
0
def test_polymorphic_legacy_deserialization(valid_simple_ml_predictor_data):
    """Ensure that a polymorphically deserialized SimplePredictor looks sane."""
    predictor: SimpleMLPredictor = Predictor.build(
        valid_simple_ml_predictor_data)
    assert predictor.name == 'ML predictor'
    assert predictor.description == 'Predicts z from input x and latent variable y'
    assert len(predictor.inputs) == 1
    assert predictor.inputs[0] == RealDescriptor("x", 0, 100, "")
    assert len(predictor.outputs) == 1
    assert predictor.outputs[0] == RealDescriptor("z", 0, 100, "")
    assert len(predictor.latent_variables) == 1
    assert predictor.latent_variables[0] == RealDescriptor("y", 0, 100, "")
    assert len(predictor.training_data) == 1
    assert predictor.training_data[0].table_id == UUID(
        'e5c51369-8e71-4ec6-b027-1f92bdc14762')
def basic_cartesian_space() -> EnumeratedDesignSpace:
    """Build basic cartesian space for testing."""
    alpha = RealDescriptor('alpha', 0, 100)
    beta = RealDescriptor('beta', 0, 100)
    gamma = CategoricalDescriptor('gamma', ['a', 'b', 'c'])
    design_grid = {
        'alpha': [0, 50, 100],
        'beta': [0, 25, 50, 75, 100],
        'gamma': ['a', 'b', 'c']
    }
    basic_space = enumerate_cartesian_product(design_grid=design_grid,
                                              descriptors=[alpha, beta, gamma],
                                              name='basic space',
                                              description='')
    return basic_space
예제 #9
0
def test_simple_deserialization(valid_data):
    """Ensure a deserialized RealDescriptor looks sane."""
    descriptor = RealDescriptor.build(valid_data)
    assert descriptor.key == 'alpha'
    assert descriptor.units == ''
    assert descriptor.lower_bound == 5.0
    assert descriptor.upper_bound == 10.0
예제 #10
0
def valid_graph_predictor_data():
    """Produce valid data used for tests."""
    from citrine.informatics.data_sources import GemTableDataSource
    from citrine.informatics.descriptors import RealDescriptor
    return dict(module_type='PREDICTOR',
                status='VALID',
                status_info=[],
                archived=False,
                display_name='Graph predictor',
                schema_id='43c61ad4-7e33-45d0-a3de-504acb4e0737',
                id=str(uuid.uuid4()),
                config=dict(
                    type='Graph',
                    name='Graph predictor',
                    description='description',
                    predictors=[
                        str(uuid.uuid4()),
                        dict(type='Expression',
                             name='Expression predictor',
                             description='mean of 2 outputs',
                             expression='(X + Y)/2',
                             output=RealDescriptor('Property~Some metric',
                                                   lower_bound=0,
                                                   upper_bound=1000,
                                                   units='W').dump(),
                             aliases={
                                 "Property~X": "X",
                                 "Property~Y": "Y"
                             })
                    ],
                    training_data=[GemTableDataSource(uuid.uuid4(),
                                                      0).dump()]))
def enumerated_design_space() -> EnumeratedDesignSpace:
    """Build an EnumeratedDesignSpace for testing."""
    x = RealDescriptor('x', lower_bound=0.0, upper_bound=1.0)
    color = CategoricalDescriptor('color', ['r', 'g', 'b'])
    data = [dict(x=0, color='r'), dict(x=1.0, color='b')]
    return EnumeratedDesignSpace('enumerated',
                                 'desc',
                                 descriptors=[x, color],
                                 data=data)
def test_enumerated_oversize_warnings():
    """Test that oversized enumerated space warnings are raised"""
    with pytest.raises(UserWarning, match="648000000"):
        # Fail on warning (so code stops running)
        with warnings.catch_warnings():
            warnings.simplefilter('error')
            delta = RealDescriptor('delta', 0, 100)
            epsilon = RealDescriptor('epsilon', 0, 100)
            zeta = RealDescriptor('zeta', 0, 100)
            too_big_enumerated_grid = {
                'delta': np.linspace(0, 100, 600),
                'epsilon': np.linspace(0, 100, 600),
                'zeta': np.linspace(0, 100, 600),
            }
            enumerate_cartesian_product(design_grid=too_big_enumerated_grid,
                                        descriptors=[delta, epsilon, zeta],
                                        name='too big space',
                                        description='')
예제 #13
0
def test_continuous_bounds():
    """Test bounds are assigned correctly, even when bounds are == 0"""
    beta = RealDescriptor('beta', -10, 10)
    lower_none = ContinuousDimension(beta, upper_bound=0)
    assert lower_none.lower_bound == -10
    assert lower_none.upper_bound == 0

    upper_none = ContinuousDimension(beta, lower_bound=0)
    assert upper_none.lower_bound == 0
    assert upper_none.upper_bound == 10
예제 #14
0
def test_design_space_limits():
    """Test that the validation logic is triggered before post/put-ing enumerated design spaces."""
    # Given
    session = FakeSession()
    collection = DesignSpaceCollection(uuid.uuid4(), session)

    too_big = EnumeratedDesignSpace(
        "foo",
        "bar",
        descriptors=[
            RealDescriptor("R-{}".format(i), 0, 1, "") for i in range(128)
        ],
        data=[{"R-{}".format(i): random()
               for i in range(128)} for _ in range(2001)])

    just_right = EnumeratedDesignSpace(
        "foo",
        "bar",
        descriptors=[
            RealDescriptor("R-{}".format(i), 0, 1, "") for i in range(128)
        ],
        data=[{"R-{}".format(i): random()
               for i in range(128)} for _ in range(2000)])

    # create mock post response by setting the status
    mock_response = just_right.dump()
    mock_response["status"] = "READY"
    session.responses.append(mock_response)

    # Then
    with pytest.raises(ValueError) as excinfo:
        collection.register(too_big)
    assert "only supports" in str(excinfo.value)

    # test register
    collection.register(just_right)

    # add back the response for the next test
    session.responses.append(mock_response)

    # test update
    collection.update(just_right)
예제 #15
0
def test_model_summary_init():
    """Tests that a ModelSummary object can be constructed."""
    x = RealDescriptor('x', 0, 1)
    y = RealDescriptor('y', 0, 1)
    z = RealDescriptor('z', 0, 1)
    feat_importance = FeatureImportanceReport(output_key='z',
                                              importances={
                                                  'x': 0.8,
                                                  'y': 0.2
                                              })
    ModelSummary(name='General model',
                 type_="ML Model",
                 inputs=[x, y],
                 outputs=[z],
                 model_settings={
                     'optimization restarts': 15,
                     'backpropagation': False
                 },
                 feature_importances=[feat_importance],
                 predictor_name="a predictor")
def test_dict_behavior():
    entries = {
        "density":
        RealDescriptor("density",
                       lower_bound=0,
                       upper_bound=100,
                       units="g/cm^3"),
        "pressure":
        RealDescriptor("pressure",
                       lower_bound=0,
                       upper_bound=10000,
                       units="GPa")
    }

    v = PlatformVocabulary(entries)

    assert len(v) == 2
    assert set(v) == {"density", "pressure"}
    assert v["density"] == entries["density"]
    assert v["pressure"] == entries["pressure"]
예제 #17
0
def test_bad_predictor_report_build(valid_predictor_report_data):
    """Modify the predictor report to be non-ideal and check the behavior."""
    too_many_descriptors = deepcopy(valid_predictor_report_data)
    # Multiple descriptors with the same key
    other_x = RealDescriptor("x", 0, 100, "")
    too_many_descriptors['report']['descriptors'].append(other_x.dump())
    with warnings.catch_warnings(record=True) as w:
        Report.build(too_many_descriptors)
        assert len(w) == 1
        assert issubclass(w[-1].category, RuntimeWarning)

    # A key that appears in inputs and/or outputs, but there is no corresponding descriptor.
    # This is done twice for coverage, once to catch a missing input and once for a missing output.
    too_few_descriptors = deepcopy(valid_predictor_report_data)
    too_few_descriptors['report']['descriptors'].pop()
    with pytest.raises(RuntimeError):
        Report.build(too_few_descriptors)
    too_few_descriptors['report']['descriptors'] = []
    with pytest.raises(RuntimeError):
        Report.build(too_few_descriptors)
예제 #18
0
def invalid_predictor_data():
    """Produce valid data used for tests."""
    from citrine.informatics.descriptors import RealDescriptor
    x = RealDescriptor("x", 0, 100, "")
    y = RealDescriptor("y", 0, 100, "")
    z = RealDescriptor("z", 0, 100, "")
    return dict(module_type='PREDICTOR',
                status='INVALID',
                status_info=['Something is wrong', 'Very wrong'],
                archived=True,
                display_name='my predictor',
                id=str(uuid.uuid4()),
                config=dict(type='invalid',
                            name='my predictor',
                            description='does some things',
                            inputs=[x.dump(), y.dump()],
                            output=z.dump()))
def test_joined_oversize_warnings(large_joint_design_space):
    """Test that oversized joined space warnings are raised"""
    with pytest.raises(UserWarning, match="239203125"):
        # Fail on warning (so code stops running)
        with warnings.catch_warnings():
            warnings.simplefilter('error')

            delta = RealDescriptor('delta', 0, 100)
            epsilon = RealDescriptor('epsilon', 0, 100)
            zeta = CategoricalDescriptor('zeta', ['a', 'b', 'c'])
            design_grid = {
                'delta': [0, 50, 100],
                'epsilon': [0, 25, 50, 75, 100],
                'zeta': ['a', 'b', 'c']
            }
            basic_space_2 = enumerate_cartesian_product(
                design_grid=design_grid,
                descriptors=[delta, epsilon, zeta],
                name='basic space 2',
                description='')

            eta = RealDescriptor('eta', 0, 100)
            theta = RealDescriptor('theta', 0, 100)
            iota = CategoricalDescriptor('iota', ['a', 'b', 'c'])
            design_grid = {
                'eta': [0, 50, 100],
                'theta': [0, 25, 50, 75, 100],
                'iota': ['a', 'b', 'c']
            }
            basic_space_3 = enumerate_cartesian_product(
                design_grid=design_grid,
                descriptors=[eta, theta, iota],
                name='basic space 3',
                description='')

            cartesian_join_design_spaces(subspaces=[
                basic_space_2, basic_space_3, large_joint_design_space
            ],
                                         name='too big join space',
                                         description='')
    def from_predictor_responses(self, predictor: Predictor, inputs: List[Descriptor]):
        if isinstance(predictor, (MolecularStructureFeaturizer, ChemicalFormulaFeaturizer)):
            if isinstance(predictor, MolecularStructureFeaturizer):
                input_descriptor = predictor.descriptor
            else:
                input_descriptor = predictor.input_descriptor
            return [
                RealDescriptor(f"{input_descriptor.key} real property {i}", lower_bound=0, upper_bound=1, units="")
                       for i in range(self.num_properties)
            ] + [CategoricalDescriptor(f"{input_descriptor.key} categorical property", ["cat1", "cat2"])]

        elif isinstance(predictor, MeanPropertyPredictor):
            label_str = predictor.label or "all ingredients"
            return [
                RealDescriptor(
                    f"mean of {prop.key} for {label_str} in {predictor.input_descriptor.key}",
                    lower_bound=0,
                    upper_bound=1,
                    units=""
                )
                for prop in predictor.properties
            ]
예제 #21
0
def test_predictor_report_build(valid_predictor_report_data):
    """Build a predictor report and verify its structure."""
    report = Report.build(valid_predictor_report_data)

    assert report.status == 'OK'
    assert str(report.uid) == valid_predictor_report_data['id']

    x = RealDescriptor("x", 0, 1, "")
    y = RealDescriptor("y", 0, 100, "")
    z = RealDescriptor("z", 0, 101, "")
    assert report.descriptors == [x, y, z]

    lolo_model: ModelSummary = report.model_summaries[0]
    assert lolo_model.name == 'GeneralLoloModel_1'
    assert lolo_model.type_ == 'ML Model'
    assert lolo_model.inputs == [x]
    assert lolo_model.outputs == [y]
    assert lolo_model.model_settings == {
        'Algorithm': 'Ensemble of non-linear estimators',
        'Number of estimators': 64,
        'Leaf model': 'Mean',
        'Use jackknife': True
    }
    assert lolo_model.feature_importances[0].dump() == FeatureImportanceReport('y', {'x': 1.0}).dump()
    assert lolo_model.predictor_name == 'Predict y from x with ML'
    assert lolo_model.predictor_uid is None

    exp_model: ModelSummary = report.model_summaries[1]
    assert exp_model.name == 'GeneralLosslessModel_2'
    assert exp_model.type_ == 'Analytic Model'
    assert exp_model.inputs == [x, y]
    assert exp_model.outputs == [z]
    assert exp_model.model_settings == {
        "Expression": "(z) <- (x + y)"
    }
    assert exp_model.feature_importances == []
    assert exp_model.predictor_name == 'Expression for z'
    assert exp_model.predictor_uid == UUID("249bf32c-6f3d-4a93-9387-94cc877f170c")
예제 #22
0
def valid_ing_to_simple_mixture_predictor_data():
    """Produce valid data used for tests."""
    from citrine.informatics.descriptors import FormulationDescriptor, RealDescriptor
    return dict(module_type='PREDICTOR',
                status='VALID',
                status_info=[],
                archived=False,
                display_name='Ingredients to simple mixture predictor',
                schema_id='873e4541-da8a-4698-a981-732c0c729c3d',
                id=str(uuid.uuid4()),
                config=dict(
                    type='IngredientsToSimpleMixture',
                    name='Ingredients to simple mixture predictor',
                    description='Constructs mixtures from ingredients',
                    output=FormulationDescriptor('simple mixture').dump(),
                    id_to_quantity={
                        'water': RealDescriptor('water quantity', 0, 1).dump(),
                        'salt': RealDescriptor('salt quantity', 0, 1).dump()
                    },
                    labels={
                        'solvent': ['water'],
                        'solute': ['salt'],
                    }))
예제 #23
0
def test_check_update_some():
    """Test the update check correctly builds a module."""
    # given
    session = FakeSession()
    desc = RealDescriptor("spam", 0, 1, "kg")
    response = {
        "type": "AnalyticExpression",
        "name": "foo",
        "description": "bar",
        "expression": "2 * x",
        "output": RealDescriptor("spam", 0, 1, "kg").dump(),
        "aliases": {}
    }
    session.set_response({"updatable": True, "update": response})
    pc = PredictorCollection(uuid.uuid4(), session)
    predictor_id = uuid.uuid4()

    # when
    update_check = pc.check_for_update(predictor_id)

    # then
    expected = ExpressionPredictor("foo", "bar", "2 * x", desc, {})
    assert update_check.dump() == expected.dump()
    assert update_check.uid == predictor_id
예제 #24
0
def valid_simple_ml_predictor_data(valid_gem_data_source_dict):
    """Produce valid data used for tests."""
    from citrine.informatics.descriptors import RealDescriptor
    x = RealDescriptor("x", 0, 100, "")
    y = RealDescriptor("y", 0, 100, "")
    z = RealDescriptor("z", 0, 100, "")
    return dict(
        module_type='PREDICTOR',
        status='READY',
        status_info=[],
        archived=False,
        display_name='ML predictor',
        id=str(uuid.uuid4()),
        config=dict(
            type='Simple',
            name='ML predictor',
            description='Predicts z from input x and latent variable y',
            inputs=[x.dump()],
            outputs=[z.dump()],
            latent_variables=[y.dump()],
            training_data=[valid_gem_data_source_dict]))
def test_deprecated_ingredients_to_simple_mixture():
    """make sure deprecation warnings are issued."""
    with warnings.catch_warnings(record=True) as caught:
        warnings.simplefilter("always")
        i2sm = IngredientsToSimpleMixturePredictor(
            name="deprecated",
            description="",
            output=FormulationDescriptor("formulation"),
            id_to_quantity={"quantity 1": RealDescriptor("foo", lower_bound=0, upper_bound=1, units="")},
            labels={"label": {"foo"}}
        )
        assert i2sm.name == "deprecated"
        assert i2sm.labels == {"label": {"foo"}}
        assert len(caught) == 1
        w = caught[0]
        assert issubclass(w.category, DeprecationWarning)
예제 #26
0
def test_data_source_create():
    ds = DataSourceDesignSpace(
        name="Test",
        description="ing",
        data_source=CSVDataSource(
            file_link=FileLink(filename="foo.csv", url="http://example.com/bar.csv"),
            column_definitions={
                "foo": RealDescriptor(key="bar", lower_bound=0, upper_bound=100, units="kg")
            }
        )
    )
    round_robin = DesignSpace.build(ds.dump())
    assert ds.name == round_robin.name
    assert ds.description == round_robin.description
    assert ds.data_source == round_robin.data_source
    assert "DataSource" in str(ds)
def template_to_descriptor(template: AttributeTemplate) -> Descriptor:
    """
    Convert a GEMD attribute template into an AI Engine Descriptor.

    IntBounds cannot be converted because they have no matching descriptor type.
    CompositionBounds can only be converted when every component is an element, in which case
    they are converted to ChemicalFormulaDescriptors.

    Parameters
    ----------
    template: AttributeTemplate
        Template to convert into a descriptor

    Returns
    -------
    Descriptor
        Descriptor with a key matching the template name and type corresponding to the bounds

    """
    bounds = template.bounds
    if isinstance(bounds, RealBounds):
        return RealDescriptor(key=template.name,
                              lower_bound=bounds.lower_bound,
                              upper_bound=bounds.upper_bound,
                              units=bounds.default_units)
    if isinstance(bounds, CategoricalBounds):
        return CategoricalDescriptor(key=template.name,
                                     categories=bounds.categories)
    if isinstance(bounds, MolecularStructureBounds):
        return MolecularStructureDescriptor(key=template.name)
    if isinstance(bounds, CompositionBounds):
        if set(bounds.components).issubset(EmpiricalFormula.all_elements()):
            return ChemicalFormulaDescriptor(key=template.name)
        else:
            msg = "Cannot create descriptor for CompositionBounds with non-atomic components"
            raise NoEquivalentDescriptorError(msg)
    if isinstance(bounds, IntegerBounds):
        raise NoEquivalentDescriptorError(
            "Cannot create a descriptor for integer-valued data")
    raise ValueError("Template has unrecognized bounds: {}".format(
        type(bounds)))
예제 #28
0
def test_from_predictor_responses():
    session = FakeSession()
    col = 'smiles'
    response_json = {
        'responses': [  # shortened sample response
            {
                'category': 'Real',
                'descriptor_key': 'khs.sNH3 KierHallSmarts for {}'.format(col),
                'units': '',
                'lower_bound': 0,
                'upper_bound': 1000000000
            },
            {
                'category': 'Real',
                'descriptor_key': 'khs.dsN KierHallSmarts for {}'.format(col),
                'units': '',
                'lower_bound': 0,
                'upper_bound': 1000000000
            },
        ]
    }
    session.set_response(response_json)
    descriptors = DescriptorMethods(uuid4(), session)
    featurizer = MolecularStructureFeaturizer(
        name="Molecule featurizer",
        description="description",
        descriptor=MolecularStructureDescriptor(col),
        features=["all"],
        excludes=["standard"])
    results = descriptors.from_predictor_responses(
        featurizer, [MolecularStructureDescriptor(col)])
    assert results == [
        RealDescriptor(
            key=r['descriptor_key'],
            lower_bound=r['lower_bound'],
            upper_bound=r['upper_bound'],
        ) for r in response_json['responses']
    ]
    assert session.last_call.path == '/projects/{}/material-descriptors/predictor-responses'\
        .format(descriptors.project_id)
    assert session.last_call.method == 'POST'
def test_descriptors_from_data_source():
    session = FakeSession()
    col = 'smiles'
    response_json = {
        'descriptors': [  # shortened sample response
            {
                'category': 'Real',
                'descriptor_key': 'khs.sNH3 KierHallSmarts for {}'.format(col),
                'units': '',
                'lower_bound': 0,
                'upper_bound': 1000000000
            },
            {
                'category': 'Real',
                'descriptor_key': 'khs.dsN KierHallSmarts for {}'.format(col),
                'units': '',
                'lower_bound': 0,
                'upper_bound': 1000000000
            },
        ]
    }
    session.set_response(response_json)
    descriptors = DescriptorMethods(uuid4(), session)
    data_source = GemTableDataSource('43357a66-3644-4959-8115-77b2630aca45',
                                     123)

    results = descriptors.descriptors_from_data_source(data_source)
    assert results == [
        RealDescriptor(key=r['descriptor_key'],
                       lower_bound=r['lower_bound'],
                       upper_bound=r['upper_bound'],
                       units=r['units']) for r in response_json['descriptors']
    ]
    assert session.last_call.path == '/projects/{}/material-descriptors/from-data-source'\
        .format(descriptors.project_id)
    assert session.last_call.method == 'POST'
"""Tests for citrine.informatics.descriptors."""
import uuid

import pytest

from citrine.informatics.data_sources import DataSource, CSVDataSource, GemTableDataSource
from citrine.informatics.descriptors import RealDescriptor, FormulationDescriptor
from citrine.resources.file_link import FileLink


@pytest.fixture(params=[
    CSVDataSource(FileLink("foo.spam", "http://example.com"), {
        "spam":
        RealDescriptor("eggs", lower_bound=0, upper_bound=1.0, units="")
    }, ["identifier"]),
    GemTableDataSource(uuid.uuid4(), 1),
    GemTableDataSource(uuid.uuid4(), "2"),
    GemTableDataSource(uuid.uuid4(), "2",
                       FormulationDescriptor("formulation")),
])
def data_source(request):
    return request.param


def test_deser_from_parent(data_source):
    # Serialize and deserialize the descriptors, making sure they are round-trip serializable
    data = data_source.dump()
    data_source_deserialized = DataSource.build(data)
    assert data_source == data_source_deserialized