Пример #1
0
    def setUpClass(cls):

        root = logging.getLogger()
        root.setLevel(logging.DEBUG)

        ch = logging.StreamHandler(sys.stdout)
        ch.setLevel(logging.DEBUG)
        formatter = logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        ch.setFormatter(formatter)
        root.addHandler(ch)

        cls.tasks = MongoStore("emmet_test", "tasks", lu_field="last_updated")
        cls.tasks.connect()
        cleardb(cls.tasks.collection.database)

        vaspdb = VaspCalcDb(database="emmet_test")
        tasks_dir = os.path.join(test_dir, "tasks")

        raw_tasks = glob.glob(os.path.join(test_dir, "tasks", "*.json.gz"))
        for task_path in raw_tasks:
            with zopen(task_path) as f:
                data = f.read().decode()
                task = json.loads(data)
            vaspdb.insert_task(task, parse_dos=True, parse_bs=True)
Пример #2
0
def calcdb_from_mgrant(spec_or_dbfile):
    if os.path.exists(spec_or_dbfile):
        return VaspCalcDb.from_db_file(spec_or_dbfile)

    client = Client()
    role = "rw"  # NOTE need write access to source to ensure indexes
    host, dbname_or_alias = spec_or_dbfile.split("/", 1)
    auth = client.get_auth(host, dbname_or_alias, role)
    if auth is None:
        raise Exception("No valid auth credentials available!")
    return VaspCalcDb(
        auth["host"],
        27017,
        auth["db"],
        "tasks",
        auth["username"],
        auth["password"],
        authSource=auth["db"],
    )
Пример #3
0
    def plot_bs(db_credentials,
                material_id,
                task_query=None,
                filename=None,
                **plot_kwargs):
        task_query = task_query if task_query else {}

        # get database connections
        db = MongoClient(db_credentials["host"],
                         db_credentials["port"])[db_credentials["database"]]
        # db.authenticate(db_credentials["username"], db_credentials["password"])
        calc_db = VaspCalcDb(db_credentials["host"], db_credentials["port"],
                             db_credentials["database"], "tasks",
                             db_credentials["username"],
                             db_credentials["password"])

        material_result = db.materials.find_one({"material_id": material_id})
        if not material_result:
            raise RuntimeError(
                "Material id not found in database: {}".format(material_id))

        tasks_ids = [
            int(tid.split("-")[-1])
            for tid in material_result["_tasksbuilder"]["all_task_ids"]
        ]
        # get all band structure tasks
        bs_query = {
            "task_id": {
                "$in": tasks_ids
            },
            "task_label": {
                "$in": ["nscf line"]
            }
        }
        bs_query.update(task_query)
        bs_tasks = list(db.tasks.find(bs_query))
        print(bs_tasks)
        if not bs_tasks:
            raise RuntimeError(
                "No band structure available for: {}".format(material_id))

        # get the band structure of the last band structure task
        band_structure = calc_db.get_band_structure(
            task_id=bs_tasks[-1]["task_id"])
        bs_plotter = SBSPlotter(band_structure)

        # get the DOS tasks
        dos_query = {
            "task_id": {
                "$in": tasks_ids
            },
            "task_label": {
                "$in": ["nscf uniform"]
            }
        }
        dos_query.update(task_query)
        dos_tasks = list(db.tasks.find(dos_query))

        if dos_tasks:
            # get the DOS for the last DOS task
            dos = calc_db.get_dos(task_id=dos_tasks[-1]["task_id"])
            pdos = get_pdos(dos)

            # generate a combined DOS and band structure plot
            dos_plotter = SDOSPlotter(dos, pdos)

            # set some better defaults for BS+DOS plots but don't overwrite user
            # settings
            if not 'dos_aspect' in plot_kwargs:
                plot_kwargs['dos_aspect'] = 4

            if not 'width' in plot_kwargs:
                plot_kwargs['width'] = 8

            plt = bs_plotter.get_plot(dos_plotter=dos_plotter, **plot_kwargs)

        else:
            # if no DOS just plot band structure only
            plt = bs_plotter.get_plot(**plot_kwargs)

        if filename:
            plt.savefig(filename, dpi=400, bbox_inches='tight')
            return plt

        else:
            figfile = BytesIO()
            plt.savefig(figfile, format='png', dpi=400, bbox_inches='tight')

            # rewind to beginning of file and base64 encode
            figfile.seek(0)
            figdata_png = base64.b64encode(figfile.getvalue())
            return figdata_png