Exemplo n.º 1
0
    def test_fit_breast2d(self):
        """
        Test 2D fit with curvature penalty requiring fibre field to be set.
        """
        zinc_model_file = os.path.join(here, "resources", "breast_plate.exf")
        zinc_data_file = os.path.join(here, "resources", "breast_data.exf")
        fitter = Fitter(zinc_model_file, zinc_data_file)
        fitter.setDiagnosticLevel(1)
        fitter.load()

        fit1 = FitterStepFit()
        fitter.addFitterStep(fit1)
        self.assertEqual(2, len(fitter.getFitterSteps()))
        fit1.setGroupCurvaturePenalty(None, [100.0])
        # can't use a curvature penalty without a fibre field
        with self.assertRaises(AssertionError) as cm:
            fit1.run()
        self.assertEqual(
            str(cm.exception),
            "Must supply a fibre field to use strain/curvature penalties "
            "with mesh dimension < coordinate components.")

        # set the in-built zero fibres field
        fieldmodule = fitter.getFieldmodule()
        zeroFibreField = fieldmodule.findFieldByName("zero fibres")
        self.assertTrue(zeroFibreField.isValid())
        fitter.setFibreField(zeroFibreField)
        fitter.load()

        # check these now as different after re-load
        fieldmodule = fitter.getFieldmodule()
        coordinates = fitter.getModelCoordinatesField()
        self.assertEqual(coordinates.getName(), "coordinates")
        self.assertEqual(fitter.getDataCoordinatesField().getName(),
                         "data_coordinates")

        fit1.run()

        # check surface area of fitted coordinates
        # Note name is only prefixes with "fitted " when written with Fitter.writeModel
        surfaceAreaField = createFieldMeshIntegral(coordinates,
                                                   fitter.getMesh(2),
                                                   number_of_points=4)
        valid = surfaceAreaField.isValid()
        self.assertTrue(surfaceAreaField.isValid())
        fieldcache = fieldmodule.createFieldcache()
        result, surfaceArea = surfaceAreaField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(surfaceArea, 104501.36293993103, delta=1.0E-1)
Exemplo n.º 2
0
    def test_alignGroupsFitEllipsoidRegularData(self):
        """
        Test automatic alignment of model and data using groups & fit two cubes model to ellipsoid data.
        """
        zinc_model_file = os.path.join(here, "resources",
                                       "two_cubes_hermite_nocross_groups.exf")
        zinc_data_file = os.path.join(here, "resources",
                                      "two_cubes_ellipsoid_data_regular.exf")
        fitter = Fitter(zinc_model_file, zinc_data_file)
        fitter.setDiagnosticLevel(1)
        fitter.load()

        coordinates = fitter.getModelCoordinatesField()
        self.assertEqual(coordinates.getName(), "coordinates")
        self.assertEqual(fitter.getDataCoordinatesField().getName(),
                         "data_coordinates")
        fieldmodule = fitter.getFieldmodule()
        # surface area includes interior surface in this case
        surfaceAreaField = createFieldMeshIntegral(coordinates,
                                                   fitter.getMesh(2),
                                                   number_of_points=4)
        volumeField = createFieldMeshIntegral(coordinates,
                                              fitter.getMesh(3),
                                              number_of_points=3)
        fieldcache = fieldmodule.createFieldcache()
        result, surfaceArea = surfaceAreaField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(surfaceArea, 11.0, delta=1.0E-6)
        result, volume = volumeField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(volume, 2.0, delta=1.0E-6)
        activeNodeset = fitter.getActiveDataNodesetGroup()

        align = FitterStepAlign()
        fitter.addFitterStep(align)
        self.assertEqual(2, len(fitter.getFitterSteps()))
        self.assertTrue(align.setAlignGroups(True))
        self.assertTrue(align.isAlignGroups())
        align.run()
        rotation = align.getRotation()
        scale = align.getScale()
        translation = align.getTranslation()
        assertAlmostEqualList(self, rotation, [0.0, 0.0, 0.0], delta=1.0E-5)
        self.assertAlmostEqual(scale, 1.040599599095245, places=5)
        assertAlmostEqualList(
            self,
            translation,
            [-1.0405995643008867, -0.5202997843515198, -0.5202997827678563],
            delta=1.0E-6)
        result, surfaceArea = surfaceAreaField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(surfaceArea, 11.0 * scale * scale, delta=1.0E-6)
        result, volume = volumeField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(volume,
                               2.0 * scale * scale * scale,
                               delta=1.0E-6)

        fit1 = FitterStepFit()
        fitter.addFitterStep(fit1)
        self.assertEqual(3, len(fitter.getFitterSteps()))
        strainPenalty, locallySet, inheritable = fit1.getGroupStrainPenalty(
            None)
        assertAlmostEqualList(self, strainPenalty, [0.0], delta=1.0E-7)
        self.assertFalse(locallySet)
        self.assertFalse(inheritable)
        curvaturePenalty, locallySet, inheritable = fit1.getGroupCurvaturePenalty(
            None)
        assertAlmostEqualList(self, curvaturePenalty, [0.0], delta=1.0E-7)
        self.assertFalse(locallySet)
        self.assertFalse(inheritable)
        fit1.setGroupStrainPenalty(None, [0.1])
        strainPenalty, locallySet, inheritable = fit1.getGroupStrainPenalty(
            None)
        assertAlmostEqualList(self, strainPenalty, [0.1], delta=1.0E-7)
        self.assertTrue(locallySet)
        self.assertFalse(inheritable)
        fit1.setGroupCurvaturePenalty(None, [0.01])
        curvaturePenalty, locallySet, inheritable = fit1.getGroupCurvaturePenalty(
            None)
        assertAlmostEqualList(self, curvaturePenalty, [0.01], delta=1.0E-7)
        self.assertTrue(locallySet)
        self.assertFalse(inheritable)
        # test specifying number of components:
        curvaturePenalty, locallySet, inheritable = fit1.getGroupCurvaturePenalty(
            None, count=5)
        assertAlmostEqualList(self,
                              curvaturePenalty, [0.01, 0.01, 0.01, 0.01, 0.01],
                              delta=1.0E-7)
        # group "two" strain penalty will initially fall back to default value
        strainPenalty, locallySet, inheritable = fit1.getGroupStrainPenalty(
            "two")
        assertAlmostEqualList(self, strainPenalty, [0.1], delta=1.0E-7)
        self.assertFalse(locallySet)
        self.assertFalse(inheritable)
        fit1.setGroupStrainPenalty(
            "two", [0.1, 0.1, 0.1, 0.1, 20.0, 0.1, 0.1, 20.0, 2.0])
        strainPenalty, locallySet, inheritable = fit1.getGroupStrainPenalty(
            "two")
        assertAlmostEqualList(self,
                              strainPenalty,
                              [0.1, 0.1, 0.1, 0.1, 20.0, 0.1, 0.1, 20.0, 2.0],
                              delta=1.0E-7)
        self.assertTrue(locallySet)
        self.assertFalse(inheritable)
        fit1.setNumberOfIterations(1)
        fit1.run()
        result, surfaceArea = surfaceAreaField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(surfaceArea, 11.097773862300704, delta=1.0E-4)
        result, volume = volumeField.evaluateReal(fieldcache, 1)
        self.assertEqual(result, RESULT_OK)
        self.assertAlmostEqual(volume, 2.323461787566051, delta=1.0E-4)

        # test fibre orientation field
        fitter.load()
        fieldmodule = fitter.getFieldmodule()
        self.assertEqual(None, fitter.getFibreField())
        fibreField = fieldmodule.createFieldConstant(
            [0.0, 0.0, 0.25 * math.pi])
        fibreField.setName("custom fibres")
        fibreField.setManaged(True)
        fitter.setFibreField(fibreField)
        self.assertEqual(fibreField, fitter.getFibreField())
        coordinates = fitter.getModelCoordinatesField()
        align.run()
        fit1.run()
        # get end node coordinate to prove twist
        nodeExpectedCoordinates = {
            3: [0.8487623099139301, -0.5012613734076182, -0.5306482017126274],
            6: [0.8487623092159226, 0.2617063557585618, -0.5464896371028911],
            9: [0.8487623062422882, -0.2617063537282271, 0.5464896401724635],
            12: [0.8487623124370356, 0.5012613792923117, 0.5306482045212996]
        }
        fieldcache = fieldmodule.createFieldcache()
        nodes = fieldmodule.findNodesetByFieldDomainType(
            Field.DOMAIN_TYPE_NODES)
        for nodeIdentifier, expectedCoordinates in nodeExpectedCoordinates.items(
        ):
            node = nodes.findNodeByIdentifier(nodeIdentifier)
            self.assertEqual(RESULT_OK, fieldcache.setNode(node))
            result, x = coordinates.getNodeParameters(fieldcache, -1,
                                                      Node.VALUE_LABEL_VALUE,
                                                      1, 3)
            self.assertEqual(RESULT_OK, result)
            assertAlmostEqualList(self, x, expectedCoordinates, delta=1.0E-6)

        # test inheritance and override of penalties
        fit2 = FitterStepFit()
        fitter.addFitterStep(fit2)
        self.assertEqual(4, len(fitter.getFitterSteps()))
        strainPenalty, locallySet, inheritable = fit2.getGroupStrainPenalty(
            None)
        assertAlmostEqualList(self, strainPenalty, [0.1], delta=1.0E-7)
        self.assertFalse(locallySet)
        self.assertTrue(inheritable)
        curvaturePenalty, locallySet, inheritable = fit2.getGroupCurvaturePenalty(
            None)
        assertAlmostEqualList(self, curvaturePenalty, [0.01], delta=1.0E-7)
        self.assertFalse(locallySet)
        self.assertTrue(inheritable)
        fit2.setGroupCurvaturePenalty(None, None)
        curvaturePenalty, locallySet, inheritable = fit2.getGroupCurvaturePenalty(
            None)
        assertAlmostEqualList(self, curvaturePenalty, [0.0], delta=1.0E-7)
        self.assertTrue(locallySet is None)
        self.assertTrue(inheritable)
        strainPenalty, locallySet, inheritable = fit2.getGroupStrainPenalty(
            "two")
        assertAlmostEqualList(self,
                              strainPenalty,
                              [0.1, 0.1, 0.1, 0.1, 20.0, 0.1, 0.1, 20.0, 2.0],
                              delta=1.0E-7)
        self.assertFalse(locallySet)
        self.assertTrue(inheritable)
        fit2.setGroupStrainPenalty("two", [0.5, 0.9, 0.2])
        strainPenalty, locallySet, inheritable = fit2.getGroupStrainPenalty(
            "two", count=9)
        assertAlmostEqualList(self,
                              strainPenalty,
                              [0.5, 0.9, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2],
                              delta=1.0E-7)
        self.assertTrue(locallySet)
        self.assertTrue(inheritable)

        # test json serialisation
        s = fitter.encodeSettingsJSON()
        fitter2 = Fitter(zinc_model_file, zinc_data_file)
        fitter2.decodeSettingsJSON(s, decodeJSONFitterSteps)
        fitterSteps = fitter2.getFitterSteps()
        self.assertEqual(4, len(fitterSteps))
        self.assertTrue(isinstance(fitterSteps[0], FitterStepConfig))
        self.assertTrue(isinstance(fitterSteps[1], FitterStepAlign))
        self.assertTrue(isinstance(fitterSteps[2], FitterStepFit))
        self.assertTrue(isinstance(fitterSteps[3], FitterStepFit))
        fit1 = fitterSteps[2]
        strainPenalty, locallySet, inheritable = fit1.getGroupStrainPenalty(
            "two")
        assertAlmostEqualList(self,
                              strainPenalty,
                              [0.1, 0.1, 0.1, 0.1, 20.0, 0.1, 0.1, 20.0, 2.0],
                              delta=1.0E-7)
        self.assertTrue(locallySet)
        self.assertFalse(inheritable)
        fit2 = fitterSteps[3]
        curvaturePenalty, locallySet, inheritable = fit2.getGroupCurvaturePenalty(
            None)
        assertAlmostEqualList(self, curvaturePenalty, [0.0], delta=1.0E-7)
        self.assertTrue(locallySet is None)
        self.assertTrue(inheritable)
        strainPenalty, locallySet, inheritable = fit2.getGroupStrainPenalty(
            "two")
        assertAlmostEqualList(self,
                              strainPenalty, [0.5, 0.9, 0.2],
                              delta=1.0E-7)
        self.assertTrue(locallySet)
        self.assertTrue(inheritable)
        s2 = fitter.encodeSettingsJSON()
        self.assertEqual(s, s2)