Ejemplo n.º 1
0
def collect_legends(*axs):
    """
    Collect legend data from multiple axes, return input for legend().
    
    Examples
    --------
    >>> fig, ax = mpl.fig_ax()
    >>> ax.plot([1,2,3], label='ax')
    >>> ax2 = ax.twinx()
    >>> ax2.plot([3,2,1], 'r', label='ax2')
    >>> ax.legend(*mpl.collect_legends(ax, ax2))
    """
    axhls = tuple(ax.get_legend_handles_labels() for ax in axs)
    ret = [common.flatten(x) for x in zip(*axhls)]
    return ret[0], ret[1]
Ejemplo n.º 2
0
    def collect_legends(self, axnames=['ax']):
        """If self has more then one axis object attached, then collect legends
        from all axes specified in axnames. Useful for handling legend entries
        of lines on differend axes (in case of twinx, for instance).

        Parameters
        ----------
        axnames : sequence of strings

        Returns
        -------
        tuple of lines and labels
            ([line1, line2, ...], ['foo', 'bar', ...])
        where lines and labels are taken from all axes. Use this as input for
        any axis's legend() method.
        """            
        axhls = [getattr(self, axname).get_legend_handles_labels() for axname in
                 axnames]
        ret = [common.flatten(x) for x in zip(*tuple(axhls))]
        return ret[0], ret[1]
Ejemplo n.º 3
0
    y = np.sin(x)

    fig3, hostax = make_axes_grid_fig(3)

    hostax.set_xlabel('hostax bottom')
    hostax.set_ylabel('hostax left')

    # {'left': (off, wsadd),
    # ...}
    off_dct = dict(left=(60, .1),
                   right=(60, .1),
                   top=(60, .15),
                   bottom=(50, .15))

    for n, val in enumerate(off_dct.items()):
        loc, off, wsadd = tuple(flatten(val))
        fig3, hostax, parax = new_axis(fig3,
                                       hostax=hostax,
                                       loc=loc,
                                       off=off,
                                       label=loc,
                                       wsadd=wsadd)
        parax.plot(x * n, y**n)

    new_axis(fig3,
             hostax=hostax,
             loc='right',
             off=0,
             wsadd=0,
             label="hostax right, I'm like twinx()")
Ejemplo n.º 4
0
def conv_table(xx, yy, ffmt="%15.4f", sfmt="%15s", mode='last', orig=False,
               absdiff=False):
    """Convergence table. Assume that quantity `xx` was varied, resulting in
    `yy` values. Return a string (table) listing::

        x, dy1, dy2, ...
    
    Useful for quickly viewing the results of a convergence study, where we
    assume that the sequence of `yy` values converges to a constant value.

    Parameters
    ----------
    xx : 1d sequence
    yy : 1d sequence, nested 1d sequences, 2d array
        Values varied with `xx`. Each row is one parameter.
    ffmt, sfmt : str
        Format strings for floats (`ffmt`) and strings (`sfmt`)
    mode : str
        'next' or 'last'. Difference to the next value ``y[i+1] - y[i]`` or to
        the last ``y[-1] - y[i]``.
    orig : bool
        Print original `yy` data as well.
    absdiff : bool
        absolute values of differences 

    Examples
    --------
    >>> kpoints = ['2 2 2', '4 4 4', '8 8 8']
    >>> etot = [-300.0, -310.0, -312.0]
    >>> forces_rms = [0.3, 0.2, 0.1]
    >>> print batch.conv_table(kpoints, etot, mode='last')
          2 2 2       -12.0000 
          4 4 4        -2.0000 
          8 8 8         0.0000 
    >>> print batch.conv_table(kpoints, [etot,forces_rms], mode='last')
          2 2 2       -12.0000        -0.2000
          4 4 4        -2.0000        -0.1000
          8 8 8         0.0000         0.0000
    >>> print batch.conv_table(kpoints, [etot,forces_rms], mode='last', orig=True)
          2 2 2       -12.0000      -300.0000        -0.2000         0.3000
          4 4 4        -2.0000      -310.0000        -0.1000         0.2000
          8 8 8         0.0000      -312.0000         0.0000         0.1000
    >>> print batch.conv_table(kpoints, np.array([etot,forces_rms]), mode='next')
          2 2 2       -10.0000        -0.1000
          4 4 4        -2.0000        -0.1000
          8 8 8         0.0000         0.0000
    """
    npoints = len(xx)
    yy = np.asarray(yy).copy()
    if yy.ndim == 1:
        yy = yy[:,None]
    else:
        yy = yy.T
    ny = yy.shape[1]
    dyy = np.empty_like(yy)
    for iy in range(ny):
        if mode == 'next':
            dyy[-1,iy] = 0.0
            dyy[:-1,iy] = np.diff(yy[:,iy])
        elif mode == 'last':    
            dyy[:,iy] = yy[-1,iy] - yy[:,iy]
        else:
            raise StandardError("unknown mode")
    if absdiff:
        dyy = np.abs(dyy)
    if orig:            
        fmtstr = ("%s"*(2*ny+1) + "\n") %((sfmt,) + (ffmt,)*2*ny)
    else:
        fmtstr = ("%s"*(ny+1)   + "\n") %((sfmt,) + (ffmt,)*ny)
    st = ''
    for idx in range(npoints):
        if orig:
            repl = (xx[idx],) + \
                tuple(common.flatten([dyy[idx,iy],yy[idx,iy]] for iy in range(ny)))
        else:
            repl = (xx[idx],) + tuple(dyy[idx,iy] for iy in range(ny))
        st += fmtstr %repl
    return st
Ejemplo n.º 5
0
 def write_input(self, mode='a', backup=True, sleep=0, excl=True):
     """
     Create calculation dir(s) for each parameter set and write input files
     based on ``templates``. Write sqlite database storing all relevant
     parameters. Write (bash) shell script to start all calculations (run
     locally or submitt batch job file, depending on ``machine.subcmd``).
 
     Parameters
     ----------
     mode : str, optional
         Fine tune how to write input files (based on ``templates``) to calc
         dirs calc_foo/0/, calc_foo/1/, ... . Note that this doesn't change
         the base dir calc_foo at all, only the subdirs for each calc.
         {'a', 'w'}
         
         | 'a': Append mode (default). If a previous database is found, then
         |     subsequent calculations are numbered based on the last 'idx'.
         |     calc_foo/0 # old
         |     calc_foo/1 # old
         |     calc_foo/2 # new
         |     calc_foo/3 # new
         | 'w': Write mode. The target dirs are purged and overwritten. Also,
         |     the database (self.dbfn) is overwritten. Use this to
         |     iteratively tune your inputs, NOT for working on already
         |     present results!
         |     calc_foo/0 # new
         |     calc_foo/1 # new
     backup : bool, optional
         Before writing anything, do a backup of self.calc_dir if it already
         exists.
     sleep : int, optional
         For the script to start (submitt) all jobs: time in seconds for the
         shell sleep(1) commmand.
     excl : bool
         If in append mode, a file <calc_root>/excl_push with all indices of
         calculations from old revisions is written. Can be used with
         ``rsync --exclude-from=excl_push`` when pushing appended new
         calculations to a cluster.
     """
     assert mode in ['a', 'w'], "Unknown mode: '%s'" %mode
     if os.path.exists(self.dbfn):
         if backup:
             common.backup(self.dbfn)
         if mode == 'w':
             os.remove(self.dbfn)
     have_new_db = not os.path.exists(self.dbfn)
     common.makedirs(self.calc_root)
     # this call creates a file ``self.dbfn`` if it doesn't exist
     sqldb = SQLiteDB(self.dbfn, table=self.db_table)
     # max_idx: counter for calc dir numbering
     revision = 0
     if have_new_db:
         max_idx = -1
     else:
         if mode == 'a':
             if sqldb.has_column('idx'):
                 max_idx = sqldb.execute("select max(idx) from %s" \
                 %self.db_table).fetchone()[0]
             else:
                 raise StandardError("database '%s': table '%s' has no "
                       "column 'idx', don't know how to number calcs"
                       %(self.dbfn, self.db_table))
             if sqldb.has_column('revision'):
                 revision = int(sqldb.get_single("select max(revision) \
                     from %s" %self.db_table)) + 1
         elif mode == 'w':
             max_idx = -1
     sql_records = []
     hostnames = []
     for imach, machine in enumerate(self.machines):
         hostnames.append(machine.hostname)
         calc_dir = pj(self.calc_root, self.calc_dir_prefix + \
                       '_%s' %machine.hostname)
         if os.path.exists(calc_dir):
             if backup:
                 common.backup(calc_dir)
             if mode == 'w':
                 common.system("rm -r %s" %calc_dir, wait=True)
         run_txt = "here=$(pwd)\n"
         for _idx, params in enumerate(self.params_lst):
             params = common.flatten(params)
             idx = max_idx + _idx + 1
             calc_subdir = pj(calc_dir, str(idx))
             extra_dct = \
                 {'revision': revision,
                  'study_name': self.study_name,
                  'idx': idx,
                  'calc_name' : self.study_name + "_run%i" %idx,
                  }
             extra_params = [SQLEntry(key=key, sqlval=val) for key,val in \
                             extra_dct.iteritems()]
             # templates[:] to copy b/c they may be modified in Calculation
             calc = Calculation(machine=machine,
                                templates=self.templates[:], 
                                params=params + extra_params,
                                calc_dir=calc_subdir,
                                )
             if mode == 'w' and os.path.exists(calc_subdir):
                 shutil.rmtree(calc_subdir)
             calc.write_input()                               
             run_txt += "cd %i && %s %s && cd $here && sleep %i\n" %(idx,\
                         machine.subcmd, machine.get_jobfile_basename(), sleep)
             if imach == 0:                            
                 sql_records.append(calc.get_sql_record())
         common.file_write(pj(calc_dir, 'run.sh'), run_txt)
     for record in sql_records:
         record['hostname'] = SQLEntry(sqlval=','.join(hostnames))
     # for incomplete parameters: collect header parts from all records and
     # make a set = unique entries
     raw_header = [(key, entry.sqltype.upper()) for record in sql_records \
         for key, entry in record.iteritems()]
     header = list(set(raw_header))
     if have_new_db:
         sqldb.create_table(header)
     else:
         for record in sql_records:
             for key, entry in record.iteritems():
                 if not sqldb.has_column(key):
                     sqldb.add_column(key, entry.sqltype.upper())
     for record in sql_records:
         cmd = "insert into %s (%s) values (%s)"\
             %(self.db_table,
               ",".join(record.keys()),
               ",".join(['?']*len(record.keys())))
         sqldb.execute(cmd, tuple(entry.sqlval for entry in record.itervalues()))
     if excl and revision > 0 and sqldb.has_column('revision'):
         old_idx_lst = [str(x) for x, in sqldb.execute("select idx from calc where \
                                                       revision < ?", (revision,))]
         common.file_write(pj(self.calc_root, 'excl_push'),
                           '\n'.join(old_idx_lst))
     sqldb.finish()
Ejemplo n.º 6
0
def nested_loops(lists, ret_all=False, flatten=False):
    """Nonrecursive version of nested loops of arbitrary depth. Pure Python
    version (no numpy).
    
    Parameters
    ----------
    lists : list of lists 
        The objects to permute. len(lists) == the depth (nesting levels) of the
        equivalent nested loops. Individual lists may contain a mix of
        different types/objects, e.g. [['a', 'b'], [Foo(), Bar(), Baz()],
        [1,2,3,4,5,6,7]].
    ret_all : bool
        True: return perms, perm_idxs
        False: return perms
    flatten : bool
        Flatten each entry in returned list. 

    Returns
    -------
    perms : list of lists with permuted objects
    perm_idxs : list of lists with indices of the permutation if ret_all=True

    Notes
    -----
    In Python >= 2.6, this is almost the same as itertools.product() but was
    written before that was in itertools.
    
    >>> [x for x in itertools.product([1,2],[33,44,55],[sin,cos])]
    >>> nested_loops([[1,2],[33,44,55],[sin,cos]])
    
    Note that nested_loops() takes a list of lists, while itertools.product()
    only the lists itself.

    Examples
    --------
    >>> a=[1,2]; b=[3,4]; c=[5,6];
    >>> perms=[]
    >>> for aa in a:
    ....:   for bb in b:
    ....:       for cc in c:
    ....:           perms.append([aa,bb,cc])
    ....:             
    >>> perms
    [[1, 3, 5],
     [1, 3, 6],
     [1, 4, 5],
     [1, 4, 6],
     [2, 3, 5],
     [2, 3, 6],
     [2, 4, 5],
     [2, 4, 6]]
    >>> nested_loops([a,b,c], ret_all=True)
    ([[1, 3, 5],
      [1, 3, 6],
      [1, 4, 5],
      [1, 4, 6],
      [2, 3, 5],
      [2, 3, 6],
      [2, 4, 5],
      [2, 4, 6]],
     [[0, 0, 0],
      [0, 0, 1],
      [0, 1, 0],
      [0, 1, 1],
      [1, 0, 0],
      [1, 0, 1],
      [1, 1, 0],
      [1, 1, 1]])
    >>> nested_loops([[1,2], ['a','b','c'], [sin, cos]])
    [[1, 'a', <ufunc 'sin'>],
     [1, 'a', <ufunc 'cos'>],
     [1, 'b', <ufunc 'sin'>],
     [1, 'b', <ufunc 'cos'>],
     [1, 'c', <ufunc 'sin'>],
     [1, 'c', <ufunc 'cos'>],
     [2, 'a', <ufunc 'sin'>],
     [2, 'a', <ufunc 'cos'>],
     [2, 'b', <ufunc 'sin'>],
     [2, 'b', <ufunc 'cos'>],
     [2, 'c', <ufunc 'sin'>],
     [2, 'c', <ufunc 'cos'>]]
    # If values of different lists should be varied together, use zip(). Note
    # that you get nested lists back. Use flatten=True to get flattened lists.
    >>> nested_loops([zip([1,2], ['a', 'b']), [88, 99]])
    [[(1, 'a'), 88], [(1, 'a'), 99], [(2, 'b'), 88], [(2, 'b'), 99]]
    >>> nested_loops([zip([1,2], ['a', 'b']), [88, 99]], flatten=True)
    [[1, 'a', 88], [1, 'a', 99], [2, 'b', 88], [2, 'b', 99]]
    """
    lens = map(len, lists)
    mx_idxs = [x - 1 for x in lens]
    # nperms = numpy.prod(lens)
    nperms = reduce(lambda x,y: x*y, lens)
    # number of nesting levels
    nlevels = len(lists)
    # index into `lists`: lists[i][j] -> lists[i][idxs[i]], i.e.
    # idxs[i] is the index into the ith list
    idxs = [0]*nlevels
    perm_idxs = []
    perms = []
    # e.g. [2,1,0]
    rev_rlevels = range(nlevels)[::-1]
    for i in range(nperms):         
        for pos in rev_rlevels:
            if idxs[pos] > mx_idxs[pos]:
                idxs[pos] = 0
                # pos - 1 never gets < 0 before all possible `nlevels`
                # permutations are generated.
                idxs[pos-1] += 1
        # [:] to append a copy                
        perm_idxs.append(idxs[:])
        perms.append([lists[j][k] for j,k in enumerate(idxs)])
        idxs[-1] += 1
    perms = [common.flatten(xx) for xx in perms] if flatten else perms
    if ret_all:
        return perms, perm_idxs
    else:
        return perms
Ejemplo n.º 7
0
 def get_list1d(self, *args, **kwargs):
     """Shortcut for commonly used functionality. If one extracts a single
     column, then ``self.cur.fetchall()`` returns a list of tuples like
     ``[(1,), (2,)]`` We call ``fetchall()`` and return the flattened list. 
     """
     return common.flatten(self.execute(*args, **kwargs).fetchall())
Ejemplo n.º 8
0
    y = np.sin(x)

    fig3, hostax = make_axes_grid_fig(3)
    
    hostax.set_xlabel('hostax bottom')
    hostax.set_ylabel('hostax left')

    # {'left': (off, wsadd),
    # ...}
    off_dct = dict(left=(60, .1), 
                   right=(60, .1), 
                   top=(60, .15), 
                   bottom=(50, .15))

    for n, val in enumerate(off_dct.iteritems()):
        loc, off, wsadd = tuple(flatten(val))
        fig3, hostax, parax = new_axis(fig3, hostax=hostax, 
                                       loc=loc, off=off, label=loc, 
                                       wsadd=wsadd)
        parax.plot(x*n, y**n)
    
    new_axis(fig3, hostax=hostax, loc='right', off=0, wsadd=0, 
             label="hostax right, I'm like twinx()")
    
    new_axis(fig3, hostax=hostax, loc='top', off=0, wsadd=0, 
             label="hostax top, I'm like twiny()")
    

    # many lines 

    fig4, hostax = make_axes_grid_fig(4)
Ejemplo n.º 9
0
def conv_table(xx,
               yy,
               ffmt="%15.4f",
               sfmt="%15s",
               mode='last',
               orig=False,
               absdiff=False):
    """Convergence table. Assume that quantity `xx` was varied, resulting in
    `yy` values. Return a string (table) listing::

        x, dy1, dy2, ...
    
    Useful for quickly viewing the results of a convergence study, where we
    assume that the sequence of `yy` values converges to a constant value.

    Parameters
    ----------
    xx : 1d sequence
    yy : 1d sequence, nested 1d sequences, 2d array
        Values varied with `xx`. Each row is one parameter.
    ffmt, sfmt : str
        Format strings for floats (`ffmt`) and strings (`sfmt`)
    mode : str
        'next' or 'last'. Difference to the next value ``y[i+1] - y[i]`` or to
        the last ``y[-1] - y[i]``.
    orig : bool
        Print original `yy` data as well.
    absdiff : bool
        absolute values of differences 

    Examples
    --------
    >>> kpoints = ['2 2 2', '4 4 4', '8 8 8']
    >>> etot = [-300.0, -310.0, -312.0]
    >>> forces_rms = [0.3, 0.2, 0.1]
    >>> print batch.conv_table(kpoints, etot, mode='last')
          2 2 2       -12.0000 
          4 4 4        -2.0000 
          8 8 8         0.0000 
    >>> print batch.conv_table(kpoints, [etot,forces_rms], mode='last')
          2 2 2       -12.0000        -0.2000
          4 4 4        -2.0000        -0.1000
          8 8 8         0.0000         0.0000
    >>> print batch.conv_table(kpoints, [etot,forces_rms], mode='last', orig=True)
          2 2 2       -12.0000      -300.0000        -0.2000         0.3000
          4 4 4        -2.0000      -310.0000        -0.1000         0.2000
          8 8 8         0.0000      -312.0000         0.0000         0.1000
    >>> print batch.conv_table(kpoints, np.array([etot,forces_rms]), mode='next')
          2 2 2       -10.0000        -0.1000
          4 4 4        -2.0000        -0.1000
          8 8 8         0.0000         0.0000
    """
    npoints = len(xx)
    yy = np.asarray(yy).copy()
    if yy.ndim == 1:
        yy = yy[:, None]
    else:
        yy = yy.T
    ny = yy.shape[1]
    dyy = np.empty_like(yy)
    for iy in range(ny):
        if mode == 'next':
            dyy[-1, iy] = 0.0
            dyy[:-1, iy] = np.diff(yy[:, iy])
        elif mode == 'last':
            dyy[:, iy] = yy[-1, iy] - yy[:, iy]
        else:
            raise Exception("unknown mode")
    if absdiff:
        dyy = np.abs(dyy)
    if orig:
        fmtstr = ("%s" * (2 * ny + 1) + "\n") % ((sfmt, ) + (ffmt, ) * 2 * ny)
    else:
        fmtstr = ("%s" * (ny + 1) + "\n") % ((sfmt, ) + (ffmt, ) * ny)
    st = ''
    for idx in range(npoints):
        if orig:
            repl = (xx[idx],) + \
                tuple(common.flatten([dyy[idx,iy],yy[idx,iy]] for iy in range(ny)))
        else:
            repl = (xx[idx], ) + tuple(dyy[idx, iy] for iy in range(ny))
        st += fmtstr % repl
    return st
Ejemplo n.º 10
0
 def write_input(self, mode='a', backup=True, sleep=0, excl=True):
     """
     Create calculation dir(s) for each parameter set and write input files
     based on ``templates``. Write sqlite database storing all relevant
     parameters. Write (bash) shell script to start all calculations (run
     locally or submitt batch job file, depending on ``machine.subcmd``).
 
     Parameters
     ----------
     mode : str, optional
         Fine tune how to write input files (based on ``templates``) to calc
         dirs calc_foo/0/, calc_foo/1/, ... . Note that this doesn't change
         the base dir calc_foo at all, only the subdirs for each calc.
         {'a', 'w'}
         
         | 'a': Append mode (default). If a previous database is found, then
         |     subsequent calculations are numbered based on the last 'idx'.
         |     calc_foo/0 # old
         |     calc_foo/1 # old
         |     calc_foo/2 # new
         |     calc_foo/3 # new
         | 'w': Write mode. The target dirs are purged and overwritten. Also,
         |     the database (self.dbfn) is overwritten. Use this to
         |     iteratively tune your inputs, NOT for working on already
         |     present results!
         |     calc_foo/0 # new
         |     calc_foo/1 # new
     backup : bool, optional
         Before writing anything, do a backup of self.calc_dir if it already
         exists.
     sleep : int, optional
         For the script to start (submitt) all jobs: time in seconds for the
         shell sleep(1) commmand.
     excl : bool
         If in append mode, a file <calc_root>/excl_push with all indices of
         calculations from old revisions is written. Can be used with
         ``rsync --exclude-from=excl_push`` when pushing appended new
         calculations to a cluster.
     """
     assert mode in ['a', 'w'], "Unknown mode: '%s'" % mode
     if os.path.exists(self.dbfn):
         if backup:
             common.backup(self.dbfn)
         if mode == 'w':
             os.remove(self.dbfn)
     have_new_db = not os.path.exists(self.dbfn)
     common.makedirs(self.calc_root)
     # this call creates a file ``self.dbfn`` if it doesn't exist
     sqldb = SQLiteDB(self.dbfn, table=self.db_table)
     # max_idx: counter for calc dir numbering
     revision = 0
     if have_new_db:
         max_idx = -1
     else:
         if mode == 'a':
             if sqldb.has_column('idx'):
                 max_idx = sqldb.execute("select max(idx) from %s" \
                 %self.db_table).fetchone()[0]
             else:
                 raise Exception(
                     "database '%s': table '%s' has no "
                     "column 'idx', don't know how to number calcs" %
                     (self.dbfn, self.db_table))
             if sqldb.has_column('revision'):
                 revision = int(
                     sqldb.get_single("select max(revision) \
                     from %s" % self.db_table)) + 1
         elif mode == 'w':
             max_idx = -1
     sql_records = []
     hostnames = []
     for imach, machine in enumerate(self.machines):
         hostnames.append(machine.hostname)
         calc_dir = pj(self.calc_root, self.calc_dir_prefix + \
                       '_%s' %machine.hostname)
         if os.path.exists(calc_dir):
             if backup:
                 common.backup(calc_dir)
             if mode == 'w':
                 common.system("rm -r %s" % calc_dir, wait=True)
         run_txt = "here=$(pwd)\n"
         for _idx, params in enumerate(self.params_lst):
             params = common.flatten(params)
             idx = max_idx + _idx + 1
             calc_subdir = pj(calc_dir, str(idx))
             extra_dct = \
                 {'revision': revision,
                  'study_name': self.study_name,
                  'idx': idx,
                  'calc_name' : self.study_name + "_run%i" %idx,
                  }
             extra_params = [SQLEntry(key=key, sqlval=val) for key,val in \
                             extra_dct.items()]
             # templates[:] to copy b/c they may be modified in Calculation
             calc = Calculation(
                 machine=machine,
                 templates=self.templates[:],
                 params=params + extra_params,
                 calc_dir=calc_subdir,
             )
             if mode == 'w' and os.path.exists(calc_subdir):
                 shutil.rmtree(calc_subdir)
             calc.write_input()
             run_txt += "cd %i && %s %s && cd $here && sleep %i\n" %(idx,\
                         machine.subcmd, machine.get_jobfile_basename(), sleep)
             if imach == 0:
                 sql_records.append(calc.get_sql_record())
         common.file_write(pj(calc_dir, 'run.sh'), run_txt)
     for record in sql_records:
         record['hostname'] = SQLEntry(sqlval=','.join(hostnames))
     # for incomplete parameters: collect header parts from all records and
     # make a set = unique entries
     raw_header = [(key, entry.sqltype.upper()) for record in sql_records \
         for key, entry in record.items()]
     header = list(set(raw_header))
     if have_new_db:
         sqldb.create_table(header)
     else:
         for record in sql_records:
             for key, entry in record.items():
                 if not sqldb.has_column(key):
                     sqldb.add_column(key, entry.sqltype.upper())
     for record in sql_records:
         cmd = "insert into %s (%s) values (%s)"\
             %(self.db_table,
               ",".join(list(record.keys())),
               ",".join(['?']*len(list(record.keys()))))
         sqldb.execute(cmd,
                       tuple(entry.sqlval for entry in record.values()))
     if excl and revision > 0 and sqldb.has_column('revision'):
         old_idx_lst = [
             str(x) for x, in sqldb.execute(
                 "select idx from calc where \
                                                       revision < ?", (
                     revision, ))
         ]
         common.file_write(pj(self.calc_root, 'excl_push'),
                           '\n'.join(old_idx_lst))
     sqldb.finish()
Ejemplo n.º 11
0
 def get_list1d(self, *args, **kwargs):
     """Shortcut for commonly used functionality. If one extracts a single
     column, then ``self.cur.fetchall()`` returns a list of tuples like
     ``[(1,), (2,)]`` We call ``fetchall()`` and return the flattened list. 
     """
     return common.flatten(self.execute(*args, **kwargs).fetchall())
Ejemplo n.º 12
0
def test_scell():
    cell = np.identity(3)
    coords_frac = np.array([[0.5, 0.5, 0.5],
                       [1,1,1]])
    symbols = ['Al', 'N']
    sc = crys.scell(Structure(coords_frac=coords_frac,
                              cell=cell,
                              symbols=symbols), (2,2,2))

    sc_coords_frac = \
        np.array([[ 0.25,  0.25,  0.25],
                  [ 0.25,  0.25,  0.75],
                  [ 0.25,  0.75,  0.25],
                  [ 0.25,  0.75,  0.75],
                  [ 0.75,  0.25,  0.25],
                  [ 0.75,  0.25,  0.75],
                  [ 0.75,  0.75,  0.25],
                  [ 0.75,  0.75,  0.75],
                  [ 0.5 ,  0.5 ,  0.5 ],
                  [ 0.5 ,  0.5 ,  1.  ],
                  [ 0.5 ,  1.  ,  0.5 ],
                  [ 0.5 ,  1.  ,  1.  ],
                  [ 1.  ,  0.5 ,  0.5 ],
                  [ 1.  ,  0.5 ,  1.  ],
                  [ 1.  ,  1.  ,  0.5 ],
                  [ 1.  ,  1.  ,  1.  ]])

    sc_symbols = ['Al']*8 + ['N']*8
    sc_cell = \
        np.array([[ 2.,  0.,  0.],
                  [ 0.,  2.,  0.],
                  [ 0.,  0.,  2.]])

    assert sc.symbols == sc_symbols
    np.testing.assert_array_almost_equal(sc.coords_frac, sc_coords_frac)
    np.testing.assert_array_almost_equal(sc.cell, sc_cell)
    
    # non-orthorhombic cell
    cell = \
        np.array([[ 1.,  0.5,  0.5],
                  [ 0.25,  1.,  0.2],
                  [ 0.2,  0.5,  1.]])

    sc = crys.scell(Structure(coords_frac=coords_frac,
                              cell=cell,
                              symbols=symbols), (2,2,2))
    sc_cell = \
        np.array([[ 2. ,  1. ,  1. ],
                  [ 0.5,  2. ,  0.4],
                  [ 0.4,  1. ,  2. ]])
    np.testing.assert_array_almost_equal(sc.cell, sc_cell)
    # crystal coords are cell-independent
    np.testing.assert_array_almost_equal(sc.coords_frac, sc_coords_frac)

    
    # slab
    #
    # Test if old and new implementation behave for a tricky case: natoms == 2
    # mask.shape[0], i.e. if reshape() behaves correctly. 
    # Reference generated with old implementation. Default is new.
    cell = np.identity(3)
    coords_frac = np.array([[0.5, 0.5, 0.5],
                       [1,1,1]])
    symbols = ['Al', 'N']
    sc = crys.scell(Structure(coords_frac=coords_frac,
                              cell=cell,
                              symbols=symbols), (1,1,2))
    sc_coords_frac = \
        np.array([[ 0.5 ,  0.5 ,  0.25],
                  [ 0.5 ,  0.5 ,  0.75],
                  [ 1.  ,  1.  ,  0.5 ],
                  [ 1.  ,  1.  ,  1.  ]])
    sc_cell = \
        np.array([[ 1.,  0.,  0.],
                  [ 0.,  1.,  0.],
                  [ 0.,  0.,  2.]])
    sc_symbols = ['Al', 'Al', 'N', 'N']
    assert sc.symbols == sc_symbols
    np.testing.assert_array_almost_equal(sc.cell, sc_cell)
    np.testing.assert_array_almost_equal(sc.coords_frac, sc_coords_frac)
    
    sc = crys.scell(Structure(coords_frac=coords_frac,
                              cell=cell,
                              symbols=symbols), (1,2,1))
    sc_coords_frac = \
        np.array([[ 0.5 ,  0.25,  0.5 ],
                  [ 0.5 ,  0.75,  0.5 ],
                  [ 1.  ,  0.5 ,  1.  ],
                  [ 1.  ,  1.  ,  1.  ]])
    sc_cell = \
        np.array([[ 1.,  0.,  0.],
                  [ 0.,  2.,  0.],
                  [ 0.,  0.,  1.]])
    assert sc.symbols == sc_symbols
    np.testing.assert_array_almost_equal(sc.cell, sc_cell)
    np.testing.assert_array_almost_equal(sc.coords_frac, sc_coords_frac)

    sc = crys.scell(Structure(coords_frac=coords_frac,
                              cell=cell,
                              symbols=symbols), (2,1,1))
    sc_coords_frac = \
        np.array([[ 0.25,  0.5 ,  0.5 ],
                  [ 0.75,  0.5 ,  0.5 ],
                  [ 0.5 ,  1.  ,  1.  ],
                  [ 1.  ,  1.  ,  1.  ]])
    sc_cell = \
        np.array([[ 2.,  0.,  0.],
                  [ 0.,  1.,  0.],
                  [ 0.,  0.,  1.]])
    assert sc.symbols == sc_symbols
    np.testing.assert_array_almost_equal(sc.cell, sc_cell)
    np.testing.assert_array_almost_equal(sc.coords_frac, sc_coords_frac)
    
    # symbols = None
    sc = crys.scell(Structure(coords_frac=coords_frac,
                              cell=cell,
                              symbols=None), (2,2,2))
    assert sc.symbols is None
    
    # Trajectory
    natoms = 4
    nstep = 100
    symbols = [syms.next() for ii in range(natoms)]
    # cell 2d
    coords_frac = rand(nstep,natoms,3)
    cell = rand(3,3)
    dims = (2,3,4)
    nmask = np.prod(dims)
    sc = crys.scell(Trajectory(coords_frac=coords_frac,
                               cell=cell,
                               symbols=symbols), 
                      dims=dims)
    assert sc.coords_frac.shape == (nstep, nmask*natoms, 3)
    assert sc.symbols == common.flatten([[sym]*nmask for sym in symbols])
    assert sc.cell.shape == (nstep,3,3)                                            
    np.testing.assert_array_almost_equal(sc.cell, 
                                         num.extend_array(cell * np.asarray(dims)[:,None], 
                                                          sc.nstep,
                                                          axis=0))
    # cell 3d
    cell = rand(nstep,3,3)
    sc = crys.scell(Trajectory(coords_frac=coords_frac,
                               cell=cell,
                               symbols=symbols), 
                      dims=dims)
    assert sc.coords_frac.shape == (nstep, nmask*natoms, 3)
    coords_frac2 = np.array([crys.scell(Structure(coords_frac=coords_frac[ii,...,], 
                                                  cell=cell[ii,...],
                                                  symbols=symbols), dims=dims).coords_frac \
                             for ii in range(nstep)])
    np.testing.assert_array_almost_equal(sc.coords_frac, coords_frac2) 
    assert sc.symbols == common.flatten([[sym]*nmask for sym in symbols])
    assert sc.cell.shape == (nstep,3,3) 
    np.testing.assert_array_almost_equal(sc.cell, 
                                         cell * np.asarray(dims)[None,:,None])
    
    # methods
    natoms = 20
    coords_frac = rand(natoms,3)
    cell = rand(3,3)
    dims = (2,3,4)
    symbols = [syms.next() for ii in range(natoms)]
    struct = Structure(coords_frac=coords_frac,
                       cell=cell,
                       symbols=symbols)
    sc1 = crys.scell(struct, dims=dims, method=1)
    sc2 = crys.scell(struct, dims=dims, method=2)
    d1 = dict([(key, getattr(sc1, key)) for key in sc1.attr_lst])
    d2 = dict([(key, getattr(sc2, key)) for key in sc2.attr_lst])
    tools.assert_dict_with_all_types_almost_equal(d1, d2)
Ejemplo n.º 13
0
def test_scell():
    cell = np.identity(3)
    coords_frac = np.array([[0.5, 0.5, 0.5], [1, 1, 1]])
    symbols = ['Al', 'N']
    sc = crys.scell(
        Structure(coords_frac=coords_frac, cell=cell, symbols=symbols),
        (2, 2, 2))

    sc_coords_frac = \
        np.array([[ 0.25,  0.25,  0.25],
                  [ 0.25,  0.25,  0.75],
                  [ 0.25,  0.75,  0.25],
                  [ 0.25,  0.75,  0.75],
                  [ 0.75,  0.25,  0.25],
                  [ 0.75,  0.25,  0.75],
                  [ 0.75,  0.75,  0.25],
                  [ 0.75,  0.75,  0.75],
                  [ 0.5 ,  0.5 ,  0.5 ],
                  [ 0.5 ,  0.5 ,  1.  ],
                  [ 0.5 ,  1.  ,  0.5 ],
                  [ 0.5 ,  1.  ,  1.  ],
                  [ 1.  ,  0.5 ,  0.5 ],
                  [ 1.  ,  0.5 ,  1.  ],
                  [ 1.  ,  1.  ,  0.5 ],
                  [ 1.  ,  1.  ,  1.  ]])

    sc_symbols = ['Al'] * 8 + ['N'] * 8
    sc_cell = \
        np.array([[ 2.,  0.,  0.],
                  [ 0.,  2.,  0.],
                  [ 0.,  0.,  2.]])

    assert sc.symbols == sc_symbols
    np.testing.assert_array_almost_equal(sc.coords_frac, sc_coords_frac)
    np.testing.assert_array_almost_equal(sc.cell, sc_cell)

    # non-orthorhombic cell
    cell = \
        np.array([[ 1.,  0.5,  0.5],
                  [ 0.25,  1.,  0.2],
                  [ 0.2,  0.5,  1.]])

    sc = crys.scell(
        Structure(coords_frac=coords_frac, cell=cell, symbols=symbols),
        (2, 2, 2))
    sc_cell = \
        np.array([[ 2. ,  1. ,  1. ],
                  [ 0.5,  2. ,  0.4],
                  [ 0.4,  1. ,  2. ]])
    np.testing.assert_array_almost_equal(sc.cell, sc_cell)
    # crystal coords are cell-independent
    np.testing.assert_array_almost_equal(sc.coords_frac, sc_coords_frac)

    # slab
    #
    # Test if old and new implementation behave for a tricky case: natoms == 2
    # mask.shape[0], i.e. if reshape() behaves correctly.
    # Reference generated with old implementation. Default is new.
    cell = np.identity(3)
    coords_frac = np.array([[0.5, 0.5, 0.5], [1, 1, 1]])
    symbols = ['Al', 'N']
    sc = crys.scell(
        Structure(coords_frac=coords_frac, cell=cell, symbols=symbols),
        (1, 1, 2))
    sc_coords_frac = \
        np.array([[ 0.5 ,  0.5 ,  0.25],
                  [ 0.5 ,  0.5 ,  0.75],
                  [ 1.  ,  1.  ,  0.5 ],
                  [ 1.  ,  1.  ,  1.  ]])
    sc_cell = \
        np.array([[ 1.,  0.,  0.],
                  [ 0.,  1.,  0.],
                  [ 0.,  0.,  2.]])
    sc_symbols = ['Al', 'Al', 'N', 'N']
    assert sc.symbols == sc_symbols
    np.testing.assert_array_almost_equal(sc.cell, sc_cell)
    np.testing.assert_array_almost_equal(sc.coords_frac, sc_coords_frac)

    sc = crys.scell(
        Structure(coords_frac=coords_frac, cell=cell, symbols=symbols),
        (1, 2, 1))
    sc_coords_frac = \
        np.array([[ 0.5 ,  0.25,  0.5 ],
                  [ 0.5 ,  0.75,  0.5 ],
                  [ 1.  ,  0.5 ,  1.  ],
                  [ 1.  ,  1.  ,  1.  ]])
    sc_cell = \
        np.array([[ 1.,  0.,  0.],
                  [ 0.,  2.,  0.],
                  [ 0.,  0.,  1.]])
    assert sc.symbols == sc_symbols
    np.testing.assert_array_almost_equal(sc.cell, sc_cell)
    np.testing.assert_array_almost_equal(sc.coords_frac, sc_coords_frac)

    sc = crys.scell(
        Structure(coords_frac=coords_frac, cell=cell, symbols=symbols),
        (2, 1, 1))
    sc_coords_frac = \
        np.array([[ 0.25,  0.5 ,  0.5 ],
                  [ 0.75,  0.5 ,  0.5 ],
                  [ 0.5 ,  1.  ,  1.  ],
                  [ 1.  ,  1.  ,  1.  ]])
    sc_cell = \
        np.array([[ 2.,  0.,  0.],
                  [ 0.,  1.,  0.],
                  [ 0.,  0.,  1.]])
    assert sc.symbols == sc_symbols
    np.testing.assert_array_almost_equal(sc.cell, sc_cell)
    np.testing.assert_array_almost_equal(sc.coords_frac, sc_coords_frac)

    # symbols = None
    sc = crys.scell(
        Structure(coords_frac=coords_frac, cell=cell, symbols=None), (2, 2, 2))
    assert sc.symbols is None

    # Trajectory
    natoms = 4
    nstep = 100
    symbols = [next(syms) for ii in range(natoms)]
    # cell 2d
    coords_frac = rand(nstep, natoms, 3)
    cell = rand(3, 3)
    dims = (2, 3, 4)
    nmask = np.prod(dims)
    sc = crys.scell(Trajectory(coords_frac=coords_frac,
                               cell=cell,
                               symbols=symbols),
                    dims=dims)
    assert sc.coords_frac.shape == (nstep, nmask * natoms, 3)
    assert sc.symbols == common.flatten([[sym] * nmask for sym in symbols])
    assert sc.cell.shape == (nstep, 3, 3)
    np.testing.assert_array_almost_equal(
        sc.cell,
        num.extend_array(cell * np.asarray(dims)[:, None], sc.nstep, axis=0))
    # cell 3d
    cell = rand(nstep, 3, 3)
    sc = crys.scell(Trajectory(coords_frac=coords_frac,
                               cell=cell,
                               symbols=symbols),
                    dims=dims)
    assert sc.coords_frac.shape == (nstep, nmask * natoms, 3)
    coords_frac2 = np.array([crys.scell(Structure(coords_frac=coords_frac[ii,...,],
                                                  cell=cell[ii,...],
                                                  symbols=symbols), dims=dims).coords_frac \
                             for ii in range(nstep)])
    np.testing.assert_array_almost_equal(sc.coords_frac, coords_frac2)
    assert sc.symbols == common.flatten([[sym] * nmask for sym in symbols])
    assert sc.cell.shape == (nstep, 3, 3)
    np.testing.assert_array_almost_equal(
        sc.cell,
        cell * np.asarray(dims)[None, :, None])

    # methods
    natoms = 20
    coords_frac = rand(natoms, 3)
    cell = rand(3, 3)
    dims = (2, 3, 4)
    symbols = [next(syms) for ii in range(natoms)]
    struct = Structure(coords_frac=coords_frac, cell=cell, symbols=symbols)
    sc1 = crys.scell(struct, dims=dims, method=1)
    sc2 = crys.scell(struct, dims=dims, method=2)
    d1 = dict([(key, getattr(sc1, key)) for key in sc1.attr_lst])
    d2 = dict([(key, getattr(sc2, key)) for key in sc2.attr_lst])
    tools.assert_dict_with_all_types_almost_equal(d1, d2)