コード例 #1
0
def prepare_computer(computername, workdir):
    """Create new computer in db or read computer from db if it already exists."""
    from aiida.orm import Computer
    from aiida.orm.querybuilder import QueryBuilder

    # first check if computer exists already in database
    qb = QueryBuilder()
    qb.append(Computer, tag='computer')
    all_computers = qb.dict()
    computer_found_in_db = False
    if len(all_computers) > 0:
        for icomp in range(len(all_computers)):
            c = all_computers[icomp].get('computer').get('*')
            if c.get_name() == computername:
                computer_found_in_db = True
                comp = c
    # if it is not there create a new one
    if not computer_found_in_db:
        #comp = Computer(computername, 'test computer', transport_type='local', scheduler_type='direct', workdir=workdir)
        comp = Computer(computername,
                        'test computer',
                        transport_type='local',
                        scheduler_type='direct',
                        workdir=workdir)
        comp.set_default_mpiprocs_per_machine(4)
        comp.store()
        print('computer stored now cofigure')
        comp.configure()
    else:
        print('found computer in database')
    # return computer
    return comp
コード例 #2
0
ファイル: app.py プロジェクト: kjappelbaum/aiida_dashboard
def get_job_calc_data(computer_dict):
    all_job_calc_qb = QueryBuilder()
    all_job_calc_qb.append(
        JobCalculation,
        project=['type', 'id', 'ctime', 'state', 'dbcomputer_id'],
    )
    tmp = all_job_calc_qb.dict()
    result = [calc['JobCalculation_1'] for calc in tmp]
    return reformat_calc_list(result, computer_dict)
コード例 #3
0
ファイル: data_psf.py プロジェクト: bosonie/aiida-ape
def listfamilies(element, with_description):
    """
    Print on screen the list of installed PSF-pseudo families.
    """
    from aiida import is_dbenv_loaded, load_dbenv
    if not is_dbenv_loaded():
        load_dbenv()

    from aiida.orm import DataFactory
    from aiida_siesta.data.psf import PSFGROUP_TYPE

    PsfData = DataFactory('siesta.psf')
    from aiida.orm.querybuilder import QueryBuilder
    from aiida.orm.group import Group
    qb = QueryBuilder()
    qb.append(PsfData, tag='psfdata')

    if element:
        qb.add_filter(PsfData, {'attributes.element': {'in': element}})

    qb.append(
        Group,
        group_of='psfdata',
        tag='group',
        project=["name", "description"],
        filters={
            "type": {
                '==': PSFGROUP_TYPE
            }
        })

    qb.distinct()
    if qb.count() > 0:
        for res in qb.dict():
            group_name = res.get("group").get("name")
            group_desc = res.get("group").get("description")
            qb = QueryBuilder()
            qb.append(
                Group, tag='thisgroup', filters={
                    "name": {
                        'like': group_name
                    }
                })
            qb.append(PsfData, project=["id"], member_of='thisgroup')

            if with_description:
                description_string = ": {}".format(group_desc)
            else:
                description_string = ""

            click.echo("* {} [{} pseudos]{}".format(group_name,
                                                    qb.count(),
                                                    description_string))

    else:
        click.echo("No valid PSF pseudopotential family found.", err=True)
コード例 #4
0
def slab_structure_list():
    """
    load the slab structure list and label in a list of dictionaries
    for structure matching
    """
    group_node = 199 # original slab structure generation without push-in oxygens
    qb = QueryBuilder()
    qb.append(Group,tag = 'group',filters={'id':{'in':[group_node]}})
    qb.append(StructureData, member_of = 'group', tag = 'slab_structure',filters={'label':{'ilike':'%slab%'}}, project = ['id','label'])
    res = [x['slab_structure'] for x in qb.dict()]#slab structure id and their label
    return res
コード例 #5
0
def load_failed_environ_calc(group_node):
    """
    load output structure of previous relaxed structures
    """
    qb = QueryBuilder()
    qb.append(Group,tag = 'group',filters={'id':{'in':[group_node]}})
    qb.append(JobCalculation, member_of = 'group',project = ['id'], tag = 'calculation',filters={
        "or":[
            #{'state':{'==':'SUBMISSIONFAILED'}},
            {'state':{'==':'FAILED'}}]})

    id_list = qb.dict()
    res = [x['calculation']['id'] for x in id_list]
    return res 
コード例 #6
0
ファイル: app.py プロジェクト: kjappelbaum/aiida_dashboard
def get_computer_names():
    qb_computer_name = QueryBuilder(
    )  # Instantiating instance. One instance -> one query
    qb_computer_name.append(
        Computer,
        project=['id', 'name'],
    )
    tmp = qb_computer_name.dict()
    computer_dict = {}
    for computer in tmp:
        computer_dict[computer['computer_1']
                      ['id']] = computer['computer_1']['name']

    return computer_dict
コード例 #7
0
ファイル: cmd_upf.py プロジェクト: chrisjsewell/aiida_core
def listfamilies(elements, with_description):
    """
    Print on screen the list of upf families installed
    """
    from aiida.orm import DataFactory
    from aiida.orm.data.upf import UPFGROUP_TYPE

    # pylint: disable=invalid-name
    UpfData = DataFactory('upf')
    from aiida.orm.querybuilder import QueryBuilder
    from aiida.orm.group import Group
    qb = QueryBuilder()
    qb.append(UpfData, tag='upfdata')
    if elements is not None:
        qb.add_filter(UpfData, {'attributes.element': {'in': elements}})
    qb.append(Group,
              group_of='upfdata',
              tag='group',
              project=["name", "description"],
              filters={"type": {
                  '==': UPFGROUP_TYPE
              }})

    qb.distinct()
    if qb.count() > 0:
        for res in qb.dict():
            group_name = res.get("group").get("name")
            group_desc = res.get("group").get("description")
            qb = QueryBuilder()
            qb.append(Group,
                      tag='thisgroup',
                      filters={"name": {
                          'like': group_name
                      }})
            qb.append(UpfData, project=["id"], member_of='thisgroup')

            if with_description:
                description_string = ": {}".format(group_desc)
            else:
                description_string = ""

            echo.echo_success("* {} [{} pseudos]{}".format(
                group_name, qb.count(), description_string))

    else:
        echo.echo_warning("No valid UPF pseudopotential family found.")
コード例 #8
0
def water_BE():

    qb = QueryBuilder()

    qb.append(Group,
              tag="group",
              project="id",
              filters={
                  "id": {
                      "in": [
                          269,
                          271,
                          272,
                          273,
                          274,
                          275,
                          277,
                          279,
                          280,
                      ]
                  }
              })
    qb.append(JobCalculation,
              member_of="group",
              tag="calculation",
              filters={"state": {
                  "==": "FINISHED"
              }},
              project=["id"])
    calc_list = qb.dict()

    print "Total slab structures in water %s . . ." % len(calc_list)

    with open("water_slab_BE.txt", "w") as f:

        for bulk_calc in calc_list:
            shift_energy = shift_fermi(bulk_calc['calculation']['id'])
            VBM, CBM = BE(bulk_calc['calculation']['id'], shift_energy)
            A_site, B_site, term_site = site_term_atm(
                bulk_calc['calculation']['id'])
            f.write(
                str(A_site) + "    " + str(B_site) + "    " + str(term_site) +
                "    " + str(VBM) + "    " + str(CBM) + "    " +
                str(bulk_calc['calculation']['id']) + "\n")
コード例 #9
0
ファイル: exportfile.py プロジェクト: nvarini/aiida_core
    def run(self, *args):
        load_dbenv()

        import argparse

        from aiida.orm.querybuilder import QueryBuilder
        from aiida.orm import Group, Node, Computer
        from aiida.orm.importexport import export, export_zip

        parser = argparse.ArgumentParser(
            prog=self.get_full_command_name(),
            description='Export data from the DB.')
        parser.add_argument('-c',
                            '--computers',
                            nargs='+',
                            type=int,
                            metavar="PK",
                            help="Export the given computers")
        parser.add_argument('-n',
                            '--nodes',
                            nargs='+',
                            type=int,
                            metavar="PK",
                            help="Export the given nodes")
        parser.add_argument('-g',
                            '--groups',
                            nargs='+',
                            metavar="GROUPNAME",
                            help="Export all nodes in the given group(s), "
                            "identified by name.",
                            type=str)
        parser.add_argument('-G',
                            '--group_pks',
                            nargs='+',
                            metavar="PK",
                            help="Export all nodes in the given group(s), "
                            "identified by pk.",
                            type=str)
        parser.add_argument('-P',
                            '--no-parents',
                            dest='no_parents',
                            action='store_true',
                            help="Store only the nodes that are explicitly "
                            "given, without exporting the parents")
        parser.set_defaults(no_parents=False)
        parser.add_argument('-O',
                            '--no-calc-outputs',
                            dest='no_calc_outputs',
                            action='store_true',
                            help="If a calculation is included in the list of "
                            "nodes to export, do not export its outputs")
        parser.set_defaults(no_calc_outputs=False)
        parser.add_argument('-y',
                            '--overwrite',
                            dest='overwrite',
                            action='store_true',
                            help="Overwrite the output file, if it exists")
        parser.set_defaults(overwrite=False)

        zipsubgroup = parser.add_mutually_exclusive_group()
        zipsubgroup.add_argument(
            '-z',
            '--zipfile-compressed',
            dest='zipfilec',
            action='store_true',
            help="Store as zip file (experimental, should be "
            "faster")
        zipsubgroup.add_argument('-Z',
                                 '--zipfile-uncompressed',
                                 dest='zipfileu',
                                 action='store_true',
                                 help="Store as uncompressed zip file "
                                 "(experimental, should be faster")
        parser.set_defaults(zipfilec=False)
        parser.set_defaults(zipfileu=False)

        parser.add_argument('output_file',
                            type=str,
                            help='The output file name for the export file')

        parsed_args = parser.parse_args(args)

        if parsed_args.nodes is None:
            node_id_set = set()
        else:
            node_id_set = set(parsed_args.nodes)

        group_dict = dict()

        if parsed_args.groups is not None:
            qb = QueryBuilder()
            qb.append(Group,
                      tag='group',
                      project=['*'],
                      filters={'name': {
                          'in': parsed_args.groups
                      }})
            qb.append(Node, tag='node', member_of='group', project=['id'])
            res = qb.dict()

            group_dict.update(
                {_['group']['*'].name: _['group']['*'].dbgroup
                 for _ in res})
            node_id_set.update([_['node']['id'] for _ in res])

        if parsed_args.group_pks is not None:
            qb = QueryBuilder()
            qb.append(Group,
                      tag='group',
                      project=['*'],
                      filters={'id': {
                          'in': parsed_args.group_pks
                      }})
            qb.append(Node, tag='node', member_of='group', project=['id'])
            res = qb.dict()

            group_dict.update(
                {_['group']['*'].name: _['group']['*'].dbgroup
                 for _ in res})
            node_id_set.update([_['node']['id'] for _ in res])

        # The db_groups that correspond to what was searched above
        dbgroups_list = group_dict.values()

        # Getting the nodes that correspond to the ids that were found above
        if len(node_id_set) > 0:
            qb = QueryBuilder()
            qb.append(Node,
                      tag='node',
                      project=['*'],
                      filters={'id': {
                          'in': node_id_set
                      }})
            node_list = [_[0] for _ in qb.all()]
        else:
            node_list = list()

        # Check if any of the nodes wasn't found in the database.
        missing_nodes = node_id_set.difference(_.id for _ in node_list)
        for id in missing_nodes:
            print >> sys.stderr, ("WARNING! Node with pk= {} "
                                  "not found, skipping.".format(id))

        # The dbnodes of the above node list
        dbnode_list = [_.dbnode for _ in node_list]

        if parsed_args.computers is not None:
            qb = QueryBuilder()
            qb.append(Computer,
                      tag='comp',
                      project=['*'],
                      filters={'id': {
                          'in': set(parsed_args.computers)
                      }})
            computer_list = [_[0] for _ in qb.all()]
            missing_computers = set(parsed_args.computers).difference(
                _.id for _ in computer_list)
            for id in missing_computers:
                print >> sys.stderr, ("WARNING! Computer with pk= {} "
                                      "not found, skipping.".format(id))
        else:
            computer_list = []

        # The dbcomputers of the above computer list
        dbcomputer_list = [_.dbcomputer for _ in computer_list]

        what_list = dbnode_list + dbcomputer_list + dbgroups_list

        export_function = export
        additional_kwargs = {}
        if parsed_args.zipfileu:
            export_function = export_zip
            additional_kwargs.update({"use_compression": False})
        elif parsed_args.zipfilec:
            export_function = export_zip
            additional_kwargs.update({"use_compression": True})
        try:
            export_function(what=what_list,
                            also_parents=not parsed_args.no_parents,
                            also_calc_outputs=not parsed_args.no_calc_outputs,
                            outfile=parsed_args.output_file,
                            overwrite=parsed_args.overwrite,
                            **additional_kwargs)
        except IOError as e:
            print >> sys.stderr, "IOError: {}".format(e.message)
            sys.exit(1)
コード例 #10
0
    def test_simple_query_django_2(self):
        from aiida.orm.querybuilder import QueryBuilder
        from aiida.orm import Node
        from datetime import datetime
        from aiida.backends.querybuild.dummy_model import (
            DbNode, DbLink, DbAttribute, session
        )

        n0 = DbNode(
            label='hello',
            type='',
            description='', user_id=1,
        )
        n1 = DbNode(
            label='foo',
            type='',
            description='I am FoO', user_id=2,
        )
        n2 = DbNode(
            label='bar',
            type='',
            description='I am BaR', user_id=3,
        )

        DbAttribute(
            key='foo',
            datatype='txt',
            tval='bar',
            dbnode=n0
        )

        l1 = DbLink(input=n0, output=n1, label='random_1', type='')
        l2 = DbLink(input=n1, output=n2, label='random_2', type='')

        session.add_all([n0, n1, n2, l1, l2])


        qb1 = QueryBuilder()
        qb1.append(
            DbNode,
            filters={
                'label': 'hello',
            }
        )
        self.assertEqual(len(list(qb1.all())), 1)

        qh = {
            'path': [
                {
                    'cls': Node,
                    'tag': 'n1'
                },
                {
                    'cls': Node,
                    'tag': 'n2',
                    'output_of': 'n1'
                }
            ],
            'filters': {
                'n1': {
                    'label': {'ilike': '%foO%'},
                },
                'n2': {
                    'label': {'ilike': 'bar%'},
                }
            },
            'project': {
                'n1': ['id', 'uuid', 'ctime', 'label'],
                'n2': ['id', 'description', 'label'],
            }
        }

        qb2 = QueryBuilder(**qh)

        resdict = qb2.dict()
        self.assertEqual(len(resdict), 1)
        resdict = resdict[0]
        self.assertTrue(isinstance(resdict['n1']['ctime'], datetime))
        self.assertEqual(resdict['n2']['label'], 'bar')


        qh = {
            'path': [
                {
                    'cls': Node,
                    'label': 'n1'
                },
                {
                    'cls': Node,
                    'label': 'n2',
                    'output_of': 'n1'
                }
            ],
            'filters': {
                'n1--n2': {'label': {'like': '%_2'}}
            }
        }
        qb = QueryBuilder(**qh)
        self.assertEqual(qb.count(), 1)

        # Test the hashing:
        query1 = qb.get_query()
        qb.add_filter('n2', {'label': 'nonexistentlabel'})
        self.assertEqual(qb.count(), 0)
        query2 = qb.get_query()
        query3 = qb.get_query()

        self.assertTrue(id(query1) != id(query2))
        self.assertTrue(id(query2) == id(query3))
コード例 #11
0
    def calculation_cleanworkdir(self, *args):
        """
        Clean the working directory of calculations by removing all the content of the
        associated RemoteFolder node. Calculations can be identified by pk with the -k flag
        or by specifying limits on the modification times with -p/-o flags
        """
        import argparse

        parser = argparse.ArgumentParser(
            prog=self.get_full_command_name(),
            description="""
                Clean all content of all output remote folders of calculations,
                passed as a list of pks, or identified by modification time.

                If a list of calculation PKs is not passed with the -k option, one or both
                of the -p and -o options has to be specified. If both are specified, a logical
                AND is done between the two, i.e. the calculations that will be cleaned have been
                modified AFTER [-p option] days from now but BEFORE [-o option] days from now.
                Passing the -f option will prevent the confirmation dialog from being prompted.
                """
        )
        parser.add_argument(
            '-k', '--pk', metavar='PK', type=int, nargs='+', dest='pk',
            help='The principal key (PK) of the calculations of which to clean the work directory'
        )
        parser.add_argument(
            '-f', '--force', action='store_true',
            help='Force the cleaning (no prompt)'
        )
        parser.add_argument(
            '-p', '--past-days', metavar='N', type=int, action='store', dest='past_days',
            help='Include calculations that have been modified within the last N days', 
        )
        parser.add_argument(
            '-o', '--older-than', metavar='N', type=int, action='store', dest='older_than',
            help='Include calculations that have been modified more than N days ago',
        )
        parser.add_argument(
            '-c', '--computers', metavar='label', nargs='+', type=str, action='store', dest='computer',
            help='Include only calculations that were ran on these computers'
        )

        if not is_dbenv_loaded():
            load_dbenv()

        from aiida.backends.utils import get_automatic_user
        from aiida.backends.utils import get_authinfo
        from aiida.common.utils import query_yes_no
        from aiida.orm.computer import Computer as OrmComputer
        from aiida.orm.user import User as OrmUser
        from aiida.orm.calculation import Calculation as OrmCalculation
        from aiida.orm.querybuilder import QueryBuilder
        from aiida.utils import timezone
        import datetime

        parsed_args = parser.parse_args(args)

        # If a pk is given then the -o & -p options should not be specified
        if parsed_args.pk is not None:
            if (parsed_args.past_days is not None or parsed_args.older_than is not None):
                print("You cannot specify both a list of calculation pks and the -p or -o options")
                return

        # If no pk is given then at least one of the -o & -p options should be specified
        else:
            if (parsed_args.past_days is None and parsed_args.older_than is None):
                print("You should specify at least a list of calculations or the -p, -o options")
                return

        qb_user_filters = dict()
        user = OrmUser(dbuser=get_automatic_user())
        qb_user_filters["email"] = user.email

        qb_computer_filters = dict()
        if parsed_args.computer is not None:
            qb_computer_filters["name"] = {"in": parsed_args.computer}

        qb_calc_filters = dict()
        if parsed_args.past_days is not None:
            pd_ts = timezone.now() - datetime.timedelta(days=parsed_args.past_days)
            qb_calc_filters["mtime"] = {">": pd_ts}
        if parsed_args.older_than is not None:
            ot_ts = timezone.now() - datetime.timedelta(days=parsed_args.older_than)
            qb_calc_filters["mtime"] = {"<": ot_ts}
        if parsed_args.pk is not None:
            print("parsed_args.pk: ", parsed_args.pk)
            qb_calc_filters["id"] = {"in": parsed_args.pk}

        qb = QueryBuilder()
        qb.append(OrmCalculation, tag="calc",
                  filters=qb_calc_filters,
                  project=["id", "uuid", "attributes.remote_workdir"])
        qb.append(OrmComputer, computer_of="calc", tag="computer",
                  project=["*"],
                  filters=qb_computer_filters)
        qb.append(OrmUser, creator_of="calc", tag="user",
                  project=["*"],
                  filters=qb_user_filters)

        no_of_calcs = qb.count()
        if no_of_calcs == 0:
            print("No calculations found with the given criteria.")
            return

        print("Found {} calculations with the given criteria.".format(
            no_of_calcs))

        if not parsed_args.force:
            if not query_yes_no("Are you sure you want to clean the work "
                                "directory?", "no"):
                return

        # get the uuids of all calculations matching the filters
        calc_list_data = qb.dict()

        # get all computers associated to the calc uuids above, and load them
        # we group them by uuid to avoid computer duplicates
        comp_uuid_to_computers = {_["computer"]["*"].uuid: _["computer"]["*"] for _ in calc_list_data}

        # now build a dictionary with the info of folders to delete
        remotes = {}
        for computer in comp_uuid_to_computers.values():
            # initialize a key of info for a given computer
            remotes[computer.name] = {'transport': get_authinfo(
                computer=computer, aiidauser=user._dbuser).get_transport(),
                                      'computer': computer,
            }

            # select the calc pks done on this computer
            this_calc_pks = [_["calc"]["id"] for _ in calc_list_data
                             if _["computer"]["*"].id == computer.id]

            this_calc_uuids = [unicode(_["calc"]["uuid"])
                               for _ in calc_list_data
                               if _["computer"]["*"].id == computer.id]

            remote_workdirs = [_["calc"]["attributes.remote_workdir"]
                               for _ in calc_list_data
                               if _["calc"]["id"] in this_calc_pks
                               if _["calc"]["attributes.remote_workdir"]
                               is not None]

            remotes[computer.name]['remotes'] = remote_workdirs
            remotes[computer.name]['uuids'] = this_calc_uuids

        # now proceed to cleaning
        for computer, dic in remotes.iteritems():
            print("Cleaning the work directory on computer {}.".format(computer))
            counter = 0
            t = dic['transport']
            with t:
                remote_user = remote_user = t.whoami()
                aiida_workdir = dic['computer'].get_workdir().format(
                    username=remote_user)

                t.chdir(aiida_workdir)
                # Hardcoding the sharding equal to 3 parts!
                existing_folders = t.glob('*/*/*')

                folders_to_delete = [i for i in existing_folders if
                                     i.replace("/", "") in dic['uuids']]

                for folder in folders_to_delete:
                    t.rmtree(folder)
                    counter += 1
                    if counter % 20 == 0 and counter > 0:
                        print("Deleted work directories: {}".format(counter))

            print("{} remote folder(s) cleaned.".format(counter))
コード例 #12
0
ファイル: calculation.py プロジェクト: kriskornel/aiida_core
    def calculation_cleanworkdir(self, *args):
        """
        Clean all the content of all the output remote folders of calculations,
        passed as a list of pks, or identified by modification time.

        If a list of calculation PKs is not passed through -c option, one of
        the option -p or -u has to be specified (if both are given, a logical
        AND is done between the 2 - you clean out calculations modified AFTER
        [-p option] days from now but BEFORE [-o option] days from now).
        If you also pass the -f option, no confirmation will be asked.
        """
        import argparse

        parser = argparse.ArgumentParser(
            prog=self.get_full_command_name(),
            description="Clean work directory (i.e. remote folder) of AiiDA "
            "calculations.")
        parser.add_argument("-k",
                            "--pk",
                            metavar="PK",
                            type=int,
                            nargs="+",
                            help="The principal key (PK) of the calculations "
                            "to clean the workdir of",
                            dest="pk")
        parser.add_argument("-f",
                            "--force",
                            action="store_true",
                            help="Force the cleaning (no prompt)")
        parser.add_argument("-p",
                            "--past-days",
                            metavar="N",
                            help="Add a filter to clean workdir of "
                            "calculations modified during the past N "
                            "days",
                            type=int,
                            action="store",
                            dest="past_days")
        parser.add_argument("-o",
                            "--older-than",
                            metavar="N",
                            help="Add a filter to clean workdir of "
                            "calculations that have been modified on a "
                            "date before N days ago",
                            type=int,
                            action="store",
                            dest="older_than")
        parser.add_argument("-c",
                            "--computers",
                            metavar="label",
                            nargs="+",
                            help="Add a filter to clean workdir of "
                            "calculations on this computer(s) only",
                            type=str,
                            action="store",
                            dest="computer")

        if not is_dbenv_loaded():
            load_dbenv()

        from aiida.backends.utils import get_automatic_user
        from aiida.backends.utils import get_authinfo
        from aiida.common.utils import query_yes_no
        from aiida.orm.computer import Computer as OrmComputer
        from aiida.orm.user import User as OrmUser
        from aiida.orm.calculation import Calculation as OrmCalculation
        from aiida.orm.querybuilder import QueryBuilder
        from aiida.utils import timezone
        import datetime

        parsed_args = parser.parse_args(args)

        # If a pk is given then the -o & -p options should not be specified
        if parsed_args.pk is not None:
            if ((parsed_args.past_days is not None)
                    or (parsed_args.older_than is not None)):
                print(
                    "You cannot specify both a list of calculation pks and "
                    "the -p or -o options")
                return
        # If no pk is given then at least one of the -o & -p options should be
        # specified
        else:
            if ((parsed_args.past_days is None)
                    and (parsed_args.older_than is None)):
                print(
                    "You should specify at least a list of calculations or "
                    "the -p, -o options")
                return

        # At this point we know that either the pk or the -p -o options are
        # specified

        # We also check that not both -o & -p options are specified
        if ((parsed_args.past_days is not None)
                and (parsed_args.older_than is not None)):
            print(
                "Not both of the -p, -o options can be specified in the "
                "same time")
            return

        qb_user_filters = dict()
        user = OrmUser(dbuser=get_automatic_user())
        qb_user_filters["email"] = user.email

        qb_computer_filters = dict()
        if parsed_args.computer is not None:
            qb_computer_filters["name"] = {"in": parsed_args.computer}

        qb_calc_filters = dict()
        if parsed_args.past_days is not None:
            pd_ts = timezone.now() - datetime.timedelta(
                days=parsed_args.past_days)
            qb_calc_filters["mtime"] = {">": pd_ts}
        if parsed_args.older_than is not None:
            ot_ts = timezone.now() - datetime.timedelta(
                days=parsed_args.older_than)
            qb_calc_filters["mtime"] = {"<": ot_ts}
        if parsed_args.pk is not None:
            print("parsed_args.pk: ", parsed_args.pk)
            qb_calc_filters["id"] = {"in": parsed_args.pk}

        qb = QueryBuilder()
        qb.append(OrmCalculation,
                  tag="calc",
                  filters=qb_calc_filters,
                  project=["id", "uuid", "attributes.remote_workdir"])
        qb.append(OrmComputer,
                  computer_of="calc",
                  project=["*"],
                  filters=qb_computer_filters)
        qb.append(OrmUser,
                  creator_of="calc",
                  project=["*"],
                  filters=qb_user_filters)

        no_of_calcs = qb.count()
        if no_of_calcs == 0:
            print("No calculations found with the given criteria.")
            return

        print("Found {} calculations with the given criteria.".format(
            no_of_calcs))

        if not parsed_args.force:
            if not query_yes_no(
                    "Are you sure you want to clean the work "
                    "directory?", "no"):
                return

        # get the uuids of all calculations matching the filters
        calc_list_data = qb.dict()

        # get all computers associated to the calc uuids above, and load them
        # we group them by uuid to avoid computer duplicates
        comp_uuid_to_computers = {
            _["computer"]["*"].uuid: _["computer"]["*"]
            for _ in calc_list_data
        }

        # now build a dictionary with the info of folders to delete
        remotes = {}
        for computer in comp_uuid_to_computers.values():
            # initialize a key of info for a given computer
            remotes[computer.name] = {
                'transport':
                get_authinfo(computer=computer,
                             aiidauser=user._dbuser).get_transport(),
                'computer':
                computer,
            }

            # select the calc pks done on this computer
            this_calc_pks = [
                _["calc"]["id"] for _ in calc_list_data
                if _["computer"]["*"].id == computer.id
            ]

            this_calc_uuids = [
                unicode(_["calc"]["uuid"]) for _ in calc_list_data
                if _["computer"]["*"].id == computer.id
            ]

            remote_workdirs = [
                _["calc"]["attributes.remote_workdir"] for _ in calc_list_data
                if _["calc"]["id"] in this_calc_pks
                if _["calc"]["attributes.remote_workdir"] is not None
            ]

            remotes[computer.name]['remotes'] = remote_workdirs
            remotes[computer.name]['uuids'] = this_calc_uuids

        # now proceed to cleaning
        for computer, dic in remotes.iteritems():
            print(
                "Cleaning the work directory on computer {}.".format(computer))
            counter = 0
            t = dic['transport']
            with t:
                remote_user = remote_user = t.whoami()
                aiida_workdir = dic['computer'].get_workdir().format(
                    username=remote_user)

                t.chdir(aiida_workdir)
                # Hardcoding the sharding equal to 3 parts!
                existing_folders = t.glob('*/*/*')

                folders_to_delete = [
                    i for i in existing_folders
                    if i.replace("/", "") in dic['uuids']
                ]

                for folder in folders_to_delete:
                    t.rmtree(folder)
                    counter += 1
                    if counter % 20 == 0 and counter > 0:
                        print("Deleted work directories: {}".format(counter))

            print("{} remote folder(s) cleaned.".format(counter))
コード例 #13
0
class BaseTranslator(object):
    """
    Generic class for translator. It contains the methods
    required to build a related QueryBuilder object
    """

    # A label associated to the present class
    __label__ = None
    # The AiiDA class one-to-one associated to the present class
    _aiida_class = None
    # The string name of the AiiDA class
    _aiida_type = None

    # The string associated to the AiiDA class in the query builder lexicon
    _qb_type = None

    # If True (False) the corresponding AiiDA class has (no) uuid property
    _has_uuid = None

    _result_type = __label__

    _default = _default_projections = ["**"]

    _schema_projections = {
        "column_order": [],
        "additional_info": {}
    }

    _is_qb_initialized = False
    _is_id_query = None
    _total_count = None

    def __init__(self, Class=None, **kwargs):
        """
        Initialise the parameters.
        Create the basic query_help

        keyword Class (default None but becomes this class): is the class
        from which one takes the initial values of the attributes. By default
        is this class so that class atributes are  translated into object
        attributes. In case of inheritance one cane use the
        same constructore but pass the inheriting class to pass its attributes.
        """

        # Assume default class is this class (cannot be done in the
        # definition as it requires self)
        if Class is None:
            Class = self.__class__

        # Assign class parameters to the object
        self.__label__ = Class.__label__
        self._aiida_class = Class._aiida_class
        self._aiida_type = Class._aiida_type
        self._qb_type = Class._qb_type
        self._result_type = Class.__label__

        self._default = Class._default
        self._default_projections = Class._default_projections
        self._schema_projections = Class._schema_projections
        self._is_qb_initialized = Class._is_qb_initialized
        self._is_id_query = Class._is_id_query
        self._total_count = Class._total_count

        # Basic filter (dict) to set the identity of the uuid. None if
        #  no specific node is requested
        self._id_filter = None

        # basic query_help object
        self._query_help = {
            "path": [{
                "type": self._qb_type,
                "label": self.__label__
            }],
            "filters": {},
            "project": {},
            "order_by": {}
        }
        # query_builder object (No initialization)
        self.qb = QueryBuilder()

        self.LIMIT_DEFAULT = kwargs['LIMIT_DEFAULT']
        self.schema = None

    def __repr__(self):
        """
        This function is required for the caching system to be able to compare
        two NodeTranslator objects. Comparison is done on the value returned by __repr__

        :return: representation of NodeTranslator objects. Returns nothing
            because the inputs of self.get_nodes are sufficient to determine the
            identity of two queries.
        """
        return ""

    def get_schema(self):

        # Construct the full class string
        class_string = 'aiida.orm.' + self._aiida_type

        # Load correspondent orm class
        orm_class = get_object_from_string(class_string)

        # Construct the json object to be returned
        basic_schema = orm_class.get_schema()

        schema = {}
        ordering = []

        # get addional info and column order from translator class
        # and combine it with basic schema
        if len(self._schema_projections["column_order"]) > 0:
            for field in self._schema_projections["column_order"]:

                # basic schema
                if field in basic_schema.keys():
                    schema[field] = basic_schema[field]
                else:
                    ## Note: if column name starts with user_* get the schema information from
                    # user class. It is added mainly to handle user_email case.
                    # TODO need to improve
                    field_parts = field.split("_")
                    if field_parts[0] == "user" and field != "user_id" and len(field_parts) > 1:
                        from aiida.orm.user import User
                        user_schema = User.get_schema()
                        if field_parts[1] in user_schema.keys():
                            schema[field] = user_schema[field_parts[1]]
                        else:
                            raise KeyError("{} is not present in user schema".format(field))
                    else:
                        raise KeyError("{} is not present in ORM basic schema".format(field))

                # additional info defined in translator class
                if field in self._schema_projections["additional_info"]:
                    schema[field].update(self._schema_projections["additional_info"][field])
                else:
                    raise KeyError("{} is not present in default projection additional info".format(field))

            # order
            ordering = self._schema_projections["column_order"]

        else:
            raise ConfigurationError("Define column order to get schema for {}".format(self._aiida_type))

        return dict(fields=schema, ordering=ordering)

    def init_qb(self):
        """
        Initialize query builder object by means of _query_help
        """
        self.qb.__init__(**self._query_help)
        self._is_qb_initialized = True

    def count(self):
        """
        Count the number of rows returned by the query and set total_count
        """
        if self._is_qb_initialized:
            self._total_count = self.qb.count()
        else:
            raise InvalidOperation("query builder object has not been "
                                   "initialized.")

            # def caching_method(self):
            #     """
            #     class method for caching. It is a wrapper of the
            # flask_cache memoize
            #     method. To be used as a decorator
            #     :return: the flask_cache memoize method with the timeout kwarg
            #     corrispondent to the class
            #     """
            #     return cache.memoize()
            #

            #    @cache.memoize(timeout=CACHING_TIMEOUTS[self.__label__])

    def get_total_count(self):
        """
        Returns the number of rows of the query.

        :return: total_count
        """
        ## Count the results if needed
        if not self._total_count:
            self.count()

        return self._total_count

    def set_filters(self, filters={}):
        """
        Add filters in query_help.

        :param filters: it is a dictionary where keys are the tag names
            given in the path in query_help and their values are the dictionary
            of filters want to add for that tag name. Format for the Filters
            dictionary::

                filters = {
                    "tag1" : {k1:v1, k2:v2},
                    "tag2" : {k1:v1, k2:v2},
                }

        :return: query_help dict including filters if any.
        """
        if isinstance(filters, dict):
            if len(filters) > 0:
                for tag, tag_filters in filters.iteritems():
                    if len(tag_filters) > 0 and isinstance(tag_filters, dict):
                        self._query_help["filters"][tag] = {}
                        for filter_key, filter_value in tag_filters.iteritems():
                            if filter_key == "pk":
                                filter_key = pk_dbsynonym
                            self._query_help["filters"][tag][filter_key] \
                                = filter_value
        else:
            raise InputValidationError("Pass data in dictionary format where "
                                       "keys are the tag names given in the "
                                       "path in query_help and and their values"
                                       " are the dictionary of filters want "
                                       "to add for that tag name.")

    def get_default_projections(self):
        """
        method to get default projections of the node
        :return: self._default_projections
        """
        return self._default_projections

    def set_default_projections(self):
        """
        It calls the set_projections() methods internally to add the
        default projections in query_help

        :return: None
        """
        self.set_projections({self.__label__: self._default_projections})

    def set_projections(self, projections):
        """
        add the projections in query_help

        :param projections: it is a dictionary where keys are the tag names
         given in the path in query_help and values are the list of the names
         you want to project in the final output
        :return: updated query_help with projections
        """
        if isinstance(projections, dict):
            if len(projections) > 0:
                for project_key, project_list in projections.iteritems():
                    self._query_help["project"][project_key] = project_list
        else:
            raise InputValidationError("Pass data in dictionary format where "
                                       "keys are the tag names given in the "
                                       "path in query_help and values are the "
                                       "list of the names you want to project "
                                       "in the final output")

    def set_order(self, orders):
        """
        Add order_by clause in query_help
        :param orders: dictionary of orders you want to apply on final
        results
        :return: None or exception if any.
        """
        ## Validate input
        if type(orders) is not dict:
            raise InputValidationError("orders has to be a dictionary"
                                       "compatible with the 'order_by' section"
                                       "of the query_help")

        ## Auxiliary_function to get the ordering cryterion
        def def_order(columns):
            """
            Takes a list of signed column names ex. ['id', '-ctime',
            '+mtime']
            and transforms it in a order_by compatible dictionary
            :param columns: (list of strings)
            :return: a dictionary
            """
            order_dict = {}
            for column in columns:
                if column[0] == '-':
                    order_dict[column[1:]] = 'desc'
                elif column[0] == '+':
                    order_dict[column[1:]] = 'asc'
                else:
                    order_dict[column] = 'asc'
            if order_dict.has_key('pk'):
                order_dict[pk_dbsynonym] = order_dict.pop('pk')
            return order_dict

        ## Assign orderby field query_help
        for tag, columns in orders.iteritems():
            self._query_help['order_by'][tag] = def_order(columns)

    def set_query(self, filters=None, orders=None, projections=None, id=None):
        """
        Adds filters, default projections, order specs to the query_help,
        and initializes the qb object

        :param filters: dictionary with the filters
        :param orders: dictionary with the order for each tag
        :param orders: dictionary with the projections
        :param id: id of a specific node
        :type id: int
        """

        tagged_filters = {}

        ## Check if filters are well defined and construct an ad-hoc filter
        # for id_query
        if id is not None:
            self._is_id_query = True
            if self._result_type == self.__label__ and len(filters) > 0:
                raise RestInputValidationError("selecting a specific id does "
                                               "not "
                                               "allow to specify filters")

            try:
                self._check_id_validity(id)
            except RestValidationError as e:
                raise RestValidationError(e.message)
            else:
                tagged_filters[self.__label__] = self._id_filter
                if self._result_type is not self.__label__:
                    tagged_filters[self._result_type] = filters
        else:
            tagged_filters[self.__label__] = filters

        ## Add filters
        self.set_filters(tagged_filters)

        ## Add projections
        if projections is None:
            self.set_default_projections()
        else:
            tagged_projections = {self._result_type: projections}
            self.set_projections(tagged_projections)

        ##Add order_by
        if orders is not None:
            tagged_orders = {self._result_type: orders}
            self.set_order(tagged_orders)

        ## Initialize the query_object
        self.init_qb()

    def get_query_help(self):
        """
        :return: return QB json dictionary
        """
        return self._query_help

    def set_limit_offset(self, limit=None, offset=None):
        """
        sets limits and offset directly to the query_builder object

        :param limit:
        :param offset:
        :return:
        """

        ## mandatory params
        # none

        ## non-mandatory params
        if limit is not None:
            try:
                limit = int(limit)
            except ValueError:
                raise InputValidationError("Limit value must be an integer")
            if limit > self.LIMIT_DEFAULT:
                raise RestValidationError("Limit and perpage cannot be bigger "
                                          "than {}".format(self.LIMIT_DEFAULT))
        else:
            limit = self.LIMIT_DEFAULT

        if offset is not None:
            try:
                offset = int(offset)
            except ValueError:
                raise InputValidationError("Offset value must be an "
                                           "integer")

        if self._is_qb_initialized:
            if limit is not None:
                self.qb.limit(limit)
            else:
                pass
            if offset is not None:
                self.qb.offset(offset)
            else:
                pass
        else:
            raise InvalidOperation("query builder object has not been "
                                   "initialized.")

    def get_formatted_result(self, label):
        """
        Runs the query and retrieves results tagged as "label".

        :param label: the tag of the results to be extracted out of
          the query rows.
        :type label: str
        :return: a list of the query results
        """

        if not self._is_qb_initialized:
            raise InvalidOperation("query builder object has not been "
                                   "initialized.")

        results = []
        if self._total_count > 0:
            results = [res[label] for res in self.qb.dict()]

        # TODO think how to make it less hardcoded
        if self._result_type == 'input_of':
            return {'inputs': results}
        elif self._result_type == 'output_of':
            return {'outputs': results}
        else:
            return {self.__label__: results}

    def get_results(self):
        """
        Returns either list of nodes or details of single node from database.

        :return: either list of nodes or details of single node from database
        """

        ## Check whether the querybuilder object has been initialized
        if not self._is_qb_initialized:
            raise InvalidOperation("query builder object has not been "
                                   "initialized.")

        ## Count the total number of rows returned by the query (if not
        # already done)
        if self._total_count is None:
            self.count()

        ## Retrieve data
        data = self.get_formatted_result(self._result_type)
        return data

    def _check_id_validity(self, id):
        """
        Checks whether id corresponds to an object of the expected type,
        whenever type is a valid column of the database (ex. for nodes,
        but not for users)
        
        :param id: id (or id starting pattern)
        
        :return: True if id valid, False if invalid. If True, sets the id
          filter attribute correctly
            
        :raise RestValidationError: if no node is found or id pattern does
          not identify a unique node
        """
        from aiida.common.exceptions import MultipleObjectsError, NotExistent
        from aiida.orm.utils.loaders import IdentifierType, get_loader

        loader = get_loader(self._aiida_class)

        if self._has_uuid:

            # For consistency check that tid is a string
            if not isinstance(id, (str, unicode)):
                raise RestValidationError('parameter id has to be an string')

            identifier_type = IdentifierType.UUID
            qb, _ = loader.get_query_builder(id, identifier_type, sub_classes=(self._aiida_class,))
        else:

            # Similarly, check that id is an integer
            if not isinstance(id, int):
                raise RestValidationError('parameter id has to be an integer')

            identifier_type = IdentifierType.ID
            qb, _ = loader.get_query_builder(id, identifier_type, sub_classes=(self._aiida_class,))

        # For efficiency I don't go further than two results
        qb.limit(2)

        try:
            pk = qb.one()[0].pk
        except MultipleObjectsError:
            raise RestValidationError("More than one node found."
                                      " Provide longer starting pattern"
                                      " for id.")
        except NotExistent:
            raise RestValidationError("either no object's id starts"
                                      " with '{}' or the corresponding object"
                                      " is not of type aiida.orm.{}"
                                      .format(id, self._aiida_type))
        else:
            # create a permanent filter
            self._id_filter = {'id': {'==': pk}}
            return True
コード例 #14
0
ファイル: query.py プロジェクト: asle85/aiida-core
    def test_simple_query_2(self):
        from aiida.orm.querybuilder import QueryBuilder
        from aiida.orm import Node
        from datetime import datetime
        from aiida.common.exceptions import MultipleObjectsError, NotExistent
        n0 = Node()
        n0.label = 'hello'
        n0.description=''
        n0._set_attr('foo', 'bar')

        n1 = Node()
        n1.label='foo'
        n1.description='I am FoO'

        n2 = Node()
        n2.label='bar'
        n2.description='I am BaR'

        n2.add_link_from(n1, label='random_2')
        n1.add_link_from(n0, label='random_1')

        for n in (n0, n1, n2):
            n.store()



        qb1 = QueryBuilder()
        qb1.append(Node, filters={'label': 'hello'})
        self.assertEqual(len(list(qb1.all())), 1)


        qh = {
            'path': [
                {
                    'cls': Node,
                    'tag': 'n1'
                },
                {
                    'cls': Node,
                    'tag': 'n2',
                    'output_of': 'n1'
                }
            ],
            'filters': {
                'n1': {
                    'label': {'ilike': '%foO%'},
                },
                'n2': {
                    'label': {'ilike': 'bar%'},
                }
            },
            'project': {
                'n1': ['id', 'uuid', 'ctime', 'label'],
                'n2': ['id', 'description', 'label'],
            }
        }

        qb2 = QueryBuilder(**qh)


        resdict = qb2.dict()
        self.assertEqual(len(resdict), 1)
        self.assertTrue(isinstance(resdict[0]['n1']['ctime'], datetime))


        res_one = qb2.one()
        self.assertTrue('bar' in res_one)




        qh = {
            'path': [
                {
                    'cls': Node,
                    'tag': 'n1'
                },
                {
                    'cls': Node,
                    'tag': 'n2',
                    'output_of': 'n1'
                }
            ],
            'filters': {
                'n1--n2': {'label': {'like': '%_2'}}
            }
        }
        qb = QueryBuilder(**qh)
        self.assertEqual(qb.count(), 1)


        # Test the hashing:
        query1 = qb.get_query()
        qb.add_filter('n2', {'label': 'nonexistentlabel'})
        self.assertEqual(qb.count(), 0)

        with self.assertRaises(NotExistent):
            qb.one()
        with self.assertRaises(MultipleObjectsError):
            QueryBuilder().append(Node).one()

        query2 = qb.get_query()
        query3 = qb.get_query()

        self.assertTrue(id(query1) != id(query2))
        self.assertTrue(id(query2) == id(query3))
コード例 #15
0
class BaseTranslator(object):
    """
    Generic class for translator. It also contains all methods
    required to build QueryBuilder object
    """

    # A label associated to the present class
    __label__ = None
    # The string name of the AiiDA class one-to-one associated to the present
    #  class
    _aiida_type = None
    # The string associated to the AiiDA class in the query builder lexicon
    _qb_type = None

    _result_type = __label__

    _default = _default_projections = []
    _is_qb_initialized = False
    _is_pk_query = None
    _total_count = None

    def __init__(self, Class=None, **kwargs):
        """
        Initialise the parameters.
        Create the basic query_help

        keyword Class (default None but becomes this class): is the class
        from which one takes the initial values of the attributes. By default
        is this class so that class atributes are  translated into object
        attributes. In case of inheritance one cane use the
        same constructore but pass the inheriting class to pass its attributes.
        """

        # Assume default class is this class (cannot be done in the
        # definition as it requires self)
        if Class is None:
            Class = self.__class__

        # Assign class parameters to the object
        self.__label__ = Class.__label__
        self._aiida_type = Class._aiida_type
        self._qb_type = Class._qb_type
        self._result_type = Class.__label__

        self._default = Class._default
        self._default_projections = Class._default_projections
        self._is_qb_initialized = Class._is_qb_initialized
        self._is_pk_query = Class._is_pk_query
        self._total_count = Class._total_count

        # basic query_help object
        self._query_help = {
            "path": [{
                "type": self._qb_type,
                "label": self.__label__
            }],
            "filters": {},
            "project": {},
            "order_by": {}
        }
        # query_builder object (No initialization)
        self.qb = QueryBuilder()

        self.LIMIT_DEFAULT = kwargs['LIMIT_DEFAULT']

        if 'custom_schema' in kwargs:
            self.custom_schema = kwargs['custom_schema']
        else:
            self.custom_schema = None

    def __repr__(self):
        """
        This function is required for the caching system to be able to compare
        two NodeTranslator objects. Comparison is done on the value returned by
        __repr__
        :return: representation of NodeTranslator objects. Returns nothing
        because the inputs of self.get_nodes are sufficient to determine the
        identity of two queries.
        """
        return ""

    def get_schema(self):

        # Construct the full class string
        class_string = 'aiida.orm.' + self._aiida_type

        # Load correspondent orm class
        orm_class = get_object_from_string(class_string)

        # Construct the json object to be returned
        basic_schema = orm_class.get_db_columns()

        if self._default_projections == ['**']:
            schema = basic_schema  # No custom schema, take the basic one
        else:
            schema = dict([(k, basic_schema[k]) for k in
                           self._default_projections
                           if k in basic_schema.keys()])

        # Convert the related_tablevalues to the RESTAPI resources
        # (orm class/db table ==> RESTapi resource)
        def table2resource(table_name):
            # TODO Consider ways to make this function backend independent (one
            # idea would be to go from table name to aiida class name which is
            # univoque)
            if BACKEND == BACKEND_DJANGO:
                (spam, resource_name) = issingular(table_name[2:].lower())
            elif BACKEND == BACKEND_SQLA:
                (spam, resource_name) = issingular(table_name[5:])
            elif BACKEND is None:
                raise ConfigurationError("settings.BACKEND has not been set.\n"
                                         "Hint: Have you called "
                                         "aiida.load_dbenv?")
            else:
                raise ConfigurationError("Unknown settings.BACKEND: {}".format(
                    BACKEND))
            return resource_name

        for k, v in schema.iteritems():

            # Add custom fields to the column dictionaries
            if 'fields' in self.custom_schema:
                if k in self.custom_schema['fields'].keys():
                    schema[k].update(self.custom_schema['fields'][k])

            # Convert python types values into strings
            schema[k]['type'] = str(schema[k]['type'])[7:-2]

            # Construct the 'related resource' field from the 'related_table'
            # field
            if v['is_foreign_key'] == True:
                schema[k]['related_resource'] = table2resource(
                    schema[k].pop('related_table'))

        return dict(columns=schema)

    def init_qb(self):
        """
        Initialize query builder object by means of _query_help
        """
        self.qb.__init__(**self._query_help)
        self._is_qb_initialized = True

    def count(self):
        """
        Count the number of rows returned by the query and set total_count
        """
        if self._is_qb_initialized:
            self._total_count = self.qb.count()
        else:
            raise InvalidOperation("query builder object has not been "
                                   "initialized.")

            # def caching_method(self):
            #     """
            #     class method for caching. It is a wrapper of the
            # flask_cache memoize
            #     method. To be used as a decorator
            #     :return: the flask_cache memoize method with the timeout kwarg
            #     corrispondent to the class
            #     """
            #     return cache.memoize()
            #

            #    @cache.memoize(timeout=CACHING_TIMEOUTS[self.__label__])

    def get_total_count(self):
        """
        Returns the number of rows of the query
        :return: total_count
        """
        ## Count the results if needed
        if not self._total_count:
            self.count()

        return self._total_count

    def set_filters(self, filters={}):
        """
        Add filters in query_help.

        :param filters: it is a dictionary where keys are the tag names
         given in the path in query_help and their values are the dictionary
         of filters want to add for that tag name. Format for the Filters
         dictionary:
         filters = {
                    "tag1" : {k1:v1, k2:v2},
                    "tag2" : {k1:v1, k2:v2},
                  }
        :return: query_help dict including filters if any.
        """

        if isinstance(filters, dict):
            if len(filters) > 0:
                for tag, tag_filters in filters.iteritems():
                    if len(tag_filters) > 0 and isinstance(tag_filters, dict):
                        self._query_help["filters"][tag] = {}
                        for filter_key, filter_value in tag_filters.iteritems():
                            if filter_key == "pk":
                                filter_key = pk_dbsynonym
                            self._query_help["filters"][tag][filter_key] \
                                = filter_value
        else:
            raise InputValidationError("Pass data in dictionary format where "
                                       "keys are the tag names given in the "
                                       "path in query_help and and their values"
                                       " are the dictionary of filters want "
                                       "to add for that tag name.")

    def get_default_projections(self):
        """
        method to get default projections of the node
        :return: self._default_projections
        """
        return self._default_projections

    def set_default_projections(self):
        """
        It calls the set_projections() methods internally to add the
        default projections in query_help

        :return: None
        """
        self.set_projections({self.__label__: self._default_projections})

    def set_projections(self, projections):
        """
        add the projections in query_help

        :param projections: it is a dictionary where keys are the tag names
         given in the path in query_help and values are the list of the names
         you want to project in the final output
        :return: updated query_help with projections
        """
        if isinstance(projections, dict):
            if len(projections) > 0:
                for project_key, project_list in projections.iteritems():
                    self._query_help["project"][project_key] = project_list
        else:
            raise InputValidationError("Pass data in dictionary format where "
                                       "keys are the tag names given in the "
                                       "path in query_help and values are the "
                                       "list of the names you want to project "
                                       "in the final output")

    def set_order(self, orders):
        """
        Add order_by clause in query_help
        :param orders: dictionary of orders you want to apply on final
        results
        :return: None or exception if any.
        """
        ## Validate input
        if type(orders) is not dict:
            raise InputValidationError("orders has to be a dictionary"
                                       "compatible with the 'order_by' section"
                                       "of the query_help")

        ## Auxiliary_function to get the ordering cryterion
        def def_order(columns):
            """
            Takes a list of signed column names ex. ['id', '-ctime',
            '+mtime']
            and transforms it in a order_by compatible dictionary
            :param columns: (list of strings)
            :return: a dictionary
            """
            order_dict = {}
            for column in columns:
                if column[0] == '-':
                    order_dict[column[1:]] = 'desc'
                elif column[0] == '+':
                    order_dict[column[1:]] = 'asc'
                else:
                    order_dict[column] = 'asc'
            if order_dict.has_key('pk'):
                order_dict[pk_dbsynonym] = order_dict.pop('pk')
            return order_dict

        ## Assign orderby field query_help
        for tag, columns in orders.iteritems():
            self._query_help['order_by'][tag] = def_order(columns)

    def set_query(self, filters=None, orders=None, projections=None, pk=None):
        """
        Adds filters, default projections, order specs to the query_help,
        and initializes the qb object

        :param filters: dictionary with the filters
        :param orders: dictionary with the order for each tag
        :param pk (integer): pk of a specific node
        """

        tagged_filters = {}

        ## Check if filters are well defined and construct an ad-hoc filter
        # for pk_query
        if pk is not None:
            self._is_pk_query = True
            if self._result_type == self.__label__ and len(filters) > 0:
                raise RestInputValidationError("selecting a specific pk does "
                                               "not "
                                               "allow to specify filters")
            elif not self._check_pk_validity(pk):
                raise RestValidationError(
                    "either the selected pk does not exist "
                    "or the corresponding object is not of "
                    "type aiida.orm.{}".format(self._aiida_type))
            else:
                tagged_filters[self.__label__] = {'id': {'==': pk}}
                if self._result_type is not self.__label__:
                    tagged_filters[self._result_type] = filters
        else:
            tagged_filters[self.__label__] = filters

        ## Add filters
        self.set_filters(tagged_filters)

        ## Add projections
        if projections is None:
            self.set_default_projections()
        else:
            tagged_projections = {self._result_type: projections}
            self.set_projections(tagged_projections)

        ##Add order_by
        if orders is not None:
            tagged_orders = {self._result_type: orders}
            self.set_order(tagged_orders)

        ## Initialize the query_object
        self.init_qb()

    def get_query_help(self):
        """
        :return: return QB json dictionary
        """
        return self._query_help

    def set_limit_offset(self, limit=None, offset=None):
        """
        sets limits and offset directly to the query_builder object

        :param limit:
        :param offset:
        :return:
        """

        ## mandatory params
        # none

        ## non-mandatory params
        if limit is not None:
            try:
                limit = int(limit)
            except ValueError:
                raise InputValidationError("Limit value must be an integer")
            if limit > self.LIMIT_DEFAULT:
                raise RestValidationError("Limit and perpage cannot be bigger "
                                          "than {}".format(self.LIMIT_DEFAULT))
        else:
            limit = self.LIMIT_DEFAULT

        if offset is not None:
            try:
                offset = int(offset)
            except ValueError:
                raise InputValidationError("Offset value must be an "
                                           "integer")

        if self._is_qb_initialized:
            if limit is not None:
                self.qb.limit(limit)
            else:
                pass
            if offset is not None:
                self.qb.offset(offset)
            else:
                pass
        else:
            raise InvalidOperation("query builder object has not been "
                                   "initialized.")

    def get_formatted_result(self, label):
        """
        Runs the query and retrieves results tagged as "label"
        :param label (string): the tag of the results to be extracted out of
        the query rows.
        :return: a list of the query results
        """

        if not self._is_qb_initialized:
            raise InvalidOperation("query builder object has not been "
                                   "initialized.")

        results = []
        if self._total_count > 0:
            results = [res[label] for res in self.qb.dict()]

        # TODO think how to make it less hardcoded
        if self._result_type == 'input_of':
            return {'inputs': results}
        elif self._result_type == 'output_of':
            return {'outputs': results}
        else:
            return {self.__label__: results}

    def get_results(self):
        """
        Returns either list of nodes or details of single node from database

        :return: either list of nodes or details of single node
        from database
        """

        ## Check whether the querybuilder object has been initialized
        if not self._is_qb_initialized:
            raise InvalidOperation("query builder object has not been "
                                   "initialized.")

        ## Count the total number of rows returned by the query (if not
        # already done)
        if self._total_count is None:
            self.count()

        ## Retrieve data
        data = self.get_formatted_result(self._result_type)
        return data

    def _check_pk_validity(self, pk):
        """
        Checks whether a pk corresponds to an object of the expected type,
        whenever type is a valid column of the database (ex. for nodes,
        but not for users)_
        :param pk: (integer) ok to check
        :return: True or False
        """
        # The logic could be to load the node or to use querybuilder. Let's
        # do with qb for consistency, although it would be easier to do it
        # with load_node

        query_help_base = {
            'path': [
                {
                    'type': self._qb_type,
                    'label': self.__label__,
                },
            ],
            'filters': {
                self.__label__:
                    {
                        'id': {'==': pk}
                    }
            }
        }

        qb_base = QueryBuilder(**query_help_base)
        return qb_base.count() == 1
コード例 #16
0
def create(outfile, computers, groups, nodes, group_names, input_forward,
           create_reversed, return_reversed, call_reversed, overwrite,
           archive_format):
    """
    Export nodes and groups of nodes to an archive file for backup or sharing purposes
    """
    import sys
    from aiida.backends.utils import load_dbenv, is_dbenv_loaded
    # TODO: Replace with aiida.cmdline.utils.decorators.with_dbenv decocator
    # TODO: when we merge to develop
    if not is_dbenv_loaded():
        load_dbenv()
    from aiida.orm import Group, Node, Computer
    from aiida.orm.querybuilder import QueryBuilder
    from aiida.orm.importexport import export, export_zip

    node_id_set = set(nodes)
    group_dict = dict()

    if group_names:
        qb = QueryBuilder()
        qb.append(Group, tag='group', project=['*'], filters={'name': {'in': group_names}})
        qb.append(Node, tag='node', member_of='group', project=['id'])
        res = qb.dict()

        group_dict.update(
            {group['group']['*'].id: group['group']['*'] for group in res})
        node_id_set.update([node['node']['id'] for node in res])

    if groups:
        qb = QueryBuilder()
        qb.append(Group, tag='group', project=['*'], filters={'id': {'in': groups}})
        qb.append(Node, tag='node', member_of='group', project=['id'])
        res = qb.dict()

        group_dict.update(
            {group['group']['*'].id: group['group']['*'] for group in res})
        node_id_set.update([node['node']['id'] for node in res])

    groups_list = group_dict.values()

    # Getting the nodes that correspond to the ids that were found above
    if len(node_id_set) > 0:
        qb = QueryBuilder()
        qb.append(Node, tag='node', project=['*'], filters={'id': {'in': node_id_set}})
        node_list = [node[0] for node in qb.all()]
    else:
        node_list = list()

    # Check if any of the nodes wasn't found in the database.
    missing_nodes = node_id_set.difference(node.id for node in node_list)
    for node_id in missing_nodes:
        print >> sys.stderr, ('WARNING! Node with pk={} not found, skipping'.format(node_id))

    if computers:
        qb = QueryBuilder()
        qb.append(Computer, tag='comp', project=['*'], filters={'id': {'in': set(computers)}})
        computer_list = [computer[0] for computer in qb.all()]
        missing_computers = set(computers).difference(computer.id for computer in computer_list)

        for computer_id in missing_computers:
            print >> sys.stderr, ('WARNING! Computer with pk={} not found, skipping'.format(computer_id))
    else:
        computer_list = []

    what_list = node_list + computer_list + groups_list
    additional_kwargs = dict()

    if archive_format == 'zip':
        export_function = export_zip
        additional_kwargs.update({'use_compression': True})
    elif archive_format == 'zip-uncompressed':
        export_function = export_zip
        additional_kwargs.update({'use_compression': False})
    elif archive_format == 'tar.gz':
        export_function = export
    else:
        print >> sys.stderr, 'invalid --archive-format value {}'.format(
            archive_format)
        sys.exit(1)

    try:
        export_function(
            what=what_list, input_forward=input_forward,
            create_reversed=create_reversed,
            return_reversed=return_reversed,
            call_reversed=call_reversed, outfile=outfile,
            overwrite=overwrite, **additional_kwargs
        )

    except IOError as e:
        print >> sys.stderr, 'IOError: {}'.format(e.message)
        sys.exit(1)
コード例 #17
0
ファイル: base.py プロジェクト: sponce24/aiida-core
class BaseTranslator(object):
    """
    Generic class for translator. It contains the methods
    required to build a related QueryBuilder object
    """
    # pylint: disable=too-many-instance-attributes,fixme

    # A label associated to the present class
    __label__ = None
    # The AiiDA class one-to-one associated to the present class
    _aiida_class = None
    # The string name of the AiiDA class
    _aiida_type = None

    # If True (False) the corresponding AiiDA class has (no) uuid property
    _has_uuid = None

    _result_type = __label__

    _default = _default_projections = ['**']

    _is_qb_initialized = False
    _is_id_query = None
    _total_count = None

    def __init__(self, **kwargs):
        """
        Initialise the parameters.
        Create the basic query_help

        keyword Class (default None but becomes this class): is the class
        from which one takes the initial values of the attributes. By default
        is this class so that class atributes are translated into object
        attributes. In case of inheritance one cane use the
        same constructore but pass the inheriting class to pass its attributes.
        """
        # Basic filter (dict) to set the identity of the uuid. None if
        #  no specific node is requested
        self._id_filter = None

        # basic query_help object
        self._query_help = {
            'path': [{
                'cls': self._aiida_class,
                'tag': self.__label__
            }],
            'filters': {},
            'project': {},
            'order_by': {}
        }
        # query_builder object (No initialization)
        self.qbobj = QueryBuilder()

        self.limit_default = kwargs['LIMIT_DEFAULT']
        self.schema = None

    def __repr__(self):
        """
        This function is required for the caching system to be able to compare
        two NodeTranslator objects. Comparison is done on the value returned by __repr__

        :return: representation of NodeTranslator objects. Returns nothing
            because the inputs of self.get_nodes are sufficient to determine the
            identity of two queries.
        """
        return ''

    @staticmethod
    def get_projectable_properties():
        """
        This method is extended in specific translators classes.
        It returns a dict as follows:
        dict(fields=projectable_properties, ordering=ordering)
        where projectable_properties is a dict and ordering is a list
        """
        return {}

    def init_qb(self):
        """
        Initialize query builder object by means of _query_help
        """
        self.qbobj.__init__(**self._query_help)
        self._is_qb_initialized = True

    def count(self):
        """
        Count the number of rows returned by the query and set total_count
        """
        if self._is_qb_initialized:
            self._total_count = self.qbobj.count()
        else:
            raise InvalidOperation(
                'query builder object has not been initialized.')

            # def caching_method(self):
            #     """
            #     class method for caching. It is a wrapper of the
            # flask_cache memoize
            #     method. To be used as a decorator
            #     :return: the flask_cache memoize method with the timeout kwarg
            #     corrispondent to the class
            #     """
            #     return cache.memoize()
            #

            #    @cache.memoize(timeout=CACHING_TIMEOUTS[self.__label__])

    def get_total_count(self):
        """
        Returns the number of rows of the query.

        :return: total_count
        """
        ## Count the results if needed
        if not self._total_count:
            self.count()

        return self._total_count

    def set_filters(self, filters=None):
        """
        Add filters in query_help.

        :param filters: it is a dictionary where keys are the tag names
            given in the path in query_help and their values are the dictionary
            of filters want to add for that tag name. Format for the Filters
            dictionary::

                filters = {
                    "tag1" : {k1:v1, k2:v2},
                    "tag2" : {k1:v1, k2:v2},
                }

        :return: query_help dict including filters if any.
        """
        if filters is None:
            filters = {}

        if isinstance(filters, dict):  # pylint: disable=too-many-nested-blocks
            if filters:
                for tag, tag_filters in filters.items():
                    if tag_filters and isinstance(tag_filters, dict):
                        self._query_help['filters'][tag] = {}
                        for filter_key, filter_value in tag_filters.items():
                            if filter_key == 'pk':
                                filter_key = PK_DBSYNONYM
                            self._query_help['filters'][tag][filter_key] \
                                = filter_value
        else:
            raise InputValidationError(
                'Pass data in dictionary format where '
                'keys are the tag names given in the '
                'path in query_help and and their values'
                ' are the dictionary of filters want '
                'to add for that tag name.')

    def get_default_projections(self):
        """
        method to get default projections of the node
        :return: self._default_projections
        """
        return self._default_projections

    def set_default_projections(self):
        """
        It calls the set_projections() methods internally to add the
        default projections in query_help

        :return: None
        """
        self.set_projections({self.__label__: self._default_projections})

    def set_projections(self, projections):
        """
        add the projections in query_help

        :param projections: it is a dictionary where keys are the tag names
         given in the path in query_help and values are the list of the names
         you want to project in the final output
        :return: updated query_help with projections
        """
        if isinstance(projections, dict):
            if projections:
                for project_key, project_list in projections.items():
                    self._query_help['project'][project_key] = project_list

        else:
            raise InputValidationError('Pass data in dictionary format where '
                                       'keys are the tag names given in the '
                                       'path in query_help and values are the '
                                       'list of the names you want to project '
                                       'in the final output')

    def set_order(self, orders):
        """
        Add order_by clause in query_help
        :param orders: dictionary of orders you want to apply on final
        results
        :return: None or exception if any.
        """
        ## Validate input
        if not isinstance(orders, dict):
            raise InputValidationError('orders has to be a dictionary'
                                       "compatible with the 'order_by' section"
                                       'of the query_help')

        ## Auxiliary_function to get the ordering cryterion
        def def_order(columns):
            """
            Takes a list of signed column names ex. ['id', '-ctime',
            '+mtime']
            and transforms it in a order_by compatible dictionary
            :param columns: (list of strings)
            :return: a dictionary
            """
            from collections import OrderedDict
            order_dict = OrderedDict()
            for column in columns:
                if column[0] == '-':
                    order_dict[column[1:]] = 'desc'
                elif column[0] == '+':
                    order_dict[column[1:]] = 'asc'
                else:
                    order_dict[column] = 'asc'
            if 'pk' in order_dict:
                order_dict[PK_DBSYNONYM] = order_dict.pop('pk')
            return order_dict

        ## Assign orderby field query_help
        if 'id' not in orders[self._result_type] and '-id' not in orders[
                self._result_type]:
            orders[self._result_type].append('id')
        for tag, columns in orders.items():
            self._query_help['order_by'][tag] = def_order(columns)

    def set_query(self,
                  filters=None,
                  orders=None,
                  projections=None,
                  query_type=None,
                  node_id=None,
                  attributes=None,
                  attributes_filter=None,
                  extras=None,
                  extras_filter=None):
        # pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
        """
        Adds filters, default projections, order specs to the query_help,
        and initializes the qb object

        :param filters: dictionary with the filters
        :param orders: dictionary with the order for each tag
        :param projections: dictionary with the projection. It is discarded
            if query_type=='attributes'/'extras'
        :param query_type: (string) specify the result or the content ("attr")
        :param id: (integer) id of a specific node
        :param filename: name of the file to return its content
        :param attributes: flag to show attributes in nodes endpoint
        :param attributes_filter: list of node attributes to query
        :param extras: flag to show attributes in nodes endpoint
        :param extras_filter: list of node extras to query
        """

        tagged_filters = {}

        ## Check if filters are well defined and construct an ad-hoc filter
        # for id_query
        if node_id is not None:
            self._is_id_query = True
            if self._result_type == self.__label__ and filters:
                raise RestInputValidationError(
                    'selecting a specific id does not allow to specify filters'
                )

            try:
                self._check_id_validity(node_id)
            except RestValidationError as exc:
                raise RestValidationError(str(exc))
            else:
                tagged_filters[self.__label__] = self._id_filter
                if self._result_type is not self.__label__:
                    tagged_filters[self._result_type] = filters
        else:
            tagged_filters[self.__label__] = filters

        ## Add filters
        self.set_filters(tagged_filters)

        ## Add projections
        if projections is None:
            if attributes is None and extras is None:
                self.set_default_projections()
            else:
                default_projections = self.get_default_projections()

                if attributes is True:
                    if attributes_filter is None:
                        default_projections.append('attributes')
                    else:
                        ## Check if attributes_filter is not a list
                        if not isinstance(attributes_filter, list):
                            attributes_filter = [attributes_filter]
                        for attr in attributes_filter:
                            default_projections.append('attributes.' +
                                                       str(attr))
                elif attributes is not None and attributes is not False:
                    raise RestValidationError(
                        'The attributes filter is false by default and can only be set to true.'
                    )

                if extras is True:
                    if extras_filter is None:
                        default_projections.append('extras')
                    else:
                        ## Check if extras_filter is not a list
                        if not isinstance(extras_filter, list):
                            extras_filter = [extras_filter]
                        for extra in extras_filter:
                            default_projections.append('extras.' + str(extra))
                elif extras is not None and extras is not False:
                    raise RestValidationError(
                        'The extras filter is false by default and can only be set to true.'
                    )

                self.set_projections({self.__label__: default_projections})
        else:
            tagged_projections = {self._result_type: projections}
            self.set_projections(tagged_projections)

        ##Add order_by
        if orders is not None:
            tagged_orders = {self._result_type: orders}
            self.set_order(tagged_orders)

        ## Initialize the query_object
        self.init_qb()

    def get_query_help(self):
        """
        :return: return QB json dictionary
        """
        return self._query_help

    def set_limit_offset(self, limit=None, offset=None):
        """
        sets limits and offset directly to the query_builder object

        :param limit:
        :param offset:
        :return:
        """

        ## mandatory params
        # none

        ## non-mandatory params
        if limit is not None:
            try:
                limit = int(limit)
            except ValueError:
                raise InputValidationError('Limit value must be an integer')
            if limit > self.limit_default:
                raise RestValidationError(
                    'Limit and perpage cannot be bigger than {}'.format(
                        self.limit_default))
        else:
            limit = self.limit_default

        if offset is not None:
            try:
                offset = int(offset)
            except ValueError:
                raise InputValidationError('Offset value must be an integer')

        if self._is_qb_initialized:
            if limit is not None:
                self.qbobj.limit(limit)
            else:
                pass
            if offset is not None:
                self.qbobj.offset(offset)
            else:
                pass
        else:
            raise InvalidOperation(
                'query builder object has not been initialized.')

    def get_formatted_result(self, label):
        """
        Runs the query and retrieves results tagged as "label".

        :param label: the tag of the results to be extracted out of
          the query rows.
        :type label: str
        :return: a list of the query results
        """

        if not self._is_qb_initialized:
            raise InvalidOperation(
                'query builder object has not been initialized.')

        results = []
        if self._total_count > 0:
            for res in self.qbobj.dict():
                tmp = res[label]

                # Note: In code cleanup and design change, remove this node dependant part
                # from base class and move it to node translator.
                if self._result_type in ['with_outgoing', 'with_incoming']:
                    tmp['link_type'] = res[self.__label__ + '--' +
                                           label]['type']
                    tmp['link_label'] = res[self.__label__ + '--' +
                                            label]['label']
                results.append(tmp)

        # TODO think how to make it less hardcoded
        if self._result_type == 'with_outgoing':
            result = {'incoming': results}
        elif self._result_type == 'with_incoming':
            result = {'outgoing': results}
        else:
            result = {self.__label__: results}

        return result

    def get_results(self):
        """
        Returns either list of nodes or details of single node from database.

        :return: either list of nodes or details of single node from database
        """

        ## Check whether the querybuilder object has been initialized
        if not self._is_qb_initialized:
            raise InvalidOperation(
                'query builder object has not been initialized.')

        ## Count the total number of rows returned by the query (if not
        # already done)
        if self._total_count is None:
            self.count()

        ## Retrieve data
        data = self.get_formatted_result(self._result_type)
        return data

    def _check_id_validity(self, node_id):
        """
        Checks whether id corresponds to an object of the expected type,
        whenever type is a valid column of the database (ex. for nodes,
        but not for users)

        :param node_id: id (or id starting pattern)

        :return: True if node_id valid, False if invalid. If True, sets the id
          filter attribute correctly

        :raise RestValidationError: if no node is found or id pattern does
          not identify a unique node
        """
        from aiida.common.exceptions import MultipleObjectsError, NotExistent
        from aiida.orm.utils.loaders import IdentifierType, get_loader

        loader = get_loader(self._aiida_class)

        if self._has_uuid:

            # For consistency check that id is a string
            if not isinstance(node_id, six.string_types):
                raise RestValidationError('parameter id has to be a string')

            identifier_type = IdentifierType.UUID
            qbobj, _ = loader.get_query_builder(
                node_id, identifier_type, sub_classes=(self._aiida_class, ))
        else:

            # Similarly, check that id is an integer
            if not isinstance(node_id, int):
                raise RestValidationError('parameter id has to be an integer')

            identifier_type = IdentifierType.ID
            qbobj, _ = loader.get_query_builder(
                node_id, identifier_type, sub_classes=(self._aiida_class, ))

        # For efficiency I don't go further than two results
        qbobj.limit(2)

        try:
            pk = qbobj.one()[0].pk
        except MultipleObjectsError:
            raise RestInputValidationError(
                'More than one node found. Provide longer starting pattern for id.'
            )
        except NotExistent:
            raise RestInputValidationError(
                "either no object's id starts"
                " with '{}' or the corresponding object"
                ' is not of type aiida.orm.{}'.format(node_id,
                                                      self._aiida_type))
        else:
            # create a permanent filter
            self._id_filter = {'id': {'==': pk}}
            return True
コード例 #18
0
ファイル: rules.py プロジェクト: ramirezfranciscof/AGE
class UpdateRule(Operation):
    def __init__(self,
                 querybuilder,
                 mode=MODES.APPEND,
                 max_iterations=1,
                 track_edges=False,
                 track_visits=True):
        def get_spec_from_path(queryhelp, idx):
            if (queryhelp['path'][idx]['type'].startswith('node')
                    or queryhelp['path'][idx]['type'].startswith('data')
                    or queryhelp['path'][idx]['type'] == ''):
                return 'nodes'
            elif queryhelp['path'][idx]['type'] == 'group':
                return 'groups'
            else:
                raise Exception("not understood entity from ( {} )".format(
                    queryhelp['path'][0]['type']))

        queryhelp = querybuilder.get_json_compatible_queryhelp()
        for pathspec in queryhelp['path']:
            if not pathspec['type']:
                pathspec['type'] = 'node.Node.'
        self._querybuilder = QueryBuilder(**queryhelp)
        queryhelp = self._querybuilder.get_json_compatible_queryhelp()
        self._first_tag = queryhelp['path'][0]['tag']
        self._last_tag = queryhelp['path'][-1]['tag']

        self._entity_from = get_spec_from_path(queryhelp, 0)
        self._entity_to = get_spec_from_path(queryhelp, -1)
        super(UpdateRule, self).__init__(mode,
                                         max_iterations,
                                         track_edges=track_edges,
                                         track_visits=track_visits)

    def _init_run(self, entity_set):
        # Removing all other projections in the QueryBuilder instance:
        for tag in self._querybuilder._projections.keys():
            self._querybuilder._projections[tag] = []
        # priming querybuilder to add projection on the key I need:
        self._querybuilder.add_projection(
            self._last_tag, entity_set[self._entity_to].identifier)
        self._entity_to_identifier = entity_set[self._entity_to].identifier
        if self._track_edges:
            self._querybuilder.add_projection(
                self._first_tag, entity_set[self._entity_to].identifier)
            edge_set = entity_set._dict['{}_{}'.format(self._entity_from,
                                                       self._entity_to)]
            self._edge_label = '{}--{}'.format(self._first_tag, self._last_tag)
            self._edge_keys = tuple(
                [(self._first_tag, entity_set[self._entity_from].identifier),
                 (self._last_tag, entity_set[self._entity_to].identifier)] +
                [(self._edge_label, identifier)
                 for identifier in edge_set._additional_identifiers])
            try:
                self._querybuilder.add_projection(
                    self._edge_label, edge_set._additional_identifiers)
            except InputValidationError as e:
                raise KeyError(
                    "The key for the edge is invalid.\n"
                    "Are the entities really connected, or have you overwritten the edge-tag?"
                )

    def _load_results(self, target_set, operational_set):
        """
        :param target_set: The set to load the results into
        :param operational_set: Where the results originate from (walkers)
        """
        # I check that I have primary keys
        primkeys = operational_set[self._entity_from].get_keys()
        # Empty the target set, so that only these results are inside
        target_set.empty()
        if primkeys:
            self._querybuilder.add_filter(self._first_tag, {
                operational_set[self._entity_from].identifier: {
                    'in': primkeys
                }
            })
            qres = self._querybuilder.dict()
            # These are the new results returned by the query
            target_set[self._entity_to].add_entities([
                item[self._last_tag][self._entity_to_identifier]
                for item in qres
            ])
            if self._track_edges:
                target_set['{}_{}'.format(
                    self._entity_to, self._entity_to)].add_entities([
                        tuple(item[key1][key2]
                              for (key1, key2) in self._edge_keys)
                        for item in qres
                    ])
コード例 #19
0
ファイル: exportfile.py プロジェクト: mikeatm/aiida_core
def create(outfile, computers, groups, nodes, group_names, no_parents,
           no_calc_outputs, overwrite, archive_format):
    """
    Export nodes and groups of nodes to an archive file for backup or sharing purposes
    """
    import sys
    from aiida.backends.utils import load_dbenv
    load_dbenv()
    from aiida.orm import Group, Node, Computer
    from aiida.orm.querybuilder import QueryBuilder
    from aiida.orm.importexport import export, export_zip

    node_id_set = set(nodes)
    group_dict = dict()

    if group_names:
        qb = QueryBuilder()
        qb.append(Group,
                  tag='group',
                  project=['*'],
                  filters={'name': {
                      'in': group_names
                  }})
        qb.append(Node, tag='node', member_of='group', project=['id'])
        res = qb.dict()

        group_dict.update({
            group['group']['*'].name: group['group']['*'].dbgroup
            for group in res
        })
        node_id_set.update([node['node']['id'] for node in res])

    if groups:
        qb = QueryBuilder()
        qb.append(Group,
                  tag='group',
                  project=['*'],
                  filters={'id': {
                      'in': groups
                  }})
        qb.append(Node, tag='node', member_of='group', project=['id'])
        res = qb.dict()

        group_dict.update({
            group['group']['*'].name: group['group']['*'].dbgroup
            for group in res
        })
        node_id_set.update([node['node']['id'] for node in res])

    # The db_groups that correspond to what was searched above
    dbgroups_list = group_dict.values()

    # Getting the nodes that correspond to the ids that were found above
    if len(node_id_set) > 0:
        qb = QueryBuilder()
        qb.append(Node,
                  tag='node',
                  project=['*'],
                  filters={'id': {
                      'in': node_id_set
                  }})
        node_list = [node[0] for node in qb.all()]
    else:
        node_list = list()

    # Check if any of the nodes wasn't found in the database.
    missing_nodes = node_id_set.difference(node.id for node in node_list)
    for node_id in missing_nodes:
        print >> sys.stderr, (
            'WARNING! Node with pk={} not found, skipping'.format(node_id))

    # The dbnodes of the above node list
    dbnode_list = [node.dbnode for node in node_list]

    if computers:
        qb = QueryBuilder()
        qb.append(Computer,
                  tag='comp',
                  project=['*'],
                  filters={'id': {
                      'in': set(computers)
                  }})
        computer_list = [computer[0] for computer in qb.all()]
        missing_computers = set(computers).difference(
            computer.id for computer in computer_list)

        for computer_id in missing_computers:
            print >> sys.stderr, (
                'WARNING! Computer with pk={} not found, skipping'.format(
                    computer_id))
    else:
        computer_list = []

    # The dbcomputers of the above computer list
    dbcomputer_list = [computer.dbcomputer for computer in computer_list]

    what_list = dbnode_list + dbcomputer_list + dbgroups_list
    additional_kwargs = dict()

    if archive_format == 'zip':
        export_function = export_zip
        additional_kwargs.update({'use_compression': True})
    elif archive_format == 'zip-uncompressed':
        export_function = export_zip
        additional_kwargs.update({'use_compression': False})
    elif archive_format == 'tar.gz':
        export_function = export
    else:
        print >> sys.stderr, 'invalid --archive-format value {}'.format(
            archive_format)
        sys.exit(1)

    try:
        export_function(what=what_list,
                        also_parents=not no_parents,
                        also_calc_outputs=not no_calc_outputs,
                        outfile=outfile,
                        overwrite=overwrite,
                        **additional_kwargs)
    except IOError as e:
        print >> sys.stderr, 'IOError: {}'.format(e.message)
        sys.exit(1)
コード例 #20
0
ファイル: base.py プロジェクト: wsmorgan/aiida_core
class BaseTranslator(object):
    """
    Generic class for translator. It contains the methods
    required to build a related QueryBuilder object
    """

    # A label associated to the present class
    __label__ = None
    # The AiiDA class one-to-one associated to the present class
    _aiida_class = None
    # The string name of the AiiDA class
    _aiida_type = None

    # The string associated to the AiiDA class in the query builder lexicon
    _qb_type = None

    # If True (False) the corresponding AiiDA class has (no) uuid property
    _has_uuid = None

    _result_type = __label__

    _default = _default_projections = []
    _is_qb_initialized = False
    _is_id_query = None
    _total_count = None

    def __init__(self, Class=None, **kwargs):
        """
        Initialise the parameters.
        Create the basic query_help

        keyword Class (default None but becomes this class): is the class
        from which one takes the initial values of the attributes. By default
        is this class so that class atributes are  translated into object
        attributes. In case of inheritance one cane use the
        same constructore but pass the inheriting class to pass its attributes.
        """

        # Assume default class is this class (cannot be done in the
        # definition as it requires self)
        if Class is None:
            Class = self.__class__

        # Assign class parameters to the object
        self.__label__ = Class.__label__
        self._aiida_class = Class._aiida_class
        self._aiida_type = Class._aiida_type
        self._qb_type = Class._qb_type
        self._result_type = Class.__label__

        self._default = Class._default
        self._default_projections = Class._default_projections
        self._is_qb_initialized = Class._is_qb_initialized
        self._is_id_query = Class._is_id_query
        self._total_count = Class._total_count

        # Basic filter (dict) to set the identity of the uuid. None if
        #  no specific node is requested
        self._id_filter = None

        # basic query_help object
        self._query_help = {
            "path": [{
                "type": self._qb_type,
                "label": self.__label__
            }],
            "filters": {},
            "project": {},
            "order_by": {}
        }
        # query_builder object (No initialization)
        self.qb = QueryBuilder()

        self.LIMIT_DEFAULT = kwargs['LIMIT_DEFAULT']

        if 'custom_schema' in kwargs:
            self.custom_schema = kwargs['custom_schema']
        else:
            self.custom_schema = None

    def __repr__(self):
        """
        This function is required for the caching system to be able to compare
        two NodeTranslator objects. Comparison is done on the value returned by
        __repr__
        :return: representation of NodeTranslator objects. Returns nothing
        because the inputs of self.get_nodes are sufficient to determine the
        identity of two queries.
        """
        return ""

    def get_schema(self):

        # Construct the full class string
        class_string = 'aiida.orm.' + self._aiida_type

        # Load correspondent orm class
        orm_class = get_object_from_string(class_string)

        # Construct the json object to be returned
        basic_schema = orm_class.get_db_columns()
        """
        Determine the API schema (spartially overlapping with the ORM/database one).
        When the ORM is based on django, however, attributes and extras are not colums of the database but are
        nevertheless         valid projections. We add them by hand into the API schema.
        """
        # TODO change the get_db_columns method to include also relationships such as attributes, extras, input,
        # and outputs        in order to have a more complete definition of the schema.

        if self._default_projections == ['**']:
            schema = basic_schema  # No custom schema, take the basic one
        else:

            # Non-schema possible projections (only for nodes when django is backend)
            non_schema_projs = ('attributes', 'extras')
            # Sub-projections of JSON fields (applies to both SQLA and Django)
            non_schema_proj_prefix = ('attributes.', 'extras.')

            schema_key = []
            schema_values = []

            for k in self._default_projections:
                if k in basic_schema.keys():
                    schema_key.append(k)
                    schema_values.append(basic_schema[k])
                elif k in non_schema_projs:
                    # Catches 'attributes' and 'extras'
                    schema_key.append(k)
                    value = dict(type=dict, is_foreign_key=False)
                    schema_values.append(value)
                elif k.startswith(non_schema_proj_prefix):
                    # Catches 'attributes.<key>' and 'extras.<key>'
                    schema_key.append(k)
                    value = dict(type=None, is_foreign_key=False)
                    schema_values.append(value)

            schema = dict(zip(schema_key, schema_values))

        def table2resource(table_name):
            """
            Convert the related_tablevalues to the RESTAPI resources
            (orm class/db table ==> RESTapi resource)

            :param table_name (str): name of the table (in SQLA is __tablename__)
            :return: resource_name (str): name of the API resource
            """
            # TODO Consider ways to make this function backend independent (one
            # idea would be to go from table name to aiida class name which is
            # unique)
            if BACKEND == BACKEND_DJANGO:
                (spam, resource_name) = issingular(table_name[2:].lower())
            elif BACKEND == BACKEND_SQLA:
                (spam, resource_name) = issingular(table_name[5:])
            elif BACKEND is None:
                raise ConfigurationError("settings.BACKEND has not been set.\n"
                                         "Hint: Have you called "
                                         "aiida.load_dbenv?")
            else:
                raise ConfigurationError(
                    "Unknown settings.BACKEND: {}".format(BACKEND))
            return resource_name

        for k, v in schema.iteritems():

            # Add custom fields to the column dictionaries
            if 'fields' in self.custom_schema:
                if k in self.custom_schema['fields'].keys():
                    schema[k].update(self.custom_schema['fields'][k])

            # Convert python types values into strings
            schema[k]['type'] = str(schema[k]['type'])[7:-2]

            # Construct the 'related resource' field from the 'related_table'
            # field
            if v['is_foreign_key'] == True:
                schema[k]['related_resource'] = table2resource(
                    schema[k].pop('related_table'))

        # TODO Construct the ordering (all these things have to be moved in matcloud_backend)
        if self._default_projections != ['**']:
            ordering = self._default_projections
        else:
            # random ordering if not set explicitely in
            ordering = schema.keys()

        return dict(fields=schema, ordering=ordering)

    def init_qb(self):
        """
        Initialize query builder object by means of _query_help
        """
        self.qb.__init__(**self._query_help)
        self._is_qb_initialized = True

    def count(self):
        """
        Count the number of rows returned by the query and set total_count
        """
        if self._is_qb_initialized:
            self._total_count = self.qb.count()
        else:
            raise InvalidOperation("query builder object has not been "
                                   "initialized.")

            # def caching_method(self):
            #     """
            #     class method for caching. It is a wrapper of the
            # flask_cache memoize
            #     method. To be used as a decorator
            #     :return: the flask_cache memoize method with the timeout kwarg
            #     corrispondent to the class
            #     """
            #     return cache.memoize()
            #

            #    @cache.memoize(timeout=CACHING_TIMEOUTS[self.__label__])

    def get_total_count(self):
        """
        Returns the number of rows of the query
        :return: total_count
        """
        ## Count the results if needed
        if not self._total_count:
            self.count()

        return self._total_count

    def set_filters(self, filters={}):
        """
        Add filters in query_help.

        :param filters: it is a dictionary where keys are the tag names
         given in the path in query_help and their values are the dictionary
         of filters want to add for that tag name. Format for the Filters
         dictionary:
         filters = {
                    "tag1" : {k1:v1, k2:v2},
                    "tag2" : {k1:v1, k2:v2},
                  }
        :return: query_help dict including filters if any.
        """
        if isinstance(filters, dict):
            if len(filters) > 0:
                for tag, tag_filters in filters.iteritems():
                    if len(tag_filters) > 0 and isinstance(tag_filters, dict):
                        self._query_help["filters"][tag] = {}
                        for filter_key, filter_value in tag_filters.iteritems(
                        ):
                            if filter_key == "pk":
                                filter_key = pk_dbsynonym
                            self._query_help["filters"][tag][filter_key] \
                                = filter_value
        else:
            raise InputValidationError(
                "Pass data in dictionary format where "
                "keys are the tag names given in the "
                "path in query_help and and their values"
                " are the dictionary of filters want "
                "to add for that tag name.")

    def get_default_projections(self):
        """
        method to get default projections of the node
        :return: self._default_projections
        """
        return self._default_projections

    def set_default_projections(self):
        """
        It calls the set_projections() methods internally to add the
        default projections in query_help

        :return: None
        """
        self.set_projections({self.__label__: self._default_projections})

    def set_projections(self, projections):
        """
        add the projections in query_help

        :param projections: it is a dictionary where keys are the tag names
         given in the path in query_help and values are the list of the names
         you want to project in the final output
        :return: updated query_help with projections
        """
        if isinstance(projections, dict):
            if len(projections) > 0:
                for project_key, project_list in projections.iteritems():
                    self._query_help["project"][project_key] = project_list
        else:
            raise InputValidationError("Pass data in dictionary format where "
                                       "keys are the tag names given in the "
                                       "path in query_help and values are the "
                                       "list of the names you want to project "
                                       "in the final output")

    def set_order(self, orders):
        """
        Add order_by clause in query_help
        :param orders: dictionary of orders you want to apply on final
        results
        :return: None or exception if any.
        """
        ## Validate input
        if type(orders) is not dict:
            raise InputValidationError("orders has to be a dictionary"
                                       "compatible with the 'order_by' section"
                                       "of the query_help")

        ## Auxiliary_function to get the ordering cryterion
        def def_order(columns):
            """
            Takes a list of signed column names ex. ['id', '-ctime',
            '+mtime']
            and transforms it in a order_by compatible dictionary
            :param columns: (list of strings)
            :return: a dictionary
            """
            order_dict = {}
            for column in columns:
                if column[0] == '-':
                    order_dict[column[1:]] = 'desc'
                elif column[0] == '+':
                    order_dict[column[1:]] = 'asc'
                else:
                    order_dict[column] = 'asc'
            if order_dict.has_key('pk'):
                order_dict[pk_dbsynonym] = order_dict.pop('pk')
            return order_dict

        ## Assign orderby field query_help
        for tag, columns in orders.iteritems():
            self._query_help['order_by'][tag] = def_order(columns)

    def set_query(self, filters=None, orders=None, projections=None, id=None):
        """
        Adds filters, default projections, order specs to the query_help,
        and initializes the qb object

        :param filters: dictionary with the filters
        :param orders: dictionary with the order for each tag
        :param orders: dictionary with the projections
        :param id (integer): id of a specific node
        """

        tagged_filters = {}

        ## Check if filters are well defined and construct an ad-hoc filter
        # for id_query
        if id is not None:
            self._is_id_query = True
            if self._result_type == self.__label__ and len(filters) > 0:
                raise RestInputValidationError("selecting a specific id does "
                                               "not "
                                               "allow to specify filters")

            try:
                self._check_id_validity(id)
            except RestValidationError as e:
                raise RestValidationError(e.message)
            else:
                tagged_filters[self.__label__] = self._id_filter
                if self._result_type is not self.__label__:
                    tagged_filters[self._result_type] = filters
        else:
            tagged_filters[self.__label__] = filters

        ## Add filters
        self.set_filters(tagged_filters)

        ## Add projections
        if projections is None:
            self.set_default_projections()
        else:
            tagged_projections = {self._result_type: projections}
            self.set_projections(tagged_projections)

        ##Add order_by
        if orders is not None:
            tagged_orders = {self._result_type: orders}
            self.set_order(tagged_orders)

        ## Initialize the query_object
        self.init_qb()

    def get_query_help(self):
        """
        :return: return QB json dictionary
        """
        return self._query_help

    def set_limit_offset(self, limit=None, offset=None):
        """
        sets limits and offset directly to the query_builder object

        :param limit:
        :param offset:
        :return:
        """

        ## mandatory params
        # none

        ## non-mandatory params
        if limit is not None:
            try:
                limit = int(limit)
            except ValueError:
                raise InputValidationError("Limit value must be an integer")
            if limit > self.LIMIT_DEFAULT:
                raise RestValidationError("Limit and perpage cannot be bigger "
                                          "than {}".format(self.LIMIT_DEFAULT))
        else:
            limit = self.LIMIT_DEFAULT

        if offset is not None:
            try:
                offset = int(offset)
            except ValueError:
                raise InputValidationError("Offset value must be an "
                                           "integer")

        if self._is_qb_initialized:
            if limit is not None:
                self.qb.limit(limit)
            else:
                pass
            if offset is not None:
                self.qb.offset(offset)
            else:
                pass
        else:
            raise InvalidOperation("query builder object has not been "
                                   "initialized.")

    def get_formatted_result(self, label):
        """
        Runs the query and retrieves results tagged as "label"
        :param label (string): the tag of the results to be extracted out of
        the query rows.
        :return: a list of the query results
        """

        if not self._is_qb_initialized:
            raise InvalidOperation("query builder object has not been "
                                   "initialized.")

        results = []
        if self._total_count > 0:
            results = [res[label] for res in self.qb.dict()]

        # TODO think how to make it less hardcoded
        if self._result_type == 'input_of':
            return {'inputs': results}
        elif self._result_type == 'output_of':
            return {'outputs': results}
        else:
            return {self.__label__: results}

    def get_results(self):
        """
        Returns either list of nodes or details of single node from database

        :return: either list of nodes or details of single node
        from database
        """

        ## Check whether the querybuilder object has been initialized
        if not self._is_qb_initialized:
            raise InvalidOperation("query builder object has not been "
                                   "initialized.")

        ## Count the total number of rows returned by the query (if not
        # already done)
        if self._total_count is None:
            self.count()

        ## Retrieve data
        data = self.get_formatted_result(self._result_type)
        return data

    def _check_id_validity(self, id):
        """
        Checks whether a id full id or id starting pattern) corresponds to
         an object of the expected type,
        whenever type is a valid column of the database (ex. for nodes,
        but not for users)
        
        :param id: id, or id starting pattern
        
        :return: True if id valid (invalid). If True, sets the
            id filter attribute correctly
            
        :raise: RestValidationError if No node is found or id pattern does
        not identify a unique node
        """
        from aiida.common.exceptions import MultipleObjectsError, NotExistent

        from aiida.orm.utils import create_node_id_qb

        if self._has_uuid:

            # For consistency check that tid is a string
            if not isinstance(id, (str, unicode)):
                raise RestValidationError('parameter id has to be an string')

            qb = create_node_id_qb(uuid=id, parent_class=self._aiida_class)
        else:

            # Similarly, check that id is an integer
            if not isinstance(id, int):
                raise RestValidationError('parameter id has to be an integer')

            qb = create_node_id_qb(pk=id, parent_class=self._aiida_class)

        # project only the pk
        qb.add_projection('node', ['id'])
        # for efficiency i don;t go further than two results
        qb.limit(2)

        try:
            pk = qb.one()[0]
        except MultipleObjectsError:
            raise RestValidationError("More than one node found."
                                      " Provide longer starting pattern"
                                      " for id.")
        except NotExistent:
            raise RestValidationError("either no object's id starts"
                                      " with '{}' or the corresponding object"
                                      " is not of type aiida.orm.{}".format(
                                          id, self._aiida_type))
        else:
            # create a permanent filter
            self._id_filter = {'id': {'==': pk}}
            return True