def _get_fileset(self): first_fn = self._get_files()[0] first_file = fileDM(first_fn, on_memory=True) if first_file.numObjects == 1: idx = 0 else: idx = 1 try: raw_dtype = first_file._DM2NPDataType(first_file.dataType[idx]) shape = (first_file.ySize[idx], first_file.xSize[idx]) except IndexError as e: raise DataSetException( "could not determine dtype or signal shape") from e start_idx = 0 files = [] for fn in self._get_files(): z_size = self._z_sizes[fn] f = StackedDMFile( path=fn, start_idx=start_idx, end_idx=start_idx + z_size, sig_shape=shape, native_dtype=raw_dtype, file_header=self._offsets[fn], ) files.append(f) start_idx += 1 # FIXME: .nav.size? return DMFileSet(files)
def get_metadata_from_dmFile(fp): """ Accepts a filepath to a dm file and returns a Metadata instance """ metadata = Metadata() with dm.fileDM(fp, on_memory=False) as dmFile: pixelSizes = dmFile.scale pixelUnits = dmFile.scaleUnit assert pixelSizes[0] == pixelSizes[ 1], "Rx and Ry pixel sizes don't match" assert pixelSizes[2] == pixelSizes[ 3], "Qx and Qy pixel sizes don't match" assert pixelUnits[0] == pixelUnits[ 1], "Rx and Ry pixel units don't match" assert pixelUnits[2] == pixelUnits[ 3], "Qx and Qy pixel units don't match" for i in range(len(pixelUnits)): if pixelUnits[i] == "": pixelUnits[i] = "pixels" metadata.set_R_pixel_size__microscope(pixelSizes[0]) metadata.set_R_pixel_size_units__microscope(pixelUnits[0]) metadata.set_Q_pixel_size__microscope(pixelSizes[2]) metadata.set_Q_pixel_size_units__microscope(pixelUnits[2]) return metadata
def _get_nav_shape(path): with dm.fileDM(_get_gtg_path(path), on_memory=True) as dm_file: nav_shape_y = dm_file.allTags.get('.SI Dimensions.Size Y') nav_shape_x = dm_file.allTags.get('.SI Dimensions.Size X') if nav_shape_y is not None and nav_shape_x is not None: return (int(nav_shape_y), int(nav_shape_x)) return None
def _scansize_without_flyback(self): with dm.fileDM(_get_gtg_path(self._path), on_memory=True) as dm_file: ss = ( dm_file.allTags['.SI Image Tags.SI.Acquisition.Spatial Sampling.Height (pixels)'], dm_file.allTags['.SI Image Tags.SI.Acquisition.Spatial Sampling.Width (pixels)'] ) return tuple(int(i) for i in ss)
def read_dm_mmap(filename): """ Read a .dm3/.dm4 file, using dm.py to read data to a memory mapped np.memmap object, which is stored in the outpute DataCube.data. Read the metadata with hyperspy. """ assert (filename.endswith('.dm3') or filename.endswith('.dm4')), 'File must be a .dm3 or .dm4' # Get metadata metadata = Metadata(init='hs', filepath=filename) # Load .dm3/.dm4 files with dm.py with fileDM(filename, verbose=False) as dmfile: data = dmfile.getMemmap(0) # Get datacube datacube = DataCube(data=data) # Link metadata and data datacube.metadata = metadata # Set scan shape, if in metadata try: R_Nx = int(metadata.get_metadata_item('scan_size_Nx')) R_Ny = int(metadata.get_metadata_item('scan_size_Ny')) datacube.set_scan_shape(R_Nx, R_Ny) except ValueError: print( "Warning: scan shape not detected in metadata; please check / set manually." ) return datacube
def _get_fileset(self): first_fn = self._get_files()[0] first_file = fileDM(first_fn, on_memory=True) if first_file.numObjects == 1: idx = 0 else: idx = 1 try: raw_dtype = first_file._DM2NPDataType(first_file.dataType[idx]) shape = (first_file.ySize[idx], first_file.xSize[idx]) except IndexError as e: raise DataSetException( "could not determine dtype or signal shape") from e start_idx = 0 files = [] for fn in self._get_files(): f = StackedDMFile( path=fn, start_idx=start_idx, offset=self._offsets[fn], shape=shape, dtype=raw_dtype, ) files.append(f) start_idx += f.num_frames return DMFileSet(files)
def check_valid(self): first_fn = self._get_files()[0] try: with fileDM(first_fn, on_memory=True): pass return True except (IOError, OSError) as e: raise DataSetException("invalid dataset: %s" % e)
def from_dm4_file(filename): m = Metadata4D() with fileDM(filename) as f: m.E_ev = f.allTags['.ImageList.2.ImageTags.Microscope Info.Voltage'] m.scan_step = np.array(f.scale[-2:]) * 10 m.wavelength = wavelength(m.E_ev) return m
def _get_offset(path): fh = fileDM(path, on_memory=True) if fh.numObjects == 1: idx = 0 else: idx = 1 offset = fh.dataOffset[idx] return offset
def _nav_shape_without_flyback(path): with dm.fileDM(_get_gtg_path(path), on_memory=True) as dm_file: nav_shape_y = dm_file.allTags.get( '.SI Image Tags.SI.Acquisition.Spatial Sampling.Height (pixels)') nav_shape_x = dm_file.allTags.get( '.SI Image Tags.SI.Acquisition.Spatial Sampling.Width (pixels)') if nav_shape_y is not None and nav_shape_x is not None: return (int(nav_shape_y), int(nav_shape_x)) return None
def _get_metadata(path): fh = fileDM(path, on_memory=True) if fh.numObjects == 1: idx = 0 else: idx = 1 return { 'offset': fh.dataOffset[idx], 'zsize': fh.zSize[idx], }
def check_valid(self): first_fn = self._get_files()[0] try: with fileDM(first_fn, on_memory=True): pass if (self._scan_size is not None and np.product(self._scan_size) != len(self._get_files())): raise DataSetException("incompatible scan_size") return True except (IOError, OSError) as e: raise DataSetException("invalid dataset: %s" % e)
def dm_to_png(source_file, dest_file, fixed_dimensions=None): """ Saves the DM3 or DM4 source_file as PNG dest_file. If the data has three of four dimensions. The image taken is from the middle image in those dimensions.""" f = fileDM(source_file, on_memory=True) f.parseHeader() ds = f.getDataset(0) img = ds['data'] img = extract_dimension(img, fixed_dimensions) imsave(dest_file, img, format="png", cmap=cm.gray) return f
def _get_sig_shape_and_native_dtype(self): first_fn = self._get_files()[0] first_file = fileDM(first_fn, on_memory=True) if first_file.numObjects == 1: idx = 0 else: idx = 1 try: raw_dtype = first_file._DM2NPDataType(first_file.dataType[idx]) native_sig_shape = (first_file.ySize[idx], first_file.xSize[idx]) except IndexError as e: raise DataSetException( "could not determine dtype or signal shape") from e return native_sig_shape, raw_dtype
def parse_dm4(fname): """ Parse dm4 data into 4D ndarray. Parameters ---------- fname: str Path to .dm4 file. Returns ------- dm: dict Dictionary of data and metadata etc. Data is (j, i, x, y) ndarray of 4D data of dataset 0 in the file. (j, i) are probe positions and (x, y) are the coordiantes of the image data. Data accessed as dm['data'] """ # Reads dm4 data and reshape into 4D stack dm1 = dm.fileDM(fname) dm1.parseHeader() im1 = dm1.getDataset(0) # print(dm1.allTags) scanI = int( dm1.allTags[".ImageList.2.ImageTags.Series.nimagesx"] ) # number of images x scanJ = int( dm1.allTags[".ImageList.2.ImageTags.Series.nimagesy"] ) # number of images y numkI = im1["data"].shape[2] # number of pixels in x numkJ = im1["data"].shape[1] # number of pixels in y im1["data"] = im1["data"].reshape([scanJ, scanI, numkJ, numkI]) im1["metadata"] = dm1.allTags return im1
def tempGetData(self,*args,**kwargs): # Try to load the data dPath = Path(r'C:/Users/Peter.000/Data/Te NP 4D-STEM') fPath = Path('07_45x8 ss=5nm_spot11_CL=100 0p1s_alpha=4p63mrad_bin=4_300kV.dm4') # Get the filename from the header. msg.logMessage('NCEM: File path = {}'.format(self.header.startdoc.get('sample_name', '????'))) # This only prints the file name. Not the full path. with dm.fileDM((dPath / fPath).as_posix()) as dm1: try: scanI = int(dm1.allTags['.ImageList.2.ImageTags.Series.nimagesx']) scanJ = int(dm1.allTags['.ImageList.2.ImageTags.Series.nimagesy']) im1 = dm1.getDataset(0) numkI = im1['data'].shape[2] numkJ = im1['data'].shape[1] data = im1['data'].reshape([scanJ,scanI,numkJ,numkI]) except: print('Data is not a 4D DM3 or DM4 stack.') raise return data
def load_data(cls, path_to_dmfile, load_additional_data=False): """ INPUT: path_to_dmfile: str, path to spectral image file (.dm3 or .dm4 extension) OUTPUT: image -- Spectral_image, object of Spectral_image class containing the data of the dm-file """ dmfile_tot = dm.fileDM(path_to_dmfile) additional_data = [] for i in range(dmfile_tot.numObjects - dmfile_tot.thumbnail * 1): dmfile = dmfile_tot.getDataset(i) if dmfile['data'].ndim == 3: dmfile = dmfile_tot.getDataset(i) data = np.swapaxes(np.swapaxes(dmfile['data'], 0, 1), 1, 2) if not load_additional_data: break elif load_additional_data: additional_data.append(dmfile_tot.getDataset(i)) if i == dmfile_tot.numObjects - dmfile_tot.thumbnail * 1 - 1: if (len(additional_data) == i + 1) or not load_additional_data: print("No spectral image detected") dmfile = dmfile_tot.getDataset(0) data = dmfile['data'] ddeltaE = dmfile['pixelSize'][0] pixelsize = np.array(dmfile['pixelSize'][1:]) energyUnit = dmfile['pixelUnit'][0] ddeltaE *= cls.get_prefix(energyUnit, 'eV') pixelUnit = dmfile['pixelUnit'][1] pixelsize *= cls.get_prefix(pixelUnit, 'm') image = cls(data, ddeltaE, pixelsize=pixelsize, name=path_to_dmfile[:-4]) if load_additional_data: image.additional_data = additional_data return image
def read_dm(fp, mem="RAM", binfactor=1, metadata=False, **kwargs): """ Read a digital micrograph 4D-STEM file. Args: fp: str or Path Path to the file mem (str, optional): Specifies how the data should be stored; must be "RAM", or "MEMMAP". See docstring for py4DSTEM.file.io.read. Default is "RAM". binfactor (int, optional): Bin the data, in diffraction space, as it's loaded. See docstring for py4DSTEM.file.io.read. Default is 1. metadata (bool, optional): if True, returns the file metadata as a Metadata instance. Returns: (variable): The return value depends on usage: * if metadata==False, returns the 4D-STEM dataset as a DataCube * if metadata==True, returns the metadata as a Metadata instance Note that metadata is read either way - in the latter case ONLY metadata is read and returned, in the former case a DataCube is returned with the metadata attached at datacube.metadata """ assert (isinstance( fp, (str, Path))), "Error: filepath fp must be a string or pathlib.Path" assert (mem in ['RAM', 'MEMMAP' ]), 'Error: argument mem must be either "RAM" or "MEMMAP"' assert (isinstance(binfactor, int)), "Error: argument binfactor must be an integer" assert (binfactor >= 1), "Error: binfactor must be >= 1" md = get_metadata_from_dmFile(fp) if metadata: return md if (mem, binfactor) == ("RAM", 1): with dm.fileDM(fp, on_memory=True) as dmFile: # loop through the datasets until a >2D one is found: i = 0 valid_data = False while not valid_data: data = dmFile.getMemmap(i) if len(np.squeeze(data).shape) > 2: valid_data = True dataSet = dmFile.getDataset(i) i += 1 dc = DataCube(data=dataSet["data"]) elif (mem, binfactor) == ("MEMMAP", 1): with dm.fileDM(fp, on_memory=False) as dmFile: # loop through the datasets until a >2D one is found: i = 0 valid_data = False while not valid_data: memmap = dmFile.getMemmap(i) if len(np.squeeze(memmap).shape) > 2: valid_data = True i += 1 dc = DataCube(data=memmap) elif (mem) == ("RAM"): with dm.fileDM(fp, on_memory=True) as dmFile: # loop through the datasets until a >2D one is found: i = 0 valid_data = False while not valid_data: memmap = dmFile.getMemmap(i) if len(np.squeeze(memmap).shape) > 2: valid_data = True i += 1 if "dtype" in kwargs.keys(): dtype = kwargs["dtype"] else: dtype = memmap.dtype R_Nx, R_Ny, Q_Nx, Q_Ny = memmap.shape Q_Nx, Q_Ny = Q_Nx // binfactor, Q_Ny // binfactor data = np.empty((R_Nx, R_Ny, Q_Nx, Q_Ny), dtype=dtype) for Rx in range(R_Nx): for Ry in range(R_Ny): data[Rx, Ry, :, :] = bin2D(memmap[Rx, Ry, :, :, ], binfactor, dtype=dtype) dc = DataCube(data=data) else: raise Exception( "Memory mapping and on-load binning together is not supported. Either set binfactor=1 or mem='RAM'." ) return dc.metadata = md return dc
crossing = (die_fun_avg[:-delta] < 0) * (die_fun_avg[delta:] >= 0) deltaE = deltaE[deltaE > 0] deltaE = deltaE[50:-50] crossing_E = deltaE[crossing] n = len(crossing_E) return crossing_E, n #%% #data = np.load("area03-eels-SI-aligned.npy") #energies = np.load("area03-eels-SI-aligned_energy.npy") #dielectric_function_im_avg, dielectric_function_im_std = im_dielectric_function(data, energies) #%% crossings_E, crossings_n = crossings_im(dielectric_function_im_avg, energies) #%% plt.figure() #plt.imshow(crossings_n, cmap='hot', interpolation='nearest') #plt. ax = sns.heatmap(crossings_n) plt.show() #%% dmfile = dm.fileDM('/path/to/area03-eels-SI-aligned.dm4') data2 = dmfile.getDataset(0)
def acquireProjAngle(self, file): """Acquires angles from metadata of .dm4 files""" with dm.fileDM(file) as inDM: alphaTag = ".ImageList.2.ImageTags.Microscope Info.Stage Position.Stage Alpha" return inDM.allTags[alphaTag]
def __init__(self, filepath, hidden_stripe_noise_reduction=True): from ncempy.io import dm import os, glob # first parse the input and get the path to the *.gtg if not os.path.isdir(filepath): filepath = os.path.dirname(filepath) os.chdir(filepath) assert len(glob.glob( '*.bin')) == 8, "Wrong path, or wrong number of bin files." assert len(glob.glob( '*.gtg')) == 1, "Wrong path, or wrong number of gtg files." gtgpath = os.path.join(filepath, glob.glob('*.gtg')[0]) binprefix = gtgpath[:-4] self._gtg_file = gtgpath self._bin_prefix = binprefix # open the *.gtg and read the metadata gtg = dm.fileDM(gtgpath) gtg.parseHeader() #get the important metadata try: R_Ny = gtg.allTags['.SI Dimensions.Size Y'] R_Nx = gtg.allTags['.SI Dimensions.Size X'] except ValueError: print( 'Warning: scan shape not detected. Please check/set manually.') R_Nx = self._guess_number_frames() R_Ny = 1 try: # this may be wrong for binned data... in which case the reader doesn't work anyway! Q_Nx = gtg.allTags[ '.SI Image Tags.Acquisition.Parameters.Detector.height'] Q_Ny = gtg.allTags[ '.SI Image Tags.Acquisition.Parameters.Detector.width'] except: print('Warning: diffraction pattern shape not detected!') print('Assuming 1920x1792 as the diffraction pattern size!') Q_Nx = 1792 Q_Ny = 1920 self.shape = (R_Nx, R_Ny, Q_Nx, Q_Ny) self._hidden_stripe_noise_reduction = hidden_stripe_noise_reduction self._attach_to_files() self._stripe_dtype = np.dtype([ ('sync','>u4',1), \ ('pad1',np.void,5),('shutter','>u1',1),('pad2',np.void,6),\ ('block','>u4',1),('pad4',np.void,4),('frame','>u4',1),('coords','>u2',4),\ ('pad3',np.void,4),('data','>u1',22320) ]) self._shutter_offsets = np.zeros((8, ), dtype=np.uint32) self._find_offsets() print('Shutter flags are:', self._shutter_offsets) self._gtg_meta = gtg.allTags super().__init__()
def __init__(self, filepath, sync_block_IDs=True, hidden_stripe_noise_reduction=True): from ncempy.io import dm import os import glob # first parse the input and get the path to the *.gtg if not os.path.isdir(filepath): filepath = os.path.dirname(filepath) assert (len(glob.glob(os.path.join( filepath, "*.bin"))) == 8), "Wrong path, or wrong number of bin files." assert (len(glob.glob(os.path.join( filepath, "*.gtg"))) == 1), "Wrong path, or wrong number of gtg files." gtgpath = os.path.join(filepath, glob.glob(os.path.join(filepath, "*.gtg"))[0]) binprefix = gtgpath[:-4] self._gtg_file = gtgpath self._bin_prefix = binprefix # open the *.gtg and read the metadata gtg = dm.fileDM(gtgpath) gtg.parseHeader() # get the important metadata try: R_Ny = gtg.allTags[".SI Dimensions.Size Y"] R_Nx = gtg.allTags[".SI Dimensions.Size X"] except ValueError: print( "Warning: scan shape not detected. Please check/set manually.") R_Nx = self._guess_number_frames() // 32 R_Ny = 1 try: # this may be wrong for binned data... in which case the reader doesn't work anyway! Q_Nx = gtg.allTags[ ".SI Image Tags.Acquisition.Parameters.Detector.height"] Q_Ny = gtg.allTags[ ".SI Image Tags.Acquisition.Parameters.Detector.width"] except: print("Warning: diffraction pattern shape not detected!") print("Assuming 1920x1792 as the diffraction pattern size!") Q_Nx = 1792 Q_Ny = 1920 self.shape = (int(R_Nx), int(R_Ny), int(Q_Nx), int(Q_Ny)) self._hidden_stripe_noise_reduction = hidden_stripe_noise_reduction self.sync_block_IDs = sync_block_IDs self._stripe_dtype = np.dtype([ ("sync", ">u4"), ("pad1", np.void, 5), ("shutter", ">u1"), ("pad2", np.void, 6), ( "block", ">u4", ), ("pad4", np.void, 4), ("frame", ">u4"), ("coords", ">u2", (4, )), ("pad3", np.void, 4), ("data", ">u1", (22320, )), ]) self._attach_to_files() self._shutter_offsets = np.zeros((8, ), dtype=np.uint32) self._find_offsets() print("Shutter flags are:", self._shutter_offsets) self._gtg_meta = gtg.allTags self._user_noise_reduction = False self._temp = np.zeros((32, ), dtype=self._stripe_dtype) self._Qx, self._Qy = self._parse_slices((slice(None), slice(None)), "diffraction") # needed for Dask support: self.ndims = 4 self.dtype = np.int16 super().__init__()
def ingest_NCEM_DM(paths): assert len(paths) == 1 path = paths[0] # Compose run start run_bundle = event_model.compose_run( ) # type: event_model.ComposeRunBundle start_doc = _metadata(path) start_doc.update(run_bundle.start_doc) start_doc["sample_name"] = Path(paths[0]).resolve().stem yield 'start', start_doc dm_handle = dm.fileDM(path, on_memory=True) num_t = _num_t(dm_handle) num_z = _num_z(dm_handle) first_frame = dm_handle.getSlice( 0, 0, sliceZ2=0)['data'] # Most DM files have only 1 dataset shape = first_frame.shape dtype = first_frame.dtype delayed_get_slice = dask.delayed(get_slice) if num_z > 1: dask_data = da.stack([[ da.from_delayed(delayed_get_slice(dm_handle, t, z), shape=shape, dtype=dtype) for z in range(num_z) ] for t in range(num_t)]) else: dask_data = da.stack([ da.from_delayed(delayed_get_slice(dm_handle, t, 0), shape=shape, dtype=dtype) for t in range(num_t) ]) # Compose descriptor source = 'NCEM' frame_data_keys = { 'raw': { 'source': source, 'dtype': 'number', 'shape': dask_data.shape } } frame_stream_name = 'primary' frame_stream_bundle = run_bundle.compose_descriptor( data_keys=frame_data_keys, name=frame_stream_name, # configuration=_metadata(path) ) yield 'descriptor', frame_stream_bundle.descriptor_doc # NOTE: Resource document may be meaningful in the future. For transient access it is not useful # # Compose resource # resource = run_bundle.compose_resource(root=Path(path).root, resource_path=path, spec='NCEM_DM', resource_kwargs={}) # yield 'resource', resource.resource_doc # Compose datum_page # z_indices, t_indices = zip(*itertools.product(z_indices, t_indices)) # datum_page_doc = resource.compose_datum_page(datum_kwargs={'index_z': list(z_indices), 'index_t': list(t_indices)}) # datum_ids = datum_page_doc['datum_id'] # yield 'datum_page', datum_page_doc yield 'event', frame_stream_bundle.compose_event( data={'raw': dask_data}, timestamps={'raw': time.time()}) yield 'stop', run_bundle.compose_stop()
def _metadata(path): metaData = {} with dm.fileDM(path, on_memory=True) as dm1: # Save most useful metaData # Only keep the most useful tags as meta data for kk, ii in dm1.allTags.items(): # Most useful starting tags prefix1 = 'ImageList.{}.ImageTags.'.format(dm1.numObjects) prefix2 = 'ImageList.{}.ImageData.'.format(dm1.numObjects) pos1 = kk.find(prefix1) pos2 = kk.find(prefix2) if pos1 > -1: sub = kk[pos1 + len(prefix1):] metaData[sub] = ii elif pos2 > -1: sub = kk[pos2 + len(prefix2):] metaData[sub] = ii # Remove unneeded keys for jj in list(metaData): if jj.find('frame sequence') > -1: del metaData[jj] elif jj.find('Private') > -1: del metaData[jj] elif jj.find('Reference Images') > -1: del metaData[jj] elif jj.find('Frame.Intensity') > -1: del metaData[jj] elif jj.find('Area.Transform') > -1: del metaData[jj] elif jj.find('Parameters.Objects') > -1: del metaData[jj] elif jj.find('Device.Parameters') > -1: del metaData[jj] # Store the X and Y pixel size, offset and unit try: metaData['PhysicalSizeX'] = metaData[ 'Calibrations.Dimension.1.Scale'] metaData['PhysicalSizeXOrigin'] = metaData[ 'Calibrations.Dimension.1.Origin'] metaData['PhysicalSizeXUnit'] = metaData[ 'Calibrations.Dimension.1.Units'] metaData['PhysicalSizeY'] = metaData[ 'Calibrations.Dimension.2.Scale'] metaData['PhysicalSizeYOrigin'] = metaData[ 'Calibrations.Dimension.2.Origin'] metaData['PhysicalSizeYUnit'] = metaData[ 'Calibrations.Dimension.2.Units'] except: metaData['PhysicalSizeX'] = 1 metaData['PhysicalSizeXOrigin'] = 0 metaData['PhysicalSizeXUnit'] = '' metaData['PhysicalSizeY'] = 1 metaData['PhysicalSizeYOrigin'] = 0 metaData['PhysicalSizeYUnit'] = '' metaData['FileName'] = path return metaData
def _get_scansize(self): with dm.fileDM(_get_gtg_path(self._path), on_memory=True) as dm_file: return (int(dm_file.allTags['.SI Dimensions.Size Y']), int(dm_file.allTags['.SI Dimensions.Size X']))
def run_MPI(): import argparse import os comm = MPI.COMM_WORLD rank = comm.Get_rank() Nworkers = comm.Get_size() HEAD_WORKER = rank == 0 def str2bool(v): if isinstance(v, bool): return v if v.lower() in ("yes", "true", "t", "y", "1"): return True elif v.lower() in ("no", "false", "f", "n", "0"): return False else: raise argparse.ArgumentTypeError("Boolean value expected.") parser = argparse.ArgumentParser(description="Launch TV Denoising using MPI.") parser.add_argument( "-i", "--input", type=os.path.abspath, nargs=1, help="input file" ) parser.add_argument( "-o", "--output", type=os.path.abspath, nargs=1, help="output file" ) parser.add_argument( "-d", "--dimensions", type=int, nargs=1, help="Number of Dimensions (3 or 4)" ) parser.add_argument( "-f", "--fista", type=str2bool, nargs=1, help="Use acceleration? 0 or 1.", default=False, ) parser.add_argument( "-n", "--niterations", type=int, nargs="+", help="Number of iterations (Specify 2 values for hybrid.)", ) parser.add_argument("-L", "--lambda", type=float, nargs="+") parser.add_argument("-m", "--mu", type=float, nargs="+") parser.add_argument("-v", "--verbose", type=str2bool, default=False) args = vars(parser.parse_args()) # VERBOSE = args["verbose"] VERBOSE = True ndim = args["dimensions"][0] FISTA = args["fista"][0] niter = args["niterations"] BC_mode = 2 lam = np.array(args["lambda"]) mu = np.array(args["mu"]) outfile = args["output"][0] if HEAD_WORKER: logger.info(f"Running MPI denoising with arguments: {args}") logger.info(f"Python sees OMP_NUM_THREADS as {os.environ['OMP_NUM_THREADS']}") # each worker must load a memory map into the data: t_read_start = time() if ndim == 3: # load EELS SI data using ncempy dmf = fileDM(args["input"][0]) data = dmf.getMemmap(2) # squeeze while retaining memmap (native numpy squeeze) tries to load the array in RAM while data.shape[0] == 1: data = data.reshape(data.shape[1:]) size = data.shape[:2] elif ndim == 4: # load 4D data using py4DSTEM # load DM data: if "dm" in args["input"][0][-3:]: data = py4DSTEM.file.io.read(args["input"][0], load="dmmmap") size = data.shape[:2] # load EMD data: elif any(ftype in args["input"][0].split(".")[-1] for ftype in ("h5", "emd")): fb = py4DSTEM.file.io.FileBrowser(args["input"][0]) # hack to swap out the HDF driver: # fb.file.close() fb.file = h5py.File( args["input"][0], "r", driver="mpio", comm=MPI.COMM_WORLD ) dc = fb.get_dataobject(0, memory_map=True) data = dc.data size = data.shape[:2] else: if HEAD_WORKER: raise (NotImplementedError("Incompatible File type...")) else: if HEAD_WORKER: raise (AssertionError("Bad number of dimensions...")) if HEAD_WORKER: logger.info(f"Loaded memory map. Data size is: {data.shape}") logger.info(f"Loading memory map took {time()-t_read_start} seconds.") # calculate the best division of labor: edges = np.zeros((Nworkers,)) for i in range(1, Nworkers + 1): if Nworkers % i == 0: # this is a factor of the number of workers, so a valid size wx = i wy = Nworkers / i # x and y sizes of the chunks, not including overlap sx = np.ceil(size[0] / wx) sy = np.ceil(size[1] / wy) edges[i - 1] = (Nworkers - 1) * (2 * sx + 2 * sy) else: # this is not a valid tiling shape edges[i - 1] = np.nan # Get the number of tiles in X and Y for the grid of workers: wx = int(np.nanargmin(edges) + 1) wy = int(Nworkers / wx) if HEAD_WORKER: logger.info(f"Dividing work over a {wx} by {wy} grid...") # Figure out the slices that this worker is responsible for: tile_x, tile_y = np.unravel_index(rank, (wx, wy)) logger.debug(f"Worker {rank} is doing tile {tile_x},{tile_y}.") # first get the size in each direction nx = int(np.ceil(size[0] / wx)) ny = int(np.ceil(size[1] / wy)) # get the slices for which this worker's data is valid (i.e. the slice before adding overlaps) valid_slice_x = slice( tile_x * nx, (tile_x + 1) * nx if (tile_x + 1) * nx <= size[0] else size[0] ) valid_slice_y = slice( tile_y * ny, (tile_y + 1) * ny if (tile_y + 1) * ny <= size[1] else size[1] ) # now get the slices to actually read read_slice_x = slice( valid_slice_x.start - 1 if valid_slice_x.start > 0 else 0, valid_slice_x.stop + 1 if valid_slice_x.stop + 1 <= size[0] else size[0], ) read_slice_y = slice( valid_slice_y.start - 1 if valid_slice_y.start > 0 else 0, valid_slice_y.stop + 1 if valid_slice_y.stop + 1 <= size[1] else size[1], ) logger.debug( f"Worker {rank} at tile {tile_x},{tile_y} is reading slice {read_slice_x},{read_slice_y}..." ) # set some flags for determining if this worker should shift data at each step: SHIFT_X_POS = tile_x < (wx - 1) SHIFT_X_NEG = tile_x > 0 SHIFT_Y_POS = tile_y < (wy - 1) SHIFT_Y_NEG = tile_y > 0 # get the slice *relative to the local chunk* that represents valid data # (this is used later for deciding what data from the local chunk is saved) local_valid_slice_x = slice(1 if SHIFT_X_NEG else 0, -1 if SHIFT_X_POS else None) local_valid_slice_y = slice(1 if SHIFT_Y_NEG else 0, -1 if SHIFT_Y_POS else None) # figure out the sources and destinations for each shift RANK_X_POS = ( np.ravel_multi_index((tile_x + 1, tile_y), (wx, wy)) if SHIFT_X_POS else None ) RANK_X_NEG = ( np.ravel_multi_index((tile_x - 1, tile_y), (wx, wy)) if SHIFT_X_NEG else None ) RANK_Y_POS = ( np.ravel_multi_index((tile_x, tile_y + 1), (wx, wy)) if SHIFT_Y_POS else None ) RANK_Y_NEG = ( np.ravel_multi_index((tile_x, tile_y - 1), (wx, wy)) if SHIFT_Y_NEG else None ) logger.debug( f"Rank {rank} has neighbors: +x {RANK_X_POS} \t -x: {RANK_X_NEG} \t +y: {RANK_Y_POS} \t -y: {RANK_Y_NEG}" ) # load in the data and make it contiguous t_load_start = time() if ndim == 3: raw = np.ascontiguousarray(data[read_slice_x, read_slice_x, :]).astype( np.float32 ) elif args["dimensions"][0] == 4: # TODO: fix this for non-py4DSTEM files!!! raw = np.zeros( ( read_slice_x.stop - read_slice_x.start, read_slice_y.stop - read_slice_y.start, data.shape[2], data.shape[3], ), dtype=np.float32, ) logger.debug(f"Raw is shape {raw.shape}") data.read_direct(raw, source_sel=np.s_[read_slice_x, read_slice_y, :, :]) # TODO: make dtype a flag if HEAD_WORKER: logger.info(f"Head worker finished reading raw data...") logger.info( f"Reading raw data took {time()-t_load_start} seconds. Data size is {filesize.size(raw.nbytes,system=filesize.alternative)}" ) recon = raw.copy() lambdaInv = (1.0 / lam).astype(recon.dtype) lam_mu = (lam / mu).astype(recon.dtype) if ndim == 3: # 3D is boring, I'll implement it later... if HEAD_WORKER: logger.error("Oops... Haven't implemented 3D yet. Sorry") elif ndim == 4: # allocate accumulators t_accum_start = time() acc0 = np.zeros_like(recon) acc1 = np.zeros_like(recon) acc2 = np.zeros_like(recon) acc3 = np.zeros_like(recon) # allocate MPI sync buffers x_pos_buffer = np.zeros( (raw.shape[1], raw.shape[2], raw.shape[3]), dtype=np.float32 ) x_neg_buffer = np.zeros_like(x_pos_buffer) y_pos_buffer = np.zeros( (raw.shape[0], raw.shape[2], raw.shape[3]), dtype=np.float32 ) y_neg_buffer = np.zeros_like(y_pos_buffer) if HEAD_WORKER: logger.info( f"Allocating the main accumulators and buffers took {time() - t_accum_start} seconds" ) if FISTA: d1 = np.zeros_like(recon) d2 = np.zeros_like(recon) d3 = np.zeros_like(recon) d4 = np.zeros_like(recon) # allocate MPI sync buffers x_pos_buffer_FISTA = np.zeros( (raw.shape[1], raw.shape[2], raw.shape[3]), dtype=np.float32 ) x_neg_buffer_FISTA = np.zeros_like(x_pos_buffer) y_pos_buffer_FISTA = np.zeros( (raw.shape[0], raw.shape[2], raw.shape[3]), dtype=np.float32 ) y_neg_buffer_FISTA = np.zeros_like(y_pos_buffer) tk = 1.0 if HEAD_WORKER: logger.info( f"With all accumulators allocated, free RAM is {filesize.size(psutil.virtual_memory().available,system=filesize.alternative)}." ) else: logger.debug( f"With all accumulators allocated, free RAM on rank {rank} is {filesize.size(psutil.virtual_memory().available,system=filesize.alternative)}." ) # create the iterators (so that only the head spits out tqdm stuff) iterator = tqdm(range(niter[0])) if HEAD_WORKER else range(niter[0]) if FISTA: logger.error("Oops, haven't done FISTA yet...") else: for i in iterator: # perform an update step along dim 0 t0 = time() tv.accumulator_update_4D(recon, acc0, 0, lambdaInv[0], BC_mode=BC_mode) logger.debug( f"X update step : rank {rank} : iteration {i} : took {time()-t0} sec" ) t0 = time() # start comms to send data right, receive data left: if SHIFT_X_POS: x_pos_buffer[:] = np.squeeze(acc0[-1, :, :, :]) mpi_send_x_pos = comm.Isend(x_pos_buffer, dest=RANK_X_POS,) if SHIFT_X_NEG: # shift x left <=> recieve data x left x_neg_buffer[:] = 0 mpi_recv_x_neg = comm.Irecv(x_neg_buffer, source=RANK_X_NEG,) logger.debug( f"X MPI sync step : rank {rank} : iteration {i} : took {time()-t0} sec" ) # perform an update step along dim 1 t0 = time() tv.accumulator_update_4D(recon, acc1, 1, lambdaInv[1], BC_mode=BC_mode) logger.debug( f"Y update step : rank {rank} : iteration {i} : took {time()-t0} sec" ) t0 = time() # start comms to send data right, receive data left: if SHIFT_Y_POS: y_pos_buffer[:] = np.squeeze(acc1[:, -1, :, :]) mpi_send_y_pos = comm.Isend(y_pos_buffer, dest=RANK_Y_POS,) if SHIFT_Y_NEG: # shift y left <=> recieve data y left y_neg_buffer[:] = 0 mpi_recv_y_neg = comm.Irecv(y_neg_buffer, source=RANK_Y_NEG,) logger.debug( f"X MPI sync step : rank {rank} : iteration {i} : took {time()-t0} sec" ) # perform update steps on the non-communicating directions if VERBOSE and HEAD_WORKER: logger.info("Starting Qx/Qy acc update") t0 = time() tv.accumulator_update_4D(recon, acc2, 2, lambdaInv[2], BC_mode=BC_mode) tv.accumulator_update_4D(recon, acc3, 3, lambdaInv[3], BC_mode=BC_mode) logger.debug( f"Qx/Qy update step : rank {rank} : iteration {i} : took {time()-t0} sec" ) comm.Barrier() # block until communication finishes. copy buffered data. if HEAD_WORKER: logger.info( f"Passed accumulator barrier on iteration {i} and entering sync block." ) else: logger.debug( f"Rank {rank} passed accumulator barrier and entering sync block." ) t_comm_wait = time() if SHIFT_X_NEG: mpi_recv_x_neg.Wait() acc0[0, :, :, :] = x_neg_buffer if SHIFT_Y_NEG: mpi_recv_y_neg.Wait() acc1[:, 0, :, :] = y_neg_buffer if SHIFT_X_POS: mpi_send_x_pos.Wait() if SHIFT_Y_POS: mpi_send_y_pos.Wait() if HEAD_WORKER: logger.info( f"Rank {rank} at iteration {i} spent {time()-t_comm_wait} seconds waiting for accumulator communication" ) else: logger.debug( f"Rank {rank} at iteration {i} spent {time()-t_comm_wait} seconds waiting for accumulator communication" ) # perform a datacube update step: if VERBOSE and HEAD_WORKER: logger.info("Starting datacube update") t0 = time() tv.datacube_update_4D( raw, recon, acc0, acc1, acc2, acc3, lam_mu, BC_mode=BC_mode ) logger.debug( f"Datacube update step : rank {rank} : iteration {i} : took {time()-t0} sec" ) t_comm_wait = time() # start comms to send data left, receive data right if SHIFT_X_NEG: x_neg_buffer[:] = np.squeeze(recon[0, :, :, :]) mpi_send_x_neg = comm.Isend(x_neg_buffer, dest=RANK_X_NEG,) if SHIFT_X_POS: x_pos_buffer[:] = 0 mpi_recv_x_pos = comm.Irecv(x_pos_buffer, source=RANK_X_POS,) if SHIFT_Y_NEG: y_neg_buffer[:] = np.squeeze(recon[:, 0, :, :]) mpi_send_y_neg = comm.Isend(y_neg_buffer, dest=RANK_Y_NEG,) if SHIFT_Y_POS: y_pos_buffer[:] = 0 mpi_recv_y_pos = comm.Irecv(y_pos_buffer, source=RANK_Y_POS) # Block until communication finishes comm.Barrier() if VERBOSE and HEAD_WORKER: logger.info("Passed second barrier and entering sync block.") t_comm_wait = time() if SHIFT_X_POS: mpi_recv_x_pos.Wait() recon[-1, :, :, :] = x_pos_buffer if SHIFT_Y_POS: mpi_recv_y_pos.Wait() recon[:, -1, :, :] = y_pos_buffer if SHIFT_X_NEG: mpi_send_x_neg.Wait() if SHIFT_Y_NEG: mpi_send_y_neg.Wait() if VERBOSE and HEAD_WORKER: logger.info( f"Rank {rank} at iteration {i} spent {time()-t_comm_wait} seconds waiting for reconstruction communication" ) # temporary kludge for writing output files t_save_start = time() logger.info(f"Rank {rank} is saving data...") fout = h5py.File( outfile.split(".")[-2] + ".emd", "w", driver="mpio", comm=MPI.COMM_WORLD ) group_toplevel = fout.create_group("4DSTEM_experiment") group_toplevel.attrs.create("emd_group_type", 2) group_toplevel.attrs.create("version_major", 0) group_toplevel.attrs.create("version_minor", 7) # Write data groups group_toplevel.create_group("metadata") group_data = group_toplevel.create_group("data") group_datacubes = group_data.create_group("datacubes") group_data.create_group("counted_datacubes") group_data.create_group("diffractionslices") group_data.create_group("realslices") group_data.create_group("pointlists") group_data.create_group("pointlistarrays") grp_dc = group_datacubes.create_group("datacube_0") dset = grp_dc.create_dataset("data", data.shape) grp_dc.attrs.create("emd_group_type", 1) grp_dc.attrs.create("metadata", -1) data_datacube = grp_dc["data"] R_Nx, R_Ny, Q_Nx, Q_Ny = data_datacube.shape data_R_Nx = grp_dc.create_dataset("dim1", (R_Nx,)) data_R_Ny = grp_dc.create_dataset("dim2", (R_Ny,)) data_Q_Nx = grp_dc.create_dataset("dim3", (Q_Nx,)) data_Q_Ny = grp_dc.create_dataset("dim4", (Q_Ny,)) if rank == 0: # Populate uncalibrated dimensional axes data_R_Nx[...] = np.arange(0, R_Nx) data_R_Nx.attrs.create("name", np.string_("R_x")) data_R_Nx.attrs.create("units", np.string_("[pix]")) data_R_Ny[...] = np.arange(0, R_Ny) data_R_Ny.attrs.create("name", np.string_("R_y")) data_R_Ny.attrs.create("units", np.string_("[pix]")) data_Q_Nx[...] = np.arange(0, Q_Nx) data_Q_Nx.attrs.create("name", np.string_("Q_x")) data_Q_Nx.attrs.create("units", np.string_("[pix]")) data_Q_Ny[...] = np.arange(0, Q_Ny) data_Q_Ny.attrs.create("name", np.string_("Q_y")) data_Q_Ny.attrs.create("units", np.string_("[pix]")) dset.write_direct( recon, source_sel=np.s_[local_valid_slice_x, local_valid_slice_y, :, :], dest_sel=np.s_[valid_slice_x, valid_slice_y, :, :], ) # dset[valid_slice_x, valid_slice_y, :, :] = recon[ # local_valid_slice_x, local_valid_slice_y, :, : # ] fout.close() logger.info(f"Rank {rank} is done! Writing data took {time()-t_save_start} seconds")