Exemplo n.º 1
0
def client(database):
    pytest.importorskip('flask')
    import ase.db.app as app

    app.add_project(database)
    app.app.testing = True
    return app.app.test_client()
Exemplo n.º 2
0
def test_db_web():
    from ase import Atoms
    from ase.db import connect
    from pytest import importorskip

    importorskip('flask')
    import ase.db.app as app



    db = connect('test.db', append=False)
    x = [0, 1, 2]
    t1 = [1, 2, 0]
    t2 = [[2, 3], [1, 1], [1, 0]]

    atoms = Atoms('H2O')
    atoms.center(vacuum=5)
    atoms.set_pbc(True)

    db.write(atoms,
             foo=42.0,
             bar='abc',
             data={'x': x,
                   't1': t1,
                   't2': t2})
    app.add_project(db)
    app.app.testing = True
    c = app.app.test_client()
    page = c.get('/').data.decode()
    assert 'foo' in page
    p1 = c.get('/default/row/1').data.decode()
    print(p1)
    c.get('/default/json/1').data
    c.get('/default/sqlite/1').data
    c.get('/default/sqlite?x=1').data
    c.get('/default/json?x=1').data
Exemplo n.º 3
0
def main(args):
    verbosity = 1 - args.quiet + args.verbose
    query = ','.join(args.query)

    if args.sort.endswith('-'):
        # Allow using "key-" instead of "-key" for reverse sorting
        args.sort = '-' + args.sort[:-1]

    if query.isdigit():
        query = int(query)

    add_key_value_pairs = {}
    if args.add_key_value_pairs:
        for pair in args.add_key_value_pairs.split(','):
            key, value = pair.split('=')
            add_key_value_pairs[key] = convert_str_to_int_float_or_str(value)

    if args.delete_keys:
        delete_keys = args.delete_keys.split(',')
    else:
        delete_keys = []

    db = connect(args.database, use_lock_file=not args.no_lock_file)

    def out(*args):
        if verbosity > 0:
            print(*args)

    if args.analyse:
        db.analyse()
        return

    if args.show_keys:
        keys = defaultdict(int)
        for row in db.select(query):
            for key in row._keys:
                keys[key] += 1

        n = max(len(key) for key in keys) + 1
        for key, number in keys.items():
            print('{:{}} {}'.format(key + ':', n, number))
        return

    if args.show_values:
        keys = args.show_values.split(',')
        values = {key: defaultdict(int) for key in keys}
        numbers = set()
        for row in db.select(query):
            kvp = row.key_value_pairs
            for key in keys:
                value = kvp.get(key)
                if value is not None:
                    values[key][value] += 1
                    if not isinstance(value, str):
                        numbers.add(key)

        n = max(len(key) for key in keys) + 1
        for key in keys:
            vals = values[key]
            if key in numbers:
                print('{:{}} [{}..{}]'.format(key + ':', n, min(vals),
                                              max(vals)))
            else:
                print('{:{}} {}'.format(
                    key + ':', n,
                    ', '.join('{}({})'.format(v, n) for v, n in vals.items())))
        return

    if args.add_from_file:
        filename = args.add_from_file
        configs = ase.io.read(filename)
        if not isinstance(configs, list):
            configs = [configs]
        for atoms in configs:
            db.write(atoms, key_value_pairs=add_key_value_pairs)
        out('Added ' + plural(len(configs), 'row'))
        return

    if args.count:
        n = db.count(query)
        print('%s' % plural(n, 'row'))
        return

    if args.insert_into:
        if args.limit == -1:
            args.limit = 0
        nkvp = 0
        nrows = 0
        with connect(args.insert_into,
                     use_lock_file=not args.no_lock_file) as db2:
            for row in db.select(query,
                                 sort=args.sort,
                                 limit=args.limit,
                                 offset=args.offset):
                kvp = row.get('key_value_pairs', {})
                nkvp -= len(kvp)
                kvp.update(add_key_value_pairs)
                nkvp += len(kvp)
                if args.unique:
                    row['unique_id'] = '%x' % randint(16**31, 16**32 - 1)
                if args.strip_data:
                    db2.write(row.toatoms(), **kvp)
                else:
                    db2.write(row, data=row.get('data'), **kvp)
                nrows += 1

        out('Added %s (%s updated)' %
            (plural(nkvp, 'key-value pair'),
             plural(len(add_key_value_pairs) * nrows - nkvp, 'pair')))
        out('Inserted %s' % plural(nrows, 'row'))
        return

    if args.limit == -1:
        args.limit = 20

    if args.explain:
        for row in db.select(query,
                             explain=True,
                             verbosity=verbosity,
                             limit=args.limit,
                             offset=args.offset):
            print(row['explain'])
        return

    if args.show_metadata:
        print(json.dumps(db.metadata, sort_keys=True, indent=4))
        return

    if args.set_metadata:
        with open(args.set_metadata) as fd:
            db.metadata = json.load(fd)
        return

    if add_key_value_pairs or delete_keys:
        ids = [row['id'] for row in db.select(query)]
        M = 0
        N = 0
        with db:
            for id in ids:
                m, n = db.update(id,
                                 delete_keys=delete_keys,
                                 **add_key_value_pairs)
                M += m
                N += n
        out('Added %s (%s updated)' %
            (plural(M, 'key-value pair'),
             plural(len(add_key_value_pairs) * len(ids) - M, 'pair')))
        out('Removed', plural(N, 'key-value pair'))

        return

    if args.delete:
        ids = [row['id'] for row in db.select(query, include_data=False)]
        if ids and not args.yes:
            msg = 'Delete %s? (yes/No): ' % plural(len(ids), 'row')
            if input(msg).lower() != 'yes':
                return
        db.delete(ids)
        out('Deleted %s' % plural(len(ids), 'row'))
        return

    if args.plot:
        if ':' in args.plot:
            tags, keys = args.plot.split(':')
            tags = tags.split(',')
        else:
            tags = []
            keys = args.plot
        keys = keys.split(',')
        plots = defaultdict(list)
        X = {}
        labels = []
        for row in db.select(query, sort=args.sort, include_data=False):
            name = ','.join(str(row[tag]) for tag in tags)
            x = row.get(keys[0])
            if x is not None:
                if isinstance(x, str):
                    if x not in X:
                        X[x] = len(X)
                        labels.append(x)
                    x = X[x]
                plots[name].append([x] + [row.get(key) for key in keys[1:]])
        import matplotlib.pyplot as plt
        for name, plot in plots.items():
            xyy = zip(*plot)
            x = xyy[0]
            for y, key in zip(xyy[1:], keys[1:]):
                plt.plot(x, y, label=name + ':' + key)
        if X:
            plt.xticks(range(len(labels)), labels, rotation=90)
        plt.legend()
        plt.show()
        return

    if args.json:
        row = db.get(query)
        db2 = connect(sys.stdout, 'json', use_lock_file=False)
        kvp = row.get('key_value_pairs', {})
        db2.write(row, data=row.get('data'), **kvp)
        return

    if args.long:
        row = db.get(query)
        print(row2str(row))
        return

    if args.open_web_browser:
        try:
            import flask  # noqa
        except ImportError:
            print('Please install Flask: python3 -m pip install flask')
            return
        check_jsmol()
        import ase.db.app as app
        app.add_project(db)
        app.app.run(host='0.0.0.0', debug=True)
        return

    columns = list(all_columns)
    c = args.columns
    if c and c.startswith('++'):
        keys = set()
        for row in db.select(query,
                             limit=args.limit,
                             offset=args.offset,
                             include_data=False):
            keys.update(row._keys)
        columns.extend(keys)
        if c[2:3] == ',':
            c = c[3:]
        else:
            c = ''
    if c:
        if c[0] == '+':
            c = c[1:]
        elif c[0] != '-':
            columns = []
        for col in c.split(','):
            if col[0] == '-':
                columns.remove(col[1:])
            else:
                columns.append(col.lstrip('+'))

    table = Table(db, verbosity=verbosity, cut=args.cut)
    table.select(query, columns, args.sort, args.limit, args.offset)
    if args.csv:
        table.write_csv()
    else:
        table.write(query)
Exemplo n.º 4
0
from ase import Atoms
from ase.db import connect
from pytest import importorskip

importorskip('flask')
import ase.db.app as app

db = connect('test.db', append=False)
x = [0, 1, 2]
t1 = [1, 2, 0]
t2 = [[2, 3], [1, 1], [1, 0]]

atoms = Atoms('H2O')
atoms.center(vacuum=5)
atoms.set_pbc(True)

db.write(atoms, foo=42.0, bar='abc', data={'x': x, 't1': t1, 't2': t2})
app.add_project(db)
app.app.testing = True
c = app.app.test_client()
page = c.get('/').data.decode()
assert 'foo' in page
p1 = c.get('/default/row/1').data.decode()
print(p1)
c.get('/default/json/1').data
c.get('/default/sqlite/1').data
c.get('/default/sqlite?x=1').data
c.get('/default/json?x=1').data