예제 #1
0
    def test_init_image(self):
        """Unit tests for initializing MarCCD from image"""

        # Initialize from image, no attributes provided
        mccd = marccd.MarCCD(self.testImage)
        self.assertFalse(np.array_equal(np.empty((0, 0)), mccd.image))
        self.assertEqual(basename(self.testImage), mccd.name)
        self.assertEqual(199.995, mccd.distance)
        self.assertEqual((1989.0, 1974.0), mccd.center)
        self.assertEqual((88.6, 88.6), mccd.pixelsize)
        self.assertEqual(1.0264, mccd.wavelength)
        self.assertNotEqual(b'\x00' * 3072, mccd._mccdheader)

        # Initialize from image, provide attributes to ensure they are
        # prioritized over MCCD header
        mccd = marccd.MarCCD(self.testImage,
                             name="name",
                             distance=200.0,
                             center=(1985.3, 1975.4),
                             pixelsize=(88.0, 88.0),
                             wavelength=1.0255)
        self.assertFalse(np.array_equal(np.empty((0, 0)), mccd.image))
        self.assertEqual("name", mccd.name)
        self.assertEqual(200.0, mccd.distance)
        self.assertEqual((1985.3, 1975.4), mccd.center)
        self.assertEqual((88.0, 88.0), mccd.pixelsize)
        self.assertEqual(1.0255, mccd.wavelength)
        self.assertNotEqual(b'\x00' * 3072, mccd._mccdheader)

        return
예제 #2
0
    def test_init_empty(self):
        """Unit tests for MarCCD default empty constructor"""

        # Empty image, no attributes provided
        mccd = marccd.MarCCD()
        self.assertTrue(np.array_equal(np.empty((0, 0)), mccd.image))
        self.assertIsNone(mccd.name)
        self.assertIsNone(mccd.distance)
        self.assertIsNone(mccd.center)
        self.assertIsNone(mccd.pixelsize)
        self.assertIsNone(mccd.wavelength)
        self.assertEqual(b'\x00' * 3072, mccd._mccdheader)

        # Empty image, provide attributes
        mccd = marccd.MarCCD(name="name",
                             distance=200.0,
                             center=(1985.3, 1975.4),
                             pixelsize=(88.6, 88.6),
                             wavelength=1.0255)
        self.assertTrue(np.array_equal(np.empty((0, 0)), mccd.image))
        self.assertEqual("name", mccd.name)
        self.assertEqual(200.0, mccd.distance)
        self.assertEqual((1985.3, 1975.4), mccd.center)
        self.assertEqual((88.6, 88.6), mccd.pixelsize)
        self.assertEqual(1.0255, mccd.wavelength)
        self.assertEqual(b'\x00' * 3072, mccd._mccdheader)

        # Invalid data argument
        with self.assertRaises(ValueError):
            marccd.MarCCD(10)
        with self.assertRaises(ValueError):
            marccd.MarCCD(data=10)

        return
예제 #3
0
    def test_repr(self):
        """Unit tests for Marccd.__repr__() method"""
        mccd = marccd.MarCCD()
        dims = mccd.dimensions
        self.assertTrue(
            f"<marccd.MarCCD with {dims[0]}x{dims[1]} pixels" in str(mccd))

        randimage = np.random.randint(0, 100, (500, 500), np.uint16)
        mccd = marccd.MarCCD(randimage)
        dims = mccd.dimensions
        self.assertTrue(
            f"<marccd.MarCCD with {dims[0]}x{dims[1]} pixels" in str(mccd))

        return
예제 #4
0
def worker(wrk_num, header, imgfiles, queue, algorithm, ADCthresh, MinSNR,
           MinPixCount, MaxPixCount, LocalBGRadius, MinPeakSeparation, d_min,
           d_max):
    """
    The code taken from yamtbx/dataproc/myspotfinder/command_line/spot_finder_backend.py (New BSD License)
    """

    cheetah = CheetahSinglePanel()
    cheetah.set_params(ADCthresh=ADCthresh,
                       MinSNR=MinSNR,
                       MinPixCount=MinPixCount,
                       MaxPixCount=MaxPixCount,
                       LocalBGRadius=LocalBGRadius,
                       MinPeakSeparation=MinPeakSeparation,
                       Dmin=d_min,
                       Dmax=d_max)

    for f in imgfiles:
        startt = time.time()
        data = marccd.MarCCD(f).read_data()
        result = cheetah_worker(header, data, algorithm, cheetah)
        result["frame"] = f
        eltime = time.time() - startt
        #print "Wrkr%3d %6d done in %.2f msec " % (wrk_num, frame, eltime*1.e3)
        print "%s %6d %.2f" % (f, len(result["spots"]), eltime * 1.e3)
        queue.put(result)
예제 #5
0
    def test_dimensions(self):
        """Unit tests for MarCCD dimensions attribute"""

        for dims in [(500, 500), (0, 0), (1000, 1300)]:
            mccd = marccd.MarCCD(np.zeros(dims, dtype=np.uint16))
            self.assertEqual(dims, mccd.dimensions)

        return
예제 #6
0
    def test_init_ndarray(self):
        """Unit tests for initializing MarCCD from ndarray"""

        randimage = np.random.randint(low=0,
                                      high=(2**16) - 1,
                                      size=(500, 500),
                                      dtype=np.uint16)

        # ndarray image, no attributes provided
        mccd = marccd.MarCCD(randimage)
        self.assertEqual((500, 500), mccd.image.shape)
        self.assertIsNone(mccd.name)
        self.assertIsNone(mccd.distance)
        self.assertIsNone(mccd.center)
        self.assertIsNone(mccd.pixelsize)
        self.assertIsNone(mccd.wavelength)
        self.assertEqual(b'\x00' * 3072, mccd._mccdheader)

        # ndarray image, provide attributes
        mccd = marccd.MarCCD(randimage,
                             name="name",
                             distance=200.0,
                             center=(1985.3, 1975.4),
                             pixelsize=(88.6, 88.6),
                             wavelength=1.0255)
        self.assertEqual((500, 500), mccd.image.shape)
        self.assertEqual("name", mccd.name)
        self.assertEqual(200.0, mccd.distance)
        self.assertEqual((1985.3, 1975.4), mccd.center)
        self.assertEqual((88.6, 88.6), mccd.pixelsize)
        self.assertEqual(1.0255, mccd.wavelength)
        self.assertEqual(b'\x00' * 3072, mccd._mccdheader)

        # providing dtype other than np.uint16 should generate warning
        randimage = np.random.randint(0, 100, (500, 500), np.int16)
        with self.assertWarns(UserWarning):
            mccd = marccd.MarCCD(randimage)

        return
예제 #7
0
    def test_readwrite(self):
        """Unit test for MarCCD reading and writing"""
        import filecmp, os

        # Test round trip leaves MCCD image unchanged
        mccd = marccd.MarCCD(self.testImage)
        datadir = dirname(self.testImage)
        temp = join(datadir, "temp.mccd")
        mccd.write(temp)
        self.assertTrue(filecmp.cmp(self.testImage, temp))
        os.remove(temp)

        return
예제 #8
0
    def test_write(self):
        """Unit tests for mccd.write()"""

        mccdobj = marccd.MarCCD(self.testImage)
        datadir = dirname(self.testImage)
        temp = join(datadir, "temp.mccd")

        # _mccdheader attribute exists
        mccd.write(mccdobj, temp)
        self.assertTrue(os.path.exists(temp))
        os.remove(temp)

        # _mccdheader attribute does not exist
        mccdobj._mccdheader = None
        with self.assertRaises(AttributeError):
            mccd.write(mccdobj, temp)
        os.remove(temp)

        return
예제 #9
0
def make_geom(f, geom_out, beam_x=None, beam_y=None, clen=None):
    im = marccd.MarCCD(f)

    if beam_x and beam_x == beam_x: im.beam_x = beam_x  # in px
    if beam_y and beam_y == beam_y: im.beam_y = beam_y  # in px
    if clen and clen == clen: im.distance = clen  # in mm

    # according to xds, gain of /xustrg0/rayonix/2018A/LRE1/lre-1343-1/001/sample_000????.img is ~0.2.
    # usually ccd's gain is underetimated (that is, photon is actually less) by a factor of 4 (a=4 in xds)
    # so 0.2*4=0.8 is appropriate value for adu_per_photon?

    s = """\
clen = %(clen)f
res = %(res).4f
data = /%%/data
photon_energy = /%%/photon_energy_ev

rigid_group_q0 = q0
rigid_group_collection_connected = q0
rigid_group_collection_independent = q0

q0/adu_per_eV = 0.8e-4 ; dummy for old hdfsee
q0/adu_per_photon = 0.8
q0/max_adu = %(max_adu)d
q0/min_fs = 0
q0/max_fs = %(fsmax)d
q0/min_ss = 0
q0/max_ss = %(ssmax)d
q0/corner_x = %(cornerx).2f
q0/corner_y = %(cornery).2f
q0/fs = -x
q0/ss = -y
""" % dict(fsmax=im.nfast - 1,
           ssmax=im.nslow - 1,
           cornerx=im.beam_x,
           cornery=im.beam_y,
           clen=im.distance / 1000.,
           res=1. / (im.pixel_x * 1.e-3),
           max_adu=im.saturated_value)

    open(geom_out, "w").write(s)
예제 #10
0
def run(runid):
    img_files = sorted(
        glob.glob(os.path.join(rayonix_root, str(runid), "data_*.img")))
    img_num = map(lambda f: int(f[f.rindex("_") + 1:f.rindex(".")]), img_files)

    acq_time, header_time, save_time, ctime = [], [], [], []
    for f in img_files:
        img = marccd.MarCCD(f)
        acq_time.append(img.acquire_time)
        header_time.append(img.header_time)
        save_time.append(img.save_time)
        ctime.append(datetime.datetime.fromtimestamp(os.path.getctime(f)))

    #for i in xrange(1, len(img_files)):
    #    print img_num[i]-img_num[i-1], acq_time[i]-acq_time[i-1]

    print "run num file time.acq time.header time.save time.ctime"
    for i in xrange(len(img_files)):
        print "%d %5d %s" % (runid, img_num[i], img_files[i]),
        print acq_time[i].strftime('"%Y-%m-%d %H:%M:%S.%f"'),
        print header_time[i].strftime('"%Y-%m-%d %H:%M:%S.%f"'),
        print save_time[i].strftime('"%Y-%m-%d %H:%M:%S.%f"'),
        print ctime[i].strftime('"%Y-%m-%d %H:%M:%S.%f"')
예제 #11
0
def run(opts):
    eltime_from = time.time()
    print("#\n#Configurations:")
    print("# runNumber (-r/--run):         %d" % opts.runid)
    print("# output H5 file (-o/--output): %s (default = run######.h5)" %
          opts.outputH5)
    print("# beamline (--bl):              %d (default = 3)" % opts.bl)
    print("# img root (--rayonix-root):    %s" % opts.rayonix_root)
    print("# distance (--clen):            %s" % opts.clen)
    print("# beam center (--beam-x/y):     %s,%s" % (opts.beam_x, opts.beam_y))
    print("# Cheetah settings")
    print("#  --dmin, --dmax:              %s,%s" % (opts.d_min, opts.d_max))
    print("#  --adc-threshold:             %s" % opts.ADCthresh)
    print("#  --min-snr:                   %s" % opts.MinSNR)
    print("#  --min/max-pixcount:          %s,%s" %
          (opts.MinPixCount, opts.MaxPixCount))
    print("#  --local-bgradius:            %s" % opts.LocalBGRadius)
    print("#  --min-peaksep:               %s" % opts.MinPeakSeparation)
    print("#  --min-spots:                 %s" % opts.min_spots)
    print("#  --algorithm:                 %s" % opts.algorithm)
    print("# PD1 threshold (--pd1_thresh): %.3f (default = 0; ignore.)" %
          opts.pd1_threshold)
    print("# PD2 threshold (--pd2_thresh): %.3f (default = 0; ignore.)" %
          opts.pd2_threshold)
    print("# PD3 threshold (--pd3_thresh): %.3f (default = 0; ignore.)" %
          opts.pd3_threshold)
    print("# PD1 sensor name (--pd1_name): %s)" % opts.pd1_sensor_name)
    print("# PD2 sensor name (--pd2_name): %s)" % opts.pd2_sensor_name)
    print("# PD3 sensor name (--pd3_name): %s)" % opts.pd3_sensor_name)
    print(
        "# nFrame after light:           %d (default = -1; accept all image. -2; accept all dark images)"
        % opts.light_dark)
    print(
        "# parallel_block:               %d (default = -1; no parallelization)"
        % opts.parallel_block)
    print("# nproc:                        %d (default = 1)" % opts.nproc)
    print("")

    assert opts.algorithm in (6, 8)
    assert opts.runid is not None
    assert opts.bl is not None

    # Beamline specific constants
    if opts.bl == 2:
        sensor_spec = "xfel_bl_2_tc_spec_1/energy"
        sensor_shutter = "xfel_bl_2_shutter_1_open_valid/status"
    elif opts.bl == 3:
        sensor_spec = "xfel_bl_3_tc_spec_1/energy"
        sensor_shutter = "xfel_bl_3_shutter_1_open_valid/status"
    else:
        error_status("BadBeamline")
        return -1

    # Get run info
    try:
        run_info = dbpy.read_runinfo(opts.bl, opts.runid)
    except:
        error_status("BadRunID")
        return -1

    high_tag = dbpy.read_hightagnumber(opts.bl, opts.runid)
    start_tag = run_info['start_tagnumber']
    end_tag = run_info['end_tagnumber']

    tag_list = numpy.array(dbpy.read_taglist_byrun(opts.bl, opts.runid))
    print "# Run %d: HighTag %d, Tags %d (inclusive) to %d (exclusive), thus %d tags" % (
        opts.runid, high_tag, start_tag, end_tag, len(tag_list))
    comment = dbpy.read_comment(opts.bl, opts.runid)
    print "# Comment: %s" % comment
    print

    # Get shutter status and find images
    try:
        shutter = numpy.array(
            map(
                str2float,
                dbpy.read_syncdatalist(sensor_shutter, high_tag,
                                       tuple(tag_list))))
    except:
        print traceback.format_exc()
        error_status("NoShutterStatus")
        return -1

    # XXX
    valid_tags = tag_list[
        shutter ==
        1]  # [tag for tag, is_open in zip(tag_list, shutter) if is_open == 1]
    print "# DEBUG:: shutter=", shutter
    print "# DEBUG:: valid_tags=", valid_tags
    if 0:
        tag_offset = 3
        tag_list = tag_list[tag_offset:]
        valid_tags = tag_list[numpy.arange(1, len(tag_list) + 1) % 6 == 0]

    if valid_tags.size == 0:
        error_status("NoValidTags")
        return -1

    # Find images
    img_files = sorted(
        glob.glob(
            os.path.join(opts.rayonix_root, str(opts.runid), "data_*.img")))
    print "# DEBUG:: img_files=%d valid_tags=%d" % (len(img_files),
                                                    len(valid_tags))
    if len(img_files) + 1 != len(valid_tags):  # last valid tag is not saved.
        print "# WARNING!! img_files and valid_tag number mismatch"

        img_numbers = map(lambda x: int(x[x.rindex("_") + 1:-4]), img_files)
        dropped_frames = sorted(
            set(range(1, len(valid_tags))).difference(img_numbers))
        print "# Unsaved frame numbers =", tuple(dropped_frames)
        print "# DEBUG::", len(img_files) - len(dropped_frames) + 1, len(
            valid_tags)
        if len(img_files) + len(dropped_frames) + 1 == len(valid_tags):
            print "#  %d unsaved img files found, which explains number mismatch" % len(
                dropped_frames)
            valid_tags = numpy.delete(valid_tags,
                                      numpy.array(dropped_frames) - 1)
            assert len(img_files) + 1 == len(valid_tags)
        else:
            print "# Assuming last %d img files are generated after stopping run.." % (
                len(img_files) - len(valid_tags) + 1)
            img_files = img_files[:len(valid_tags) - 1]
            assert len(img_files) + 1 == len(valid_tags)

    # Get photon energies
    photon_energies_in_keV = numpy.array([
        str2float(s) for s in dbpy.read_syncdatalist(sensor_spec, high_tag,
                                                     tuple(valid_tags))
    ])
    mean_photon_energy = numpy.mean(photon_energies_in_keV[
        photon_energies_in_keV ==
        photon_energies_in_keV])  # XXX if no valid data?
    print "# Photon energies obtained: %d valid numbers, %d invalid, average=%f sd=%f" % (
        len(photon_energies_in_keV),
        sum(photon_energies_in_keV != photon_energies_in_keV),
        mean_photon_energy,
        numpy.std(photon_energies_in_keV[photon_energies_in_keV ==
                                         photon_energies_in_keV]))
    photon_energies_in_keV[
        photon_energies_in_keV != photon_energies_in_keV] = mean_photon_energy

    # Get PD values
    pd1_values, pd2_values, pd3_values = [], [], []
    if opts.pd1_threshold != 0:
        pd1_values = numpy.array(
            map(
                str2float,
                dbpy.read_syncdatalist(opts.pd1_sensor_name, high_tag,
                                       tuple(valid_tags))))
    if opts.pd2_threshold != 0:
        pd2_values = numpy.array(
            map(
                str2float,
                dbpy.read_syncdatalist(opts.pd2_sensor_name, high_tag,
                                       tuple(valid_tags))))
    if opts.pd3_threshold != 0:
        pd3_values = numpy.array(
            map(
                str2float,
                dbpy.read_syncdatalist(opts.pd3_sensor_name, high_tag,
                                       tuple(valid_tags))))

    # Identify bad tags
    # XXX not tested!! this feature must not be used. tags with bad PDs must be detected after experiment.
    frame_after_light = 9999
    bad_tag_idxes = []
    for i in xrange(len(valid_tags)):
        light = True
        if (opts.pd1_threshold != 0
                and not (opts.pd1_threshold > 0
                         and opts.pd1_threshold <= pd1_values[i])
                and not (opts.pd1_threshold < 0
                         and -opts.pd1_threshold > pd1_values[i])):
            light = False
        if (opts.pd2_threshold != 0
                and not (opts.pd2_threshold > 0
                         and opts.pd2_threshold <= pd2_values[i])
                and not (opts.pd2_threshold < 0
                         and -opts.pd2_threshold > pd2_values[i])):
            light = False
        if (opts.pd3_threshold != 0
                and not (opts.pd3_threshold > 0
                         and opts.pd3_threshold <= pd3_values[i])
                and not (opts.pd3_threshold < 0
                         and -opts.pd3_threshold > pd3_values[i])):
            light = False

        if light:
            frame_after_light = 0
        else:
            frame_after_light += 1

        if ((opts.light_dark >= 0 and frame_after_light != opts.light_dark) or
            (opts.light_dark == PD_DARK_ANY and frame_after_light == 0)):
            print "# PD check: %d is bad tag!" % valid_tags[i]
            bad_tag_idxes.append(i)

    if bad_tag_idxes:
        valid_tags = numpy.delete(valid_tags, numpy.array(bad_tag_idxes))
        for i in reversed(bad_tag_idxes):
            del img_files[i]

    # Debug code; this takes too much time!
    try:
        if 0 and opts.parallel_block == 0:
            tag_timestamp = map(
                lambda x: datetime.datetime.fromtimestamp(
                    dbpy.read_timestamp_fromtag(high_tag, x, sensor_shutter)).
                strftime('%Y-%m-%d %H:%M:%S.%f'), valid_tags)
            img_timestamp = map(
                lambda x: marccd.MarCCD(x).acquire_time.strftime(
                    '%Y-%m-%d %H:%M:%S.%f'), img_files)
            ofs = open("tag_file_time.dat", "w")
            ofs.write("run tag file tag.time file.time\n")
            for i in xrange(len(img_files)):
                ofs.write('%d %d %s "%s" "%s"\n' %
                          (opts.runid, valid_tags[i], img_files[i],
                           tag_timestamp[i], img_timestamp[i]))
            ofs.close()
    except:
        pass

    # block spliting
    # TODO db query may be slow, which may need to be done only once?
    if opts.parallel_block >= 0:
        width = len(valid_tags) // parallel_size
        i_start = opts.parallel_block * width
        i_end = (opts.parallel_block + 1
                 ) * width if opts.parallel_block < parallel_size - 1 else None
        valid_tags = valid_tags[i_start:i_end]
        photon_energies_in_keV = photon_energies_in_keV[i_start:i_end]
        img_files = img_files[i_start:i_end]
        print "# parallel_block=%d: %d tags will be processed (%d..%d)" % (
            opts.parallel_block, len(valid_tags), valid_tags[0],
            valid_tags[-1])

    make_geom(img_files[0],
              opts.output_geom,
              beam_x=opts.beam_x,
              beam_y=opts.beam_y,
              clen=opts.clen)

    # Hit-finding
    results = process_images(img_files, mean_photon_energy, opts)
    file_tag_ene = []
    for frame, tag, ene in zip(sorted(results), valid_tags,
                               photon_energies_in_keV):
        if len(results[frame]["spots"]) < opts.min_spots:
            continue
        file_tag_ene.append((frame, tag, ene))

    # TODO on-the-fly status updating
    open("status.txt", "w").write("""\
# Cheetah status
Update time: %(ctime)s
Elapsed time: %(eltime)f sec
Status: Total=%(ntotal)d,Processed=%(ntotal)d,LLFpassed=%(ntotal)d,Hits=%(nhits)d,Status=WritingH5
Frames processed: %(ntotal)d
Number of hits: %(nhits)d
""" % dict(ctime=time.ctime(),
           eltime=time.time() - eltime_from,
           ntotal=len(img_files),
           nhits=len(file_tag_ene)))

    # Save h5
    # TODO implement on-the-fly h5 file writing in hit-finding to avoid reading img file twice.
    make_h5(out=opts.outputH5, file_tag_ene=file_tag_ene, comment=comment)

    open("status.txt", "w").write("""\
# Cheetah status
Update time: %(ctime)s
Elapsed time: %(eltime)f sec
Status: Total=%(ntotal)d,Processed=%(ntotal)d,LLFpassed=%(ntotal)d,Hits=%(nhits)d,Status=Finished
Frames processed: %(ntotal)d
Number of hits: %(nhits)d
""" % dict(ctime=time.ctime(),
           eltime=time.time() - eltime_from,
           ntotal=len(img_files),
           nhits=len(file_tag_ene)))

    ofs = open("cheetah.dat", "w")
    ofs.write("file tag nspots total_snr\n")
    for frame, tag in zip(sorted(results), valid_tags):
        ret = results[frame]
        n_spots = len(ret["spots"])
        total_snr = sum(map(lambda x: x[2], ret["spots"]))
        ofs.write("%s %d %6d %.3e\n" % (frame, tag, n_spots, total_snr))
    ofs.close()

    if opts.gen_adx:
        for frame in sorted(results):
            ret = results[frame]
            adx_out = open(os.path.basename(frame) + ".adx", "w")
            for x, y, snr, d in ret["spots"]:
                adx_out.write("%6d %6d %.2e\n" % (x, y, snr))
            adx_out.close()
예제 #12
0
def process_images(img_files, mean_photon_energy, opts):
    startt = time.time()

    nframes = len(img_files)
    first_img = img_files[0]  # Assumes all images are with the same conditions
    im = marccd.MarCCD(first_img)

    header = dict(beam_center_x=im.beam_x,
                  beam_center_y=im.beam_y,
                  pixel_size_x=im.pixel_x,
                  distance=im.distance,
                  wavelength=im.wavelength)

    if opts.clen:
        print "# overriding camera distance = %f (header value: %f)" % (
            opts.clen, header["distance"])
        header["distance"] = opts.clen

    if mean_photon_energy and mean_photon_energy == mean_photon_energy:
        print "# overriding wavelength = %f (header value: %f)" % (
            12.3984 / mean_photon_energy, header["wavelength"])
        header["wavelength"] = 12.3984 / mean_photon_energy

    if opts.beam_x and opts.beam_x == opts.beam_x:
        print "# overriding beam_x = %f (header value: %f)" % (
            opts.beam_x, header["beam_center_x"])
        header["beam_center_x"] = opts.beam_x

    if opts.beam_y and opts.beam_y == opts.beam_y:
        print "# overriding beam_y = %f (header value: %f)" % (
            opts.beam_y, header["beam_center_y"])
        header["beam_center_y"] = opts.beam_y

    print "#%s" % header
    print "#frames= %d" % nframes

    ranges = map(tuple, numpy.array_split(numpy.arange(nframes), opts.nproc))
    queue = Queue()
    pp = []
    print "frame nspots eltime"
    for i, rr in enumerate(ranges):
        if not rr: continue
        p = Process(target=worker,
                    args=(i, header, map(lambda x: img_files[x],
                                         rr), queue, opts.algorithm,
                          opts.ADCthresh, opts.MinSNR, opts.MinPixCount,
                          opts.MaxPixCount, opts.LocalBGRadius,
                          opts.MinPeakSeparation, opts.d_min, opts.d_max))
        p.start()
        pp.append(p)

    results = {}

    while any(map(lambda p: p.is_alive(), pp)):
        while not queue.empty():
            ret = queue.get()
            results[ret["frame"]] = ret

    for p in pp:
        p.join()

    print "# hit finding for %d images finished in %.2f sec." % (
        len(img_files), time.time() - startt)

    return results
예제 #13
0
    def write_par(i):
        f, tag, ene = file_tag_ene[i]
        tmp = time.time()
        data = marccd.MarCCD(f).read_data()
        grp = of.create_group("tag-%d" % tag)
        if ene != ene: ene = default_energy

        if compression == "shuf+gz":
            if data.shape[0] % 384 == 0:
                chunks = (384, 384)
            elif data.shape[0] % 256 == 0:
                chunks = (256, 256)
            else:
                chunks = data.shape

            assert len(data.shape) == 2
            assert len(chunks) == 2
            assert data.shape[1] % chunks[
                1] == 0  # because we didn't implement padding to fill a chunk

            as_uint8 = data.view(
                dtype=numpy.uint8
            )  # ONLY the length of the fast axis is doubled
            itemsize = data.dtype.itemsize

            cy = int(numpy.ceil(data.shape[0] / chunks[0]))
            cx = int(numpy.ceil(data.shape[1] / chunks[1]))

            compressed_chunks = [None] * (cy * cx)
            for iy in xrange(cy):
                for ix in xrange(cx):
                    sy = iy * chunks[0]
                    sx = ix * chunks[1]
                    ey = (iy + 1) * chunks[0]
                    ex = (ix + 1) * chunks[1]
                    if ey > data.shape[0]: ey = data.shape[0]
                    if ex > data.shape[1]: ex = data.shape[1]

                    my_chunk = as_uint8[sy:ey, (sx * itemsize):(ex * itemsize)]
                    shuffled = my_chunk.reshape(
                        (-1, data.dtype.itemsize)).transpose().reshape(-1)
                    compressed_chunks[iy * cx + ix] = zlib.compress(
                        shuffled.tobytes(), compression_level)

        lock.acquire()
        grp["photon_energy_ev"] = ene * 1000.
        grp["photon_wavelength_A"] = 12.3984 / ene
        grp["original_file"] = f

        if not compression:
            grp.create_dataset("data", data.shape, dtype=data.dtype, data=data)
        elif compression == "bslz4":
            grp.create_dataset(
                "data",
                data.shape,
                compression=bitshuffle.h5.H5FILTER,
                compression_opts=(0, bitshuffle.h5.H5_COMPRESS_LZ4),
                dtype=data.dtype,
                data=data)
        elif compression == "shuf+gz":
            dataset = grp.create_dataset("data",
                                         data.shape,
                                         chunks=chunks,
                                         compression="gzip",
                                         shuffle=True,
                                         dtype=data.dtype)

            for iy in xrange(cy):
                for ix in xrange(cx):
                    sy = iy * chunks[0]
                    sx = ix * chunks[1]

                    dataset.id.write_direct_chunk(
                        offsets=(sy, sx),
                        data=compressed_chunks[iy * cx + ix],
                        filter_mask=0)
        else:
            raise "Unknwon compression name (%s)" % compression

        print "# converted: %s %d %.4f %.2f" % (f, tag, ene,
                                                (time.time() - tmp) * 1.e3)
        lock.release()
예제 #14
0
def make_h5(out,
            file_tag_ene,
            comment,
            default_energy=None,
            compression="shuf+gz",
            compression_level=4):
    startt = time.time()

    from multiprocessing.dummy import Pool as ThreadPool
    import threading
    import zlib

    lock = threading.Lock()
    of = h5py.File(out, "w")

    of["/metadata/detector"] = "Rayonix MX300HS"
    #of["/metadata/distance_in_mm"] = opts.clen_mm
    if comment: of["/metadata/run_comment"] = comment
    if file_tag_ene:
        tmp = marccd.MarCCD(file_tag_ene[0][0])
        of["/metadata/pixelsize_in_um"] = tmp.pixel_x * 1000.
        if not default_energy: default_energy = 12.3984 / tmp.wavelength

    def write_par(i):
        f, tag, ene = file_tag_ene[i]
        tmp = time.time()
        data = marccd.MarCCD(f).read_data()
        grp = of.create_group("tag-%d" % tag)
        if ene != ene: ene = default_energy

        if compression == "shuf+gz":
            if data.shape[0] % 384 == 0:
                chunks = (384, 384)
            elif data.shape[0] % 256 == 0:
                chunks = (256, 256)
            else:
                chunks = data.shape

            assert len(data.shape) == 2
            assert len(chunks) == 2
            assert data.shape[1] % chunks[
                1] == 0  # because we didn't implement padding to fill a chunk

            as_uint8 = data.view(
                dtype=numpy.uint8
            )  # ONLY the length of the fast axis is doubled
            itemsize = data.dtype.itemsize

            cy = int(numpy.ceil(data.shape[0] / chunks[0]))
            cx = int(numpy.ceil(data.shape[1] / chunks[1]))

            compressed_chunks = [None] * (cy * cx)
            for iy in xrange(cy):
                for ix in xrange(cx):
                    sy = iy * chunks[0]
                    sx = ix * chunks[1]
                    ey = (iy + 1) * chunks[0]
                    ex = (ix + 1) * chunks[1]
                    if ey > data.shape[0]: ey = data.shape[0]
                    if ex > data.shape[1]: ex = data.shape[1]

                    my_chunk = as_uint8[sy:ey, (sx * itemsize):(ex * itemsize)]
                    shuffled = my_chunk.reshape(
                        (-1, data.dtype.itemsize)).transpose().reshape(-1)
                    compressed_chunks[iy * cx + ix] = zlib.compress(
                        shuffled.tobytes(), compression_level)

        lock.acquire()
        grp["photon_energy_ev"] = ene * 1000.
        grp["photon_wavelength_A"] = 12.3984 / ene
        grp["original_file"] = f

        if not compression:
            grp.create_dataset("data", data.shape, dtype=data.dtype, data=data)
        elif compression == "bslz4":
            grp.create_dataset(
                "data",
                data.shape,
                compression=bitshuffle.h5.H5FILTER,
                compression_opts=(0, bitshuffle.h5.H5_COMPRESS_LZ4),
                dtype=data.dtype,
                data=data)
        elif compression == "shuf+gz":
            dataset = grp.create_dataset("data",
                                         data.shape,
                                         chunks=chunks,
                                         compression="gzip",
                                         shuffle=True,
                                         dtype=data.dtype)

            for iy in xrange(cy):
                for ix in xrange(cx):
                    sy = iy * chunks[0]
                    sx = ix * chunks[1]

                    dataset.id.write_direct_chunk(
                        offsets=(sy, sx),
                        data=compressed_chunks[iy * cx + ix],
                        filter_mask=0)
        else:
            raise "Unknwon compression name (%s)" % compression

        print "# converted: %s %d %.4f %.2f" % (f, tag, ene,
                                                (time.time() - tmp) * 1.e3)
        lock.release()

    pool = ThreadPool(opts.nproc)
    pool.map(write_par, xrange(len(file_tag_ene)))

    of.close()

    eltime = time.time() - startt
    print "# Processed: %s (%.2f sec for %d images)" % (out, eltime,
                                                        len(file_tag_ene))
def run(opts):
    eltime_from = time.time()

    assert opts.runid is not None
    assert opts.bl is not None

    # Beamline specific constants
    if opts.bl == 2:
        sensor_spec = "xfel_bl_2_tc_spec_1/energy"
        sensor_shutter = "xfel_bl_2_shutter_1_open_valid/status"
    elif opts.bl == 3:
        sensor_spec = "xfel_bl_3_tc_spec_1/energy"
        sensor_shutter = "xfel_bl_3_shutter_1_open_valid/status"
    else:
        error_status("BadBeamline")
        return -1

    # Get run info
    try:
        run_info = dbpy.read_runinfo(opts.bl, opts.runid)
    except:
        error_status("BadRunID")
        return -1

    high_tag = dbpy.read_hightagnumber(opts.bl, opts.runid)
    start_tag = run_info['start_tagnumber']
    end_tag = run_info['end_tagnumber']

    tag_list = numpy.array(dbpy.read_taglist_byrun(opts.bl, opts.runid))
    print "# Run %d: HighTag %d, Tags %d (inclusive) to %d (exclusive), thus %d images" % (opts.runid, high_tag, start_tag, end_tag, len(tag_list))
    comment = dbpy.read_comment(opts.bl, opts.runid)
    print "# Comment: %s" % comment
    print

    # Get shutter status and find images
    try:
        shutter = numpy.array(map(str2float, dbpy.read_syncdatalist(sensor_shutter, high_tag, tuple(tag_list))))
    except:
        print traceback.format_exc()
        error_status("NoShutterStatus")
        return -1

    # XXX
    valid_tags = tag_list[shutter==1] # [tag for tag, is_open in zip(tag_list, shutter) if is_open == 1]
    print "DEBUG:: shutter=", shutter
    print "DEBUG:: valid_tags=", valid_tags
    if 0:
        tag_offset = 3
        tag_list = tag_list[tag_offset:]
        valid_tags = tag_list[numpy.arange(1, len(tag_list)+1)%6==0]
        
    if valid_tags.size == 0:
        error_status("NoValidTags")
        return -1

    # Find images
    img_files = sorted(glob.glob(os.path.join(opts.rayonix_root, str(opts.runid), "data_*.img")))
    print "# DEBUG:: img_files=%d valid_tags=%d" % (len(img_files), len(valid_tags))
    if len(img_files)+1 != len(valid_tags): # last valid tag is not saved.
        print "# WARNING!! img_files and valid_tag number mismatch"

        img_numbers = map(lambda x: int(x[x.rindex("_")+1:-4]), img_files)
        dropped_frames = sorted(set(range(1, len(valid_tags))).difference(img_numbers))
        print "# Unsaved frame numbers =", tuple(dropped_frames)
        print "# DEBUG::", len(img_files)-len(dropped_frames)+1, len(valid_tags)
        if len(img_files)+len(dropped_frames)+1 == len(valid_tags):
            print "#  %d unsaved img files found, which explains number mismatch" % len(dropped_frames)
            valid_tags = numpy.delete(valid_tags, numpy.array(dropped_frames)-1)
            assert len(img_files)+1 == len(valid_tags)
        else:
            print "# Assuming last %d img files are generated after stopping run.." % (len(img_files)-len(valid_tags)+1)
            img_files = img_files[:len(valid_tags)-1]
            assert len(img_files)+1 == len(valid_tags)
    
    tag_timestamp = map(lambda x: datetime.datetime.fromtimestamp(dbpy.read_timestamp_fromtag(high_tag, x, sensor_shutter)).strftime('%Y-%m-%d %H:%M:%S.%f'), valid_tags)
    img_timestamp = map(lambda x: marccd.MarCCD(x).acquire_time.strftime('%Y-%m-%d %H:%M:%S.%f'), img_files)
    ofs = open("tag_file_time.dat", "w")
    ofs.write("run tag file tag.time file.time\n")
    for i in xrange(len(img_files)):
        ofs.write('%d %d %s "%s" "%s"\n'%(opts.runid, valid_tags[i], img_files[i], tag_timestamp[i], img_timestamp[i]))
    ofs.close()