-
Notifications
You must be signed in to change notification settings - Fork 0
/
trim.py
executable file
·132 lines (106 loc) · 4.38 KB
/
trim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python3
'''
Designed to work on the cluster, removing 'ejected species' after each bomb or quench step
Make sure the source files are named in the form `output%d.gen`
'''
# imports
import os
import re
import sys
# scipy
import numpy as np
import pandas as pd
#ase
from ase.io import gen, vasp, xyz, extxyz, dftb
from ase.geometry.analysis import Analysis
#researchscripts
from researchscripts.structure import Graph
def main(
datadir = "temp/", #data files, structured as datadir/output$i-$j.gen and datadir/velos$i-$j
outputdir = "temp.new/", #files for output
hbondrange = 6, #offset from surface corresponding to Hbond range
zmincutoff = 0.1, #somewhat arbitrary value to get rid of atoms that have gone into bulk
output_geom_name = "output", #prefix for output geometry files
output_velos_name = "velos" #prefix for output velocity files
):
##############################
### Read in geometry files ###
##############################
hbondrange = int(hbondrange)
zmincutoff = float(zmincutoff)
geometries = {}
for i in os.listdir(datadir):
if output_geom_name in i:
key = re.search(r"\d+", i)
if key:
key = key.group(0)
geometries[key] = gen.read_gen(datadir + i)
##########################
### Read in velocities ###
##########################
velos = dict()
for i in os.listdir(datadir):
if output_velos_name in i:
key = re.search(r"\d+", i)
if key:
key = key.group(0)
velos[key] = pd.read_csv(datadir + i, header = None, dtype = float, sep = "\s+")
################
### trimming ###
################
trimmedgeoms = dict()
trimmedvelos = dict()
removedspecies = dict()
for key, geom in geometries.items():
removedatoms = {'Si': 0, 'N': 0, 'H': 0, 'Ar': 0, 'F':0, 'C':0}
# construct graph
adjmat = Analysis(geom).adjacency_matrix[0]
numnodes = adjmat.shape[0]
g = Graph(numnodes)
for i in range(numnodes):
for j in range(numnodes):
if adjmat[i,j]:
g.addEdge(i,j)
cc = g.connectedComponents()
#identify slab, and max height of slab
maingraph = np.array([i for i in cc if 0 in i][0])
slab = geom[[atom.index for atom in geom if atom.index in maingraph]]
gen.write_gen(outputdir + "slab{}.gen".format(key), slab)
zcutoff = np.max([atom.position[2] for atom in slab]) + hbondrange
# isolate fragments and identify which to remove
fragGraphs = [i for i in cc if 0 not in i]
fragZs = [[geom[i].position[2] for i in frag] for frag in fragGraphs]
removeFrag = [np.all(np.array(i) > zcutoff) or np.all(np.array(i) < zmincutoff)
for i in fragZs]
atomsToRemove = [i for g,r in zip(fragGraphs, removeFrag) if r for i in g]
#account for any atoms that have wrapped around through the top of the cell (lookin at you, H)
atomsToRemove += [a.index for a in geom if a.z > geom.cell[2,2]]
for idx in atomsToRemove:
removedatoms[geom[idx].symbol] += 1 #tally removed atoms by species
geomcopy = geom.copy()
del geomcopy[[atom.index for atom in geomcopy if atom.index in atomsToRemove]]
removedspecies[key] = pd.Series(removedatoms)
trimmedgeoms[key] = geomcopy
trimmedvelos[key] = velos[key][[i not in atomsToRemove for i in np.arange(len(velos[key]))]]
# collect all removed species series into a df and write as csv
pd.DataFrame(removedspecies).to_csv("removedspecies.csv")
#write
for key, geom in trimmedgeoms.items():
gen.write_gen("%sinput%s.gen" % (outputdir, key),
geom)
for key, v in trimmedvelos.items():
v.to_csv("%s%s%s.in" % (outputdir, output_velos_name, key),
sep = " ", index = False, header = False)
if __name__ == "__main__":
"""
Takes in three arguments:
surftype: original slab type (used for zmax ref, some slabs are taller than others)
if this is a number, use this as the zmax ref
datadir: where the input data is
outputdir: where the output goes
"""
args = sys.argv[1:]
if len(args) > 6:
print(args)
raise Exception("No more than 6 arguments allowed")
main(*args)