Example #1
0
    def test_abinit_dftp_output(self):

        doc_class = self.get_document_class_from_mixin(AbinitDftpOutputMixin)
        doc = doc_class()

        if has_mongodb():
            doc.structure = self.si_structure.as_dict()
            ddb_path = os.path.join(abidata.dirpath, 'refs', 'znse_phonons',
                                    'ZnSe_hex_qpt_DDB')

            # read/write in binary for py3k compatibility with mongoengine
            with open(ddb_path, "rb") as ddb:
                doc.ddb.put(ddb)

            doc.save()

            query_result = doc_class.objects()

            assert len(query_result) == 1
            saved_doc = query_result[0]
            with tempfile.NamedTemporaryFile(mode="wb") as db_file:
                db_file.write(saved_doc.ddb.read())
                db_file.seek(0)

                assert filecmp.cmp(ddb_path, db_file.name)
Example #2
0
    def test_ground_state_output(self):
        doc_class = self.get_document_class_from_mixin(GroundStateOutputMixin)
        doc = doc_class(final_energy=1.1, efermi=0.1, total_magnetization=-0.1, structure=self.si_structure.as_dict())

        if has_mongodb():
            doc.save()
            query_result = doc_class.objects()
            assert len(query_result) == 1
Example #3
0
    def test_calculation_type(self):
        doc_class = self.get_document_class_from_mixin(CalculationTypeMixin)
        doc = doc_class(xc_functional="PBE", pseudo_type="nc", is_hubbard=False, pseudo_dojo_table="table")

        if has_mongodb():
            doc.save()
            query_result = doc_class.objects()
            assert len(query_result) == 1
Example #4
0
    def test_calculation_metadata(self):
        doc_class = self.get_document_class_from_mixin(CalculationMetadataMixin)
        doc = doc_class(user="******", cluster="cluster", execution_date=datetime(2010, 1, 1, 0, 1, 1))

        if has_mongodb():
            doc.save()
            query_result = doc_class.objects()
            assert len(query_result) == 1
Example #5
0
    def test_date(self):
        doc_class = self.get_document_class_from_mixin(DateMixin)
        doc = doc_class(created_on=datetime(2010, 1, 1, 0, 1, 1), modified_on=datetime(2011, 1, 1, 0, 1, 1))

        if has_mongodb():
            doc.save()
            query_result = doc_class.objects()
            assert len(query_result) == 1
Example #6
0
    def test_run_stats(self):
        doc_class = self.get_document_class_from_mixin(RunStatsMixin)
        doc = doc_class(core_num=10, elapsed_time=10.5, maximum_memory_used=1000, number_of_restarts=2,
                        number_of_errors=1, number_of_warnings=2)

        if has_mongodb():
            doc.save()
            query_result = doc_class.objects()
            assert len(query_result) == 1
Example #7
0
    def test_dfpt_result(self):

        doc = DfptResult(mp_id="mp-1",
                         time_report={"test": 1},
                         fw_id=1,
                         relax_db={"test": 1},
                         relax_id="id_string")

        doc.abinit_input.structure = self.si_structure.as_dict()
        doc.abinit_input.gs_input = self.scf_inp.as_dict()
        doc.abinit_input.ddk_input = self.scf_inp.as_dict()
        doc.abinit_input.dde_input = self.scf_inp.as_dict()
        doc.abinit_input.dde_input = self.scf_inp.as_dict()
        doc.abinit_input.wfq_input = self.scf_inp.as_dict()
        doc.abinit_input.strain_input = self.scf_inp.as_dict()
        doc.abinit_input.dte_input = self.scf_inp.as_dict()
        doc.abinit_input.phonon_input = self.scf_inp.as_dict()
        doc.abinit_input.kppa = 1000
        doc.abinit_output.structure = self.si_structure.as_dict()

        if has_mongodb():
            gsr_path = abidata.ref_file('si_scf_GSR.nc')

            with open(gsr_path, "rb") as gsr:
                doc.abinit_output.gs_gsr.put(gsr)

            gs_outfile_path = self.out_file

            # read/write in binary for py3k compatibility with mongoengine
            with open(gs_outfile_path, "rb") as gs_outfile:
                doc.abinit_output.gs_outfile.put(gs_outfile)

            anaddb_nc_path = abidata.ref_file('ZnSe_hex_886.anaddb.nc')

            with open(anaddb_nc_path, "rb") as anaddb_nc:
                doc.abinit_output.anaddb_nc.put(anaddb_nc)

            doc.save()

            query_result = DfptResult.objects()

            assert len(query_result) == 1
            saved_doc = query_result[0]
            with tempfile.NamedTemporaryFile(mode="wb") as db_file:
                db_file.write(saved_doc.abinit_output.gs_gsr.read())
                db_file.seek(0)
                assert filecmp.cmp(gsr_path, db_file.name)

            with tempfile.NamedTemporaryFile(mode="wt") as db_file:
                saved_doc.abinit_output.gs_outfile.unzip(filepath=db_file.name)
                db_file.seek(0)
                assert filecmp.cmp(gs_outfile_path, db_file.name)

            with tempfile.NamedTemporaryFile(mode="wb") as db_file:
                db_file.write(saved_doc.abinit_output.anaddb_nc.read())
                db_file.seek(0)
                assert filecmp.cmp(anaddb_nc_path, db_file.name)
Example #8
0
    def test_date(self):
        doc_class = self.get_document_class_from_mixin(DateMixin)
        doc = doc_class(created_on=datetime(2010, 1, 1, 0, 1, 1),
                        modified_on=datetime(2011, 1, 1, 0, 1, 1))

        if has_mongodb():
            doc.save()
            query_result = doc_class.objects()
            assert len(query_result) == 1
    def test_dte_result(self):

        doc = DteResult(mp_id="mp-1", time_report= {"test": 1}, fw_id=1, relax_db={"test": 1},
                           relax_id="id_string")

        doc.abinit_input.structure = self.si_structure.as_dict()
        doc.abinit_input.gs_input = self.scf_inp.as_dict()
        doc.abinit_input.ddk_input = self.scf_inp.as_dict()
        doc.abinit_input.dde_input = self.scf_inp.as_dict()
        doc.abinit_input.dte_input = self.scf_inp.as_dict()
        doc.abinit_input.phonon_input = self.scf_inp.as_dict()
        doc.abinit_input.kppa = 1000
        doc.abinit_input.with_phonons = True
        doc.abinit_output.structure = self.si_structure.as_dict()
        doc.abinit_output.emacro = np.eye(3).tolist()
        doc.abinit_output.emacro_rlx = np.eye(3).tolist()
        doc.abinit_output.dchide = np.arange(36).reshape((4,3,3)).tolist()
        doc.abinit_output.dchidt = np.arange(36).reshape((2,2,3,3)).tolist()

        if has_mongodb():
            gsr_path = abidata.ref_file('si_scf_GSR.nc')

            with open(gsr_path, "rb") as gsr:
                doc.abinit_output.gs_gsr.put(gsr)

            gs_outfile_path = self.out_file

            # read/write in binary for py3k compatibility with mongoengine
            with open(gs_outfile_path, "rb") as gs_outfile:
                doc.abinit_output.gs_outfile.put(gs_outfile)

            anaddb_nc_path = abidata.ref_file('ZnSe_hex_886.anaddb.nc')

            with open(anaddb_nc_path, "rb") as anaddb_nc:
                doc.abinit_output.anaddb_nc.put(anaddb_nc)

            doc.save()

            query_result = DteResult.objects()

            assert len(query_result) == 1
            saved_doc = query_result[0]
            with tempfile.NamedTemporaryFile(mode="wb") as db_file:
                db_file.write(saved_doc.abinit_output.gs_gsr.read())
                db_file.seek(0)
                assert filecmp.cmp(gsr_path, db_file.name)

            with tempfile.NamedTemporaryFile(mode="wt") as db_file:
                saved_doc.abinit_output.gs_outfile.unzip(filepath=db_file.name)
                db_file.seek(0)
                assert filecmp.cmp(gs_outfile_path, db_file.name)

            with tempfile.NamedTemporaryFile(mode="wb") as db_file:
                db_file.write(saved_doc.abinit_output.anaddb_nc.read())
                db_file.seek(0)
                assert filecmp.cmp(anaddb_nc_path, db_file.name)
Example #10
0
    def test_material(self):
        doc_class = self.get_document_class_from_mixin(MaterialMixin)
        doc = doc_class()
        doc.set_material_data_from_structure(self.si_structure)
        assert doc.nelements == 1

        if has_mongodb():
            doc.save()
            query_result = doc_class.objects()
            assert len(query_result) == 1
Example #11
0
    def test_material(self):
        doc_class = self.get_document_class_from_mixin(MaterialMixin)
        doc = doc_class()
        doc.set_material_data_from_structure(self.si_structure)
        assert doc.nelements == 1

        if has_mongodb():
            doc.save()
            query_result = doc_class.objects()
            assert len(query_result) == 1
Example #12
0
    def test_spacegroup(self):
        doc_class = self.get_document_class_from_mixin(SpaceGroupMixin)
        doc = doc_class()
        doc.set_space_group_from_structure(self.si_structure)
        assert doc.number == 227

        if has_mongodb():
            doc.save()
            query_result = doc_class.objects()
            assert len(query_result) == 1
Example #13
0
    def test_spacegroup(self):
        doc_class = self.get_document_class_from_mixin(SpaceGroupMixin)
        doc = doc_class()
        doc.set_space_group_from_structure(self.si_structure)
        assert doc.number == 227

        if has_mongodb():
            doc.save()
            query_result = doc_class.objects()
            assert len(query_result) == 1
Example #14
0
    def test_calculation_type(self):
        doc_class = self.get_document_class_from_mixin(CalculationTypeMixin)
        doc = doc_class(xc_functional="PBE",
                        pseudo_type="nc",
                        is_hubbard=False,
                        pseudo_dojo_table="table")

        if has_mongodb():
            doc.save()
            query_result = doc_class.objects()
            assert len(query_result) == 1
Example #15
0
    def test_calculation_metadata(self):
        doc_class = self.get_document_class_from_mixin(
            CalculationMetadataMixin)
        doc = doc_class(user="******",
                        cluster="cluster",
                        execution_date=datetime(2010, 1, 1, 0, 1, 1))

        if has_mongodb():
            doc.save()
            query_result = doc_class.objects()
            assert len(query_result) == 1
Example #16
0
    def test_ground_state_output(self):
        doc_class = self.get_document_class_from_mixin(GroundStateOutputMixin)
        doc = doc_class(final_energy=1.1,
                        efermi=0.1,
                        total_magnetization=-0.1,
                        structure=self.si_structure.as_dict())

        if has_mongodb():
            doc.save()
            query_result = doc_class.objects()
            assert len(query_result) == 1
Example #17
0
    def test_hubbards(self):
        doc_class = self.get_document_class_from_mixin(HubbardMixin)
        doc = doc_class(hubbards={"AA": 1.5})
        with self.assertRaises(ValidationError):
            doc.validate()

        doc = doc_class(hubbards={"Ti": 1.5})

        if has_mongodb():
            doc.save()
            query_result = doc_class.objects()
            assert len(query_result) == 1
Example #18
0
    def test_hubbards(self):
        doc_class = self.get_document_class_from_mixin(HubbardMixin)
        doc = doc_class(hubbards={"AA": 1.5})
        with self.assertRaises(ValidationError):
            doc.validate()

        doc = doc_class(hubbards={"Ti": 1.5})

        if has_mongodb():
            doc.save()
            query_result = doc_class.objects()
            assert len(query_result) == 1
Example #19
0
    def test_run_stats(self):
        doc_class = self.get_document_class_from_mixin(RunStatsMixin)
        doc = doc_class(core_num=10,
                        elapsed_time=10.5,
                        maximum_memory_used=1000,
                        number_of_restarts=2,
                        number_of_errors=1,
                        number_of_warnings=2)

        if has_mongodb():
            doc.save()
            query_result = doc_class.objects()
            assert len(query_result) == 1
Example #20
0
class TestMongoEngineDBInsertionTask(AbiflowsTest):
    @classmethod
    def setUpClass(cls):
        cls.setup_fireworks()

    @classmethod
    def tearDownClass(cls):
        cls.teardown_fireworks(module_dir=MODULE_DIR)

    def test_class(self):
        db = DatabaseData("test_db",
                          collection="test_collection",
                          username="******",
                          password="******")
        task = MongoEngineDBInsertionTask(db)

        self.assertFwSerializable(task)

    @unittest.skipUnless(has_mongodb(), "A local mongodb is required.")
    def test_run(self):
        db = DatabaseData(self.lp.name,
                          collection="test_MongoEngineDBInsertionTask",
                          username=self.lp.username,
                          password=self.lp.password)
        task = MongoEngineDBInsertionTask(db)
        fw = Firework([task], fw_id=1, spec={"_add_launchpad_and_fw_id": True})
        wf = Workflow(
            [fw],
            metadata={
                'workflow_class': SaveDataWorkflow.workflow_class,
                'workflow_module': SaveDataWorkflow.workflow_module
            })
        self.lp.add_wf(wf)

        rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR, nlaunches=1)

        wf = self.lp.get_wf_by_fw_id(1)

        assert wf.state == "COMPLETED"

        # retrived the saved object
        # error if not imported locally
        from abiflows.fireworks.tasks.tests.mock_objects import DataDocument
        db.connect_mongoengine()
        with db.switch_collection(DataDocument) as DataDocument:
            data = DataDocument.objects()

            assert len(data) == 1

            assert data[0].test_field_string == "test_text"
            assert data[0].test_field_int == 5
Example #21
0
    def test_abinit_basic_input(self):

        doc_class = self.get_document_class_from_mixin(AbinitBasicInputMixin)
        doc = doc_class()
        doc.set_abinit_basic_from_abinit_input(self.scf_inp)

        if has_mongodb():
            doc.save()

            query_result = doc_class.objects()

            assert len(query_result) == 1
            saved_doc = query_result[0]
            assert saved_doc.ecut == doc.ecut
            self.assertArrayEqual(self.scf_inp['ngkpt'], saved_doc.ngkpt)
            self.assertArrayEqual(self.scf_inp['shiftk'], saved_doc.shiftk)
Example #22
0
class TestAbinitPseudoData(AbiflowsTest):
    @classmethod
    def setUpClass(cls):
        cls.si_structure = abilab.Structure.from_file(
            abidata.cif_file("si.cif"))
        cls.scf_inp = scf_input(cls.si_structure,
                                abidata.pseudos("14si.pspnc"),
                                ecut=2,
                                kppa=10)

    def setUp(self):
        self.setup_mongoengine()

    @classmethod
    def tearDownClass(cls):
        cls.teardown_mongoengine()

    def test_class(self):
        pseudo = AbinitPseudoData()
        pseudo.set_pseudos_from_abinit_input(self.scf_inp)
        with tempfile.NamedTemporaryFile("wt") as files_file:
            files_file.write("run.abi\nrun.abo\nin\n\out\ntmp\n")
            files_file.writelines([
                abidata.pseudo("C.oncvpsp").filepath,
                abidata.pseudo("Ga.oncvpsp").filepath
            ])
            pseudo.set_pseudos_from_files_file(files_file.name, 2)

    @unittest.skipUnless(has_mongodb(), "A local mongodb is required.")
    def test_save(self):
        class TestDocument(Document):
            meta = {'collection': "test_AbinitPseudoData"}
            pseudo = EmbeddedDocumentField(AbinitPseudoData,
                                           default=AbinitPseudoData)

        pseudo = AbinitPseudoData()
        pseudo.set_pseudos_from_abinit_input(self.scf_inp)

        doc = TestDocument()
        doc.pseudo = pseudo
        doc.save()

        query_result = TestDocument.objects()

        assert len(query_result) == 1
Example #23
0
    def test_phonon_result(self):

        doc = PhononResult(mp_id="mp-1", time_report= {"test": 1}, fw_id=1, relax_db={"test": 1},
                           relax_id="id_string")

        doc.abinit_input.structure = self.si_structure.as_dict()
        doc.abinit_input.gs_input = self.scf_inp.as_dict()
        doc.abinit_input.ddk_input = self.scf_inp.as_dict()
        doc.abinit_input.dde_input = self.scf_inp.as_dict()
        doc.abinit_input.wfq_input = self.scf_inp.as_dict()
        doc.abinit_input.phonon_input = self.scf_inp.as_dict()
        doc.abinit_input.kppa = 1000
        doc.abinit_input.ngqpt = [4,4,4]
        doc.abinit_input.qppa = 1000
        doc.abinit_output.structure = self.si_structure.as_dict()

        if has_mongodb():
            gsr_path = abidata.ref_file('si_scf_GSR.nc')

            with open(gsr_path, "rb") as gsr:
                doc.abinit_output.gs_gsr.put(gsr)

            gs_outfile_path = self.out_file

            # read/write in binary for py3k compatibility with mongoengine
            with open(gs_outfile_path, "rb") as gs_outfile:
                doc.abinit_output.gs_outfile.put(gs_outfile)

            doc.save()

            query_result = PhononResult.objects()

            assert len(query_result) == 1
            saved_doc = query_result[0]
            with tempfile.NamedTemporaryFile(mode="wb") as db_file:
                db_file.write(saved_doc.abinit_output.gs_gsr.read())
                db_file.seek(0)
                assert filecmp.cmp(gsr_path, db_file.name)

            with tempfile.NamedTemporaryFile(mode="wt") as db_file:
                saved_doc.abinit_output.gs_outfile.unzip(filepath=db_file.name)
                db_file.seek(0)
                assert filecmp.cmp(gs_outfile_path, db_file.name)
Example #24
0
    def test_relax_result(self):

        doc = RelaxResult(history={"test": 1},
                          mp_id="mp-1",
                          time_report={"test": 1},
                          fw_id=1)

        doc.abinit_input.structure = self.si_structure.as_dict()
        doc.abinit_input.last_input = self.scf_inp.as_dict()
        doc.abinit_input.kppa = 1000

        doc.abinit_output.structure = self.si_structure.as_dict()

        if has_mongodb():
            hist_path = abidata.ref_file('sic_relax_HIST.nc')

            with open(hist_path, "rb") as hist:
                # the proxy class and collection name of the hist file field
                proxy_class = RelaxResult.abinit_output.default.hist_files.field.proxy_class
                collection_name = RelaxResult.abinit_output.default.hist_files.field.collection_name
                file_field = proxy_class(collection_name=collection_name)
                file_field.put(hist)

                doc.abinit_output.hist_files = {'test_hist': file_field}

            outfile_path = self.out_file

            # read/write in binary for py3k compatibility with mongoengine
            with open(outfile_path, "rb") as outfile:
                doc.abinit_output.outfile_ioncell.put(outfile)

            doc.save()

            query_result = RelaxResult.objects()

            assert len(query_result) == 1
            saved_doc = query_result[0]
            with tempfile.NamedTemporaryFile(mode="wb") as db_file:
                db_file.write(
                    saved_doc.abinit_output.hist_files["test_hist"].read())
                db_file.seek(0)
                assert filecmp.cmp(hist_path, db_file.name)
Example #25
0
    def test_directory(self):
        doc_class = self.get_document_class_from_mixin(DirectoryMixin)
        doc = doc_class()

        # create and run a fireworks workflow
        task = PyTask(func="time.sleep", args=[0.5])
        wf = Workflow([Firework(task, spec={'wf_task_index': "1"}, fw_id=1),
                       Firework(task, spec={'wf_task_index': "2"}, fw_id=2)])
        self.lp.add_wf(wf)
        rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR)
        wf = self.lp.get_wf_by_fw_id(1)

        doc.set_dir_names_from_fws_wf(wf)
        assert len(doc.dir_names) == 2

        if has_mongodb():
            doc.save()
            query_result = doc_class.objects()
            assert len(query_result) == 1
            assert len(query_result[0].dir_names) == 2
Example #26
0
class TestBaseClassMethods(AbiflowsTest):

    @classmethod
    def setUpClass(cls):
        cls.si_structure = abilab.Structure.from_file(abidata.cif_file("si.cif"))
        cls.scf_inp = scf_input(cls.si_structure, abidata.pseudos("14si.pspnc"), ecut=2, kppa=10)
        cls.setup_fireworks()

    @classmethod
    def tearDownClass(cls):
        cls.teardown_fireworks()

    def setUp(self):
        self.scf_wf = ScfFWWorkflow(self.scf_inp)

    def tearDown(self):
        if self.lp:
            self.lp.reset(password=None,require_password=False)

    def test_add_fws(self):
        assert len(self.scf_wf.wf.fws) == 1
        self.scf_wf.add_final_cleanup(out_exts=["WFK", "DEN"])
        assert len(self.scf_wf.wf.fws) == 2
        self.scf_wf.add_mongoengine_db_insertion(DatabaseData(TESTDB_NAME))
        assert len(self.scf_wf.wf.fws) == 3
        self.scf_wf.add_cut3d_den_to_cube_task()
        assert len(self.scf_wf.wf.fws) == 4

    def test_fireworks_methods(self):
        self.scf_wf.add_metadata(self.si_structure, {"test": 1})
        assert "nsites" in self.scf_wf.wf.metadata

        self.scf_wf.fix_fworker("test_worker")
        self.scf_wf.get_reduced_formula(self.scf_inp)
        self.scf_wf.set_short_single_core_to_spec()
        self.scf_wf.set_preserve_fworker()
        self.scf_wf.add_spec_to_all_fws({"test_spec": 1})

    @unittest.skipUnless(has_mongodb(), "A local mongodb is required.")
    def test_add_to_db(self):
        self.scf_wf.add_to_db(self.lp)
Example #27
0
    def test_dte_result(self):

        doc = DteResult(mp_id="mp-1",
                        time_report={"test": 1},
                        fw_id=1,
                        relax_db={"test": 1},
                        relax_id="id_string")

        doc.abinit_input.structure = self.si_structure.as_dict()
        doc.abinit_input.gs_input = self.scf_inp.as_dict()
        doc.abinit_input.ddk_input = self.scf_inp.as_dict()
        doc.abinit_input.dde_input = self.scf_inp.as_dict()
        doc.abinit_input.dte_input = self.scf_inp.as_dict()
        doc.abinit_input.phonon_input = self.scf_inp.as_dict()
        doc.abinit_input.kppa = 1000
        doc.abinit_input.with_phonons = True
        doc.abinit_output.structure = self.si_structure.as_dict()
        doc.abinit_output.epsinf = np.eye(3).tolist()
        doc.abinit_output.eps0 = np.eye(3).tolist()
        doc.abinit_output.dchide = np.arange(36).reshape((4, 3, 3)).tolist()
        doc.abinit_output.dchidt = np.arange(36).reshape((2, 2, 3, 3)).tolist()

        if has_mongodb():
            gsr_path = abidata.ref_file('si_scf_GSR.nc')

            with open(gsr_path, "rb") as gsr:
                doc.abinit_output.gs_gsr.put(gsr)

            gs_outfile_path = self.out_file

            # read/write in binary for py3k compatibility with mongoengine
            with open(gs_outfile_path, "rb") as gs_outfile:
                doc.abinit_output.gs_outfile.put(gs_outfile)

            anaddb_nc_path = abidata.ref_file('ZnSe_hex_886.anaddb.nc')

            with open(anaddb_nc_path, "rb") as anaddb_nc:
                doc.abinit_output.anaddb_nc.put(anaddb_nc)

            doc.save()
Example #28
0
    def test_directory(self):
        doc_class = self.get_document_class_from_mixin(DirectoryMixin)
        doc = doc_class()

        # create and run a fireworks workflow
        task = PyTask(func="time.sleep", args=[0.5])
        wf = Workflow([
            Firework(task, spec={'wf_task_index': "1"}, fw_id=1),
            Firework(task, spec={'wf_task_index': "2"}, fw_id=2)
        ])
        self.lp.add_wf(wf)
        rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR)
        wf = self.lp.get_wf_by_fw_id(1)

        doc.set_dir_names_from_fws_wf(wf)
        assert len(doc.dir_names) == 2

        if has_mongodb():
            doc.save()
            query_result = doc_class.objects()
            assert len(query_result) == 1
            assert len(query_result[0].dir_names) == 2
Example #29
0
    def test_relax_result(self):

        doc = RelaxResult(history={"test": 1}, mp_id="mp-1", time_report= {"test": 1}, fw_id=1)

        doc.abinit_input.structure = self.si_structure.as_dict()
        doc.abinit_input.last_input = self.scf_inp.as_dict()
        doc.abinit_input.kppa = 1000

        doc.abinit_output.structure = self.si_structure.as_dict()

        if has_mongodb():
            hist_path = abidata.ref_file('sic_relax_HIST.nc')

            with open(hist_path, "rb") as hist:
                # the proxy class and collection name of the hist file field
                proxy_class = RelaxResult.abinit_output.default.hist_files.field.proxy_class
                collection_name = RelaxResult.abinit_output.default.hist_files.field.collection_name
                file_field = proxy_class(collection_name=collection_name)
                file_field.put(hist)

                doc.abinit_output.hist_files = {'test_hist': file_field}

            outfile_path = self.out_file

            # read/write in binary for py3k compatibility with mongoengine
            with open(outfile_path, "rb") as outfile:
                doc.abinit_output.outfile_ioncell.put(outfile)

            doc.save()

            query_result = RelaxResult.objects()

            assert len(query_result) == 1
            saved_doc = query_result[0]
            with tempfile.NamedTemporaryFile(mode="wb") as db_file:
                db_file.write(saved_doc.abinit_output.hist_files["test_hist"].read())
                db_file.seek(0)
                assert filecmp.cmp(hist_path, db_file.name)
Example #30
0
    def test_abinit_phonon_output(self):

        doc_class = self.get_document_class_from_mixin(AbinitPhononOutputMixin)
        doc = doc_class()

        if has_mongodb():
            doc.structure = self.si_structure.as_dict()
            phbst_path = abidata.ref_file('ZnSe_hex_886.out_PHBST.nc')
            phdos_path = abidata.ref_file('ZnSe_hex_886.out_PHDOS.nc')
            ananc_path = abidata.ref_file('ZnSe_hex_886.anaddb.nc')

            with open(phbst_path, "rb") as phbst:
                doc.phonon_bs.put(phbst)
            with open(phdos_path, "rb") as phdos:
                doc.phonon_dos.put(phdos)
            with open(ananc_path, "rb") as anaddb_nc:
                doc.anaddb_nc.put(anaddb_nc)

            doc.save()

            query_result = doc_class.objects()

            assert len(query_result) == 1
            saved_doc = query_result[0]
            with tempfile.NamedTemporaryFile(mode="wb") as db_file:
                db_file.write(saved_doc.phonon_bs.read())
                db_file.seek(0)
                assert filecmp.cmp(phbst_path, db_file.name)

            with tempfile.NamedTemporaryFile(mode="r+b") as db_file:
                db_file.write(saved_doc.phonon_dos.read())
                db_file.seek(0)
                assert filecmp.cmp(phdos_path, db_file.name)

            with tempfile.NamedTemporaryFile(mode="r+b") as db_file:
                db_file.write(saved_doc.anaddb_nc.read())
                db_file.seek(0)
                assert filecmp.cmp(ananc_path, db_file.name)
Example #31
0
    def test_abinit_gs_output(self):

        doc_class = self.get_document_class_from_mixin(AbinitGSOutputMixin)
        doc = doc_class()

        if has_mongodb():
            doc.structure = self.si_structure.as_dict()
            gsr_path = abidata.ref_file('si_scf_GSR.nc')

            with open(gsr_path, "rb") as gsr:
                doc.gsr.put(gsr)

            doc.save()

            query_result = doc_class.objects()

            assert len(query_result) == 1
            saved_doc = query_result[0]
            with tempfile.NamedTemporaryFile(mode="wb") as db_file:
                db_file.write(saved_doc.gsr.read())
                db_file.seek(0)

                assert filecmp.cmp(gsr_path, db_file.name)
Example #32
0
class TestFinalCleanUpTask(AbiflowsTest):
    @classmethod
    def setUpClass(cls):
        cls.setup_fireworks()

    @classmethod
    def tearDownClass(cls):
        cls.teardown_fireworks(module_dir=MODULE_DIR)

    def tearDown(self):
        if self.lp:
            self.lp.reset(password=None, require_password=False)

    def test_class(self):
        """Test of basic methods"""

        task = FinalCleanUpTask()
        assert "WFK" in task.out_exts
        assert "1WF" in task.out_exts

        task = FinalCleanUpTask(out_exts=["DEN", "WFK"])
        assert "WFK" in task.out_exts
        assert "DEN" in task.out_exts
        assert "1WF" not in task.out_exts

        task = FinalCleanUpTask(out_exts="DEN,WFK")
        assert "WFK" in task.out_exts
        assert "DEN" in task.out_exts
        assert "1WF" not in task.out_exts

        self.assertFwSerializable(task)

        with ScratchDir(".") as tmp_dir:

            wfkfile = tempfile.NamedTemporaryFile(suffix="_WFK", dir=tmp_dir)
            denfile = tempfile.NamedTemporaryFile(suffix="DEN", dir=tmp_dir)

            FinalCleanUpTask.delete_files(tmp_dir, exts=["WFK", "1WF"])

            assert not os.path.isfile(wfkfile.name)
            assert os.path.isfile(denfile.name)

    @unittest.skipUnless(has_mongodb(), "A local mongodb is required.")
    def test_run(self):
        create_fw = Firework(
            [mock_objects.CreateOutputsTask(extensions=["WFK", "DEN"])],
            fw_id=1)
        delete_fw = Firework([FinalCleanUpTask(["WFK", "1WF"])],
                             parents=create_fw,
                             fw_id=2,
                             spec={"_add_launchpad_and_fw_id": True})

        wf = Workflow([create_fw, delete_fw])

        self.lp.add_wf(wf)

        rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR, nlaunches=1)

        # check that the files have been created
        create_fw = self.lp.get_fw_by_id(1)
        create_ldir = create_fw.launches[0].launch_dir

        for d in ["tmpdata", "outdata", "indata"]:
            assert os.path.isfile(os.path.join(create_ldir, d, "tmp_WFK"))
            assert os.path.isfile(os.path.join(create_ldir, d, "tmp_DEN"))

        rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR, nlaunches=1)

        wf = self.lp.get_wf_by_fw_id(1)

        assert wf.state == "COMPLETED"

        for d in ["tmpdata", "indata"]:
            assert not os.path.isfile(os.path.join(create_ldir, d, "tmp_WFK"))
            assert not os.path.isfile(os.path.join(create_ldir, d, "tmp_DEN"))

        assert not os.path.isfile(
            os.path.join(create_ldir, "outdata", "tmp_WFK"))
        assert os.path.isfile(os.path.join(create_ldir, "outdata", "tmp_DEN"))
Example #33
0
class TestFunctions(AbiflowsTest):
    @classmethod
    def setUpClass(cls):
        cls.setup_fireworks()
        ftm_path = os.path.join(test_dir, "fw_manager_ok.yaml")
        with open(ftm_path, "rt") as fh:
            conf = yaml.load(fh)
        conf['fw_policy']['abipy_manager'] = os.path.join(
            test_dir, "manager_ok.yml")
        cls.ftm = FWTaskManager(**conf)

    @classmethod
    def tearDownClass(cls):
        cls.teardown_fireworks(module_dir=MODULE_DIR)

    def tearDown(self):
        if self.lp:
            self.lp.reset(password=None, require_password=False)

    def test_get_short_single_core_spec(self):
        spec = get_short_single_core_spec(self.ftm, timelimit=610)

        assert spec['ntasks'] == 1
        assert spec['time'] == '0-0:10:10'

    def test_set_short_single_core_to_spec(self):
        spec = {}
        spec = set_short_single_core_to_spec(spec, fw_manager=self.ftm)

        assert spec['_queueadapter']['ntasks'] == 1
        assert spec['mpi_ncpus'] == 1

    @unittest.skipUnless(has_mongodb(), "A local mongodb is required.")
    def test_get_time_report_for_wf(self):
        task = PyTask(func="time.sleep", args=[0.5])
        fw1 = Firework([task],
                       spec={
                           'wf_task_index': "test1_1",
                           "nproc": 16
                       },
                       fw_id=1)
        fw2 = Firework([task],
                       spec={
                           'wf_task_index': "test2_1",
                           "nproc": 16
                       },
                       fw_id=2)
        wf = Workflow([fw1, fw2])
        self.lp.add_wf(wf)

        rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR)

        wf = self.lp.get_wf_by_fw_id(1)

        assert wf.state == "COMPLETED"

        tr = get_time_report_for_wf(wf)

        assert tr.n_fws == 2
        assert tr.total_run_time > 1

    @unittest.skipUnless(has_mongodb(), "A local mongodb is required.")
    def test_get_lp_and_fw_id_from_task(self):
        """
        Tests the get_lp_and_fw_id_from_task. This test relies on the fact that the LaunchPad loaded from auto_load
        will be different from what is defined in TESTDB_NAME. If this is not the case the test will be skipped.
        """
        lp = LaunchPad.auto_load()

        if not lp or lp.db.name == TESTDB_NAME:
            raise unittest.SkipTest(
                "LaunchPad lp {} is not suitable for this test. Should be available and different"
                "from {}".format(lp, TESTDB_NAME))

        task = LpTask()
        # this will pass the lp
        fw1 = Firework([task],
                       spec={'_add_launchpad_and_fw_id': True},
                       fw_id=1)
        # this will not have the lp and should fail
        fw2 = Firework([task], spec={}, fw_id=2, parents=[fw1])
        wf = Workflow([fw1, fw2])
        self.lp.add_wf(wf)

        rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR, nlaunches=1)

        fw = self.lp.get_fw_by_id(1)

        assert fw.state == "COMPLETED"

        rapidfire(self.lp, self.fworker, m_dir=MODULE_DIR, nlaunches=1)

        fw = self.lp.get_fw_by_id(2)

        assert fw.state == "FIZZLED"
Example #34
0
class TestDatabaseData(AbiflowsTest):
    def setUp(self):
        self.setup_mongoengine()

    @classmethod
    def tearDownClass(cls):
        cls.teardown_mongoengine()

    def test_class(self):
        db_data = DatabaseData(TESTDB_NAME,
                               host="localhost",
                               port=27017,
                               collection="test_collection",
                               username="******",
                               password="******")

        self.assertMSONable(db_data)
        d = db_data.as_dict_no_credentials()
        assert "username" not in d
        assert "password" not in d

    @unittest.skipUnless(has_mongodb(), "A local mongodb is required.")
    def test_connection(self):
        db_data = DatabaseData(TESTDB_NAME, collection="test_collection")

        db_data.connect_mongoengine()

        class TestDocument(Document):
            test = StringField()

        with db_data.switch_collection(TestDocument) as TestDocument:
            TestDocument(test="abc").save()

        # check the collection with pymongo
        client = MongoClient()
        db = client[TESTDB_NAME]
        collection = db[db_data.collection]
        documents = collection.find()
        assert documents.count() == 1
        assert documents[0]['test'] == "abc"

    @unittest.skipUnless(has_mongodb(), "A local mongodb is required.")
    def test_connection_alias(self):

        test_db_name_2 = TESTDB_NAME + "_2"
        try:
            db_data2 = DatabaseData(test_db_name_2,
                                    collection="test_collection")
            db_data2.connect_mongoengine(alias="test_alias")

            db_data1 = DatabaseData(TESTDB_NAME, collection="test_collection")
            db_data1.connect_mongoengine()

            class TestDocument(Document):
                test = StringField()

            with switch_db(TestDocument, "test_alias") as TestDocument:
                with db_data2.switch_collection(TestDocument) as TestDocument:
                    TestDocument(test="abc").save()

            # check the collection with pymongo
            client = MongoClient()
            db = client[test_db_name_2]
            collection = db[db_data2.collection]
            documents = collection.find()
            assert documents.count() == 1
            assert documents[0]['test'] == "abc"
        finally:
            if self._connection:
                self._connection.drop_database(test_db_name_2)