コード例 #1
0
ファイル: analytic.py プロジェクト: akhi28/OpenMDAO-Framework
def _d_conn(expr_txt, target, ascope):
    """ Evaluates the derivative of a source-target variable connection.
    This includes the derivative of the expression as well as the
    derivative of the unit conversion factor."""
    
    expr = ascope._exprmapper.get_expr(expr_txt)
    source = expr.refs().pop()
        
    # Need derivative of the expression
    expr_deriv = expr.evaluate_gradient(scope=ascope,
                                        wrt=source)
    
    # We also need the derivative of the unit
    # conversion factor if there is one
    metadata = expr.get_metadata('units')
    source_unit = [x[1] for x in metadata if x[0] == source]
    if source_unit and source_unit[0]:
        dest_expr = ascope._exprmapper.get_expr(target)
        metadata = dest_expr.get_metadata('units')
        target_unit = [x[1] for x in metadata if x[0] == target]

        expr_deriv[source] = expr_deriv[source] * \
            convert_units(1.0, source_unit[0], target_unit[0])
        
    return source, expr_deriv
コード例 #2
0
    def test_prefix_plus_math(self):
        # From an issue: m**2 converts find, but cm**2 does not.

        x1 = units.convert_units(1.0, "m**2", "cm**2")
        self.assertEqual(x1, 10000.0)

        # Let's make sure we can dclare some complicated units
        x = units.PhysicalQuantity("7200nm**3/kPa*dl")
コード例 #3
0
    def test_prefix_plus_math(self):
        # From an issue: m**2 converts find, but cm**2 does not.

        x1 = units.convert_units(1.0, 'm**2', 'cm**2')
        self.assertEqual(x1, 10000.0)

        # Let's make sure we can dclare some complicated units
        x = units.PhysicalQuantity('7200nm**3/kPa*dl')
コード例 #4
0
 def test_assignment(self):
     # check starting value
     self.assertEqual(3.1415926, self.hobj.float1)
     # check default value
     self.assertEqual(98.9, self.hobj.get_trait('float1').default)
     
     # use unit_convert to perform unit conversion
     self.hobj.float1 = 3.
     self.hobj.float2 = convert_units(self.hobj.float1, self.hobj.get_trait('float1').units,
                                      'inch')
     self.assertAlmostEqual(36., self.hobj.float2,5)
コード例 #5
0
    def test_prefix_plus_math(self):
        # From an issue: m**2 converts fine, but cm**2 does not.

        x1 = units.convert_units(1.0, 'm**2', 'cm**2')
        self.assertEqual(x1, 10000.0)

        # Let's make sure we can dclare some complicated units
        x = units.PhysicalQuantity('7200nm**3/kPa*dL')

        # from issue 825, make sure you can handle single characters before a /
        x = units.PhysicalQuantity('1 g/kW')
コード例 #6
0
ファイル: test_units.py プロジェクト: jcchin/project_clippy
    def test_prefix_plus_math(self):
        # From an issue: m**2 converts fine, but cm**2 does not.

        x1 = convert_units(1.0, 'm**2', 'cm**2')
        self.assertEqual(x1, 10000.0)

        # Let's make sure we can dclare some complicated units
        x = PhysicalQuantity('7200nm**3/kPa*dL')

        #from issue 825, make sure you can handle single characters before a /
        x = PhysicalQuantity('1 g/kW')
コード例 #7
0
 def test_assignment(self):
     # check starting value
     self.assertTrue(all(array([1.,2.,3.]) == self.hobj.arr1))
     # check default value
     self.assertEqual([98.9], self.hobj.get_trait('arr1').trait_type.default_value)
     
     # use convert_units to perform unit conversion
     self.hobj.arr2 = convert_units(self.hobj.arr1, self.hobj.get_trait('arr1').units,
                                      'inch')
     self.assertAlmostEqual(12., self.hobj.arr2[0], 5)
     self.assertAlmostEqual(24., self.hobj.arr2[1], 5)
     self.assertAlmostEqual(36., self.hobj.arr2[2], 5)
コード例 #8
0
 def test_unit_conversion(self):
     self.hobj.float2 = 12.  # inches
     self.hobj.float1 = convert_units(self.hobj.float2, self.hobj.get_trait('float2').units,
                                      'ft')
     self.assertEqual(self.hobj.float1, 1.) # 12 inches = 1 ft
     
     # now set to a value that will violate constraint after conversion
     self.hobj.float2 = 1200.  # inches
     try:
         self.hobj.float1 = self.hobj.get_wrapped_attr('float2')
     except ValueError, err:
         self.assertEqual(str(err), 
             ": Variable 'float1' must be a float in the range [0.0, 99.0], but a value of 100.0 <type 'float'> was specified.")
コード例 #9
0
    def test_assignment(self):
        # check starting value
        self.assertTrue(all(array([1., 2., 3.]) == self.hobj.arr1))
        # check default value
        self.assertEqual([98.9],
                         self.hobj.get_trait('arr1').trait_type.default_value)

        # use convert_units to perform unit conversion
        self.hobj.arr2 = convert_units(self.hobj.arr1,
                                       self.hobj.get_trait('arr1').units,
                                       'inch')
        self.assertAlmostEqual(12., self.hobj.arr2[0], 5)
        self.assertAlmostEqual(24., self.hobj.arr2[1], 5)
        self.assertAlmostEqual(36., self.hobj.arr2[2], 5)
コード例 #10
0
ファイル: test_float.py プロジェクト: cephdon/meta-core
    def test_unit_conversion(self):
        self.hobj.float2 = 12.  # inches
        self.hobj.float1 = convert_units(self.hobj.float2,
                                         self.hobj.get_trait('float2').units,
                                         'ft')
        self.assertEqual(self.hobj.float1, 1.)  # 12 inches = 1 ft

        # now set to a value that will violate constraint after conversion
        self.hobj.float2 = 1200.  # inches
        try:
            self.hobj.float1 = self.hobj.get_wrapped_attr('float2')
        except ValueError, err:
            self.assertEqual(
                str(err),
                ": Variable 'float1' must be a float in the range [0.0, 99.0], but a value of 100.0 <type 'float'> was specified."
            )
コード例 #11
0
def _d_conn(expr_txt, target, ascope):
    """ Evaluates the derivative of a source-target variable connection.
    This includes the derivative of the expression as well as the
    derivative of the unit conversion factor."""

    expr = ascope._exprmapper.get_expr(expr_txt)
    source = expr.refs().pop()

    # Need derivative of the expression
    expr_deriv = expr.evaluate_gradient(scope=ascope, wrt=source)

    # We also need the derivative of the unit
    # conversion factor if there is one
    metadata = expr.get_metadata('units')
    source_unit = [x[1] for x in metadata if x[0] == source]
    if source_unit and source_unit[0]:
        dest_expr = ascope._exprmapper.get_expr(target)
        metadata = dest_expr.get_metadata('units')
        target_unit = [x[1] for x in metadata if x[0] == target]

        expr_deriv[source] = expr_deriv[source] * \
            convert_units(1.0, source_unit[0], target_unit[0])

    return source, expr_deriv
コード例 #12
0
ファイル: chain_rule.py プロジェクト: RubenvdBerg/SatToolAG
    def _recurse_assy(self, scope, upscope_derivs, upscope_param):
        """Enables assembly recursion by scope translation."""

        # Find all assembly boundary connections, and propagate
        # derivatives through the expressions.
        local_derivs = {}
        name = scope.name

        for item in scope._depgraph.var_edges('@xin'):
            src_expr = item[0].replace('@xin.', '')
            expr = scope._exprmapper.get_expr(src_expr)
            src = expr.refs().pop()
            upscope_src = src.replace('parent.', '')

            # Real connections on boundary
            dest = item[1]
            dest = dest.split('.')[1]
            dest_txt = dest.replace('@bin.', '')

            # Differentiate all expressions
            expr_deriv = expr.evaluate_gradient(scope=scope, wrt=src)

            # We also need the derivative of the unit
            # conversion factor if there is one
            metadata = expr.get_metadata('units')
            source_unit = [x[1] for x in metadata if x[0] == src]
            if source_unit and source_unit[0]:
                dest_expr = scope._exprmapper.get_expr(dest_txt)
                metadata = dest_expr.get_metadata('units')
                target_unit = [x[1] for x in metadata if x[0] == dest_txt]

                expr_deriv[src] = expr_deriv[src] * \
                           convert_units(1.0, source_unit[0], target_unit[0])

            if dest in local_derivs:
                local_derivs[dest] += \
                    upscope_derivs[upscope_src]*expr_deriv[src]
            else:
                local_derivs[dest] = \
                    upscope_derivs[upscope_src]*expr_deriv[src]

        param = upscope_param.split('.')
        if param[0] == name:
            param = param[1:].join('.')
        else:
            param = ''

        # Find derivatives for this assembly's workflow
        self._chain_workflow(local_derivs, scope.driver, param)

        # Convert scope and return gradient of connected components.
        for item in scope._depgraph.var_in_edges('@bout'):
            src = item[0]
            upscope_src = '%s.%s' % (name, src)
            dest = item[1]

            # Real connections on boundary need expressions differentiated
            if dest.count('.') < 2:

                upscope_dest = dest.replace('@bout', name)
                dest = dest.replace('@bout.', '')

                expr_txt = scope._depgraph.get_source(dest)
                expr = scope._exprmapper.get_expr(expr_txt)
                expr_deriv = expr.evaluate_gradient(scope=scope, wrt=src)

                upscope_derivs[
                    upscope_dest] = local_derivs[src] * expr_deriv[src]

            # Fake connection, so just add source
            else:
                upscope_derivs[upscope_src] = local_derivs[src]

        # Finally, stuff in all our extra unconnected outputs because they may
        # be referenced by an objective or constraint at the outer scope.
        for key, value in local_derivs.iteritems():

            if key[0] == '@':
                continue

            target = '%s.%s' % (name, key)
            if target not in upscope_derivs:
                upscope_derivs[target] = value
コード例 #13
0
    def _chain_workflow(self, derivs, scope, param):
        """Process a workflow calculating all intermediate derivatives
        using the chain rule. This can be called recursively to handle
        nested assemblies."""
        
        # Figure out what outputs we need
        scope_name = scope.get_pathname()
        if scope_name not in self.edge_dicts:
            self._find_edges(scope, scope)

        # Loop through each comp in the workflow
        for node_names in self.dworkflow[scope_name]:
    
            # If it's a list, then it's a set of components to finite
            # difference together.
            if not isinstance(node_names, list):
                node = scope.parent.get(node_names)
                fdblock = False
                node_names = [node_names]
            else:
                fdblock = True
    
            # Finite difference block
            if fdblock:
                
                fd = self.fdhelpers[scope_name][str(node_names)]
                
                input_dict = {}
                for item in fd.list_wrt():
                    input_dict[item] = scope.parent.get(item)
                    
                output_dict = {}
                for item in fd.list_outs():
                    output_dict[item] = scope.parent.get(item)
                        
                local_derivs = fd.run(input_dict, output_dict)
            
            # We don't handle nested drivers.
            elif isinstance(node, Driver):
                raise NotImplementedError('Nested drivers')
            
            # Recurse into assemblies.
            elif isinstance(node, Assembly):
                
                if not isinstance(node.driver, Run_Once):
                    raise NotImplementedError('Nested drivers')
                
                self._recurse_assy(node, derivs, param)
                continue
                                     
            # This component can determine its derivatives.
            elif hasattr(node, 'calculate_first_derivatives'):
                
                node.calc_derivatives(first=True)
                
                local_derivs = node.derivatives.first_derivatives
                
            # The following executes for components with derivatives and
            # blocks that are finite-differenced
            for node_name in node_names:
            
                node = scope.parent.get(node_name)
                ascope = node.parent
                
                local_inputs = self.edge_dicts[scope_name][node_name][0]
                local_outputs = self.edge_dicts[scope_name][node_name][1]
                
                incoming_deriv_names = {}
                incoming_derivs = {}
                
                for input_name in local_inputs:
                    
                    full_name = '.'.join([node_name, input_name])
    
                    # Inputs who are hooked directly to the current param
                    if full_name == param or \
                       (full_name in self.grouped_param_names and \
                        self.grouped_param_names[full_name] == param):
                            
                        incoming_deriv_names[input_name] = full_name
                        incoming_derivs[full_name] = derivs[full_name]
                        
                    # Do nothing for inputs connected to the other params
                    elif full_name in self.param_names:
                        pass
                    
                    # Inputs who are connected to something with a derivative
                    else:
                
                        sources = ascope._depgraph.connections_to(full_name)
                        expr_txt = sources[0][0]
                        target = sources[0][1]
                        
                        # Variables on an assembly boundary
                        if expr_txt[0:4] == '@bin':
                            expr_txt = expr_txt.replace('@bin.', '')
                        
                        expr = ascope._exprmapper.get_expr(expr_txt)
                        source = expr.refs().pop()
                            
                        # Need derivative of the expression
                        expr = ascope._exprmapper.get_expr(expr_txt)
                        expr_deriv = expr.evaluate_gradient(scope=ascope,
                                                            wrt=source)
                        
                        # We also need the derivative of the unit
                        # conversion factor if there is one
                        metadata = expr.get_metadata('units')
                        source_unit = [x[1] for x in metadata if x[0] == source]
                        if source_unit and source_unit[0]:
                            dest_expr = ascope._exprmapper.get_expr(target)
                            metadata = dest_expr.get_metadata('units')
                            target_unit = [x[1] for x in metadata \
                                           if x[0] == target]
    
                            expr_deriv[source] = expr_deriv[source] * \
                                convert_units(1.0, source_unit[0], 
                                              target_unit[0])
    
                        # Store our derivatives to chain them
                        incoming_deriv_names[input_name] = full_name
                        if full_name in incoming_derivs:
                            incoming_derivs[full_name] += derivs[source] * \
                                expr_deriv[source]
                        else:
                            incoming_derivs[full_name] = derivs[source] * \
                                expr_deriv[source]
                        
                            
                # CHAIN RULE
                # Propagate derivatives wrt parameter through current component
                for output_name in local_outputs:
                    
                    full_output_name = '.'.join([node_name, output_name])
                    derivs[full_output_name] = 0.0
                    
                    if fdblock:
                        local_out = full_output_name
                    else:
                        local_out = output_name
                    
                    for input_name, full_input_name in \
                        incoming_deriv_names.iteritems():
                        
                        if fdblock:
                            local_in = "%s.%s" % (node_name, input_name)
                        else:
                            local_in = input_name
                        
                        derivs[full_output_name] += \
                            local_derivs[local_out][local_in] * \
                            incoming_derivs[full_input_name]
コード例 #14
0
    def _chain_workflow(self, derivs, scope, param):
        """Process a workflow calculating all intermediate derivatives
        using the chain rule. This can be called recursively to handle
        nested assemblies."""

        # Loop through each comp in the workflow
        for node in scope.workflow.__iter__():
            
            node_name = node.name
            #print "processing ", node_name
    
            incoming_deriv_names = {}
            incoming_derivs = {}
            
            # We don't handle nested drivers yet.
            if isinstance(node, Driver):
                raise NotImplementedError('Nested drivers')
            
            # Recurse into assemblies.
            elif isinstance(node, Assembly):
                
                if not isinstance(node.driver, Run_Once):
                    raise NotImplementedError('Nested drivers')
                
                self._recurse_assy(node, derivs, param)
                                     
            # This component can determine its derivatives.
            elif hasattr(node, 'calculate_first_derivatives'):
                
                node.calc_derivatives(first=True)
                
                local_inputs = node.derivatives.in_names
                local_outputs = node.derivatives.out_names
                local_derivs = node.derivatives.first_derivatives
                
                for input_name in local_inputs:
                    
                    full_name = '.'.join([node_name, input_name])

                    # Inputs who are hooked directly to the parameters
                    if full_name == param and \
                            full_name in derivs:
                            
                        incoming_deriv_names[input_name] = full_name
                        incoming_derivs[full_name] = derivs[full_name]
                        
                    # Inputs who are connected to something with a derivative
                    else:
                
                        sources = node.parent._depgraph.connections_to(full_name)
                        
                        # This list keeps track of duplicated sources, which
                        # are a biproduct of a connection to multiple inputs
                        # across a fake boundary node.
                        used_sources = []
                        
                        for source_tuple in sources:
                            
                            source = source_tuple[0]
                            expr_txt = node.parent._depgraph.get_source(source_tuple[1])
                            
                            # Variables on an assembly boundary
                            if source[0:4] == '@bin' and source.count('.') < 2:
                                source = source.replace('@bin.', '')
                            
                            # Only process inputs who are connected to outputs
                            # with derivatives in the chain
                            if expr_txt and source in derivs and \
                               source not in used_sources:
                                
                                # Need derivative of the expression
                                expr = node.parent._exprmapper.get_expr(expr_txt)
                                expr_deriv = expr.evaluate_gradient(scope=node.parent,
                                                                    wrt=source)
                                
                                # We also need the derivative of the unit
                                # conversion factor if there is one
                                metadata = expr.get_metadata('units')
                                source_unit = [x[1] for x in metadata if x[0]==source]
                                if source_unit and source_unit[0]:
                                    dest_expr = node.parent._exprmapper.get_expr(source_tuple[1])
                                    metadata = dest_expr.get_metadata('units')
                                    target_unit = [x[1] for x in metadata if x[0]==source_tuple[1]]

                                    expr_deriv[source] = expr_deriv[source] * \
                                        convert_units(1.0, source_unit[0], target_unit[0])

                                incoming_deriv_names[input_name] = full_name
                                if full_name in incoming_derivs:
                                    incoming_derivs[full_name] += derivs[source] * \
                                        expr_deriv[source]
                                else:
                                    incoming_derivs[full_name] = derivs[source] * \
                                        expr_deriv[source]
                                    
                                used_sources.append(source)
                        
                            
                # CHAIN RULE
                # Propagate derivatives wrt parameter through current component
                for output_name in local_outputs:
                    
                    full_output_name = '.'.join([node_name, output_name])
                    derivs[full_output_name] = 0.0
                    
                    for input_name, full_input_name in incoming_deriv_names.iteritems():
                        derivs[full_output_name] += \
                            local_derivs[output_name][input_name] * \
                            incoming_derivs[full_input_name]
                            
            # This component must be finite differenced.
            else:
                raise NotImplementedError('CRND cannot Finite Difference subblocks yet.')
コード例 #15
0
ファイル: chain_rule.py プロジェクト: hitej/meta-core
    def _chain_workflow(self, derivs, scope, param):
        """Process a workflow calculating all intermediate derivatives
        using the chain rule. This can be called recursively to handle
        nested assemblies."""

        # Figure out what outputs we need
        scope_name = scope.get_pathname()
        if scope_name not in self.edge_dicts:
            self._find_edges(scope, scope)

        # Loop through each comp in the workflow
        for node_names in self.dworkflow[scope_name]:

            # If it's a list, then it's a set of components to finite
            # difference together.
            if not isinstance(node_names, list):
                node = scope.parent.get(node_names)
                fdblock = False
                node_names = [node_names]
            else:
                fdblock = True

            # Finite difference block
            if fdblock:

                fd = self.fdhelpers[scope_name][str(node_names)]

                input_dict = {}
                for item in fd.list_wrt():
                    input_dict[item] = scope.parent.get(item)

                output_dict = {}
                for item in fd.list_outs():
                    output_dict[item] = scope.parent.get(item)

                local_derivs = fd.run(input_dict, output_dict)

            # We don't handle nested drivers.
            elif isinstance(node, Driver):
                raise NotImplementedError("Nested drivers")

            # Recurse into assemblies.
            elif isinstance(node, Assembly):

                if not isinstance(node.driver, Run_Once):
                    raise NotImplementedError("Nested drivers")

                self._recurse_assy(node, derivs, param)
                continue

            # This component can determine its derivatives.
            elif hasattr(node, "calculate_first_derivatives"):

                node.calc_derivatives(first=True)

                local_derivs = node.derivatives.first_derivatives

            # The following executes for components with derivatives and
            # blocks that are finite-differenced
            for node_name in node_names:

                node = scope.parent.get(node_name)
                ascope = node.parent

                local_inputs = self.edge_dicts[scope_name][node_name][0]
                local_outputs = self.edge_dicts[scope_name][node_name][1]

                incoming_deriv_names = {}
                incoming_derivs = {}

                for input_name in local_inputs:

                    full_name = ".".join([node_name, input_name])

                    # Inputs who are hooked directly to the current param
                    if full_name == param or (
                        full_name in self.grouped_param_names and self.grouped_param_names[full_name] == param
                    ):

                        incoming_deriv_names[input_name] = full_name
                        incoming_derivs[full_name] = derivs[full_name]

                    # Do nothing for inputs connected to the other params
                    elif full_name in self.param_names:
                        pass

                    # Inputs who are connected to something with a derivative
                    else:

                        sources = ascope._depgraph.connections_to(full_name)
                        expr_txt = sources[0][0]
                        target = sources[0][1]

                        # Variables on an assembly boundary
                        if expr_txt[0:4] == "@bin":
                            expr_txt = expr_txt.replace("@bin.", "")

                        expr = ascope._exprmapper.get_expr(expr_txt)
                        source = expr.refs().pop()

                        # Need derivative of the expression
                        expr = ascope._exprmapper.get_expr(expr_txt)
                        expr_deriv = expr.evaluate_gradient(scope=ascope, wrt=source)

                        # We also need the derivative of the unit
                        # conversion factor if there is one
                        metadata = expr.get_metadata("units")
                        source_unit = [x[1] for x in metadata if x[0] == source]
                        if source_unit and source_unit[0]:
                            dest_expr = ascope._exprmapper.get_expr(target)
                            metadata = dest_expr.get_metadata("units")
                            target_unit = [x[1] for x in metadata if x[0] == target]

                            expr_deriv[source] = expr_deriv[source] * convert_units(1.0, source_unit[0], target_unit[0])

                        # Store our derivatives to chain them
                        incoming_deriv_names[input_name] = full_name
                        if full_name in incoming_derivs:
                            incoming_derivs[full_name] += derivs[source] * expr_deriv[source]
                        else:
                            incoming_derivs[full_name] = derivs[source] * expr_deriv[source]

                # CHAIN RULE
                # Propagate derivatives wrt parameter through current component
                for output_name in local_outputs:

                    full_output_name = ".".join([node_name, output_name])
                    derivs[full_output_name] = 0.0

                    if fdblock:
                        local_out = full_output_name
                    else:
                        local_out = output_name

                    for input_name, full_input_name in incoming_deriv_names.iteritems():

                        if fdblock:
                            local_in = "%s.%s" % (node_name, input_name)
                        else:
                            local_in = input_name

                        derivs[full_output_name] += local_derivs[local_out][local_in] * incoming_derivs[full_input_name]
コード例 #16
0
ファイル: chain_rule.py プロジェクト: hitej/meta-core
    def _recurse_assy(self, scope, upscope_derivs, upscope_param):
        """Enables assembly recursion by scope translation."""

        # Find all assembly boundary connections, and propagate
        # derivatives through the expressions.
        local_derivs = {}
        name = scope.name

        for item in scope._depgraph.var_edges("@xin"):
            src_expr = item[0].replace("@xin.", "")
            expr = scope._exprmapper.get_expr(src_expr)
            src = expr.refs().pop()
            upscope_src = src.replace("parent.", "")

            # Real connections on boundary
            dest = item[1]
            dest = dest.split(".")[1]
            dest_txt = dest.replace("@bin.", "")

            # Differentiate all expressions
            expr_deriv = expr.evaluate_gradient(scope=scope, wrt=src)

            # We also need the derivative of the unit
            # conversion factor if there is one
            metadata = expr.get_metadata("units")
            source_unit = [x[1] for x in metadata if x[0] == src]
            if source_unit and source_unit[0]:
                dest_expr = scope._exprmapper.get_expr(dest_txt)
                metadata = dest_expr.get_metadata("units")
                target_unit = [x[1] for x in metadata if x[0] == dest_txt]

                expr_deriv[src] = expr_deriv[src] * convert_units(1.0, source_unit[0], target_unit[0])

            if dest in local_derivs:
                local_derivs[dest] += upscope_derivs[upscope_src] * expr_deriv[src]
            else:
                local_derivs[dest] = upscope_derivs[upscope_src] * expr_deriv[src]

        param = upscope_param.split(".")
        if param[0] == name:
            param = param[1:].join(".")
        else:
            param = ""

        # Find derivatives for this assembly's workflow
        self._chain_workflow(local_derivs, scope.driver, param)

        # Convert scope and return gradient of connected components.
        for item in scope._depgraph.var_in_edges("@bout"):
            src = item[0]
            upscope_src = "%s.%s" % (name, src)
            dest = item[1]

            # Real connections on boundary need expressions differentiated
            if dest.count(".") < 2:

                upscope_dest = dest.replace("@bout", name)
                dest = dest.replace("@bout.", "")

                expr_txt = scope._depgraph.get_source(dest)
                expr = scope._exprmapper.get_expr(expr_txt)
                expr_deriv = expr.evaluate_gradient(scope=scope, wrt=src)

                upscope_derivs[upscope_dest] = local_derivs[src] * expr_deriv[src]

            # Fake connection, so just add source
            else:
                upscope_derivs[upscope_src] = local_derivs[src]

        # Finally, stuff in all our extra unconnected outputs because they may
        # be referenced by an objective or constraint at the outer scope.
        for key, value in local_derivs.iteritems():

            if key[0] == "@":
                continue

            target = "%s.%s" % (name, key)
            if target not in upscope_derivs:
                upscope_derivs[target] = value
コード例 #17
0
ファイル: chain_rule.py プロジェクト: RubenvdBerg/SatToolAG
    def _chain_workflow(self, derivs, scope, param):
        """Process a workflow calculating all intermediate derivatives
        using the chain rule. This can be called recursively to handle
        nested assemblies."""

        # Figure out what outputs we need
        scope_name = scope.get_pathname()
        if scope_name not in self.edge_dicts:
            self._find_edges(scope, scope)

        # Loop through each comp in the workflow
        for node in scope.workflow.__iter__():

            node_name = node.name

            incoming_deriv_names = {}
            incoming_derivs = {}
            ascope = node.parent

            # We don't handle nested drivers yet.
            if isinstance(node, Driver):
                raise NotImplementedError('Nested drivers')

            # Recurse into assemblies.
            elif isinstance(node, Assembly):

                if not isinstance(node.driver, Run_Once):
                    raise NotImplementedError('Nested drivers')

                self._recurse_assy(node, derivs, param)

            # This component can determine its derivatives.
            elif hasattr(node, 'calculate_first_derivatives'):

                node.calc_derivatives(first=True)

                local_inputs = self.edge_dicts[scope_name][node_name][0]
                local_outputs = self.edge_dicts[scope_name][node_name][1]
                local_derivs = node.derivatives.first_derivatives

                for input_name in local_inputs:

                    full_name = '.'.join([node_name, input_name])

                    # Inputs who are hooked directly to the current param
                    if full_name == param:

                        incoming_deriv_names[input_name] = full_name
                        incoming_derivs[full_name] = derivs[full_name]

                    # Do nothing for inputs connected to the other params
                    elif full_name in self.param_names:
                        pass

                    # Inputs who are connected to something with a derivative
                    else:

                        sources = ascope._depgraph.connections_to(full_name)
                        expr_txt = sources[0][0]
                        target = sources[0][1]

                        # Variables on an assembly boundary
                        if expr_txt[0:4] == '@bin':
                            expr_txt = expr_txt.replace('@bin.', '')

                        expr = ascope._exprmapper.get_expr(expr_txt)
                        source = expr.refs().pop()

                        # Need derivative of the expression
                        expr = ascope._exprmapper.get_expr(expr_txt)
                        expr_deriv = expr.evaluate_gradient(scope=ascope,
                                                            wrt=source)

                        # We also need the derivative of the unit
                        # conversion factor if there is one
                        metadata = expr.get_metadata('units')
                        source_unit = [
                            x[1] for x in metadata if x[0] == source
                        ]
                        if source_unit and source_unit[0]:
                            dest_expr = ascope._exprmapper.get_expr(target)
                            metadata = dest_expr.get_metadata('units')
                            target_unit = [x[1] for x in metadata \
                                           if x[0] == target]

                            expr_deriv[source] = expr_deriv[source] * \
                                convert_units(1.0, source_unit[0],
                                              target_unit[0])

                        # Store our derivatives to chain them
                        incoming_deriv_names[input_name] = full_name
                        if full_name in incoming_derivs:
                            incoming_derivs[full_name] += derivs[source] * \
                                expr_deriv[source]
                        else:
                            incoming_derivs[full_name] = derivs[source] * \
                                expr_deriv[source]

                # CHAIN RULE
                # Propagate derivatives wrt parameter through current component
                for output_name in local_outputs:

                    full_output_name = '.'.join([node_name, output_name])
                    derivs[full_output_name] = 0.0

                    for input_name, full_input_name in \
                        incoming_deriv_names.iteritems():

                        derivs[full_output_name] += \
                            local_derivs[output_name][input_name] * \
                            incoming_derivs[full_input_name]

            # This component must be finite differenced.
            else:
                msg = 'CRND cannot Finite Difference subblocks yet.'
                raise NotImplementedError(msg)