コード例 #1
0
ファイル: aff.py プロジェクト: engelund/SIESTA_ASE
def print_aff_info(filename, verbose=False, *args):
    b = affopen(filename, 'r')
    indices = [int(args.pop())] if args else range(len(b))
    print('{0}  (tag: "{1}", {2})'.format(filename, b.get_tag(),
                                          plural(len(b), 'item')))
    for i in indices:
        print('item #{0}:'.format(i))
        print(b[i].tostr(verbose))
コード例 #2
0
    def set_positions(self, atoms=None):
        """Update the positions of the atoms and initialize wave functions."""
        spos_ac = self.initialize_positions(atoms)

        nlcao, nrand = self.wfs.initialize(self.density, self.hamiltonian,
                                           spos_ac)
        if nlcao + nrand:
            self.log('Creating initial wave functions:')
            if nlcao:
                self.log(' ', plural(nlcao, 'band'), 'from LCAO basis set')
            if nrand:
                self.log(' ', plural(nrand, 'band'), 'from random numbers')
            self.log()

        self.wfs.eigensolver.reset()
        self.scf.reset()
        print_positions(self.atoms, self.log)
コード例 #3
0
def print_ulm_info(filename, index=None, verbose=False):
    b = ulmopen(filename, 'r')
    if index is None:
        indices = range(len(b))
    else:
        indices = [index]
    print('{0}  (tag: "{1}", {2})'.format(filename, b.get_tag(),
                                          plural(len(b), 'item')))
    for i in indices:
        print('item #{0}:'.format(i))
        print(b[i].tostr(verbose))
コード例 #4
0
ファイル: ulm.py プロジェクト: rosswhitfield/ase
def print_ulm_info(filename, index=None, verbose=False):
    b = ulmopen(filename, 'r')
    if index is None:
        indices = range(len(b))
    else:
        indices = [index]
    print('{0}  (tag: "{1}", {2})'.format(filename, b.get_tag(),
                                          plural(len(b), 'item')))
    for i in indices:
        print('item #{0}:'.format(i))
        print(b[i].tostr(verbose))
コード例 #5
0
def main(args):
    verbosity = 1 - args.quiet + args.verbose
    query = ','.join(args.query)

    if args.sort.endswith('-'):
        # Allow using "key-" instead of "-key" for reverse sorting
        args.sort = '-' + args.sort[:-1]

    if query.isdigit():
        query = int(query)

    add_key_value_pairs = {}
    if args.add_key_value_pairs:
        for pair in args.add_key_value_pairs.split(','):
            key, value = pair.split('=')
            add_key_value_pairs[key] = convert_str_to_int_float_or_str(value)

    if args.delete_keys:
        delete_keys = args.delete_keys.split(',')
    else:
        delete_keys = []

    db = connect(args.database, use_lock_file=not args.no_lock_file)

    def out(*args):
        if verbosity > 0:
            print(*args)

    if args.analyse:
        db.analyse()
        return

    if args.show_keys:
        keys = defaultdict(int)
        for row in db.select(query):
            for key in row._keys:
                keys[key] += 1

        n = max(len(key) for key in keys) + 1
        for key, number in keys.items():
            print('{:{}} {}'.format(key + ':', n, number))
        return

    if args.show_values:
        keys = args.show_values.split(',')
        values = {key: defaultdict(int) for key in keys}
        numbers = set()
        for row in db.select(query):
            kvp = row.key_value_pairs
            for key in keys:
                value = kvp.get(key)
                if value is not None:
                    values[key][value] += 1
                    if not isinstance(value, str):
                        numbers.add(key)

        n = max(len(key) for key in keys) + 1
        for key in keys:
            vals = values[key]
            if key in numbers:
                print('{:{}} [{}..{}]'.format(key + ':', n, min(vals),
                                              max(vals)))
            else:
                print('{:{}} {}'.format(
                    key + ':', n,
                    ', '.join('{}({})'.format(v, n) for v, n in vals.items())))
        return

    if args.add_from_file:
        filename = args.add_from_file
        configs = ase.io.read(filename)
        if not isinstance(configs, list):
            configs = [configs]
        for atoms in configs:
            db.write(atoms, key_value_pairs=add_key_value_pairs)
        out('Added ' + plural(len(configs), 'row'))
        return

    if args.count:
        n = db.count(query)
        print('%s' % plural(n, 'row'))
        return

    if args.explain:
        for row in db.select(query,
                             explain=True,
                             verbosity=verbosity,
                             limit=args.limit,
                             offset=args.offset):
            print(row['explain'])
        return

    if args.show_metadata:
        print(json.dumps(db.metadata, sort_keys=True, indent=4))
        return

    if args.set_metadata:
        with open(args.set_metadata) as fd:
            db.metadata = json.load(fd)
        return

    if args.insert_into:
        nkvp = 0
        nrows = 0
        with connect(args.insert_into,
                     use_lock_file=not args.no_lock_file) as db2:
            for row in db.select(query, sort=args.sort):
                kvp = row.get('key_value_pairs', {})
                nkvp -= len(kvp)
                kvp.update(add_key_value_pairs)
                nkvp += len(kvp)
                if args.unique:
                    row['unique_id'] = '%x' % randint(16**31, 16**32 - 1)
                if args.strip_data:
                    db2.write(row.toatoms(), **kvp)
                else:
                    db2.write(row, data=row.get('data'), **kvp)
                nrows += 1

        out('Added %s (%s updated)' %
            (plural(nkvp, 'key-value pair'),
             plural(len(add_key_value_pairs) * nrows - nkvp, 'pair')))
        out('Inserted %s' % plural(nrows, 'row'))
        return

    if add_key_value_pairs or delete_keys:
        ids = [row['id'] for row in db.select(query)]
        M = 0
        N = 0
        with db:
            for id in ids:
                m, n = db.update(id,
                                 delete_keys=delete_keys,
                                 **add_key_value_pairs)
                M += m
                N += n
        out('Added %s (%s updated)' %
            (plural(M, 'key-value pair'),
             plural(len(add_key_value_pairs) * len(ids) - M, 'pair')))
        out('Removed', plural(N, 'key-value pair'))

        return

    if args.delete:
        ids = [row['id'] for row in db.select(query)]
        if ids and not args.yes:
            msg = 'Delete %s? (yes/No): ' % plural(len(ids), 'row')
            if input(msg).lower() != 'yes':
                return
        db.delete(ids)
        out('Deleted %s' % plural(len(ids), 'row'))
        return

    if args.plot_data:
        from ase.db.plot import dct2plot
        dct2plot(db.get(query).data, args.plot_data)
        return

    if args.plot:
        if ':' in args.plot:
            tags, keys = args.plot.split(':')
            tags = tags.split(',')
        else:
            tags = []
            keys = args.plot
        keys = keys.split(',')
        plots = defaultdict(list)
        X = {}
        labels = []
        for row in db.select(query, sort=args.sort, include_data=False):
            name = ','.join(str(row[tag]) for tag in tags)
            x = row.get(keys[0])
            if x is not None:
                if isinstance(x, basestring):
                    if x not in X:
                        X[x] = len(X)
                        labels.append(x)
                    x = X[x]
                plots[name].append([x] + [row.get(key) for key in keys[1:]])
        import matplotlib.pyplot as plt
        for name, plot in plots.items():
            xyy = zip(*plot)
            x = xyy[0]
            for y, key in zip(xyy[1:], keys[1:]):
                plt.plot(x, y, label=name + ':' + key)
        if X:
            plt.xticks(range(len(labels)), labels, rotation=90)
        plt.legend()
        plt.show()
        return

    if args.json:
        row = db.get(query)
        db2 = connect(sys.stdout, 'json', use_lock_file=False)
        kvp = row.get('key_value_pairs', {})
        db2.write(row, data=row.get('data'), **kvp)
        return

    db.python = args.metadata_from_python_script

    if args.long:
        db.meta = process_metadata(db, html=args.open_web_browser)
        row = db.get(query)
        summary = Summary(row, db.meta)
        summary.write()
        return

    if args.open_web_browser:
        try:
            import ase.db.app as app
        except ImportError:
            print('Please install Flask: pip install flask')
            return
        app.databases['default'] = db
        app.initialize_databases()
        app.app.run(host='0.0.0.0', debug=True)
        return

    if args.write_summary_files:
        prefix = args.write_summary_files
        db.meta = process_metadata(db, html=args.open_web_browser)
        ukey = db.meta.get('unique_key', 'id')
        for row in db.select(query):
            uid = row.get(ukey)
            summary = Summary(row,
                              db.meta,
                              prefix='{}-{}-'.format(prefix, uid))
        return

    columns = list(all_columns)
    c = args.columns
    if c and c.startswith('++'):
        keys = set()
        for row in db.select(query,
                             limit=args.limit,
                             offset=args.offset,
                             include_data=False):
            keys.update(row._keys)
        columns.extend(keys)
        if c[2:3] == ',':
            c = c[3:]
        else:
            c = ''
    if c:
        if c[0] == '+':
            c = c[1:]
        elif c[0] != '-':
            columns = []
        for col in c.split(','):
            if col[0] == '-':
                columns.remove(col[1:])
            else:
                columns.append(col.lstrip('+'))

    table = Table(db, verbosity=verbosity, cut=args.cut)
    table.select(query, columns, args.sort, args.limit, args.offset)
    if args.csv:
        table.write_csv()
    else:
        table.write(query)
コード例 #6
0
ファイル: cli.py プロジェクト: rosswhitfield/ase
def run(opts, args, verbosity):
    filename = args.pop(0)
    query = ','.join(args)

    if query.isdigit():
        query = int(query)

    add_key_value_pairs = {}
    if opts.add_key_value_pairs:
        for pair in opts.add_key_value_pairs.split(','):
            key, value = pair.split('=')
            add_key_value_pairs[key] = convert_str_to_int_float_or_str(value)

    if opts.delete_keys:
        delete_keys = opts.delete_keys.split(',')
    else:
        delete_keys = []

    con = connect(filename, use_lock_file=not opts.no_lock_file)

    def out(*args):
        if verbosity > 0:
            print(*args)

    if opts.analyse:
        con.analyse()
        return

    if opts.add_from_file:
        filename = opts.add_from_file
        if ':' in filename:
            calculator_name, filename = filename.split(':')
            atoms = get_calculator(calculator_name)(filename).get_atoms()
        else:
            atoms = ase.io.read(filename)
        con.write(atoms, key_value_pairs=add_key_value_pairs)
        out('Added {0} from {1}'.format(atoms.get_chemical_formula(),
                                        filename))
        return

    if opts.count:
        n = con.count(query)
        print('%s' % plural(n, 'row'))
        return

    if opts.explain:
        for row in con.select(query, explain=True,
                              verbosity=verbosity,
                              limit=opts.limit, offset=opts.offset):
            print(row['explain'])
        return

    if opts.insert_into:
        nkvp = 0
        nrows = 0
        with connect(opts.insert_into,
                     use_lock_file=not opts.no_lock_file) as con2:
            for row in con.select(query):
                kvp = row.get('key_value_pairs', {})
                nkvp -= len(kvp)
                kvp.update(add_key_value_pairs)
                nkvp += len(kvp)
                if opts.unique:
                    row['unique_id'] = '%x' % randint(16**31, 16**32 - 1)
                con2.write(row, data=row.get('data'), **kvp)
                nrows += 1

        out('Added %s (%s updated)' %
            (plural(nkvp, 'key-value pair'),
             plural(len(add_key_value_pairs) * nrows - nkvp, 'pair')))
        out('Inserted %s' % plural(nrows, 'row'))
        return

    if add_key_value_pairs or delete_keys:
        ids = [row['id'] for row in con.select(query)]
        m, n = con.update(ids, delete_keys, **add_key_value_pairs)
        out('Added %s (%s updated)' %
            (plural(m, 'key-value pair'),
             plural(len(add_key_value_pairs) * len(ids) - m, 'pair')))
        out('Removed', plural(n, 'key-value pair'))

        return

    if opts.delete:
        ids = [row['id'] for row in con.select(query)]
        if ids and not opts.yes:
            msg = 'Delete %s? (yes/No): ' % plural(len(ids), 'row')
            if input(msg).lower() != 'yes':
                return
        con.delete(ids)
        out('Deleted %s' % plural(len(ids), 'row'))
        return

    if opts.plot_data:
        from ase.db.plot import dct2plot
        dct2plot(con.get(query).data, opts.plot_data)
        return

    if opts.plot:
        if ':' in opts.plot:
            tags, keys = opts.plot.split(':')
            tags = tags.split(',')
        else:
            tags = []
            keys = opts.plot
        keys = keys.split(',')
        plots = collections.defaultdict(list)
        X = {}
        labels = []
        for row in con.select(query, sort=opts.sort):
            name = ','.join(str(row[tag]) for tag in tags)
            x = row.get(keys[0])
            if x is not None:
                if isinstance(x, basestring):
                    if x not in X:
                        X[x] = len(X)
                        labels.append(x)
                    x = X[x]
                plots[name].append([x] + [row.get(key) for key in keys[1:]])
        import matplotlib.pyplot as plt
        for name, plot in plots.items():
            xyy = zip(*plot)
            x = xyy[0]
            for y, key in zip(xyy[1:], keys[1:]):
                plt.plot(x, y, label=name + ':' + key)
        if X:
            plt.xticks(range(len(labels)), labels, rotation=90)
        plt.legend()
        plt.show()
        return

    if opts.long:
        row = con.get(query)
        summary = Summary(row)
        summary.write()
    elif opts.json:
        row = con.get(query)
        con2 = connect(sys.stdout, 'json', use_lock_file=False)
        kvp = row.get('key_value_pairs', {})
        con2.write(row, data=row.get('data'), **kvp)
    else:
        if opts.open_web_browser:
            import ase.db.app as app
            app.db = con
            app.app.run(host='0.0.0.0', debug=True)
        else:
            columns = list(all_columns)
            c = opts.columns
            if c and c.startswith('++'):
                keys = set()
                for row in con.select(query,
                                      limit=opts.limit, offset=opts.offset):
                    keys.update(row._keys)
                columns.extend(keys)
                if c[2:3] == ',':
                    c = c[3:]
                else:
                    c = ''
            if c:
                if c[0] == '+':
                    c = c[1:]
                elif c[0] != '-':
                    columns = []
                for col in c.split(','):
                    if col[0] == '-':
                        columns.remove(col[1:])
                    else:
                        columns.append(col.lstrip('+'))

            table = Table(con, verbosity, opts.cut)
            table.select(query, columns, opts.sort, opts.limit, opts.offset)
            if opts.csv:
                table.write_csv()
            else:
                table.write(query)
コード例 #7
0
ファイル: bulk.py プロジェクト: essil1/ase-laser
def bulk(name, crystalstructure=None, a=None, c=None, covera=None, u=None,
         orthorhombic=False, cubic=False):
    """Creating bulk systems.

    Crystal structure and lattice constant(s) will be guessed if not
    provided.

    name: str
        Chemical symbol or symbols as in 'MgO' or 'NaCl'.
    crystalstructure: str
        Must be one of sc, fcc, bcc, hcp, diamond, zincblende,
        rocksalt, cesiumchloride, fluorite or wurtzite.
    a: float
        Lattice constant.
    c: float
        Lattice constant.
    covera: float
        c/a ratio used for hcp.  Default is ideal ratio: sqrt(8/3).
    u: float
        Internal coordinate for Wurtzite structure.
    orthorhombic: bool
        Construct orthorhombic unit cell instead of primitive cell
        which is the default.
    cubic: bool
        Construct cubic unit cell if possible.
    """

    if covera is not None and c is not None:
        raise ValueError("Don't specify both c and c/a!")

    xref = None
    ref = {}

    if name in chemical_symbols:
        Z = atomic_numbers[name]
        ref = reference_states[Z]
        if ref is not None:
            xref = ref['symmetry']

    structures = {'sc': 1, 'fcc': 1, 'bcc': 1, 'hcp': 1, 'diamond': 1,
                  'zincblende': 2, 'rocksalt':2, 'cesiumchloride':2,
                  'fluorite': 3, 'wurtzite': 2}

    if crystalstructure is None:
        crystalstructure = xref
        if crystalstructure not in structures:
            raise ValueError('No suitable reference data for bulk {}.'
                             '  Reference data: {}'
                             .format(name, ref))

    if crystalstructure not in structures:
        raise ValueError('Unknown structure: {}.'
                         .format(crystalstructure))

    # Check name:
    n = len(string2symbols(name))
    n0 = structures[crystalstructure]
    if n != n0:
        raise ValueError('Please specify {} for {} and not {}'
                         .format(plural(n0, 'atom'), crystalstructure, n))

    if a is None:
        if xref != crystalstructure:
            raise ValueError('You need to specify the lattice constant.')
        try:
            a = ref['a']
        except KeyError:
            raise KeyError('No reference lattice parameter "a" for "{}"'
                           .format(name))

    if crystalstructure in ['hcp', 'wurtzite']:
        cubic = False
        if c is not None:
            covera = c / a
        elif covera is None:
            if xref == crystalstructure:
                covera = ref['c/a']
            else:
                covera = sqrt(8 / 3)

    if orthorhombic and crystalstructure != 'sc':
        return _orthorhombic_bulk(name, crystalstructure, a, covera, u)

    if cubic and crystalstructure in ['bcc', 'cesiumchloride']:
        return _orthorhombic_bulk(name, crystalstructure, a, covera)

    if cubic and crystalstructure != 'sc':
        return _cubic_bulk(name, crystalstructure, a)

    if crystalstructure == 'sc':
        atoms = Atoms(name, cell=(a, a, a), pbc=True)
    elif crystalstructure == 'fcc':
        b = a / 2
        atoms = Atoms(name, cell=[(0, b, b), (b, 0, b), (b, b, 0)], pbc=True)
    elif crystalstructure == 'bcc':
        b = a / 2
        atoms = Atoms(name, cell=[(-b, b, b), (b, -b, b), (b, b, -b)],
                      pbc=True)
    elif crystalstructure == 'hcp':
        atoms = Atoms(2 * name,
                      scaled_positions=[(0, 0, 0),
                                        (1 / 3, 2 / 3, 0.5)],
                      cell=[(a, 0, 0),
                            (-a / 2, a * sqrt(3) / 2, 0),
                            (0, 0, covera * a)],
                      pbc=True)
    elif crystalstructure == 'diamond':
        atoms = bulk(2 * name, 'zincblende', a)
    elif crystalstructure == 'zincblende':
        s1, s2 = string2symbols(name)
        atoms = bulk(s1, 'fcc', a) + bulk(s2, 'fcc', a)
        atoms.positions[1] += a / 4
    elif crystalstructure == 'rocksalt':
        s1, s2 = string2symbols(name)
        atoms = bulk(s1, 'fcc', a) + bulk(s2, 'fcc', a)
        atoms.positions[1, 0] += a / 2
    elif crystalstructure == 'cesiumchloride':
        s1, s2 = string2symbols(name)
        atoms = bulk(s1, 'sc', a) + bulk(s2, 'sc', a)
        atoms.positions[1, :] += a / 2
    elif crystalstructure == 'fluorite':
        s1, s2, s3 = string2symbols(name)
        atoms = bulk(s1, 'fcc', a) + bulk(s2, 'fcc', a) + bulk(s3, 'fcc', a)
        atoms.positions[1, :] += a / 4
        atoms.positions[2, :] += a * 3 / 4
    elif crystalstructure == 'wurtzite':
        u = u or 0.25 + 1 / 3 / covera**2
        atoms = Atoms(2 * name,
                      scaled_positions=[(0, 0, 0),
                                        (1 / 3, 2 / 3, 0.5 - u),
                                        (1 / 3, 2 / 3, 0.5),
                                        (0, 0, 1 - u)],
                      cell=[(a, 0, 0),
                            (-a / 2, a * sqrt(3) / 2, 0),
                            (0, 0, a * covera)],
                      pbc=True)
    else:
        raise ValueError('Unknown crystal structure: ' + crystalstructure)

    return atoms
コード例 #8
0
def rst1(dataset, nlfer, energies):
    table1 = ''
    nv = 0
    for n, l, f, e, rcut in nlfer:
        n, l, f = (int(x) for x in [n, l, f])
        if n == -1:
            n = ''
        table1 += '    {}{},{},{:.3f},'.format(n, 'spdf'[l], f, e * Hartree)
        if rcut:
            table1 += '{:.2f}'.format(rcut)
            nv += f
        table1 += '\n'

    rst = """

{electrons}
====================

Radial cutoffs and eigenvalues:

.. csv-table::
    :header: id, occ, eig [eV], cutoff [Bohr]

{table1}

The figure shows convergence of the absolute energy (red line)
and atomization energy (green line) of a {symbol} dimer relative
to completely converged numbers (plane-wave calculation at 1500 eV).
Also shown are finite-difference and LCAO (dzp) calculations at gridspacings
0.143 Å, 0.167 Å and 0.200 Å.

.. image:: {dataset}.png

Egg-box errors in finite-difference mode:

.. csv-table::
    :header: grid-spacing [Å], energy error [eV]

{table2}"""

    epw, depw, efd, defd, elcao, delcao, deegg = energies

    table2 = ''
    for h, e in zip([0.16, 0.18, 0.2], deegg):
        table2 += '    {:.2f},{:.4f}\n'.format(h, e)

    fig = plt.figure(figsize=(8, 5))

    ax1 = plt.subplot(121)
    ax1.semilogy(cutoffs[:-1], epw[:-1], 'r', label='pw, absolute')
    ax1.semilogy(cutoffs[:-1], depw[:-1], 'g', label='pw, atomization')
    plt.xticks([200, 400, 600])
    plt.xlabel('plane-wave cutoff [eV]')
    plt.ylabel('error [eV/atom]')
    plt.legend(loc='best')

    ax2 = plt.subplot(122, sharey=ax1)
    h = [4.0 / g for g in [20, 24, 28]]
    ax2.semilogy(h, efd, '-rs', label='fd, absolute')
    ax2.semilogy(h, defd, '-gs', label='fd, atomization')
    ax2.semilogy(h, elcao, '-ro', label='lcao, absolute')
    ax2.semilogy(h, delcao, '-go', label='lcao, atomization')
    plt.xticks([0.16, 0.18, 0.2])
    plt.xlim(0.14, 0.2)
    plt.xlabel(u'grid-spacing [Å]')
    plt.legend(loc='best')
    plt.setp(ax2.get_yticklabels(), visible=False)

    plt.tight_layout()
    plt.subplots_adjust(wspace=0)
    plt.savefig(dataset + '.png')
    plt.close(fig)

    return nv, rst.format(electrons=plural(nv, 'valence electron'),
                          table1=table1,
                          table2=table2,
                          symbol=symbol,
                          dataset=dataset)
コード例 #9
0
ファイル: cli.py プロジェクト: jboes/ase
def run(opts, args, verbosity):
    filename = args.pop(0)
    query = ",".join(args)

    if query.isdigit():
        query = int(query)

    add_key_value_pairs = {}
    if opts.add_key_value_pairs:
        for pair in opts.add_key_value_pairs.split(","):
            key, value = pair.split("=")
            add_key_value_pairs[key] = convert_str_to_float_or_str(value)

    if opts.delete_keys:
        delete_keys = opts.delete_keys.split(",")
    else:
        delete_keys = []

    con = connect(filename, use_lock_file=not opts.no_lock_file)

    def out(*args):
        if verbosity > 0:
            print(*args)

    if opts.analyse:
        con.analyse()
        return

    if opts.add_from_file:
        filename = opts.add_from_file
        if ":" in filename:
            calculator_name, filename = filename.split(":")
            atoms = get_calculator(calculator_name)(filename).get_atoms()
        else:
            atoms = ase.io.read(filename)
        con.write(atoms, key_value_pairs=add_key_value_pairs)
        out("Added {0} from {1}".format(atoms.get_chemical_formula(), filename))
        return

    if opts.count:
        n = con.count(query)
        print("%s" % plural(n, "row"))
        return

    if opts.explain:
        for dct in con.select(query, explain=True, verbosity=verbosity, limit=opts.limit, offset=opts.offset):
            print(dct["explain"])
        return

    if opts.insert_into:
        nkvp = 0
        nrows = 0
        with connect(opts.insert_into, use_lock_file=not opts.no_lock_file) as con2:
            for dct in con.select(query):
                kvp = dct.get("key_value_pairs", {})
                nkvp -= len(kvp)
                kvp.update(add_key_value_pairs)
                nkvp += len(kvp)
                if opts.unique:
                    dct["unique_id"] = "%x" % randint(16 ** 31, 16 ** 32 - 1)
                con2.write(dct, data=dct.get("data"), **kvp)
                nrows += 1

        out(
            "Added %s (%s updated)"
            % (plural(nkvp, "key-value pair"), plural(len(add_key_value_pairs) * nrows - nkvp, "pair"))
        )
        out("Inserted %s" % plural(nrows, "row"))
        return

    if add_key_value_pairs or delete_keys:
        ids = [dct["id"] for dct in con.select(query)]
        m, n = con.update(ids, delete_keys, **add_key_value_pairs)
        out(
            "Added %s (%s updated)"
            % (plural(m, "key-value pair"), plural(len(add_key_value_pairs) * len(ids) - m, "pair"))
        )
        out("Removed", plural(n, "key-value pair"))

        return

    if opts.delete:
        ids = [dct["id"] for dct in con.select(query)]
        if ids and not opts.yes:
            msg = "Delete %s? (yes/No): " % plural(len(ids), "row")
            if input(msg).lower() != "yes":
                return
        con.delete(ids)
        out("Deleted %s" % plural(len(ids), "row"))
        return

    if opts.plot:
        if ":" in opts.plot:
            tags, keys = opts.plot.split(":")
            tags = tags.split(",")
        else:
            tags = []
            keys = opts.plot
        keys = keys.split(",")
        plots = collections.defaultdict(list)
        X = {}
        labels = []
        for row in con.select(query, sort=opts.sort):
            name = ",".join(row[tag] for tag in tags)
            x = row.get(keys[0])
            if x is not None:
                if isinstance(x, (unicode, str)):
                    if x not in X:
                        X[x] = len(X)
                        labels.append(x)
                    x = X[x]
                plots[name].append([x] + [row.get(key) for key in keys[1:]])
        import matplotlib.pyplot as plt

        for name, plot in plots.items():
            xyy = zip(*plot)
            x = xyy[0]
            for y, key in zip(xyy[1:], keys[1:]):
                plt.plot(x, y, label=name + key)
        if X:
            plt.xticks(range(len(labels)), labels, rotation=90)
        plt.legend()
        plt.show()
        return

    if opts.long:
        dct = con.get(query)
        summary = Summary(dct)
        summary.write()
    elif opts.json:
        dct = con.get(query)
        con2 = connect(sys.stdout, "json", use_lock_file=False)
        kvp = dct.get("key_value_pairs", {})
        con2.write(dct, data=dct.get("data"), **kvp)
    else:
        if opts.open_web_browser:
            import ase.db.app as app

            app.db = con
            app.app.run(host="0.0.0.0", debug=True)
        else:
            columns = list(all_columns)
            c = opts.columns
            if c and c.startswith("++"):
                keys = set()
                for row in con.select(query, limit=opts.limit, offset=opts.offset):
                    keys.update(row._keys)
                columns.extend(keys)
                if c[2:3] == ",":
                    c = c[3:]
                else:
                    c = ""
            if c:
                if c[0] == "+":
                    c = c[1:]
                elif c[0] != "-":
                    columns = []
                for col in c.split(","):
                    if col[0] == "-":
                        columns.remove(col[1:])
                    else:
                        columns.append(col.lstrip("+"))

            table = Table(con, verbosity, opts.cut)
            table.select(query, columns, opts.sort, opts.limit, opts.offset)
            if opts.csv:
                table.write_csv()
            else:
                table.write(query)
コード例 #10
0
def run(opts, args, verbosity):
    filename = args.pop(0)
    query = ','.join(args)

    if query.isdigit():
        query = int(query)

    add_key_value_pairs = {}
    if opts.add_key_value_pairs:
        for pair in opts.add_key_value_pairs.split(','):
            key, value = pair.split('=')
            add_key_value_pairs[key] = convert_str_to_float_or_str(value)

    if opts.delete_keys:
        delete_keys = opts.delete_keys.split(',')
    else:
        delete_keys = []

    con = connect(filename, use_lock_file=not opts.no_lock_file)

    def out(*args):
        if verbosity > 0:
            print(*args)

    if opts.analyse:
        con.analyse()
        return

    if opts.add_from_file:
        filename = opts.add_from_file
        if ':' in filename:
            calculator_name, filename = filename.split(':')
            atoms = get_calculator(calculator_name)(filename).get_atoms()
        else:
            atoms = ase.io.read(filename)
        con.write(atoms, key_value_pairs=add_key_value_pairs)
        out('Added {0} from {1}'.format(atoms.get_chemical_formula(),
                                        filename))
        return

    if opts.count:
        n = con.count(query)
        print('%s' % plural(n, 'row'))
        return

    if opts.explain:
        for dct in con.select(query,
                              explain=True,
                              verbosity=verbosity,
                              limit=opts.limit,
                              offset=opts.offset):
            print(dct['explain'])
        return

    if opts.insert_into:
        nkvp = 0
        nrows = 0
        with connect(opts.insert_into,
                     use_lock_file=not opts.no_lock_file) as con2:
            for dct in con.select(query):
                kvp = dct.get('key_value_pairs', {})
                nkvp -= len(kvp)
                kvp.update(add_key_value_pairs)
                nkvp += len(kvp)
                if opts.unique:
                    dct['unique_id'] = '%x' % randint(16**31, 16**32 - 1)
                con2.write(dct, data=dct.get('data'), **kvp)
                nrows += 1

        out('Added %s (%s updated)' %
            (plural(nkvp, 'key-value pair'),
             plural(len(add_key_value_pairs) * nrows - nkvp, 'pair')))
        out('Inserted %s' % plural(nrows, 'row'))
        return

    if add_key_value_pairs or delete_keys:
        ids = [dct['id'] for dct in con.select(query)]
        m, n = con.update(ids, delete_keys, **add_key_value_pairs)
        out('Added %s (%s updated)' %
            (plural(m, 'key-value pair'),
             plural(len(add_key_value_pairs) * len(ids) - m, 'pair')))
        out('Removed', plural(n, 'key-value pair'))

        return

    if opts.delete:
        ids = [dct['id'] for dct in con.select(query)]
        if ids and not opts.yes:
            msg = 'Delete %s? (yes/No): ' % plural(len(ids), 'row')
            if input(msg).lower() != 'yes':
                return
        con.delete(ids)
        out('Deleted %s' % plural(len(ids), 'row'))
        return

    if opts.plot_data:
        from ase.db.plot import dct2plot
        dct2plot(con.get(query).data, opts.plot_data)
        return

    if opts.plot:
        if ':' in opts.plot:
            tags, keys = opts.plot.split(':')
            tags = tags.split(',')
        else:
            tags = []
            keys = opts.plot
        keys = keys.split(',')
        plots = collections.defaultdict(list)
        X = {}
        labels = []
        for row in con.select(query, sort=opts.sort):
            name = ','.join(str(row[tag]) for tag in tags)
            x = row.get(keys[0])
            if x is not None:
                if isinstance(x, basestring):
                    if x not in X:
                        X[x] = len(X)
                        labels.append(x)
                    x = X[x]
                plots[name].append([x] + [row.get(key) for key in keys[1:]])
        import matplotlib.pyplot as plt
        for name, plot in plots.items():
            xyy = zip(*plot)
            x = xyy[0]
            for y, key in zip(xyy[1:], keys[1:]):
                plt.plot(x, y, label=name + ':' + key)
        if X:
            plt.xticks(range(len(labels)), labels, rotation=90)
        plt.legend()
        plt.show()
        return

    if opts.long:
        dct = con.get(query)
        summary = Summary(dct)
        summary.write()
    elif opts.json:
        dct = con.get(query)
        con2 = connect(sys.stdout, 'json', use_lock_file=False)
        kvp = dct.get('key_value_pairs', {})
        con2.write(dct, data=dct.get('data'), **kvp)
    else:
        if opts.open_web_browser:
            import ase.db.app as app
            app.db = con
            app.app.run(host='0.0.0.0', debug=True)
        else:
            columns = list(all_columns)
            c = opts.columns
            if c and c.startswith('++'):
                keys = set()
                for row in con.select(query,
                                      limit=opts.limit,
                                      offset=opts.offset):
                    keys.update(row._keys)
                columns.extend(keys)
                if c[2:3] == ',':
                    c = c[3:]
                else:
                    c = ''
            if c:
                if c[0] == '+':
                    c = c[1:]
                elif c[0] != '-':
                    columns = []
                for col in c.split(','):
                    if col[0] == '-':
                        columns.remove(col[1:])
                    else:
                        columns.append(col.lstrip('+'))

            table = Table(con, verbosity, opts.cut)
            table.select(query, columns, opts.sort, opts.limit, opts.offset)
            if opts.csv:
                table.write_csv()
            else:
                table.write(query)
コード例 #11
0
def main(args):
    verbosity = 1 - args.quiet + args.verbose
    query = ','.join(args.query)

    if args.sort.endswith('-'):
        args.sort = '-' + args.sort[:-1]

    if query.isdigit():
        query = int(query)

    add_key_value_pairs = {}
    if args.add_key_value_pairs:
        for pair in args.add_key_value_pairs.split(','):
            key, value = pair.split('=')
            add_key_value_pairs[key] = convert_str_to_int_float_or_str(value)

    if args.delete_keys:
        delete_keys = args.delete_keys.split(',')
    else:
        delete_keys = []

    db = connect(args.database, use_lock_file=not args.no_lock_file)

    def out(*args):
        if verbosity > 0:
            print(*args)

    if args.analyse:
        db.analyse()
        return

    if args.add_from_file:
        filename = args.add_from_file
        if ':' in filename:
            calculator_name, filename = filename.split(':')
            atoms = get_calculator(calculator_name)(filename).get_atoms()
        else:
            atoms = ase.io.read(filename)
        db.write(atoms, key_value_pairs=add_key_value_pairs)
        out('Added {0} from {1}'.format(atoms.get_chemical_formula(),
                                        filename))
        return

    if args.count:
        n = db.count(query)
        print('%s' % plural(n, 'row'))
        return

    if args.explain:
        for row in db.select(query,
                             explain=True,
                             verbosity=verbosity,
                             limit=args.limit,
                             offset=args.offset):
            print(row['explain'])
        return

    if args.show_metadata:
        print(json.dumps(db.metadata, sort_keys=True, indent=4))
        return

    if args.set_metadata:
        with open(args.set_metadata) as fd:
            db.metadata = json.load(fd)
        return

    if args.insert_into:
        nkvp = 0
        nrows = 0
        with connect(args.insert_into,
                     use_lock_file=not args.no_lock_file) as db2:
            for row in db.select(query, sort=args.sort):
                kvp = row.get('key_value_pairs', {})
                nkvp -= len(kvp)
                kvp.update(add_key_value_pairs)
                nkvp += len(kvp)
                if args.unique:
                    row['unique_id'] = '%x' % randint(16**31, 16**32 - 1)
                db2.write(row, data=row.get('data'), **kvp)
                nrows += 1

        out('Added %s (%s updated)' %
            (plural(nkvp, 'key-value pair'),
             plural(len(add_key_value_pairs) * nrows - nkvp, 'pair')))
        out('Inserted %s' % plural(nrows, 'row'))
        return

    if add_key_value_pairs or delete_keys:
        ids = [row['id'] for row in db.select(query)]
        m, n = db.update(ids, delete_keys, **add_key_value_pairs)
        out('Added %s (%s updated)' %
            (plural(m, 'key-value pair'),
             plural(len(add_key_value_pairs) * len(ids) - m, 'pair')))
        out('Removed', plural(n, 'key-value pair'))

        return

    if args.delete:
        ids = [row['id'] for row in db.select(query)]
        if ids and not args.yes:
            msg = 'Delete %s? (yes/No): ' % plural(len(ids), 'row')
            if input(msg).lower() != 'yes':
                return
        db.delete(ids)
        out('Deleted %s' % plural(len(ids), 'row'))
        return

    if args.plot_data:
        from ase.db.plot import dct2plot
        dct2plot(db.get(query).data, args.plot_data)
        return

    if args.plot:
        if ':' in args.plot:
            tags, keys = args.plot.split(':')
            tags = tags.split(',')
        else:
            tags = []
            keys = args.plot
        keys = keys.split(',')
        plots = collections.defaultdict(list)
        X = {}
        labels = []
        for row in db.select(query, sort=args.sort, include_data=False):
            name = ','.join(str(row[tag]) for tag in tags)
            x = row.get(keys[0])
            if x is not None:
                if isinstance(x, basestring):
                    if x not in X:
                        X[x] = len(X)
                        labels.append(x)
                    x = X[x]
                plots[name].append([x] + [row.get(key) for key in keys[1:]])
        import matplotlib.pyplot as plt
        for name, plot in plots.items():
            xyy = zip(*plot)
            x = xyy[0]
            for y, key in zip(xyy[1:], keys[1:]):
                plt.plot(x, y, label=name + ':' + key)
        if X:
            plt.xticks(range(len(labels)), labels, rotation=90)
        plt.legend()
        plt.show()
        return

    if args.json:
        row = db.get(query)
        db2 = connect(sys.stdout, 'json', use_lock_file=False)
        kvp = row.get('key_value_pairs', {})
        db2.write(row, data=row.get('data'), **kvp)
        return

    db.python = args.metadata_from_python_script
    db.meta = process_metadata(db, html=args.open_web_browser)

    if args.long:
        # Remove .png files so that new ones will be created.
        for func, filenames in db.meta.get('functions', []):
            for filename in filenames:
                try:
                    os.remove(filename)
                except OSError:  # Python 3 only: FileNotFoundError
                    pass

        row = db.get(query)
        summary = Summary(row, db.meta)
        summary.write()
    else:
        if args.open_web_browser:
            import ase.db.app as app
            app.databases['default'] = db
            app.app.run(host='0.0.0.0', debug=True)
        else:
            columns = list(all_columns)
            c = args.columns
            if c and c.startswith('++'):
                keys = set()
                for row in db.select(query,
                                     limit=args.limit,
                                     offset=args.offset,
                                     include_data=False):
                    keys.update(row._keys)
                columns.extend(keys)
                if c[2:3] == ',':
                    c = c[3:]
                else:
                    c = ''
            if c:
                if c[0] == '+':
                    c = c[1:]
                elif c[0] != '-':
                    columns = []
                for col in c.split(','):
                    if col[0] == '-':
                        columns.remove(col[1:])
                    else:
                        columns.append(col.lstrip('+'))

            table = Table(db, verbosity, args.cut)
            table.select(query, columns, args.sort, args.limit, args.offset)
            if args.csv:
                table.write_csv()
            else:
                table.write(query)
コード例 #12
0
ファイル: bulk.py プロジェクト: shuchingou/ase
def bulk(name,
         crystalstructure=None,
         a=None,
         b=None,
         c=None,
         *,
         alpha=None,
         covera=None,
         u=None,
         orthorhombic=False,
         cubic=False,
         basis=None):
    """Creating bulk systems.

    Crystal structure and lattice constant(s) will be guessed if not
    provided.

    name: str
        Chemical symbol or symbols as in 'MgO' or 'NaCl'.
    crystalstructure: str
        Must be one of sc, fcc, bcc, hcp, diamond, zincblende,
        rocksalt, cesiumchloride, fluorite or wurtzite.
    a: float
        Lattice constant.
    b: float
        Lattice constant.  If only a and b is given, b will be interpreted
        as c instead.
    c: float
        Lattice constant.
    alpha: float
        Angle in degrees for rhombohedral lattice.
    covera: float
        c/a ratio used for hcp.  Default is ideal ratio: sqrt(8/3).
    u: float
        Internal coordinate for Wurtzite structure.
    orthorhombic: bool
        Construct orthorhombic unit cell instead of primitive cell
        which is the default.
    cubic: bool
        Construct cubic unit cell if possible.
    """

    if c is None and b is not None:
        # If user passes (a, b) positionally, we want it as (a, c) instead:
        c, b = b, c

    if covera is not None and c is not None:
        raise ValueError("Don't specify both c and c/a!")

    xref = None
    ref = {}

    if name in chemical_symbols:
        Z = atomic_numbers[name]
        ref = reference_states[Z]
        if ref is not None:
            xref = ref['symmetry']

            # If user did not specify crystal structure, and no basis
            # is given, and the reference state says we need one, but
            # does not have one, then we can't proceed.
            if (crystalstructure is None and basis is None and 'basis' in ref
                    and ref['basis'] is None):
                # XXX This is getting much too complicated, we need to split
                # this function up.  A lot.
                raise RuntimeError('This structure requires an atomic basis')

        if ref is None:
            ref = {}  # easier to 'get' things from empty dictionary than None

        if xref == 'cubic':
            # P and Mn are listed as 'cubic' but the lattice constants
            # are 7 and 9.  They must be something other than simple cubic
            # then. We used to just return the cubic one but that must
            # have been wrong somehow.  --askhl
            raise RuntimeError('Only simple cubic ("sc") supported')

    # Mapping of name to number of atoms in primitive cell.
    structures = {
        'sc': 1,
        'fcc': 1,
        'bcc': 1,
        'tetragonal': 1,
        'bct': 1,
        'hcp': 1,
        'rhombohedral': 1,
        'orthorhombic': 1,
        'mcl': 1,
        'diamond': 1,
        'zincblende': 2,
        'rocksalt': 2,
        'cesiumchloride': 2,
        'fluorite': 3,
        'wurtzite': 2
    }

    if crystalstructure is None:
        crystalstructure = xref
        if crystalstructure not in structures:
            raise ValueError('No suitable reference data for bulk {}.'
                             '  Reference data: {}'.format(name, ref))

    if crystalstructure not in structures:
        raise ValueError('Unknown structure: {}.'.format(crystalstructure))

    # Check name:
    natoms = len(string2symbols(name))
    natoms0 = structures[crystalstructure]
    if natoms != natoms0:
        raise ValueError('Please specify {} for {} and not {}'.format(
            plural(natoms0, 'atom'), crystalstructure, natoms))

    if alpha is None:
        alpha = ref.get('alpha')

    if a is None:
        if xref != crystalstructure:
            raise ValueError('You need to specify the lattice constant.')
        try:
            a = ref['a']
        except KeyError:
            raise KeyError(
                'No reference lattice parameter "a" for "{}"'.format(name))

    if b is None:
        bovera = ref.get('b/a')
        if bovera is not None and a is not None:
            b = bovera * a

    if crystalstructure in ['hcp', 'wurtzite']:
        if cubic:
            raise incompatible_cell(want='cubic', have=crystalstructure)

        if c is not None:
            covera = c / a
        elif covera is None:
            if xref == crystalstructure:
                covera = ref['c/a']
            else:
                covera = sqrt(8 / 3)

    if covera is None:
        covera = ref.get('c/a')
        if c is None and covera is not None:
            c = covera * a

    if orthorhombic and crystalstructure not in [
            'sc', 'tetragonal', 'orthorhombic'
    ]:
        return _orthorhombic_bulk(name, crystalstructure, a, covera, u)

    if cubic and crystalstructure in ['bcc', 'cesiumchloride']:
        return _orthorhombic_bulk(name, crystalstructure, a, covera)

    if cubic and crystalstructure != 'sc':
        return _cubic_bulk(name, crystalstructure, a)

    if crystalstructure == 'sc':
        atoms = Atoms(name, cell=(a, a, a), pbc=True)
    elif crystalstructure == 'fcc':
        b = a / 2
        atoms = Atoms(name, cell=[(0, b, b), (b, 0, b), (b, b, 0)], pbc=True)
    elif crystalstructure == 'bcc':
        b = a / 2
        atoms = Atoms(name,
                      cell=[(-b, b, b), (b, -b, b), (b, b, -b)],
                      pbc=True)
    elif crystalstructure == 'hcp':
        atoms = Atoms(2 * name,
                      scaled_positions=[(0, 0, 0), (1 / 3, 2 / 3, 0.5)],
                      cell=[(a, 0, 0), (-a / 2, a * sqrt(3) / 2, 0),
                            (0, 0, covera * a)],
                      pbc=True)
    elif crystalstructure == 'diamond':
        atoms = bulk(2 * name, 'zincblende', a)
    elif crystalstructure == 'zincblende':
        s1, s2 = string2symbols(name)
        atoms = bulk(s1, 'fcc', a) + bulk(s2, 'fcc', a)
        atoms.positions[1] += a / 4
    elif crystalstructure == 'rocksalt':
        s1, s2 = string2symbols(name)
        atoms = bulk(s1, 'fcc', a) + bulk(s2, 'fcc', a)
        atoms.positions[1, 0] += a / 2
    elif crystalstructure == 'cesiumchloride':
        s1, s2 = string2symbols(name)
        atoms = bulk(s1, 'sc', a) + bulk(s2, 'sc', a)
        atoms.positions[1, :] += a / 2
    elif crystalstructure == 'fluorite':
        s1, s2, s3 = string2symbols(name)
        atoms = bulk(s1, 'fcc', a) + bulk(s2, 'fcc', a) + bulk(s3, 'fcc', a)
        atoms.positions[1, :] += a / 4
        atoms.positions[2, :] += a * 3 / 4
    elif crystalstructure == 'wurtzite':
        u = u or 0.25 + 1 / 3 / covera**2
        atoms = Atoms(2 * name,
                      scaled_positions=[(0, 0, 0), (1 / 3, 2 / 3, 0.5 - u),
                                        (1 / 3, 2 / 3, 0.5), (0, 0, 1 - u)],
                      cell=[(a, 0, 0), (-a / 2, a * sqrt(3) / 2, 0),
                            (0, 0, a * covera)],
                      pbc=True)
    elif crystalstructure == 'bct':
        from ase.lattice import BCT
        if basis is None:
            basis = ref.get('basis')
        if basis is not None:
            natoms = len(basis)
        lat = BCT(a=a, c=c)
        atoms = Atoms([name] * natoms,
                      cell=lat.tocell(),
                      pbc=True,
                      scaled_positions=basis)
    elif crystalstructure == 'rhombohedral':
        atoms = _build_rhl(name, a, alpha, basis)
    elif crystalstructure == 'orthorhombic':
        atoms = Atoms(name, cell=[a, b, c], pbc=True)
    else:
        raise ValueError('Unknown crystal structure: ' + crystalstructure)

    if orthorhombic:
        assert atoms.cell.orthorhombic
    if cubic:
        assert abs(atoms.cell.angles() - 90).all() < 1e-10
    return atoms