示例#1
0
 def __init__(self, inputZincModelFile, inputZincDataFile, location,
              identifier):
     """
     :param location: Path to folder for mapclient step name.
     """
     self._fitter = Fitter(inputZincModelFile, inputZincDataFile)
     # self._fitter.setDiagnosticLevel(1)
     self._location = os.path.join(location, identifier)
     self._identifier = identifier
     self._initGraphicsModules()
     self._settings = {
         "displayAxes": True,
         "displayMarkerDataPoints": True,
         "displayMarkerDataNames": False,
         "displayMarkerDataProjections": True,
         "displayMarkerPoints": True,
         "displayMarkerNames": False,
         "displayDataPoints": True,
         "displayDataProjections": True,
         "displayDataProjectionPoints": True,
         "displayNodePoints": False,
         "displayNodeNumbers": False,
         "displayNodeDerivatives": False,
         "displayNodeDerivativeLabels": nodeDerivativeLabels[0:3],
         "displayElementNumbers": False,
         "displayElementAxes": False,
         "displayLines": True,
         "displayLinesExterior": False,
         "displaySurfaces": True,
         "displaySurfacesExterior": True,
         "displaySurfacesTranslucent": True,
         "displaySurfacesWireframe": False
     }
     self._loadSettings()
     self._fitter.load()
示例#2
0
 def __init__(self, fitter: Fitter):
     super(FitterStepAlign, self).__init__(fitter)
     self._alignMarkers = False
     markerNodeGroup, markerLocation, markerCoordinates, markerName = fitter.getMarkerModelFields(
     )
     markerDataGroup, markerDataCoordinates, markerDataName = fitter.getMarkerDataFields(
     )
     if markerNodeGroup and markerLocation and markerCoordinates and markerName and \
         markerDataGroup and markerDataCoordinates and markerDataName:
         self._alignMarkers = True
     self._rotation = [0.0, 0.0, 0.0]
     self._scale = 1.0
     self._translation = [0.0, 0.0, 0.0]
示例#3
0
def decodeJSONFitterSteps(fitter: Fitter, dct):
    """
    Function for passing as decoder to Fitter.decodeSettingsJSON().
    Constructs scaffold objects from their JSON object encoding.
    Used as object_hook argument to json.loads as a lambda function
    to pass fitter:
    lambda dct: decodeJSONFitterSteps(fitter, dct)
    :param fitter: Owning Fitter object for new FitterSteps.
    """
    for FitterStepType in [FitterStepAlign, FitterStepConfig, FitterStepFit]:
        if FitterStepType.getJsonTypeId() in dct:
            if (FitterStepType is FitterStepConfig) and (len(
                    fitter.getFitterSteps()) == 1):
                fitterStep = fitter.getInitialFitterStepConfig()
            else:
                fitterStep = FitterStepType()
                fitter.addFitterStep(fitterStep)
            fitterStep.decodeSettingsJSONDict(dct)
            return fitterStep
    return dct
示例#4
0
 def test_projection_error(self):
     """
     Test data projection RMS and maximum error calculations.
     """
     zinc_model_file = os.path.join(here, "resources", "square.exf")
     zinc_data_file = os.path.join(here, "resources",
                                   "square_error_data.exf")
     fitter = Fitter(zinc_model_file, zinc_data_file)
     fitter.setDiagnosticLevel(1)
     fitter.load()
     rmsErrorValue, maxErrorValue = fitter.getDataRMSAndMaximumProjectionError(
     )
     TOL = 1.0E-10
     self.assertAlmostEqual(rmsErrorValue, 0.34641016151377546,
                            delta=TOL)  # sqrt(0.12)
     self.assertAlmostEqual(maxErrorValue, 0.5, delta=TOL)
示例#5
0
    def test_alignFixedRandomData(self):
        """
        Test alignment of model and data to known transformations.
        """
        zinc_model_file = os.path.join(here, "resources", "cube_to_sphere.exf")
        zinc_data_file = os.path.join(here, "resources", "cube_to_sphere_data_random.exf")
        fitter = Fitter(zinc_model_file, zinc_data_file)
        fitter.setDiagnosticLevel(1)
        fitter.load()

        self.assertEqual(fitter.getModelCoordinatesField().getName(), "coordinates")
        self.assertEqual(fitter.getDataCoordinatesField().getName(), "data_coordinates")
        self.assertEqual(fitter.getMarkerGroup().getName(), "marker")
        bottomCentre1 = fitter.evaluateNodeGroupMeanCoordinates("bottom", "coordinates", isData = False)
        sidesCentre1 = fitter.evaluateNodeGroupMeanCoordinates("sides", "coordinates", isData = False)
        topCentre1 = fitter.evaluateNodeGroupMeanCoordinates("top", "coordinates", isData = False)
        assertAlmostEqualList(self, bottomCentre1, [ 0.5, 0.5, 0.0 ], delta=1.0E-7)
        assertAlmostEqualList(self, sidesCentre1, [ 0.5, 0.5, 0.5 ], delta=1.0E-7)
        assertAlmostEqualList(self, topCentre1, [ 0.5, 0.5, 1.0 ], delta=1.0E-7)
        align = FitterStepAlign(fitter)
        align.setScale(1.1)
        align.setTranslation([ 0.1, -0.2, 0.3 ])
        align.setRotation([ math.pi/4.0, math.pi/8.0, math.pi/2.0 ])
        self.assertTrue(align.isAlignMarkers())
        align.setAlignMarkers(False)
        align.run()
        rotation = align.getRotation()
        scale = align.getScale()
        translation = align.getTranslation()
        rotationMatrix = getRotationMatrix(rotation)
        transformationMatrix = [ v*scale for v in rotationMatrix ]
        bottomCentre2Expected, sidesCentre2Expected, topCentre2Expected = transformCoordinatesList(
            [ bottomCentre1, sidesCentre1, topCentre1 ], transformationMatrix, translation)
        bottomCentre2 = fitter.evaluateNodeGroupMeanCoordinates("bottom", "coordinates", isData = False)
        sidesCentre2 = fitter.evaluateNodeGroupMeanCoordinates("sides", "coordinates", isData = False)
        topCentre2 = fitter.evaluateNodeGroupMeanCoordinates("top", "coordinates", isData = False)
        assertAlmostEqualList(self, bottomCentre2, bottomCentre2Expected, delta=1.0E-7)
        assertAlmostEqualList(self, sidesCentre2, sidesCentre2Expected, delta=1.0E-7)
        assertAlmostEqualList(self, topCentre2, topCentre2Expected, delta=1.0E-7)
示例#6
0
    def test_alignMarkersFitRegularData(self):
        """
        Test automatic alignment of model and data using fiducial markers.
        """
        zinc_model_file = os.path.join(here, "resources", "cube_to_sphere.exf")
        zinc_data_file = os.path.join(here, "resources", "cube_to_sphere_data_regular.exf")
        fitter = Fitter(zinc_model_file, zinc_data_file)
        fitter.setDiagnosticLevel(1)
        fitter.load()
        coordinates = fitter.getModelCoordinatesField()
        self.assertEqual(coordinates.getName(), "coordinates")
        self.assertEqual(fitter.getDataCoordinatesField().getName(), "data_coordinates")
        self.assertEqual(fitter.getMarkerGroup().getName(), "marker")
        #fitter.getRegion().writeFile(os.path.join(here, "resources", "km_fitgeometry1.exf"))
        fieldmodule = fitter.getFieldmodule()
        surfaceAreaField = createFieldMeshIntegral(coordinates, fitter.getMesh(2), number_of_points=4)
        volumeField = createFieldMeshIntegral(coordinates, fitter.getMesh(3), number_of_points=3)
        fieldcache = fieldmodule.createFieldcache()
        result, surfaceArea = surfaceAreaField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(surfaceArea, 6.0, delta=1.0E-6)
        result, volume = volumeField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(volume, 1.0, delta=1.0E-7)

        align = FitterStepAlign(fitter)
        self.assertTrue(align.isAlignMarkers())
        align.setAlignMarkers(True)
        align.run()
        #fitter.getRegion().writeFile(os.path.join(here, "resources", "km_fitgeometry2.exf"))
        rotation = align.getRotation()
        scale = align.getScale()
        translation = align.getTranslation()
        assertAlmostEqualList(self, rotation, [ -0.25*math.pi, 0.0, 0.0 ], delta=1.0E-4)
        self.assertAlmostEqual(scale, 0.8047378476539072, places=5)
        assertAlmostEqualList(self, translation, [ -0.5690355950594247, 1.1068454682130484e-05, -0.4023689233125251 ], delta=1.0E-6)
        result, surfaceArea = surfaceAreaField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(surfaceArea, 3.885618020657802, delta=1.0E-6)
        result, volume = volumeField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(volume, 0.5211506471189844, delta=1.0E-6)

        fit1 = FitterStepFit(fitter)
        fit1.setMarkerWeight(1.0)
        fit1.setCurvaturePenaltyWeight(0.1)
        fit1.setNumberOfIterations(3)
        fit1.setUpdateReferenceState(True)
        fit1.run()
        #fitter.getRegion().writeFile(os.path.join(here, "resources", "km_fitgeometry3.exf"))

        result, surfaceArea = surfaceAreaField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(surfaceArea, 3.1892231780263853, delta=1.0E-4)
        result, volume = volumeField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(volume, 0.5276229458448985, delta=1.0E-4)

        # test json serialisation
        s = fitter.encodeSettingsJSON()
        fitter2 = Fitter(zinc_model_file, zinc_data_file)
        fitter2.decodeSettingsJSON(s, decodeJSONFitterSteps)
        fitterSteps = fitter2.getFitterSteps()
        self.assertEqual(2, len(fitterSteps))
        self.assertTrue(isinstance(fitterSteps[0], FitterStepAlign))
        self.assertTrue(isinstance(fitterSteps[1], FitterStepFit))
        #fitter2.load()
        #for fitterStep in fitterSteps:
        #    fitterStep.run()
        s2 = fitter.encodeSettingsJSON()
        self.assertEqual(s, s2)
示例#7
0
    def test_preAlignment(self):
        """
        Test prealignment step to ensure models at different translation, scale and rotation all return close
        to same aligned model.
        """
        zinc_model_file = os.path.join(here, "resources", "cube_to_sphere.exf")
        zinc_data_file = os.path.join(here, "resources",
                                      "cube_to_sphere_data_random.exf")
        fitter = Fitter(zinc_model_file, zinc_data_file)
        self.assertEqual(1, len(fitter.getFitterSteps()))
        fitter.setDiagnosticLevel(1)

        # Rotation, scale, translation
        transformationList = [
            [[0.0, 0.0, 0.0], 1.0, [0.0, 0.0, 0.0]],
            [[math.pi * 20 / 180, 0.0, 0.0], 1.0, [0.0, 0.0, 0.0]],
            [[math.pi * 135 / 180, 0.0, 0.0], 1.0, [0.0, 0.0, 0.0]],
            [[math.pi * 250 / 180, math.pi * -45 / 180, 0.0], 1.0,
             [0.0, 0.0, 0.0]],
            [[math.pi * 45 / 180, math.pi * 45 / 180, math.pi * 45 / 180], 1.0,
             [0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0], 0.05, [0.0, 0.0, 0.0]],
            [[math.pi * 70 / 180, math.pi * 10 / 180, math.pi * -300 / 180],
             0.2, [0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0], 1.0, [15.0, 15.0, 15.0]],
            [[0.0, 0.0, 0.0], 20.0, [50.0, 0.0, 10.0]],
            [[math.pi * 90 / 180, math.pi * 200 / 180, math.pi * 5 / 180], 1.0,
             [-10.0, -20.0, 100.0]],
            [[math.pi * -45 / 180, math.pi * 120 / 180, math.pi * 10 / 180],
             500.0, [100.0, 100.0, 100.0]]
        ]

        expectedAlignedNodes = [
            [
                -0.5690355951820659, 1.1070979208244695e-05,
                -0.40236892417087866
            ],
            [
                -1.1077595833408616e-05, -0.5690355904946871,
                -0.4023689227447479
            ],
            [1.1066291829453512e-05, 0.5690355885654408, -0.4023689255966489],
            [0.569035583878062, -1.1072908454692232e-05, -0.4023689241705181],
            [-0.5690355951822816, 1.1072995806778281e-05, 0.4023689241678401],
            [-1.107759604912495e-05, -0.5690355884780887, 0.40236892559397086],
            [1.10662916138482e-05, 0.5690355905820392, 0.4023689227420698],
            [0.5690355838778464, -1.1070891856158648e-05, 0.4023689241682007]
        ]

        align = FitterStepAlign()
        fitter.addFitterStep(align)
        self.assertTrue(align.setAlignMarkers(True))
        self.assertTrue(align.isAlignMarkers())

        for i in range(len(transformationList)):
            fitter.load()

            fieldmodule = fitter.getFieldmodule()
            fieldcache = fieldmodule.createFieldcache()
            modelCoordinates = fitter.getModelCoordinatesField()

            rotation = transformationList[i][0]
            scale = transformationList[i][1]
            translation = transformationList[i][2]
            modelCoordinatesTransformed = createFieldsTransformations(
                modelCoordinates, rotation, scale, translation)[0]
            fieldassignment = modelCoordinates.createFieldassignment(
                modelCoordinatesTransformed)
            fieldassignment.assign()

            align.run()
            nodeset = fieldmodule.findNodesetByFieldDomainType(
                Field.DOMAIN_TYPE_NODES)

            for nodeIdentifier in range(1, 9):
                node = nodeset.findNodeByIdentifier(nodeIdentifier)
                fieldcache.setNode(node)
                result, x = modelCoordinates.getNodeParameters(
                    fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, 3)
                assertAlmostEqualList(self,
                                      x,
                                      expectedAlignedNodes[nodeIdentifier - 1],
                                      delta=1.0E-3)
示例#8
0
    def test_groupSettings(self):
        """
        Test per-group settings, and inheritance from previous 
        """
        zinc_model_file = os.path.join(here, "resources", "cube_to_sphere.exf")
        zinc_data_file = os.path.join(here, "resources",
                                      "cube_to_sphere_data_regular.exf")
        fitter = Fitter(zinc_model_file, zinc_data_file)
        fitter.setDiagnosticLevel(1)
        config1 = fitter.getInitialFitterStepConfig()
        groupNames = config1.getGroupSettingsNames()
        self.assertEqual(0, len(groupNames))
        self.assertEqual((1.0, False, False),
                         config1.getGroupDataProportion("sides"))
        config1.setGroupDataProportion("sides", -0.1)
        groupNames = config1.getGroupSettingsNames()
        self.assertEqual(1, len(groupNames))
        self.assertEqual((0.0, True, False),
                         config1.getGroupDataProportion("sides"))
        config1.setGroupDataProportion("sides", 0.25)
        config1.setGroupDataProportion("sides", "A")
        config1.setGroupDataProportion("top", 0.4)
        groupNames = config1.getGroupSettingsNames()
        self.assertEqual(2, len(groupNames))
        self.assertTrue("sides" in groupNames)
        self.assertTrue("top" in groupNames)
        self.assertEqual((0.25, True, False),
                         config1.getGroupDataProportion("sides"))
        self.assertEqual((0.4, True, False),
                         config1.getGroupDataProportion("top"))
        self.assertEqual((1.0, False, False),
                         config1.getGroupDataProportion("bottom"))
        config1.setGroupDataProportion("bottom", 0.1)
        self.assertEqual((0.1, True, False),
                         config1.getGroupDataProportion("bottom"))
        groupNames = config1.getGroupSettingsNames()
        self.assertEqual(3, len(groupNames))
        # setting a non-inheriting value to None clears it:
        config1.setGroupDataProportion("bottom", None)
        self.assertEqual((1.0, False, False),
                         config1.getGroupDataProportion("bottom"))
        groupNames = config1.getGroupSettingsNames()
        self.assertEqual(2, len(groupNames))
        config1.setGroupDataProportion("bottom", 0.12)
        self.assertEqual((0.12, True, False),
                         config1.getGroupDataProportion("bottom"))
        groupNames = config1.getGroupSettingsNames()
        self.assertEqual(3, len(groupNames))
        config1.clearGroupDataProportion("bottom")
        self.assertEqual((1.0, False, False),
                         config1.getGroupDataProportion("bottom"))
        groupNames = config1.getGroupSettingsNames()
        self.assertEqual(2, len(groupNames))
        self.assertTrue("sides" in groupNames)
        self.assertTrue("top" in groupNames)
        fitter.load()
        activeNodeset = fitter.getActiveDataNodesetGroup()
        self.assertEqual(141, activeNodeset.getSize())
        self.assertEqual(
            72,
            getNodesetConditionalSize(
                activeNodeset,
                fitter.getFieldmodule().findFieldByName("bottom")))
        self.assertEqual(
            36,
            getNodesetConditionalSize(
                activeNodeset,
                fitter.getFieldmodule().findFieldByName("sides")))
        self.assertEqual(
            29,
            getNodesetConditionalSize(
                activeNodeset,
                fitter.getFieldmodule().findFieldByName("top")))
        self.assertEqual(
            4,
            getNodesetConditionalSize(
                activeNodeset,
                fitter.getFieldmodule().findFieldByName("marker")))
        # test override and inherit
        config2 = FitterStepConfig()
        fitter.addFitterStep(config2)
        config2.setGroupDataProportion("top", None)
        groupNames = config2.getGroupSettingsNames()
        self.assertEqual(1, len(groupNames))
        self.assertTrue("top" in groupNames)
        self.assertEqual((0.25, False, True),
                         config2.getGroupDataProportion("sides"))
        # test that the reset proportion has setLocally None
        self.assertEqual((1.0, None, True),
                         config2.getGroupDataProportion("top"))
        config2.run()
        activeNodeset = fitter.getActiveDataNodesetGroup()
        self.assertEqual(184, activeNodeset.getSize())
        self.assertEqual(
            72,
            getNodesetConditionalSize(
                activeNodeset,
                fitter.getFieldmodule().findFieldByName("bottom")))
        self.assertEqual(
            36,
            getNodesetConditionalSize(
                activeNodeset,
                fitter.getFieldmodule().findFieldByName("sides")))
        self.assertEqual(
            72,
            getNodesetConditionalSize(
                activeNodeset,
                fitter.getFieldmodule().findFieldByName("top")))
        self.assertEqual(
            4,
            getNodesetConditionalSize(
                activeNodeset,
                fitter.getFieldmodule().findFieldByName("marker")))
        # test inherit through 2 previous configs and cancel/None in config2
        config3 = FitterStepConfig()
        fitter.addFitterStep(config3)
        groupNames = config3.getGroupSettingsNames()
        self.assertEqual(0, len(groupNames))
        self.assertEqual((0.25, False, True),
                         config3.getGroupDataProportion("sides"))
        self.assertEqual((1.0, False, False),
                         config3.getGroupDataProportion("top"))
        config3.run()
        activeNodeset = fitter.getActiveDataNodesetGroup()
        self.assertEqual(184, activeNodeset.getSize())
        self.assertEqual(
            72,
            getNodesetConditionalSize(
                activeNodeset,
                fitter.getFieldmodule().findFieldByName("bottom")))
        self.assertEqual(
            36,
            getNodesetConditionalSize(
                activeNodeset,
                fitter.getFieldmodule().findFieldByName("sides")))
        self.assertEqual(
            72,
            getNodesetConditionalSize(
                activeNodeset,
                fitter.getFieldmodule().findFieldByName("top")))
        self.assertEqual(
            4,
            getNodesetConditionalSize(
                activeNodeset,
                fitter.getFieldmodule().findFieldByName("marker")))
        del config1
        del config2
        del config3

        # test json serialisation
        s = fitter.encodeSettingsJSON()
        fitter2 = Fitter(zinc_model_file, zinc_data_file)
        fitter2.decodeSettingsJSON(s, decodeJSONFitterSteps)
        fitterSteps = fitter2.getFitterSteps()
        self.assertEqual(3, len(fitterSteps))
        config1, config2, config3 = fitterSteps
        self.assertTrue(isinstance(config1, FitterStepConfig))
        self.assertTrue(isinstance(config2, FitterStepConfig))
        self.assertTrue(isinstance(config3, FitterStepConfig))
        groupNames = config1.getGroupSettingsNames()
        self.assertEqual(2, len(groupNames))
        self.assertTrue("sides" in groupNames)
        self.assertTrue("top" in groupNames)
        self.assertEqual((0.25, True, False),
                         config1.getGroupDataProportion("sides"))
        self.assertEqual((0.4, True, False),
                         config1.getGroupDataProportion("top"))
        self.assertEqual((1.0, False, False),
                         config1.getGroupDataProportion("bottom"))
        groupNames = config2.getGroupSettingsNames()
        self.assertEqual(1, len(groupNames))
        self.assertTrue("top" in groupNames)
        self.assertEqual((0.25, False, True),
                         config2.getGroupDataProportion("sides"))
        self.assertEqual((1.0, None, True),
                         config2.getGroupDataProportion("top"))
示例#9
0
    def test_fitRegularDataGroupWeight(self):
        """
        Test automatic alignment of model and data using fiducial markers.
        """
        zinc_model_file = os.path.join(here, "resources", "cube_to_sphere.exf")
        zinc_data_file = os.path.join(here, "resources",
                                      "cube_to_sphere_data_regular.exf")
        fitter = Fitter(zinc_model_file, zinc_data_file)
        self.assertEqual(1, len(fitter.getFitterSteps())
                         )  # there is always an initial FitterStepConfig
        fitter.setDiagnosticLevel(1)
        fitter.load()

        coordinates = fitter.getModelCoordinatesField()
        self.assertEqual(coordinates.getName(), "coordinates")
        fieldmodule = fitter.getFieldmodule()
        surfaceAreaField = createFieldMeshIntegral(coordinates,
                                                   fitter.getMesh(2),
                                                   number_of_points=4)
        volumeField = createFieldMeshIntegral(coordinates,
                                              fitter.getMesh(3),
                                              number_of_points=3)
        fieldcache = fieldmodule.createFieldcache()

        align = FitterStepAlign()
        fitter.addFitterStep(align)
        self.assertEqual(2, len(fitter.getFitterSteps()))
        self.assertTrue(align.setAlignMarkers(True))
        align.run()

        fit1 = FitterStepFit()
        fitter.addFitterStep(fit1)
        self.assertEqual(3, len(fitter.getFitterSteps()))
        fit1.setGroupDataWeight("bottom", 0.5)
        fit1.setGroupDataWeight("sides", 0.1)
        groupNames = fit1.getGroupSettingsNames()
        self.assertEqual(2, len(groupNames))
        self.assertEqual((0.5, True, False), fit1.getGroupDataWeight("bottom"))
        self.assertEqual((0.1, True, False), fit1.getGroupDataWeight("sides"))
        fit1.setCurvaturePenaltyWeight(0.01)
        fit1.setNumberOfIterations(3)
        fit1.setUpdateReferenceState(True)
        fit1.run()
        dataWeightField = fieldmodule.findFieldByName(
            "data_weight").castFiniteElement()
        self.assertTrue(dataWeightField.isValid())
        groupData = {
            "bottom": (72, 0.5),
            "sides": (144, 0.1),
            "top": (72, 1.0)
        }
        mesh2d = fitter.getMesh(2)
        for groupName in groupData.keys():
            expectedSize, expectedWeight = groupData[groupName]
            group = fieldmodule.findFieldByName(groupName).castGroup()
            dataGroup = fitter.getGroupDataProjectionNodesetGroup(group)
            size = dataGroup.getSize()
            self.assertEqual(size, expectedSize)
            dataIterator = dataGroup.createNodeiterator()
            node = dataIterator.next()
            while node.isValid():
                fieldcache.setNode(node)
                result, weight = dataWeightField.evaluateReal(fieldcache, 1)
                self.assertEqual(result, RESULT_OK)
                self.assertAlmostEqual(weight, expectedWeight, delta=1.0E-10)
                node = dataIterator.next()

        result, surfaceArea = surfaceAreaField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(surfaceArea, 3.2298953613027956, delta=1.0E-4)
        result, volume = volumeField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(volume, 0.5156233237703589, delta=1.0E-4)
示例#10
0
    def test_alignGroupsFitEllipsoidRegularData(self):
        """
        Test automatic alignment of model and data using groups & fit two cubes model to ellipsoid data.
        """
        zinc_model_file = os.path.join(here, "resources",
                                       "two_cubes_hermite_nocross_groups.exf")
        zinc_data_file = os.path.join(here, "resources",
                                      "two_cubes_ellipsoid_data_regular.exf")
        fitter = Fitter(zinc_model_file, zinc_data_file)
        fitter.setDiagnosticLevel(1)
        fitter.load()

        coordinates = fitter.getModelCoordinatesField()
        self.assertEqual(coordinates.getName(), "coordinates")
        self.assertEqual(fitter.getDataCoordinatesField().getName(),
                         "data_coordinates")
        fieldmodule = fitter.getFieldmodule()
        # surface area includes interior surface in this case
        surfaceAreaField = createFieldMeshIntegral(coordinates,
                                                   fitter.getMesh(2),
                                                   number_of_points=4)
        volumeField = createFieldMeshIntegral(coordinates,
                                              fitter.getMesh(3),
                                              number_of_points=3)
        fieldcache = fieldmodule.createFieldcache()
        result, surfaceArea = surfaceAreaField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(surfaceArea, 11.0, delta=1.0E-6)
        result, volume = volumeField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(volume, 2.0, delta=1.0E-6)
        activeNodeset = fitter.getActiveDataNodesetGroup()

        align = FitterStepAlign()
        fitter.addFitterStep(align)
        self.assertEqual(2, len(fitter.getFitterSteps()))
        self.assertTrue(align.setAlignGroups(True))
        self.assertTrue(align.isAlignGroups())
        align.run()
        rotation = align.getRotation()
        scale = align.getScale()
        translation = align.getTranslation()
        assertAlmostEqualList(self, rotation, [0.0, 0.0, 0.0], delta=1.0E-5)
        self.assertAlmostEqual(scale, 1.040599599095245, places=5)
        assertAlmostEqualList(
            self,
            translation,
            [-1.0405995643008867, -0.5202997843515198, -0.5202997827678563],
            delta=1.0E-6)
        result, surfaceArea = surfaceAreaField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(surfaceArea, 11.0 * scale * scale, delta=1.0E-6)
        result, volume = volumeField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(volume,
                               2.0 * scale * scale * scale,
                               delta=1.0E-6)

        fit1 = FitterStepFit()
        fitter.addFitterStep(fit1)
        self.assertEqual(3, len(fitter.getFitterSteps()))
        strainPenalty, locallySet, inheritable = fit1.getGroupStrainPenalty(
            None)
        assertAlmostEqualList(self, strainPenalty, [0.0], delta=1.0E-7)
        self.assertFalse(locallySet)
        self.assertFalse(inheritable)
        curvaturePenalty, locallySet, inheritable = fit1.getGroupCurvaturePenalty(
            None)
        assertAlmostEqualList(self, curvaturePenalty, [0.0], delta=1.0E-7)
        self.assertFalse(locallySet)
        self.assertFalse(inheritable)
        fit1.setGroupStrainPenalty(None, [0.1])
        strainPenalty, locallySet, inheritable = fit1.getGroupStrainPenalty(
            None)
        assertAlmostEqualList(self, strainPenalty, [0.1], delta=1.0E-7)
        self.assertTrue(locallySet)
        self.assertFalse(inheritable)
        fit1.setGroupCurvaturePenalty(None, [0.01])
        curvaturePenalty, locallySet, inheritable = fit1.getGroupCurvaturePenalty(
            None)
        assertAlmostEqualList(self, curvaturePenalty, [0.01], delta=1.0E-7)
        self.assertTrue(locallySet)
        self.assertFalse(inheritable)
        # test specifying number of components:
        curvaturePenalty, locallySet, inheritable = fit1.getGroupCurvaturePenalty(
            None, count=5)
        assertAlmostEqualList(self,
                              curvaturePenalty, [0.01, 0.01, 0.01, 0.01, 0.01],
                              delta=1.0E-7)
        # group "two" strain penalty will initially fall back to default value
        strainPenalty, locallySet, inheritable = fit1.getGroupStrainPenalty(
            "two")
        assertAlmostEqualList(self, strainPenalty, [0.1], delta=1.0E-7)
        self.assertFalse(locallySet)
        self.assertFalse(inheritable)
        fit1.setGroupStrainPenalty(
            "two", [0.1, 0.1, 0.1, 0.1, 20.0, 0.1, 0.1, 20.0, 2.0])
        strainPenalty, locallySet, inheritable = fit1.getGroupStrainPenalty(
            "two")
        assertAlmostEqualList(self,
                              strainPenalty,
                              [0.1, 0.1, 0.1, 0.1, 20.0, 0.1, 0.1, 20.0, 2.0],
                              delta=1.0E-7)
        self.assertTrue(locallySet)
        self.assertFalse(inheritable)
        fit1.setNumberOfIterations(1)
        fit1.run()
        result, surfaceArea = surfaceAreaField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(surfaceArea, 11.097773862300704, delta=1.0E-4)
        result, volume = volumeField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(volume, 2.323461787566051, delta=1.0E-4)

        # test fibre orientation field
        fitter.load()
        fieldmodule = fitter.getFieldmodule()
        self.assertEqual(None, fitter.getFibreField())
        fibreField = fieldmodule.createFieldConstant(
            [0.0, 0.0, 0.25 * math.pi])
        fibreField.setName("custom fibres")
        fibreField.setManaged(True)
        fitter.setFibreField(fibreField)
        self.assertEqual(fibreField, fitter.getFibreField())
        coordinates = fitter.getModelCoordinatesField()
        align.run()
        fit1.run()
        # get end node coordinate to prove twist
        nodeExpectedCoordinates = {
            3: [0.8487623099139301, -0.5012613734076182, -0.5306482017126274],
            6: [0.8487623092159226, 0.2617063557585618, -0.5464896371028911],
            9: [0.8487623062422882, -0.2617063537282271, 0.5464896401724635],
            12: [0.8487623124370356, 0.5012613792923117, 0.5306482045212996]
        }
        fieldcache = fieldmodule.createFieldcache()
        nodes = fieldmodule.findNodesetByFieldDomainType(
            Field.DOMAIN_TYPE_NODES)
        for nodeIdentifier, expectedCoordinates in nodeExpectedCoordinates.items(
        ):
            node = nodes.findNodeByIdentifier(nodeIdentifier)
            self.assertEqual(RESULT_OK, fieldcache.setNode(node))
            result, x = coordinates.getNodeParameters(fieldcache, -1,
                                                      Node.VALUE_LABEL_VALUE,
                                                      1, 3)
            self.assertEqual(RESULT_OK, result)
            assertAlmostEqualList(self, x, expectedCoordinates, delta=1.0E-6)

        # test inheritance and override of penalties
        fit2 = FitterStepFit()
        fitter.addFitterStep(fit2)
        self.assertEqual(4, len(fitter.getFitterSteps()))
        strainPenalty, locallySet, inheritable = fit2.getGroupStrainPenalty(
            None)
        assertAlmostEqualList(self, strainPenalty, [0.1], delta=1.0E-7)
        self.assertFalse(locallySet)
        self.assertTrue(inheritable)
        curvaturePenalty, locallySet, inheritable = fit2.getGroupCurvaturePenalty(
            None)
        assertAlmostEqualList(self, curvaturePenalty, [0.01], delta=1.0E-7)
        self.assertFalse(locallySet)
        self.assertTrue(inheritable)
        fit2.setGroupCurvaturePenalty(None, None)
        curvaturePenalty, locallySet, inheritable = fit2.getGroupCurvaturePenalty(
            None)
        assertAlmostEqualList(self, curvaturePenalty, [0.0], delta=1.0E-7)
        self.assertTrue(locallySet is None)
        self.assertTrue(inheritable)
        strainPenalty, locallySet, inheritable = fit2.getGroupStrainPenalty(
            "two")
        assertAlmostEqualList(self,
                              strainPenalty,
                              [0.1, 0.1, 0.1, 0.1, 20.0, 0.1, 0.1, 20.0, 2.0],
                              delta=1.0E-7)
        self.assertFalse(locallySet)
        self.assertTrue(inheritable)
        fit2.setGroupStrainPenalty("two", [0.5, 0.9, 0.2])
        strainPenalty, locallySet, inheritable = fit2.getGroupStrainPenalty(
            "two", count=9)
        assertAlmostEqualList(self,
                              strainPenalty,
                              [0.5, 0.9, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2],
                              delta=1.0E-7)
        self.assertTrue(locallySet)
        self.assertTrue(inheritable)

        # test json serialisation
        s = fitter.encodeSettingsJSON()
        fitter2 = Fitter(zinc_model_file, zinc_data_file)
        fitter2.decodeSettingsJSON(s, decodeJSONFitterSteps)
        fitterSteps = fitter2.getFitterSteps()
        self.assertEqual(4, len(fitterSteps))
        self.assertTrue(isinstance(fitterSteps[0], FitterStepConfig))
        self.assertTrue(isinstance(fitterSteps[1], FitterStepAlign))
        self.assertTrue(isinstance(fitterSteps[2], FitterStepFit))
        self.assertTrue(isinstance(fitterSteps[3], FitterStepFit))
        fit1 = fitterSteps[2]
        strainPenalty, locallySet, inheritable = fit1.getGroupStrainPenalty(
            "two")
        assertAlmostEqualList(self,
                              strainPenalty,
                              [0.1, 0.1, 0.1, 0.1, 20.0, 0.1, 0.1, 20.0, 2.0],
                              delta=1.0E-7)
        self.assertTrue(locallySet)
        self.assertFalse(inheritable)
        fit2 = fitterSteps[3]
        curvaturePenalty, locallySet, inheritable = fit2.getGroupCurvaturePenalty(
            None)
        assertAlmostEqualList(self, curvaturePenalty, [0.0], delta=1.0E-7)
        self.assertTrue(locallySet is None)
        self.assertTrue(inheritable)
        strainPenalty, locallySet, inheritable = fit2.getGroupStrainPenalty(
            "two")
        assertAlmostEqualList(self,
                              strainPenalty, [0.5, 0.9, 0.2],
                              delta=1.0E-7)
        self.assertTrue(locallySet)
        self.assertTrue(inheritable)
        s2 = fitter.encodeSettingsJSON()
        self.assertEqual(s, s2)
示例#11
0
class GeometryFitterModel(object):
    """
    Geometric fit model adding visualisations to github.com/ABI-Software/scaffoldfitter
    """
    def __init__(self, inputZincModelFile, inputZincDataFile, location,
                 identifier):
        """
        :param location: Path to folder for mapclient step name.
        """
        self._fitter = Fitter(inputZincModelFile, inputZincDataFile)
        # self._fitter.setDiagnosticLevel(1)
        self._location = os.path.join(location, identifier)
        self._identifier = identifier
        self._initGraphicsModules()
        self._settings = {
            "displayAxes": True,
            "displayMarkerDataPoints": True,
            "displayMarkerDataNames": False,
            "displayMarkerDataProjections": True,
            "displayMarkerPoints": True,
            "displayMarkerNames": False,
            "displayDataPoints": True,
            "displayDataProjections": True,
            "displayDataProjectionPoints": True,
            "displayNodePoints": False,
            "displayNodeNumbers": False,
            "displayNodeDerivatives": False,
            "displayNodeDerivativeLabels": nodeDerivativeLabels[0:3],
            "displayElementNumbers": False,
            "displayElementAxes": False,
            "displayLines": True,
            "displayLinesExterior": False,
            "displaySurfaces": True,
            "displaySurfacesExterior": True,
            "displaySurfacesTranslucent": True,
            "displaySurfacesWireframe": False
        }
        self._loadSettings()
        self._fitter.load()

    def _initGraphicsModules(self):
        context = self._fitter.getContext()
        self._materialmodule = context.getMaterialmodule()
        with ChangeManager(self._materialmodule):
            self._materialmodule.defineStandardMaterials()
            solid_blue = self._materialmodule.createMaterial()
            solid_blue.setName("solid_blue")
            solid_blue.setManaged(True)
            solid_blue.setAttributeReal3(Material.ATTRIBUTE_AMBIENT,
                                         [0.0, 0.2, 0.6])
            solid_blue.setAttributeReal3(Material.ATTRIBUTE_DIFFUSE,
                                         [0.0, 0.7, 1.0])
            solid_blue.setAttributeReal3(Material.ATTRIBUTE_EMISSION,
                                         [0.0, 0.0, 0.0])
            solid_blue.setAttributeReal3(Material.ATTRIBUTE_SPECULAR,
                                         [0.1, 0.1, 0.1])
            solid_blue.setAttributeReal(Material.ATTRIBUTE_SHININESS, 0.2)
            trans_blue = self._materialmodule.createMaterial()
            trans_blue.setName("trans_blue")
            trans_blue.setManaged(True)
            trans_blue.setAttributeReal3(Material.ATTRIBUTE_AMBIENT,
                                         [0.0, 0.2, 0.6])
            trans_blue.setAttributeReal3(Material.ATTRIBUTE_DIFFUSE,
                                         [0.0, 0.7, 1.0])
            trans_blue.setAttributeReal3(Material.ATTRIBUTE_EMISSION,
                                         [0.0, 0.0, 0.0])
            trans_blue.setAttributeReal3(Material.ATTRIBUTE_SPECULAR,
                                         [0.1, 0.1, 0.1])
            trans_blue.setAttributeReal(Material.ATTRIBUTE_ALPHA, 0.3)
            trans_blue.setAttributeReal(Material.ATTRIBUTE_SHININESS, 0.2)
        glyphmodule = context.getGlyphmodule()
        glyphmodule.defineStandardGlyphs()
        tessellationmodule = context.getTessellationmodule()
        defaultTessellation = tessellationmodule.getDefaultTessellation()
        defaultTessellation.setRefinementFactors([12])

    def _getFitSettingsFileName(self):
        return self._location + "-settings.json"

    def _getDisplaySettingsFileName(self):
        return self._location + "-display-settings.json"

    def _loadSettings(self):
        # try:
        fitSettingsFileName = self._getFitSettingsFileName()
        if os.path.isfile(fitSettingsFileName):
            with open(fitSettingsFileName, "r") as f:
                self._fitter.decodeSettingsJSON(f.read(),
                                                decodeJSONFitterSteps)
        # except:
        #    print('_loadSettings FitSettings EXCEPTION')
        #    raise()
        # try:
        displaySettingsFileName = self._getDisplaySettingsFileName()
        if os.path.isfile(displaySettingsFileName):
            with open(displaySettingsFileName, "r") as f:
                savedSettings = json.loads(f.read())
                self._settings.update(savedSettings)
        # except:
        #    print('_loadSettings DisplaySettings EXCEPTION')
        #    pass

    def _saveSettings(self):
        with open(self._getFitSettingsFileName(), "w") as f:
            f.write(self._fitter.encodeSettingsJSON())
        with open(self._getDisplaySettingsFileName(), "w") as f:
            f.write(json.dumps(self._settings, sort_keys=False, indent=4))

    def getOutputModelFileNameStem(self):
        return self._location

    def getOutputModelFileName(self):
        return self._location + ".exf"

    def done(self):
        self._saveSettings()
        self._fitter.run(endStep=None,
                         modelFileNameStem=self.getOutputModelFileNameStem())
        self._fitter.writeModel(self.getOutputModelFileName())

    def getIdentifier(self):
        return self._identifier

    def getContext(self):
        return self._fitter.getContext()

    def getFitter(self):
        return self._fitter

    def getRegion(self):
        return self._fitter.getRegion()

    def getFieldmodule(self):
        return self._fitter.getFieldmodule()

    def getScene(self):
        return self._fitter.getRegion().getScene()

    def _getVisibility(self, graphicsName):
        return self._settings[graphicsName]

    def _setVisibility(self, graphicsName, show):
        self._settings[graphicsName] = show
        graphics = self.getScene().findGraphicsByName(graphicsName)
        graphics.setVisibilityFlag(show)

    def _setMultipleGraphicsVisibility(self, graphicsName, show):
        '''
        Ensure visibility of all graphics with graphicsName is set to boolean show.
        '''
        self._settings[graphicsName] = show
        scene = self.getScene()
        graphics = scene.findGraphicsByName(graphicsName)
        while graphics.isValid():
            graphics.setVisibilityFlag(show)
            while True:
                graphics = scene.getNextGraphics(graphics)
                if (not graphics.isValid()) or (graphics.getName()
                                                == graphicsName):
                    break

    def isDisplayAxes(self):
        return self._getVisibility("displayAxes")

    def setDisplayAxes(self, show):
        self._setVisibility("displayAxes", show)

    def isDisplayElementNumbers(self):
        return self._getVisibility("displayElementNumbers")

    def setDisplayElementNumbers(self, show):
        self._setVisibility("displayElementNumbers", show)

    def isDisplayLines(self):
        return self._getVisibility("displayLines")

    def setDisplayLines(self, show):
        self._setVisibility("displayLines", show)

    def isDisplayLinesExterior(self):
        return self._settings["displayLinesExterior"]

    def setDisplayLinesExterior(self, isExterior):
        self._settings["displayLinesExterior"] = isExterior
        lines = self.getScene().findGraphicsByName("displayLines")
        lines.setExterior(self.isDisplayLinesExterior())

    def isDisplayNodeDerivatives(self):
        return self._getVisibility("displayNodeDerivatives")

    def setDisplayNodeDerivatives(self, show):
        self._settings["displayNodeDerivatives"] = show
        scene = self.getScene()
        for nodeDerivativeLabel in nodeDerivativeLabels:
            graphics = scene.findGraphicsByName("displayNodeDerivatives" +
                                                nodeDerivativeLabel)
            graphics.setVisibilityFlag(
                show
                and self.isDisplayNodeDerivativeLabels(nodeDerivativeLabel))

    def isDisplayNodeDerivativeLabels(self, nodeDerivativeLabel):
        """
        :param nodeDerivativeLabel: Label from nodeDerivativeLabels ("D1", "D2" ...)
        """
        return nodeDerivativeLabel in self._settings[
            "displayNodeDerivativeLabels"]

    def setDisplayNodeDerivativeLabels(self, nodeDerivativeLabel, show):
        """
        :param nodeDerivativeLabel: Label from nodeDerivativeLabels ("D1", "D2" ...)
        """
        shown = nodeDerivativeLabel in self._settings[
            "displayNodeDerivativeLabels"]
        if show:
            if not shown:
                # keep in same order as nodeDerivativeLabels
                nodeDerivativeLabels = []
                for label in nodeDerivativeLabels:
                    if (label == nodeDerivativeLabel
                        ) or self.isDisplayNodeDerivativeLabels(label):
                        nodeDerivativeLabels.append(label)
                self._settings[
                    "displayNodeDerivativeLabels"] = nodeDerivativeLabels
        else:
            if shown:
                self._settings["displayNodeDerivativeLabels"].remove(
                    nodeDerivativeLabel)
        graphics = self.getScene().findGraphicsByName(
            "displayNodeDerivatives" + nodeDerivativeLabel)
        graphics.setVisibilityFlag(show and self.isDisplayNodeDerivatives())

    def isDisplayMarkerDataPoints(self):
        return self._getVisibility("displayMarkerDataPoints")

    def setDisplayMarkerDataPoints(self, show):
        self._setVisibility("displayMarkerDataPoints", show)

    def isDisplayMarkerDataNames(self):
        return self._getVisibility("displayMarkerDataNames")

    def setDisplayMarkerDataNames(self, show):
        self._setVisibility("displayMarkerDataNames", show)

    def isDisplayMarkerDataProjections(self):
        return self._getVisibility("displayMarkerDataProjections")

    def setDisplayMarkerDataProjections(self, show):
        self._setVisibility("displayMarkerDataProjections", show)

    def isDisplayMarkerPoints(self):
        return self._getVisibility("displayMarkerPoints")

    def setDisplayMarkerPoints(self, show):
        self._setVisibility("displayMarkerPoints", show)

    def isDisplayMarkerNames(self):
        return self._getVisibility("displayMarkerNames")

    def setDisplayMarkerNames(self, show):
        self._setVisibility("displayMarkerNames", show)

    def isDisplayDataPoints(self):
        return self._getVisibility("displayDataPoints")

    def setDisplayDataPoints(self, show):
        self._setVisibility("displayDataPoints", show)

    def isDisplayDataProjections(self):
        return self._getVisibility("displayDataProjections")

    def setDisplayDataProjections(self, show):
        self._setMultipleGraphicsVisibility("displayDataProjections", show)

    def isDisplayDataProjectionPoints(self):
        return self._getVisibility("displayDataProjectionPoints")

    def setDisplayDataProjectionPoints(self, show):
        self._setMultipleGraphicsVisibility("displayDataProjectionPoints",
                                            show)

    def isDisplayNodeNumbers(self):
        return self._getVisibility("displayNodeNumbers")

    def setDisplayNodeNumbers(self, show):
        self._setVisibility("displayNodeNumbers", show)

    def isDisplayNodePoints(self):
        return self._getVisibility("displayNodePoints")

    def setDisplayNodePoints(self, show):
        self._setVisibility("displayNodePoints", show)

    def isDisplaySurfaces(self):
        return self._getVisibility("displaySurfaces")

    def setDisplaySurfaces(self, show):
        self._setVisibility("displaySurfaces", show)

    def isDisplaySurfacesExterior(self):
        return self._settings["displaySurfacesExterior"]

    def setDisplaySurfacesExterior(self, isExterior):
        self._settings["displaySurfacesExterior"] = isExterior
        surfaces = self.getScene().findGraphicsByName("displaySurfaces")
        surfaces.setExterior(self.isDisplaySurfacesExterior() if (
            self._fitter.getHighestDimensionMesh().getDimension() == 3
        ) else False)

    def isDisplaySurfacesTranslucent(self):
        return self._settings["displaySurfacesTranslucent"]

    def setDisplaySurfacesTranslucent(self, isTranslucent):
        self._settings["displaySurfacesTranslucent"] = isTranslucent
        surfaces = self.getScene().findGraphicsByName("displaySurfaces")
        surfacesMaterial = self._materialmodule.findMaterialByName(
            "trans_blue" if isTranslucent else "solid_blue")
        surfaces.setMaterial(surfacesMaterial)

    def isDisplaySurfacesWireframe(self):
        return self._settings["displaySurfacesWireframe"]

    def setDisplaySurfacesWireframe(self, isWireframe):
        self._settings["displaySurfacesWireframe"] = isWireframe
        surfaces = self.getScene().findGraphicsByName("displaySurfaces")
        surfaces.setRenderPolygonMode(
            Graphics.RENDER_POLYGON_MODE_WIREFRAME
            if isWireframe else Graphics.RENDER_POLYGON_MODE_SHADED)

    def isDisplayElementAxes(self):
        return self._getVisibility("displayElementAxes")

    def setDisplayElementAxes(self, show):
        self._setVisibility("displayElementAxes", show)

    def needPerturbLines(self):
        """
        Return if solid surfaces are drawn with lines, requiring perturb lines to be activated.
        """
        region = self.getRegion()
        if region is None:
            return False
        mesh2d = region.getFieldmodule().findMeshByDimension(2)
        if mesh2d.getSize() == 0:
            return False
        return self.isDisplayLines() and self.isDisplaySurfaces(
        ) and not self.isDisplaySurfacesTranslucent()

    def setSelectHighlightGroup(self, group: FieldGroup):
        """
        Select and highlight objects in the group.
        :param group: FieldGroup to select, or None to clear selection.
        """
        fieldmodule = self.getFieldmodule()
        with ChangeManager(fieldmodule):
            scene = self.getScene()
            # can't use SUBELEMENT_HANDLING_MODE_FULL as some groups have been tweaked to omit some faces
            selectionGroup = get_scene_selection_group(
                scene,
                subelementHandlingMode=FieldGroup.SUBELEMENT_HANDLING_MODE_NONE
            )
            if group:
                if selectionGroup:
                    selectionGroup.clear()
                else:
                    selectionGroup = create_scene_selection_group(
                        scene,
                        subelementHandlingMode=FieldGroup.
                        SUBELEMENT_HANDLING_MODE_NONE)
                group_add_group_elements(selectionGroup,
                                         group,
                                         highest_dimension_only=False)
                for fieldDomainType in (Field.DOMAIN_TYPE_NODES,
                                        Field.DOMAIN_TYPE_DATAPOINTS):
                    group_add_group_nodes(
                        selectionGroup, group,
                        fieldmodule.findNodesetByFieldDomainType(
                            fieldDomainType))
            else:
                if selectionGroup:
                    selectionGroup.clear()
                    scene.setSelectionField(Field())

    def setSelectHighlightGroupByName(self, groupName):
        """
        Select and highlight objects in the group by name.
        :param groupName: Name of group to select, or None to clear selection.
        """
        fieldmodule = self.getFieldmodule()
        group = None
        if groupName:
            group = fieldmodule.findFieldByName(groupName).castGroup()
            if not group.isValid():
                group = None
        self.setSelectHighlightGroup(group)

    def createGraphics(self):
        fieldmodule = self.getFieldmodule()
        mesh = self._fitter.getHighestDimensionMesh()
        meshDimension = mesh.getDimension()
        modelCoordinates = self._fitter.getModelCoordinatesField()
        componentsCount = modelCoordinates.getNumberOfComponents()

        # prepare fields and calculate axis and glyph scaling
        with ChangeManager(fieldmodule):
            # fields in same order as nodeDerivativeLabels
            nodeDerivativeFields = [
                fieldmodule.createFieldNodeValue(modelCoordinates, derivative,
                                                 1)
                for derivative in [
                    Node.VALUE_LABEL_D_DS1, Node.VALUE_LABEL_D_DS2,
                    Node.VALUE_LABEL_D_DS3, Node.VALUE_LABEL_D2_DS1DS2,
                    Node.VALUE_LABEL_D2_DS1DS3, Node.VALUE_LABEL_D2_DS2DS3,
                    Node.VALUE_LABEL_D3_DS1DS2DS3
                ]
            ]
            elementDerivativesField = fieldmodule.createFieldConcatenate([
                fieldmodule.createFieldDerivative(modelCoordinates, d + 1)
                for d in range(meshDimension)
            ])
            cmiss_number = fieldmodule.findFieldByName("cmiss_number")
            markerNodeGroup, markerLocation, markerCoordinates, markerName = self._fitter.getMarkerModelFields(
            )

            # get sizing for axes
            axesScale = 1.0
            nodes = fieldmodule.findNodesetByFieldDomainType(
                Field.DOMAIN_TYPE_NODES)
            minX, maxX = evaluateFieldNodesetRange(modelCoordinates, nodes)
            if componentsCount == 1:
                maxRange = maxX - minX
            else:
                maxRange = maxX[0] - minX[0]
                for c in range(1, componentsCount):
                    maxRange = max(maxRange, maxX[c] - minX[c])
            if maxRange > 0.0:
                while axesScale * 10.0 < maxRange:
                    axesScale *= 10.0
                while axesScale * 0.1 > maxRange:
                    axesScale *= 0.1

            # fixed width glyph size is based on average element size in all dimensions
            mesh1d = fieldmodule.findMeshByDimension(1)
            meanLineLength = 0.0
            lineCount = mesh1d.getSize()
            if lineCount > 0:
                one = fieldmodule.createFieldConstant(1.0)
                sumLineLength = fieldmodule.createFieldMeshIntegral(
                    one, modelCoordinates, mesh1d)
                cache = fieldmodule.createFieldcache()
                result, totalLineLength = sumLineLength.evaluateReal(cache, 1)
                glyphWidth = 0.1 * totalLineLength / lineCount
                del cache
                del sumLineLength
                del one
            if (lineCount == 0) or (glyphWidth == 0.0):
                # use function of coordinate range if no elements
                if componentsCount == 1:
                    maxScale = maxX - minX
                else:
                    first = True
                    for c in range(componentsCount):
                        scale = maxX[c] - minX[c]
                        if first or (scale > maxScale):
                            maxScale = scale
                            first = False
                if maxScale == 0.0:
                    maxScale = 1.0
                glyphWidth = 0.01 * maxScale
            glyphWidthSmall = 0.25 * glyphWidth

        # make graphics
        scene = self._fitter.getRegion().getScene()
        with ChangeManager(scene):
            scene.removeAllGraphics()

            axes = scene.createGraphicsPoints()
            pointattr = axes.getGraphicspointattributes()
            pointattr.setGlyphShapeType(Glyph.SHAPE_TYPE_AXES_XYZ)
            pointattr.setBaseSize([axesScale, axesScale, axesScale])
            pointattr.setLabelText(1, "  " + str(axesScale))
            axes.setMaterial(self._materialmodule.findMaterialByName("grey50"))
            axes.setName("displayAxes")
            axes.setVisibilityFlag(self.isDisplayAxes())

            # marker points, projections

            markerGroup = self._fitter.getMarkerGroup()
            markerDataGroup, markerDataCoordinates, markerDataName = self._fitter.getMarkerDataFields(
            )
            markerDataLocation, markerDataLocationCoordinates, markerDataDelta = self._fitter.getMarkerDataLocationFields(
            )
            markerNodeGroup, markerLocation, markerCoordinates, markerName = self._fitter.getMarkerModelFields(
            )
            markerDataLocationGroupField = self._fitter.getMarkerDataLocationGroupField(
            )

            markerDataPoints = scene.createGraphicsPoints()
            markerDataPoints.setFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS)
            if markerDataLocationGroupField:
                markerDataPoints.setSubgroupField(markerDataLocationGroupField)
            elif markerGroup:
                markerDataPoints.setSubgroupField(markerGroup)
            if markerDataCoordinates:
                markerDataPoints.setCoordinateField(markerDataCoordinates)
            pointattr = markerDataPoints.getGraphicspointattributes()
            pointattr.setBaseSize(
                [glyphWidthSmall, glyphWidthSmall, glyphWidthSmall])
            pointattr.setGlyphShapeType(Glyph.SHAPE_TYPE_CROSS)
            markerDataPoints.setMaterial(
                self._materialmodule.findMaterialByName("yellow"))
            markerDataPoints.setName("displayMarkerDataPoints")
            markerDataPoints.setVisibilityFlag(
                self.isDisplayMarkerDataPoints())

            markerDataNames = scene.createGraphicsPoints()
            markerDataNames.setFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS)
            if markerDataLocationGroupField:
                markerDataNames.setSubgroupField(markerDataLocationGroupField)
            elif markerGroup:
                markerDataNames.setSubgroupField(markerGroup)
            if markerDataCoordinates:
                markerDataNames.setCoordinateField(markerDataCoordinates)
            pointattr = markerDataNames.getGraphicspointattributes()
            pointattr.setBaseSize(
                [glyphWidthSmall, glyphWidthSmall, glyphWidthSmall])
            pointattr.setGlyphShapeType(Glyph.SHAPE_TYPE_NONE)
            if markerDataName:
                pointattr.setLabelField(markerDataName)
            markerDataNames.setMaterial(
                self._materialmodule.findMaterialByName("yellow"))
            markerDataNames.setName("displayMarkerDataNames")
            markerDataNames.setVisibilityFlag(self.isDisplayMarkerDataNames())

            markerDataProjections = scene.createGraphicsPoints()
            markerDataProjections.setFieldDomainType(
                Field.DOMAIN_TYPE_DATAPOINTS)
            if markerDataLocationGroupField:
                markerDataProjections.setSubgroupField(
                    markerDataLocationGroupField)
            elif markerGroup:
                markerDataProjections.setSubgroupField(markerGroup)
            if markerDataCoordinates:
                markerDataProjections.setCoordinateField(markerDataCoordinates)
            pointAttr = markerDataProjections.getGraphicspointattributes()
            pointAttr.setGlyphShapeType(Glyph.SHAPE_TYPE_LINE)
            pointAttr.setBaseSize([0.0, 1.0, 1.0])
            pointAttr.setScaleFactors([1.0, 0.0, 0.0])
            if markerDataDelta:
                pointAttr.setOrientationScaleField(markerDataDelta)
            markerDataProjections.setMaterial(
                self._materialmodule.findMaterialByName("magenta"))
            markerDataProjections.setName("displayMarkerDataProjections")
            markerDataProjections.setVisibilityFlag(
                self.isDisplayMarkerDataProjections())

            markerPoints = scene.createGraphicsPoints()
            markerPoints.setFieldDomainType(Field.DOMAIN_TYPE_NODES)
            if markerGroup:
                markerPoints.setSubgroupField(markerGroup)
            if markerCoordinates:
                markerPoints.setCoordinateField(markerCoordinates)
            pointattr = markerPoints.getGraphicspointattributes()
            pointattr.setBaseSize(
                [glyphWidthSmall, glyphWidthSmall, glyphWidthSmall])
            pointattr.setGlyphShapeType(Glyph.SHAPE_TYPE_CROSS)
            markerPoints.setMaterial(
                self._materialmodule.findMaterialByName("white"))
            markerPoints.setName("displayMarkerPoints")
            markerPoints.setVisibilityFlag(self.isDisplayMarkerPoints())

            markerNames = scene.createGraphicsPoints()
            markerNames.setFieldDomainType(Field.DOMAIN_TYPE_NODES)
            if markerGroup:
                markerNames.setSubgroupField(markerGroup)
            if markerCoordinates:
                markerNames.setCoordinateField(markerCoordinates)
            pointattr = markerNames.getGraphicspointattributes()
            pointattr.setBaseSize(
                [glyphWidthSmall, glyphWidthSmall, glyphWidthSmall])
            pointattr.setGlyphShapeType(Glyph.SHAPE_TYPE_NONE)
            if markerName:
                pointattr.setLabelField(markerName)
            markerNames.setMaterial(
                self._materialmodule.findMaterialByName("white"))
            markerNames.setName("displayMarkerNames")
            markerNames.setVisibilityFlag(self.isDisplayMarkerNames())

            # data points, projections and projection points

            dataCoordinates = self._fitter.getDataCoordinatesField()
            dataPoints = scene.createGraphicsPoints()
            dataPoints.setFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS)
            if dataCoordinates:
                dataPoints.setCoordinateField(dataCoordinates)
            pointattr = dataPoints.getGraphicspointattributes()
            # pointattr.setGlyphShapeType(Glyph.SHAPE_TYPE_DIAMOND)
            # pointattr.setBaseSize([glyphWidthSmall, glyphWidthSmall, glyphWidthSmall])
            pointattr.setGlyphShapeType(Glyph.SHAPE_TYPE_POINT)
            dataPoints.setRenderPointSize(3.0)
            dataPoints.setMaterial(
                self._materialmodule.findMaterialByName("grey50"))
            dataPoints.setName("displayDataPoints")
            dataPoints.setVisibilityFlag(self.isDisplayDataPoints())

            for projectionMeshDimension in range(1, 3):
                dataProjectionNodeGroup = self._fitter.getDataProjectionNodeGroupField(
                    projectionMeshDimension)
                if dataProjectionNodeGroup.getNodesetGroup().getSize() == 0:
                    continue
                dataProjectionCoordinates = self._fitter.getDataProjectionCoordinatesField(
                    projectionMeshDimension)
                dataProjectionDelta = self._fitter.getDataProjectionDeltaField(
                    projectionMeshDimension)
                dataProjectionError = self._fitter.getDataProjectionErrorField(
                    projectionMeshDimension)

                dataProjections = scene.createGraphicsPoints()
                dataProjections.setFieldDomainType(
                    Field.DOMAIN_TYPE_DATAPOINTS)
                dataProjections.setSubgroupField(dataProjectionNodeGroup)
                if dataCoordinates:
                    dataProjections.setCoordinateField(dataCoordinates)
                pointAttr = dataProjections.getGraphicspointattributes()
                pointAttr.setGlyphShapeType(Glyph.SHAPE_TYPE_LINE)
                pointAttr.setBaseSize([0.0, 1.0, 1.0])
                pointAttr.setScaleFactors([1.0, 0.0, 0.0])
                dataProjections.setRenderLineWidth(2.0 if (
                    projectionMeshDimension == 1) else 1.0)
                if dataProjectionDelta:
                    pointAttr.setOrientationScaleField(dataProjectionDelta)
                if dataProjectionError:
                    dataProjections.setDataField(dataProjectionError)
                spectrummodule = scene.getSpectrummodule()
                spectrum = spectrummodule.getDefaultSpectrum()
                dataProjections.setSpectrum(spectrum)
                dataProjections.setName("displayDataProjections")
                dataProjections.setVisibilityFlag(
                    self.isDisplayDataProjections())

                dataProjectionPoints = scene.createGraphicsPoints()
                dataProjectionPoints.setFieldDomainType(
                    Field.DOMAIN_TYPE_DATAPOINTS)
                dataProjectionPoints.setSubgroupField(dataProjectionNodeGroup)
                if dataProjectionCoordinates:
                    dataProjectionPoints.setCoordinateField(
                        dataProjectionCoordinates)
                pointattr = dataProjectionPoints.getGraphicspointattributes()
                # pointattr.setGlyphShapeType(Glyph.SHAPE_TYPE_DIAMOND)
                # pointattr.setBaseSize([glyphWidthSmall, glyphWidthSmall, glyphWidthSmall])
                pointattr.setGlyphShapeType(Glyph.SHAPE_TYPE_POINT)
                dataProjectionPoints.setRenderPointSize(3.0)
                dataProjectionPoints.setMaterial(
                    self._materialmodule.findMaterialByName("grey50"))
                dataProjectionPoints.setName("displayDataProjectionPoints")
                dataProjectionPoints.setVisibilityFlag(
                    self.isDisplayDataProjectionPoints())

            nodePoints = scene.createGraphicsPoints()
            nodePoints.setFieldDomainType(Field.DOMAIN_TYPE_NODES)
            nodePoints.setCoordinateField(modelCoordinates)
            pointattr = nodePoints.getGraphicspointattributes()
            pointattr.setBaseSize([glyphWidth, glyphWidth, glyphWidth])
            pointattr.setGlyphShapeType(Glyph.SHAPE_TYPE_SPHERE)
            nodePoints.setMaterial(
                self._materialmodule.findMaterialByName("white"))
            nodePoints.setName("displayNodePoints")
            nodePoints.setVisibilityFlag(self.isDisplayNodePoints())

            nodeNumbers = scene.createGraphicsPoints()
            nodeNumbers.setFieldDomainType(Field.DOMAIN_TYPE_NODES)
            nodeNumbers.setCoordinateField(modelCoordinates)
            pointattr = nodeNumbers.getGraphicspointattributes()
            pointattr.setLabelField(cmiss_number)
            pointattr.setGlyphShapeType(Glyph.SHAPE_TYPE_NONE)
            nodeNumbers.setMaterial(
                self._materialmodule.findMaterialByName("green"))
            nodeNumbers.setName("displayNodeNumbers")
            nodeNumbers.setVisibilityFlag(self.isDisplayNodeNumbers())

            # names in same order as nodeDerivativeLabels "D1", "D2", "D3", "D12", "D13", "D23", "D123" and nodeDerivativeFields
            nodeDerivativeMaterialNames = [
                "gold", "silver", "green", "cyan", "magenta", "yellow", "blue"
            ]
            derivativeScales = [1.0, 1.0, 1.0, 0.5, 0.5, 0.5, 0.25]
            for i in range(len(nodeDerivativeLabels)):
                nodeDerivativeLabel = nodeDerivativeLabels[i]
                nodeDerivatives = scene.createGraphicsPoints()
                nodeDerivatives.setFieldDomainType(Field.DOMAIN_TYPE_NODES)
                nodeDerivatives.setCoordinateField(modelCoordinates)
                pointattr = nodeDerivatives.getGraphicspointattributes()
                pointattr.setGlyphShapeType(Glyph.SHAPE_TYPE_ARROW_SOLID)
                pointattr.setOrientationScaleField(nodeDerivativeFields[i])
                pointattr.setBaseSize([0.0, glyphWidth, glyphWidth])
                pointattr.setScaleFactors([derivativeScales[i], 0.0, 0.0])
                material = self._materialmodule.findMaterialByName(
                    nodeDerivativeMaterialNames[i])
                nodeDerivatives.setMaterial(material)
                nodeDerivatives.setSelectedMaterial(material)
                nodeDerivatives.setName("displayNodeDerivatives" +
                                        nodeDerivativeLabel)
                nodeDerivatives.setVisibilityFlag(
                    self.isDisplayNodeDerivatives() and
                    self.isDisplayNodeDerivativeLabels(nodeDerivativeLabel))

            elementNumbers = scene.createGraphicsPoints()
            elementNumbers.setFieldDomainType(
                Field.DOMAIN_TYPE_MESH_HIGHEST_DIMENSION)
            elementNumbers.setCoordinateField(modelCoordinates)
            pointattr = elementNumbers.getGraphicspointattributes()
            pointattr.setLabelField(cmiss_number)
            pointattr.setGlyphShapeType(Glyph.SHAPE_TYPE_NONE)
            elementNumbers.setMaterial(
                self._materialmodule.findMaterialByName("cyan"))
            elementNumbers.setName("displayElementNumbers")
            elementNumbers.setVisibilityFlag(self.isDisplayElementNumbers())

            elementAxes = scene.createGraphicsPoints()
            elementAxes.setFieldDomainType(
                Field.DOMAIN_TYPE_MESH_HIGHEST_DIMENSION)
            elementAxes.setCoordinateField(modelCoordinates)
            pointattr = elementAxes.getGraphicspointattributes()
            pointattr.setGlyphShapeType(Glyph.SHAPE_TYPE_AXES_123)
            pointattr.setOrientationScaleField(elementDerivativesField)
            if meshDimension == 1:
                pointattr.setBaseSize([0.0, 2 * glyphWidth, 2 * glyphWidth])
                pointattr.setScaleFactors([0.25, 0.0, 0.0])
            elif meshDimension == 2:
                pointattr.setBaseSize([0.0, 0.0, 2 * glyphWidth])
                pointattr.setScaleFactors([0.25, 0.25, 0.0])
            else:
                pointattr.setBaseSize([0.0, 0.0, 0.0])
                pointattr.setScaleFactors([0.25, 0.25, 0.25])
            elementAxes.setMaterial(
                self._materialmodule.findMaterialByName("yellow"))
            elementAxes.setName("displayElementAxes")
            elementAxes.setVisibilityFlag(self.isDisplayElementAxes())

            lines = scene.createGraphicsLines()
            lines.setCoordinateField(modelCoordinates)
            lines.setExterior(self.isDisplayLinesExterior())
            lines.setName("displayLines")
            lines.setVisibilityFlag(self.isDisplayLines())

            surfaces = scene.createGraphicsSurfaces()
            surfaces.setCoordinateField(modelCoordinates)
            surfaces.setRenderPolygonMode(
                Graphics.RENDER_POLYGON_MODE_WIREFRAME if self.
                isDisplaySurfacesWireframe(
                ) else Graphics.RENDER_POLYGON_MODE_SHADED)
            surfaces.setExterior(self.isDisplaySurfacesExterior() if (
                meshDimension == 3) else False)
            surfacesMaterial = self._materialmodule.findMaterialByName(
                "trans_blue" if self.isDisplaySurfacesTranslucent(
                ) else "solid_blue")
            surfaces.setMaterial(surfacesMaterial)
            surfaces.setName("displaySurfaces")
            surfaces.setVisibilityFlag(self.isDisplaySurfaces())

    def autorangeSpectrum(self):
        scene = self._fitter.getRegion().getScene()
        spectrummodule = scene.getSpectrummodule()
        spectrum = spectrummodule.getDefaultSpectrum()
        spectrum.autorange(scene, Scenefilter())

    # === Align Utilities ===

    def isStateAlign(self):
        return False  # disabled as not implemented

    def rotateModel(self, axis, angle):
        mat1 = axis_angle_to_rotation_matrix(axis, angle)
        mat2 = euler_to_rotation_matrix(self._alignSettings["euler_angles"])
        newmat = matrix_mult(mat1, mat2)
        self._alignSettings["euler_angles"] = rotation_matrix_to_euler(newmat)
        self._applyAlignSettings()

    def scaleModel(self, factor):
        self._alignSettings["scale"] *= factor
        self._applyAlignSettings()

    def translateModel(self, relativeOffset):
        self._alignSettings["offset"] = add(self._alignSettings["offset"],
                                            relativeOffset)
        self._applyAlignSettings()

    def _autorangeSpectrum(self):
        scene = self.getScene()
        spectrummodule = scene.getSpectrummodule()
        spectrum = spectrummodule.getDefaultSpectrum()
        scenefiltermodule = scene.getScenefiltermodule()
        scenefilter = scenefiltermodule.getDefaultScenefilter()
        spectrum.autorange(scene, scenefilter)
示例#12
0
    def test_alignMarkersFitRegularData(self):
        """
        Test automatic alignment of model and data using fiducial markers.
        """
        zinc_model_file = os.path.join(here, "resources", "cube_to_sphere.exf")
        zinc_data_file = os.path.join(here, "resources",
                                      "cube_to_sphere_data_regular.exf")
        fitter = Fitter(zinc_model_file, zinc_data_file)
        self.assertEqual(1, len(fitter.getFitterSteps())
                         )  # there is always an initial FitterStepConfig
        fitter.setDiagnosticLevel(1)
        fitter.load()
        dataScale = fitter.getDataScale()
        self.assertAlmostEqual(dataScale, 1.0, delta=1.0E-7)

        coordinates = fitter.getModelCoordinatesField()
        self.assertEqual(coordinates.getName(), "coordinates")
        self.assertEqual(fitter.getDataCoordinatesField().getName(),
                         "data_coordinates")
        self.assertEqual(fitter.getMarkerGroup().getName(), "marker")
        # fitter.getRegion().writeFile(os.path.join(here, "resources", "km_fitgeometry1.exf"))
        fieldmodule = fitter.getFieldmodule()
        surfaceAreaField = createFieldMeshIntegral(coordinates,
                                                   fitter.getMesh(2),
                                                   number_of_points=4)
        volumeField = createFieldMeshIntegral(coordinates,
                                              fitter.getMesh(3),
                                              number_of_points=3)
        fieldcache = fieldmodule.createFieldcache()
        result, surfaceArea = surfaceAreaField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(surfaceArea, 6.0, delta=1.0E-6)
        result, volume = volumeField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(volume, 1.0, delta=1.0E-7)
        activeNodeset = fitter.getActiveDataNodesetGroup()
        self.assertEqual(292, activeNodeset.getSize())
        groupSizes = {"bottom": 72, "sides": 144, "top": 72, "marker": 4}
        for groupName, count in groupSizes.items():
            self.assertEqual(
                count,
                getNodesetConditionalSize(
                    activeNodeset,
                    fitter.getFieldmodule().findFieldByName(groupName)))

        align = FitterStepAlign()
        fitter.addFitterStep(align)
        self.assertEqual(2, len(fitter.getFitterSteps()))
        self.assertTrue(align.setAlignMarkers(True))
        self.assertTrue(align.isAlignMarkers())
        align.run()
        # fitter.getRegion().writeFile(os.path.join(here, "resources", "km_fitgeometry2.exf"))
        rotation = align.getRotation()
        scale = align.getScale()
        translation = align.getTranslation()
        assertAlmostEqualList(self,
                              rotation, [-0.25 * math.pi, 0.0, 0.0],
                              delta=1.0E-4)
        self.assertAlmostEqual(scale, 0.8047378476539072, places=5)
        assertAlmostEqualList(
            self,
            translation,
            [-0.5690355950594247, 1.1068454682130484e-05, -0.4023689233125251],
            delta=1.0E-6)
        result, surfaceArea = surfaceAreaField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(surfaceArea, 3.885618020657802, delta=1.0E-6)
        result, volume = volumeField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(volume, 0.5211506471189844, delta=1.0E-6)

        fit1 = FitterStepFit()
        fitter.addFitterStep(fit1)
        self.assertEqual(3, len(fitter.getFitterSteps()))
        fit1.setGroupDataWeight("marker", 1.0)
        fit1.setGroupCurvaturePenalty(None, [0.01])
        fit1.setNumberOfIterations(3)
        fit1.setUpdateReferenceState(True)
        fit1.run()
        # fitter.getRegion().writeFile(os.path.join(here, "resources", "km_fitgeometry3.exf"))

        result, surfaceArea = surfaceAreaField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(surfaceArea, 3.18921662820759, delta=1.0E-4)
        result, volume = volumeField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(volume, 0.5276212500499845, delta=1.0E-4)

        # test json serialisation
        s = fitter.encodeSettingsJSON()
        fitter2 = Fitter(zinc_model_file, zinc_data_file)
        fitter2.decodeSettingsJSON(s, decodeJSONFitterSteps)
        fitterSteps = fitter2.getFitterSteps()
        self.assertEqual(3, len(fitterSteps))
        self.assertTrue(isinstance(fitterSteps[0], FitterStepConfig))
        self.assertTrue(isinstance(fitterSteps[1], FitterStepAlign))
        self.assertTrue(isinstance(fitterSteps[2], FitterStepFit))
        # fitter2.load()
        # for fitterStep in fitterSteps:
        #    fitterStep.run()
        s2 = fitter.encodeSettingsJSON()
        self.assertEqual(s, s2)
示例#13
0
    def test_fit_breast2d(self):
        """
        Test 2D fit with curvature penalty requiring fibre field to be set.
        """
        zinc_model_file = os.path.join(here, "resources", "breast_plate.exf")
        zinc_data_file = os.path.join(here, "resources", "breast_data.exf")
        fitter = Fitter(zinc_model_file, zinc_data_file)
        fitter.setDiagnosticLevel(1)
        fitter.load()

        fit1 = FitterStepFit()
        fitter.addFitterStep(fit1)
        self.assertEqual(2, len(fitter.getFitterSteps()))
        fit1.setGroupCurvaturePenalty(None, [100.0])
        # can't use a curvature penalty without a fibre field
        with self.assertRaises(AssertionError) as cm:
            fit1.run()
        self.assertEqual(
            str(cm.exception),
            "Must supply a fibre field to use strain/curvature penalties "
            "with mesh dimension < coordinate components.")

        # set the in-built zero fibres field
        fieldmodule = fitter.getFieldmodule()
        zeroFibreField = fieldmodule.findFieldByName("zero fibres")
        self.assertTrue(zeroFibreField.isValid())
        fitter.setFibreField(zeroFibreField)
        fitter.load()

        # check these now as different after re-load
        fieldmodule = fitter.getFieldmodule()
        coordinates = fitter.getModelCoordinatesField()
        self.assertEqual(coordinates.getName(), "coordinates")
        self.assertEqual(fitter.getDataCoordinatesField().getName(),
                         "data_coordinates")

        fit1.run()

        # check surface area of fitted coordinates
        # Note name is only prefixes with "fitted " when written with Fitter.writeModel
        surfaceAreaField = createFieldMeshIntegral(coordinates,
                                                   fitter.getMesh(2),
                                                   number_of_points=4)
        valid = surfaceAreaField.isValid()
        self.assertTrue(surfaceAreaField.isValid())
        fieldcache = fieldmodule.createFieldcache()
        result, surfaceArea = surfaceAreaField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(surfaceArea, 104501.36293993103, delta=1.0E-1)