Beispiel #1
0
    def __init__(self, args):

        super().__init__()
        self.dataset = Dataset()
        self.args = args
        self.data_path = None
        self.current_shot = pd.Series()
        self.diff_image = np.empty((0,0))
        self.map_image = np.empty((0,0))
        self.init_widgets()
        self.adxv = None
        self.geom = None

        self.read_files()
        self.switch_shot(0)

        if self.args.internal:
            self.hist_img.setLevels(np.quantile(self.diff_image, 0.02), np.quantile(self.diff_image, 0.98))

        self.update()

        self.show()
Beispiel #2
0
    def read_files(self):

        file_type = args.filename.rsplit('.', 1)[-1]

        if file_type == 'stream':
            print(f'Parsing stream file {args.filename}...')
            stream = StreamParser(args.filename)
            # with open('tmp.geom', 'w') as fh:
            #     fh.write('\n'.join(stream._geometry_string))
            # self.geom = load_crystfel_geometry('tmp.geom')
            # os.remove('tmp.geom')
            # if len(self.geom['panels']) == 1:
            #     print('Single-panel geometry, so ignoring transforms for now.')
            #     #TODO make this more elegant, e.g. by overwriting image transform func with identity
            #     self.geom = None
            self.geom = None
            
            try:
                self.data_path = stream.geometry['data']
            except KeyError:
                if args.geometry is None:
                    raise ValueError('No data location specified in geometry file. Please use -d parameter.')

            files = sorted(list(stream.shots['file'].unique()))
            # print('Loading data files found in stream... \n', '\n'.join(files))
            try:
                self.dataset = Dataset.from_files(files, load_tables=False, init_stacks=False, open_stacks=False)
                self.dataset.load_tables(features=True)
                # print(self.dataset.shots.columns)
                self.dataset.merge_stream(stream)
                # get_selection would not be the right method to call (changes IDs), instead do...
                self.dataset._shots = self.dataset._shots.loc[self.dataset._shots.selected,:].reset_index(drop=True)
                # TODO get subset for incomplete coverage
                print('Merged stream and hdf5 shot lists')
            except Exception as err:
                self.dataset = Dataset()
                self.dataset._shots = stream.shots
                self.dataset._peaks = stream.peaks
                self.dataset._predict = stream.indexed
                self.dataset._shots['selected'] = True
                print('Could not load shot lists from H5 files, but have that from the stream file.')
                print(f'Reason: {err}')

        if args.geometry is not None:
            raise ValueError('Geometry files are currently not supported.')
            # self.geom = load_crystfel_geometry(args.geometry)

        if file_type in ['lst', 'h5', 'hdf', 'nxs']:
            self.dataset = Dataset.from_list(args.filename, load_tables=True, init_stacks=False, open_stacks=False)
            if not self.dataset.shots.selected.all():
                # dirty removal of unwanted shots is sufficient in this case:
                self.dataset._shots = self.dataset._shots.loc[self.dataset._shots.selected,:].reset_index(drop=True)

        if args.data_path is not None:
            self.data_path = args.data_path

        if self.data_path is None:
            # data path neither set via stream file, nor explicitly. We have to guess.
            try:
                with h5py.File(self.dataset.shots.file.iloc[0], 'r') as fh:
                    base = '/%/data'.replace('%', self.dataset.shots.subset.iloc[0])
                    self.data_path = '/%/data/' + fh[base].attrs['signal']
                print('Found data path', self.data_path)
            except Exception as err:
                warn(str(err), RuntimeWarning)
                print('Could not find out data path. Assuming /%/data/raw_counts')
                self.data_path = '/%/data/raw_counts'

        if self.args.query:
            print('Only showing shots with', self.args.query)
            #self.dataset.select(self.args.query)
            #self.dataset = self.dataset.get_selection(self.args.query, file_suffix=None, reset_id=False)
            #print('cutting shot list only')
            self.dataset._shots = self.dataset._shots.query(args.query)

        if self.args.sort_crystals:
            print('Re-sorting shots by region/crystal/run.')
            self.dataset._shots = self.dataset._shots.sort_values(by=['sample', 'region', 'crystal_id', 'run'])

        if not self.args.internal:
            #adxv_args = {'wavelength': 0.0251, 'distance': 2280, 'pixelsize': 0.055}
            adxv_args = {}
            self.adxv = Adxv(hdf5_path=self.data_path.replace('%', 'entry'),
                             adxv_bin=self.args.adxv_bin, **adxv_args)

        self.b_goto.setMaximum(self.dataset.shots.shape[0]-1)
        self.b_goto.setMinimum(0)
Beispiel #3
0
def cumulate(fn, opt: PreProcOpts):
    """Applies cumulative summation to a data set comprising movie frame stacks. At the moment, requires
    the summed frame stacks to have the same shape as the raw data.
    
    Arguments:
        fn {function} -- [description]
        opt {PreProcOpts} -- [description]
    
    Raises:
        err: [description]
    
    Returns:
        [type] -- [description]
    """

    if isinstance(fn, list) and len(fn) == 1:
        fn = fn[0]

    def log(*args):
        if isinstance(fn, list):
            dispfn = os.path.basename(fn[0]) + ' etc.'
        else:
            dispfn = os.path.basename(fn)
        idstring = '[{} - {} - cumulate] '.format(
            datetime.datetime.now().time(), dispfn)
        print(idstring, *args)

    dssel = Dataset().from_list(fn)
    log('Cumulating from frame', opt.cum_first_frame)
    dssel.open_stacks(readonly=False)

    # chunks for aggregation
    chunks = tuple(
        dssel.shots.groupby(opt.idfields).count()['selected'].values)
    for k, stk in dssel.stacks.items():
        if stk.chunks[0] != chunks:
            if k == 'index':
                continue
            log(k, 'needs rechunking...')
            dssel.add_stack(k, stk.rechunk({0: chunks}), overwrite=True)
    dssel._zchunks = chunks

    def cumfunc(movie):
        movie_out = movie
        movie_out[opt.cum_first_frame:,
                  ...] = np.cumsum(movie[opt.cum_first_frame:, ...], axis=0)
        return movie_out

    for k in opt.cum_stacks:
        dssel.stacks[k] = dssel.stacks[k].map_blocks(
            cumfunc, dtype=dssel.stacks[k].dtype)

    dssel.change_filenames(opt.cum_file_suffix)
    dssel.init_files(overwrite=True, keep_features=False)
    log('File initialized, writing tables...')
    dssel.store_tables(shots=True, features=True)

    try:
        dssel.open_stacks(readonly=False)
        log('Writing stack data...')
        dssel.store_stacks(overwrite=True, progress_bar=False)

    except Exception as err:
        log('Cumulative processing failed.')
        raise err

    finally:
        dssel.close_stacks()
        log('Cumulation done.')

    return dssel.files
Beispiel #4
0
class EDViewer(QWidget):

    def __init__(self, args):

        super().__init__()
        self.dataset = Dataset()
        self.args = args
        self.data_path = None
        self.current_shot = pd.Series()
        self.diff_image = np.empty((0,0))
        self.map_image = np.empty((0,0))
        self.init_widgets()
        self.adxv = None
        self.geom = None

        self.read_files()
        self.switch_shot(0)

        if self.args.internal:
            self.hist_img.setLevels(np.quantile(self.diff_image, 0.02), np.quantile(self.diff_image, 0.98))

        self.update()

        self.show()

    def closeEvent(self, a0: QtGui.QCloseEvent) -> None:
        if not self.args.internal:
            self.adxv.exit()
        a0.accept()

    def read_files(self):

        file_type = args.filename.rsplit('.', 1)[-1]

        if file_type == 'stream':
            print(f'Parsing stream file {args.filename}...')
            stream = StreamParser(args.filename)
            # with open('tmp.geom', 'w') as fh:
            #     fh.write('\n'.join(stream._geometry_string))
            # self.geom = load_crystfel_geometry('tmp.geom')
            # os.remove('tmp.geom')
            # if len(self.geom['panels']) == 1:
            #     print('Single-panel geometry, so ignoring transforms for now.')
            #     #TODO make this more elegant, e.g. by overwriting image transform func with identity
            #     self.geom = None
            self.geom = None
            
            try:
                self.data_path = stream.geometry['data']
            except KeyError:
                if args.geometry is None:
                    raise ValueError('No data location specified in geometry file. Please use -d parameter.')

            files = sorted(list(stream.shots['file'].unique()))
            # print('Loading data files found in stream... \n', '\n'.join(files))
            try:
                self.dataset = Dataset.from_files(files, load_tables=False, init_stacks=False, open_stacks=False)
                self.dataset.load_tables(features=True)
                # print(self.dataset.shots.columns)
                self.dataset.merge_stream(stream)
                # get_selection would not be the right method to call (changes IDs), instead do...
                self.dataset._shots = self.dataset._shots.loc[self.dataset._shots.selected,:].reset_index(drop=True)
                # TODO get subset for incomplete coverage
                print('Merged stream and hdf5 shot lists')
            except Exception as err:
                self.dataset = Dataset()
                self.dataset._shots = stream.shots
                self.dataset._peaks = stream.peaks
                self.dataset._predict = stream.indexed
                self.dataset._shots['selected'] = True
                print('Could not load shot lists from H5 files, but have that from the stream file.')
                print(f'Reason: {err}')

        if args.geometry is not None:
            raise ValueError('Geometry files are currently not supported.')
            # self.geom = load_crystfel_geometry(args.geometry)

        if file_type in ['lst', 'h5', 'hdf', 'nxs']:
            self.dataset = Dataset.from_list(args.filename, load_tables=True, init_stacks=False, open_stacks=False)
            if not self.dataset.shots.selected.all():
                # dirty removal of unwanted shots is sufficient in this case:
                self.dataset._shots = self.dataset._shots.loc[self.dataset._shots.selected,:].reset_index(drop=True)

        if args.data_path is not None:
            self.data_path = args.data_path

        if self.data_path is None:
            # data path neither set via stream file, nor explicitly. We have to guess.
            try:
                with h5py.File(self.dataset.shots.file.iloc[0], 'r') as fh:
                    base = '/%/data'.replace('%', self.dataset.shots.subset.iloc[0])
                    self.data_path = '/%/data/' + fh[base].attrs['signal']
                print('Found data path', self.data_path)
            except Exception as err:
                warn(str(err), RuntimeWarning)
                print('Could not find out data path. Assuming /%/data/raw_counts')
                self.data_path = '/%/data/raw_counts'

        if self.args.query:
            print('Only showing shots with', self.args.query)
            #self.dataset.select(self.args.query)
            #self.dataset = self.dataset.get_selection(self.args.query, file_suffix=None, reset_id=False)
            #print('cutting shot list only')
            self.dataset._shots = self.dataset._shots.query(args.query)

        if self.args.sort_crystals:
            print('Re-sorting shots by region/crystal/run.')
            self.dataset._shots = self.dataset._shots.sort_values(by=['sample', 'region', 'crystal_id', 'run'])

        if not self.args.internal:
            #adxv_args = {'wavelength': 0.0251, 'distance': 2280, 'pixelsize': 0.055}
            adxv_args = {}
            self.adxv = Adxv(hdf5_path=self.data_path.replace('%', 'entry'),
                             adxv_bin=self.args.adxv_bin, **adxv_args)

        self.b_goto.setMaximum(self.dataset.shots.shape[0]-1)
        self.b_goto.setMinimum(0)

    def update_image(self):
        print(self.current_shot)
        with h5py.File(self.current_shot['file'], mode='r') as f:

            if self.args.internal:
                path = self.data_path.replace('%', self.current_shot.subset)
                print('Loading {}:{} from {}'.format(path,
                                                     self.current_shot['shot_in_subset'], self.current_shot['file']))                
                if len(f[path].shape) == 3:
                    self.diff_image = f[path][int(self.current_shot['shot_in_subset']), ...]
                elif len(f[path].shape) == 2:
                    self.diff_image = f[path][:]
                    
                self.diff_image[np.isnan(self.diff_image)] = 0
                self.hist_img.setHistogramRange(np.partition(self.diff_image.flatten(), 100)[100], np.partition(self.diff_image.flatten(), -100)[-100])

                levels = self.hist_img.getLevels()
                # levels = (max(levels[0], -1), levels[1])
                levels = (levels[0], levels[1])
                if self.geom is not None:
                    raise RuntimeError('This should not happen')
                    # self.diff_image = apply_geometry_to_data(self.diff_image, self.geom)
                self.img.setImage(self.diff_image, autoRange=False)
                
                self.img.setLevels(levels)
                self.hist_img.setLevels(levels[0], levels[1])

            if not self.args.no_map:
                try:
                    path = args.map_path.replace('%', self.current_shot['subset'])
                    self.map_image = f[path][...]
                    self.mapimg.setImage(self.map_image)
                except KeyError:
                     warn('No map found at {}!'.format(path), Warning)

        if not self.args.internal:
            self.adxv.load_image(self.current_shot.file)
            self.adxv.slab(self.current_shot.shot_in_subset + 1)

    def update_plot(self):

        allpk = []

        if self.b_peaks.isChecked():
            
            if (len(self.dataset.peaks) == 0) or args.cxi_peaks:
                path = args.cxi_peaks_path.replace('%', self.current_shot.subset)
                print('Loading CXI peaks of {}:{} from {}'.format(path,
                                                        self.current_shot['shot_in_subset'], self.current_shot['file']))     
                with h5py.File(self.current_shot.file) as fh:
                    ii = int(self.current_shot['shot_in_subset'])
                    Npk = fh[path + '/nPeaks'][ii]
                    x = fh[path + '/peakXPosRaw'][ii, :Npk]
                    y = fh[path + '/peakYPosRaw'][ii, :Npk]

                peaks = pd.DataFrame((x, y), index=['fs/px', 'ss/px']).T

            else:
                peaks = self.dataset.peaks.loc[(self.dataset.peaks.file == self.current_shot.file)
                                           & (self.dataset.peaks.Event == self.current_shot.Event),
                                           ['fs/px', 'ss/px']] - 0.5
                x = peaks.loc[:,'fs/px']
                y = peaks.loc[:,'ss/px']

            if self.geom is not None:
                raise RuntimeError('Someone set geom to something. This should not happen.')
                # print('Projecting peaks...')
                # maps = compute_visualization_pix_maps(self.geom)
                # x = maps.x[y.astype(int), x.astype(int)]
                # y = maps.y[y.astype(int), x.astype(int)]

            if self.args.internal:
                ring_pen = pg.mkPen('g', width=0.8)
                self.found_peak_canvas.setData(x, y,
                symbol='o', size=13, pen=ring_pen, brush=(0, 0, 0, 0), antialias=True)
            else:
                allpk.append(peaks.assign(group=0))

        else:
            self.found_peak_canvas.clear()

        if self.b_pred.isChecked() and (self.dataset.predict.shape[0] > 0):
            
            pred = self.dataset.predict.loc[(self.dataset.predict.file == self.current_shot.file)
                                           & (self.dataset.predict.Event == self.current_shot.Event),
                                           ['fs/px', 'ss/px']] - 0.5

            if self.geom is not None:
                raise RuntimeError('Someone set geom to not None. This should not happen.')
                # print('Projecting predictions...')
                # maps = compute_visualization_pix_maps(self.geom)
                # x = maps.x[pred.loc[:,'ss/px'].astype(int),
                #     pred.loc[:,'fs/px'].astype(int)]
                # y = maps.y[pred.loc[:,'ss/px'].astype(int),
                #     pred.loc[:,'fs/px'].astype(int)]
            else:
                x = pred.loc[:,'fs/px']
                y = pred.loc[:,'ss/px']

            if self.args.internal:
                square_pen = pg.mkPen('r', width=0.8)
                self.predicted_peak_canvas.setData(x, y,
                                              symbol='s', size=13, pen=square_pen, brush=(0, 0, 0, 0), antialias=True)
            else:
                allpk.append(pred.assign(group=1))

        else:
            self.predicted_peak_canvas.clear()

        if not self.args.internal and len(allpk) > 0:
            self.adxv.define_spot('green', 5, 0, 0)
            self.adxv.define_spot('red', 0, 10, 1)
            self.adxv.load_spots(pd.concat(allpk, axis=0, ignore_index=True).values)
        elif not self.args.internal:
            self.adxv.load_spots(np.empty((0,3)))

        if self.dataset.features.shape[0] > 0:
            ring_pen = pg.mkPen('g', width=2)
            dot_pen = pg.mkPen('y', width=0.5)

            region_feat = self.dataset.features.loc[(self.dataset.features['region'] == self.current_shot['region'])
                                               & (self.dataset.features['sample'] == self.current_shot['sample'])]
            
            print('Number of region features:', region_feat.shape[0])

            if self.current_shot['crystal_id'] != -1:
                single_feat = region_feat.loc[region_feat['crystal_id'] == self.current_shot['crystal_id'], :]
                x0 = single_feat['crystal_x'].squeeze()
                y0 = single_feat['crystal_y'].squeeze()
                if self.b_locations.isChecked():
                    self.found_features_canvas.setData(region_feat['crystal_x'], region_feat['crystal_y'],
                                                symbol='+', size=7, pen=dot_pen, brush=(0, 0, 0, 0), pxMode=True)
                else:
                    self.found_features_canvas.clear()

                if self.b_zoom.isChecked():
                    self.map_box.setRange(xRange=(x0 - 5 * args.beam_diam, x0 + 5 * args.beam_diam),
                                     yRange=(y0 - 5 * args.beam_diam, y0 + 5 * args.beam_diam))
                    self.single_feature_canvas.setData([x0], [y0],
                                                  symbol='o', size=args.beam_diam, pen=ring_pen,
                                                  brush=(0, 0, 0, 0), pxMode=False)
                    try:
                        c_real = np.cross([self.current_shot.astar_x, self.current_shot.astar_y, self.current_shot.astar_z],
                                          [self.current_shot.bstar_x, self.current_shot.bstar_y, self.current_shot.bstar_z])
                        b_real = np.cross([self.current_shot.cstar_x, self.current_shot.cstar_y, self.current_shot.cstar_z],
                                          [self.current_shot.astar_x, self.current_shot.astar_y, self.current_shot.astar_z])
                        a_real = np.cross([self.current_shot.bstar_x, self.current_shot.bstar_y, self.current_shot.bstar_z],
                                          [self.current_shot.cstar_x, self.current_shot.cstar_y, self.current_shot.cstar_z])
                        a_real = 20 * a_real / np.sum(a_real ** 2) ** .5
                        b_real = 20 * b_real / np.sum(b_real ** 2) ** .5
                        c_real = 20 * c_real / np.sum(c_real ** 2) ** .5
                        self.a_dir.setData(x=x0 + np.array([0, a_real[0]]), y=y0 + np.array([0, a_real[1]]))
                        self.b_dir.setData(x=x0 + np.array([0, b_real[0]]), y=y0 + np.array([0, b_real[1]]))
                        self.c_dir.setData(x=x0 + np.array([0, c_real[0]]), y=y0 + np.array([0, c_real[1]]))
                    except:
                        print('Could not read lattice vectors.')
                else:
                    self.single_feature_canvas.setData([x0], [y0],
                                                  symbol='o', size=13, pen=ring_pen, brush=(0, 0, 0, 0), pxMode=True)
                    self.map_box.setRange(xRange=(0, self.map_image.shape[1]), yRange=(0, self.map_image.shape[0]))



            else:
                self.single_feature_canvas.setData([], [])

    def update(self):

        self.found_peak_canvas.clear()
        self.predicted_peak_canvas.clear()
        app.processEvents()

        self.update_image()   
        if args.cxi_peaks and not args.internal:
            # give adxv some time to display the image before accessing the CXI data
            sleep(0.2)
        self.update_plot()

        print(self.current_shot)

    # CALLBACK FUNCTIONS

    def switch_shot(self, shot_id=None):
        if shot_id is None:
            shot_id = self.b_goto.value()

        self.shot_id = max(0, shot_id % self.dataset.shots.shape[0])
        self.current_shot = self.dataset.shots.iloc[self.shot_id, :]
        self.meta_table.setRowCount(self.current_shot.shape[0])
        self.meta_table.setColumnCount(2)

        for row, (k, v) in enumerate(self.current_shot.items()):
            self.meta_table.setItem(row, 0, QTableWidgetItem(k))
            self.meta_table.setItem(row, 1, QTableWidgetItem(str(v)))

        self.meta_table.resizeRowsToContents()

        shot = self.current_shot
        title = {'sample': '', 'region': 'Reg', 'feature': 'Feat', 'frame': 'Frame', 'event': 'Ev', 'file': ''}
        titlestr = ''
        for k, v in title.items():
            titlestr += f'{v} {shot[k]}' if k in shot.keys() else ''
        titlestr += f' ({shot.name} of {self.dataset.shots.shape[0]})'
        print(titlestr)

        self.setWindowTitle(titlestr)

        self.b_goto.blockSignals(True)
        self.b_goto.setValue(self.shot_id)
        self.b_goto.blockSignals(False)

        self.update()

    def switch_shot_rel(self, shift):
        self.switch_shot(self.shot_id + shift)

    def mouse_moved(self, evt):
        mousePoint = self.img.mapFromDevice(evt[0])
        x, y = round(mousePoint.x()), round(mousePoint.y())
        x = min(max(0, x), self.diff_image.shape[1] - 1)
        y = min(max(0, y), self.diff_image.shape[0] - 1)
        I = self.diff_image[y, x]
        #print(x, y, I)
        self.info_text.setPos(x, y)
        self.info_text.setText(f'{x:0.1f}, {y:0.1f}: {I:0.1f}')

    def init_widgets(self):

        self.imageWidget = pg.GraphicsLayoutWidget()

        # IMAGE DISPLAY

        # A plot area (ViewBox + axes) for displaying the image
        self.image_box = self.imageWidget.addViewBox()
        self.image_box.setAspectLocked()

        self.img = pg.ImageItem()
        self.img.setZValue(0)
        self.image_box.addItem(self.img)
        self.proxy = pg.SignalProxy(self.img.scene().sigMouseMoved, rateLimit=60, slot=self.mouse_moved)

        self.found_peak_canvas = pg.ScatterPlotItem()
        self.image_box.addItem(self.found_peak_canvas)
        self.found_peak_canvas.setZValue(2)

        self.predicted_peak_canvas = pg.ScatterPlotItem()
        self.image_box.addItem(self.predicted_peak_canvas)
        self.predicted_peak_canvas.setZValue(2)

        self.info_text = pg.TextItem(text='')
        self.image_box.addItem(self.info_text)
        self.info_text.setPos(0, 0)

        # Contrast/color control
        self.hist_img = pg.HistogramLUTItem(self.img, fillHistogram=False)
        self.imageWidget.addItem(self.hist_img)

        # MAP DISPLAY

        self.map_widget = pg.GraphicsLayoutWidget()
        self.map_widget.setWindowTitle('region map')

        # Map image control
        self.map_box = self.map_widget.addViewBox()
        self.map_box.setAspectLocked()

        self.mapimg = pg.ImageItem()
        self.mapimg.setZValue(0)
        self.map_box.addItem(self.mapimg)

        self.found_features_canvas = pg.ScatterPlotItem()
        self.map_box.addItem(self.found_features_canvas)
        self.found_features_canvas.setZValue(2)

        self.single_feature_canvas = pg.ScatterPlotItem()
        self.map_box.addItem(self.single_feature_canvas)
        self.single_feature_canvas.setZValue(2)

        # lattice vectors
        self.a_dir = pg.PlotDataItem(pen=pg.mkPen('r', width=1))
        self.b_dir = pg.PlotDataItem(pen=pg.mkPen('g', width=1))
        self.c_dir = pg.PlotDataItem(pen=pg.mkPen('b', width=1))
        self.map_box.addItem(self.a_dir)
        self.map_box.addItem(self.b_dir)
        self.map_box.addItem(self.c_dir)

        # Contrast/color control
        self.hist_map = pg.HistogramLUTItem(self.mapimg)
        self.map_widget.addItem(self.hist_map)

        ### CONTROl BUTTONS

        b_rand = QPushButton('rnd')
        b_plus10 = QPushButton('+10')
        b_minus10 = QPushButton('-10')
        b_last = QPushButton('last')
        self.b_peaks = QCheckBox('peaks')
        self.b_pred = QCheckBox('crystal')
        self.b_zoom = QCheckBox('zoom')
        self.b_locations = QCheckBox('locations')
        self.b_locations.setChecked(True)
        b_reload = QPushButton('reload')
        self.b_goto = QSpinBox()

        b_rand.clicked.connect(lambda: self.switch_shot(np.random.randint(0, self.dataset.shots.shape[0] - 1)))
        b_plus10.clicked.connect(lambda: self.switch_shot_rel(+10))
        b_minus10.clicked.connect(lambda: self.switch_shot_rel(-10))
        b_last.clicked.connect(lambda: self.switch_shot(self.dataset.shots.index.max()))
        self.b_peaks.stateChanged.connect(self.update)
        self.b_pred.stateChanged.connect(self.update)
        self.b_zoom.stateChanged.connect(self.update)
        self.b_locations.stateChanged.connect(self.update)
        b_reload.clicked.connect(lambda: self.read_files())
        self.b_goto.valueChanged.connect(lambda: self.switch_shot(None))

        self.button_layout = QtGui.QGridLayout()
        self.button_layout.addWidget(b_plus10, 0, 2)
        self.button_layout.addWidget(b_minus10, 0, 1)
        self.button_layout.addWidget(b_rand, 0, 4)
        self.button_layout.addWidget(b_last, 0, 3)
        self.button_layout.addWidget(self.b_goto, 0, 0)
        self.button_layout.addWidget(b_reload, 0, 10)
        self.button_layout.addWidget(self.b_peaks, 0, 21)
        self.button_layout.addWidget(self.b_pred, 0, 22)
        self.button_layout.addWidget(self.b_zoom, 0, 23)
        self.button_layout.addWidget(self.b_locations, 0, 24)

        self.meta_table = QTableWidget()
        self.meta_table.verticalHeader().setVisible(False)
        self.meta_table.horizontalHeader().setVisible(False)
        self.meta_table.setFont(QtGui.QFont('Helvetica', 10))

        # --- TOP-LEVEL ARRANGEMENT
        self.top_layout = QGridLayout()
        self.setLayout(self.top_layout)

        if self.args.internal:
            self.top_layout.addWidget(self.imageWidget, 0, 0)
            self.top_layout.setColumnStretch(0, 2)
            
        if not self.args.no_map:
            self.top_layout.addWidget(self.map_widget, 0, 1)
            self.top_layout.setColumnStretch(1, 1.5)
            
        self.top_layout.addWidget(self.meta_table, 0, 2)
        self.top_layout.addLayout(self.button_layout, 1, 0, 1, 3)
        
        self.top_layout.setColumnStretch(2, 0)
Beispiel #5
0
def broadcast(fn, opt: PreProcOpts):
    """Pre-processes in one go a dataset comprising movie frames, by transferring the found beam center
    positions and diffraction spots (in CXI format) from an aggregated set processed earlier.
    
    Arguments:
        fn {function} -- [description]
        opt {PreProcOpts} -- [description]
    
    Raises:
        err: [description]
    
    Returns:
        [type] -- [description]
    """

    if isinstance(fn, list) and len(fn) == 1:
        fn = fn[0]

    def log(*args):
        if not (opt.verbose or any([isinstance(err, Exception)
                                    for e in args])):
            return
        if isinstance(fn, list):
            dispfn = os.path.basename(fn[0]) + ' etc.'
        else:
            dispfn = os.path.basename(fn)
        idstring = '[{} - {} - broadcast] '.format(
            datetime.datetime.now().time(), dispfn)
        print(idstring, *args)

    t0 = time()
    dsagg = Dataset.from_list(fn, load_tables=False)
    dsraw = Dataset.from_list(list(dsagg.shots.file_raw.unique()))
    dsraw.shots['file_raw'] = dsraw.shots[
        'file']  # required for association later

    reference = imread(opt.reference)
    pxmask = imread(opt.pxmask)

    # And now: the interesting part...
    dsagg.shots['from_id'] = range(
        dsagg.shots.shape[0])  # label original (aggregated shots)

    in_agg = dsraw.shots[opt.idfields].merge(
        dsagg.shots[opt.idfields + ['selected']], on=opt.idfields,
        how='left')['selected'].fillna(False)
    dsraw.shots['selected'] = in_agg

    try:

        dsraw.open_stacks(readonly=True)
        dssel = dsraw.get_selection(f'({opt.select_query}) and selected',
                                    file_suffix=opt.single_suffix,
                                    new_folder=opt.proc_dir)
        shots = dssel.shots.merge(dsagg.shots[opt.idfields + ['from_id']],
                                  on=opt.idfields,
                                  validate='m:1')  # _inner_ merge

        log(f'{dsraw.shots.shape[0]} raw, {dssel.shots.shape[0]} selected, {dsagg.shots.shape[0]} aggregated.'
            )

        # get the broadcasted image centers
        dsagg.open_stacks(readonly=True)
        ctr = dsagg.stacks[opt.center_stack][shots.from_id.values, :]

        # Flat-field and dead-pixel correction
        stack_rechunked = dssel.raw_counts.rechunk(
            {0: ctr.chunks[0]})  # select and re-chunk the raw data
        if opt.correct_saturation:
            stack_ff = proc2d.apply_flatfield(
                proc2d.apply_saturation_correction(stack_rechunked,
                                                   opt.shutter_time,
                                                   opt.dead_time), reference)
        else:
            stack_ff = proc2d.apply_flatfield(stack_rechunked, reference)
        stack = proc2d.correct_dead_pixels(stack_ff,
                                           pxmask,
                                           strategy='replace',
                                           replace_val=-1,
                                           mask_gaps=True)
        centered = proc2d.center_image(stack,
                                       ctr[:, 0],
                                       ctr[:, 1],
                                       opt.xsize,
                                       opt.ysize,
                                       -1,
                                       parallel=True)

        # add the new stacks to the aggregated dataset
        alldata = {
            'center_of_mass':
            dsagg.stacks['center_of_mass'][shots.from_id.values, ...],
            'lorentz_fit':
            dsagg.stacks['lorentz_fit'][shots.from_id.values, ...],
            'beam_center':
            ctr,
            'centered':
            centered.astype(np.float32)
            if opt.float else centered.astype(np.int16),
            'pxmask_centered': (centered != -1).astype(np.uint16),
            'adf1':
            proc2d.apply_virtual_detector(centered, opt.r_adf1[0],
                                          opt.r_adf1[1]),
            'adf2':
            proc2d.apply_virtual_detector(centered, opt.r_adf2[0],
                                          opt.r_adf2[1])
        }

        if opt.broadcast_peaks:
            alldata.update({
                'nPeaks':
                dsagg.stacks['nPeaks'][shots.from_id.values, ...],
                'peakTotalIntensity':
                dsagg.stacks['peakTotalIntensity'][shots.from_id.values, ...],
                'peakXPosRaw':
                dsagg.stacks['peakXPosRaw'][shots.from_id.values, ...],
                'peakYPosRaw':
                dsagg.stacks['peakYPosRaw'][shots.from_id.values, ...],
            })

        for lbl, stk in alldata.items():
            dssel.add_stack(lbl, stk, overwrite=True)

        dssel.init_files(overwrite=True)
        dssel.store_tables(shots=True, features=True)
        dssel.open_stacks(readonly=False)
        dssel.delete_stack(
            'raw_counts',
            from_files=False)  # we don't need the raw counts in the new files
        dssel.store_stacks(
            overwrite=True,
            progress_bar=False)  # this does the actual calculation
        log('Finished with', dssel.centered.shape[0], 'shots after',
            time() - t0, 'seconds')

    except Exception as err:
        log('Broadcast processing failed:', err)
        raise err

    finally:
        dsagg.close_stacks()
        dsraw.close_stacks()
        dssel.close_stacks()

    return dssel.files
Beispiel #6
0
def subtract_bg(fn, opt: PreProcOpts):
    """Subtracts the background of a diffraction pattern by azimuthal integration excluding the Bragg peaks.
    
    Arguments:
        fn {function} -- [description]
        opt {PreProcOpts} -- [description]
    
    Returns:
        [type] -- [description]
    """

    if isinstance(fn, list) and len(fn) == 1:
        fn = fn[0]

    def log(*args):
        if not (opt.verbose or any([isinstance(err, Exception)
                                    for e in args])):
            return
        if isinstance(fn, list):
            dispfn = os.path.basename(fn[0]) + ' etc.'
        else:
            dispfn = os.path.basename(fn)
        idstring = '[{} - {} - subtract_bg] '.format(
            datetime.datetime.now().time(), dispfn)
        print(idstring, *args)

    ds = Dataset().from_list(fn)
    ds.open_stacks(readonly=False)

    if opt.rerun_peak_finder:
        pks = find_peaks(ds, opt=opt)
        nPeaks = da.from_array(pks['nPeaks'][:, np.newaxis, np.newaxis],
                               chunks=(ds.centered.chunks[0], 1, 1))
        peakX = da.from_array(pks['peakXPosRaw'][:, :, np.newaxis],
                              chunks=(ds.centered.chunks[0], -1, 1))
        peakY = da.from_array(pks['peakYPosRaw'][:, :, np.newaxis],
                              chunks=(ds.centered.chunks[0], -1, 1))
    else:
        nPeaks = ds.nPeaks[:, np.newaxis, np.newaxis]
        peakX = ds.peakXPosRaw[:, :, np.newaxis]
        peakY = ds.peakYPosRaw[:, :, np.newaxis]

    original = ds.centered
    bg_corrected = da.map_blocks(proc2d.remove_background,
                                 original,
                                 original.shape[2] / 2,
                                 original.shape[1] / 2,
                                 nPeaks,
                                 peakX,
                                 peakY,
                                 peak_radius=opt.peak_radius,
                                 filter_len=opt.filter_len,
                                 dtype=np.float32 if opt.float else np.int32,
                                 chunks=original.chunks)

    ds.add_stack('centered', bg_corrected, overwrite=True)
    ds.change_filenames(opt.nobg_file_suffix)
    ds.init_files(keep_features=False, overwrite=True)
    ds.store_tables(shots=True, features=True)
    ds.open_stacks(readonly=False)

    # for lbl in ['nPeaks', 'peakTotalIntensity', 'peakXPosRaw', 'peakYPosRaw']:
    #    if lbl in ds.stacks:
    #        ds.delete_stack(lbl, from_files=False)

    try:
        ds.store_stacks(overwrite=True, progress_bar=False)
    except Exception as err:
        log('Error during background correction:', err)
        raise err
    finally:
        ds.close_stacks()

    return ds.files
Beispiel #7
0
def refine_center(fn, opt: PreProcOpts):
    """Refines the centering of diffraction patterns based on Friedel mate positions.
    
    Arguments:
        fn {str} -- [file/list name of input files, can contain wildcards]
        opt {PreProcOpts} -- [pre-processing options]
    
    Raises:
        err: [description]
    
    Returns:
        [list] -- [output files]
    """

    if isinstance(fn, list) and len(fn) == 1:
        fn = fn[0]

    def log(*args):
        if not (opt.verbose or any([isinstance(err, Exception)
                                    for e in args])):
            return
        if isinstance(fn, list):
            dispfn = os.path.basename(fn[0]) + ' etc.'
        else:
            dispfn = os.path.basename(fn)
        idstring = '[{} - {} - refine_center] '.format(
            datetime.datetime.now().time(), dispfn)
        print(idstring, *args)

    ds = Dataset.from_list(fn)
    stream = find_peaks(ds,
                        opt=opt,
                        merge_peaks=False,
                        return_cxi=False,
                        geo_params={'clen': opt.cam_length},
                        exc=opt.im_exc)

    p0 = [opt.xsize // 2, opt.ysize // 2]

    # get Friedel-refined center from stream file
    ctr = proc_peaks.center_friedel(stream.peaks,
                                    ds.shots,
                                    p0=p0,
                                    sigma=opt.peak_sigma,
                                    minpeaks=opt.min_peaks,
                                    maxres=opt.friedel_max_radius)

    maxcorr = ctr['friedel_cost'].values
    changed = np.logical_not(np.isnan(maxcorr))

    with ds.Stacks() as stk:
        beam_center_old = stk['beam_center'].compute()  # previous beam center

    beam_center_new = beam_center_old.copy()
    beam_center_new[changed, :] = np.ceil(beam_center_old[
        changed, :]) + ctr.loc[changed, ['beam_x', 'beam_y']].values - p0
    if (np.abs(np.mean(beam_center_new - beam_center_old, axis=0) > .5)).any():
        log('WARNING: average shift is larger than 0.5!')

    # visualization
    log(
        '{:g}% of shots refined. \n'.format(
            (1 - np.isnan(maxcorr).sum() / len(maxcorr)) * 100),
        'Shift standard deviation: {} \n'.format(
            np.std(beam_center_new - beam_center_old,
                   axis=0)), 'Average shift: {} \n'.format(
                       np.mean(beam_center_new - beam_center_old, axis=0)))

    # make new files and add the shifted images
    try:
        ds.open_stacks(readonly=False)
        centered2 = proc2d.center_image(ds.centered, ctr['beam_x'].values,
                                        ctr['beam_y'].values, 1556, 616, -1)
        ds.add_stack('centered', centered2, overwrite=True)
        ds.add_stack('pxmask_centered', (centered2 != -1).astype(np.uint16),
                     overwrite=True)
        ds.add_stack('beam_center', beam_center_new, overwrite=True)
        ds.change_filenames(opt.refined_file_suffix)
        print(ds.files)
        ds.init_files(keep_features=False, overwrite=True)
        ds.store_tables(shots=True, features=True)
        ds.open_stacks(readonly=False)
        ds.store_stacks(overwrite=True, progress_bar=False)
        ds.close_stacks()
        del centered2
    except Exception as err:
        log('Error during storing center-refined images', err)
        raise err
    finally:
        ds.close_stacks()

    # run peak finder again, this time on the refined images
    pks_cxi = find_peaks(ds,
                         opt=opt,
                         merge_peaks=opt.peaks_nexus,
                         return_cxi=True,
                         geo_params={'clen': opt.cam_length},
                         exc=opt.im_exc)

    # export peaks to CXI-format arrays
    if opt.peaks_cxi:
        with ds.Stacks() as stk:
            for k, v in pks_cxi.items():
                if k in stk:
                    ds.delete_stack(k, from_files=True)
                ds.add_stack(k, v, overwrite=True)
            ds.store_stacks(list(pks_cxi.keys()),
                            progress_bar=True,
                            overwrite=True)

    return ds.files
Beispiel #8
0
def from_raw(fn, opt: PreProcOpts):

    if isinstance(fn, list) and len(fn) == 1:
        fn = fn[0]

    def log(*args):
        if not (opt.verbose or any([isinstance(err, Exception)
                                    for e in args])):
            return
        if isinstance(fn, list):
            dispfn = os.path.basename(fn[0]) + ' etc.'
        else:
            dispfn = os.path.basename(fn)
        idstring = '[{} - {} - from_raw] '.format(
            datetime.datetime.now().time(), dispfn)
        print(idstring, *args)

    t0 = time()
    dsraw = Dataset().from_list(fn)

    reference = imread(opt.reference)
    pxmask = imread(opt.pxmask)

    os.makedirs(opt.scratch_dir, exist_ok=True)
    os.makedirs(opt.proc_dir, exist_ok=True)
    dsraw.open_stacks(readonly=True)

    if opt.aggregate:
        dsagg = dsraw.aggregate(file_suffix=opt.agg_file_suffix,
                                new_folder=opt.proc_dir,
                                force_commensurate=False,
                                how={'raw_counts': 'sum'},
                                query=opt.agg_query)
    else:
        dsagg = dsraw.get_selection(opt.agg_query,
                                    new_folder=opt.proc_dir,
                                    file_suffix=opt.agg_file_suffix)

    log(f'{dsraw.shots.shape[0]} raw, {dsagg.shots.shape[0]} aggregated/selected.'
        )

    if opt.rechunk is not None:
        dsagg.rechunk_stacks(opt.rechunk)

    # Saturation, flat-field and dead-pixel correction
    if opt.correct_saturation:
        stack_ff = proc2d.apply_flatfield(
            proc2d.apply_saturation_correction(dsagg.raw_counts,
                                               opt.shutter_time,
                                               opt.dead_time), reference)
    else:
        stack_ff = proc2d.apply_flatfield(dsagg.raw_counts, reference)

    stack = proc2d.correct_dead_pixels(stack_ff,
                                       pxmask,
                                       strategy='replace',
                                       replace_val=-1,
                                       mask_gaps=True)

    # Stack in central region along x (note that the gaps are not masked this time)
    xrng = slice((opt.xsize - opt.com_xrng) // 2,
                 (opt.xsize + opt.com_xrng) // 2)
    stack_ct = proc2d.correct_dead_pixels(stack_ff[:, :, xrng],
                                          pxmask[:, xrng],
                                          strategy='replace',
                                          replace_val=-1,
                                          mask_gaps=False)

    # Define COM threshold as fraction of highest pixel (after discarding some too high ones)
    thr = stack_ct.max(axis=1).topk(10, axis=1)[:, 9].reshape(
        (-1, 1, 1)) * opt.com_threshold
    com = proc2d.center_of_mass2(
        stack_ct, threshold=thr) + [[(opt.xsize - opt.com_xrng) // 2, 0]]

    # Lorentzian fit in region around the found COM
    lorentz = compute.map_reduction_func(stack,
                                         proc2d.lorentz_fast,
                                         com[:, 0],
                                         com[:, 1],
                                         radius=opt.lorentz_radius,
                                         limit=opt.lorentz_maxshift,
                                         scale=7,
                                         threads=False,
                                         output_len=4)
    ctr = lorentz[:, 1:3]

    # calculate the centered image by shifting and padding with -1
    centered = proc2d.center_image(stack,
                                   ctr[:, 0],
                                   ctr[:, 1],
                                   opt.xsize,
                                   opt.ysize,
                                   -1,
                                   parallel=True)

    # add the new stacks to the aggregated dataset
    alldata = {
        'center_of_mass':
        com,
        'lorentz_fit':
        lorentz,
        'beam_center':
        ctr,
        'centered':
        centered,
        'pxmask_centered': (centered != -1).astype(np.uint16),
        'adf1':
        proc2d.apply_virtual_detector(centered, opt.r_adf1[0], opt.r_adf1[1]),
        'adf2':
        proc2d.apply_virtual_detector(centered, opt.r_adf2[0], opt.r_adf2[1])
    }
    for lbl, stk in alldata.items():
        print('adding', lbl, stk.shape)
        dsagg.add_stack(lbl, stk, overwrite=True)

    # make the files and crunch the data
    try:
        dsagg.init_files(overwrite=True)
        dsagg.store_tables(shots=True, features=True)
        dsagg.open_stacks(readonly=False)
        dsagg.delete_stack(
            'raw_counts',
            from_files=False)  # we don't need the raw counts in the new files
        dsagg.store_stacks(
            overwrite=True,
            progress_bar=False)  # this does the actual calculation

        log('Finished first centering', dsagg.centered.shape[0], 'shots after',
            time() - t0, 'seconds')

    except Exception as err:
        log('Raw processing failed.', err)
        raise err

    finally:
        dsagg.close_stacks()
        dsraw.close_stacks()

    return dsagg.files
Beispiel #9
0
def main():

    parser = argparse.ArgumentParser(
        description=
        'Quick and dirty pre-processing for Serial Electron Diffraction data',
        allow_abbrev=False,
        epilog=
        'Any other options are passed on as modification to the option file')
    parser.add_argument(
        'filename',
        type=str,
        nargs='*',
        help=
        'List or HDF5 file or glob pattern. Glob pattern must be given in SINGLE quotes.'
    )
    parser.add_argument('-s',
                        '--settings',
                        type=str,
                        help='Option YAML file. Defaults to \'preproc.yaml\'.',
                        default='preproc.yaml')
    parser.add_argument(
        '-A',
        '--address',
        type=str,
        help=
        'Address of an existing dask.distributed cluster to use instead of making a new one. Defaults to making a new one.',
        default=None)
    parser.add_argument(
        '-N',
        '--nprocs',
        type=int,
        help=
        'Number of processes of a new dask.distributed cluster. Defaults to letting dask decide.',
        default=None)
    parser.add_argument(
        '-L',
        '--local-directory',
        type=str,
        help=
        'Fast (scratch) directory for computations. Defaults to the current directory.',
        default=None)
    parser.add_argument(
        '-c',
        '--chunksize',
        type=int,
        help=
        'Chunk size of raw data stack. Should be integer multiple of movie stack frames! Defaults to 100.',
        default=100)
    parser.add_argument('-l',
                        '--list-file',
                        type=str,
                        help='Name of output list file',
                        default='processed.lst')
    parser.add_argument('-w',
                        '--wait-for-files',
                        help='Wait for files matching wildcard pattern',
                        action='store_true')
    parser.add_argument(
        '--include-existing',
        help='When using -w/--wait-for-file, also include existing files',
        action='store_true')
    parser.add_argument('--append',
                        help='Append to list instead of overwrite',
                        action='store_true')
    parser.add_argument(
        '-d',
        '--data-path-old',
        type=str,
        help='Raw data field in HDF5 file(s). Defaults to /entry/data/raw_data',
        default='/%/data/raw_counts')
    parser.add_argument(
        '-n',
        '--data-path-new',
        type=str,
        help=
        'Corrected data field in HDF5 file(s). Defaults to /entry/data/corrected',
        default='/%/data/corrected')
    parser.add_argument('--no-bgcorr',
                        help='Skip background correction',
                        action='store_true')
    parser.add_argument(
        '--no-validate',
        help='Do not validate files before attempting to process',
        action='store_true')
    # parser.add_argument('ppopt', nargs=argparse.REMAINDER, help='Preprocessing options to be overriden')

    args, extra = parser.parse_known_args()
    # print(args, extra)
    # raise RuntimeError('thus far!')
    opts = pre_proc_opts.PreProcOpts(args.settings)

    label_raw = args.data_path_old.rsplit('/', 1)[-1]
    label = args.data_path_new.rsplit('/', 1)[-1]

    if extra:
        # If extra arguments have been supplied, overwrite existing values
        opt_parser = argparse.ArgumentParser()
        for k, v in opts.__dict__.items():
            opt_parser.add_argument('--' + k, type=type(v), default=None)
        opts2 = opt_parser.parse_args(extra)

        for k, v in vars(opts2).items():
            if v is not None:
                if type(v) != type(opts.__dict__[k]):
                    warn('Mismatch of data types in overriden argument!',
                         RuntimeWarning)
                print(
                    f'Overriding option file setting {k} = {opts.__dict__[k]} ({type(opts.__dict__[k])}). ',
                    f'New value is {v} ({type(v)})')
                opts.__dict__[k] = v

    # raise RuntimeError('thus far!')
    print(f'Running on diffractem:', version())
    print(f'Current path is:', os.getcwd())

    # client = Client()
    if args.address is not None:
        print('Connecting to cluster scheduler at', args.address)

        try:
            client = Client(address=args.address, timeout=3)
        except:
            print(
                f'\n----\nThere seems to be no dask.distributed scheduler running at {args.address}.\n'
                f'Please double-check or start one by either omitting the --address option.'
            )
            return
    else:
        print('Creating a dask.distributed cluster...')
        client = Client(n_workers=args.nprocs,
                        local_directory=args.local_directory,
                        processes=True)
        print('\n\n---\nStarted dask.distributed cluster:')
        print(client)
        print('You can access the dashboard for monitoring at: ',
              client.dashboard_link)

    client.run(os.chdir, os.getcwd())

    if len(args.filename) == 1:
        args.filename = args.filename[0]

    # print(args.filename)
    seen_raw_files = [] if args.include_existing else io.expand_files(
        args.filename)

    while True:

        if args.wait_for_files:

            # slightly awkward sequence to only open finished files... (but believe me - it works!)

            fns = io.expand_files(args.filename)
            # print(fns)
            fns = [fn for fn in fns if fn not in seen_raw_files]
            # validation...
            try:
                fns = io.expand_files(fns, validate=not args.no_validate)
            except (OSError, IOError, RuntimeError) as err:
                print(f'Could not open file(s) {" ".join(fns)} because of',
                      err)
                print(
                    'Possibly, it is still being written to. Waiting a bit...')
                sleep(5)
                continue

            if not fns:
                print('No new files, waiting a bit...')
                sleep(5)
                continue
            else:
                print(f'Found new files(s):\n', '\n'.join(fns))
                try:
                    ds_raw = Dataset.from_files(fns, chunking=args.chunksize)
                except Exception as err:
                    print(f'Could not open file(s) {" ".join(fns)} because of',
                          err)
                    print(
                        'Possibly, it is still being written to. Waiting a bit...'
                    )
                    sleep(5)
                    continue

        else:
            fns = io.expand_files(args.filename, validate=not args.no_validate)
            if fns:
                ds_raw = Dataset.from_files(fns, chunking=args.chunksize)
            else:
                print(
                    f'\n---\n\nFile(s) {args.filename} not found or (all of them) invalid.'
                )
                return

        seen_raw_files.extend(ds_raw.files)

        print('---- Have dataset ----')
        print(ds_raw)

        # delete undesired stacks
        delstacks = [
            sn for sn in ds_raw.stacks.keys()
            if sn != args.data_path_old.rsplit('/', 1)[-1]
        ]
        for sn in delstacks:
            ds_raw.delete_stack(sn)

        if opts.aggregate:
            print('---- Aggregating raw data ----')
            ds_compute = ds_raw.aggregate(
                query=opts.agg_query,
                by=['sample', 'region', 'run', 'crystal_id'],
                how='sum',
                new_folder=opts.proc_dir,
                file_suffix=opts.agg_file_suffix)
        else:
            ds_compute = ds_raw.get_selection(query=opts.select_query,
                                              file_suffix=opts.agg_file_suffix)

        print('Initializing data files...')
        os.makedirs(opts.proc_dir, exist_ok=True)
        ds_compute.init_files(overwrite=True)

        print('Storing meta tables...')
        ds_compute.store_tables(shots=True, features=True)

        print(
            f'Processing diffraction data... monitor progress at {client.dashboard_link} (or forward port if remote)'
        )
        chunk_info = quick_proc(ds_compute, opts, label_raw, label, client)

        # make sure that the calculation went consistent with the data set
        for (sh, sh_grp), (ch, ch_grp) in zip(
                ds_compute.shots.groupby(['file', 'subset']),
                chunk_info.groupby(['file', 'subset'])):
            if any(sh_grp.shot_in_subset.values != np.sort(
                    np.concatenate(ch_grp.shot_in_subset.values))):
                raise ValueError(
                    f'Incosistency between calculated data and shot list in {sh[0]}: {sh[1]} found. Please investigate.'
                )

        ds_compute.write_list(args.list_file, append=args.append)

        print(f'Computation done. Processed files are in {args.list_file}')

        if not args.wait_for_files:
            break