def collect_legends(*axs): """ Collect legend data from multiple axes, return input for legend(). Examples -------- >>> fig, ax = mpl.fig_ax() >>> ax.plot([1,2,3], label='ax') >>> ax2 = ax.twinx() >>> ax2.plot([3,2,1], 'r', label='ax2') >>> ax.legend(*mpl.collect_legends(ax, ax2)) """ axhls = tuple(ax.get_legend_handles_labels() for ax in axs) ret = [common.flatten(x) for x in zip(*axhls)] return ret[0], ret[1]
def collect_legends(self, axnames=['ax']): """If self has more then one axis object attached, then collect legends from all axes specified in axnames. Useful for handling legend entries of lines on differend axes (in case of twinx, for instance). Parameters ---------- axnames : sequence of strings Returns ------- tuple of lines and labels ([line1, line2, ...], ['foo', 'bar', ...]) where lines and labels are taken from all axes. Use this as input for any axis's legend() method. """ axhls = [getattr(self, axname).get_legend_handles_labels() for axname in axnames] ret = [common.flatten(x) for x in zip(*tuple(axhls))] return ret[0], ret[1]
y = np.sin(x) fig3, hostax = make_axes_grid_fig(3) hostax.set_xlabel('hostax bottom') hostax.set_ylabel('hostax left') # {'left': (off, wsadd), # ...} off_dct = dict(left=(60, .1), right=(60, .1), top=(60, .15), bottom=(50, .15)) for n, val in enumerate(off_dct.items()): loc, off, wsadd = tuple(flatten(val)) fig3, hostax, parax = new_axis(fig3, hostax=hostax, loc=loc, off=off, label=loc, wsadd=wsadd) parax.plot(x * n, y**n) new_axis(fig3, hostax=hostax, loc='right', off=0, wsadd=0, label="hostax right, I'm like twinx()")
def conv_table(xx, yy, ffmt="%15.4f", sfmt="%15s", mode='last', orig=False, absdiff=False): """Convergence table. Assume that quantity `xx` was varied, resulting in `yy` values. Return a string (table) listing:: x, dy1, dy2, ... Useful for quickly viewing the results of a convergence study, where we assume that the sequence of `yy` values converges to a constant value. Parameters ---------- xx : 1d sequence yy : 1d sequence, nested 1d sequences, 2d array Values varied with `xx`. Each row is one parameter. ffmt, sfmt : str Format strings for floats (`ffmt`) and strings (`sfmt`) mode : str 'next' or 'last'. Difference to the next value ``y[i+1] - y[i]`` or to the last ``y[-1] - y[i]``. orig : bool Print original `yy` data as well. absdiff : bool absolute values of differences Examples -------- >>> kpoints = ['2 2 2', '4 4 4', '8 8 8'] >>> etot = [-300.0, -310.0, -312.0] >>> forces_rms = [0.3, 0.2, 0.1] >>> print batch.conv_table(kpoints, etot, mode='last') 2 2 2 -12.0000 4 4 4 -2.0000 8 8 8 0.0000 >>> print batch.conv_table(kpoints, [etot,forces_rms], mode='last') 2 2 2 -12.0000 -0.2000 4 4 4 -2.0000 -0.1000 8 8 8 0.0000 0.0000 >>> print batch.conv_table(kpoints, [etot,forces_rms], mode='last', orig=True) 2 2 2 -12.0000 -300.0000 -0.2000 0.3000 4 4 4 -2.0000 -310.0000 -0.1000 0.2000 8 8 8 0.0000 -312.0000 0.0000 0.1000 >>> print batch.conv_table(kpoints, np.array([etot,forces_rms]), mode='next') 2 2 2 -10.0000 -0.1000 4 4 4 -2.0000 -0.1000 8 8 8 0.0000 0.0000 """ npoints = len(xx) yy = np.asarray(yy).copy() if yy.ndim == 1: yy = yy[:,None] else: yy = yy.T ny = yy.shape[1] dyy = np.empty_like(yy) for iy in range(ny): if mode == 'next': dyy[-1,iy] = 0.0 dyy[:-1,iy] = np.diff(yy[:,iy]) elif mode == 'last': dyy[:,iy] = yy[-1,iy] - yy[:,iy] else: raise StandardError("unknown mode") if absdiff: dyy = np.abs(dyy) if orig: fmtstr = ("%s"*(2*ny+1) + "\n") %((sfmt,) + (ffmt,)*2*ny) else: fmtstr = ("%s"*(ny+1) + "\n") %((sfmt,) + (ffmt,)*ny) st = '' for idx in range(npoints): if orig: repl = (xx[idx],) + \ tuple(common.flatten([dyy[idx,iy],yy[idx,iy]] for iy in range(ny))) else: repl = (xx[idx],) + tuple(dyy[idx,iy] for iy in range(ny)) st += fmtstr %repl return st
def write_input(self, mode='a', backup=True, sleep=0, excl=True): """ Create calculation dir(s) for each parameter set and write input files based on ``templates``. Write sqlite database storing all relevant parameters. Write (bash) shell script to start all calculations (run locally or submitt batch job file, depending on ``machine.subcmd``). Parameters ---------- mode : str, optional Fine tune how to write input files (based on ``templates``) to calc dirs calc_foo/0/, calc_foo/1/, ... . Note that this doesn't change the base dir calc_foo at all, only the subdirs for each calc. {'a', 'w'} | 'a': Append mode (default). If a previous database is found, then | subsequent calculations are numbered based on the last 'idx'. | calc_foo/0 # old | calc_foo/1 # old | calc_foo/2 # new | calc_foo/3 # new | 'w': Write mode. The target dirs are purged and overwritten. Also, | the database (self.dbfn) is overwritten. Use this to | iteratively tune your inputs, NOT for working on already | present results! | calc_foo/0 # new | calc_foo/1 # new backup : bool, optional Before writing anything, do a backup of self.calc_dir if it already exists. sleep : int, optional For the script to start (submitt) all jobs: time in seconds for the shell sleep(1) commmand. excl : bool If in append mode, a file <calc_root>/excl_push with all indices of calculations from old revisions is written. Can be used with ``rsync --exclude-from=excl_push`` when pushing appended new calculations to a cluster. """ assert mode in ['a', 'w'], "Unknown mode: '%s'" %mode if os.path.exists(self.dbfn): if backup: common.backup(self.dbfn) if mode == 'w': os.remove(self.dbfn) have_new_db = not os.path.exists(self.dbfn) common.makedirs(self.calc_root) # this call creates a file ``self.dbfn`` if it doesn't exist sqldb = SQLiteDB(self.dbfn, table=self.db_table) # max_idx: counter for calc dir numbering revision = 0 if have_new_db: max_idx = -1 else: if mode == 'a': if sqldb.has_column('idx'): max_idx = sqldb.execute("select max(idx) from %s" \ %self.db_table).fetchone()[0] else: raise StandardError("database '%s': table '%s' has no " "column 'idx', don't know how to number calcs" %(self.dbfn, self.db_table)) if sqldb.has_column('revision'): revision = int(sqldb.get_single("select max(revision) \ from %s" %self.db_table)) + 1 elif mode == 'w': max_idx = -1 sql_records = [] hostnames = [] for imach, machine in enumerate(self.machines): hostnames.append(machine.hostname) calc_dir = pj(self.calc_root, self.calc_dir_prefix + \ '_%s' %machine.hostname) if os.path.exists(calc_dir): if backup: common.backup(calc_dir) if mode == 'w': common.system("rm -r %s" %calc_dir, wait=True) run_txt = "here=$(pwd)\n" for _idx, params in enumerate(self.params_lst): params = common.flatten(params) idx = max_idx + _idx + 1 calc_subdir = pj(calc_dir, str(idx)) extra_dct = \ {'revision': revision, 'study_name': self.study_name, 'idx': idx, 'calc_name' : self.study_name + "_run%i" %idx, } extra_params = [SQLEntry(key=key, sqlval=val) for key,val in \ extra_dct.iteritems()] # templates[:] to copy b/c they may be modified in Calculation calc = Calculation(machine=machine, templates=self.templates[:], params=params + extra_params, calc_dir=calc_subdir, ) if mode == 'w' and os.path.exists(calc_subdir): shutil.rmtree(calc_subdir) calc.write_input() run_txt += "cd %i && %s %s && cd $here && sleep %i\n" %(idx,\ machine.subcmd, machine.get_jobfile_basename(), sleep) if imach == 0: sql_records.append(calc.get_sql_record()) common.file_write(pj(calc_dir, 'run.sh'), run_txt) for record in sql_records: record['hostname'] = SQLEntry(sqlval=','.join(hostnames)) # for incomplete parameters: collect header parts from all records and # make a set = unique entries raw_header = [(key, entry.sqltype.upper()) for record in sql_records \ for key, entry in record.iteritems()] header = list(set(raw_header)) if have_new_db: sqldb.create_table(header) else: for record in sql_records: for key, entry in record.iteritems(): if not sqldb.has_column(key): sqldb.add_column(key, entry.sqltype.upper()) for record in sql_records: cmd = "insert into %s (%s) values (%s)"\ %(self.db_table, ",".join(record.keys()), ",".join(['?']*len(record.keys()))) sqldb.execute(cmd, tuple(entry.sqlval for entry in record.itervalues())) if excl and revision > 0 and sqldb.has_column('revision'): old_idx_lst = [str(x) for x, in sqldb.execute("select idx from calc where \ revision < ?", (revision,))] common.file_write(pj(self.calc_root, 'excl_push'), '\n'.join(old_idx_lst)) sqldb.finish()
def nested_loops(lists, ret_all=False, flatten=False): """Nonrecursive version of nested loops of arbitrary depth. Pure Python version (no numpy). Parameters ---------- lists : list of lists The objects to permute. len(lists) == the depth (nesting levels) of the equivalent nested loops. Individual lists may contain a mix of different types/objects, e.g. [['a', 'b'], [Foo(), Bar(), Baz()], [1,2,3,4,5,6,7]]. ret_all : bool True: return perms, perm_idxs False: return perms flatten : bool Flatten each entry in returned list. Returns ------- perms : list of lists with permuted objects perm_idxs : list of lists with indices of the permutation if ret_all=True Notes ----- In Python >= 2.6, this is almost the same as itertools.product() but was written before that was in itertools. >>> [x for x in itertools.product([1,2],[33,44,55],[sin,cos])] >>> nested_loops([[1,2],[33,44,55],[sin,cos]]) Note that nested_loops() takes a list of lists, while itertools.product() only the lists itself. Examples -------- >>> a=[1,2]; b=[3,4]; c=[5,6]; >>> perms=[] >>> for aa in a: ....: for bb in b: ....: for cc in c: ....: perms.append([aa,bb,cc]) ....: >>> perms [[1, 3, 5], [1, 3, 6], [1, 4, 5], [1, 4, 6], [2, 3, 5], [2, 3, 6], [2, 4, 5], [2, 4, 6]] >>> nested_loops([a,b,c], ret_all=True) ([[1, 3, 5], [1, 3, 6], [1, 4, 5], [1, 4, 6], [2, 3, 5], [2, 3, 6], [2, 4, 5], [2, 4, 6]], [[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]]) >>> nested_loops([[1,2], ['a','b','c'], [sin, cos]]) [[1, 'a', <ufunc 'sin'>], [1, 'a', <ufunc 'cos'>], [1, 'b', <ufunc 'sin'>], [1, 'b', <ufunc 'cos'>], [1, 'c', <ufunc 'sin'>], [1, 'c', <ufunc 'cos'>], [2, 'a', <ufunc 'sin'>], [2, 'a', <ufunc 'cos'>], [2, 'b', <ufunc 'sin'>], [2, 'b', <ufunc 'cos'>], [2, 'c', <ufunc 'sin'>], [2, 'c', <ufunc 'cos'>]] # If values of different lists should be varied together, use zip(). Note # that you get nested lists back. Use flatten=True to get flattened lists. >>> nested_loops([zip([1,2], ['a', 'b']), [88, 99]]) [[(1, 'a'), 88], [(1, 'a'), 99], [(2, 'b'), 88], [(2, 'b'), 99]] >>> nested_loops([zip([1,2], ['a', 'b']), [88, 99]], flatten=True) [[1, 'a', 88], [1, 'a', 99], [2, 'b', 88], [2, 'b', 99]] """ lens = map(len, lists) mx_idxs = [x - 1 for x in lens] # nperms = numpy.prod(lens) nperms = reduce(lambda x,y: x*y, lens) # number of nesting levels nlevels = len(lists) # index into `lists`: lists[i][j] -> lists[i][idxs[i]], i.e. # idxs[i] is the index into the ith list idxs = [0]*nlevels perm_idxs = [] perms = [] # e.g. [2,1,0] rev_rlevels = range(nlevels)[::-1] for i in range(nperms): for pos in rev_rlevels: if idxs[pos] > mx_idxs[pos]: idxs[pos] = 0 # pos - 1 never gets < 0 before all possible `nlevels` # permutations are generated. idxs[pos-1] += 1 # [:] to append a copy perm_idxs.append(idxs[:]) perms.append([lists[j][k] for j,k in enumerate(idxs)]) idxs[-1] += 1 perms = [common.flatten(xx) for xx in perms] if flatten else perms if ret_all: return perms, perm_idxs else: return perms
def get_list1d(self, *args, **kwargs): """Shortcut for commonly used functionality. If one extracts a single column, then ``self.cur.fetchall()`` returns a list of tuples like ``[(1,), (2,)]`` We call ``fetchall()`` and return the flattened list. """ return common.flatten(self.execute(*args, **kwargs).fetchall())
y = np.sin(x) fig3, hostax = make_axes_grid_fig(3) hostax.set_xlabel('hostax bottom') hostax.set_ylabel('hostax left') # {'left': (off, wsadd), # ...} off_dct = dict(left=(60, .1), right=(60, .1), top=(60, .15), bottom=(50, .15)) for n, val in enumerate(off_dct.iteritems()): loc, off, wsadd = tuple(flatten(val)) fig3, hostax, parax = new_axis(fig3, hostax=hostax, loc=loc, off=off, label=loc, wsadd=wsadd) parax.plot(x*n, y**n) new_axis(fig3, hostax=hostax, loc='right', off=0, wsadd=0, label="hostax right, I'm like twinx()") new_axis(fig3, hostax=hostax, loc='top', off=0, wsadd=0, label="hostax top, I'm like twiny()") # many lines fig4, hostax = make_axes_grid_fig(4)
def conv_table(xx, yy, ffmt="%15.4f", sfmt="%15s", mode='last', orig=False, absdiff=False): """Convergence table. Assume that quantity `xx` was varied, resulting in `yy` values. Return a string (table) listing:: x, dy1, dy2, ... Useful for quickly viewing the results of a convergence study, where we assume that the sequence of `yy` values converges to a constant value. Parameters ---------- xx : 1d sequence yy : 1d sequence, nested 1d sequences, 2d array Values varied with `xx`. Each row is one parameter. ffmt, sfmt : str Format strings for floats (`ffmt`) and strings (`sfmt`) mode : str 'next' or 'last'. Difference to the next value ``y[i+1] - y[i]`` or to the last ``y[-1] - y[i]``. orig : bool Print original `yy` data as well. absdiff : bool absolute values of differences Examples -------- >>> kpoints = ['2 2 2', '4 4 4', '8 8 8'] >>> etot = [-300.0, -310.0, -312.0] >>> forces_rms = [0.3, 0.2, 0.1] >>> print batch.conv_table(kpoints, etot, mode='last') 2 2 2 -12.0000 4 4 4 -2.0000 8 8 8 0.0000 >>> print batch.conv_table(kpoints, [etot,forces_rms], mode='last') 2 2 2 -12.0000 -0.2000 4 4 4 -2.0000 -0.1000 8 8 8 0.0000 0.0000 >>> print batch.conv_table(kpoints, [etot,forces_rms], mode='last', orig=True) 2 2 2 -12.0000 -300.0000 -0.2000 0.3000 4 4 4 -2.0000 -310.0000 -0.1000 0.2000 8 8 8 0.0000 -312.0000 0.0000 0.1000 >>> print batch.conv_table(kpoints, np.array([etot,forces_rms]), mode='next') 2 2 2 -10.0000 -0.1000 4 4 4 -2.0000 -0.1000 8 8 8 0.0000 0.0000 """ npoints = len(xx) yy = np.asarray(yy).copy() if yy.ndim == 1: yy = yy[:, None] else: yy = yy.T ny = yy.shape[1] dyy = np.empty_like(yy) for iy in range(ny): if mode == 'next': dyy[-1, iy] = 0.0 dyy[:-1, iy] = np.diff(yy[:, iy]) elif mode == 'last': dyy[:, iy] = yy[-1, iy] - yy[:, iy] else: raise Exception("unknown mode") if absdiff: dyy = np.abs(dyy) if orig: fmtstr = ("%s" * (2 * ny + 1) + "\n") % ((sfmt, ) + (ffmt, ) * 2 * ny) else: fmtstr = ("%s" * (ny + 1) + "\n") % ((sfmt, ) + (ffmt, ) * ny) st = '' for idx in range(npoints): if orig: repl = (xx[idx],) + \ tuple(common.flatten([dyy[idx,iy],yy[idx,iy]] for iy in range(ny))) else: repl = (xx[idx], ) + tuple(dyy[idx, iy] for iy in range(ny)) st += fmtstr % repl return st
def write_input(self, mode='a', backup=True, sleep=0, excl=True): """ Create calculation dir(s) for each parameter set and write input files based on ``templates``. Write sqlite database storing all relevant parameters. Write (bash) shell script to start all calculations (run locally or submitt batch job file, depending on ``machine.subcmd``). Parameters ---------- mode : str, optional Fine tune how to write input files (based on ``templates``) to calc dirs calc_foo/0/, calc_foo/1/, ... . Note that this doesn't change the base dir calc_foo at all, only the subdirs for each calc. {'a', 'w'} | 'a': Append mode (default). If a previous database is found, then | subsequent calculations are numbered based on the last 'idx'. | calc_foo/0 # old | calc_foo/1 # old | calc_foo/2 # new | calc_foo/3 # new | 'w': Write mode. The target dirs are purged and overwritten. Also, | the database (self.dbfn) is overwritten. Use this to | iteratively tune your inputs, NOT for working on already | present results! | calc_foo/0 # new | calc_foo/1 # new backup : bool, optional Before writing anything, do a backup of self.calc_dir if it already exists. sleep : int, optional For the script to start (submitt) all jobs: time in seconds for the shell sleep(1) commmand. excl : bool If in append mode, a file <calc_root>/excl_push with all indices of calculations from old revisions is written. Can be used with ``rsync --exclude-from=excl_push`` when pushing appended new calculations to a cluster. """ assert mode in ['a', 'w'], "Unknown mode: '%s'" % mode if os.path.exists(self.dbfn): if backup: common.backup(self.dbfn) if mode == 'w': os.remove(self.dbfn) have_new_db = not os.path.exists(self.dbfn) common.makedirs(self.calc_root) # this call creates a file ``self.dbfn`` if it doesn't exist sqldb = SQLiteDB(self.dbfn, table=self.db_table) # max_idx: counter for calc dir numbering revision = 0 if have_new_db: max_idx = -1 else: if mode == 'a': if sqldb.has_column('idx'): max_idx = sqldb.execute("select max(idx) from %s" \ %self.db_table).fetchone()[0] else: raise Exception( "database '%s': table '%s' has no " "column 'idx', don't know how to number calcs" % (self.dbfn, self.db_table)) if sqldb.has_column('revision'): revision = int( sqldb.get_single("select max(revision) \ from %s" % self.db_table)) + 1 elif mode == 'w': max_idx = -1 sql_records = [] hostnames = [] for imach, machine in enumerate(self.machines): hostnames.append(machine.hostname) calc_dir = pj(self.calc_root, self.calc_dir_prefix + \ '_%s' %machine.hostname) if os.path.exists(calc_dir): if backup: common.backup(calc_dir) if mode == 'w': common.system("rm -r %s" % calc_dir, wait=True) run_txt = "here=$(pwd)\n" for _idx, params in enumerate(self.params_lst): params = common.flatten(params) idx = max_idx + _idx + 1 calc_subdir = pj(calc_dir, str(idx)) extra_dct = \ {'revision': revision, 'study_name': self.study_name, 'idx': idx, 'calc_name' : self.study_name + "_run%i" %idx, } extra_params = [SQLEntry(key=key, sqlval=val) for key,val in \ extra_dct.items()] # templates[:] to copy b/c they may be modified in Calculation calc = Calculation( machine=machine, templates=self.templates[:], params=params + extra_params, calc_dir=calc_subdir, ) if mode == 'w' and os.path.exists(calc_subdir): shutil.rmtree(calc_subdir) calc.write_input() run_txt += "cd %i && %s %s && cd $here && sleep %i\n" %(idx,\ machine.subcmd, machine.get_jobfile_basename(), sleep) if imach == 0: sql_records.append(calc.get_sql_record()) common.file_write(pj(calc_dir, 'run.sh'), run_txt) for record in sql_records: record['hostname'] = SQLEntry(sqlval=','.join(hostnames)) # for incomplete parameters: collect header parts from all records and # make a set = unique entries raw_header = [(key, entry.sqltype.upper()) for record in sql_records \ for key, entry in record.items()] header = list(set(raw_header)) if have_new_db: sqldb.create_table(header) else: for record in sql_records: for key, entry in record.items(): if not sqldb.has_column(key): sqldb.add_column(key, entry.sqltype.upper()) for record in sql_records: cmd = "insert into %s (%s) values (%s)"\ %(self.db_table, ",".join(list(record.keys())), ",".join(['?']*len(list(record.keys())))) sqldb.execute(cmd, tuple(entry.sqlval for entry in record.values())) if excl and revision > 0 and sqldb.has_column('revision'): old_idx_lst = [ str(x) for x, in sqldb.execute( "select idx from calc where \ revision < ?", ( revision, )) ] common.file_write(pj(self.calc_root, 'excl_push'), '\n'.join(old_idx_lst)) sqldb.finish()
def test_scell(): cell = np.identity(3) coords_frac = np.array([[0.5, 0.5, 0.5], [1,1,1]]) symbols = ['Al', 'N'] sc = crys.scell(Structure(coords_frac=coords_frac, cell=cell, symbols=symbols), (2,2,2)) sc_coords_frac = \ np.array([[ 0.25, 0.25, 0.25], [ 0.25, 0.25, 0.75], [ 0.25, 0.75, 0.25], [ 0.25, 0.75, 0.75], [ 0.75, 0.25, 0.25], [ 0.75, 0.25, 0.75], [ 0.75, 0.75, 0.25], [ 0.75, 0.75, 0.75], [ 0.5 , 0.5 , 0.5 ], [ 0.5 , 0.5 , 1. ], [ 0.5 , 1. , 0.5 ], [ 0.5 , 1. , 1. ], [ 1. , 0.5 , 0.5 ], [ 1. , 0.5 , 1. ], [ 1. , 1. , 0.5 ], [ 1. , 1. , 1. ]]) sc_symbols = ['Al']*8 + ['N']*8 sc_cell = \ np.array([[ 2., 0., 0.], [ 0., 2., 0.], [ 0., 0., 2.]]) assert sc.symbols == sc_symbols np.testing.assert_array_almost_equal(sc.coords_frac, sc_coords_frac) np.testing.assert_array_almost_equal(sc.cell, sc_cell) # non-orthorhombic cell cell = \ np.array([[ 1., 0.5, 0.5], [ 0.25, 1., 0.2], [ 0.2, 0.5, 1.]]) sc = crys.scell(Structure(coords_frac=coords_frac, cell=cell, symbols=symbols), (2,2,2)) sc_cell = \ np.array([[ 2. , 1. , 1. ], [ 0.5, 2. , 0.4], [ 0.4, 1. , 2. ]]) np.testing.assert_array_almost_equal(sc.cell, sc_cell) # crystal coords are cell-independent np.testing.assert_array_almost_equal(sc.coords_frac, sc_coords_frac) # slab # # Test if old and new implementation behave for a tricky case: natoms == 2 # mask.shape[0], i.e. if reshape() behaves correctly. # Reference generated with old implementation. Default is new. cell = np.identity(3) coords_frac = np.array([[0.5, 0.5, 0.5], [1,1,1]]) symbols = ['Al', 'N'] sc = crys.scell(Structure(coords_frac=coords_frac, cell=cell, symbols=symbols), (1,1,2)) sc_coords_frac = \ np.array([[ 0.5 , 0.5 , 0.25], [ 0.5 , 0.5 , 0.75], [ 1. , 1. , 0.5 ], [ 1. , 1. , 1. ]]) sc_cell = \ np.array([[ 1., 0., 0.], [ 0., 1., 0.], [ 0., 0., 2.]]) sc_symbols = ['Al', 'Al', 'N', 'N'] assert sc.symbols == sc_symbols np.testing.assert_array_almost_equal(sc.cell, sc_cell) np.testing.assert_array_almost_equal(sc.coords_frac, sc_coords_frac) sc = crys.scell(Structure(coords_frac=coords_frac, cell=cell, symbols=symbols), (1,2,1)) sc_coords_frac = \ np.array([[ 0.5 , 0.25, 0.5 ], [ 0.5 , 0.75, 0.5 ], [ 1. , 0.5 , 1. ], [ 1. , 1. , 1. ]]) sc_cell = \ np.array([[ 1., 0., 0.], [ 0., 2., 0.], [ 0., 0., 1.]]) assert sc.symbols == sc_symbols np.testing.assert_array_almost_equal(sc.cell, sc_cell) np.testing.assert_array_almost_equal(sc.coords_frac, sc_coords_frac) sc = crys.scell(Structure(coords_frac=coords_frac, cell=cell, symbols=symbols), (2,1,1)) sc_coords_frac = \ np.array([[ 0.25, 0.5 , 0.5 ], [ 0.75, 0.5 , 0.5 ], [ 0.5 , 1. , 1. ], [ 1. , 1. , 1. ]]) sc_cell = \ np.array([[ 2., 0., 0.], [ 0., 1., 0.], [ 0., 0., 1.]]) assert sc.symbols == sc_symbols np.testing.assert_array_almost_equal(sc.cell, sc_cell) np.testing.assert_array_almost_equal(sc.coords_frac, sc_coords_frac) # symbols = None sc = crys.scell(Structure(coords_frac=coords_frac, cell=cell, symbols=None), (2,2,2)) assert sc.symbols is None # Trajectory natoms = 4 nstep = 100 symbols = [syms.next() for ii in range(natoms)] # cell 2d coords_frac = rand(nstep,natoms,3) cell = rand(3,3) dims = (2,3,4) nmask = np.prod(dims) sc = crys.scell(Trajectory(coords_frac=coords_frac, cell=cell, symbols=symbols), dims=dims) assert sc.coords_frac.shape == (nstep, nmask*natoms, 3) assert sc.symbols == common.flatten([[sym]*nmask for sym in symbols]) assert sc.cell.shape == (nstep,3,3) np.testing.assert_array_almost_equal(sc.cell, num.extend_array(cell * np.asarray(dims)[:,None], sc.nstep, axis=0)) # cell 3d cell = rand(nstep,3,3) sc = crys.scell(Trajectory(coords_frac=coords_frac, cell=cell, symbols=symbols), dims=dims) assert sc.coords_frac.shape == (nstep, nmask*natoms, 3) coords_frac2 = np.array([crys.scell(Structure(coords_frac=coords_frac[ii,...,], cell=cell[ii,...], symbols=symbols), dims=dims).coords_frac \ for ii in range(nstep)]) np.testing.assert_array_almost_equal(sc.coords_frac, coords_frac2) assert sc.symbols == common.flatten([[sym]*nmask for sym in symbols]) assert sc.cell.shape == (nstep,3,3) np.testing.assert_array_almost_equal(sc.cell, cell * np.asarray(dims)[None,:,None]) # methods natoms = 20 coords_frac = rand(natoms,3) cell = rand(3,3) dims = (2,3,4) symbols = [syms.next() for ii in range(natoms)] struct = Structure(coords_frac=coords_frac, cell=cell, symbols=symbols) sc1 = crys.scell(struct, dims=dims, method=1) sc2 = crys.scell(struct, dims=dims, method=2) d1 = dict([(key, getattr(sc1, key)) for key in sc1.attr_lst]) d2 = dict([(key, getattr(sc2, key)) for key in sc2.attr_lst]) tools.assert_dict_with_all_types_almost_equal(d1, d2)
def test_scell(): cell = np.identity(3) coords_frac = np.array([[0.5, 0.5, 0.5], [1, 1, 1]]) symbols = ['Al', 'N'] sc = crys.scell( Structure(coords_frac=coords_frac, cell=cell, symbols=symbols), (2, 2, 2)) sc_coords_frac = \ np.array([[ 0.25, 0.25, 0.25], [ 0.25, 0.25, 0.75], [ 0.25, 0.75, 0.25], [ 0.25, 0.75, 0.75], [ 0.75, 0.25, 0.25], [ 0.75, 0.25, 0.75], [ 0.75, 0.75, 0.25], [ 0.75, 0.75, 0.75], [ 0.5 , 0.5 , 0.5 ], [ 0.5 , 0.5 , 1. ], [ 0.5 , 1. , 0.5 ], [ 0.5 , 1. , 1. ], [ 1. , 0.5 , 0.5 ], [ 1. , 0.5 , 1. ], [ 1. , 1. , 0.5 ], [ 1. , 1. , 1. ]]) sc_symbols = ['Al'] * 8 + ['N'] * 8 sc_cell = \ np.array([[ 2., 0., 0.], [ 0., 2., 0.], [ 0., 0., 2.]]) assert sc.symbols == sc_symbols np.testing.assert_array_almost_equal(sc.coords_frac, sc_coords_frac) np.testing.assert_array_almost_equal(sc.cell, sc_cell) # non-orthorhombic cell cell = \ np.array([[ 1., 0.5, 0.5], [ 0.25, 1., 0.2], [ 0.2, 0.5, 1.]]) sc = crys.scell( Structure(coords_frac=coords_frac, cell=cell, symbols=symbols), (2, 2, 2)) sc_cell = \ np.array([[ 2. , 1. , 1. ], [ 0.5, 2. , 0.4], [ 0.4, 1. , 2. ]]) np.testing.assert_array_almost_equal(sc.cell, sc_cell) # crystal coords are cell-independent np.testing.assert_array_almost_equal(sc.coords_frac, sc_coords_frac) # slab # # Test if old and new implementation behave for a tricky case: natoms == 2 # mask.shape[0], i.e. if reshape() behaves correctly. # Reference generated with old implementation. Default is new. cell = np.identity(3) coords_frac = np.array([[0.5, 0.5, 0.5], [1, 1, 1]]) symbols = ['Al', 'N'] sc = crys.scell( Structure(coords_frac=coords_frac, cell=cell, symbols=symbols), (1, 1, 2)) sc_coords_frac = \ np.array([[ 0.5 , 0.5 , 0.25], [ 0.5 , 0.5 , 0.75], [ 1. , 1. , 0.5 ], [ 1. , 1. , 1. ]]) sc_cell = \ np.array([[ 1., 0., 0.], [ 0., 1., 0.], [ 0., 0., 2.]]) sc_symbols = ['Al', 'Al', 'N', 'N'] assert sc.symbols == sc_symbols np.testing.assert_array_almost_equal(sc.cell, sc_cell) np.testing.assert_array_almost_equal(sc.coords_frac, sc_coords_frac) sc = crys.scell( Structure(coords_frac=coords_frac, cell=cell, symbols=symbols), (1, 2, 1)) sc_coords_frac = \ np.array([[ 0.5 , 0.25, 0.5 ], [ 0.5 , 0.75, 0.5 ], [ 1. , 0.5 , 1. ], [ 1. , 1. , 1. ]]) sc_cell = \ np.array([[ 1., 0., 0.], [ 0., 2., 0.], [ 0., 0., 1.]]) assert sc.symbols == sc_symbols np.testing.assert_array_almost_equal(sc.cell, sc_cell) np.testing.assert_array_almost_equal(sc.coords_frac, sc_coords_frac) sc = crys.scell( Structure(coords_frac=coords_frac, cell=cell, symbols=symbols), (2, 1, 1)) sc_coords_frac = \ np.array([[ 0.25, 0.5 , 0.5 ], [ 0.75, 0.5 , 0.5 ], [ 0.5 , 1. , 1. ], [ 1. , 1. , 1. ]]) sc_cell = \ np.array([[ 2., 0., 0.], [ 0., 1., 0.], [ 0., 0., 1.]]) assert sc.symbols == sc_symbols np.testing.assert_array_almost_equal(sc.cell, sc_cell) np.testing.assert_array_almost_equal(sc.coords_frac, sc_coords_frac) # symbols = None sc = crys.scell( Structure(coords_frac=coords_frac, cell=cell, symbols=None), (2, 2, 2)) assert sc.symbols is None # Trajectory natoms = 4 nstep = 100 symbols = [next(syms) for ii in range(natoms)] # cell 2d coords_frac = rand(nstep, natoms, 3) cell = rand(3, 3) dims = (2, 3, 4) nmask = np.prod(dims) sc = crys.scell(Trajectory(coords_frac=coords_frac, cell=cell, symbols=symbols), dims=dims) assert sc.coords_frac.shape == (nstep, nmask * natoms, 3) assert sc.symbols == common.flatten([[sym] * nmask for sym in symbols]) assert sc.cell.shape == (nstep, 3, 3) np.testing.assert_array_almost_equal( sc.cell, num.extend_array(cell * np.asarray(dims)[:, None], sc.nstep, axis=0)) # cell 3d cell = rand(nstep, 3, 3) sc = crys.scell(Trajectory(coords_frac=coords_frac, cell=cell, symbols=symbols), dims=dims) assert sc.coords_frac.shape == (nstep, nmask * natoms, 3) coords_frac2 = np.array([crys.scell(Structure(coords_frac=coords_frac[ii,...,], cell=cell[ii,...], symbols=symbols), dims=dims).coords_frac \ for ii in range(nstep)]) np.testing.assert_array_almost_equal(sc.coords_frac, coords_frac2) assert sc.symbols == common.flatten([[sym] * nmask for sym in symbols]) assert sc.cell.shape == (nstep, 3, 3) np.testing.assert_array_almost_equal( sc.cell, cell * np.asarray(dims)[None, :, None]) # methods natoms = 20 coords_frac = rand(natoms, 3) cell = rand(3, 3) dims = (2, 3, 4) symbols = [next(syms) for ii in range(natoms)] struct = Structure(coords_frac=coords_frac, cell=cell, symbols=symbols) sc1 = crys.scell(struct, dims=dims, method=1) sc2 = crys.scell(struct, dims=dims, method=2) d1 = dict([(key, getattr(sc1, key)) for key in sc1.attr_lst]) d2 = dict([(key, getattr(sc2, key)) for key in sc2.attr_lst]) tools.assert_dict_with_all_types_almost_equal(d1, d2)