def __init__(self, args): super().__init__() self.dataset = Dataset() self.args = args self.data_path = None self.current_shot = pd.Series() self.diff_image = np.empty((0,0)) self.map_image = np.empty((0,0)) self.init_widgets() self.adxv = None self.geom = None self.read_files() self.switch_shot(0) if self.args.internal: self.hist_img.setLevels(np.quantile(self.diff_image, 0.02), np.quantile(self.diff_image, 0.98)) self.update() self.show()
def read_files(self): file_type = args.filename.rsplit('.', 1)[-1] if file_type == 'stream': print(f'Parsing stream file {args.filename}...') stream = StreamParser(args.filename) # with open('tmp.geom', 'w') as fh: # fh.write('\n'.join(stream._geometry_string)) # self.geom = load_crystfel_geometry('tmp.geom') # os.remove('tmp.geom') # if len(self.geom['panels']) == 1: # print('Single-panel geometry, so ignoring transforms for now.') # #TODO make this more elegant, e.g. by overwriting image transform func with identity # self.geom = None self.geom = None try: self.data_path = stream.geometry['data'] except KeyError: if args.geometry is None: raise ValueError('No data location specified in geometry file. Please use -d parameter.') files = sorted(list(stream.shots['file'].unique())) # print('Loading data files found in stream... \n', '\n'.join(files)) try: self.dataset = Dataset.from_files(files, load_tables=False, init_stacks=False, open_stacks=False) self.dataset.load_tables(features=True) # print(self.dataset.shots.columns) self.dataset.merge_stream(stream) # get_selection would not be the right method to call (changes IDs), instead do... self.dataset._shots = self.dataset._shots.loc[self.dataset._shots.selected,:].reset_index(drop=True) # TODO get subset for incomplete coverage print('Merged stream and hdf5 shot lists') except Exception as err: self.dataset = Dataset() self.dataset._shots = stream.shots self.dataset._peaks = stream.peaks self.dataset._predict = stream.indexed self.dataset._shots['selected'] = True print('Could not load shot lists from H5 files, but have that from the stream file.') print(f'Reason: {err}') if args.geometry is not None: raise ValueError('Geometry files are currently not supported.') # self.geom = load_crystfel_geometry(args.geometry) if file_type in ['lst', 'h5', 'hdf', 'nxs']: self.dataset = Dataset.from_list(args.filename, load_tables=True, init_stacks=False, open_stacks=False) if not self.dataset.shots.selected.all(): # dirty removal of unwanted shots is sufficient in this case: self.dataset._shots = self.dataset._shots.loc[self.dataset._shots.selected,:].reset_index(drop=True) if args.data_path is not None: self.data_path = args.data_path if self.data_path is None: # data path neither set via stream file, nor explicitly. We have to guess. try: with h5py.File(self.dataset.shots.file.iloc[0], 'r') as fh: base = '/%/data'.replace('%', self.dataset.shots.subset.iloc[0]) self.data_path = '/%/data/' + fh[base].attrs['signal'] print('Found data path', self.data_path) except Exception as err: warn(str(err), RuntimeWarning) print('Could not find out data path. Assuming /%/data/raw_counts') self.data_path = '/%/data/raw_counts' if self.args.query: print('Only showing shots with', self.args.query) #self.dataset.select(self.args.query) #self.dataset = self.dataset.get_selection(self.args.query, file_suffix=None, reset_id=False) #print('cutting shot list only') self.dataset._shots = self.dataset._shots.query(args.query) if self.args.sort_crystals: print('Re-sorting shots by region/crystal/run.') self.dataset._shots = self.dataset._shots.sort_values(by=['sample', 'region', 'crystal_id', 'run']) if not self.args.internal: #adxv_args = {'wavelength': 0.0251, 'distance': 2280, 'pixelsize': 0.055} adxv_args = {} self.adxv = Adxv(hdf5_path=self.data_path.replace('%', 'entry'), adxv_bin=self.args.adxv_bin, **adxv_args) self.b_goto.setMaximum(self.dataset.shots.shape[0]-1) self.b_goto.setMinimum(0)
def cumulate(fn, opt: PreProcOpts): """Applies cumulative summation to a data set comprising movie frame stacks. At the moment, requires the summed frame stacks to have the same shape as the raw data. Arguments: fn {function} -- [description] opt {PreProcOpts} -- [description] Raises: err: [description] Returns: [type] -- [description] """ if isinstance(fn, list) and len(fn) == 1: fn = fn[0] def log(*args): if isinstance(fn, list): dispfn = os.path.basename(fn[0]) + ' etc.' else: dispfn = os.path.basename(fn) idstring = '[{} - {} - cumulate] '.format( datetime.datetime.now().time(), dispfn) print(idstring, *args) dssel = Dataset().from_list(fn) log('Cumulating from frame', opt.cum_first_frame) dssel.open_stacks(readonly=False) # chunks for aggregation chunks = tuple( dssel.shots.groupby(opt.idfields).count()['selected'].values) for k, stk in dssel.stacks.items(): if stk.chunks[0] != chunks: if k == 'index': continue log(k, 'needs rechunking...') dssel.add_stack(k, stk.rechunk({0: chunks}), overwrite=True) dssel._zchunks = chunks def cumfunc(movie): movie_out = movie movie_out[opt.cum_first_frame:, ...] = np.cumsum(movie[opt.cum_first_frame:, ...], axis=0) return movie_out for k in opt.cum_stacks: dssel.stacks[k] = dssel.stacks[k].map_blocks( cumfunc, dtype=dssel.stacks[k].dtype) dssel.change_filenames(opt.cum_file_suffix) dssel.init_files(overwrite=True, keep_features=False) log('File initialized, writing tables...') dssel.store_tables(shots=True, features=True) try: dssel.open_stacks(readonly=False) log('Writing stack data...') dssel.store_stacks(overwrite=True, progress_bar=False) except Exception as err: log('Cumulative processing failed.') raise err finally: dssel.close_stacks() log('Cumulation done.') return dssel.files
class EDViewer(QWidget): def __init__(self, args): super().__init__() self.dataset = Dataset() self.args = args self.data_path = None self.current_shot = pd.Series() self.diff_image = np.empty((0,0)) self.map_image = np.empty((0,0)) self.init_widgets() self.adxv = None self.geom = None self.read_files() self.switch_shot(0) if self.args.internal: self.hist_img.setLevels(np.quantile(self.diff_image, 0.02), np.quantile(self.diff_image, 0.98)) self.update() self.show() def closeEvent(self, a0: QtGui.QCloseEvent) -> None: if not self.args.internal: self.adxv.exit() a0.accept() def read_files(self): file_type = args.filename.rsplit('.', 1)[-1] if file_type == 'stream': print(f'Parsing stream file {args.filename}...') stream = StreamParser(args.filename) # with open('tmp.geom', 'w') as fh: # fh.write('\n'.join(stream._geometry_string)) # self.geom = load_crystfel_geometry('tmp.geom') # os.remove('tmp.geom') # if len(self.geom['panels']) == 1: # print('Single-panel geometry, so ignoring transforms for now.') # #TODO make this more elegant, e.g. by overwriting image transform func with identity # self.geom = None self.geom = None try: self.data_path = stream.geometry['data'] except KeyError: if args.geometry is None: raise ValueError('No data location specified in geometry file. Please use -d parameter.') files = sorted(list(stream.shots['file'].unique())) # print('Loading data files found in stream... \n', '\n'.join(files)) try: self.dataset = Dataset.from_files(files, load_tables=False, init_stacks=False, open_stacks=False) self.dataset.load_tables(features=True) # print(self.dataset.shots.columns) self.dataset.merge_stream(stream) # get_selection would not be the right method to call (changes IDs), instead do... self.dataset._shots = self.dataset._shots.loc[self.dataset._shots.selected,:].reset_index(drop=True) # TODO get subset for incomplete coverage print('Merged stream and hdf5 shot lists') except Exception as err: self.dataset = Dataset() self.dataset._shots = stream.shots self.dataset._peaks = stream.peaks self.dataset._predict = stream.indexed self.dataset._shots['selected'] = True print('Could not load shot lists from H5 files, but have that from the stream file.') print(f'Reason: {err}') if args.geometry is not None: raise ValueError('Geometry files are currently not supported.') # self.geom = load_crystfel_geometry(args.geometry) if file_type in ['lst', 'h5', 'hdf', 'nxs']: self.dataset = Dataset.from_list(args.filename, load_tables=True, init_stacks=False, open_stacks=False) if not self.dataset.shots.selected.all(): # dirty removal of unwanted shots is sufficient in this case: self.dataset._shots = self.dataset._shots.loc[self.dataset._shots.selected,:].reset_index(drop=True) if args.data_path is not None: self.data_path = args.data_path if self.data_path is None: # data path neither set via stream file, nor explicitly. We have to guess. try: with h5py.File(self.dataset.shots.file.iloc[0], 'r') as fh: base = '/%/data'.replace('%', self.dataset.shots.subset.iloc[0]) self.data_path = '/%/data/' + fh[base].attrs['signal'] print('Found data path', self.data_path) except Exception as err: warn(str(err), RuntimeWarning) print('Could not find out data path. Assuming /%/data/raw_counts') self.data_path = '/%/data/raw_counts' if self.args.query: print('Only showing shots with', self.args.query) #self.dataset.select(self.args.query) #self.dataset = self.dataset.get_selection(self.args.query, file_suffix=None, reset_id=False) #print('cutting shot list only') self.dataset._shots = self.dataset._shots.query(args.query) if self.args.sort_crystals: print('Re-sorting shots by region/crystal/run.') self.dataset._shots = self.dataset._shots.sort_values(by=['sample', 'region', 'crystal_id', 'run']) if not self.args.internal: #adxv_args = {'wavelength': 0.0251, 'distance': 2280, 'pixelsize': 0.055} adxv_args = {} self.adxv = Adxv(hdf5_path=self.data_path.replace('%', 'entry'), adxv_bin=self.args.adxv_bin, **adxv_args) self.b_goto.setMaximum(self.dataset.shots.shape[0]-1) self.b_goto.setMinimum(0) def update_image(self): print(self.current_shot) with h5py.File(self.current_shot['file'], mode='r') as f: if self.args.internal: path = self.data_path.replace('%', self.current_shot.subset) print('Loading {}:{} from {}'.format(path, self.current_shot['shot_in_subset'], self.current_shot['file'])) if len(f[path].shape) == 3: self.diff_image = f[path][int(self.current_shot['shot_in_subset']), ...] elif len(f[path].shape) == 2: self.diff_image = f[path][:] self.diff_image[np.isnan(self.diff_image)] = 0 self.hist_img.setHistogramRange(np.partition(self.diff_image.flatten(), 100)[100], np.partition(self.diff_image.flatten(), -100)[-100]) levels = self.hist_img.getLevels() # levels = (max(levels[0], -1), levels[1]) levels = (levels[0], levels[1]) if self.geom is not None: raise RuntimeError('This should not happen') # self.diff_image = apply_geometry_to_data(self.diff_image, self.geom) self.img.setImage(self.diff_image, autoRange=False) self.img.setLevels(levels) self.hist_img.setLevels(levels[0], levels[1]) if not self.args.no_map: try: path = args.map_path.replace('%', self.current_shot['subset']) self.map_image = f[path][...] self.mapimg.setImage(self.map_image) except KeyError: warn('No map found at {}!'.format(path), Warning) if not self.args.internal: self.adxv.load_image(self.current_shot.file) self.adxv.slab(self.current_shot.shot_in_subset + 1) def update_plot(self): allpk = [] if self.b_peaks.isChecked(): if (len(self.dataset.peaks) == 0) or args.cxi_peaks: path = args.cxi_peaks_path.replace('%', self.current_shot.subset) print('Loading CXI peaks of {}:{} from {}'.format(path, self.current_shot['shot_in_subset'], self.current_shot['file'])) with h5py.File(self.current_shot.file) as fh: ii = int(self.current_shot['shot_in_subset']) Npk = fh[path + '/nPeaks'][ii] x = fh[path + '/peakXPosRaw'][ii, :Npk] y = fh[path + '/peakYPosRaw'][ii, :Npk] peaks = pd.DataFrame((x, y), index=['fs/px', 'ss/px']).T else: peaks = self.dataset.peaks.loc[(self.dataset.peaks.file == self.current_shot.file) & (self.dataset.peaks.Event == self.current_shot.Event), ['fs/px', 'ss/px']] - 0.5 x = peaks.loc[:,'fs/px'] y = peaks.loc[:,'ss/px'] if self.geom is not None: raise RuntimeError('Someone set geom to something. This should not happen.') # print('Projecting peaks...') # maps = compute_visualization_pix_maps(self.geom) # x = maps.x[y.astype(int), x.astype(int)] # y = maps.y[y.astype(int), x.astype(int)] if self.args.internal: ring_pen = pg.mkPen('g', width=0.8) self.found_peak_canvas.setData(x, y, symbol='o', size=13, pen=ring_pen, brush=(0, 0, 0, 0), antialias=True) else: allpk.append(peaks.assign(group=0)) else: self.found_peak_canvas.clear() if self.b_pred.isChecked() and (self.dataset.predict.shape[0] > 0): pred = self.dataset.predict.loc[(self.dataset.predict.file == self.current_shot.file) & (self.dataset.predict.Event == self.current_shot.Event), ['fs/px', 'ss/px']] - 0.5 if self.geom is not None: raise RuntimeError('Someone set geom to not None. This should not happen.') # print('Projecting predictions...') # maps = compute_visualization_pix_maps(self.geom) # x = maps.x[pred.loc[:,'ss/px'].astype(int), # pred.loc[:,'fs/px'].astype(int)] # y = maps.y[pred.loc[:,'ss/px'].astype(int), # pred.loc[:,'fs/px'].astype(int)] else: x = pred.loc[:,'fs/px'] y = pred.loc[:,'ss/px'] if self.args.internal: square_pen = pg.mkPen('r', width=0.8) self.predicted_peak_canvas.setData(x, y, symbol='s', size=13, pen=square_pen, brush=(0, 0, 0, 0), antialias=True) else: allpk.append(pred.assign(group=1)) else: self.predicted_peak_canvas.clear() if not self.args.internal and len(allpk) > 0: self.adxv.define_spot('green', 5, 0, 0) self.adxv.define_spot('red', 0, 10, 1) self.adxv.load_spots(pd.concat(allpk, axis=0, ignore_index=True).values) elif not self.args.internal: self.adxv.load_spots(np.empty((0,3))) if self.dataset.features.shape[0] > 0: ring_pen = pg.mkPen('g', width=2) dot_pen = pg.mkPen('y', width=0.5) region_feat = self.dataset.features.loc[(self.dataset.features['region'] == self.current_shot['region']) & (self.dataset.features['sample'] == self.current_shot['sample'])] print('Number of region features:', region_feat.shape[0]) if self.current_shot['crystal_id'] != -1: single_feat = region_feat.loc[region_feat['crystal_id'] == self.current_shot['crystal_id'], :] x0 = single_feat['crystal_x'].squeeze() y0 = single_feat['crystal_y'].squeeze() if self.b_locations.isChecked(): self.found_features_canvas.setData(region_feat['crystal_x'], region_feat['crystal_y'], symbol='+', size=7, pen=dot_pen, brush=(0, 0, 0, 0), pxMode=True) else: self.found_features_canvas.clear() if self.b_zoom.isChecked(): self.map_box.setRange(xRange=(x0 - 5 * args.beam_diam, x0 + 5 * args.beam_diam), yRange=(y0 - 5 * args.beam_diam, y0 + 5 * args.beam_diam)) self.single_feature_canvas.setData([x0], [y0], symbol='o', size=args.beam_diam, pen=ring_pen, brush=(0, 0, 0, 0), pxMode=False) try: c_real = np.cross([self.current_shot.astar_x, self.current_shot.astar_y, self.current_shot.astar_z], [self.current_shot.bstar_x, self.current_shot.bstar_y, self.current_shot.bstar_z]) b_real = np.cross([self.current_shot.cstar_x, self.current_shot.cstar_y, self.current_shot.cstar_z], [self.current_shot.astar_x, self.current_shot.astar_y, self.current_shot.astar_z]) a_real = np.cross([self.current_shot.bstar_x, self.current_shot.bstar_y, self.current_shot.bstar_z], [self.current_shot.cstar_x, self.current_shot.cstar_y, self.current_shot.cstar_z]) a_real = 20 * a_real / np.sum(a_real ** 2) ** .5 b_real = 20 * b_real / np.sum(b_real ** 2) ** .5 c_real = 20 * c_real / np.sum(c_real ** 2) ** .5 self.a_dir.setData(x=x0 + np.array([0, a_real[0]]), y=y0 + np.array([0, a_real[1]])) self.b_dir.setData(x=x0 + np.array([0, b_real[0]]), y=y0 + np.array([0, b_real[1]])) self.c_dir.setData(x=x0 + np.array([0, c_real[0]]), y=y0 + np.array([0, c_real[1]])) except: print('Could not read lattice vectors.') else: self.single_feature_canvas.setData([x0], [y0], symbol='o', size=13, pen=ring_pen, brush=(0, 0, 0, 0), pxMode=True) self.map_box.setRange(xRange=(0, self.map_image.shape[1]), yRange=(0, self.map_image.shape[0])) else: self.single_feature_canvas.setData([], []) def update(self): self.found_peak_canvas.clear() self.predicted_peak_canvas.clear() app.processEvents() self.update_image() if args.cxi_peaks and not args.internal: # give adxv some time to display the image before accessing the CXI data sleep(0.2) self.update_plot() print(self.current_shot) # CALLBACK FUNCTIONS def switch_shot(self, shot_id=None): if shot_id is None: shot_id = self.b_goto.value() self.shot_id = max(0, shot_id % self.dataset.shots.shape[0]) self.current_shot = self.dataset.shots.iloc[self.shot_id, :] self.meta_table.setRowCount(self.current_shot.shape[0]) self.meta_table.setColumnCount(2) for row, (k, v) in enumerate(self.current_shot.items()): self.meta_table.setItem(row, 0, QTableWidgetItem(k)) self.meta_table.setItem(row, 1, QTableWidgetItem(str(v))) self.meta_table.resizeRowsToContents() shot = self.current_shot title = {'sample': '', 'region': 'Reg', 'feature': 'Feat', 'frame': 'Frame', 'event': 'Ev', 'file': ''} titlestr = '' for k, v in title.items(): titlestr += f'{v} {shot[k]}' if k in shot.keys() else '' titlestr += f' ({shot.name} of {self.dataset.shots.shape[0]})' print(titlestr) self.setWindowTitle(titlestr) self.b_goto.blockSignals(True) self.b_goto.setValue(self.shot_id) self.b_goto.blockSignals(False) self.update() def switch_shot_rel(self, shift): self.switch_shot(self.shot_id + shift) def mouse_moved(self, evt): mousePoint = self.img.mapFromDevice(evt[0]) x, y = round(mousePoint.x()), round(mousePoint.y()) x = min(max(0, x), self.diff_image.shape[1] - 1) y = min(max(0, y), self.diff_image.shape[0] - 1) I = self.diff_image[y, x] #print(x, y, I) self.info_text.setPos(x, y) self.info_text.setText(f'{x:0.1f}, {y:0.1f}: {I:0.1f}') def init_widgets(self): self.imageWidget = pg.GraphicsLayoutWidget() # IMAGE DISPLAY # A plot area (ViewBox + axes) for displaying the image self.image_box = self.imageWidget.addViewBox() self.image_box.setAspectLocked() self.img = pg.ImageItem() self.img.setZValue(0) self.image_box.addItem(self.img) self.proxy = pg.SignalProxy(self.img.scene().sigMouseMoved, rateLimit=60, slot=self.mouse_moved) self.found_peak_canvas = pg.ScatterPlotItem() self.image_box.addItem(self.found_peak_canvas) self.found_peak_canvas.setZValue(2) self.predicted_peak_canvas = pg.ScatterPlotItem() self.image_box.addItem(self.predicted_peak_canvas) self.predicted_peak_canvas.setZValue(2) self.info_text = pg.TextItem(text='') self.image_box.addItem(self.info_text) self.info_text.setPos(0, 0) # Contrast/color control self.hist_img = pg.HistogramLUTItem(self.img, fillHistogram=False) self.imageWidget.addItem(self.hist_img) # MAP DISPLAY self.map_widget = pg.GraphicsLayoutWidget() self.map_widget.setWindowTitle('region map') # Map image control self.map_box = self.map_widget.addViewBox() self.map_box.setAspectLocked() self.mapimg = pg.ImageItem() self.mapimg.setZValue(0) self.map_box.addItem(self.mapimg) self.found_features_canvas = pg.ScatterPlotItem() self.map_box.addItem(self.found_features_canvas) self.found_features_canvas.setZValue(2) self.single_feature_canvas = pg.ScatterPlotItem() self.map_box.addItem(self.single_feature_canvas) self.single_feature_canvas.setZValue(2) # lattice vectors self.a_dir = pg.PlotDataItem(pen=pg.mkPen('r', width=1)) self.b_dir = pg.PlotDataItem(pen=pg.mkPen('g', width=1)) self.c_dir = pg.PlotDataItem(pen=pg.mkPen('b', width=1)) self.map_box.addItem(self.a_dir) self.map_box.addItem(self.b_dir) self.map_box.addItem(self.c_dir) # Contrast/color control self.hist_map = pg.HistogramLUTItem(self.mapimg) self.map_widget.addItem(self.hist_map) ### CONTROl BUTTONS b_rand = QPushButton('rnd') b_plus10 = QPushButton('+10') b_minus10 = QPushButton('-10') b_last = QPushButton('last') self.b_peaks = QCheckBox('peaks') self.b_pred = QCheckBox('crystal') self.b_zoom = QCheckBox('zoom') self.b_locations = QCheckBox('locations') self.b_locations.setChecked(True) b_reload = QPushButton('reload') self.b_goto = QSpinBox() b_rand.clicked.connect(lambda: self.switch_shot(np.random.randint(0, self.dataset.shots.shape[0] - 1))) b_plus10.clicked.connect(lambda: self.switch_shot_rel(+10)) b_minus10.clicked.connect(lambda: self.switch_shot_rel(-10)) b_last.clicked.connect(lambda: self.switch_shot(self.dataset.shots.index.max())) self.b_peaks.stateChanged.connect(self.update) self.b_pred.stateChanged.connect(self.update) self.b_zoom.stateChanged.connect(self.update) self.b_locations.stateChanged.connect(self.update) b_reload.clicked.connect(lambda: self.read_files()) self.b_goto.valueChanged.connect(lambda: self.switch_shot(None)) self.button_layout = QtGui.QGridLayout() self.button_layout.addWidget(b_plus10, 0, 2) self.button_layout.addWidget(b_minus10, 0, 1) self.button_layout.addWidget(b_rand, 0, 4) self.button_layout.addWidget(b_last, 0, 3) self.button_layout.addWidget(self.b_goto, 0, 0) self.button_layout.addWidget(b_reload, 0, 10) self.button_layout.addWidget(self.b_peaks, 0, 21) self.button_layout.addWidget(self.b_pred, 0, 22) self.button_layout.addWidget(self.b_zoom, 0, 23) self.button_layout.addWidget(self.b_locations, 0, 24) self.meta_table = QTableWidget() self.meta_table.verticalHeader().setVisible(False) self.meta_table.horizontalHeader().setVisible(False) self.meta_table.setFont(QtGui.QFont('Helvetica', 10)) # --- TOP-LEVEL ARRANGEMENT self.top_layout = QGridLayout() self.setLayout(self.top_layout) if self.args.internal: self.top_layout.addWidget(self.imageWidget, 0, 0) self.top_layout.setColumnStretch(0, 2) if not self.args.no_map: self.top_layout.addWidget(self.map_widget, 0, 1) self.top_layout.setColumnStretch(1, 1.5) self.top_layout.addWidget(self.meta_table, 0, 2) self.top_layout.addLayout(self.button_layout, 1, 0, 1, 3) self.top_layout.setColumnStretch(2, 0)
def broadcast(fn, opt: PreProcOpts): """Pre-processes in one go a dataset comprising movie frames, by transferring the found beam center positions and diffraction spots (in CXI format) from an aggregated set processed earlier. Arguments: fn {function} -- [description] opt {PreProcOpts} -- [description] Raises: err: [description] Returns: [type] -- [description] """ if isinstance(fn, list) and len(fn) == 1: fn = fn[0] def log(*args): if not (opt.verbose or any([isinstance(err, Exception) for e in args])): return if isinstance(fn, list): dispfn = os.path.basename(fn[0]) + ' etc.' else: dispfn = os.path.basename(fn) idstring = '[{} - {} - broadcast] '.format( datetime.datetime.now().time(), dispfn) print(idstring, *args) t0 = time() dsagg = Dataset.from_list(fn, load_tables=False) dsraw = Dataset.from_list(list(dsagg.shots.file_raw.unique())) dsraw.shots['file_raw'] = dsraw.shots[ 'file'] # required for association later reference = imread(opt.reference) pxmask = imread(opt.pxmask) # And now: the interesting part... dsagg.shots['from_id'] = range( dsagg.shots.shape[0]) # label original (aggregated shots) in_agg = dsraw.shots[opt.idfields].merge( dsagg.shots[opt.idfields + ['selected']], on=opt.idfields, how='left')['selected'].fillna(False) dsraw.shots['selected'] = in_agg try: dsraw.open_stacks(readonly=True) dssel = dsraw.get_selection(f'({opt.select_query}) and selected', file_suffix=opt.single_suffix, new_folder=opt.proc_dir) shots = dssel.shots.merge(dsagg.shots[opt.idfields + ['from_id']], on=opt.idfields, validate='m:1') # _inner_ merge log(f'{dsraw.shots.shape[0]} raw, {dssel.shots.shape[0]} selected, {dsagg.shots.shape[0]} aggregated.' ) # get the broadcasted image centers dsagg.open_stacks(readonly=True) ctr = dsagg.stacks[opt.center_stack][shots.from_id.values, :] # Flat-field and dead-pixel correction stack_rechunked = dssel.raw_counts.rechunk( {0: ctr.chunks[0]}) # select and re-chunk the raw data if opt.correct_saturation: stack_ff = proc2d.apply_flatfield( proc2d.apply_saturation_correction(stack_rechunked, opt.shutter_time, opt.dead_time), reference) else: stack_ff = proc2d.apply_flatfield(stack_rechunked, reference) stack = proc2d.correct_dead_pixels(stack_ff, pxmask, strategy='replace', replace_val=-1, mask_gaps=True) centered = proc2d.center_image(stack, ctr[:, 0], ctr[:, 1], opt.xsize, opt.ysize, -1, parallel=True) # add the new stacks to the aggregated dataset alldata = { 'center_of_mass': dsagg.stacks['center_of_mass'][shots.from_id.values, ...], 'lorentz_fit': dsagg.stacks['lorentz_fit'][shots.from_id.values, ...], 'beam_center': ctr, 'centered': centered.astype(np.float32) if opt.float else centered.astype(np.int16), 'pxmask_centered': (centered != -1).astype(np.uint16), 'adf1': proc2d.apply_virtual_detector(centered, opt.r_adf1[0], opt.r_adf1[1]), 'adf2': proc2d.apply_virtual_detector(centered, opt.r_adf2[0], opt.r_adf2[1]) } if opt.broadcast_peaks: alldata.update({ 'nPeaks': dsagg.stacks['nPeaks'][shots.from_id.values, ...], 'peakTotalIntensity': dsagg.stacks['peakTotalIntensity'][shots.from_id.values, ...], 'peakXPosRaw': dsagg.stacks['peakXPosRaw'][shots.from_id.values, ...], 'peakYPosRaw': dsagg.stacks['peakYPosRaw'][shots.from_id.values, ...], }) for lbl, stk in alldata.items(): dssel.add_stack(lbl, stk, overwrite=True) dssel.init_files(overwrite=True) dssel.store_tables(shots=True, features=True) dssel.open_stacks(readonly=False) dssel.delete_stack( 'raw_counts', from_files=False) # we don't need the raw counts in the new files dssel.store_stacks( overwrite=True, progress_bar=False) # this does the actual calculation log('Finished with', dssel.centered.shape[0], 'shots after', time() - t0, 'seconds') except Exception as err: log('Broadcast processing failed:', err) raise err finally: dsagg.close_stacks() dsraw.close_stacks() dssel.close_stacks() return dssel.files
def subtract_bg(fn, opt: PreProcOpts): """Subtracts the background of a diffraction pattern by azimuthal integration excluding the Bragg peaks. Arguments: fn {function} -- [description] opt {PreProcOpts} -- [description] Returns: [type] -- [description] """ if isinstance(fn, list) and len(fn) == 1: fn = fn[0] def log(*args): if not (opt.verbose or any([isinstance(err, Exception) for e in args])): return if isinstance(fn, list): dispfn = os.path.basename(fn[0]) + ' etc.' else: dispfn = os.path.basename(fn) idstring = '[{} - {} - subtract_bg] '.format( datetime.datetime.now().time(), dispfn) print(idstring, *args) ds = Dataset().from_list(fn) ds.open_stacks(readonly=False) if opt.rerun_peak_finder: pks = find_peaks(ds, opt=opt) nPeaks = da.from_array(pks['nPeaks'][:, np.newaxis, np.newaxis], chunks=(ds.centered.chunks[0], 1, 1)) peakX = da.from_array(pks['peakXPosRaw'][:, :, np.newaxis], chunks=(ds.centered.chunks[0], -1, 1)) peakY = da.from_array(pks['peakYPosRaw'][:, :, np.newaxis], chunks=(ds.centered.chunks[0], -1, 1)) else: nPeaks = ds.nPeaks[:, np.newaxis, np.newaxis] peakX = ds.peakXPosRaw[:, :, np.newaxis] peakY = ds.peakYPosRaw[:, :, np.newaxis] original = ds.centered bg_corrected = da.map_blocks(proc2d.remove_background, original, original.shape[2] / 2, original.shape[1] / 2, nPeaks, peakX, peakY, peak_radius=opt.peak_radius, filter_len=opt.filter_len, dtype=np.float32 if opt.float else np.int32, chunks=original.chunks) ds.add_stack('centered', bg_corrected, overwrite=True) ds.change_filenames(opt.nobg_file_suffix) ds.init_files(keep_features=False, overwrite=True) ds.store_tables(shots=True, features=True) ds.open_stacks(readonly=False) # for lbl in ['nPeaks', 'peakTotalIntensity', 'peakXPosRaw', 'peakYPosRaw']: # if lbl in ds.stacks: # ds.delete_stack(lbl, from_files=False) try: ds.store_stacks(overwrite=True, progress_bar=False) except Exception as err: log('Error during background correction:', err) raise err finally: ds.close_stacks() return ds.files
def refine_center(fn, opt: PreProcOpts): """Refines the centering of diffraction patterns based on Friedel mate positions. Arguments: fn {str} -- [file/list name of input files, can contain wildcards] opt {PreProcOpts} -- [pre-processing options] Raises: err: [description] Returns: [list] -- [output files] """ if isinstance(fn, list) and len(fn) == 1: fn = fn[0] def log(*args): if not (opt.verbose or any([isinstance(err, Exception) for e in args])): return if isinstance(fn, list): dispfn = os.path.basename(fn[0]) + ' etc.' else: dispfn = os.path.basename(fn) idstring = '[{} - {} - refine_center] '.format( datetime.datetime.now().time(), dispfn) print(idstring, *args) ds = Dataset.from_list(fn) stream = find_peaks(ds, opt=opt, merge_peaks=False, return_cxi=False, geo_params={'clen': opt.cam_length}, exc=opt.im_exc) p0 = [opt.xsize // 2, opt.ysize // 2] # get Friedel-refined center from stream file ctr = proc_peaks.center_friedel(stream.peaks, ds.shots, p0=p0, sigma=opt.peak_sigma, minpeaks=opt.min_peaks, maxres=opt.friedel_max_radius) maxcorr = ctr['friedel_cost'].values changed = np.logical_not(np.isnan(maxcorr)) with ds.Stacks() as stk: beam_center_old = stk['beam_center'].compute() # previous beam center beam_center_new = beam_center_old.copy() beam_center_new[changed, :] = np.ceil(beam_center_old[ changed, :]) + ctr.loc[changed, ['beam_x', 'beam_y']].values - p0 if (np.abs(np.mean(beam_center_new - beam_center_old, axis=0) > .5)).any(): log('WARNING: average shift is larger than 0.5!') # visualization log( '{:g}% of shots refined. \n'.format( (1 - np.isnan(maxcorr).sum() / len(maxcorr)) * 100), 'Shift standard deviation: {} \n'.format( np.std(beam_center_new - beam_center_old, axis=0)), 'Average shift: {} \n'.format( np.mean(beam_center_new - beam_center_old, axis=0))) # make new files and add the shifted images try: ds.open_stacks(readonly=False) centered2 = proc2d.center_image(ds.centered, ctr['beam_x'].values, ctr['beam_y'].values, 1556, 616, -1) ds.add_stack('centered', centered2, overwrite=True) ds.add_stack('pxmask_centered', (centered2 != -1).astype(np.uint16), overwrite=True) ds.add_stack('beam_center', beam_center_new, overwrite=True) ds.change_filenames(opt.refined_file_suffix) print(ds.files) ds.init_files(keep_features=False, overwrite=True) ds.store_tables(shots=True, features=True) ds.open_stacks(readonly=False) ds.store_stacks(overwrite=True, progress_bar=False) ds.close_stacks() del centered2 except Exception as err: log('Error during storing center-refined images', err) raise err finally: ds.close_stacks() # run peak finder again, this time on the refined images pks_cxi = find_peaks(ds, opt=opt, merge_peaks=opt.peaks_nexus, return_cxi=True, geo_params={'clen': opt.cam_length}, exc=opt.im_exc) # export peaks to CXI-format arrays if opt.peaks_cxi: with ds.Stacks() as stk: for k, v in pks_cxi.items(): if k in stk: ds.delete_stack(k, from_files=True) ds.add_stack(k, v, overwrite=True) ds.store_stacks(list(pks_cxi.keys()), progress_bar=True, overwrite=True) return ds.files
def from_raw(fn, opt: PreProcOpts): if isinstance(fn, list) and len(fn) == 1: fn = fn[0] def log(*args): if not (opt.verbose or any([isinstance(err, Exception) for e in args])): return if isinstance(fn, list): dispfn = os.path.basename(fn[0]) + ' etc.' else: dispfn = os.path.basename(fn) idstring = '[{} - {} - from_raw] '.format( datetime.datetime.now().time(), dispfn) print(idstring, *args) t0 = time() dsraw = Dataset().from_list(fn) reference = imread(opt.reference) pxmask = imread(opt.pxmask) os.makedirs(opt.scratch_dir, exist_ok=True) os.makedirs(opt.proc_dir, exist_ok=True) dsraw.open_stacks(readonly=True) if opt.aggregate: dsagg = dsraw.aggregate(file_suffix=opt.agg_file_suffix, new_folder=opt.proc_dir, force_commensurate=False, how={'raw_counts': 'sum'}, query=opt.agg_query) else: dsagg = dsraw.get_selection(opt.agg_query, new_folder=opt.proc_dir, file_suffix=opt.agg_file_suffix) log(f'{dsraw.shots.shape[0]} raw, {dsagg.shots.shape[0]} aggregated/selected.' ) if opt.rechunk is not None: dsagg.rechunk_stacks(opt.rechunk) # Saturation, flat-field and dead-pixel correction if opt.correct_saturation: stack_ff = proc2d.apply_flatfield( proc2d.apply_saturation_correction(dsagg.raw_counts, opt.shutter_time, opt.dead_time), reference) else: stack_ff = proc2d.apply_flatfield(dsagg.raw_counts, reference) stack = proc2d.correct_dead_pixels(stack_ff, pxmask, strategy='replace', replace_val=-1, mask_gaps=True) # Stack in central region along x (note that the gaps are not masked this time) xrng = slice((opt.xsize - opt.com_xrng) // 2, (opt.xsize + opt.com_xrng) // 2) stack_ct = proc2d.correct_dead_pixels(stack_ff[:, :, xrng], pxmask[:, xrng], strategy='replace', replace_val=-1, mask_gaps=False) # Define COM threshold as fraction of highest pixel (after discarding some too high ones) thr = stack_ct.max(axis=1).topk(10, axis=1)[:, 9].reshape( (-1, 1, 1)) * opt.com_threshold com = proc2d.center_of_mass2( stack_ct, threshold=thr) + [[(opt.xsize - opt.com_xrng) // 2, 0]] # Lorentzian fit in region around the found COM lorentz = compute.map_reduction_func(stack, proc2d.lorentz_fast, com[:, 0], com[:, 1], radius=opt.lorentz_radius, limit=opt.lorentz_maxshift, scale=7, threads=False, output_len=4) ctr = lorentz[:, 1:3] # calculate the centered image by shifting and padding with -1 centered = proc2d.center_image(stack, ctr[:, 0], ctr[:, 1], opt.xsize, opt.ysize, -1, parallel=True) # add the new stacks to the aggregated dataset alldata = { 'center_of_mass': com, 'lorentz_fit': lorentz, 'beam_center': ctr, 'centered': centered, 'pxmask_centered': (centered != -1).astype(np.uint16), 'adf1': proc2d.apply_virtual_detector(centered, opt.r_adf1[0], opt.r_adf1[1]), 'adf2': proc2d.apply_virtual_detector(centered, opt.r_adf2[0], opt.r_adf2[1]) } for lbl, stk in alldata.items(): print('adding', lbl, stk.shape) dsagg.add_stack(lbl, stk, overwrite=True) # make the files and crunch the data try: dsagg.init_files(overwrite=True) dsagg.store_tables(shots=True, features=True) dsagg.open_stacks(readonly=False) dsagg.delete_stack( 'raw_counts', from_files=False) # we don't need the raw counts in the new files dsagg.store_stacks( overwrite=True, progress_bar=False) # this does the actual calculation log('Finished first centering', dsagg.centered.shape[0], 'shots after', time() - t0, 'seconds') except Exception as err: log('Raw processing failed.', err) raise err finally: dsagg.close_stacks() dsraw.close_stacks() return dsagg.files
def main(): parser = argparse.ArgumentParser( description= 'Quick and dirty pre-processing for Serial Electron Diffraction data', allow_abbrev=False, epilog= 'Any other options are passed on as modification to the option file') parser.add_argument( 'filename', type=str, nargs='*', help= 'List or HDF5 file or glob pattern. Glob pattern must be given in SINGLE quotes.' ) parser.add_argument('-s', '--settings', type=str, help='Option YAML file. Defaults to \'preproc.yaml\'.', default='preproc.yaml') parser.add_argument( '-A', '--address', type=str, help= 'Address of an existing dask.distributed cluster to use instead of making a new one. Defaults to making a new one.', default=None) parser.add_argument( '-N', '--nprocs', type=int, help= 'Number of processes of a new dask.distributed cluster. Defaults to letting dask decide.', default=None) parser.add_argument( '-L', '--local-directory', type=str, help= 'Fast (scratch) directory for computations. Defaults to the current directory.', default=None) parser.add_argument( '-c', '--chunksize', type=int, help= 'Chunk size of raw data stack. Should be integer multiple of movie stack frames! Defaults to 100.', default=100) parser.add_argument('-l', '--list-file', type=str, help='Name of output list file', default='processed.lst') parser.add_argument('-w', '--wait-for-files', help='Wait for files matching wildcard pattern', action='store_true') parser.add_argument( '--include-existing', help='When using -w/--wait-for-file, also include existing files', action='store_true') parser.add_argument('--append', help='Append to list instead of overwrite', action='store_true') parser.add_argument( '-d', '--data-path-old', type=str, help='Raw data field in HDF5 file(s). Defaults to /entry/data/raw_data', default='/%/data/raw_counts') parser.add_argument( '-n', '--data-path-new', type=str, help= 'Corrected data field in HDF5 file(s). Defaults to /entry/data/corrected', default='/%/data/corrected') parser.add_argument('--no-bgcorr', help='Skip background correction', action='store_true') parser.add_argument( '--no-validate', help='Do not validate files before attempting to process', action='store_true') # parser.add_argument('ppopt', nargs=argparse.REMAINDER, help='Preprocessing options to be overriden') args, extra = parser.parse_known_args() # print(args, extra) # raise RuntimeError('thus far!') opts = pre_proc_opts.PreProcOpts(args.settings) label_raw = args.data_path_old.rsplit('/', 1)[-1] label = args.data_path_new.rsplit('/', 1)[-1] if extra: # If extra arguments have been supplied, overwrite existing values opt_parser = argparse.ArgumentParser() for k, v in opts.__dict__.items(): opt_parser.add_argument('--' + k, type=type(v), default=None) opts2 = opt_parser.parse_args(extra) for k, v in vars(opts2).items(): if v is not None: if type(v) != type(opts.__dict__[k]): warn('Mismatch of data types in overriden argument!', RuntimeWarning) print( f'Overriding option file setting {k} = {opts.__dict__[k]} ({type(opts.__dict__[k])}). ', f'New value is {v} ({type(v)})') opts.__dict__[k] = v # raise RuntimeError('thus far!') print(f'Running on diffractem:', version()) print(f'Current path is:', os.getcwd()) # client = Client() if args.address is not None: print('Connecting to cluster scheduler at', args.address) try: client = Client(address=args.address, timeout=3) except: print( f'\n----\nThere seems to be no dask.distributed scheduler running at {args.address}.\n' f'Please double-check or start one by either omitting the --address option.' ) return else: print('Creating a dask.distributed cluster...') client = Client(n_workers=args.nprocs, local_directory=args.local_directory, processes=True) print('\n\n---\nStarted dask.distributed cluster:') print(client) print('You can access the dashboard for monitoring at: ', client.dashboard_link) client.run(os.chdir, os.getcwd()) if len(args.filename) == 1: args.filename = args.filename[0] # print(args.filename) seen_raw_files = [] if args.include_existing else io.expand_files( args.filename) while True: if args.wait_for_files: # slightly awkward sequence to only open finished files... (but believe me - it works!) fns = io.expand_files(args.filename) # print(fns) fns = [fn for fn in fns if fn not in seen_raw_files] # validation... try: fns = io.expand_files(fns, validate=not args.no_validate) except (OSError, IOError, RuntimeError) as err: print(f'Could not open file(s) {" ".join(fns)} because of', err) print( 'Possibly, it is still being written to. Waiting a bit...') sleep(5) continue if not fns: print('No new files, waiting a bit...') sleep(5) continue else: print(f'Found new files(s):\n', '\n'.join(fns)) try: ds_raw = Dataset.from_files(fns, chunking=args.chunksize) except Exception as err: print(f'Could not open file(s) {" ".join(fns)} because of', err) print( 'Possibly, it is still being written to. Waiting a bit...' ) sleep(5) continue else: fns = io.expand_files(args.filename, validate=not args.no_validate) if fns: ds_raw = Dataset.from_files(fns, chunking=args.chunksize) else: print( f'\n---\n\nFile(s) {args.filename} not found or (all of them) invalid.' ) return seen_raw_files.extend(ds_raw.files) print('---- Have dataset ----') print(ds_raw) # delete undesired stacks delstacks = [ sn for sn in ds_raw.stacks.keys() if sn != args.data_path_old.rsplit('/', 1)[-1] ] for sn in delstacks: ds_raw.delete_stack(sn) if opts.aggregate: print('---- Aggregating raw data ----') ds_compute = ds_raw.aggregate( query=opts.agg_query, by=['sample', 'region', 'run', 'crystal_id'], how='sum', new_folder=opts.proc_dir, file_suffix=opts.agg_file_suffix) else: ds_compute = ds_raw.get_selection(query=opts.select_query, file_suffix=opts.agg_file_suffix) print('Initializing data files...') os.makedirs(opts.proc_dir, exist_ok=True) ds_compute.init_files(overwrite=True) print('Storing meta tables...') ds_compute.store_tables(shots=True, features=True) print( f'Processing diffraction data... monitor progress at {client.dashboard_link} (or forward port if remote)' ) chunk_info = quick_proc(ds_compute, opts, label_raw, label, client) # make sure that the calculation went consistent with the data set for (sh, sh_grp), (ch, ch_grp) in zip( ds_compute.shots.groupby(['file', 'subset']), chunk_info.groupby(['file', 'subset'])): if any(sh_grp.shot_in_subset.values != np.sort( np.concatenate(ch_grp.shot_in_subset.values))): raise ValueError( f'Incosistency between calculated data and shot list in {sh[0]}: {sh[1]} found. Please investigate.' ) ds_compute.write_list(args.list_file, append=args.append) print(f'Computation done. Processed files are in {args.list_file}') if not args.wait_for_files: break