Example #1
0
    def _write(self, atoms, key_value_pairs, data):
        Database._write(self, atoms, key_value_pairs, data)

        bigdct = {}
        ids = []
        nextid = 1

        if (isinstance(self.filename, basestring)
                and os.path.isfile(self.filename)):
            try:
                bigdct, ids, nextid = self._read_json()
            except (SyntaxError, ValueError):
                pass

        if isinstance(atoms, AtomsRow):
            row = atoms
            unique_id = row.unique_id
            for id in ids:
                if bigdct[id]['unique_id'] == unique_id:
                    break
            else:
                id = None
            mtime = now()
        else:
            row = AtomsRow(atoms)
            row.ctime = mtime = now()
            row.user = os.getenv('USER')
            id = None

        dct = {}
        for key in row.__dict__:
            if key[0] == '_' or key in row._keys or key == 'id':
                continue
            dct[key] = row[key]

        dct['mtime'] = mtime

        kvp = key_value_pairs or row.key_value_pairs
        if kvp:
            dct['key_value_pairs'] = kvp

        data = data or row.get('data')
        if data:
            dct['data'] = data

        constraints = row.get('constraints')
        if constraints:
            dct['constraints'] = constraints

        if id is None:
            id = nextid
            ids.append(id)
            nextid += 1

        bigdct[id] = dct
        self._write_json(bigdct, ids, nextid)
        return id
Example #2
0
    def _write(self, atoms, key_value_pairs, data):
        Database._write(self, atoms, key_value_pairs, data)

        bigdct = {}
        ids = []
        nextid = 1

        if (isinstance(self.filename, basestring) and
            os.path.isfile(self.filename)):
            try:
                bigdct, ids, nextid = self._read_json()
            except (SyntaxError, ValueError):
                pass

        if isinstance(atoms, AtomsRow):
            row = atoms
            unique_id = row.unique_id
            for id in ids:
                if bigdct[id]['unique_id'] == unique_id:
                    break
            else:
                id = None
            mtime = now()
        else:
            row = AtomsRow(atoms)
            row.ctime = mtime = now()
            row.user = os.getenv('USER')
            id = None

        dct = {}
        for key in row.__dict__:
            if key[0] == '_' or key in row._keys or key == 'id':
                continue
            dct[key] = row[key]

        dct['mtime'] = mtime

        kvp = key_value_pairs or row.key_value_pairs
        if kvp:
            dct['key_value_pairs'] = kvp

        data = data or row.get('data')
        if data:
            dct['data'] = data

        constraints = row.get('constraints')
        if constraints:
            dct['constraints'] = constraints

        if id is None:
            id = nextid
            ids.append(id)
            nextid += 1

        bigdct[id] = dct
        self._write_json(bigdct, ids, nextid)
        return id
Example #3
0
File: jsondb.py Project: jboes/ase
 def _select(self, keys, cmps, explain=False, verbosity=0,
             limit=None, offset=0, sort=None):
     if explain:
         yield {'explain': (0, 0, 0, 'scan table')}
         return
         
     if sort:
         if sort[0] == '-':
             reverse = True
             sort = sort[1:]
         else:
             reverse = False
         
         def f(row):
             return row[sort]
             
         rows = sorted(self._select(keys + [sort], cmps),
                       key=f, reverse=reverse)
         if limit:
             rows = rows[offset:offset + limit]
         for row in rows:
             yield row
         return
         
     try:
         bigdct, ids, nextid = self._read_json()
     except IOError:
         return
         
     if not limit:
         limit = -offset - 1
         
     cmps = [(key, ops[op], val) for key, op, val in cmps]
     n = 0
     for id in ids:
         if n - offset == limit:
             return
         row = AtomsRow(bigdct[id])
         row.id = id
         for key in keys:
             if key not in row:
                 break
         else:
             for key, op, val in cmps:
                 if isinstance(key, int):
                     value = np.equal(row.numbers, key).sum()
                 else:
                     value = row.get(key)
                     if key == 'pbc':
                         assert op in [ops['='], ops['!=']]
                         value = ''.join('FT'[x] for x in value)
                 if value is None or not op(value, val):
                     break
             else:
                 if n >= offset:
                     yield row
                 n += 1
Example #4
0
def Json2Atoms(jsonstring):
    """Read a JSON string and return an Atoms object"""

    from ase.io.jsonio import decode
    from ase.db.row import AtomsRow

    dct = decode(jsonstring)
    row = AtomsRow(dct)

    return row.toatoms(attach_calculator=False,
                       add_additional_information=True)
Example #5
0
    def _write(self, atoms, key_value_pairs, data, id):
        Database._write(self, atoms, key_value_pairs, data)

        bigdct = {}
        ids = []
        nextid = 1

        if (isinstance(self.filename, str) and
            os.path.isfile(self.filename)):
            try:
                bigdct, ids, nextid = self._read_json()
            except (SyntaxError, ValueError):
                pass

        mtime = now()

        if isinstance(atoms, AtomsRow):
            row = atoms
        else:
            row = AtomsRow(atoms)
            row.ctime = mtime
            row.user = os.getenv('USER')

        dct = {}
        for key in row.__dict__:
            if key[0] == '_' or key in row._keys or key == 'id':
                continue
            dct[key] = row[key]

        dct['mtime'] = mtime

        if key_value_pairs:
            dct['key_value_pairs'] = key_value_pairs

        if data:
            dct['data'] = data

        constraints = row.get('constraints')
        if constraints:
            dct['constraints'] = constraints

        if id is None:
            id = nextid
            ids.append(id)
            nextid += 1
        else:
            assert id in bigdct

        bigdct[id] = dct
        self._write_json(bigdct, ids, nextid)
        return id
Example #6
0
 def _get_row(self, id):
     bigdct, ids, nextid = self._read_json()
     if id is None:
         assert len(ids) == 1
         id = ids[0]
     dct = bigdct[id]
     dct['id'] = id
     return AtomsRow(dct)
Example #7
0
def get_traj_str(filename):
    from ase.db.row import AtomsRow
    from ase.io.jsonio import encode
    atoms = read_ase(filename)
    row = AtomsRow(atoms)
    dct = {}
    for key in row.__dict__:
        if key[0] == '_' or key in row._keys or key == 'id':
            continue
        dct[key] = row[key]
    constraints = row.get('constraints')
    if constraints:
        dct['constraints'] = constraints

    txt = ','.join('"{0}": {1}'.format(key, encode(dct[key]))
                   for key in sorted(dct.keys()))

    atoms_txt = '{{{0}}}'.format(txt)
    return atoms_txt
Example #8
0
    def _convert_tuple_to_row(self, values):
        deblob = self.deblob
        decode = self.decode

        values = self._old2new(values)
        dct = {'id': values[0],
               'unique_id': values[1],
               'ctime': values[2],
               'mtime': values[3],
               'user': values[4],
               'numbers': deblob(values[5], np.int32),
               'positions': deblob(values[6], shape=(-1, 3)),
               'cell': deblob(values[7], shape=(3, 3))}

        if values[8] is not None:
            dct['pbc'] = (values[8] & np.array([1, 2, 4])).astype(bool)
        if values[9] is not None:
            dct['initial_magmoms'] = deblob(values[9])
        if values[10] is not None:
            dct['initial_charges'] = deblob(values[10])
        if values[11] is not None:
            dct['masses'] = deblob(values[11])
        if values[12] is not None:
            dct['tags'] = deblob(values[12], np.int32)
        if values[13] is not None:
            dct['momenta'] = deblob(values[13], shape=(-1, 3))
        if values[14] is not None:
            dct['constraints'] = values[14]
        if values[15] is not None:
            dct['calculator'] = values[15]
        if values[16] is not None:
            dct['calculator_parameters'] = decode(values[16])
        if values[17] is not None:
            dct['energy'] = values[17]
        if values[18] is not None:
            dct['free_energy'] = values[18]
        if values[19] is not None:
            dct['forces'] = deblob(values[19], shape=(-1, 3))
        if values[20] is not None:
            dct['stress'] = deblob(values[20])
        if values[21] is not None:
            dct['dipole'] = deblob(values[21])
        if values[22] is not None:
            dct['magmoms'] = deblob(values[22])
        if values[23] is not None:
            dct['magmom'] = values[23]
        if values[24] is not None:
            dct['charges'] = deblob(values[24])
        if values[25] != '{}':
            dct['key_value_pairs'] = decode(values[25])
        if len(values) >= 27 and values[26] != 'null':
            dct['data'] = decode(values[26])

        return AtomsRow(dct)
Example #9
0
def atoms2json(structure, additional_information=None):
    """Serialize an ASE Structure definition to JSON and return it as a string"""

    import json, os
    from ase.db.row import AtomsRow
    from ase.db.core import now
    from ase.io.jsonio import MyEncoder as AseJsonEncoder

    row = AtomsRow(structure)  # this is what ASE would store in its DB
    row.ctime = mtime = now(
    )  # the Row object has an attribute ctime, but not mtime,
    # we have to wiggle it into the dict later
    row.user = os.getenv("USER")

    dct = row.__dict__.copy()
    del dct["_keys"], dct["_data"], dct[
        "_constraints"]  # containing useless default entries that shouldn't be stored
    dct["mtime"] = mtime
    dct["key_value_pairs"] = additional_information if additional_information else {}

    return json.dumps(dct, sort_keys=True, cls=AseJsonEncoder)
Example #10
0
 def _read(self):
     if self._names:
         return
     bigdct = read_json(self.filename)
     for id in bigdct['ids']:
         dct = bigdct[id]
         kvp = dct['key_value_pairs']
         name = str(kvp['name'])
         self._names.append(name)
         self._systems[name] = AtomsRow(dct).toatoms()
         del kvp['name']
         self._data[name] = dict((str(k), v) for k, v in kvp.items())
Example #11
0
    def fromdict(self, dct):
        """Restore calculator from a :func:`~ase.calculators.vasp.Vasp2.asdicti`
        dictionary.

        Parameters:

        dct: Dictionary
            The dictionary which is used to restore the calculator state.
        """
        if 'vasp_version' in dct:
            self.version = dct['vasp_version']
        if 'inputs' in dct:
            self.set(**dct['inputs'])
            self._store_param_state()
        if 'atoms' in dct:
            from ase.db.row import AtomsRow
            atoms = AtomsRow(dct['atoms']).toatoms()
            self.set_atoms(atoms)
        if 'results' in dct:
            self.results.update(dct['results'])
Example #12
0
    def _convert_tuple_to_row(self, values):
        deblob = self.deblob
        decode = self.decode

        values = self._old2new(values)
        dct = {
            'id': values[0],
            'unique_id': values[1],
            'ctime': values[2],
            'mtime': values[3],
            'user': values[4],
            'numbers': deblob(values[5], np.int32),
            'positions': deblob(values[6], shape=(-1, 3)),
            'cell': deblob(values[7], shape=(3, 3))
        }

        if values[8] is not None:
            dct['pbc'] = (values[8] & np.array([1, 2, 4])).astype(bool)
        if values[9] is not None:
            dct['initial_magmoms'] = deblob(values[9])
        if values[10] is not None:
            dct['initial_charges'] = deblob(values[10])
        if values[11] is not None:
            dct['masses'] = deblob(values[11])
        if values[12] is not None:
            dct['tags'] = deblob(values[12], np.int32)
        if values[13] is not None:
            dct['momenta'] = deblob(values[13], shape=(-1, 3))
        if values[14] is not None:
            dct['constraints'] = values[14]
        if values[15] is not None:
            dct['calculator'] = values[15]
        if values[16] is not None:
            dct['calculator_parameters'] = decode(values[16])
        if values[17] is not None:
            dct['energy'] = values[17]
        if values[18] is not None:
            dct['free_energy'] = values[18]
        if values[19] is not None:
            dct['forces'] = deblob(values[19], shape=(-1, 3))
        if values[20] is not None:
            dct['stress'] = deblob(values[20])
        if values[21] is not None:
            dct['dipole'] = deblob(values[21])
        if values[22] is not None:
            dct['magmoms'] = deblob(values[22])
        if values[23] is not None:
            dct['magmom'] = values[23]
        if values[24] is not None:
            dct['charges'] = deblob(values[24])
        if values[25] != '{}':
            dct['key_value_pairs'] = decode(values[25])
        if len(values) >= 27 and values[26] != 'null':
            dct['data'] = decode(values[26], lazy=True)

        # Now we need to update with info from the external tables
        external_tab = self._get_external_table_names()
        tables = {}
        for tab in external_tab:
            row = self._read_external_table(tab, dct["id"])
            tables[tab] = row

        dct.update(tables)
        return AtomsRow(dct)
Example #13
0
    def _write(self, atoms, key_value_pairs, data, id):
        ext_tables = key_value_pairs.pop("external_tables", {})
        Database._write(self, atoms, key_value_pairs, data)

        mtime = now()

        encode = self.encode
        blob = self.blob

        if not isinstance(atoms, AtomsRow):
            row = AtomsRow(atoms)
            row.ctime = mtime
            row.user = os.getenv('USER')
        else:
            row = atoms
            # Extract the external tables from AtomsRow
            names = self._get_external_table_names()
            for name in names:
                new_table = row.get(name, {})
                if new_table:
                    ext_tables[name] = new_table

        if not id and not key_value_pairs and not ext_tables:
            key_value_pairs = row.key_value_pairs

        for k, v in ext_tables.items():
            dtype = self._guess_type(v)
            self._create_table_if_not_exists(k, dtype)

        constraints = row._constraints
        if constraints:
            if isinstance(constraints, list):
                constraints = encode(constraints)
        else:
            constraints = None

        values = (row.unique_id, row.ctime, mtime, row.user, blob(row.numbers),
                  blob(row.positions), blob(row.cell),
                  int(np.dot(row.pbc,
                             [1, 2, 4])), blob(row.get('initial_magmoms')),
                  blob(row.get('initial_charges')), blob(row.get('masses')),
                  blob(row.get('tags')), blob(row.get('momenta')), constraints)

        if 'calculator' in row:
            values += (row.calculator, encode(row.calculator_parameters))
        else:
            values += (None, None)

        if not data:
            data = row._data

        with self.managed_connection() as con:
            if not isinstance(data, (str, bytes)):
                data = encode(data, binary=self.version >= 9)

            values += (row.get('energy'), row.get('free_energy'),
                       blob(row.get('forces')), blob(row.get('stress')),
                       blob(row.get('dipole')), blob(row.get('magmoms')),
                       row.get('magmom'), blob(row.get('charges')),
                       encode(key_value_pairs), data, len(row.numbers),
                       float_if_not_none(row.get('fmax')),
                       float_if_not_none(row.get('smax')),
                       float_if_not_none(row.get('volume')), float(row.mass),
                       float(row.charge))

            cur = con.cursor()
            if id is None:
                q = self.default + ', ' + ', '.join('?' * len(values))
                cur.execute('INSERT INTO systems VALUES ({})'.format(q),
                            values)
                id = self.get_last_id(cur)
            else:
                self._delete(cur, [id], [
                    'keys', 'text_key_values', 'number_key_values', 'species'
                ])
                q = ', '.join(name + '=?' for name in self.columnnames[1:])
                cur.execute('UPDATE systems SET {} WHERE id=?'.format(q),
                            values + (id, ))

            count = row.count_atoms()
            if count:
                species = [(atomic_numbers[symbol], n, id)
                           for symbol, n in count.items()]
                cur.executemany('INSERT INTO species VALUES (?, ?, ?)',
                                species)

            text_key_values = []
            number_key_values = []
            for key, value in key_value_pairs.items():
                if isinstance(value, (numbers.Real, np.bool_)):
                    number_key_values.append([key, float(value), id])
                else:
                    assert isinstance(value, str)
                    text_key_values.append([key, value, id])

            cur.executemany('INSERT INTO text_key_values VALUES (?, ?, ?)',
                            text_key_values)
            cur.executemany('INSERT INTO number_key_values VALUES (?, ?, ?)',
                            number_key_values)
            cur.executemany('INSERT INTO keys VALUES (?, ?)',
                            [(key, id) for key in key_value_pairs])

            # Insert entries in the valid tables
            for tabname in ext_tables.keys():
                entries = ext_tables[tabname]
                entries['id'] = id
                self._insert_in_external_table(cur,
                                               name=tabname,
                                               entries=ext_tables[tabname])

        return id
Example #14
0
    def _write(self, atoms, key_value_pairs, data, id):
        Database._write(self, atoms, key_value_pairs, data)
        encode = self.encode

        con = self.connection or self._connect()
        self._initialize(con)
        cur = con.cursor()

        mtime = now()

        blob = self.blob

        text_key_values = []
        number_key_values = []

        if not isinstance(atoms, AtomsRow):
            row = AtomsRow(atoms)
            row.ctime = mtime
            row.user = os.getenv('USER')
        else:
            row = atoms

        if id:
            self._delete(cur, [id], ['keys', 'text_key_values',
                                     'number_key_values', 'species'])
        else:
            if not key_value_pairs:
                key_value_pairs = row.key_value_pairs

        constraints = row._constraints
        if constraints:
            if isinstance(constraints, list):
                constraints = encode(constraints)
        else:
            constraints = None

        values = (row.unique_id,
                  row.ctime,
                  mtime,
                  row.user,
                  blob(row.numbers),
                  blob(row.positions),
                  blob(row.cell),
                  int(np.dot(row.pbc, [1, 2, 4])),
                  blob(row.get('initial_magmoms')),
                  blob(row.get('initial_charges')),
                  blob(row.get('masses')),
                  blob(row.get('tags')),
                  blob(row.get('momenta')),
                  constraints)

        if 'calculator' in row:
            values += (row.calculator, encode(row.calculator_parameters))
        else:
            values += (None, None)

        if not data:
            data = row._data
        if not isinstance(data, basestring):
            data = encode(data)

        values += (row.get('energy'),
                   row.get('free_energy'),
                   blob(row.get('forces')),
                   blob(row.get('stress')),
                   blob(row.get('dipole')),
                   blob(row.get('magmoms')),
                   row.get('magmom'),
                   blob(row.get('charges')),
                   encode(key_value_pairs),
                   data,
                   len(row.numbers),
                   float_if_not_none(row.get('fmax')),
                   float_if_not_none(row.get('smax')),
                   float_if_not_none(row.get('volume')),
                   float(row.mass),
                   float(row.charge))

        if id is None:
            q = self.default + ', ' + ', '.join('?' * len(values))
            cur.execute('INSERT INTO systems VALUES ({})'.format(q),
                        values)
            id = self.get_last_id(cur)
        else:
            q = ', '.join(name + '=?' for name in self.columnnames[1:])
            cur.execute('UPDATE systems SET {} WHERE id=?'.format(q),
                        values + (id,))

        count = row.count_atoms()
        if count:
            species = [(atomic_numbers[symbol], n, id)
                       for symbol, n in count.items()]
            cur.executemany('INSERT INTO species VALUES (?, ?, ?)',
                            species)

        text_key_values = []
        number_key_values = []
        for key, value in key_value_pairs.items():
            if isinstance(value, (numbers.Real, np.bool_)):
                number_key_values.append([key, float(value), id])
            else:
                assert isinstance(value, basestring)
                text_key_values.append([key, value, id])

        cur.executemany('INSERT INTO text_key_values VALUES (?, ?, ?)',
                        text_key_values)
        cur.executemany('INSERT INTO number_key_values VALUES (?, ?, ?)',
                        number_key_values)
        cur.executemany('INSERT INTO keys VALUES (?, ?)',
                        [(key, id) for key in key_value_pairs])

        if self.connection is None:
            con.commit()
            con.close()

        return id
Example #15
0
    def _write(self, atoms, key_value_pairs, data):
        Database._write(self, atoms, key_value_pairs, data)

        con = self.connection or self._connect()
        self._initialize(con)
        cur = con.cursor()

        id = None

        if not isinstance(atoms, AtomsRow):
            row = AtomsRow(atoms)
            row.ctime = mtime = now()
            row.user = os.getenv('USER')
        else:
            row = atoms

            cur.execute('SELECT id FROM systems WHERE unique_id=?',
                        (row.unique_id,))
            results = cur.fetchall()
            if results:
                id = results[0][0]
                self._delete(cur, [id], ['keys', 'text_key_values',
                                         'number_key_values'])
            mtime = now()

        constraints = row._constraints
        if constraints:
            if isinstance(constraints, list):
                constraints = encode(constraints)
        else:
            constraints = None

        values = (row.unique_id,
                  row.ctime,
                  mtime,
                  row.user,
                  blob(row.numbers),
                  blob(row.positions),
                  blob(row.cell),
                  int(np.dot(row.pbc, [1, 2, 4])),
                  blob(row.get('initial_magmoms')),
                  blob(row.get('initial_charges')),
                  blob(row.get('masses')),
                  blob(row.get('tags')),
                  blob(row.get('momenta')),
                  constraints)

        if 'calculator' in row:
            if not isinstance(row.calculator_parameters, basestring):
                row.calculator_parameters = encode(row.calculator_parameters)
            values += (row.calculator,
                       row.calculator_parameters)
        else:
            values += (None, None)

        if key_value_pairs is None:
            key_value_pairs = row.key_value_pairs

        if not data:
            data = row._data
        if not isinstance(data, basestring):
            data = encode(data)

        values += (row.get('energy'),
                   row.get('free_energy'),
                   blob(row.get('forces')),
                   blob(row.get('stress')),
                   blob(row.get('dipole')),
                   blob(row.get('magmoms')),
                   row.get('magmom'),
                   blob(row.get('charges')),
                   encode(key_value_pairs),
                   data,
                   len(row.numbers),
                   float_if_not_none(row.get('fmax')),
                   float_if_not_none(row.get('smax')),
                   float_if_not_none(row.get('volume')),
                   float(row.mass),
                   float(row.charge))

        if id is None:
            q = self.default + ', ' + ', '.join('?' * len(values))
            cur.execute('INSERT INTO systems VALUES ({0})'.format(q),
                        values)
        else:
            q = ', '.join(line.split()[0].lstrip() + '=?'
                          for line in init_statements[0].splitlines()[2:])
            cur.execute('UPDATE systems SET {0} WHERE id=?'.format(q),
                        values + (id,))

        if id is None:
            id = self.get_last_id(cur)

            count = row.count_atoms()
            if count:
                species = [(atomic_numbers[symbol], n, id)
                           for symbol, n in count.items()]
                cur.executemany('INSERT INTO species VALUES (?, ?, ?)',
                                species)

        text_key_values = []
        number_key_values = []
        for key, value in key_value_pairs.items():
            if isinstance(value, (float, int)):
                number_key_values.append([key, float(value), id])
            else:
                assert isinstance(value, basestring)
                text_key_values.append([key, value, id])

        cur.executemany('INSERT INTO text_key_values VALUES (?, ?, ?)',
                        text_key_values)
        cur.executemany('INSERT INTO number_key_values VALUES (?, ?, ?)',
                        number_key_values)
        cur.executemany('INSERT INTO keys VALUES (?, ?)',
                        [(key, id) for key in key_value_pairs])

        if self.connection is None:
            con.commit()
            con.close()

        return id
Example #16
0
    def _select(self,
                keys,
                cmps,
                explain=False,
                verbosity=0,
                limit=None,
                offset=0,
                sort=None,
                include_data=True):
        if explain:
            yield {'explain': (0, 0, 0, 'scan table')}
            return

        if sort:
            if sort[0] == '-':
                reverse = True
                sort = sort[1:]
            else:
                reverse = False

            def f(row):
                return row[sort]

            rows = sorted(self._select(keys + [sort], cmps),
                          key=f,
                          reverse=reverse)
            if limit:
                rows = rows[offset:offset + limit]
            for row in rows:
                yield row
            return

        try:
            bigdct, ids, nextid = self._read_json()
        except IOError:
            return

        if not limit:
            limit = -offset - 1

        cmps = [(key, ops[op], val) for key, op, val in cmps]
        n = 0
        for id in ids:
            if n - offset == limit:
                return
            dct = bigdct[id]
            if not include_data:
                dct.pop('data', None)
            row = AtomsRow(dct)
            row.id = id
            for key in keys:
                if key not in row:
                    break
            else:
                for key, op, val in cmps:
                    if isinstance(key, int):
                        value = np.equal(row.numbers, key).sum()
                    else:
                        value = row.get(key)
                        if key == 'pbc':
                            assert op in [ops['='], ops['!=']]
                            value = ''.join('FT'[x] for x in value)
                    if value is None or not op(value, val):
                        break
                else:
                    if n >= offset:
                        yield row
                    n += 1
Example #17
0
    def _write(self, atoms, key_value_pairs, data, id):
        ext_tables = key_value_pairs.pop("external_tables", {})
        Database._write(self, atoms, key_value_pairs, data)
        encode = self.encode

        con = self.connection or self._connect()
        self._initialize(con)
        cur = con.cursor()

        mtime = now()

        blob = self.blob

        text_key_values = []
        number_key_values = []

        if not isinstance(atoms, AtomsRow):
            row = AtomsRow(atoms)
            row.ctime = mtime
            row.user = os.getenv('USER')
        else:
            row = atoms

            # Extract the external tables from AtomsRow
            names = self._get_external_table_names(db_con=con)
            for name in names:
                new_table = row.get(name, {})
                if new_table:
                    ext_tables[name] = new_table

        if id:
            self._delete(
                cur, [id],
                ['keys', 'text_key_values', 'number_key_values', 'species'])
        else:
            if not key_value_pairs:
                key_value_pairs = row.key_value_pairs

        constraints = row._constraints
        if constraints:
            if isinstance(constraints, list):
                constraints = encode(constraints)
        else:
            constraints = None

        values = (row.unique_id, row.ctime, mtime, row.user, blob(row.numbers),
                  blob(row.positions), blob(row.cell),
                  int(np.dot(row.pbc,
                             [1, 2, 4])), blob(row.get('initial_magmoms')),
                  blob(row.get('initial_charges')), blob(row.get('masses')),
                  blob(row.get('tags')), blob(row.get('momenta')), constraints)

        if 'calculator' in row:
            values += (row.calculator, encode(row.calculator_parameters))
        else:
            values += (None, None)

        if not data:
            data = row._data
        if not isinstance(data, basestring):
            data = encode(data)

        values += (row.get('energy'), row.get('free_energy'),
                   blob(row.get('forces')), blob(row.get('stress')),
                   blob(row.get('dipole')), blob(row.get('magmoms')),
                   row.get('magmom'), blob(row.get('charges')),
                   encode(key_value_pairs), data, len(row.numbers),
                   float_if_not_none(row.get('fmax')),
                   float_if_not_none(row.get('smax')),
                   float_if_not_none(row.get('volume')), float(row.mass),
                   float(row.charge))

        if id is None:
            q = self.default + ', ' + ', '.join('?' * len(values))
            cur.execute('INSERT INTO systems VALUES ({})'.format(q), values)
            id = self.get_last_id(cur)
        else:
            q = ', '.join(name + '=?' for name in self.columnnames[1:])
            cur.execute('UPDATE systems SET {} WHERE id=?'.format(q),
                        values + (id, ))

        count = row.count_atoms()
        if count:
            species = [(atomic_numbers[symbol], n, id)
                       for symbol, n in count.items()]
            cur.executemany('INSERT INTO species VALUES (?, ?, ?)', species)

        text_key_values = []
        number_key_values = []
        for key, value in key_value_pairs.items():
            if isinstance(value, (numbers.Real, np.bool_)):
                number_key_values.append([key, float(value), id])
            else:
                assert isinstance(value, basestring)
                text_key_values.append([key, value, id])

        cur.executemany('INSERT INTO text_key_values VALUES (?, ?, ?)',
                        text_key_values)
        cur.executemany('INSERT INTO number_key_values VALUES (?, ?, ?)',
                        number_key_values)
        cur.executemany('INSERT INTO keys VALUES (?, ?)',
                        [(key, id) for key in key_value_pairs])

        # Update external tables
        valid_entries = []
        for k, v in ext_tables.items():
            try:
                # Guess the type of the value
                dtype = self._guess_type(v)
                self._create_table_if_not_exists(k, dtype, db_con=con)
                v["id"] = id
                valid_entries.append(k)
            except ValueError as exc:
                # Close the connection without committing
                if self.connection is None:
                    con.close()
                # Raise error again
                raise ValueError(exc)

        # Insert entries in the valid tables
        for tabname in valid_entries:
            try:
                self._insert_in_external_table(cur,
                                               name=tabname,
                                               entries=ext_tables[tabname])
            except ValueError as exc:
                # Close the connection without committing
                if self.connection is None:
                    con.close()
                # Raise the error again
                raise ValueError(exc)

        if self.connection is None:
            con.commit()
            con.close()

        return id
Example #18
0
    def update(self,
               id,
               atoms=None,
               delete_keys=[],
               data=None,
               **add_key_value_pairs):
        """Update and/or delete key-value pairs of row(s).

        id: int
            ID of row to update.
        atoms: Atoms object
            Optionally update the Atoms data (positions, cell, ...).
        data: dict
            Data dict to be added to the existing data.
        delete_keys: list of str
            Keys to remove.

        Use keyword arguments to add new key-value pairs.

        Returns number of key-value pairs added and removed.
        """

        if not isinstance(id, numbers.Integral):
            if isinstance(id, list):
                err = ('First argument must be an int and not a list.\n'
                       'Do something like this instead:\n\n'
                       'with db:\n'
                       '    for id in ids:\n'
                       '        db.update(id, ...)')
                raise ValueError(err)
            raise TypeError('id must be an int')

        check(add_key_value_pairs)

        row = self._get_row(id)
        kvp = row.key_value_pairs

        n = len(kvp)
        for key in delete_keys:
            kvp.pop(key, None)
        n -= len(kvp)
        m = -len(kvp)
        kvp.update(add_key_value_pairs)
        m += len(kvp)

        moredata = data
        data = row.get('data', {})
        if moredata:
            data.update(moredata)
        if not data:
            data = None

        if atoms:
            oldrow = row
            row = AtomsRow(atoms)
            # Copy over data, kvp, ctime, user and id
            row._data = oldrow._data
            row.__dict__.update(kvp)
            row._keys = list(kvp)
            row.ctime = oldrow.ctime
            row.user = oldrow.user
            row.id = id

        if atoms or os.path.splitext(self.filename)[1] == '.json':
            self._write(row, kvp, data, row.id)
        else:
            self._update(row.id, kvp, data)
        return m, n