예제 #1
0
    def test_unit_handling(self):
        """
        Tests unit handling with a simple model that calculates the area of a rectangle as the
        product of two lengths.

        In this case the input lengths are provided in centimeters and meters.
        Tests whether the input units are properly coerced into canonical types.
        Tests whether the output units are properly set.
        Tests whether the model returns as predicted.
        Returns:
            None
        """
        L = Symbol('l', ['L'], ['L'], units=[1.0, [['centimeter', 1.0]]], shape=[1])
        A = Symbol('a', ['A'], ['A'], units=[1.0, [['centimeter', 2.0]]], shape=[1])
        get_area_config = {
            'name': 'area',
            # 'connections': [{'inputs': ['l1', 'l2'], 'outputs': ['a']}],
            'equations': ['a = l1 * l2'],
            # 'unit_map': {'l1': "cm", "l2": "cm", 'a': "cm^2"}
            'symbol_property_map': {"a": A, "l1": L, "l2": L}
        }
        model = EquationModel(**get_area_config)
        out = model.evaluate({'l1': QuantityFactory.create_quantity(L, 1, 'meter'),
                              'l2': QuantityFactory.create_quantity(L, 2)}, allow_failure=False)

        self.assertTrue(math.isclose(out['a'].magnitude, 200.0))
        self.assertTrue(out['a'].units == A.units)
예제 #2
0
    def test_model_returns_nan(self):
        # This tests model failure with scalar nan.
        # Quantity class has other more thorough tests.

        A = Symbol('a', ['A'], ['A'], units='dimensionless', shape=1)
        B = Symbol('b', ['B'], ['B'], units='dimensionless', shape=1)
        for sym in (B, A):
            Registry("symbols")[sym] = sym
            Registry("units")[sym] = sym.units
        get_config = {
            'name': 'equality',
            # 'connections': [{'inputs': ['b'], 'outputs': ['a']}],
            'equations': ['a = b'],
            # 'unit_map': {'a': "dimensionless", 'a': "dimensionless"}
            'variable_symbol_map': {
                "a": A,
                "b": B
            }
        }
        model = EquationModel(**get_config)
        out = model.evaluate(
            {'b': QuantityFactory.create_quantity(B, float('nan'))},
            allow_failure=True)
        self.assertFalse(out['successful'])
        self.assertEqual(out['message'],
                         'Evaluation returned invalid values (NaN)')
예제 #3
0
    def test_nan_checking(self):
        A = Symbol('a', ['A'], ['A'], units='dimensionless', shape=1)
        B = Symbol('b', ['B'], ['B'], units='dimensionless', shape=[2, 2])
        C = Symbol('c', ['C'], ['C'], units='dimensionless', shape=1)
        D = Symbol('d', ['D'], ['D'], units='dimensionless', shape=[2, 2])

        scalar_quantity = Quantity(A, float('nan'))
        non_scalar_quantity = Quantity(
            B, [[1.0, float('nan')], [float('nan'), 1.0]])
        complex_scalar_quantity = Quantity(C, complex('nan+nanj'))
        complex_non_scalar_quantity = Quantity(
            D, [[complex(1.0), complex('nanj')],
                [complex('nan'), complex(1.0)]])

        self.assertTrue(scalar_quantity.contains_nan_value())
        self.assertTrue(non_scalar_quantity.contains_nan_value())
        self.assertTrue(complex_scalar_quantity.contains_nan_value())
        self.assertTrue(complex_non_scalar_quantity.contains_nan_value())

        scalar_quantity = Quantity(A, 1.0)
        non_scalar_quantity = Quantity(B, [[1.0, 2.0], [2.0, 1.0]])
        complex_scalar_quantity = Quantity(C, complex('1+1j'))
        complex_non_scalar_quantity = Quantity(
            D, [[complex(1.0), complex('5j')], [complex('5'),
                                                complex(1.0)]])

        self.assertFalse(scalar_quantity.contains_nan_value())
        self.assertFalse(non_scalar_quantity.contains_nan_value())
        self.assertFalse(complex_scalar_quantity.contains_nan_value())
        self.assertFalse(complex_non_scalar_quantity.contains_nan_value())
예제 #4
0
def import_materials(mp_ids, api_key=None):
    """
    Given a list of material ids, returns a list of Material objects with all
    available properties from the Materials Project.
    Args:
        mp_ids (list<str>): list of material ids whose information will be retrieved.
        api_key (str): api key to be used to conduct the query.
    Returns:
        (list<Material>): list of material objects with associated data.
    """
    mpr = MPRester(api_key)
    to_return = []
    query = mpr.query(criteria={"task_id": {
        '$in': mp_ids
    }},
                      properties=AVAILABLE_MP_PROPERTIES)
    for data in query:
        # properties of one mp-id
        mat = Material()
        tag_string = data['task_id']
        mat.add_property(Symbol('structure', data['structure'], [tag_string]))
        mat.add_property(
            Symbol('lattice_unit_cell', data['structure'].lattice.matrix,
                   [tag_string]))
        for key in data:
            if not data[
                    key] is None and key in PROPNET_FROM_MP_NAME_MAPPING.keys(
                    ):
                prop_type = DEFAULT_SYMBOL_TYPES[
                    PROPNET_FROM_MP_NAME_MAPPING[key]]
                p = Symbol(prop_type, data[key], [tag_string])
                mat.add_property(p)
        to_return.append(mat)
    return to_return
예제 #5
0
    def test_model_returns_complex(self):
        # This tests model failure with scalar complex.
        # Quantity class has other more thorough tests.

        A = Symbol('a', ['A'], ['A'], units='dimensionless', shape=1)
        B = Symbol('b', ['B'], ['B'], units='dimensionless', shape=1)
        get_config = {
            'name': 'add_complex_value',
            # 'connections': [{'inputs': ['b'], 'outputs': ['a']}],
            'equations': ['a = b + 1j'],
            # 'unit_map': {'a': "dimensionless", 'a': "dimensionless"}
            'symbol_property_map': {
                "a": A,
                "b": B
            }
        }
        model = EquationModel(**get_config)
        out = model.evaluate({'b': Quantity(B, 5)}, allow_failure=True)
        self.assertFalse(out['successful'])
        self.assertEqual(out['message'],
                         'Evaluation returned invalid values (complex)')

        out = model.evaluate({'b': Quantity(B, 5j)}, allow_failure=True)
        self.assertTrue(out['successful'])
        self.assertTrue(np.isclose(out['a'].magnitude, 6j))
예제 #6
0
    def test_model_register_unregister(self):
        A = Symbol('a', ['A'], ['A'], units='dimensionless', shape=1)
        B = Symbol('b', ['B'], ['B'], units='dimensionless', shape=1)
        C = Symbol('c', ['C'], ['C'], units='dimensionless', shape=1)
        D = Symbol('d', ['D'], ['D'], units='dimensionless', shape=1)
        m = EquationModel('equation_model_to_remove', ['a = b * 3'],
                          variable_symbol_map={
                              'a': A,
                              'b': B
                          })
        self.assertIn(m.name, Registry("models"))
        self.assertTrue(m.registered)
        m.unregister()
        self.assertNotIn(m.name, Registry("models"))
        self.assertFalse(m.registered)
        m.register()
        self.assertTrue(m.registered)
        with self.assertRaises(KeyError):
            m.register(overwrite_registry=False)

        m.unregister()
        m = EquationModel('equation_model_to_remove', ['a = b * 3'],
                          variable_symbol_map={
                              'a': A,
                              'b': B
                          },
                          register=False)
        self.assertNotIn(m.name, Registry("models"))
        self.assertFalse(m.registered)

        m.register()
        with self.assertRaises(KeyError):
            _ = EquationModel('equation_model_to_remove', ['a = b * 3'],
                              variable_symbol_map={
                                  'a': A,
                                  'b': B
                              },
                              register=True,
                              overwrite_registry=False)

        m_replacement = EquationModel('equation_model_to_remove',
                                      ['c = d * 3'],
                                      variable_symbol_map={
                                          'c': C,
                                          'd': D
                                      })

        m_registered = Registry("models")['equation_model_to_remove']
        self.assertIs(m_registered, m_replacement)
        self.assertIsNot(m_registered, m)
예제 #7
0
    def generate_symbols():
        """
        Returns a set of Symbol objects used in testing.
        Returns: (dict<str, Symbol>)
        """
        a = Symbol('A', ['A'], ['A'], units="dimensionless", shape=[1])
        b = Symbol('B', ['B'], ['B'], units="dimensionless", shape=[1])
        c = Symbol('C', ['C'], ['C'], category="object", object_type=str)

        return {
            'A': a,
            'B': b,
            'C': c
        }
예제 #8
0
    def generate_symbols():
        """
        Returns a set of Symbol objects used in testing.
        Returns: (dict<str, Symbol>)
        """
        a = Symbol('A', ['A'], ['A'], units="dimensionless", shape=[1])
        b = Symbol('B', ['B'], ['B'], units="dimensionless", shape=[1])
        c = Symbol('C', ['C'], ['C'], category="object", object_type=str)

        syms = {
            'A': a,
            'B': b,
            'C': c
        }

        for sym in syms.values():
            Registry("symbols")[sym] = sym

        return syms
예제 #9
0
    def test_property_construction(self):
        sample_symbol_type_dict = {
            'name': 'youngs_modulus',
            'units': [1.0, [["gigapascal", 1.0]]],
            'display_names': ["Young's modulus", "Elastic modulus"],
            'display_symbols': ["E"],
            'shape': 1,
            'comment': ""
        }

        sample_symbol_type = Symbol(
            name='youngs_modulus',  #
            units=[1.0, [["gigapascal", 1.0]]],  #ureg.parse_expression("GPa"),
            display_names=["Young's modulus", "Elastic modulus"],
            display_symbols=["E"],
            shape=1,
            comment="")

        self.assertEqual(sample_symbol_type,
                         Symbol.from_dict(sample_symbol_type_dict))
예제 #10
0
def add_builtin_symbols_to_registry():
    for f in _DEFAULT_SYMBOL_TYPE_FILES:
        d = loadfn(f)
        d['is_builtin'] = True
        d['overwrite_registry'] = True
        symbol_type = Symbol.from_dict(d)
        if "{}.yaml".format(symbol_type.name) not in f:
            raise ValueError('Name/filename mismatch in {}'.format(f))

    # This is just to enable importing this module
    for name, symbol in Registry("symbols").items():
        if symbol.is_builtin:
            globals()[name] = symbol
예제 #11
0
    def _evaluate(self, symbol_values):

        # placeholder, a little dumb
        # will be partly replaced with CrystalPrototypeClassifier

        structure = symbol_values['s']

        # support other anions too?
        if 'O' not in [sp.symbol for sp in structure.types_of_specie]:
            return None

        if structure.composition.anonymized_formula != 'ABC3':
            return None

        radii = []
        for sp in structure.types_of_specie:
            if sp.symbol != 'O':
                radii.append(sp.ionic_radius)

        return {
            'r_A': Symbol('ionic_radius', max(radii), tags='anion_A'),
            'r_B': Symbol('ionic_radius', min(radii), tags='anion_B')
        }
예제 #12
0
 def generate_canonical_symbols():
     """
     Returns a set of Symbol objects used in testing.
     Returns: (dict<str, Symbol>)
     """
     A = Symbol('A', ['A'], ['A'], units="dimensionless", shape=[1])
     B = Symbol('B', ['B'], ['B'], units="dimensionless", shape=[1])
     C = Symbol('C', ['C'], ['C'], units="dimensionless", shape=[1])
     D = Symbol('D', ['D'], ['D'], units="dimensionless", shape=[1])
     G = Symbol('G', ['G'], ['G'], units="dimensionless", shape=[1])
     F = Symbol('F', ['F'], ['F'], units="dimensionless", shape=[1])
     return {'A': A, 'B': B, 'C': C, 'D': D, 'G': G, 'F': F}
예제 #13
0
 def setUp(self):
     self.custom_symbol = Symbol("A", units='dimensionless')
     self.constraint_symbol = Symbol("A",
                                     constraint="A > 0",
                                     units='dimensionless')
예제 #14
0
    def test_complex_and_imaginary_checking(self):
        A = Symbol('a', ['A'], ['A'], units='dimensionless', shape=1)
        B = Symbol('b', ['B'], ['B'], units='dimensionless', shape=[2, 2])
        # TODO: Revisit this when splitting quantity class into non-numerical and numerical
        C = Symbol('c', ['C'], ['C'], category='object', shape=1)

        real_float_scalar = Quantity(A, 1.0)
        real_float_non_scalar = Quantity(B, [[1.0, 1.0], [1.0, 1.0]])

        complex_scalar = Quantity(A, complex(1 + 1j))
        complex_non_scalar = Quantity(
            B, [[complex(1.0), complex(1.j)], [complex(1.j),
                                               complex(1.0)]])

        complex_scalar_zero_imaginary = Quantity(A, complex(1.0))
        complex_non_scalar_zero_imaginary = Quantity(
            B, [[complex(1.0), complex(1.0)], [complex(1.0),
                                               complex(1.0)]])

        complex_scalar_appx_zero_imaginary = Quantity(A, complex(1.0 + 1e-10j))
        complex_non_scalar_appx_zero_imaginary = Quantity(
            B, [[complex(1.0), complex(1.0 + 1e-10j)],
                [complex(1.0 + 1e-10j), complex(1.0)]])

        non_numerical = Quantity(C, 'test')

        # Test is_complex_type() with...
        # ...Quantity objects
        self.assertFalse(Quantity.is_complex_type(real_float_scalar))
        self.assertFalse(Quantity.is_complex_type(real_float_non_scalar))
        self.assertTrue(Quantity.is_complex_type(complex_scalar))
        self.assertTrue(Quantity.is_complex_type(complex_non_scalar))
        self.assertTrue(
            Quantity.is_complex_type(complex_scalar_zero_imaginary))
        self.assertTrue(
            Quantity.is_complex_type(complex_non_scalar_zero_imaginary))
        self.assertTrue(
            Quantity.is_complex_type(complex_scalar_appx_zero_imaginary))
        self.assertTrue(
            Quantity.is_complex_type(complex_non_scalar_appx_zero_imaginary))
        self.assertFalse(Quantity.is_complex_type(non_numerical))

        # ...primitive types
        self.assertFalse(Quantity.is_complex_type(1))
        self.assertFalse(Quantity.is_complex_type(1.))
        self.assertTrue(Quantity.is_complex_type(1j))
        self.assertFalse(Quantity.is_complex_type('test'))

        # ...np.array types
        self.assertFalse(Quantity.is_complex_type(np.array([1])))
        self.assertFalse(Quantity.is_complex_type(np.array([1.])))
        self.assertTrue(Quantity.is_complex_type(np.array([1j])))
        self.assertFalse(Quantity.is_complex_type(np.array(['test'])))

        # ...ureg Quantity objects
        self.assertFalse(Quantity.is_complex_type(ureg.Quantity(1)))
        self.assertFalse(Quantity.is_complex_type(ureg.Quantity(1.)))
        self.assertTrue(Quantity.is_complex_type(ureg.Quantity(1j)))
        self.assertFalse(Quantity.is_complex_type(ureg.Quantity([1])))
        self.assertFalse(Quantity.is_complex_type(ureg.Quantity([1.])))
        self.assertTrue(Quantity.is_complex_type(ureg.Quantity([1j])))

        # Check member functions
        self.assertFalse(real_float_scalar.contains_complex_type())
        self.assertFalse(real_float_scalar.contains_imaginary_value())
        self.assertFalse(real_float_non_scalar.contains_complex_type())
        self.assertFalse(real_float_non_scalar.contains_imaginary_value())

        self.assertTrue(complex_scalar.contains_complex_type())
        self.assertTrue(complex_scalar.contains_imaginary_value())
        self.assertTrue(complex_non_scalar.contains_complex_type())
        self.assertTrue(complex_non_scalar.contains_imaginary_value())

        self.assertTrue(complex_scalar_zero_imaginary.contains_complex_type())
        self.assertFalse(
            complex_scalar_zero_imaginary.contains_imaginary_value())
        self.assertTrue(
            complex_non_scalar_zero_imaginary.contains_complex_type())
        self.assertFalse(
            complex_non_scalar_zero_imaginary.contains_imaginary_value())

        self.assertTrue(
            complex_scalar_appx_zero_imaginary.contains_complex_type())
        self.assertFalse(
            complex_scalar_appx_zero_imaginary.contains_imaginary_value())
        self.assertTrue(
            complex_non_scalar_appx_zero_imaginary.contains_complex_type())
        self.assertFalse(
            complex_non_scalar_appx_zero_imaginary.contains_imaginary_value())

        self.assertFalse(non_numerical.contains_complex_type())
        self.assertFalse(non_numerical.contains_imaginary_value())
예제 #15
0
from glob import glob
from monty.serialization import loadfn

from propnet.core.symbols import Symbol

# Auto loading of all allowed properties

# stores all loaded properties as PropertyMetadata instances in a dictionary,
# mapped to their names
DEFAULT_SYMBOLS = {}
DEFAULT_SYMBOL_VALUES = {}

_DEFAULT_SYMBOL_TYPE_FILES = glob(
    os.path.join(os.path.dirname(__file__), '../symbols/**/*.yaml'),
    recursive=True)

for f in _DEFAULT_SYMBOL_TYPE_FILES:
    symbol_type = Symbol.from_dict(loadfn(f))
    DEFAULT_SYMBOLS[symbol_type.name] = symbol_type
    if symbol_type.default_value is not None:
        DEFAULT_SYMBOL_VALUES[symbol_type] = symbol_type.default_value
    if "{}.yaml".format(symbol_type.name) not in f:
        raise ValueError('Name/filename mismatch in {}'.format(f))

# Stores all loaded properties' names in a tuple in the global scope.
DEFAULT_UNITS = {name: symbol.units
                 for name, symbol in DEFAULT_SYMBOLS.items()}
DEFAULT_SYMBOL_TYPE_NAMES = tuple(DEFAULT_SYMBOLS.keys())

for name, symbol in DEFAULT_SYMBOLS.items():
    globals()[name] = symbol
예제 #16
0
    def evaluate(self, material=None, property_type=None):
        """
        Expands the graph, producing the output of models that have the appropriate inputs supplied.
        Mutates the graph instance variable.

        Optional arguments limit the scope of which models or properties are tested.
            material parameter: produces output from models only if the input properties come from the specified material.
                                mutated graph will modify the Material's graph instance as well as this graph instance.
                                mutated graph will include edges from Material to Symbol to SymbolType.
            property_type parameter: produces output from models only if the input properties are in the list.

        If no material parameter is specified, the generated SymbolNodes will be added with edges to and from
        corresponding SymbolTypeNodes specifically. No connections will be made to existing Material nodes because
        a Symbol might be derived from a combination of materials in this case. Likewise existing Material nodes' graph
        instances will not be mutated in this case.

        Args:
            material (Material): optional limit on which material's properties will be expanded (default: all materials)
            property_type (list<SymbolType>): optional limit on which Symbols will be considered as input.
        Returns:
            void
        """

        ##
        # Get existing Symbol nodes, 'active' SymbolType nodes, and 'candidate' Models.
        # Filter by provided material and property_type arguments.
        ##

        if not material:
            # All symbol_nodes are candidates for evaluation.
            symbol_nodes = list(self.nodes_by_type('Symbol'))
        else:
            # Only symbol_nodes connected to the given Material object are candidates for evaluation.
            material_nodes = self.nodes_by_type('Material')
            material_node = None
            for node in material_nodes:
                if node.node_value == material:
                    if material_node:
                        raise ValueError('Multiple identical materials found.')
                    material_node = node
            if not material_node:
                raise ValueError('Specified material not found.')
            symbol_nodes = []
            for node in self.graph.neighbors(material_node):
                if node.node_type == PropnetNodeType['Symbol'] and node not in symbol_nodes:
                    symbol_nodes.append(node)

        if property_type:
            # Only SymbolType objects in the property_type list are candidates for evaluation.
            c = 0
            while c < len(symbol_nodes):
                if symbol_nodes[c].node_value.type not in property_type:
                    symbol_nodes.remove(symbol_nodes[c])
                else:
                    c += 1

        # Get set of SymbolTypes that have values provided.
        active_symbol_type_nodes = set()
        for node in symbol_nodes:
            for neighbor in self.graph.neighbors(node):
                if neighbor.node_type != PropnetNodeType.SymbolType:
                    continue
                active_symbol_type_nodes.add(neighbor)

        # Get set of Models that have values provided to inputs.
        candidate_models = set()
        for node in active_symbol_type_nodes:
            for neighbor in self.graph.neighbors(node):
                if neighbor.node_type != PropnetNodeType.Model:
                    continue
                candidate_models.add(neighbor.node_value)

        ##
        # Define helper data structures and methods.
        ##

        # Create fast-lookup data structure (SymbolType -> Symbol):
        lookup_dict = {}
        for node in symbol_nodes:
            if node.node_value.type not in lookup_dict:
                lookup_dict[node.node_value.type] = [node.node_value]
            else:
                lookup_dict[node.node_value.type] += [node.node_value]

        # Create fast-lookup data structure (Symbol -> MaterialNode)
        source_dict = {}

        # Create fast-lookup data structure (str (SymbolType name) -> SymbolType)
        symbol_type_nodes = self.nodes_by_type('SymbolType')
        symbol_types = {x.node_value.name: x.node_value for x in symbol_type_nodes}

        def get_source_nodes(graph, node):
            """
            Given a Symbol node on the graph, returns a list of connected material nodes.
            This list symbolizes the set of materials for which this Symbol is a property.
            Args:
                graph (networkx.MultiDiGraph): graph on which the node is stored.
                node (PropnetNode): node on the graph whose connected material nodes are to be found.
            Returns:
                (list<PropnetNode>): list of material type nodes that are connected to this node.
            """
            to_return = []
            for n in graph.in_edges(node):
                if n[0].node_type == PropnetNodeType['Material']:
                    to_return.append(n[0])
            return to_return

        for node in symbol_nodes:
            source_dict[node.node_value] = get_source_nodes(self.graph, node)

        ##
        # For each candidate model, check if we have active property types to match inputs and conditions.
        # If so, produce the available output properties using all possible permutations of inputs & add
        #     new models that can be calculated from previously-derived properties.
        # If certain outputs have already been generated, do not generate duplicate outputs.
        ##

        # Keeps track of number of Symbol_Types derived from the current loop iteration.
        # Loop terminates when no new properties are derived from any models.

        original_models = {x for x in candidate_models}
        evaluated_models = set()
        next_round_models = candidate_models
        while True:
            added_on_loop = False
            candidate_models = next_round_models
            next_round_models = set()
            while len(candidate_models) > 0:
                model = candidate_models.pop()
                outputs = []
                # Cache necessary data from model_node: input symbols, types, and conditions.
                legend = model.symbol_mapping
                sym_inputs = model.input_symbols
                for i in sym_inputs:
                    for c in model.constraint_symbols:
                        if c not in i:
                            i.append(c)

                def get_types(symbols_in, legend, symbol_types):
                    """Converts symbols used in equations to SymbolType objects"""
                    to_return = []
                    for l in symbols_in:
                        if not isinstance(l, list):
                            l = [l]
                        out = []
                        for i in l:
                            to_append = symbol_types.get(legend[i])
                            if not to_append:
                                raise Exception('Error evaluating graph: Model references SymbolType'
                                                'objects that do not appear in the graph.')
                            out.append(to_append)
                        to_return.append(out)
                    return to_return

                # list<list<SymbolType>>, representing sets of input properties the model accepts.
                type_inputs = get_types(sym_inputs, legend, symbol_types)

                # Recursive helper method.
                # Look through all input sets and match with all combinations from lookup_dict.
                def gen_input_dicts(symbols, candidate_props, level):
                    """
                    Recursively generates all possible combinations of input arguments.
                    Args:
                        symbols (list<str>):
                            one set of input symbols required by the model.
                        candidate_props (list<list<Symbol>>):
                            list of potential values that can be plugged into each symbol,
                            the outer list corresponds by ordering to the symbols list,
                            the inner list gives values that can be plugged in to each symbol.
                        level (int):
                            internal parameter used for recursion, says which symbol is being enumerated, should
                                     be set to the final index value of symbols.
                    Returns:
                        (list<dict<String, Symbol>>) list of dictionaries giving symbol strings mapped to values.
                    """
                    current_level = []
                    candidates = candidate_props[level]
                    for candidate in candidates:
                        current_level.append({symbols[level]: candidate})
                    if level == 0:
                        return current_level
                    else:
                        others = gen_input_dicts(symbols, candidate_props, level-1)
                        to_return = []
                        for entry1 in current_level:
                            for entry2 in others:
                                merged_dict = {}
                                for (k, v) in entry1.items():
                                    merged_dict[k] = v
                                for (k, v) in entry2.items():
                                    merged_dict[k] = v
                                to_return.append(merged_dict)
                        return to_return

                # Get candidate input Symbols for the given model.
                # Skip over any input Symbol lists that have already been evaluated.
                for i in range(0, len(type_inputs)):
                    candidate_properties = []
                    for j in range(0, len(type_inputs[i])):
                        candidate_properties.append(lookup_dict.get(type_inputs[i][j], []))
                    input_sets = gen_input_dicts(sym_inputs[i], candidate_properties, len(candidate_properties)-1)
                    for input_set in input_sets:
                        if not model.check_constraints(input_set):
                            continue
                        plug_in_set = {}
                        sourcing = set()
                        for (k, v) in input_set.items():
                            plug_in_set[k] = v.value
                            for elem in source_dict[v]:
                                sourcing.add(elem)
                        outputs.append({"output": model.evaluate(plug_in_set), "source": sourcing})

                # For any new outputs generated, create the appropriate SymbolNode and connections to SymbolTypeNodes
                # For any new outputs generated, create the appropriate connections from Material Nodes
                # For any new outputs generated, add new models connected to the derived SymbolTypeNodes to the
                #     candidate_models list & update convenience data structures.
                # Mutates this graph.
                symbol_outputs = []
                output_sources = []
                if len(outputs) == 0:
                    next_round_models.add(model)
                else:
                    added_on_loop = True
                    evaluated_models.add(model)
                for entry in outputs:
                    for (k, v) in entry['output'].items():
                        prop_type = symbol_types.get(legend.get(k))
                        if not prop_type:
                            continue
                        symbol_outputs.append(Symbol(prop_type, v, None))
                        output_sources.append(entry['source'])
                for i in range(0, len(symbol_outputs)):
                    # Add outputs to graph.
                    symbol = symbol_outputs[i]
                    symbol_node = PropnetNode(node_type=PropnetNodeType['Symbol'], node_value=symbol)
                    if symbol_node in self.graph:
                        continue
                    symbol_type_node = PropnetNode(node_type=PropnetNodeType['SymbolType'], node_value=symbol.type)
                    self.graph.add_edge(symbol_node, symbol_type_node)
                    for source_node in output_sources[i]:
                        self.graph.add_edge(source_node, symbol_node)

                    # Strategy A:

                    if len(output_sources[i]) == 1:
                        store = output_sources[i].__iter__().__next__()
                        store.node_value.graph.add_edge(store.node_value.root_node, symbol_node)

                    # Strategy B:
                    """
                    for store in output_sources[i]:
                        store.node_value.graph.add_edge(store.node_value.root_node, symbol_node)
                    """

                    # Update helper data structures etc. for next cycle.
                    source_dict[symbol] = get_source_nodes(self.graph, symbol_node)
                    if symbol.type not in lookup_dict:
                        lookup_dict[symbol.type] = [symbol]
                    else:
                        lookup_dict[symbol.type] += [symbol]
                    if not property_type or symbol.type in property_type:
                        for neighbor in self.graph.neighbors(symbol_type_node):
                            if neighbor.node_type == PropnetNodeType['Model']:
                                if neighbor.node_value not in original_models:
                                    next_round_models.add(neighbor.node_value)
            if not added_on_loop:
                break
예제 #17
0
    def test_symbol_register_unregister(self):
        A = Symbol('a', ['A'], ['A'], units='dimensionless', shape=1)

        self.assertIn(A.name, Registry("symbols"))
        self.assertTrue(A.registered)
        A.unregister()
        self.assertNotIn(A.name, Registry("symbols"))
        self.assertFalse(A.registered)
        A.register()
        self.assertTrue(A.registered)
        with self.assertRaises(KeyError):
            A.register(overwrite_registry=False)

        A.unregister()
        A = Symbol('a', ['A'], ['A'],
                   units='dimensionless',
                   shape=1,
                   register=False)
        self.assertNotIn(A.name, Registry("symbols"))
        self.assertFalse(A.registered)

        A.register()
        with self.assertRaises(KeyError):
            _ = Symbol('a', ['A'], ['A'],
                       units='dimensionless',
                       shape=1,
                       register=True,
                       overwrite_registry=False)

        A_replacement = Symbol('a', ['A^*'], ['A^*'],
                               units='kilogram',
                               shape=1)

        A_registered = Registry("symbols")['a']
        self.assertIs(A_registered, A_replacement)
        self.assertIsNot(A_registered, A)