Exemple #1
class lammps(object):
  # detect if Python is using version of mpi4py that can pass a communicator

  has_mpi4py_v2 = False
    from mpi4py import MPI
    from mpi4py import __version__ as mpi4py_version
    if mpi4py_version.split('.')[0] == '2':
      has_mpi4py_v2 = True

  # create instance of LAMMPS

  def __init__(self,name="",cmdargs=None,ptr=None,comm=None):
    self.comm = comm
    self.opened = 0

    # determine module location

    modpath = dirname(abspath(getsourcefile(lambda:0)))
    self.lib = None

    # if a pointer to a LAMMPS object is handed in,
    # all symbols should already be available
      if ptr: self.lib = CDLL("",RTLD_GLOBAL)
      self.lib = None

    # load liblammps.so unless name is given
    #   if name = "g++", load liblammps_g++.so
    # try loading the LAMMPS shared object from the location
    #   of lammps.py with an absolute path,
    #   so that LD_LIBRARY_PATH does not need to be set for regular install
    # fall back to loading with a relative path,
    #   typically requires LD_LIBRARY_PATH to be set appropriately
    if not self.lib:
        if not name: self.lib = CDLL(join(modpath,"liblammps.so"),RTLD_GLOBAL)
        else: self.lib = CDLL(join(modpath,"liblammps_%s.so" % name),
        if not name: self.lib = CDLL("liblammps.so",RTLD_GLOBAL)
        else: self.lib = CDLL("liblammps_%s.so" % name,RTLD_GLOBAL)

    # if no ptr provided, create an instance of LAMMPS
    #   don't know how to pass an MPI communicator from PyPar
    #   but we can pass an MPI communicator from mpi4py v2.0.0 and later
    #   no_mpi call lets LAMMPS use MPI_COMM_WORLD
    #   cargs = array of C strings from args
    # if ptr, then are embedding Python in LAMMPS input script
    #   ptr is the desired instance of LAMMPS
    #   just convert it to ctypes ptr and store in self.lmp

    if not ptr:
      # with mpi4py v2, can pass MPI communicator to LAMMPS
      # need to adjust for type of MPI communicator object
      # allow for int (like MPICH) or void* (like OpenMPI)

      if lammps.has_mpi4py_v2 and comm != None:
        if lammps.MPI._sizeof(lammps.MPI.Comm) == sizeof(c_int):
          MPI_Comm = c_int
          MPI_Comm = c_void_p

        narg = 0
        cargs = 0
        if cmdargs:
          narg = len(cmdargs)
          for i in range(narg):
            if type(cmdargs[i]) is str:
              cmdargs[i] = cmdargs[i].encode()
          cargs = (c_char_p*narg)(*cmdargs)
          self.lib.lammps_open.argtypes = [c_int, c_char_p*narg, \
                                           MPI_Comm, c_void_p()]
          self.lib.lammps_open.argtypes = [c_int, c_int, \
                                           MPI_Comm, c_void_p()]

        self.lib.lammps_open.restype = None
        self.opened = 1
        self.lmp = c_void_p()
        comm_ptr = lammps.MPI._addressof(comm)
        comm_val = MPI_Comm.from_address(comm_ptr)

        self.opened = 1
        if cmdargs:
          narg = len(cmdargs)
          for i in range(narg):
            if type(cmdargs[i]) is str:
              cmdargs[i] = cmdargs[i].encode()
          cargs = (c_char_p*narg)(*cmdargs)
          self.lmp = c_void_p()
          self.lmp = c_void_p()
          # could use just this if LAMMPS lib interface supported it
          # self.lmp = self.lib.lammps_open_no_mpi(0,None)

      # magic to convert ptr to ctypes ptr
      if sys.version_info >= (3, 0):
        # Python 3 (uses PyCapsule API)
        pythonapi.PyCapsule_GetPointer.restype = c_void_p
        pythonapi.PyCapsule_GetPointer.argtypes = [py_object, c_char_p]
        self.lmp = c_void_p(pythonapi.PyCapsule_GetPointer(ptr, None))
        # Python 2 (uses PyCObject API)
        pythonapi.PyCObject_AsVoidPtr.restype = c_void_p
        pythonapi.PyCObject_AsVoidPtr.argtypes = [py_object]
        self.lmp = c_void_p(pythonapi.PyCObject_AsVoidPtr(ptr))

    # optional numpy support (lazy loading)
    self._numpy = None

    # set default types
    self.c_bigint = get_ctypes_int(self.extract_setting("bigint"))
    self.c_tagint = get_ctypes_int(self.extract_setting("tagint"))
    self.c_imageint = get_ctypes_int(self.extract_setting("imageint"))

  def __del__(self):
    if self.lmp and self.opened:
      self.opened = 0

  def close(self):
    if self.opened: self.lib.lammps_close(self.lmp)
    self.lmp = None
    self.opened = 0

  def version(self):
    return self.lib.lammps_version(self.lmp)

  def file(self,file):
    if file: file = file.encode()

  # send a single command
  def command(self,cmd):
    if cmd: cmd = cmd.encode()

    if self.uses_exceptions and self.lib.lammps_has_error(self.lmp):
      sb = create_string_buffer(100)
      error_type = self.lib.lammps_get_last_error_message(self.lmp, sb, 100)
      error_msg = sb.value.decode().strip()

      if error_type == 2:
        raise MPIAbortException(error_msg)
      raise Exception(error_msg)

  # send a list of commands

  def commands_list(self,cmdlist):
    cmds = [x.encode() for x in cmdlist if type(x) is str]
    args = (c_char_p * len(cmdlist))(*cmds)
  # send a string of commands

  def commands_string(self,multicmd):
    if type(multicmd) is str:
        multicmd = multicmd.encode()
  # extract global info
  def extract_global(self,name,type):
    if name: name = name.encode()
    if type == 0:
      self.lib.lammps_extract_global.restype = POINTER(c_int)
    elif type == 1:
      self.lib.lammps_extract_global.restype = POINTER(c_double)
    else: return None
    ptr = self.lib.lammps_extract_global(self.lmp,name)
    return ptr[0]

  # extract per-atom info
  def extract_atom(self,name,type):
    if name: name = name.encode()
    if type == 0:
      self.lib.lammps_extract_atom.restype = POINTER(c_int)
    elif type == 1:
      self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_int))
    elif type == 2:
      self.lib.lammps_extract_atom.restype = POINTER(c_double)
    elif type == 3:
      self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_double))
    else: return None
    ptr = self.lib.lammps_extract_atom(self.lmp,name)
    return ptr

  # extract lammps type byte sizes

  def extract_setting(self, name):
    if name: name = name.encode()
    self.lib.lammps_extract_atom.restype = c_int
    return int(self.lib.lammps_extract_setting(self.lmp,name))

  def numpy(self):
    if not self._numpy:
      import numpy as np
      class LammpsNumpyWrapper:
        def __init__(self, lmp):
          self.lmp = lmp

        def _ctype_to_numpy_int(self, ctype_int):
          if ctype_int == c_int32:
            return np.int32
          elif ctype_int == c_int64:
            return np.int64
          return np.intc

        def extract_atom_iarray(self, name, nelem, dim=1):
          if name in ['id', 'molecule']:
            c_int_type = self.lmp.c_tagint
          elif name in ['image']:
            c_int_type = self.lmp.c_imageint
            c_int_type = c_int

          np_int_type = self._ctype_to_numpy_int(c_int_type)

          if dim == 1:
            tmp = self.lmp.extract_atom(name, 0)
            ptr = cast(tmp, POINTER(c_int_type * nelem))
            tmp = self.lmp.extract_atom(name, 1)
            ptr = cast(tmp[0], POINTER(c_int_type * nelem * dim))

          a = np.frombuffer(ptr.contents, dtype=np_int_type)
          a.shape = (nelem, dim)
          return a

        def extract_atom_darray(self, name, nelem, dim=1):
          if dim == 1:
            tmp = self.lmp.extract_atom(name, 2)
            ptr = cast(tmp, POINTER(c_double * nelem))
            tmp = self.lmp.extract_atom(name, 3)
            ptr = cast(tmp[0], POINTER(c_double * nelem * dim))

          a = np.frombuffer(ptr.contents)
          a.shape = (nelem, dim)
          return a

      self._numpy = LammpsNumpyWrapper(self)
    return self._numpy

  # extract compute info
  def extract_compute(self,id,style,type):
    if id: id = id.encode()
    if type == 0:
      if style > 0: return None
      self.lib.lammps_extract_compute.restype = POINTER(c_double)
      ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
      return ptr[0]
    if type == 1:
      self.lib.lammps_extract_compute.restype = POINTER(c_double)
      ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
      return ptr
    if type == 2:
      self.lib.lammps_extract_compute.restype = POINTER(POINTER(c_double))
      ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
      return ptr
    return None

  # extract fix info
  # in case of global datum, free memory for 1 double via lammps_free()
  # double was allocated by library interface function

  def extract_fix(self,id,style,type,i=0,j=0):
    if id: id = id.encode()
    if style == 0:
      self.lib.lammps_extract_fix.restype = POINTER(c_double)
      ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,i,j)
      result = ptr[0]
      return result
    elif (style == 1) or (style == 2):
      if type == 1:
        self.lib.lammps_extract_fix.restype = POINTER(c_double)
      elif type == 2:
        self.lib.lammps_extract_fix.restype = POINTER(POINTER(c_double))
        return None
      ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,i,j)
      return ptr
      return None

  # extract variable info
  # free memory for 1 double or 1 vector of doubles via lammps_free()
  # for vector, must copy nlocal returned values to local c_double vector
  # memory was allocated by library interface function

  def extract_variable(self,name,group,type):
    if name: name = name.encode()
    if group: group = group.encode()
    if type == 0:
      self.lib.lammps_extract_variable.restype = POINTER(c_double)
      ptr = self.lib.lammps_extract_variable(self.lmp,name,group)
      result = ptr[0]
      return result
    if type == 1:
      self.lib.lammps_extract_global.restype = POINTER(c_int)
      nlocalptr = self.lib.lammps_extract_global(self.lmp,"nlocal".encode())
      nlocal = nlocalptr[0]
      result = (c_double*nlocal)()
      self.lib.lammps_extract_variable.restype = POINTER(c_double)
      ptr = self.lib.lammps_extract_variable(self.lmp,name,group)
      for i in range(nlocal): result[i] = ptr[i]
      return result
    return None

  # set variable value
  # value is converted to string
  # returns 0 for success, -1 if failed

  def set_variable(self,name,value):
    if name: name = name.encode()
    if value: value = str(value).encode()
    return self.lib.lammps_set_variable(self.lmp,name,value)

  # return current value of thermo keyword

  def get_thermo(self,name):
    if name: name = name.encode()
    self.lib.lammps_get_thermo.restype = c_double
    return self.lib.lammps_get_thermo(self.lmp,name)

  # return total number of atoms in system

  def get_natoms(self):
    return self.lib.lammps_get_natoms(self.lmp)

  # return vector of atom properties gathered across procs, ordered by atom ID
  # name = atom property recognized by LAMMPS in atom->extract()
  # type = 0 for integer values, 1 for double values
  # count = number of per-atom valus, 1 for type or charge, 3 for x or f
  # returned data is a 1d vector - doc how it is ordered?
  # NOTE: how could we insure are converting to correct Python type
  #   e.g. for Python list or NumPy, etc
  #   ditto for extact_atom() above
  def gather_atoms(self,name,type,count):
    if name: name = name.encode()
    natoms = self.lib.lammps_get_natoms(self.lmp)
    if type == 0:
      data = ((count*natoms)*c_int)()
    elif type == 1:
      data = ((count*natoms)*c_double)()
    else: return None
    return data

  # scatter vector of atom properties across procs, ordered by atom ID
  # name = atom property recognized by LAMMPS in atom->extract()
  # type = 0 for integer values, 1 for double values
  # count = number of per-atom valus, 1 for type or charge, 3 for x or f
  # assume data is of correct type and length, as created by gather_atoms()
  # NOTE: how could we insure are passing correct type to LAMMPS
  # e.g. for Python list or NumPy, etc
  def scatter_atoms(self,name,type,count,data):
    if name: name = name.encode()

  # create N atoms on all procs
  # N = global number of atoms
  # id = ID of each atom (optional, can be None)
  # type = type of each atom (1 to Ntypes) (required)
  # x = coords of each atom as (N,3) array (required)
  # v = velocity of each atom as (N,3) array (optional, can be None)
  # NOTE: how could we insure are passing correct type to LAMMPS
  #   e.g. for Python list or NumPy, etc
  #   ditto for gather_atoms() above

  def create_atoms(self,n,id,type,x,v,image=None,shrinkexceed=False):
    if id:
      id_lmp = (c_int * n)()
      id_lmp[:] = id
      id_lmp = id

    if image:
      image_lmp = (c_int * n)()
      image_lmp[:] = image
      image_lmp = image

    type_lmp = (c_int * n)()
    type_lmp[:] = type

  def uses_exceptions(self):
    """ Return whether the LAMMPS shared library was compiled with C++ exceptions handling enabled """
      if self.lib.lammps_has_error:
        return True
      return False
Exemple #2
import sys, os, unittest
from lammps import lammps

has_mpi = False
has_mpi4py = False
has_exceptions = False
    from mpi4py import __version__ as mpi4py_version
    # tested to work with mpi4py versions 2 and 3
    has_mpi4py = mpi4py_version.split('.')[0] in ['2', '3']

    if 'LAMMPS_MACHINE_NAME' in os.environ:
        machine = os.environ['LAMMPS_MACHINE_NAME']
        machine = ""
    lmp = lammps(name=machine)
    has_mpi = lmp.has_mpi_support
    has_exceptions = lmp.has_exceptions

class PythonOpen(unittest.TestCase):
    def setUp(self):
        self.machine = None
        if 'LAMMPS_MACHINE_NAME' in os.environ:
            self.machine = os.environ['LAMMPS_MACHINE_NAME']
Exemple #3
class lammps:
  # detect if Python is using version of mpi4py that can pass a communicator
  has_mpi4py_v2 = False
    from mpi4py import MPI
    from mpi4py import __version__ as mpi4py_version
    if mpi4py_version.split('.')[0] == '2':
      has_mpi4py_v2 = True

  # create instance of LAMMPS
  def __init__(self,name="",cmdargs=None,ptr=None,comm=None):

    # determine module location
    modpath = dirname(abspath(getsourcefile(lambda:0)))

    # load liblammps.so by default
    # if name = "g++", load liblammps_g++.so

      if not name: self.lib = CDLL(join(modpath,"liblammps.so"),RTLD_GLOBAL)
      else: self.lib = CDLL(join(modpath,"liblammps_%s.so" % name),RTLD_GLOBAL)
      type,value,tb = sys.exc_info()
      raise OSError,"Could not load LAMMPS dynamic library from %s" % modpath

    # if no ptr provided, create an instance of LAMMPS
    #   don't know how to pass an MPI communicator from PyPar
    #   but we can pass an MPI communicator from mpi4py v2.0.0 and later
    #   no_mpi call lets LAMMPS use MPI_COMM_WORLD
    #   cargs = array of C strings from args
    # if ptr, then are embedding Python in LAMMPS input script
    #   ptr is the desired instance of LAMMPS
    #   just convert it to ctypes ptr and store in self.lmp
    if not ptr:
      # with mpi4py v2, can pass MPI communicator to LAMMPS
      # need to adjust for type of MPI communicator object
      # allow for int (like MPICH) or void* (like OpenMPI)
      if lammps.has_mpi4py_v2 and comm != None:
        if lammps.MPI._sizeof(lammps.MPI.Comm) == sizeof(c_int):
          MPI_Comm = c_int
          MPI_Comm = c_void_p

        narg = 0
        cargs = 0
        if cmdargs:
          narg = len(cmdargs)
          cargs = (c_char_p*narg)(*cmdargs)
          self.lib.lammps_open.argtypes = [c_int, c_char_p*narg, \
                                           MPI_Comm, c_void_p()]
          self.lib.lammps_open.argtypes = [c_int, c_int, \
                                           MPI_Comm, c_void_p()]

        self.lib.lammps_open.restype = None
        self.opened = 1
        self.lmp = c_void_p()
        comm_ptr = lammps.MPI._addressof(comm)
        comm_val = MPI_Comm.from_address(comm_ptr)

        self.opened = 1
        if cmdargs:
          narg = len(cmdargs)
          cargs = (c_char_p*narg)(*cmdargs)
          self.lmp = c_void_p()
          self.lmp = c_void_p()
          # could use just this if LAMMPS lib interface supported it
          # self.lmp = self.lib.lammps_open_no_mpi(0,None)
      self.opened = 0
      # magic to convert ptr to ctypes ptr
      pythonapi.PyCObject_AsVoidPtr.restype = c_void_p
      pythonapi.PyCObject_AsVoidPtr.argtypes = [py_object]
      self.lmp = c_void_p(pythonapi.PyCObject_AsVoidPtr(ptr))

  def __del__(self):
    if self.lmp and self.opened: self.lib.lammps_close(self.lmp)

  def close(self):
    if self.opened: self.lib.lammps_close(self.lmp)
    self.lmp = None

  def version(self):
    return self.lib.lammps_version(self.lmp)

  def file(self,file):

  def command(self,cmd):

  def extract_global(self,name,type):
    if type == 0:
      self.lib.lammps_extract_global.restype = POINTER(c_int)
    elif type == 1:
      self.lib.lammps_extract_global.restype = POINTER(c_double)
    else: return None
    ptr = self.lib.lammps_extract_global(self.lmp,name)
    return ptr[0]

  def extract_atom(self,name,type):
    if type == 0:
      self.lib.lammps_extract_atom.restype = POINTER(c_int)
    elif type == 1:
      self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_int))
    elif type == 2:
      self.lib.lammps_extract_atom.restype = POINTER(c_double)
    elif type == 3:
      self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_double))
    else: return None
    ptr = self.lib.lammps_extract_atom(self.lmp,name)
    return ptr

  def extract_compute(self,id,style,type):
    if type == 0:
      if style > 0: return None
      self.lib.lammps_extract_compute.restype = POINTER(c_double)
      ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
      return ptr[0]
    if type == 1:
      self.lib.lammps_extract_compute.restype = POINTER(c_double)
      ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
      return ptr
    if type == 2:
      self.lib.lammps_extract_compute.restype = POINTER(POINTER(c_double))
      ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
      return ptr
    return None

  # in case of global datum, free memory for 1 double via lammps_free()
  # double was allocated by library interface function
  def extract_fix(self,id,style,type,i=0,j=0):
    if style == 0:
      self.lib.lammps_extract_fix.restype = POINTER(c_double)
      ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,i,j)
      result = ptr[0]
      return result
    elif (style == 1) or (style == 2):
      if type == 1:
        self.lib.lammps_extract_fix.restype = POINTER(c_double)
      elif type == 2:
        self.lib.lammps_extract_fix.restype = POINTER(POINTER(c_double))
        return None
      ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,i,j)
      return ptr
      return None

  # free memory for 1 double or 1 vector of doubles via lammps_free()
  # for vector, must copy nlocal returned values to local c_double vector
  # memory was allocated by library interface function
  def extract_variable(self,name,group,type):
    if type == 0:
      self.lib.lammps_extract_variable.restype = POINTER(c_double)
      ptr = self.lib.lammps_extract_variable(self.lmp,name,group)
      result = ptr[0]
      return result
    if type == 1:
      self.lib.lammps_extract_global.restype = POINTER(c_int)
      nlocalptr = self.lib.lammps_extract_global(self.lmp,"nlocal")
      nlocal = nlocalptr[0]
      result = (c_double*nlocal)()
      self.lib.lammps_extract_variable.restype = POINTER(c_double)
      ptr = self.lib.lammps_extract_variable(self.lmp,name,group)
      for i in xrange(nlocal): result[i] = ptr[i]
      return result
    return None

  # set variable value
  # value is converted to string
  # returns 0 for success, -1 if failed
  def set_variable(self,name,value):
    return self.lib.lammps_set_variable(self.lmp,name,str(value))

  # return total number of atoms in system
  def get_natoms(self):
    return self.lib.lammps_get_natoms(self.lmp)

  # return vector of atom properties gathered across procs, ordered by atom ID

  def gather_atoms(self,name,type,count):
    natoms = self.lib.lammps_get_natoms(self.lmp)
    if type == 0:
      data = ((count*natoms)*c_int)()
    elif type == 1:
      data = ((count*natoms)*c_double)()
    else: return None
    return data

  # scatter vector of atom properties across procs, ordered by atom ID
  # assume vector is of correct type and length, as created by gather_atoms()

  def scatter_atoms(self,name,type,count,data):
Exemple #4
class lammps(object):

  # detect if Python is using version of mpi4py that can pass a communicator

  has_mpi4py = False
    from mpi4py import MPI
    from mpi4py import __version__ as mpi4py_version
    if mpi4py_version.split('.')[0] in ['2','3']: has_mpi4py = True

  # create instance of LAMMPS

  def __init__(self,name="",cmdargs=None,ptr=None,comm=None):
    self.comm = comm
    self.opened = 0

    # determine module location

    modpath = dirname(abspath(getsourcefile(lambda:0)))
    self.lib = None

    # if a pointer to a LAMMPS object is handed in,
    # all symbols should already be available

      if ptr: self.lib = CDLL("",RTLD_GLOBAL)
      self.lib = None

    # load liblammps.so unless name is given
    #   if name = "g++", load liblammps_g++.so
    # try loading the LAMMPS shared object from the location
    #   of lammps.py with an absolute path,
    #   so that LD_LIBRARY_PATH does not need to be set for regular install
    # fall back to loading with a relative path,
    #   typically requires LD_LIBRARY_PATH to be set appropriately

    if any([f.startswith('liblammps') and f.endswith('.dylib') for f in os.listdir(modpath)]):
      lib_ext = ".dylib"
      lib_ext = ".so"

    if not self.lib:
        if not name: self.lib = CDLL(join(modpath,"liblammps" + lib_ext),RTLD_GLOBAL)
        else: self.lib = CDLL(join(modpath,"liblammps_%s" % name + lib_ext),
        if not name: self.lib = CDLL("liblammps" + lib_ext,RTLD_GLOBAL)
        else: self.lib = CDLL("liblammps_%s" % name + lib_ext,RTLD_GLOBAL)

    # define ctypes API for each library method
    # NOTE: should add one of these for each lib function

    self.lib.lammps_extract_box.argtypes = \
    self.lib.lammps_extract_box.restype = None

    self.lib.lammps_reset_box.argtypes = \
    self.lib.lammps_reset_box.restype = None

    self.lib.lammps_gather_atoms.argtypes = \
    self.lib.lammps_gather_atoms.restype = None

    self.lib.lammps_gather_atoms_concat.argtypes = \
    self.lib.lammps_gather_atoms_concat.restype = None

    self.lib.lammps_gather_atoms_subset.argtypes = \
    self.lib.lammps_gather_atoms_subset.restype = None

    self.lib.lammps_scatter_atoms.argtypes = \
    self.lib.lammps_scatter_atoms.restype = None

    self.lib.lammps_scatter_atoms_subset.argtypes = \
    self.lib.lammps_scatter_atoms_subset.restype = None

    # if no ptr provided, create an instance of LAMMPS
    #   don't know how to pass an MPI communicator from PyPar
    #   but we can pass an MPI communicator from mpi4py v2.0.0 and later
    #   no_mpi call lets LAMMPS use MPI_COMM_WORLD
    #   cargs = array of C strings from args
    # if ptr, then are embedding Python in LAMMPS input script
    #   ptr is the desired instance of LAMMPS
    #   just convert it to ctypes ptr and store in self.lmp

    if not ptr:

      # with mpi4py v2, can pass MPI communicator to LAMMPS
      # need to adjust for type of MPI communicator object
      # allow for int (like MPICH) or void* (like OpenMPI)

      if comm:
        if not lammps.has_mpi4py:
          raise Exception('Python mpi4py version is not 2 or 3')
        if lammps.MPI._sizeof(lammps.MPI.Comm) == sizeof(c_int):
          MPI_Comm = c_int
          MPI_Comm = c_void_p

        narg = 0
        cargs = 0
        if cmdargs:
          narg = len(cmdargs)
          for i in range(narg):
            if type(cmdargs[i]) is str:
              cmdargs[i] = cmdargs[i].encode()
          cargs = (c_char_p*narg)(*cmdargs)
          self.lib.lammps_open.argtypes = [c_int, c_char_p*narg, \
                                           MPI_Comm, c_void_p()]
          self.lib.lammps_open.argtypes = [c_int, c_int, \
                                           MPI_Comm, c_void_p()]

        self.lib.lammps_open.restype = None
        self.opened = 1
        self.lmp = c_void_p()
        comm_ptr = lammps.MPI._addressof(comm)
        comm_val = MPI_Comm.from_address(comm_ptr)

        if lammps.has_mpi4py:
          from mpi4py import MPI
          self.comm = MPI.COMM_WORLD
        self.opened = 1
        if cmdargs:
          narg = len(cmdargs)
          for i in range(narg):
            if type(cmdargs[i]) is str:
              cmdargs[i] = cmdargs[i].encode()
          cargs = (c_char_p*narg)(*cmdargs)
          self.lmp = c_void_p()
          self.lmp = c_void_p()
          # could use just this if LAMMPS lib interface supported it
          # self.lmp = self.lib.lammps_open_no_mpi(0,None)

      # magic to convert ptr to ctypes ptr
      if sys.version_info >= (3, 0):
        # Python 3 (uses PyCapsule API)
        pythonapi.PyCapsule_GetPointer.restype = c_void_p
        pythonapi.PyCapsule_GetPointer.argtypes = [py_object, c_char_p]
        self.lmp = c_void_p(pythonapi.PyCapsule_GetPointer(ptr, None))
        # Python 2 (uses PyCObject API)
        pythonapi.PyCObject_AsVoidPtr.restype = c_void_p
        pythonapi.PyCObject_AsVoidPtr.argtypes = [py_object]
        self.lmp = c_void_p(pythonapi.PyCObject_AsVoidPtr(ptr))

    # optional numpy support (lazy loading)
    self._numpy = None

    # set default types
    self.c_bigint = get_ctypes_int(self.extract_setting("bigint"))
    self.c_tagint = get_ctypes_int(self.extract_setting("tagint"))
    self.c_imageint = get_ctypes_int(self.extract_setting("imageint"))
    self._installed_packages = None

    # add way to insert Python callback for fix external
    self.callback = {}
    self.FIX_EXTERNAL_CALLBACK_FUNC = CFUNCTYPE(None, c_void_p, self.c_bigint, c_int, POINTER(self.c_tagint), POINTER(POINTER(c_double)), POINTER(POINTER(c_double)))
    self.lib.lammps_set_fix_external_callback.argtypes = [c_void_p, c_char_p, self.FIX_EXTERNAL_CALLBACK_FUNC, c_void_p]
    self.lib.lammps_set_fix_external_callback.restype = None

  # shut-down LAMMPS instance

  def __del__(self):
    if self.lmp and self.opened:
      self.opened = 0

  def close(self):
    if self.opened: self.lib.lammps_close(self.lmp)
    self.lmp = None
    self.opened = 0

  def version(self):
    return self.lib.lammps_version(self.lmp)

  def file(self,file):
    if file: file = file.encode()

  # send a single command

  def command(self,cmd):
    if cmd: cmd = cmd.encode()

    if self.has_exceptions and self.lib.lammps_has_error(self.lmp):
      sb = create_string_buffer(100)
      error_type = self.lib.lammps_get_last_error_message(self.lmp, sb, 100)
      error_msg = sb.value.decode().strip()

      if error_type == 2:
        raise MPIAbortException(error_msg)
      raise Exception(error_msg)

  # send a list of commands

  def commands_list(self,cmdlist):
    cmds = [x.encode() for x in cmdlist if type(x) is str]
    args = (c_char_p * len(cmdlist))(*cmds)

  # send a string of commands

  def commands_string(self,multicmd):
    if type(multicmd) is str: multicmd = multicmd.encode()

  # extract lammps type byte sizes

  def extract_setting(self, name):
    if name: name = name.encode()
    self.lib.lammps_extract_setting.restype = c_int
    return int(self.lib.lammps_extract_setting(self.lmp,name))

  # extract global info

  def extract_global(self,name,type):
    if name: name = name.encode()
    if type == 0:
      self.lib.lammps_extract_global.restype = POINTER(c_int)
    elif type == 1:
      self.lib.lammps_extract_global.restype = POINTER(c_double)
    else: return None
    ptr = self.lib.lammps_extract_global(self.lmp,name)
    return ptr[0]

  # extract global info

  def extract_box(self):
    boxlo = (3*c_double)()
    boxhi = (3*c_double)()
    xy = c_double()
    yz = c_double()
    xz = c_double()
    periodicity = (3*c_int)()
    box_change = c_int()


    boxlo = boxlo[:3]
    boxhi = boxhi[:3]
    xy = xy.value
    yz = yz.value
    xz = xz.value
    periodicity = periodicity[:3]
    box_change = box_change.value

    return boxlo,boxhi,xy,yz,xz,periodicity,box_change

  # extract per-atom info
  # NOTE: need to insure are converting to/from correct Python type
  #   e.g. for Python list or NumPy or ctypes

  def extract_atom(self,name,type):
    if name: name = name.encode()
    if type == 0:
      self.lib.lammps_extract_atom.restype = POINTER(c_int)
    elif type == 1:
      self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_int))
    elif type == 2:
      self.lib.lammps_extract_atom.restype = POINTER(c_double)
    elif type == 3:
      self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_double))
    else: return None
    ptr = self.lib.lammps_extract_atom(self.lmp,name)
    return ptr

  def numpy(self):
    if not self._numpy:
      import numpy as np
      class LammpsNumpyWrapper:
        def __init__(self, lmp):
          self.lmp = lmp

        def _ctype_to_numpy_int(self, ctype_int):
          if ctype_int == c_int32:
            return np.int32
          elif ctype_int == c_int64:
            return np.int64
          return np.intc

        def extract_atom_iarray(self, name, nelem, dim=1):
          if name in ['id', 'molecule']:
            c_int_type = self.lmp.c_tagint
          elif name in ['image']:
            c_int_type = self.lmp.c_imageint
            c_int_type = c_int

          np_int_type = self._ctype_to_numpy_int(c_int_type)

          if dim == 1:
            tmp = self.lmp.extract_atom(name, 0)
            ptr = cast(tmp, POINTER(c_int_type * nelem))
            tmp = self.lmp.extract_atom(name, 1)
            ptr = cast(tmp[0], POINTER(c_int_type * nelem * dim))

          a = np.frombuffer(ptr.contents, dtype=np_int_type)
          a.shape = (nelem, dim)
          return a

        def extract_atom_darray(self, name, nelem, dim=1):
          if dim == 1:
            tmp = self.lmp.extract_atom(name, 2)
            ptr = cast(tmp, POINTER(c_double * nelem))
            tmp = self.lmp.extract_atom(name, 3)
            ptr = cast(tmp[0], POINTER(c_double * nelem * dim))

          a = np.frombuffer(ptr.contents)
          a.shape = (nelem, dim)
          return a

      self._numpy = LammpsNumpyWrapper(self)
    return self._numpy

  # extract compute info

  def extract_compute(self,id,style,type):
    if id: id = id.encode()
    if type == 0:
      if style == 0:
        self.lib.lammps_extract_compute.restype = POINTER(c_double)
        ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
        return ptr[0]
      else if style == 1:
        return None
      else if style == 2:
        self.lib.lammps_extract_compute.restype = POINTER(c_int)
        return ptr[0]
    if type == 1:
      self.lib.lammps_extract_compute.restype = POINTER(c_double)
      ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
      return ptr
    if type == 2:
      self.lib.lammps_extract_compute.restype = POINTER(POINTER(c_double))
      ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
      return ptr
    return None

  # extract fix info
  # in case of global datum, free memory for 1 double via lammps_free()
  # double was allocated by library interface function

  def extract_fix(self,id,style,type,i=0,j=0):
    if id: id = id.encode()
    if style == 0:
      self.lib.lammps_extract_fix.restype = POINTER(c_double)
      ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,i,j)
      result = ptr[0]
      return result
    elif (style == 1) or (style == 2):
      if type == 1:
        self.lib.lammps_extract_fix.restype = POINTER(c_double)
      elif type == 2:
        self.lib.lammps_extract_fix.restype = POINTER(POINTER(c_double))
        return None
      ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,i,j)
      return ptr
      return None

  # extract variable info
  # free memory for 1 double or 1 vector of doubles via lammps_free()
  # for vector, must copy nlocal returned values to local c_double vector
  # memory was allocated by library interface function

  def extract_variable(self,name,group,type):
    if name: name = name.encode()
    if group: group = group.encode()
    if type == 0:
      self.lib.lammps_extract_variable.restype = POINTER(c_double)
      ptr = self.lib.lammps_extract_variable(self.lmp,name,group)
      result = ptr[0]
      return result
    if type == 1:
      self.lib.lammps_extract_global.restype = POINTER(c_int)
      nlocalptr = self.lib.lammps_extract_global(self.lmp,"nlocal".encode())
      nlocal = nlocalptr[0]
      result = (c_double*nlocal)()
      self.lib.lammps_extract_variable.restype = POINTER(c_double)
      ptr = self.lib.lammps_extract_variable(self.lmp,name,group)
      for i in range(nlocal): result[i] = ptr[i]
      return result
    return None

  # return current value of thermo keyword

  def get_thermo(self,name):
    if name: name = name.encode()
    self.lib.lammps_get_thermo.restype = c_double
    return self.lib.lammps_get_thermo(self.lmp,name)

  # return total number of atoms in system

  def get_natoms(self):
    return self.lib.lammps_get_natoms(self.lmp)

  # set variable value
  # value is converted to string
  # returns 0 for success, -1 if failed

  def set_variable(self,name,value):
    if name: name = name.encode()
    if value: value = str(value).encode()
    return self.lib.lammps_set_variable(self.lmp,name,value)

  # reset simulation box size

  def reset_box(self,boxlo,boxhi,xy,yz,xz):
    cboxlo = (3*c_double)(*boxlo)
    cboxhi = (3*c_double)(*boxhi)

  # return vector of atom properties gathered across procs
  # 3 variants to match src/library.cpp
  # name = atom property recognized by LAMMPS in atom->extract()
  # type = 0 for integer values, 1 for double values
  # count = number of per-atom valus, 1 for type or charge, 3 for x or f
  # returned data is a 1d vector - doc how it is ordered?
  # NOTE: need to insure are converting to/from correct Python type
  #   e.g. for Python list or NumPy or ctypes

  def gather_atoms(self,name,type,count):
    if name: name = name.encode()
    natoms = self.lib.lammps_get_natoms(self.lmp)
    if type == 0:
      data = ((count*natoms)*c_int)()
    elif type == 1:
      data = ((count*natoms)*c_double)()
    else: return None
    return data

  def gather_atoms_concat(self,name,type,count):
    if name: name = name.encode()
    natoms = self.lib.lammps_get_natoms(self.lmp)
    if type == 0:
      data = ((count*natoms)*c_int)()
    elif type == 1:
      data = ((count*natoms)*c_double)()
    else: return None
    return data

  def gather_atoms_subset(self,name,type,count,ndata,ids):
    if name: name = name.encode()
    if type == 0:
      data = ((count*ndata)*c_int)()
    elif type == 1:
      data = ((count*ndata)*c_double)()
    else: return None
    return data

  # scatter vector of atom properties across procs
  # 2 variants to match src/library.cpp
  # name = atom property recognized by LAMMPS in atom->extract()
  # type = 0 for integer values, 1 for double values
  # count = number of per-atom valus, 1 for type or charge, 3 for x or f
  # assume data is of correct type and length, as created by gather_atoms()
  # NOTE: need to insure are converting to/from correct Python type
  #   e.g. for Python list or NumPy or ctypes

  def scatter_atoms(self,name,type,count,data):
    if name: name = name.encode()

  def scatter_atoms_subset(self,name,type,count,ndata,ids,data):
    if name: name = name.encode()

  # create N atoms on all procs
  # N = global number of atoms
  # id = ID of each atom (optional, can be None)
  # type = type of each atom (1 to Ntypes) (required)
  # x = coords of each atom as (N,3) array (required)
  # v = velocity of each atom as (N,3) array (optional, can be None)
  # NOTE: how could we insure are passing correct type to LAMMPS
  #   e.g. for Python list or NumPy, etc
  #   ditto for gather_atoms() above

  def create_atoms(self,n,id,type,x,v,image=None,shrinkexceed=False):
    if id:
      id_lmp = (c_int * n)()
      id_lmp[:] = id
      id_lmp = id

    if image:
      image_lmp = (c_int * n)()
      image_lmp[:] = image
      image_lmp = image

    type_lmp = (c_int * n)()
    type_lmp[:] = type

  def has_exceptions(self):
    """ Return whether the LAMMPS shared library was compiled with C++ exceptions handling enabled """
    return self.lib.lammps_config_has_exceptions() != 0

  def has_gzip_support(self):
    return self.lib.lammps_config_has_gzip_support() != 0

  def has_png_support(self):
    return self.lib.lammps_config_has_png_support() != 0

  def has_jpeg_support(self):
    return self.lib.lammps_config_has_jpeg_support() != 0

  def has_ffmpeg_support(self):
    return self.lib.lammps_config_has_ffmpeg_support() != 0

  def installed_packages(self):
    if self._installed_packages is None:
      self._installed_packages = []
      npackages = self.lib.lammps_config_package_count()
      sb = create_string_buffer(100)
      for idx in range(npackages):
        self.lib.lammps_config_package_name(idx, sb, 100)
    return self._installed_packages

  def set_fix_external_callback(self, fix_name, callback, caller=None):
    import numpy as np
    def _ctype_to_numpy_int(ctype_int):
          if ctype_int == c_int32:
            return np.int32
          elif ctype_int == c_int64:
            return np.int64
          return np.intc

    def callback_wrapper(caller_ptr, ntimestep, nlocal, tag_ptr, x_ptr, fext_ptr):
      if cast(caller_ptr,POINTER(py_object)).contents:
        pyCallerObj = cast(caller_ptr,POINTER(py_object)).contents.value
        pyCallerObj = None

      tptr = cast(tag_ptr, POINTER(self.c_tagint * nlocal))
      tag = np.frombuffer(tptr.contents, dtype=_ctype_to_numpy_int(self.c_tagint))
      tag.shape = (nlocal)

      xptr = cast(x_ptr[0], POINTER(c_double * nlocal * 3))
      x = np.frombuffer(xptr.contents)
      x.shape = (nlocal, 3)

      fptr = cast(fext_ptr[0], POINTER(c_double * nlocal * 3))
      f = np.frombuffer(fptr.contents)
      f.shape = (nlocal, 3)

      callback(pyCallerObj, ntimestep, nlocal, tag, x, f)

    cFunc   = self.FIX_EXTERNAL_CALLBACK_FUNC(callback_wrapper)
    cCaller = cast(pointer(py_object(caller)), c_void_p)

    self.callback[fix_name] = { 'function': cFunc, 'caller': caller }

    self.lib.lammps_set_fix_external_callback(self.lmp, fix_name.encode(), cFunc, cCaller)
Exemple #5
class lammps(object):
  # detect if Python is using version of mpi4py that can pass a communicator

  has_mpi4py_v2 = False
    from mpi4py import MPI
    from mpi4py import __version__ as mpi4py_version
    if mpi4py_version.split('.')[0] == '2':
      has_mpi4py_v2 = True

  # create instance of LAMMPS

  def __init__(self,name="",cmdargs=None,ptr=None,comm=None):
    self.comm = comm
    self.opened = 0

    # determine module location

    modpath = dirname(abspath(getsourcefile(lambda:0)))
    self.lib = None

    # if a pointer to a LAMMPS object is handed in,
    # all symbols should already be available
      if ptr: self.lib = CDLL("",RTLD_GLOBAL)
      self.lib = None

    # load liblammps.so unless name is given
    #   if name = "g++", load liblammps_g++.so
    # try loading the LAMMPS shared object from the location
    #   of lammps.py with an absolute path,
    #   so that LD_LIBRARY_PATH does not need to be set for regular install
    # fall back to loading with a relative path,
    #   typically requires LD_LIBRARY_PATH to be set appropriately
    if not self.lib:
        if not name: self.lib = CDLL(join(modpath,"liblammps.so"),RTLD_GLOBAL)
        else: self.lib = CDLL(join(modpath,"liblammps_%s.so" % name),
        if not name: self.lib = CDLL("liblammps.so",RTLD_GLOBAL)
        else: self.lib = CDLL("liblammps_%s.so" % name,RTLD_GLOBAL)

    # if no ptr provided, create an instance of LAMMPS
    #   don't know how to pass an MPI communicator from PyPar
    #   but we can pass an MPI communicator from mpi4py v2.0.0 and later
    #   no_mpi call lets LAMMPS use MPI_COMM_WORLD
    #   cargs = array of C strings from args
    # if ptr, then are embedding Python in LAMMPS input script
    #   ptr is the desired instance of LAMMPS
    #   just convert it to ctypes ptr and store in self.lmp

    if not ptr:
      # with mpi4py v2, can pass MPI communicator to LAMMPS
      # need to adjust for type of MPI communicator object
      # allow for int (like MPICH) or void* (like OpenMPI)

      if lammps.has_mpi4py_v2 and comm != None:
        if lammps.MPI._sizeof(lammps.MPI.Comm) == sizeof(c_int):
          MPI_Comm = c_int
          MPI_Comm = c_void_p

        narg = 0
        cargs = 0
        if cmdargs:
          narg = len(cmdargs)
          for i in range(narg):
            if type(cmdargs[i]) is str:
              cmdargs[i] = cmdargs[i].encode()
          cargs = (c_char_p*narg)(*cmdargs)
          self.lib.lammps_open.argtypes = [c_int, c_char_p*narg, \
                                           MPI_Comm, c_void_p()]
          self.lib.lammps_open.argtypes = [c_int, c_int, \
                                           MPI_Comm, c_void_p()]

        self.lib.lammps_open.restype = None
        self.opened = 1
        self.lmp = c_void_p()
        comm_ptr = lammps.MPI._addressof(comm)
        comm_val = MPI_Comm.from_address(comm_ptr)

        self.opened = 1
        if cmdargs:
          narg = len(cmdargs)
          for i in range(narg):
            if type(cmdargs[i]) is str:
              cmdargs[i] = cmdargs[i].encode()
          cargs = (c_char_p*narg)(*cmdargs)
          self.lmp = c_void_p()
          self.lmp = c_void_p()
          # could use just this if LAMMPS lib interface supported it
          # self.lmp = self.lib.lammps_open_no_mpi(0,None)

      # magic to convert ptr to ctypes ptr
      pythonapi.PyCObject_AsVoidPtr.restype = c_void_p
      pythonapi.PyCObject_AsVoidPtr.argtypes = [py_object]
      self.lmp = c_void_p(pythonapi.PyCObject_AsVoidPtr(ptr))

  def __del__(self):
    if self.lmp and self.opened:
      self.opened = 0

  def close(self):
    if self.opened: self.lib.lammps_close(self.lmp)
    self.lmp = None
    self.opened = 0

  def version(self):
    return self.lib.lammps_version(self.lmp)

  def file(self,file):
    if file: file = file.encode()

  def command(self,cmd):
    if cmd: cmd = cmd.encode()

    if self.uses_exceptions and self.lib.lammps_has_error(self.lmp):
      sb = create_string_buffer(100)
      error_type = self.lib.lammps_get_last_error_message(self.lmp, sb, 100)
      error_msg = sb.value.decode().strip()

      if error_type == 2:
        raise MPIAbortException(error_msg)
      raise Exception(error_msg)

  def extract_global(self,name,type):
    if name: name = name.encode()
    if type == 0:
      self.lib.lammps_extract_global.restype = POINTER(c_int)
    elif type == 1:
      self.lib.lammps_extract_global.restype = POINTER(c_double)
    else: return None
    ptr = self.lib.lammps_extract_global(self.lmp,name)
    return ptr[0]

  def extract_atom(self,name,type):
    if name: name = name.encode()
    if type == 0:
      self.lib.lammps_extract_atom.restype = POINTER(c_int)
    elif type == 1:
      self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_int))
    elif type == 2:
      self.lib.lammps_extract_atom.restype = POINTER(c_double)
    elif type == 3:
      self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_double))
    else: return None
    ptr = self.lib.lammps_extract_atom(self.lmp,name)
    return ptr

  def extract_compute(self,id,style,type):
    if id: id = id.encode()
    if type == 0:
      if style > 0: return None
      self.lib.lammps_extract_compute.restype = POINTER(c_double)
      ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
      return ptr[0]
    if type == 1:
      self.lib.lammps_extract_compute.restype = POINTER(c_double)
      ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
      return ptr
    if type == 2:
      self.lib.lammps_extract_compute.restype = POINTER(POINTER(c_double))
      ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
      return ptr
    return None

  # in case of global datum, free memory for 1 double via lammps_free()
  # double was allocated by library interface function

  def extract_fix(self,id,style,type,i=0,j=0):
    if id: id = id.encode()
    if style == 0:
      self.lib.lammps_extract_fix.restype = POINTER(c_double)
      ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,i,j)
      result = ptr[0]
      return result
    elif (style == 1) or (style == 2):
      if type == 1:
        self.lib.lammps_extract_fix.restype = POINTER(c_double)
      elif type == 2:
        self.lib.lammps_extract_fix.restype = POINTER(POINTER(c_double))
        return None
      ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,i,j)
      return ptr
      return None

  # free memory for 1 double or 1 vector of doubles via lammps_free()
  # for vector, must copy nlocal returned values to local c_double vector
  # memory was allocated by library interface function

  def extract_variable(self,name,group,type):
    if name: name = name.encode()
    if group: group = group.encode()
    if type == 0:
      self.lib.lammps_extract_variable.restype = POINTER(c_double)
      ptr = self.lib.lammps_extract_variable(self.lmp,name,group)
      result = ptr[0]
      return result
    if type == 1:
      self.lib.lammps_extract_global.restype = POINTER(c_int)
      nlocalptr = self.lib.lammps_extract_global(self.lmp,"nlocal".encode())
      nlocal = nlocalptr[0]
      result = (c_double*nlocal)()
      self.lib.lammps_extract_variable.restype = POINTER(c_double)
      ptr = self.lib.lammps_extract_variable(self.lmp,name,group)
      for i in range(nlocal): result[i] = ptr[i]
      return result
    return None

  # set variable value
  # value is converted to string
  # returns 0 for success, -1 if failed

  def set_variable(self,name,value):
    if name: name = name.encode()
    if value: value = str(value).encode()
    return self.lib.lammps_set_variable(self.lmp,name,str(value))

  # return current value of thermo keyword

  def get_thermo(self,name):
    if name: name = name.encode()
    self.lib.lammps_get_thermo.restype = c_double
    return self.lib.lammps_get_thermo(self.lmp,name)

  # return total number of atoms in system

  def get_natoms(self):
    return self.lib.lammps_get_natoms(self.lmp)

  # return vector of atom properties gathered across procs, ordered by atom ID

  def gather_atoms(self,name,type,count):
    if name: name = name.encode()
    natoms = self.lib.lammps_get_natoms(self.lmp)
    if type == 0:
      data = ((count*natoms)*c_int)()
    elif type == 1:
      data = ((count*natoms)*c_double)()
    else: return None
    return data

  # scatter vector of atom properties across procs, ordered by atom ID
  # assume vector is of correct type and length, as created by gather_atoms()

  def scatter_atoms(self,name,type,count,data):
    if name: name = name.encode()

  def uses_exceptions(self):
      if self.lib.lammps_has_error:
        return True
      return False
Exemple #6
class lammps(object):
    # detect if Python is using version of mpi4py that can pass a communicator

    has_mpi4py = False
        from mpi4py import MPI
        from mpi4py import __version__ as mpi4py_version
        if mpi4py_version.split('.')[0] in ['2', '3']: has_mpi4py = True

    # create instance of LAMMPS

#<<<<<<< master
#    def __init__(self, infile, label="", mesh=100., dmtol=0.001, \
    def __init__(self, infile, label="", \
                 constraints=[], tdir="./", lunit="Ang", eunit="eV", md2ang=0.06466, \
                 name="", cmdargs=args.split(), ptr=None, comm=None
        self.comm = comm
        self.opened = 0

        # determine module location

        modpath = dirname(abspath(getsourcefile(lambda: 0)))
        self.lib = None

        # if a pointer to a LAMMPS object is handed in,
        # all symbols should already be available

            if ptr: self.lib = CDLL("", RTLD_GLOBAL)
            self.lib = None

        # load liblammps.so unless name is given
        #   if name = "g++", load liblammps_g++.so
        # try loading the LAMMPS shared object from the location
        #   of lammps.py with an absolute path,
        #   so that LD_LIBRARY_PATH does not need to be set for regular install
        # fall back to loading with a relative path,
        #   typically requires LD_LIBRARY_PATH to be set appropriately

        if not self.lib:
                if not name:
                    self.lib = CDLL(join(modpath, "liblammps.so"), RTLD_GLOBAL)
                    self.lib = CDLL(join(modpath, "liblammps_%s.so" % name),
                if not name:
                    self.lib = CDLL("liblammps.so", RTLD_GLOBAL)
                    self.lib = CDLL("liblammps_%s.so" % name, RTLD_GLOBAL)

        # define ctypes API for each library method
        # NOTE: should add one of these for each lib function

        self.lib.lammps_extract_box.argtypes = \
            [c_void_p, POINTER(c_double), POINTER(c_double),
             POINTER(c_double), POINTER(c_double), POINTER(c_double),
             POINTER(c_int), POINTER(c_int)]
        self.lib.lammps_extract_box.restype = None

        self.lib.lammps_reset_box.argtypes = \
            [c_void_p, POINTER(c_double), POINTER(c_double), c_double, c_double, c_double]
        self.lib.lammps_reset_box.restype = None

        self.lib.lammps_gather_atoms.argtypes = \
            [c_void_p, c_char_p, c_int, c_int, c_void_p]
        self.lib.lammps_gather_atoms.restype = None

        self.lib.lammps_gather_atoms_concat.argtypes = \
            [c_void_p, c_char_p, c_int, c_int, c_void_p]
        self.lib.lammps_gather_atoms_concat.restype = None

        self.lib.lammps_gather_atoms_subset.argtypes = \
            [c_void_p, c_char_p, c_int, c_int, c_int, POINTER(c_int), c_void_p]
        self.lib.lammps_gather_atoms_subset.restype = None

        self.lib.lammps_scatter_atoms.argtypes = \
            [c_void_p, c_char_p, c_int, c_int, c_void_p]
        self.lib.lammps_scatter_atoms.restype = None

        self.lib.lammps_scatter_atoms_subset.argtypes = \
            [c_void_p, c_char_p, c_int, c_int, c_int, POINTER(c_int), c_void_p]
        self.lib.lammps_scatter_atoms_subset.restype = None

        # if no ptr provided, create an instance of LAMMPS
        #   don't know how to pass an MPI communicator from PyPar
        #   but we can pass an MPI communicator from mpi4py v2.0.0 and later
        #   no_mpi call lets LAMMPS use MPI_COMM_WORLD
        #   cargs = array of C strings from args
        # if ptr, then are embedding Python in LAMMPS input script
        #   ptr is the desired instance of LAMMPS
        #   just convert it to ctypes ptr and store in self.lmp

        if not ptr:

            # with mpi4py v2, can pass MPI communicator to LAMMPS
            # need to adjust for type of MPI communicator object
            # allow for int (like MPICH) or void* (like OpenMPI)

            if comm:
                if not lammps.has_mpi4py:
                    raise Exception('Python mpi4py version is not 2 or 3')
                if lammps.MPI._sizeof(lammps.MPI.Comm) == sizeof(c_int):
                    MPI_Comm = c_int
                    MPI_Comm = c_void_p

                narg = 0
                cargs = 0
                if cmdargs:
                    cmdargs.insert(0, "lammps.py")
                    narg = len(cmdargs)
                    for i in range(narg):
                        if type(cmdargs[i]) is str:
                            cmdargs[i] = cmdargs[i].encode()
                    cargs = (c_char_p * narg)(*cmdargs)
                    self.lib.lammps_open.argtypes = [c_int, c_char_p * narg, \
                                                     MPI_Comm, c_void_p()]
                    self.lib.lammps_open.argtypes = [c_int, c_int, \
                                                     MPI_Comm, c_void_p()]

                self.lib.lammps_open.restype = None
                self.opened = 1
                self.lmp = c_void_p()
                comm_ptr = lammps.MPI._addressof(comm)
                comm_val = MPI_Comm.from_address(comm_ptr)
                self.lib.lammps_open(narg, cargs, comm_val, byref(self.lmp))

                if lammps.has_mpi4py:
                    from mpi4py import MPI
                    self.comm = MPI.COMM_WORLD
                self.opened = 1
                if cmdargs:
                    cmdargs.insert(0, "lammps.py")
                    narg = len(cmdargs)
                    for i in range(narg):
                        if type(cmdargs[i]) is str:
                            cmdargs[i] = cmdargs[i].encode()
                    cargs = (c_char_p * narg)(*cmdargs)
                    self.lmp = c_void_p()
                    self.lib.lammps_open_no_mpi(narg, cargs, byref(self.lmp))
                    self.lmp = c_void_p()
                    self.lib.lammps_open_no_mpi(0, None, byref(self.lmp))
                    # could use just this if LAMMPS lib interface supported it
                    # self.lmp = self.lib.lammps_open_no_mpi(0,None)

            # magic to convert ptr to ctypes ptr
            if sys.version_info >= (3, 0):
                # Python 3 (uses PyCapsule API)
                pythonapi.PyCapsule_GetPointer.restype = c_void_p
                pythonapi.PyCapsule_GetPointer.argtypes = [py_object, c_char_p]
                self.lmp = c_void_p(pythonapi.PyCapsule_GetPointer(ptr, None))
                # Python 2 (uses PyCObject API)
                pythonapi.PyCObject_AsVoidPtr.restype = c_void_p
                pythonapi.PyCObject_AsVoidPtr.argtypes = [py_object]
                self.lmp = c_void_p(pythonapi.PyCObject_AsVoidPtr(ptr))

        # optional numpy support (lazy loading)
        self._numpy = None
        # set default types
        self.c_bigint = get_ctypes_int(self.extract_setting("bigint"))
        self.c_tagint = get_ctypes_int(self.extract_setting("tagint"))
        self.c_imageint = get_ctypes_int(self.extract_setting("imageint"))
        self._installed_packages = None
        self.infile = infile
        self.md2ang = md2ang
        self.constraints = constraints
        self.label = label
        self.lunit = lunit
        self.eunit = eunit

        #start lammps

    def start(self, np=1):
        print("lammps launched")
        #todo:better to set the unit to metals here again
        self.command("units metal")

        lines = open(self.infile, 'r').readlines()
        for line in lines:
        self.type = N.array(self.gather_atoms("type", 0, 1))
        #self.mass = N.array(self.gather_atoms("mass",1,1))
        self.mass = self.extract_atom("mass", 2)
        self.els = []
        for type in self.type:
        self.xyz = self.gather_atoms("x", 1, 3)
        self.newxyz = self.gather_atoms("x", 1, 3)
        self.conv = self.md2ang * N.array(
            [3 * [1.0 / N.sqrt(mass)] for mass in self.els]).flatten()
        self.number = self.get_natoms()
        #<<<<<<< Updated upstream
        #<<<<<<< master

        #        #conversion factor from eV/Ang (force from lammps)
        #        #to the intermal unit of MD
        #        #todo
        #        #self.els = self.extract_atom("mass",2)
        #        self.els = self.gather_atoms("mass",1,1)
        #        print("self.els:",self.els[1])
        #        #self.conv = self.md2ang*N.array([3*[1.0/N.sqrt(AtomicMassTable[a])]\
        #                   #                        for a in self.els]).flatten()
        #        self.conv = 1.
        #>>>>>>> master

        #conversion factor from eV/Ang (force from lammps)
        #to the intermal unit of MD
        #self.els = self.extract_atom("mass",2)
        #self.els = self.gather_atoms("mass",1,1)
        #self.conv = self.md2ang*N.array([3*[1.0/N.sqrt(AtomicMassTable[a])]\
        #                        for a in self.els]).flatten()
        #self.conv = 1.

        self.type = N.array(self.gather_atoms("type", 0, 1))
        #self.mass = N.array(self.gather_atoms("mass",1,1))
        self.mass = self.extract_atom("mass", 2)
        self.els = []
        for type in self.type:
        self.xyz = self.gather_atoms("x", 1, 3)
        self.newxyz = self.gather_atoms("x", 1, 3)
        self.conv = self.md2ang * N.array(
            [3 * [1.0 / N.sqrt(mass)] for mass in self.els]).flatten()
        self.axyz = []
        for i, a in enumerate(self.els):
                get_atomname(a), self.xyz[i * 3], self.xyz[i * 3 + 1],
                self.xyz[i * 3 + 2]

    def quit(self):
        print("Quit lammps!")

    def newx(self, q):
        for i in range(3 * self.number):
            self.newxyz[i] = self.xyz[i] + self.conv[i] * q[i]
        return self.newxyz

    def absforce(self, q):
        self.scatter_atoms("x", 1, 3, self.newx(q))
        self.command("run 1")
        return self.conv * N.array(self.gather_atoms("f", 1, 3))

    def initforce(self):
        print("Calculate zero displacement force")
        extq = N.zeros(3 * self.number)
        self.f0 = self.absforce(extq)

    def force(self, q):
        f = self.absforce(q) - self.f0
        return f

    # send a single command
    def command(self, cmd):
        if cmd: cmd = cmd.encode()
        self.lib.lammps_command(self.lmp, cmd)

        if self.has_exceptions and self.lib.lammps_has_error(self.lmp):
            sb = create_string_buffer(100)
            error_type = self.lib.lammps_get_last_error_message(
                self.lmp, sb, 100)
            error_msg = sb.value.decode().strip()

            if error_type == 2:
                raise MPIAbortException(error_msg)
            raise Exception(error_msg)

    # send a list of commands
    def commands_list(self, cmdlist):
        cmds = [x.encode() for x in cmdlist if type(x) is str]
        args = (c_char_p * len(cmdlist))(*cmds)
        self.lib.lammps_commands_list(self.lmp, len(cmdlist), args)

    # send a string of commands
    def commands_string(self, multicmd):
        if type(multicmd) is str: multicmd = multicmd.encode()
        self.lib.lammps_commands_string(self.lmp, c_char_p(multicmd))

    # extract lammps type byte sizes
    def extract_setting(self, name):
        if name: name = name.encode()
        self.lib.lammps_extract_atom.restype = c_int
        return int(self.lib.lammps_extract_setting(self.lmp, name))

    # extract global info
    def extract_global(self, name, type):
        if name: name = name.encode()
        if type == 0:
            self.lib.lammps_extract_global.restype = POINTER(c_int)
        elif type == 1:
            self.lib.lammps_extract_global.restype = POINTER(c_double)
            return None
        ptr = self.lib.lammps_extract_global(self.lmp, name)
        return ptr[0]

    # extract global info
    def extract_box(self):
        boxlo = (3 * c_double)()
        boxhi = (3 * c_double)()
        xy = c_double()
        yz = c_double()
        xz = c_double()
        periodicity = (3 * c_int)()
        box_change = c_int()

        self.lib.lammps_extract_box(self.lmp, boxlo, boxhi, byref(xy),
                                    byref(yz), byref(xz), periodicity,

        boxlo = boxlo[:3]
        boxhi = boxhi[:3]
        xy = xy.value
        yz = yz.value
        xz = xz.value
        periodicity = periodicity[:3]
        box_change = box_change.value

        return boxlo, boxhi, xy, yz, xz, periodicity, box_change

    # extract per-atom info
    # NOTE: need to insure are converting to/from correct Python type
    #   e.g. for Python list or NumPy or ctypes

    def extract_atom(self, name, type):
        if name: name = name.encode()
        if type == 0:
            self.lib.lammps_extract_atom.restype = POINTER(c_int)
        elif type == 1:
            self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_int))
        elif type == 2:
            self.lib.lammps_extract_atom.restype = POINTER(c_double)
        elif type == 3:
            self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_double))
            return None
        ptr = self.lib.lammps_extract_atom(self.lmp, name)
        return ptr

    # shut-down LAMMPS instance

    def __del__(self):
        if self.lmp and self.opened:
            self.opened = 0

    def close(self):
        if self.opened: self.lib.lammps_close(self.lmp)
        self.lmp = None
        self.opened = 0

    def version(self):
        return self.lib.lammps_version(self.lmp)

    def file(self, file):
        if file: file = file.encode()
        self.lib.lammps_file(self.lmp, file)

    def numpy(self):
        if not self._numpy:
            import numpy as np

            class LammpsNumpyWrapper:
                def __init__(self, lmp):
                    self.lmp = lmp

                def _ctype_to_numpy_int(self, ctype_int):
                    if ctype_int == c_int32:
                        return np.int32
                    elif ctype_int == c_int64:
                        return np.int64
                    return np.intc

                def extract_atom_iarray(self, name, nelem, dim=1):
                    if name in ['id', 'molecule']:
                        c_int_type = self.lmp.c_tagint
                    elif name in ['image']:
                        c_int_type = self.lmp.c_imageint
                        c_int_type = c_int

                    np_int_type = self._ctype_to_numpy_int(c_int_type)

                    if dim == 1:
                        tmp = self.lmp.extract_atom(name, 0)
                        ptr = cast(tmp, POINTER(c_int_type * nelem))
                        tmp = self.lmp.extract_atom(name, 1)
                        ptr = cast(tmp[0], POINTER(c_int_type * nelem * dim))

                    a = np.frombuffer(ptr.contents, dtype=np_int_type)
                    a.shape = (nelem, dim)
                    return a

                def extract_atom_darray(self, name, nelem, dim=1):
                    if dim == 1:
                        tmp = self.lmp.extract_atom(name, 2)
                        ptr = cast(tmp, POINTER(c_double * nelem))
                        tmp = self.lmp.extract_atom(name, 3)
                        ptr = cast(tmp[0], POINTER(c_double * nelem * dim))

                    a = np.frombuffer(ptr.contents)
                    a.shape = (nelem, dim)
                    return a

            self._numpy = LammpsNumpyWrapper(self)
        return self._numpy

    # extract compute info

    def extract_compute(self, id, style, type):
        if id: id = id.encode()
        if type == 0:
            if style > 0: return None
            self.lib.lammps_extract_compute.restype = POINTER(c_double)
            ptr = self.lib.lammps_extract_compute(self.lmp, id, style, type)
            return ptr[0]
        if type == 1:
            self.lib.lammps_extract_compute.restype = POINTER(c_double)
            ptr = self.lib.lammps_extract_compute(self.lmp, id, style, type)
            return ptr
        if type == 2:
            if style == 0:
                self.lib.lammps_extract_compute.restype = POINTER(c_int)
                ptr = self.lib.lammps_extract_compute(self.lmp, id, style,
                return ptr[0]
                self.lib.lammps_extract_compute.restype = POINTER(
                ptr = self.lib.lammps_extract_compute(self.lmp, id, style,
                return ptr
        return None

    # extract fix info
    # in case of global datum, free memory for 1 double via lammps_free()
    # double was allocated by library interface function

    def extract_fix(self, id, style, type, i=0, j=0):
        if id: id = id.encode()
        if style == 0:
            self.lib.lammps_extract_fix.restype = POINTER(c_double)
            ptr = self.lib.lammps_extract_fix(self.lmp, id, style, type, i, j)
            result = ptr[0]
            return result
        elif (style == 1) or (style == 2):
            if type == 1:
                self.lib.lammps_extract_fix.restype = POINTER(c_double)
            elif type == 2:
                self.lib.lammps_extract_fix.restype = POINTER(
                return None
            ptr = self.lib.lammps_extract_fix(self.lmp, id, style, type, i, j)
            return ptr
            return None

    # extract variable info
    # free memory for 1 double or 1 vector of doubles via lammps_free()
    # for vector, must copy nlocal returned values to local c_double vector
    # memory was allocated by library interface function

    def extract_variable(self, name, group, type):
        if name: name = name.encode()
        if group: group = group.encode()
        if type == 0:
            self.lib.lammps_extract_variable.restype = POINTER(c_double)
            ptr = self.lib.lammps_extract_variable(self.lmp, name, group)
            result = ptr[0]
            return result
        if type == 1:
            self.lib.lammps_extract_global.restype = POINTER(c_int)
            nlocalptr = self.lib.lammps_extract_global(self.lmp,
            nlocal = nlocalptr[0]
            result = (c_double * nlocal)()
            self.lib.lammps_extract_variable.restype = POINTER(c_double)
            ptr = self.lib.lammps_extract_variable(self.lmp, name, group)
            for i in range(nlocal):
                result[i] = ptr[i]
            return result
        return None

    # return current value of thermo keyword

    def get_thermo(self, name):
        if name: name = name.encode()
        self.lib.lammps_get_thermo.restype = c_double
        return self.lib.lammps_get_thermo(self.lmp, name)

    # return total number of atoms in system

    def get_natoms(self):
        return self.lib.lammps_get_natoms(self.lmp)

    # set variable value
    # value is converted to string
    # returns 0 for success, -1 if failed

    def set_variable(self, name, value):
        if name: name = name.encode()
        if value: value = str(value).encode()
        return self.lib.lammps_set_variable(self.lmp, name, value)

    # reset simulation box size

    def reset_box(self, boxlo, boxhi, xy, yz, xz):
        cboxlo = (3 * c_double)(*boxlo)
        cboxhi = (3 * c_double)(*boxhi)
        self.lib.lammps_reset_box(self.lmp, cboxlo, cboxhi, xy, yz, xz)

    # return vector of atom properties gathered across procs
    # 3 variants to match src/library.cpp
    # name = atom property recognized by LAMMPS in atom->extract()
    # type = 0 for integer values, 1 for double values
    # count = number of per-atom valus, 1 for type or charge, 3 for x or f
    # returned data is a 1d vector - doc how it is ordered?
    # NOTE: need to insure are converting to/from correct Python type
    #   e.g. for Python list or NumPy or ctypes

    def gather_atoms(self, name, type, count):
        if name: name = name.encode()
        natoms = self.lib.lammps_get_natoms(self.lmp)
        if type == 0:
            data = ((count * natoms) * c_int)()
            self.lib.lammps_gather_atoms(self.lmp, name, type, count, data)
        elif type == 1:
            data = ((count * natoms) * c_double)()
            self.lib.lammps_gather_atoms(self.lmp, name, type, count, data)
            return None
        return data

    def gather_atoms_concat(self, name, type, count):
        if name: name = name.encode()
        natoms = self.lib.lammps_get_natoms(self.lmp)
        if type == 0:
            data = ((count * natoms) * c_int)()
            self.lib.lammps_gather_atoms_concat(self.lmp, name, type, count,
        elif type == 1:
            data = ((count * natoms) * c_double)()
            self.lib.lammps_gather_atoms_concat(self.lmp, name, type, count,
            return None
        return data

    def gather_atoms_subset(self, name, type, count, ndata, ids):
        if name: name = name.encode()
        if type == 0:
            data = ((count * ndata) * c_int)()
            self.lib.lammps_gather_atoms_subset(self.lmp, name, type, count,
                                                ndata, ids, data)
        elif type == 1:
            data = ((count * ndata) * c_double)()
            self.lib.lammps_gather_atoms_subset(self.lmp, name, type, count,
                                                ndata, ids, data)
            return None
        return data

    # scatter vector of atom properties across procs
    # 2 variants to match src/library.cpp
    # name = atom property recognized by LAMMPS in atom->extract()
    # type = 0 for integer values, 1 for double values
    # count = number of per-atom valus, 1 for type or charge, 3 for x or f
    # assume data is of correct type and length, as created by gather_atoms()
    # NOTE: need to insure are converting to/from correct Python type
    #   e.g. for Python list or NumPy or ctypes

    def scatter_atoms(self, name, type, count, data):
        if name: name = name.encode()
        self.lib.lammps_scatter_atoms(self.lmp, name, type, count, data)

    def scatter_atoms_subset(self, name, type, count, ndata, ids, data):
        if name: name = name.encode()
        self.lib.lammps_scatter_atoms_subset(self.lmp, name, type, count,
                                             ndata, ids, data)

    # create N atoms on all procs
    # N = global number of atoms
    # id = ID of each atom (optional, can be None)
    # type = type of each atom (1 to Ntypes) (required)
    # x = coords of each atom as (N,3) array (required)
    # v = velocity of each atom as (N,3) array (optional, can be None)
    # NOTE: how could we insure are passing correct type to LAMMPS
    #   e.g. for Python list or NumPy, etc
    #   ditto for gather_atoms() above

    def create_atoms(self, n, id, type, x, v, image=None, shrinkexceed=False):
        if id:
            id_lmp = (c_int * n)()
            id_lmp[:] = id
            id_lmp = id

        if image:
            image_lmp = (c_int * n)()
            image_lmp[:] = image
            image_lmp = image

        type_lmp = (c_int * n)()
        type_lmp[:] = type
        self.lib.lammps_create_atoms(self.lmp, n, id_lmp, type_lmp, x, v,
                                     image_lmp, shrinkexceed)

    def has_exceptions(self):
        """ Return whether the LAMMPS shared library was compiled with C++ exceptions handling enabled """
        return self.lib.lammps_config_has_exceptions() != 0

    def has_gzip_support(self):
        return self.lib.lammps_config_has_gzip_support() != 0

    def has_png_support(self):
        return self.lib.lammps_config_has_png_support() != 0

    def has_jpeg_support(self):
        return self.lib.lammps_config_has_jpeg_support() != 0

    def has_ffmpeg_support(self):
        return self.lib.lammps_config_has_ffmpeg_support() != 0

    def installed_packages(self):  #
        if self._installed_packages is None:
            self._installed_packages = []
            npackages = self.lib.lammps_config_package_count()
            sb = create_string_buffer(100)
            for idx in range(npackages):
                self.lib.lammps_config_package_name(idx, sb, 100)
        return self._installed_packages