def setUp(self): self.td = TemporaryDirectoryChanger() self.td.__enter__() self.o, self.r = test_reactors.loadTestReactor(TEST_ROOT) self.dbi = database3.DatabaseInterface(self.r, self.o.cs) self.dbi.initDB(fName=self._testMethodName + ".h5") self.db: db.Database3 = self.dbi.database self.stateRetainer = self.r.retainState().__enter__() # used to test location-based history. see details below self.centralAssemSerialNums = [] self.centralTopBlockSerialNums = []
class TestCompareDB3(unittest.TestCase): """Tests for the compareDB3 module""" def setUp(self): self.td = TemporaryDirectoryChanger() self.td.__enter__() def tearDown(self): self.td.__exit__(None, None, None) def test_outputWriter(self) -> None: fileName = "test_outputWriter.txt" with OutputWriter(fileName) as out: out.writeln("Rubber Baby Buggy Bumpers") txt = open(fileName, "r").read() self.assertIn("Rubber", txt) def test_diffResultsBasic(self) -> None: # init an instance of the class dr = DiffResults(0.01) self.assertEqual(len(dr._columns), 0) self.assertEqual(len(dr._structureDiffs), 0) self.assertEqual(len(dr.diffs), 0) # simple test of addDiff dr.addDiff("thing", "what", 123.4, 122.2345, 555) self.assertEqual(len(dr._columns), 0) self.assertEqual(len(dr._structureDiffs), 0) self.assertEqual(len(dr.diffs), 3) self.assertEqual(dr.diffs["thing/what mean(abs(diff))"][0], 123.4) self.assertEqual(dr.diffs["thing/what mean(diff)"][0], 122.2345) self.assertEqual(dr.diffs["thing/what max(abs(diff))"][0], 555) # simple test of addTimeStep dr.addTimeStep("timeStep") self.assertEqual(dr._structureDiffs[0], 0) self.assertEqual(dr._columns[0], "timeStep") # simple test of addStructureDiffs dr.addStructureDiffs(7) self.assertEqual(len(dr._structureDiffs), 1) self.assertEqual(dr._structureDiffs[0], 7) # simple test of _getDefault self.assertEqual(len(dr._getDefault()), 0) # simple test of nDiffs self.assertEqual(dr.nDiffs(), 10)
def test_plotHexBlock(self): with TemporaryDirectoryChanger(): first_fuel_block = self.r.core.getFirstBlock(Flags.FUEL) first_fuel_block.autoCreateSpatialGrids() plotting.plotBlockDiagram(first_fuel_block, "blockDiagram23.svg", True) self.assertTrue(os.path.exists("blockDiagram23.svg"))
def test_concatenateLogs(self): """simple test of the concat logs function""" with TemporaryDirectoryChanger(): # create the log dir logDir = "test_concatenateLogs" if os.path.exists(logDir): rmtree(logDir) context.createLogDir(0, logDir) # create as stdout file stdoutFile = os.path.join(logDir, logDir + ".0.0.stdout") with open(stdoutFile, "w") as f: f.write("hello world\n") self.assertTrue(os.path.exists(stdoutFile)) # create a stderr file stderrFile = os.path.join(logDir, logDir + ".0.0.stderr") with open(stderrFile, "w") as f: f.write("goodbye cruel world\n") self.assertTrue(os.path.exists(stderrFile)) # concat logs runLog.concatenateLogs(logDir=logDir) # verify output self.assertFalse(os.path.exists(stdoutFile)) self.assertFalse(os.path.exists(stderrFile))
def test_createLogDir(self): """Test the createLogDir() method""" with TemporaryDirectoryChanger(): logDir = "test_createLogDir" self.assertFalse(os.path.exists(logDir)) for _ in range(10): runLog.createLogDir(logDir) self.assertTrue(os.path.exists(logDir))
def test_dumpReactorXdmf(self): # This does a lot, and is hard to verify. at least make sure it doesn't crash with TemporaryDirectoryChanger(dumpOnException=False): db = Database3("testDatabase.h5", "w") with db: db.writeToDB(self.r) dumper = xdmf.XdmfDumper("testVtk", inputName="testDatabase.h5") with dumper: dumper.dumpState(self.r)
def test_plotCartesianBlock(self): from armi import settings from armi.reactor import blueprints, reactors with TemporaryDirectoryChanger(): cs = settings.Settings( os.path.join(TEST_ROOT, "tutorials", "c5g7-settings.yaml")) blueprint = blueprints.loadFromCs(cs) r = reactors.factory(cs, blueprint) for name, bDesign in blueprint.blockDesigns.items(): b = bDesign.construct(cs, blueprint, 0, 1, 1, "AA", {}) plotting.plotBlockDiagram(b, "{}.svg".format(name), True) self.assertTrue(os.path.exists("uo2.svg")) self.assertTrue(os.path.exists("mox.svg"))
def test_cleanPathMpi(self): # """Simple tests of cleanPath(), in the MPI scenario""" with TemporaryDirectoryChanger(): # TEST 0: File is not safe to delete, due to name pathing filePath0 = "test0_cleanPathNoMpi" open(filePath0, "w").write("something") self.assertTrue(os.path.exists(filePath0)) with self.assertRaises(Exception): pathTools.cleanPath(filePath0, mpiRank=context.MPI_RANK) context.waitAll() # TEST 1: Delete a single file filePath1 = "test1_cleanPathNoMpi_mongoose" open(filePath1, "w").write("something") self.assertTrue(os.path.exists(filePath1)) pathTools.cleanPath(filePath1, mpiRank=context.MPI_RANK) context.waitAll() self.assertFalse(os.path.exists(filePath1)) # TEST 2: Delete an empty directory dir2 = "mongoose" os.mkdir(dir2) self.assertTrue(os.path.exists(dir2)) pathTools.cleanPath(dir2, mpiRank=context.MPI_RANK) context.waitAll() self.assertFalse(os.path.exists(dir2)) # TEST 3: Delete a directory with two files inside # create directory dir3 = "mongoose" os.mkdir(dir3) # throw in a couple of simple text files open(os.path.join(dir3, "file1.txt"), "w").write("something1") open(os.path.join(dir3, "file2.txt"), "w").write("something2") # delete the directory and test self.assertTrue(os.path.exists(dir3)) self.assertTrue(os.path.exists(os.path.join(dir3, "file1.txt"))) self.assertTrue(os.path.exists(os.path.join(dir3, "file2.txt"))) pathTools.cleanPath(dir3, mpiRank=context.MPI_RANK) context.waitAll() self.assertFalse(os.path.exists(dir3))
def test_concatenateLogs(self): """simple test of the concat logs function""" with TemporaryDirectoryChanger(): # create the log dir logDir = "test_concatenateLogs" if os.path.exists(logDir): rmtree(logDir) runLog.createLogDir(logDir) # create as stdout file stdoutFile1 = os.path.join( logDir, "{}.runLogTest.0000.stdout".format(runLog.STDOUT_LOGGER_NAME)) with open(stdoutFile1, "w") as f: f.write("hello world\n") stdoutFile2 = os.path.join( logDir, "{}.runLogTest.0001.stdout".format(runLog.STDOUT_LOGGER_NAME)) with open(stdoutFile2, "w") as f: f.write("hello other world\n") self.assertTrue(os.path.exists(stdoutFile1)) self.assertTrue(os.path.exists(stdoutFile2)) # create a stderr file stderrFile = os.path.join( logDir, "{}.runLogTest.0000.stderr".format(runLog.STDOUT_LOGGER_NAME)) with open(stderrFile, "w") as f: f.write("goodbye cruel world\n") self.assertTrue(os.path.exists(stderrFile)) # concat logs runLog.concatenateLogs(logDir=logDir) # verify output combinedLogFile = os.path.join(logDir, "runLogTest-mpi.log") self.assertTrue(os.path.exists(combinedLogFile)) self.assertFalse(os.path.exists(stdoutFile1)) self.assertFalse(os.path.exists(stdoutFile2)) self.assertFalse(os.path.exists(stderrFile))
def test_dumpReactorVtk(self): # This does a lot, and is hard to verify. at least make sure it doesn't crash with TemporaryDirectoryChanger(dumpOnException=False): dumper = vtk.VtkDumper("testVtk", inputName=None) with dumper: dumper.dumpState(self.r)
class TestDatabase3(unittest.TestCase): r"""Tests for the Database3 class""" def setUp(self): self.td = TemporaryDirectoryChanger() self.td.__enter__() self.o, self.r = test_reactors.loadTestReactor(TEST_ROOT) cs = self.o.cs self.dbi = database3.DatabaseInterface(self.r, cs) self.dbi.initDB(fName=self._testMethodName + ".h5") self.db: db.Database3 = self.dbi.database self.stateRetainer = self.r.retainState().__enter__() # used to test location-based history. see details below self.centralAssemSerialNums = [] self.centralTopBlockSerialNums = [] def tearDown(self): self.db.close() self.stateRetainer.__exit__() self.td.__exit__(None, None, None) def makeHistory(self): """ Walk the reactor through a few time steps and write them to the db. """ for cycle, node in ((cycle, node) for cycle in range(3) for node in range(3)): self.r.p.cycle = cycle self.r.p.timeNode = node # something that splitDatabase won't change, so that we can make sure that # the right data went to the right new groups/cycles self.r.p.cycleLength = cycle self.db.writeToDB(self.r) def makeShuffleHistory(self): """ Walk the reactor through a few time steps with some shuffling. """ # Serial numbers *are not stable* (i.e., they can be different between test runs # due to parallelism and test run order). However, they are the simplest way to # check correctness of location-based history tracking. So we stash the serial # numbers at the location of interest so that we can use them later to check our # work. self.centralAssemSerialNums = [] self.centralTopBlockSerialNums = [] grid = self.r.core.spatialGrid for cycle in range(3): a1 = self.r.core.childrenByLocator[grid[cycle, 0, 0]] a2 = self.r.core.childrenByLocator[grid[0, 0, 0]] olda1Loc = a1.spatialLocator a1.moveTo(a2.spatialLocator) a2.moveTo(olda1Loc) c = self.r.core.childrenByLocator[grid[0, 0, 0]] self.centralAssemSerialNums.append(c.p.serialNum) self.centralTopBlockSerialNums.append(c[-1].p.serialNum) for node in range(3): self.r.p.cycle = cycle self.r.p.timeNode = node # something that splitDatabase won't change, so that we can make sure # that the right data went to the right new groups/cycles self.r.p.cycleLength = cycle self.db.writeToDB(self.r) # add some more data that isnt written to the database to test the # DatabaseInterface API self.r.p.cycle = 3 self.r.p.timeNode = 0 self.r.p.cycleLength = cycle self.r.core[0].p.chargeTime = 3 def _compareArrays(self, ref, src): """ Compare two numpy arrays. Comparing numpy arrays that may have unsavory data (NaNs, Nones, jagged data, etc.) is really difficult. For now, convert to a list and compare element-by-element. """ self.assertEqual(type(ref), type(src)) if isinstance(ref, numpy.ndarray): ref = ref.tolist() src = src.tolist() for v1, v2 in zip(ref, src): # Entries may be None if isinstance(v1, numpy.ndarray): v1 = v1.tolist() if isinstance(v2, numpy.ndarray): v2 = v2.tolist() self.assertEqual(v1, v2) def _compareRoundTrip(self, data): """ Make sure that data is unchanged by packing/unpacking. """ packed, attrs = database3.packSpecialData(data, "testing") roundTrip = database3.unpackSpecialData(packed, attrs, "testing") self._compareArrays(data, roundTrip) def test_computeParents(self): # The below arrays represent a tree structure like this: # 71 -----------------------. # | \ # 12--.-----.------. 72 # / | \ \ \ # 22 30 4---. 6 18-. # / | | | \ \ / | \ # 8 17 2 32 52 62 1 9 10 # # This should cover a handful of corner cases numChildren = [2, 5, 2, 0, 0, 1, 0, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0] serialNums = [ 71, 12, 22, 8, 17, 30, 2, 4, 32, 53, 62, 6, 18, 1, 9, 10, 72 ] expected_1 = [ None, 71, 12, 22, 22, 12, 30, 12, 4, 4, 4, 12, 12, 18, 18, 18, 71 ] expected_2 = [ None, None, 71, 12, 12, 71, 12, 71, 12, 12, 12, 71, 71, 12, 12, 12, None, ] expected_3 = [ None, None, None, 71, 71, None, 71, None, 71, 71, 71, None, None, 71, 71, 71, None, ] self.assertEqual( database3.Layout.computeAncestors(serialNums, numChildren), expected_1) self.assertEqual( database3.Layout.computeAncestors(serialNums, numChildren, 2), expected_2) self.assertEqual( database3.Layout.computeAncestors(serialNums, numChildren, 3), expected_3) def test_history(self) -> None: self.makeShuffleHistory() grid = self.r.core.spatialGrid testAssem = self.r.core.childrenByLocator[grid[0, 0, 0]] testBlock = testAssem[-1] # Test assem hist = self.db.getHistoryByLocation(testAssem, params=["chargeTime", "serialNum"]) expectedSn = {(c, n): self.centralAssemSerialNums[c] for c in range(3) for n in range(3)} self.assertEqual(expectedSn, hist["serialNum"]) # test block hists = self.db.getHistoriesByLocation([testBlock], params=["serialNum"], timeSteps=[(0, 0), (1, 0), (2, 0)]) expectedSn = {(c, 0): self.centralTopBlockSerialNums[c] for c in range(3)} self.assertEqual(expectedSn, hists[testBlock]["serialNum"]) # cant mix blocks and assems, since they are different distance from core with self.assertRaises(ValueError): self.db.getHistoriesByLocation([testAssem, testBlock], params=["serialNum"]) # if requested time step isnt written, return no content hist = self.dbi.getHistory(self.r.core[0], params=["chargeTime", "serialNum"], byLocation=True) self.assertIn((3, 0), hist["chargeTime"].keys()) self.assertEqual(hist["chargeTime"][(3, 0)], 3) def test_replaceNones(self): """ This definitely needs some work. """ data3 = numpy.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) data1 = numpy.array([1, 2, 3, 4, 5, 6, 7, 8]) data1iNones = numpy.array([1, 2, None, 5, 6]) data1fNones = numpy.array([None, 2.0, None, 5.0, 6.0]) data2fNones = numpy.array([None, [[1.0, 2.0, 6.0], [2.0, 3.0, 4.0]]], dtype=object) dataJag = numpy.array( [[[1, 2], [3, 4]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]], dtype=object) dataJagNones = numpy.array( [[[1, 2], [3, 4]], [[1], [1]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]], dtype=object, ) dataDict = numpy.array([{ "bar": 2, "baz": 3 }, { "foo": 4, "baz": 6 }, { "foo": 7, "bar": 8 }]) self._compareRoundTrip(data3) self._compareRoundTrip(data1) self._compareRoundTrip(data1iNones) self._compareRoundTrip(data1fNones) self._compareRoundTrip(data2fNones) self._compareRoundTrip(dataJag) self._compareRoundTrip(dataJagNones) self._compareRoundTrip(dataDict) def test_mergeHistory(self): # pylint: disable=protected-access self.makeHistory() # put some big data in an HDF5 attribute. This will exercise the code that pulls # such attributes into a formal dataset and a reference. self.r.p.cycle = 1 self.r.p.timeNode = 0 tnGroup = self.db.getH5Group(self.r) database3._writeAttrs( tnGroup["layout/serialNum"], tnGroup, { "fakeBigData": numpy.eye(6400), "someString": "this isn't a reference to another dataset", }, ) db_path = "restartDB.h5" db2 = database3.Database3(db_path, "w") with db2: db2.mergeHistory(self.db, 2, 2) self.r.p.cycle = 1 self.r.p.timeNode = 0 tnGroup = db2.getH5Group(self.r) # this test is a little bit implementation-specific, but nice to be explicit self.assertEqual( tnGroup["layout/serialNum"].attrs["fakeBigData"], "@/c01n00/attrs/0_fakeBigData", ) # actually exercise the _resolveAttrs function attrs = database3._resolveAttrs(tnGroup["layout/serialNum"].attrs, tnGroup) self.assertTrue( numpy.array_equal(attrs["fakeBigData"], numpy.eye(6400))) def test_splitDatabase(self): self.makeHistory() self.db.splitDatabase([(c, n) for c in (1, 2) for n in range(3)], "-all-iterations") # Closing to copy back from fast path self.db.close() with h5py.File("test_splitDatabase.h5", "r") as newDb: self.assertTrue(newDb["c00n00/Reactor/cycle"][()] == 0) self.assertTrue(newDb["c00n00/Reactor/cycleLength"][()] == 1) self.assertTrue("c02n00" not in newDb) self.assertTrue( newDb.attrs["databaseVersion"] == database3.DB_VERSION) # validate that the min set of meta data keys exists meta_data_keys = [ "appName", "armiLocation", "databaseVersion", "hostname", "localCommitHash", "machines", "platform", "platformArch", "platformRelease", "platformVersion", "pluginPaths", "python", "startTime", "successfulCompletion", "user", "version", ] for meta_key in meta_data_keys: self.assertIn(meta_key, newDb.attrs) self.assertTrue(newDb.attrs[meta_key] is not None)
class TestDatabase3(unittest.TestCase): r"""Tests for the Database3 class""" def setUp(self): self.td = TemporaryDirectoryChanger() self.td.__enter__() self.o, self.r = test_reactors.loadTestReactor(TEST_ROOT) self.dbi = database3.DatabaseInterface(self.r, self.o.cs) self.dbi.initDB(fName=self._testMethodName + ".h5") self.db: db.Database3 = self.dbi.database self.stateRetainer = self.r.retainState().__enter__() # used to test location-based history. see details below self.centralAssemSerialNums = [] self.centralTopBlockSerialNums = [] def tearDown(self): self.db.close() self.stateRetainer.__exit__() self.td.__exit__(None, None, None) def makeHistory(self): """Walk the reactor through a few time steps and write them to the db.""" for cycle, node in ((cycle, node) for cycle in range(3) for node in range(3)): self.r.p.cycle = cycle self.r.p.timeNode = node # something that splitDatabase won't change, so that we can make sure that # the right data went to the right new groups/cycles self.r.p.cycleLength = cycle self.db.writeToDB(self.r) def makeShuffleHistory(self): """Walk the reactor through a few time steps with some shuffling.""" # Serial numbers *are not stable* (i.e., they can be different between test runs # due to parallelism and test run order). However, they are the simplest way to # check correctness of location-based history tracking. So we stash the serial # numbers at the location of interest so that we can use them later to check our # work. self.centralAssemSerialNums = [] self.centralTopBlockSerialNums = [] grid = self.r.core.spatialGrid for cycle in range(2): a1 = self.r.core.childrenByLocator[grid[cycle, 0, 0]] a2 = self.r.core.childrenByLocator[grid[0, 0, 0]] olda1Loc = a1.spatialLocator a1.moveTo(a2.spatialLocator) a2.moveTo(olda1Loc) c = self.r.core.childrenByLocator[grid[0, 0, 0]] self.centralAssemSerialNums.append(c.p.serialNum) self.centralTopBlockSerialNums.append(c[-1].p.serialNum) for node in range(2): self.r.p.cycle = cycle self.r.p.timeNode = node # something that splitDatabase won't change, so that we can make sure # that the right data went to the right new groups/cycles self.r.p.cycleLength = cycle self.db.writeToDB(self.r) # add some more data that isnt written to the database to test the # DatabaseInterface API self.r.p.cycle = 2 self.r.p.timeNode = 0 self.r.p.cycleLength = cycle self.r.core[0].p.chargeTime = 2 # add some fake missing parameter data to test allowMissing self.db.h5db["c00n00/Reactor/missingParam"] = "i don't exist" def _compareArrays(self, ref, src): """ Compare two numpy arrays. Comparing numpy arrays that may have unsavory data (NaNs, Nones, jagged data, etc.) is really difficult. For now, convert to a list and compare element-by-element. """ self.assertEqual(type(ref), type(src)) if isinstance(ref, numpy.ndarray): ref = ref.tolist() src = src.tolist() for v1, v2 in zip(ref, src): # Entries may be None if isinstance(v1, numpy.ndarray): v1 = v1.tolist() if isinstance(v2, numpy.ndarray): v2 = v2.tolist() self.assertEqual(v1, v2) def _compareRoundTrip(self, data): """Make sure that data is unchanged by packing/unpacking.""" packed, attrs = database3.packSpecialData(data, "testing") roundTrip = database3.unpackSpecialData(packed, attrs, "testing") self._compareArrays(data, roundTrip) def test_prepRestartRun(self): """ This test is based on the armiRun.yaml case that is loaded during the `setUp` above. In that cs, `reloadDBName` is set to 'reloadingDB.h5', `startCycle` = 1, and `startNode` = 2. The nonexistent 'reloadingDB.h5' must first be created here for this test. """ # first successfully call to prepRestartRun o, r = test_reactors.loadTestReactor(TEST_ROOT) cs = o.cs cs = cs.modified( newSettings={ "nCycles": 3, "cycles": [ { "step days": [1000, 1000], "power fractions": [1, 1] }, { "step days": [1000, 1000], "power fractions": [1, 1] }, { "step days": [1000, 1000], "power fractions": [1, 1] }, ], "reloadDBName": "something_fake.h5", }) # create a db based on the cs dbi = database3.DatabaseInterface(r, cs) dbi.initDB(fName="reloadingDB.h5") db = dbi.database # populate the db with something for cycle, node in ((cycle, node) for cycle in range(3) for node in range(2)): r.p.cycle = cycle r.p.timeNode = node r.p.cycleLength = 2000 db.writeToDB(r) db.close() self.dbi.prepRestartRun() # now make the cycle histories clash and confirm that an error is thrown cs = cs.modified( newSettings={ "cycles": [ { "step days": [666, 666], "power fractions": [1, 1] }, { "step days": [666, 666], "power fractions": [1, 1] }, { "step days": [666, 666], "power fractions": [1, 1] }, ], }) # create a db based on the cs dbi = database3.DatabaseInterface(r, cs) dbi.initDB(fName="reloadingDB.h5") db = dbi.database # populate the db with something for cycle, node in ((cycle, node) for cycle in range(3) for node in range(2)): r.p.cycle = cycle r.p.timeNode = node r.p.cycleLength = 2000 db.writeToDB(r) db.close() with self.assertRaises(ValueError): self.dbi.prepRestartRun() def test_computeParents(self): # The below arrays represent a tree structure like this: # 71 -----------------------. # | \ # 12--.-----.------. 72 # / | \ \ \ # 22 30 4---. 6 18-. # / | | | \ \ / | \ # 8 17 2 32 52 62 1 9 10 # # This should cover a handful of corner cases numChildren = [2, 5, 2, 0, 0, 1, 0, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0] serialNums = [ 71, 12, 22, 8, 17, 30, 2, 4, 32, 53, 62, 6, 18, 1, 9, 10, 72 ] expected_1 = [ None, 71, 12, 22, 22, 12, 30, 12, 4, 4, 4, 12, 12, 18, 18, 18, 71 ] expected_2 = [ None, None, 71, 12, 12, 71, 12, 71, 12, 12, 12, 71, 71, 12, 12, 12, None, ] expected_3 = [ None, None, None, 71, 71, None, 71, None, 71, 71, 71, None, None, 71, 71, 71, None, ] self.assertEqual( database3.Layout.computeAncestors(serialNums, numChildren), expected_1) self.assertEqual( database3.Layout.computeAncestors(serialNums, numChildren, 2), expected_2) self.assertEqual( database3.Layout.computeAncestors(serialNums, numChildren, 3), expected_3) def test_load(self): self.makeShuffleHistory() with self.assertRaises(KeyError): _r = self.db.load(0, 0) _r = self.db.load(0, 0, allowMissing=True) del self.db.h5db["c00n00/Reactor/missingParam"] _r = self.db.load(0, 0, allowMissing=False) # we shouldn't be able to set the fileName if a file is open with self.assertRaises(RuntimeError): self.db.fileName = "whatever.h5" def test_history(self): self.makeShuffleHistory() grid = self.r.core.spatialGrid testAssem = self.r.core.childrenByLocator[grid[0, 0, 0]] testBlock = testAssem[-1] # Test assem hist = self.db.getHistoryByLocation(testAssem, params=["chargeTime", "serialNum"]) expectedSn = {(c, n): self.centralAssemSerialNums[c] for c in range(2) for n in range(2)} self.assertEqual(expectedSn, hist["serialNum"]) # test block hists = self.db.getHistoriesByLocation([testBlock], params=["serialNum"], timeSteps=[(0, 0), (1, 0)]) expectedSn = {(c, 0): self.centralTopBlockSerialNums[c] for c in range(2)} self.assertEqual(expectedSn, hists[testBlock]["serialNum"]) # cant mix blocks and assems, since they are different distance from core with self.assertRaises(ValueError): self.db.getHistoriesByLocation([testAssem, testBlock], params=["serialNum"]) # if requested time step isnt written, return no content hist = self.dbi.getHistory(self.r.core[0], params=["chargeTime", "serialNum"], byLocation=True) self.assertIn((2, 0), hist["chargeTime"].keys()) self.assertEqual(hist["chargeTime"][(2, 0)], 2) def test_auxData(self): path = self.db.getAuxiliaryDataPath((2, 0), "test_stuff") self.assertEqual(path, "c02n00/test_stuff") with self.assertRaises(KeyError): self.db.genAuxiliaryData((-1, -1)) # TODO: This should be expanded. def test_replaceNones(self): """Super basic test that we handle Nones correctly in database read/writes""" data3 = numpy.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) data1 = numpy.array([1, 2, 3, 4, 5, 6, 7, 8]) data1iNones = numpy.array([1, 2, None, 5, 6]) data1fNones = numpy.array([None, 2.0, None, 5.0, 6.0]) data2fNones = numpy.array([None, [[1.0, 2.0, 6.0], [2.0, 3.0, 4.0]]], dtype=object) dataJag = numpy.array( [[[1, 2], [3, 4]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]], dtype=object) dataJagNones = numpy.array( [[[1, 2], [3, 4]], [[1], [1]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]], dtype=object, ) dataDict = numpy.array([{ "bar": 2, "baz": 3 }, { "foo": 4, "baz": 6 }, { "foo": 7, "bar": 8 }]) self._compareRoundTrip(data3) self._compareRoundTrip(data1) self._compareRoundTrip(data1iNones) self._compareRoundTrip(data1fNones) self._compareRoundTrip(data2fNones) self._compareRoundTrip(dataJag) self._compareRoundTrip(dataJagNones) self._compareRoundTrip(dataDict) def test_mergeHistory(self): # pylint: disable=protected-access self.makeHistory() # put some big data in an HDF5 attribute. This will exercise the code that pulls # such attributes into a formal dataset and a reference. self.r.p.cycle = 1 self.r.p.timeNode = 0 tnGroup = self.db.getH5Group(self.r) database3._writeAttrs( tnGroup["layout/serialNum"], tnGroup, { "fakeBigData": numpy.eye(6400), "someString": "this isn't a reference to another dataset", }, ) db_path = "restartDB.h5" db2 = database3.Database3(db_path, "w") with db2: db2.mergeHistory(self.db, 2, 2) self.r.p.cycle = 1 self.r.p.timeNode = 0 tnGroup = db2.getH5Group(self.r) # this test is a little bit implementation-specific, but nice to be explicit self.assertEqual( tnGroup["layout/serialNum"].attrs["fakeBigData"], "@/c01n00/attrs/0_fakeBigData", ) # actually exercise the _resolveAttrs function attrs = database3._resolveAttrs(tnGroup["layout/serialNum"].attrs, tnGroup) self.assertTrue( numpy.array_equal(attrs["fakeBigData"], numpy.eye(6400))) keys = sorted(db2.keys()) self.assertEqual(len(keys), 8) self.assertEqual(keys[:3], ["/c00n00", "/c00n01", "/c00n02"]) def test_splitDatabase(self): self.makeHistory() self.db.splitDatabase([(c, n) for c in (1, 2) for n in range(3)], "-all-iterations") # Closing to copy back from fast path self.db.close() with h5py.File("test_splitDatabase.h5", "r") as newDb: self.assertEqual(newDb["c00n00/Reactor/cycle"][()], 0) self.assertEqual(newDb["c00n00/Reactor/cycleLength"][()], 1) self.assertNotIn("c02n00", newDb) self.assertEqual(newDb.attrs["databaseVersion"], database3.DB_VERSION) # validate that the min set of meta data keys exists meta_data_keys = [ "appName", "armiLocation", "databaseVersion", "hostname", "localCommitHash", "machines", "platform", "platformArch", "platformRelease", "platformVersion", "pluginPaths", "python", "startTime", "successfulCompletion", "user", "version", ] for meta_key in meta_data_keys: self.assertIn(meta_key, newDb.attrs) self.assertIsNotNone(newDb.attrs[meta_key]) # test an edge case - no DB to split with self.assertRaises(ValueError): self.db.h5db = None self.db.splitDatabase([(c, n) for c in (1, 2) for n in range(3)], "-all-iterations") def test_grabLocalCommitHash(self): """test of static method to grab a local commit hash with ARMI version""" # 1. test outside a Git repo localHash = database3.Database3.grabLocalCommitHash() self.assertEqual(localHash, "unknown") # 2. test inside an empty git repo code = subprocess.run( ["git", "init", "."], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ).returncode self.assertEqual(code, 0) localHash = database3.Database3.grabLocalCommitHash() self.assertEqual(localHash, "unknown") # 3. test inside a git repo with one tag # commit the empty repo code = subprocess.run( [ "git", "commit", "--allow-empty", "-m", '"init"', "--author", '"sam <>"' ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ).returncode if code == 128: # GitHub Actions blocks certain kinds of Git commands return # create a tag off our new commit code = subprocess.run( ["git", "tag", "thanks", "-m", '"you_rock"'], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ).returncode self.assertEqual(code, 0) # test that we recover the correct commit hash localHash = database3.Database3.grabLocalCommitHash() self.assertEqual(localHash, "thanks") # delete the .git directory code = subprocess.run(["git", "clean", "-f"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode self.assertEqual(code, 0) code = subprocess.run( ["git", "clean", "-f", "-d"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ).returncode self.assertEqual(code, 0) def test_fileName(self): # test the file name getter self.assertEqual(str(self.db.fileName), "test_fileName.h5") # test the file name setter self.db.close() self.db.fileName = "thing.h5" self.assertEqual(str(self.db.fileName), "thing.h5") def test_readInputsFromDB(self): inputs = self.db.readInputsFromDB() self.assertEqual(len(inputs), 3) self.assertGreater(len(inputs[0]), 100) self.assertIn("metadata:", inputs[0]) self.assertIn("settings:", inputs[0]) self.assertEqual(len(inputs[1]), 0) self.assertGreater(len(inputs[2]), 100) self.assertIn("custom isotopics:", inputs[2]) self.assertIn("blocks:", inputs[2]) def test_deleting(self): self.assertEqual(type(self.db), database3.Database3) del self.db self.assertFalse(hasattr(self, "db")) self.db = self.dbi.database def test_open(self): with self.assertRaises(ValueError): self.db.open() def test_loadCS(self): cs = self.db.loadCS() self.assertEqual(cs["numProcessors"], 1) self.assertEqual(cs["nCycles"], 6) def test_loadBlueprints(self): bp = self.db.loadBlueprints() self.assertIsNone(bp.nuclideFlags) self.assertEqual(len(bp.assemblies), 0)
def setUp(self): self.td = TemporaryDirectoryChanger() self.td.__enter__()
class TestBlockConverter(unittest.TestCase): def setUp(self): self.td = TemporaryDirectoryChanger() self.td.__enter__() def tearDown(self): self.td.__exit__(None, None, None) def test_dissolveWireIntoCoolant(self): self._test_dissolve(loadTestBlock(), "wire", "coolant") hotBlock = loadTestBlock(cold=False) self._test_dissolve(hotBlock, "wire", "coolant") hotBlock = self._perturbTemps(hotBlock, "wire", 127, 800) self._test_dissolve(hotBlock, "wire", "coolant") def test_dissolveLinerIntoClad(self): self._test_dissolve(loadTestBlock(), "outer liner", "clad") hotBlock = loadTestBlock(cold=False) self._test_dissolve(hotBlock, "outer liner", "clad") hotBlock = self._perturbTemps(hotBlock, "outer liner", 127, 800) self._test_dissolve(hotBlock, "outer liner", "clad") def _perturbTemps(self, block, cName, tCold, tHot): "Give the component different ref and hot temperatures than in test_Blocks." c = block.getComponent(Flags.fromString(cName)) c.refTemp, c.refHot = tCold, tHot c.applyHotHeightDensityReduction() c.setTemperature(tHot) return block def _test_dissolve(self, block, soluteName, solventName): converter = blockConverters.ComponentMerger(block, soluteName, solventName) convertedBlock = converter.convert() self.assertNotIn(soluteName, convertedBlock.getComponentNames()) self._checkAreaAndComposition(block, convertedBlock) def test_build_NthRing(self): """Test building of one ring.""" RING = 6 block = loadTestBlock(cold=False) block.spatialGrid = grids.HexGrid.fromPitch(1.0) numPinsInRing = 30 converter = blockConverters.HexComponentsToCylConverter(block) fuel, clad = _buildJoyoFuel() pinComponents = [fuel, clad] converter._buildFirstRing(pinComponents) converter.pinPitch = 0.76 converter._buildNthRing(pinComponents, RING) components = converter.convertedBlock self.assertEqual(components[3].name.split()[0], components[-1].name.split()[0]) self.assertAlmostEqual(clad.getNumberDensity("FE56"), components[1].getNumberDensity("FE56")) self.assertAlmostEqual( components[3].getArea() + components[-1].getArea(), clad.getArea() * numPinsInRing / clad.getDimension("mult"), ) def test_convert(self): """Test conversion with no fuel driver.""" block = (loadTestReactor(TEST_ROOT)[1].core.getAssemblies( Flags.FUEL)[2].getFirstBlock(Flags.FUEL)) block.spatialGrid = grids.HexGrid.fromPitch(1.0) area = block.getArea() converter = blockConverters.HexComponentsToCylConverter(block) converter.convert() self.assertAlmostEqual(area, converter.convertedBlock.getArea()) self.assertAlmostEqual(area, block.getArea()) for compType in [Flags.FUEL, Flags.CLAD, Flags.DUCT]: self.assertAlmostEqual( block.getComponent(compType).getArea(), sum([ component.getArea() for component in converter.convertedBlock if component.hasFlags(compType) ]), ) self._checkAreaAndComposition(block, converter.convertedBlock) self._checkCiclesAreInContact(converter.convertedBlock) def test_convertHexWithFuelDriver(self): """Test conversion with fuel driver.""" driverBlock = (loadTestReactor(TEST_ROOT)[1].core.getAssemblies( Flags.FUEL)[2].getFirstBlock(Flags.FUEL)) block = loadTestReactor(TEST_ROOT)[1].core.getFirstBlock(Flags.CONTROL) driverBlock.spatialGrid = None block.spatialGrid = grids.HexGrid.fromPitch(1.0) self._testConvertWithDriverRings( block, driverBlock, blockConverters.HexComponentsToCylConverter, hexagon.numPositionsInRing, ) # This should fail because a spatial grid is required # on the block. driverBlock.spatialGrid = None block.spatialGrid = None with self.assertRaises(ValueError): self._testConvertWithDriverRings( block, driverBlock, blockConverters.HexComponentsToCylConverter, hexagon.numPositionsInRing, ) # The ``BlockAvgToCylConverter`` should work # without any spatial grid defined because it # assumes the grid based on the block type. driverBlock.spatialGrid = None block.spatialGrid = None self._testConvertWithDriverRings( block, driverBlock, blockConverters.BlockAvgToCylConverter, hexagon.numPositionsInRing, ) def test_convertCartesianLatticeWithFuelDriver(self): """Test conversion with fuel driver.""" r = loadTestReactor(TEST_ROOT, inputFileName="zpprTest.yaml")[1] driverBlock = r.core.getAssemblies(Flags.FUEL)[2].getFirstBlock( Flags.FUEL) block = r.core.getAssemblies(Flags.FUEL)[2].getFirstBlock( Flags.BLANKET) driverBlock.spatialGrid = grids.CartesianGrid.fromRectangle(1.0, 1.0) block.spatialGrid = grids.CartesianGrid.fromRectangle(1.0, 1.0) converter = blockConverters.BlockAvgToCylConverter self._testConvertWithDriverRings(block, driverBlock, converter, lambda n: (n - 1) * 8) def _testConvertWithDriverRings(self, block, driverBlock, converterToTest, getNumInRing): area = block.getArea() numExternalFuelRings = [1, 2, 3, 4] numBlocks = 1 for externalRings in numExternalFuelRings: numBlocks += getNumInRing(externalRings + 1) converter = converterToTest(block, driverFuelBlock=driverBlock, numExternalRings=externalRings) convertedBlock = converter.convert() self.assertAlmostEqual(area * numBlocks, convertedBlock.getArea()) self._checkCiclesAreInContact(convertedBlock) plotFile = "convertedBlock_{0}.svg".format(externalRings) converter.plotConvertedBlock(fName=plotFile) os.remove(plotFile) for c in list(reversed(convertedBlock))[:externalRings]: self.assertTrue(c.isFuel(), "c was {}".format(c.name)) convertedBlock.remove( c ) # remove external driver rings in preperation to check composition self._checkAreaAndComposition(block, convertedBlock) def _checkAreaAndComposition(self, block, convertedBlock): self.assertAlmostEqual(block.getArea(), convertedBlock.getArea()) unmergedNucs = block.getNumberDensities() convDens = convertedBlock.getNumberDensities() errorMessage = "" nucs = set(unmergedNucs) | set(convDens) for nucName in nucs: n1, n2 = unmergedNucs[nucName], convDens[nucName] try: self.assertAlmostEqual(n1, n2) except AssertionError: errorMessage += "\nnuc {} not equal. unmerged: {} merged: {}".format( nucName, n1, n2) self.assertTrue(not errorMessage, errorMessage) bMass = block.getMass() self.assertAlmostEqual(bMass, convertedBlock.getMass()) self.assertGreater(bMass, 0.0) # verify it isn't empty def _checkCiclesAreInContact(self, convertedCircleBlock): numComponents = len(convertedCircleBlock) self.assertGreater(numComponents, 1) self.assertTrue( all( isinstance(c, components.Circle) for c in convertedCircleBlock)) lastCompOD = None lastComp = None for c in sorted(convertedCircleBlock): thisID = c.getDimension("id") thisOD = c.getDimension("od") if lastCompOD is None: self.assertTrue( thisID == 0, "The inner component {} should have an ID of zero".format( c), ) else: self.assertTrue( thisID == lastCompOD, "The component {} with id {} was not in contact with the " "previous component ({}) that had od {}".format( c, thisID, lastComp, lastCompOD), ) lastCompOD = thisOD lastComp = c
class TestCompareDB3(unittest.TestCase): """Tests for the compareDB3 module""" def setUp(self): self.td = TemporaryDirectoryChanger() self.td.__enter__() def tearDown(self): self.td.__exit__(None, None, None) def test_outputWriter(self): fileName = "test_outputWriter.txt" with OutputWriter(fileName) as out: out.writeln("Rubber Baby Buggy Bumpers") txt = open(fileName, "r").read() self.assertIn("Rubber", txt) def test_diffResultsBasic(self): # init an instance of the class dr = DiffResults(0.01) self.assertEqual(len(dr._columns), 0) self.assertEqual(len(dr._structureDiffs), 0) self.assertEqual(len(dr.diffs), 0) # simple test of addDiff dr.addDiff("thing", "what", 123.4, 122.2345, 555) self.assertEqual(len(dr._columns), 0) self.assertEqual(len(dr._structureDiffs), 0) self.assertEqual(len(dr.diffs), 3) self.assertEqual(dr.diffs["thing/what mean(abs(diff))"][0], 123.4) self.assertEqual(dr.diffs["thing/what mean(diff)"][0], 122.2345) self.assertEqual(dr.diffs["thing/what max(abs(diff))"][0], 555) # simple test of addTimeStep dr.addTimeStep("timeStep") self.assertEqual(dr._structureDiffs[0], 0) self.assertEqual(dr._columns[0], "timeStep") # simple test of addStructureDiffs dr.addStructureDiffs(7) self.assertEqual(len(dr._structureDiffs), 1) self.assertEqual(dr._structureDiffs[0], 7) # simple test of _getDefault self.assertEqual(len(dr._getDefault()), 0) # simple test of nDiffs self.assertEqual(dr.nDiffs(), 10) def test_compareDatabaseDuplicate(self): """end-to-end test of compareDatabases() on a photocopy database""" # build two super-simple H5 files for testing o, r = test_reactors.loadTestReactor(TEST_ROOT) # create two DBs, identical but for file names dbs = [] for i in range(2): # create the tests DB dbi = database3.DatabaseInterface(r, o.cs) dbi.initDB(fName=self._testMethodName + str(i) + ".h5") db = dbi.database # validate the file exists, and force it to be readable again b = h5py.File(db._fullPath, "r") self.assertEqual(list(b.keys()), ["inputs"]) self.assertEqual(sorted(b["inputs"].keys()), ["blueprints", "geomFile", "settings"]) b.close() # append to lists dbs.append(db) # end-to-end validation that comparing a photocopy database works diffs = compareDatabases(dbs[0]._fullPath, dbs[1]._fullPath) self.assertEqual(len(diffs.diffs), 0) self.assertEqual(diffs.nDiffs(), 0) def test_compareDatabaseSim(self): """end-to-end test of compareDatabases() on very simlar databases""" # build two super-simple H5 files for testing o, r = test_reactors.loadTestReactor(TEST_ROOT) # create two DBs, identical but for file names dbs = [] for nCycles in range(1, 3): # build some test data days = 100 * nCycles cycles = [{ "step days": [days, days], "power fractions": [1, 0.5] }] * nCycles cs = o.cs.modified( newSettings={ "nCycles": nCycles, "cycles": cycles, "reloadDBName": "something_fake.h5", }) # create the tests DB dbi = database3.DatabaseInterface(r, cs) dbi.initDB(fName=self._testMethodName + str(nCycles) + ".h5") db = dbi.database # populate the db with something for cycle, node in ((cycle, node) for cycle in range(nCycles + 1) for node in range(2)): r.p.cycle = cycle r.p.timeNode = node r.p.cycleLength = days * 2 db.writeToDB(r) # validate the file exists, and force it to be readable again b = h5py.File(db._fullPath, "r") dbKeys = sorted(b.keys()) self.assertEqual(len(dbKeys), 2 * (nCycles + 1) + 1) self.assertIn("inputs", dbKeys) self.assertIn("c00n00", dbKeys) self.assertEqual(sorted(b["inputs"].keys()), ["blueprints", "geomFile", "settings"]) b.close() # append to lists dbs.append(db) # end-to-end validation that comparing a photocopy database works diffs = compareDatabases(dbs[0]._fullPath, dbs[1]._fullPath) self.assertEqual(len(diffs.diffs), 456) self.assertEqual(diffs.nDiffs(), 3)