예제 #1
0
def UFLFunction(grid, name, order, expr, renumbering=None, virtualize=True, tempVars=True,
                predefined=None, **kwargs):
    scalar = False
    if type(expr) == list or type(expr) == tuple:
        expr = ufl.as_vector(expr)
    elif type(expr) == int or type(expr) == float:
        expr = ufl.as_vector( [expr] )
        scalar = True
    try:
        if expr.ufl_shape == ():
            expr = ufl.as_vector([expr])
            scalar = True
    except:
        return None
    _, coeff_ = ufl.algorithms.analysis.extract_arguments_and_coefficients(expr)
    coeff   = {c : c.toVectorCoefficient()[0] for c in coeff_ if len(c.ufl_shape) == 0 and not c.is_cellwise_constant()}
    expr = replace(expr,coeff)

    if len(expr.ufl_shape) > 1:
        raise AttributeError("can only generate grid functions from vector values UFL expressions not from expressions with shape=",expr.ufl_shape)

    # set up the source class
    source = UFLFunctionSource(grid, expr,
            name,order,
            tempVars=tempVars,virtualize=virtualize,
            predefined=predefined)

    coefficients = source.coefficientList
    numCoefficients = len(coefficients)
    if renumbering is None:
        renumbering = dict()
        renumbering.update((c, i) for i, c in enumerate(sorted((c for c in coefficients if not c.is_cellwise_constant()), key=lambda c: c.count())))
        renumbering.update((c, i) for i, c in enumerate(c for c in coefficients if c.is_cellwise_constant()))
    coefficientNames = ['coefficient' + str(i) if n is None else n for i, n in enumerate(getattr(c, 'name', None) for c in coefficients if not c.is_cellwise_constant())]

    # call code generator
    from dune.generator import builder
    module = builder.load(source.name(), source, "UFLLocalFunction")

    assert hasattr(module,"UFLLocalFunction"),\
          "GridViews of coefficients need to be compatible with the grid view of the ufl local functions"

    class LocalFunction(module.UFLLocalFunction):
        def __init__(self, gridView, name, order, *args, **kwargs):
            self.base = module.UFLLocalFunction
            self._coefficientNames = {n: i for i, n in enumerate(source.coefficientNames)}
            if renumbering is not None:
                self._renumbering = renumbering
                self._setConstant = self.setConstant # module.UFLLocalFunction.__dict__['setConstant']
                self.setConstant = lambda *args: setConstant(self,*args)
            self.constantShape = source._constantShapes
            self._constants = [c for c in source.constantList if isinstance(c,Constant)]
            self.scalar = scalar
            init(self, gridView, name, order, *args, **kwargs)

    return LocalFunction
예제 #2
0
def function(gv,callback,includeFiles=None,*args,name=None,order=None,dimRange=None):
    if name is None:
        name = "tmp"+str(gv._gfCounter)
        gv.__class__._gfCounter += 1
    if isString(callback):
        if includeFiles is None:
            raise ValueError("""if `callback` is the name of a C++ function
            then at least one include file containing that function must be
            provided""")

        # unique header guard is added further down
        source  = '#include <config.h>\n\n'
        source += '#define USING_DUNE_PYTHON 1\n\n'
        includes = []
        if isString(includeFiles):
            if not os.path.dirname(includeFiles):
                with open(includeFiles, "r") as include:
                    source += include.read()
                source += "\n"
            else:
                source += "#include <"+includeFiles+">\n"
                includes += [includeFiles]
        elif hasattr(includeFiles,"readable"): # for IOString
            with includeFiles as include:
                source += include.read()
            source += "\n"
        elif isinstance(includeFiles, list):
            for includefile in includeFiles:
                if not os.path.dirname(includefile):
                    with open(includefile, "r") as include:
                        source += include.read()
                    source += "\n"
            else:
                source += "#include <"+includefile+">\n"
                includes += [includefile]
        includes += gv.cppIncludes
        argTypes = []
        for arg in args:
            t,i = cppType(arg)
            argTypes.append(t)
            includes += i

        signature = callback + "( " + ", ".join(argTypes) + " )"
        moduleName = "gf_" + hashIt(signature) + "_" + hashIt(source)

        # add unique header guard with moduleName
        source = '#ifndef Guard_'+moduleName+'\n' + \
                 '#define Guard_'+moduleName+'\n\n' + \
                 source

        includes = sorted(set(includes))
        source += "".join(["#include <" + i + ">\n" for i in includes])
        source += "\n"
        source += '#include <dune/python/grid/function.hh>\n'
        source += '#include <dune/python/pybind11/pybind11.h>\n'
        source += '\n'

        source += "PYBIND11_MODULE( " + moduleName + ", module )\n"
        source += "{\n"
        source += "  module.def( \"gf\", [module] ( "+gv.cppTypeName + " &gv"+"".join([", "+argTypes[i] + " arg" + str(i) for i in range(len(argTypes))]) + " ) {\n"
        source += "      auto callback="+callback+"<"+gv.cppTypeName+">( "+",".join(["arg"+str(i) for i in range(len(argTypes))]) +"); \n"
        source += "      return Dune::Python::registerGridFunction<"+gv.cppTypeName+",decltype(callback)>(module,pybind11::cast(gv),\"tmp\",callback);\n"
        source += "    },"
        source += "    "+",".join(["pybind11::keep_alive<0,"+str(i+1)+">()" for i in range(len(argTypes)+1)])
        source += ");\n"
        source += "}\n"
        source += "#endif\n"
        gf = builder.load(moduleName, source, signature).gf(gv,*args)
    else:
        if len(inspect.signature(callback).parameters) == 1: # global function, turn into a local function
            callback_ = callback
            callback = lambda e,x: callback_(e.geometry.toGlobal(x))
        else:
            callback_ = None
        if dimRange is None:
            # if no `dimRange` attribute is set on the callback,
            # try to evaluate the function to determin the dimension of
            # the return value. This can fail if the function is singular in
            # the computational domain in which case an exception is raised
            e = gv.elements.__iter__().__next__()
            try:
                y = callback(e,e.referenceElement.position(0,0))
            except ArithmeticError:
                try:
                    y = callback(e,e.referenceElement.position(0,2))
                except ArithmeticError:
                    raise TypeError("can not determin dimension of range of "+
                      "given grid function due to arithmetic exceptions being "+
                      "raised. Add a `dimRange` parameter to the grid function to "+
                      "solve this issue - set `dimRange`=0 for a scalar function.")
            try:
                dimRange = len(y)
            except TypeError:
                dimRange = 0
        if dimRange > 0:
            scalar = "false"
        else:
            scalar = "true"
        FieldVector(dimRange*[0]) # register FieldVector for the return value
        if not dimRange in gv.__class__._functions.keys():
            # unique header key is added further down
            source  = '#include <config.h>\n\n'
            source += '#define USING_DUNE_PYTHON 1\n\n'
            includes = gv.cppIncludes

            signature = gv.cppTypeName+"::gf<"+str(dimRange)+">"
            moduleName = "gf_" + hashIt(signature) + "_" + hashIt(source)

            # add unique header guard with moduleName
            source = '#ifndef Guard_'+moduleName+'\n' + \
                     '#define Guard_'+moduleName+'\n\n' + \
                     source

            includes = sorted(set(includes))
            source += "".join(["#include <" + i + ">\n" for i in includes])
            source += "\n"
            source += '#include <dune/python/grid/function.hh>\n'
            source += '#include <dune/python/pybind11/pybind11.h>\n'
            source += '\n'

            source += "PYBIND11_MODULE( " + moduleName + ", module )\n"
            source += "{\n"
            source += "  typedef pybind11::function Evaluate;\n";
            source += "  Dune::Python::registerGridFunction< "+gv.cppTypeName+", Evaluate, "+str(dimRange)+" >( module, \"gf\", "+scalar+" );\n"
            source += "}\n"
            source += "#endif\n"
            gfModule = builder.load(moduleName, source, signature)
            gfFunc = getattr(gfModule,"gf"+str(dimRange))
            if callback_ is not None:
                gfFunc.localCall = gfFunc.__call__
                feval = lambda self,e,x=None: callback_(e) if x is None else self.localCall(e,x)
                subclass = type(gfFunc.__name__, (gfFunc,), {"__call__": feval})
                gv.__class__._functions[dimRange] = subclass
            else:
                gv.__class__._functions[dimRange] = gfFunc
        gf = gv.__class__._functions[dimRange](gv,callback)
    def gfPlot(gf, *args, **kwargs):
        gf.grid.plot(gf,*args,**kwargs)
    gf.plot = gfPlot.__get__(gf)
    gf.name = name
    gf.order = order
    return gf
예제 #3
0
def load(grid, form, *args, renumbering=None, tempVars=True,
        virtualize=True, modelPatch=[None,None],
        includes=None):

    if not isinstance(modelPatch,list) and not isinstance(modelPatch,tuple):
        modelPatch = [modelPatch,None]

    if isinstance(form, Equation):
        form = form.lhs - form.rhs

    if isinstance(form, Integrands):
        integrands = form
    else:
        if len(form.arguments()) < 2:
            raise ValueError("Integrands model requires form with at least two arguments.")

        phi_, u_ = form.arguments()

        if phi_.ufl_function_space().scalar:
            phi = TestFunction(phi_.ufl_function_space().toVectorSpace())
            form = replace(form,{phi_:phi[0]})
        else:
            phi = phi_
        if u_.ufl_function_space().scalar:
            u = TrialFunction(u_.ufl_function_space().toVectorSpace())
            form = replace(form,{u_:u[0]})
        else:
            u = u_

        if not isinstance(form, Form):
            raise ValueError("ufl.Form or ufl.Equation expected.")

        _, coeff_ = extract_arguments_and_coefficients(form)
        coeff_ = set(coeff_)

        # added for dirichlet treatment same as conservationlaw model
        dirichletBCs = [arg for arg in args if isinstance(arg, DirichletBC)]
        # remove the dirichletBCs
        arg = [arg for arg in args if not isinstance(arg, DirichletBC)]
        for dBC in dirichletBCs:
            _, coeff__ = extract_arguments_and_coefficients(dBC.ufl_value)
            coeff_ |= set(coeff__)
        coeff = {c : c.toVectorCoefficient()[0] for c in coeff_ if len(c.ufl_shape) == 0 and not c.is_cellwise_constant()}

        form = replace(form,coeff)
        uflExpr = [form]
        for dBC in dirichletBCs:
            arg.append(dBC.replace(coeff))
            uflExpr += [dBC.ufl_value] # arg[-1].ufl_value]

        if modelPatch[1] is not None:
            uflExpr += modelPatch[1]

        derivatives = gatherDerivatives(form, [phi, u])

        derivatives_phi = derivatives[0]
        derivatives_u = derivatives[1]

        integrands = Integrands(u,
                                (d.ufl_shape for d in derivatives_u), (d.ufl_shape for d in derivatives_phi),
                                uflExpr,virtualize)

    if modelPatch[0] is not None:
        modelPatch[0](integrands)

    # set up the source class
    source = Source(integrands, grid, includes, form, *args,
             tempVars=tempVars,virtualize=virtualize)

    # ufl coefficient and constants only have numbers which depend on the
    # order in whch they were generated - we need to keep track of how
    # these numbers are translated into the tuple numbering in the
    # generated C++ code
    if isinstance(form, Form):
        coefficients = set(integrands.coefficientList+integrands.constantList)
        numCoefficients = len(coefficients)
        if renumbering is None:
            renumbering = dict()
            renumbering.update((c, i) for i, c in enumerate(sorted((c for c in coefficients if not c.is_cellwise_constant()), key=lambda c: c.count())))
            renumbering.update((c, i) for i, c in enumerate(c for c in coefficients if c.is_cellwise_constant()))
        coefficientNames = integrands._coefficientNames # ['coefficient' + str(i) if n is None else n for i, n in enumerate(getattr(c, 'name', None) for c in coefficients if not c.is_cellwise_constant())]
    else:
        coefficientNames = form.coefficientNames

    # call code generator
    from dune.generator import builder
    module = builder.load(source.name(), source, "Integrands")

    assert hasattr(module,"Integrands"),\
          "GridViews of coefficients need to be compatible with the grid view of the ufl model"

    rangeValueTuple, domainValueTuple = source.valueTuples()
    setattr(module.Integrands, "_domainValueType", domainValueTuple)
    setattr(module.Integrands, "_rangeValueType", rangeValueTuple)
    # redirect the __init__ method to take care of setting coefficient and renumbering
    class Model(module.Integrands):
        def __init__(self, *args, **kwargs):
            self.base = module.Integrands
            init(self,integrands,*args,**kwargs)
            for c in integrands.constantList:
                c.registerModel(self)

    setattr(Model, '_coefficientNames', {n: i for i, n in enumerate(coefficientNames)})
    if renumbering is not None:
        setattr(Model, '_renumbering', renumbering)
        Model._setConstant = module.Integrands.__dict__['setConstant']
        setattr(Model, 'setConstant', setConstant)

    return Model
예제 #4
0
def load(grid, model, *args, modelPatch=[None,None], virtualize=True, **kwargs):
    if not isinstance(modelPatch,list) and not isinstance(modelPatch,tuple):
        modelPatch = [modelPatch,None]

    from dune.generator import builder
    if isinstance(model, (Equation, Form)):
        model = compileUFL(model, modelPatch[1], *args, **kwargs)
        renumbering = model.coefficients.copy()
        renumbering.update(model.constants)
    else:
        renumbering = kwargs.get("renumbering")

    if isinstance(model, str):
        with open(model, 'r') as modelFile:
             data = modelFile.read()
        name = data.split('PYBIND11_MODULE( ')[1].split(',')[0]
        endPos = name.find('_')
        modelName = name[0:endPos]
        module = builder.load(name, data, modelName)
        renumbering = {}
        if renumbering is not None:
            setattr(module.Model, '_renumbering', renumbering)
            module.Model._init = module.Model.__dict__['__init__']
            setattr(module.Model, '__init__', initModel)
        return module

    if modelPatch[0] is not None:
        modelPatch[0](model)
    else:
        modelPatch = None

    signature = ("" if virtualize else "nv") + model.signature + "_" + hashIt(grid.cppTypeName)
    name = model.baseName + '_' + signature

    writer = SourceWriter()

    writer.emit("#ifndef GuardModelImpl_" + signature)
    writer.emit("#define GuardModelImpl_" + signature)
    writer.emit("#define USING_DUNE_PYTHON 1")

    writer.emit('#include <config.h>')
    writer.emit(["#include <" + i + ">" for i in grid.cppIncludes])
    writer.emit('')
    writer.emit('#include <dune/fem/misc/boundaryidprovider.hh>')
    writer.emit('')
    writer.emit('#include <dune/python/pybind11/pybind11.h>')
    writer.emit('#include <dune/python/pybind11/extensions.h>')
    writer.emit('')
    writer.emit('#include <dune/fempy/py/grid/gridpart.hh>')
    if model.hasCoefficients:
        writer.emit('#include <dune/fempy/function/virtualizedgridfunction.hh>')
        writer.emit('')
    if 'virtualModel' in kwargs:
        virtualModel = kwargs.pop('virtualModel')
    else:
        virtualModel = 'dune/fem/schemes/conservationlawmodel.hh'
    writer.emit('#include <' + virtualModel + '>')

    nameSpace = NameSpace("ModelImpl_" + signature)
    if modelPatch:
        nameSpace.append(model.code(model))
    else:
        nameSpace.append(model.code())

    writer.emit(nameSpace)

    writer.openNameSpace("ModelImpl_" + signature)
    gridPartType = "typename Dune::FemPy::GridPart< " + grid.cppTypeName + " >"
    rangeTypes = ["Dune::FieldVector< " +
            SourceWriter.cpp_fields(c['field']) + ", " + str(c['dimRange']) + " >"\
            for c in model._coefficients]
    coefficients = ["Dune::FemPy::VirtualizedGridFunction<"+gridPartType+", " + r + " >"
                    if not c['typeName'].startswith("Dune::Python::SimpleGridFunction") \
                    else c['typeName'] \
            for r,c in zip(rangeTypes,model._coefficients)]
    modelType = nameSpace.name + "::Model< " + ", ".join([gridPartType] + coefficients) + " >"
    if not virtualize:
        wrapperType = modelType
    else:
        wrapperType = model.modelWrapper.replace(" Model ",modelType)
    if model.hasConstants:
        model.exportSetConstant(writer, modelClass=modelType, wrapperClass=wrapperType)

    writer.closeNameSpace("ModelImpl_" + signature)
    writer.openPythonModule(name)
    code = []
    code += [TypeAlias("GridPart", gridPartType)]
    code += [TypeAlias("Model", nameSpace.name + "::Model< " + ", ".join(["GridPart"] + coefficients) + " >")]
    if virtualize:
        modelType = model.modelWrapper
        code += [TypeAlias("ModelWrapper", model.modelWrapper)]
        code += [TypeAlias("ModelBase", "typename ModelWrapper::Base")]
    else:
        modelType = nameSpace.name + "::Model< " + ", ".join([gridPartType] + coefficients) + " >"
        code += [TypeAlias("ModelWrapper", "Model")]
    writer.emit(code)

    if virtualize:
        writer.emit('// export abstract base class')
        writer.emit('if( !pybind11::already_registered< ModelBase >() )')
        writer.emit('  pybind11::class_< ModelBase >( module, "ModelBase" );')
        writer.emit('')
        writer.emit('// actual wrapper class for model derived from abstract base')
        # writer.emit('pybind11::class_< ModelWrapper > cls( module, "Model", pybind11::base< ModelBase >() );')
        writer.emit('auto cls = Dune::Python::insertClass<ModelWrapper,ModelBase>(module,"Model",'+\
                        'Dune::Python::GenerateTypeName("'+modelType+'"),'+\
                        'Dune::Python::IncludeFiles({})).first;')
    else:
        # writer.emit('pybind11::class_< ModelWrapper > cls( module, "Model" );')
        writer.emit('auto cls = Dune::Python::insertClass<ModelWrapper>(module,"Model",'+\
                        'Dune::Python::GenerateTypeName("'+modelType+'"),'+\
                        'Dune::Python::IncludeFiles({"python/dune/generated/'+name+'.cc"})).first;')
    writer.emit('cls.def_property_readonly( "dimRange", [] ( ModelWrapper & ) { return ' + str(model.dimRange) + '; } );')
    hasDirichletBC = 'true' if model.hasDirichletBoundary else 'false'
    writer.emit('cls.def_property_readonly( "hasDirichletBoundary", [] ( ModelWrapper& ) -> bool { return '+hasDirichletBC+';});')
    writer.emit('')
    for n, number in model._constantNames.items():
        writer.emit('cls.def_property( "' + n + '", ' +
          '[] ( ModelWrapper &self ) { return self.template constant<' + str(number) + '>(); }, ' +
          '[] ( ModelWrapper &self, typename ModelWrapper::ConstantType<' + str(number) + '>& value) { self.template constant<' + str(number) + '>() = value; }' +
          ');')
    writer.emit('')

    model.export(writer, 'Model', 'ModelWrapper',nameSpace="ModelImpl_"+signature)
    writer.closePythonModule(name)
    writer.emit("#endif // GuardModelImpl_" + signature)

    source = writer.writer.getvalue()
    writer.close()

    if "header" in kwargs:
        with open(kwargs["header"], 'w') as modelFile:
            modelFile.write(source)

    endPos = name.find('_')
    modelName = name[0:endPos]
    module = builder.load(name, source, modelName)
    if (renumbering is not None) and (module.Model.__dict__['__init__'] != initModel):
        setattr(module.Model, '_renumbering', renumbering)
        setattr(module.Model, '_coefficientNames', {c['name']: i for i, c in enumerate(model._coefficients)})
        module.Model._init = module.Model.__dict__['__init__']
        setattr(module.Model, '__init__', initModel)
    return module