class NetCDFContainer(object): def __init__(self, fn, frame=0, double=False, store_force=False, mode='r', format='NETCDF4'): self._fn = fn self._data = None try: if __have_netcdf4__: self._data = Dataset(fn, mode, format=format) else: if mode == 'ws': mode = 'w' self._data = NetCDFFile(fn, mode) except RuntimeError as e: raise RuntimeError('Error opening file "{0}": {1}'.format(fn, e)) if double: self._float_str = 'f8' else: self._float_str = 'f4' self._store_force = store_force self._is_amber = False if mode[0] == 'w': self._data.program = 'PyCo' self._data.programVersion = 'N/A' self._is_defined = False else: if 'nx' in self._data.dimensions and 'ny' in self._data.dimensions: try: self._shape = (len(self._data.dimensions['nx']), len(self._data.dimensions['ny'])) except TypeError: self._shape = (self._data.dimensions['nx'], self._data.dimensions['ny']) self._shape2 = self._shape elif 'atom' in self._data.dimensions: n = self._data.dimensions['atom'] nx = int(sqrt(n)) assert nx * nx == n self._shape = (nx, nx) self._shape2 = (nx, nx) self._is_amber = True else: raise RuntimeError('Unknown NetCDF convention used for file ' '%s.' % fn) self._is_defined = True if frame < 0: self._cur_frame = len(self) + frame else: self._cur_frame = frame def __del__(self): if self._data is not None: self._data.close() def _define_file_structure(self, shape): # print 'defining file structure, shape = {0}'.format(shape) self._shape = shape if len(shape) == 3: ndof, nx, ny = shape else: ndof = 1 nx, ny = shape self._shape2 = (nx, ny) if 'frame' not in self._data.dimensions: self._data.createDimension('frame', None) if ndof > 1 and 'ndof' not in self._data.dimensions: self._data.createDimension('ndof', ndof) if 'nx' not in self._data.dimensions: self._data.createDimension('nx', nx) if 'ny' not in self._data.dimensions: self._data.createDimension('ny', ny) self._data.sync() self._is_defined = True def set_shape(self, x, ndof=None): try: shape = x.shape except AttributeError: shape = x if not self._is_defined: if ndof is None or ndof == 1: self._define_file_structure(shape) else: self._define_file_structure([ndof] + list(shape)) else: if ndof is None or ndof == 1: if not np.all(np.array(shape) == np.array(self._shape)): raise RuntimeError( 'Shape mismatch: NetCDF file has shape ' '{0} x {1}, but someone is trying to ' 'override this with shape {2} x {3}.'.format( self._shape[0], self._shape[1], shape[0], shape[1])) else: assert np.all( np.array([ndof] + list(shape)) == np.array(self._shape)) def __len__(self): try: length = len(self._data.dimensions['frame']) except TypeError: length = self._data.dimensions['frame'] return length def close(self): if self._data is not None: self._data.close() self._is_defined = False self._data = None def has_h(self): return 'h' in self._data.variables def set_rigid_surface(self, h, ndof=None): self.set_shape(h, ndof=ndof) nx, ny = self._shape2 hnx, hny = h.shape if 'h' not in self._data.variables: if hnx != nx or hny != ny: if 'rigid_nx' not in self._data.dimensions: self._data.createDimension('rigid_nx', hnx) if 'rigid_ny' not in self._data.dimensions: self._data.createDimension('rigid_ny', hny) self._data.createVariable('h', 'f8', ( 'rigid_nx', 'rigid_ny', )) else: self._data.createVariable('h', 'f8', ( 'nx', 'ny', )) self._data.variables['h'][:, :] = h # Backward compatibility set_h = set_rigid_surface def set_elastic_surface(self, h): if 'elastic_surface' not in self._data.variables: self._data.createVariable('elastic_surface', 'f8', ( 'nx', 'ny', )) self._data.variables['elastic_surface'][:, :] = h def get_rigid_surface(self): return self._data.variables['h'][:, :] # Backward compatibility get_h = get_rigid_surface def get_elastic_surface(self): return self._data.variables['elastic_surface'] def get_filename(self): return self._fn def get_next_frame(self): frame = NetCDFContainerFrame(self, self._cur_frame) self._cur_frame += 1 return frame def set_cursor(self, cur_frame): self._cur_frame = cur_frame def get_cursor(self): return self._cur_frame def __getattr__(self, name): if name[0] == '_': return self.__dict__[name] if name in self._data.variables: return self._data.variables[name][...] return self._data.__getattr__(name) def __setattr__(self, name, value): if name[0] == '_': return object.__setattr__(self, name, value) if isinstance(value, np.ndarray) and value.shape != (): if name not in self._data.variables: if len(value.shape) == len(self._shape) and \ np.all(np.array(value.shape) == np.array(self._shape)): self._data.createVariable(name, 'f8', ( 'nx', 'ny', )) else: raise RuntimeError('Not sure how to guess NetCDF type for ' 'field "{0}" which is a numpy ndarray ' 'with type {1} and shape {2}.'.format( name, value.dtype, value.shape)) self._data.variables[name][:, :] = value return return self._data.__setattr__(name, value) def __setitem__(self, i, value): if isinstance(i, str): return self.__setattr__(i, value) raise RuntimeError('Cannot set full frame.') def __getitem__(self, i): if isinstance(i, str): return self.__getattr__(i) if isinstance(i, slice): return [ NetCDFContainerFrame(self, j) for j in range(*i.indices(len(self))) ] return NetCDFContainerFrame(self, i) def __iter__(self): for i in range(len(self)): yield NetCDFContainerFrame(self, i) def get_size(self): return self._shape def sync(self): self._data.sync() def __str__(self): return self._fn