예제 #1
0
    def test_roundtrip_images(self):
        for filename in fmf_filenames:
            if filename.endswith('raw8.fmf'):
                continue

            for version in [3]:
                # write a new movie
                tmp_dir = tempfile.mkdtemp()
                input_glob = os.path.join( tmp_dir, '*' )
                out_fname = os.path.join( tmp_dir, 'output.fmf' )
                try:
                    fmf2bmps.fmf2images( filename, outdir=tmp_dir )
                    input_list = glob.glob(input_glob)
                    input_list.sort()

                    images2fmf.images2fmf( input_list, out_fname )

                    fmf_in = FlyMovieFormat.FlyMovie(filename)
                    fmf_out = FlyMovieFormat.FlyMovie(out_fname)
                    assert fmf_in.get_n_frames() == fmf_out.get_n_frames()
                    for i in range(fmf_in.get_n_frames()):

                        frame_in1, timestamp_in = fmf_in.get_next_frame()
                        frame_out1, timestamp_out = fmf_out.get_next_frame()

                        frame_in  = imops.to_rgb8(fmf_in.format, frame_in1)
                        frame_out = imops.to_rgb8(fmf_out.format,frame_out1)

                        assert frame_in.shape == frame_out.shape
                        assert numpy.allclose(frame_in,frame_out)
                    fmf_in.close()
                    fmf_out.close()

                finally:
                    shutil.rmtree( tmp_dir )
예제 #2
0
def doit(input_fname,
         subtract_frame,
         start=None,
         stop=None,
         gain=1.0,
         offset=0.0,
         ):
    output_fname = os.path.splitext(input_fname)[0]+'.sub.fmf'
    in_fmf = FMF.FlyMovie(input_fname)
    input_format = in_fmf.get_format()
    input_is_color = imops.is_coding_color(input_format)


    if not subtract_frame.endswith('.fmf'):
        raise NotImplementedError('only fmf supported for --subtract-frame')
    tmp_fmf = FMF.FlyMovie(subtract_frame)
    if input_is_color:
        tmp_frame, tmp_timestamp = tmp_fmf.get_next_frame()
        subtract_frame = imops.to_rgb8(tmp_fmf.get_format(),tmp_frame)
        subtract_frame = subtract_frame.astype(np.float32) # force upconversion to float
    else:
        tmp_frame, tmp_timestamp = tmp_fmf.get_next_frame()
        subtract_frame = imops.to_mono8(tmp_fmf.get_format(),tmp_frame)
        subtract_frame = subtract_frame.astype(np.float32) # force upconversion to float

    if input_is_color:
        output_format = 'RGB8'
    else:
        output_format = 'MONO8'
    out_fmf = FMF.FlyMovieSaver(output_fname,
                                version=3,
                                format=output_format,
                                )
    try:
        if stop is None:
            stop = in_fmf.get_n_frames()-1
        if start is None:
            start = 0
        n_frames = stop-start+1
        n_samples = max(30,n_frames)
        for fno in np.linspace(start,stop,n_samples):
            fno = int(round(fno))
            in_fmf.seek(fno)
            frame,timestamp = in_fmf.get_next_frame()

            if input_is_color:
                frame = imops.to_rgb8(input_format,frame)
                new_frame = frame-subtract_frame
            else:
                frame = np.atleast_3d(frame)
                new_frame = frame-subtract_frame
            new_frame = np.clip(new_frame*gain + offset,0,255)
            new_frame = new_frame.astype(np.uint8)
            out_fmf.add_frame(new_frame,timestamp)
        out_fmf.close()
    except:
        os.unlink(output_fname)
        raise
    in_fmf.close()
예제 #3
0
def convert(frame,format):
    if format in ['RGB8','ARGB8','YUV411','YUV422']:
        frame = imops.to_rgb8(format,frame)
    elif format in ['MONO8','MONO16']:
        frame = imops.to_mono8(format,frame)
    elif (format.startswith('MONO8:') or
          format.startswith('MONO32f:')):
        # bayer
        frame = imops.to_rgb8(format,frame)
    return frame
예제 #4
0
 def _convert_to_displayable(self,frame):
     if self.format in ['RGB8','ARGB8','YUV411','YUV422']:
         frame = imops.to_rgb8(self.format,frame)
     elif self.format in ['MONO8','MONO16']:
         frame = imops.to_mono8(self.format,frame)
     elif (self.format.startswith('MONO8:') or
           self.format.startswith('MONO32f:')):
         # bayer
         frame = imops.to_rgb8(self.format,frame)
     #frame = self.convert_to_matplotlib(frame)
     return frame
예제 #5
0
 def _convert_to_displayable(self, frame):
     if self.format in ["RGB8", "ARGB8", "YUV411", "YUV422", "RGB32f"]:
         frame = imops.to_rgb8(self.format, frame)
     elif self.format in ["MONO8", "MONO16"]:
         frame = imops.to_mono8(self.format, frame)
     elif self.format.startswith("MONO8:") or self.format.startswith("MONO32f:") or self.format.startswith("RAW8:"):
         # bayer
         frame = imops.to_rgb8(self.format, frame)
     else:
         warnings.warn('unknown format "%s" conversion to displayable' % self.format)
     # frame = self.convert_to_matplotlib(frame)
     return frame
예제 #6
0
 def _convert_to_displayable(self,frame):
     if self.format in ['RGB8','ARGB8','YUV411','YUV422','RGB32f']:
         frame = imops.to_rgb8(self.format,frame)
     elif self.format in ['MONO8','MONO16']:
         frame = imops.to_mono8(self.format,frame)
     elif (self.format.startswith('MONO8:') or
           self.format.startswith('MONO32f:')):
         # bayer
         frame = imops.to_rgb8(self.format,frame)
     else:
         warnings.warn('unknown format "%s" conversion to displayable'%
                       self.format)
     #frame = self.convert_to_matplotlib(frame)
     return frame
예제 #7
0
def fmf2images(filename, imgformat='png',
               startframe=0,endframe=-1,interval=1,
               prefix=None,outdir=None,progress=False):
    base,ext = os.path.splitext(filename)
    if ext != '.fmf':
        print 'fmf_filename does not end in .fmf'
        sys.exit()

    path,base = os.path.split(base)
    if prefix is not None:
        base = prefix

    if outdir is None:
        outdir = path

    fly_movie = FlyMovieFormat.FlyMovie(filename)
    fmf_format = fly_movie.get_format()
    n_frames = fly_movie.get_n_frames()
    if endframe < 0 or endframe >= n_frames:
        endframe = n_frames - 1

    fly_movie.seek(startframe)
    frames = range(startframe,endframe+1,interval)
    n_frames = len(frames)
    if progress:
        import progressbar
        widgets=['fmf2bmps', progressbar.Percentage(), ' ',
                 progressbar.Bar(), ' ', progressbar.ETA()]
        pbar=progressbar.ProgressBar(widgets=widgets,
                                     maxval=n_frames).start()
    else:
        pbar = None

    for count,frame_number in enumerate(frames):
        if pbar is not None:
            pbar.update(count)
        frame,timestamp = fly_movie.get_frame(frame_number)

        mono=False
        if (fmf_format in ['RGB8','RGB32f','ARGB8','YUV411','YUV422'] or
            fmf_format.startswith('MONO8:') or
            fmf_format.startswith('MONO32f:')):
            save_frame = imops.to_rgb8(fmf_format,frame)
        else:
            if fmf_format not in ['MONO8','MONO16']:
                warnings.warn('converting unknown fmf format %s to mono'%(
                    fmf_format,))
            save_frame = imops.to_mono8(fmf_format,frame)
            mono=True
        h,w=save_frame.shape[:2]
        if mono:
            im = Image.fromstring('L',(w,h),save_frame.tostring())
        else:
            im = Image.fromstring('RGB',(w,h),save_frame.tostring())
        f='%s_%08d.%s'%(os.path.join(outdir,base),frame_number,imgformat)
        im.save(f)
    if pbar is not None:
        pbar.finish()
예제 #8
0
def main():
    try:
        filename = sys.argv[1]
    except:
        print 'Usage: %s fmf_filename' % sys.argv[0]
        sys.exit()


    path,ext = os.path.splitext(filename)
    if ext != '.fmf':
        print 'fmf_filename does not end in .fmf'
        sys.exit()

    fly_movie = FlyMovieFormat.FlyMovie(filename)
    n_frames = fly_movie.get_n_frames()
    fmf_format = fly_movie.get_format()

    fmf = fly_movie
    delays = numpy.array([])
    for frame_number in range(n_frames):
        frame,timestamp = fmf.get_frame(frame_number)

        mono=False
        if (fmf_format in ['RGB8','ARGB8','YUV411','YUV422'] or
            fmf_format.startswith('MONO8:') or
            fmf_format.startswith('MONO32f:')):
            save_frame = imops.to_rgb8(fmf_format, frame)
        else:
            if fmf_format not in ['MONO8','MONO16']:
                warnings.warn('converting unknown fmf format %s to mono'%(
                    fmf_format,))
            save_frame = imops.to_mono8(fmf_format,frame)
            mono=True
        h, w = save_frame.shape[:2]
        if mono:
            im = Image.fromstring('L',(w,h),save_frame.tostring())
        else:
            im = Image.fromstring('RGB',(w,h),save_frame.tostring())
        f = '%s_%08d.%s'%(os.path.join("./", "zbartmp"), frame_number, 'bmp')
        im.save(f)
        try:
            TS = subprocess.check_output(['zbarimg', '-q', f])
            ts = float(TS[8:].strip())
        except OSError:
            raise
        except:
            ts = float('nan')
        delay = timestamp-ts
        delays = numpy.append(delays,[delay])
        print "ds: % 14.6f cam: % 14.6f delay: % 8.6f" % (ts, timestamp, delay)
        os.unlink(f)
    print "delay mean: % 8.6f std: % 8.6f" % (delays[~numpy.isnan(delays)].mean(),delays[~numpy.isnan(delays)].std())
    print "delay max: % 8.6f min: % 8.6f" % (delays[~numpy.isnan(delays)].max(),delays[~numpy.isnan(delays)].min())
    print "%i of %i frames used" % (len(delays[numpy.isnan(delays)]),len(delays[~numpy.isnan(delays)]))
예제 #9
0
 def save( self, save_frame, timestamp ):
     fname = self.filename%self.count
     self.count += 1
     if self.flip_upside_down:
         save_frame = save_frame[::-1,:] # flip
     if self.format in ['MONO8','RAW8']:
         height,width = save_frame.shape
         im=Image.fromstring('L',(width,height),save_frame.tostring())
     elif self.format in ['MONO32f']:
         save_frame = save_frame.astype( numpy.uint8 )
         height,width = save_frame.shape
         im=Image.fromstring('L',(width,height),save_frame.tostring())
     else:
         rgb8 = imops.to_rgb8(self.format,save_frame)
         height,width,depth = rgb8.shape
         im=Image.fromstring('RGB', (width,height),
                             rgb8.tostring())
     im.save(fname)
예제 #10
0
def doit(input_fname,
         single_channel=False,
         start=None,
         stop=None,
         ):
    output_fname = os.path.splitext(input_fname)[0]+'.av.fmf'
    in_fmf = FMF.FlyMovie(input_fname)
    input_format = in_fmf.get_format()
    input_is_color = imops.is_coding_color(input_format)

    if single_channel is not None:
        if not input_is_color:
            warnings.warn('ignoring --single-channel option for non-color input')
            single_channel = False
        output_format = 'MONO32f'
    else:
        if input_is_color:
            output_format = 'RGB32f'
        else:
            output_format = 'MONO32f'
    out_fmf = FMF.FlyMovieSaver(output_fname,
                                version=3,
                                format=output_format,
                                )
    try:
        if input_is_color:
            channels = [('red',0),('green',1),('blue',2)]
        else:
            channels = [('gray',0)]
        channel_dict = dict(channels)

        if stop is None:
            stop = in_fmf.get_n_frames()-1
        if start is None:
            start = 0
        n_frames = stop-start+1
        n_samples = max(30,n_frames)
        frame,timestamp0 = in_fmf.get_next_frame()
        if input_is_color:
            frame = imops.to_rgb8(input_format,frame)
        else:
            frame = np.atleast_3d(frame)
        cumsum = np.zeros( frame.shape, dtype=np.float32)
        for fno in np.linspace(start,stop,n_samples):
            fno = int(round(fno))
            in_fmf.seek(fno)
            frame,timestamp = in_fmf.get_next_frame()
            if input_is_color:
                frame = imops.to_rgb8(input_format,frame)
            else:
                frame = np.atleast_3d(frame)
            cumsum += frame

        frame = cumsum/n_samples

        if output_format == 'MONO32f':
            # drop dimension
            assert frame.shape[2]==1
            frame = frame[:,:,0]
        out_fmf.add_frame(frame,timestamp0)
        out_fmf.close()
    except:
        os.unlink(output_fname)
        raise
    in_fmf.close()
예제 #11
0
def main():
    usage = """%prog FILE [options]

Example:

fmf2bmps myvideo.fmf --start=10 --stop=100 --extension=jpg --outdir=tmp
"""

    parser = OptionParser(usage)
    parser.add_option('--start',type='int',default=0,help='first frame to save')
    parser.add_option('--stop',type='int',default=-1,help='last frame to save')
    parser.add_option('--interval',type='int',default=1,help='save every Nth frame')
    parser.add_option('--extension',type='string',default='bmp',
                      help='image extension (default: bmp)')
    parser.add_option('--outdir',type='string',default=None,
                      help='directory to save images (default: same as fmf)')
    parser.add_option('--progress',action='store_true',default=False,
                      help='show progress bar')
    parser.add_option('--prefix',default=None,type='str',
                      help='prefix for image filenames')
    (options, args) = parser.parse_args()

    if len(args)<1:
        parser.print_help()
        return

    filename = args[0]
    startframe = options.start
    endframe = options.stop
    interval = options.interval
    assert interval >= 1
    imgformat = options.extension

    base,ext = os.path.splitext(filename)
    if ext != '.fmf':
        print 'fmf_filename does not end in .fmf'
        sys.exit()

    path,base = os.path.split(base)
    if options.prefix is not None:
        base = options.prefix

    if options.outdir is None:
        outdir = path
    else:
        outdir = options.outdir

    fly_movie = FlyMovieFormat.FlyMovie(filename)
    fmf_format = fly_movie.get_format()
    n_frames = fly_movie.get_n_frames()
    if endframe < 0 or endframe >= n_frames:
        endframe = n_frames - 1

    fly_movie.seek(startframe)
    frames = range(startframe,endframe+1,interval)
    n_frames = len(frames)
    if options.progress:
        import progressbar
        widgets=['fmf2bmps', progressbar.Percentage(), ' ',
                 progressbar.Bar(), ' ', progressbar.ETA()]
        pbar=progressbar.ProgressBar(widgets=widgets,
                                     maxval=n_frames).start()
    else:
        pbar = None

    for count,frame_number in enumerate(frames):
        if pbar is not None:
            pbar.update(count)
        frame,timestamp = fly_movie.get_frame(frame_number)

        mono=False
        if (fmf_format in ['RGB8','ARGB8','YUV411','YUV422'] or
            fmf_format.startswith('MONO8:') or
            fmf_format.startswith('MONO32f:')):
            save_frame = imops.to_rgb8(fmf_format,frame)
        else:
            if fmf_format not in ['MONO8','MONO16']:
                warnings.warn('converting unknown fmf format %s to mono'%(
                    fmf_format,))
            save_frame = imops.to_mono8(fmf_format,frame)
            mono=True
        h,w=save_frame.shape[:2]
        if mono:
            im = Image.fromstring('L',(w,h),save_frame.tostring())
        else:
            im = Image.fromstring('RGB',(w,h),save_frame.tostring())
        f='%s_%08d.%s'%(os.path.join(outdir,base),frame_number,imgformat)
        im.save(f)
    if pbar is not None:
        pbar.finish()
예제 #12
0
def doit(input_fname, single_channel=False, start=None, stop=None, gain=1.0, offset=0.0, blur=0.0):
    if blur != 0:
        import scipy.ndimage.filters

    output_fname = os.path.splitext(input_fname)[0] + ".highcontrast.fmf"
    in_fmf = FMF.FlyMovie(input_fname)
    input_format = in_fmf.get_format()
    input_is_color = imops.is_coding_color(input_format)

    if single_channel is not None:
        if not input_is_color:
            warnings.warn("ignoring --single-channel option for non-color input")
            single_channel = False
        output_format = "MONO8"
    else:
        if input_is_color:
            output_format = "RGB8"
        else:
            output_format = "MONO8"
    out_fmf = FMF.FlyMovieSaver(output_fname, version=3, format=output_format)
    try:
        # pass 1 - get 5,95 percentiles
        if input_is_color:
            channels = [("red", 0), ("green", 1), ("blue", 2)]
        else:
            channels = [("gray", 0)]
        channel_dict = dict(channels)
        minvs = collections.defaultdict(list)
        maxvs = collections.defaultdict(list)
        if stop is None:
            stop = in_fmf.get_n_frames() - 1
        if start is None:
            start = 0
        n_frames = stop - start + 1
        n_samples = max(30, n_frames)
        for fno in np.linspace(start, stop, n_samples):
            fno = int(round(fno))
            in_fmf.seek(fno)
            frame, timestamp = in_fmf.get_next_frame()
            if input_is_color:
                frame = imops.to_rgb8(input_format, frame)
            else:
                frame = np.atleast_3d(frame)
            for channel_name, channel_idx in channels:
                minv, maxv = prctile(frame[:, :, channel_idx].ravel(), p=(5.0, 95.0))
                minvs[channel_name].append(minv)
                maxvs[channel_name].append(maxv)

        orig_center = {}
        orig_range = {}
        for channel_name, channel_idx in channels:
            minv = np.min(minvs[channel_name])
            maxv = np.max(maxvs[channel_name])
            orig_center[channel_name] = (minv + maxv) / 2.0
            orig_range[channel_name] = maxv - minv

        new_center = 127.5 + offset
        new_range = 127.5 * gain

        # pass 2 - rescale and save
        in_fmf.seek(0)
        for fno in range(start, stop + 1):
            frame, timestamp = in_fmf.get_next_frame()
            if input_is_color:
                frame = imops.to_rgb8(input_format, frame)
            else:
                frame = np.atleast_3d(frame)

            if single_channel is not None:
                # input is, by definition, color
                frame = frame[:, :, channel_dict[single_channel]]  # drop all but single_channel dim
                frame = frame.astype(np.float32)
                frame = (frame - orig_center[single_channel]) / orig_range[single_channel]
                frame = frame * new_range + new_center
                frame = np.atleast_3d(frame)  # add dim
            else:
                frame = frame.astype(np.float32)
                for channel_name, channel_idx in channels:
                    frame[:, :, channel_idx] = (frame[:, :, channel_idx] - orig_center[channel_name]) / orig_range[
                        channel_name
                    ]
                    frame[:, :, channel_idx] = frame[:, :, channel_idx] * new_range + new_center

            if blur != 0:
                for chan in range(frame.shape[2]):
                    # filter each channel independently
                    frame[:, :, chan] = scipy.ndimage.filters.gaussian_filter(frame[:, :, chan], blur)
            frame = np.clip(frame, 0, 255).astype(np.uint8)
            if output_format is "MONO8":
                # drop dimension
                assert frame.shape[2] == 1
                frame = frame[:, :, 0]
            out_fmf.add_frame(frame, timestamp)
        out_fmf.close()
    except:
        os.unlink(output_fname)
        raise
    in_fmf.close()
예제 #13
0
                            # should only be one .ufmf with this frame and cam_id
                            assert found is None

                            fmf_fno = fmf_fnos[0]
                            found = (fmf, fmf_fno )
                        if found is None:
                            print 'no image data for frame timestamp %s cam_id %s'%(
                                repr(frame_timestamp),cam_id)
                            continue
                        fmf, fmf_fno = found
                        image, fmf_timestamp = fmf.get_frame( fmf_fno )
                        mean_image = fmf.get_mean_for_timestamp(fmf_timestamp)
                        coding = fmf.get_format()
                        if imops.is_coding_color(coding):
                            image = imops.to_rgb8(coding,image)
                            mean_image = imops.to_rgb8(coding,mean_image)
                        else:
                            image = imops.to_mono8(coding,image)
                            mean_image = imops.to_mono8(coding,mean_image)

                        xy = (int(round(frame2d[idx]['x'])),
                              int(round(frame2d[idx]['y'])))
                        maxsize = (fmf.get_width(), fmf.get_height())

                        # Accumulate cropped images. Note that the region
                        # of the full image that the cropped image
                        # occupies changes over time as the tracked object
                        # moves. Thus, averaging these cropped-and-shifted
                        # images is not the same as simply averaging the
                        # full frame.
예제 #14
0
def doit(
    movie_fname=None,
    reconstructor_fname=None,
    h5_fname=None,
    cam_id=None,
    dest_dir=None,
    transform=None,
    start=None,
    stop=None,
    h5start=None,
    h5stop=None,
    show_obj_ids=False,
    obj_only=None,
    image_format=None,
    subtract_frame=None,
    save_framelist_fname=None,
):

    if dest_dir is None:
        dest_dir = os.curdir

    if movie_fname is None:
        raise NotImplementedError('')

    if image_format is None:
        image_format = 'png'

    if cam_id is None:
        raise NotImplementedError('')

    if movie_fname.lower().endswith('.fmf'):
        movie = fmf_mod.FlyMovie(movie_fname)
    else:
        movie = ufmf_mod.FlyMovieEmulator(movie_fname)

    if start is None:
        start = 0

    if stop is None:
        stop = movie.get_n_frames() - 1

    ca = core_analysis.get_global_CachingAnalyzer()
    (obj_ids, unique_obj_ids, is_mat_file, data_file, extra) = \
              ca.initial_file_load(h5_fname)
    if obj_only is not None:
        unique_obj_ids = obj_only

    dynamic_model_name = extra['dynamic_model_name']
    if dynamic_model_name.startswith('EKF'):
        dynamic_model_name = dynamic_model_name[4:]

    if reconstructor_fname is None:
        reconstructor = flydra_core.reconstruct.Reconstructor(data_file)
    else:
        reconstructor = flydra_core.reconstruct.Reconstructor(
            reconstructor_fname)

    fix_w = movie.get_width()
    fix_h = movie.get_height()
    is_color = imops.is_coding_color(movie.get_format())

    if subtract_frame is not None:
        if not subtract_frame.endswith('.fmf'):
            raise NotImplementedError(
                'only fmf supported for --subtract-frame')
        tmp_fmf = fmf_mod.FlyMovie(subtract_frame)

        if is_color:
            tmp_frame, tmp_timestamp = tmp_fmf.get_next_frame()
            subtract_frame = imops.to_rgb8(tmp_fmf.get_format(), tmp_frame)
            subtract_frame = subtract_frame.astype(
                np.float32)  # force upconversion to float
        else:
            tmp_frame, tmp_timestamp = tmp_fmf.get_next_frame()
            subtract_frame = imops.to_mono8(tmp_fmf.get_format(), tmp_frame)
            subtract_frame = subtract_frame.astype(
                np.float32)  # force upconversion to float

    if save_framelist_fname is not None:
        save_framelist_fd = open(save_framelist_fname, mode='w')

    movie_fno_count = 0
    for movie_fno in range(start, stop + 1):
        movie.seek(movie_fno)
        image, timestamp = movie.get_next_frame()
        h5_frame = extra['time_model'].timestamp2framestamp(timestamp)
        if h5start is not None:
            if h5_frame < h5start:
                continue
        if h5stop is not None:
            if h5_frame > h5stop:
                continue
        if is_color:
            image = imops.to_rgb8(movie.get_format(), image)
        else:
            image = imops.to_mono8(movie.get_format(), image)
        if subtract_frame is not None:
            new_image = np.clip(image - subtract_frame, 0, 255)
            image = new_image.astype(np.uint8)
        warnings.warn('not implemented: interpolating data')
        h5_frame = int(round(h5_frame))
        if save_framelist_fname is not None:
            save_framelist_fd.write('%d\n' % h5_frame)

        movie_fno_count += 1
        if 0:
            # save starting from frame 1
            save_fname_path = os.path.splitext(movie_fname)[
                0] + '_frame%06d.%s' % (movie_fno_count, image_format)
        else:
            # frame is frame in movie file
            save_fname_path = os.path.splitext(
                movie_fname)[0] + '_frame%06d.%s' % (movie_fno, image_format)
        save_fname_path = os.path.join(dest_dir, save_fname_path)
        if transform in ['rot 90', 'rot -90']:
            device_rect = (0, 0, fix_h, fix_w)
            canv = benu.Canvas(save_fname_path, fix_h, fix_w)
        else:
            device_rect = (0, 0, fix_w, fix_h)
            canv = benu.Canvas(save_fname_path, fix_w, fix_h)
        user_rect = (0, 0, image.shape[1], image.shape[0])
        show_points = []
        with canv.set_user_coords(device_rect, user_rect, transform=transform):
            canv.imshow(image, 0, 0)
            for obj_id in unique_obj_ids:
                try:
                    data = ca.load_data(
                        obj_id,
                        data_file,
                        frames_per_second=extra['frames_per_second'],
                        dynamic_model_name=dynamic_model_name,
                    )
                except core_analysis.NotEnoughDataToSmoothError:
                    continue
                cond = data['frame'] == h5_frame
                idxs = np.nonzero(cond)[0]
                if not len(idxs):
                    continue  # no data at this frame for this obj_id
                assert len(idxs) == 1
                idx = idxs[0]
                row = data[idx]

                # circle over data point
                xyz = row['x'], row['y'], row['z']
                x2d, y2d = reconstructor.find2d(cam_id, xyz, distorted=True)
                radius = 10
                canv.scatter([x2d], [y2d],
                             color_rgba=green,
                             markeredgewidth=3,
                             radius=radius)

                if 1:
                    # z line to XY plane through origin
                    xyz0 = row['x'], row['y'], 0
                    x2d_z0, y2d_z0 = reconstructor.find2d(cam_id,
                                                          xyz0,
                                                          distorted=True)
                    warnings.warn('not distorting Z line')
                    if 1:
                        xdist = x2d - x2d_z0
                        ydist = y2d - y2d_z0
                        dist = np.sqrt(xdist**2 + ydist**2)
                        start_frac = radius / dist
                        if radius > dist:
                            start_frac = 0
                        x2d_r = x2d - xdist * start_frac
                        y2d_r = y2d - ydist * start_frac
                    else:
                        x2d_r = x2d
                        y2d_r = y2d
                    canv.plot([x2d_r, x2d_z0], [y2d_r, y2d_z0],
                              color_rgba=green,
                              linewidth=3)
                if show_obj_ids:
                    show_points.append((obj_id, x2d, y2d))
        for show_point in show_points:
            obj_id, x2d, y2d = show_point
            x, y = canv.get_transformed_point(x2d,
                                              y2d,
                                              device_rect,
                                              user_rect,
                                              transform=transform)
            canv.text(
                'obj_id %d' % obj_id,
                x,
                y,
                color_rgba=(0, 1, 0, 1),
                font_size=20,
            )
        canv.save()
예제 #15
0
    def update_image_and_drawings(self,
                                  id_val,
                                  image,
                                  format=None,
                                  points=None,
                                  linesegs=None,
                                  point_colors=None,
                                  point_radii=None,
                                  lineseg_colors=None,
                                  lineseg_widths=None,
                                  xoffset=0,
                                  yoffset=0,
                                  doresize=None):
        """update the displayed image

        **Arguments**

        id_val : string
            An identifier for the particular source being updated
        image : numpy array
            The image data to update

        **Optional keyword arguments**

        format : string
            The image format (e.g. 'MONO8', 'RGB8', or 'YUV422')
        points : list of points
            Points to display (e.g. [(x0,y0),(x1,y1)])
        linesegs : list of line segments
            Line segments to display (e.g. [(x0,y0,x1,y1),(x1,y1,x2,y2)])
        """

        # create bitmap, don't paint on screen
        if points is None:
            points = []
        if linesegs is None:
            linesegs = []
        if format is None:
            format = 'MONO8'
            warnings.warn('format unspecified - assuming MONO8')

        # if doresize is not input, then use the default value
        if doresize is None:
            doresize = self.doresize

        rgb8 = imops.to_rgb8(format, image)

        if doresize:
            from scipy.misc.pilutil import imresize

            # how much should we resize the image
            windowwidth = self.GetRect().GetWidth()
            windowheight = self.GetRect().GetHeight()
            imagewidth = rgb8.shape[1]
            imageheight = rgb8.shape[0]
            resizew = float(windowwidth) / float(imagewidth)
            resizeh = float(windowheight) / float(imageheight)
            self.resize = min(resizew, resizeh)
            # resize the image
            rgb8 = imresize(rgb8, self.resize)
            # scale all the points and lines
            pointscp = []
            for pt in points:
                pointscp.append([pt[0] * self.resize, pt[1] * self.resize])
            points = pointscp
            linesegscp = []
            for line in linesegs:
                linesegscp.append([
                    line[0] * self.resize, line[1] * self.resize,
                    line[2] * self.resize, line[3] * self.resize
                ])
            linesegs = linesegscp

        if self.id_val is None:
            self.id_val = id_val
        if id_val != self.id_val:
            raise NotImplementedError(
                "only 1 image source currently supported")

        h, w, three = rgb8.shape
        # get full image
        if self.full_image_numpy is not None:
            full_h, full_w, tmp = self.full_image_numpy.shape
            if h < full_h or w < full_w:
                self.full_image_numpy[yoffset:yoffset + h,
                                      xoffset:xoffset + w, :] = rgb8
                rgb8 = self.full_image_numpy
                h, w = full_h, full_w
        else:
            self.full_image_numpy = rgb8

        image = wx.EmptyImage(w, h)

        # XXX TODO could eliminate data copy here?
        image.SetData(rgb8.tostring())
        bmp = wx.BitmapFromImage(image)

        # now draw into bmp

        drawDC = wx.MemoryDC()
        #assert drawDC.Ok(), "drawDC not OK"
        drawDC.SelectObject(bmp)  # draw into bmp
        drawDC.SetBrush(wx.Brush(wx.Colour(255, 255, 255), wx.TRANSPARENT))

        if self.do_draw_points and points is not None and len(points) > 0:
            if point_radii is None:
                point_radii = [8] * len(points)
            if point_colors is None:
                point_colors = [(0, 1, 0)] * len(points)
        if self.do_draw_points and linesegs is not None and len(linesegs) > 0:
            if lineseg_widths is None:
                lineseg_widths = [1] * len(linesegs)
            if lineseg_colors is None:
                lineseg_colors = [(0, 1, 0)] * len(linesegs)

        # fixing drawing point colors!!!
        if self.do_draw_points:
            for i in range(len(points)):

                # point
                pt = points[i]

                # point color
                ptcolor = point_colors[i]
                wxptcolor = wx.Colour(round(ptcolor[0] * 255),
                                      round(ptcolor[1] * 255),
                                      round(ptcolor[2] * 255))

                # radius of point
                ptradius = point_radii[i]

                # draw it
                drawDC.SetPen(wx.Pen(colour=wxptcolor, width=1))
                drawDC.DrawCircle(int(pt[0]), int(pt[1]), ptradius)

            for i in range(len(linesegs)):
                lineseg = linesegs[i]
                linesegcolor = lineseg_colors[i]
                wxlinesegcolor = wx.Colour(round(linesegcolor[0] * 255),
                                           round(linesegcolor[1] * 255),
                                           round(linesegcolor[2] * 255))
                linesegwidth = lineseg_widths[i]

                drawDC.SetPen(wx.Pen(colour=wxlinesegcolor,
                                     width=linesegwidth))
                if len(lineseg) <= 4:
                    drawDC.DrawLine(*lineseg)
                else:
                    for start_idx in range(0, len(lineseg) - 3, 2):
                        this_seg = lineseg[start_idx:start_idx + 4]
                        drawDC.DrawLine(*this_seg)

        if id_val in self.lbrt:
            drawDC.SetPen(wx.Pen('GREEN', width=1))
            l, b, r, t = self.lbrt[id_val]
            drawDC.DrawLine(l, b, r, b)
            drawDC.DrawLine(r, b, r, t)
            drawDC.DrawLine(r, t, l, t)
            drawDC.DrawLine(l, t, l, b)

        img = wx.ImageFromBitmap(bmp)
        if self.mirror_display:
            if not self.display_rotate_180:
                img = img.Rotate90()
                img = img.Rotate90()
        else:
            img = img.Mirror(True)
            if not self.display_rotate_180:
                img = img.Rotate90()
                img = img.Rotate90()
        bmp = wx.BitmapFromImage(img)

        self.bitmap = bmp
예제 #16
0
def doit(
    h5_filename=None,
    output_h5_filename=None,
    ufmf_filenames=None,
    kalman_filename=None,
    start=None,
    stop=None,
    view=None,
    erode=0,
    save_images=False,
    save_image_dir=None,
    intermediate_thresh_frac=None,
    final_thresh=None,
    stack_N_images=None,
    stack_N_images_min=None,
    old_sync_timestamp_source=False,
    do_rts_smoothing=True,
):
    """

    Copy all data in .h5 file (specified by h5_filename) to a new .h5
    file in which orientations are set based on image analysis of
    .ufmf files. Tracking data to associate 2D points from subsequent
    frames is read from the .h5 kalman file specified by
    kalman_filename.

    """
    if view is None:
        view = ["orig" for f in ufmf_filenames]
    else:
        assert len(view) == len(ufmf_filenames)

    if intermediate_thresh_frac is None or final_thresh is None:
        raise ValueError("intermediate_thresh_frac and final_thresh must be "
                         "set")

    filename2view = dict(zip(ufmf_filenames, view))

    ca = core_analysis.get_global_CachingAnalyzer()
    obj_ids, use_obj_ids, is_mat_file, data_file, extra = ca.initial_file_load(
        kalman_filename)
    try:
        ML_estimates_2d_idxs = data_file.root.ML_estimates_2d_idxs[:]
    except tables.exceptions.NoSuchNodeError as err1:
        # backwards compatibility
        try:
            ML_estimates_2d_idxs = data_file.root.kalman_observations_2d_idxs[:]
        except tables.exceptions.NoSuchNodeError as err2:
            raise err1

    if os.path.exists(output_h5_filename):
        raise RuntimeError("will not overwrite old file '%s'" %
                           output_h5_filename)
    with open_file_safe(output_h5_filename, delete_on_error=True,
                        mode="w") as output_h5:
        if save_image_dir is not None:
            if not os.path.exists(save_image_dir):
                os.mkdir(save_image_dir)

        with open_file_safe(h5_filename, mode="r") as h5:

            fps = result_utils.get_fps(h5, fail_on_error=True)

            for input_node in h5.root._f_iter_nodes():
                # copy everything from source to dest
                input_node._f_copy(output_h5.root, recursive=True)
            print("done copying")

            # Clear values in destination table that we may overwrite.
            dest_table = output_h5.root.data2d_distorted
            for colname in [
                    "x",
                    "y",
                    "area",
                    "slope",
                    "eccentricity",
                    "cur_val",
                    "mean_val",
                    "sumsqf_val",
            ]:
                if colname == "cur_val":
                    fill_value = 0
                else:
                    fill_value = np.nan
                clear_col(dest_table, colname, fill_value=fill_value)
            dest_table.flush()
            print("done clearing")

            camn2cam_id, cam_id2camns = result_utils.get_caminfo_dicts(h5)

            cam_id2fmfs = collections.defaultdict(list)
            cam_id2view = {}
            for ufmf_filename in ufmf_filenames:
                fmf = ufmf.FlyMovieEmulator(
                    ufmf_filename,
                    # darken=-50,
                    allow_no_such_frame_errors=True,
                )
                timestamps = fmf.get_all_timestamps()

                cam_id = get_cam_id_from_filename(fmf.filename,
                                                  cam_id2camns.keys())
                cam_id2fmfs[cam_id].append(
                    (fmf, result_utils.Quick1DIndexer(timestamps)))

                cam_id2view[cam_id] = filename2view[fmf.filename]

            # associate framenumbers with timestamps using 2d .h5 file
            data2d = h5.root.data2d_distorted[:]  # load to RAM
            data2d_idxs = np.arange(len(data2d))
            h5_framenumbers = data2d["frame"]
            h5_frame_qfi = result_utils.QuickFrameIndexer(h5_framenumbers)

            fpc = realtime_image_analysis.FitParamsClass(
            )  # allocate FitParamsClass

            for obj_id_enum, obj_id in enumerate(use_obj_ids):
                print("object %d of %d" % (obj_id_enum, len(use_obj_ids)))

                # get all images for this camera and this obj_id

                obj_3d_rows = ca.load_dynamics_free_MLE_position(
                    obj_id, data_file)

                this_obj_framenumbers = collections.defaultdict(list)
                if save_images:
                    this_obj_raw_images = collections.defaultdict(list)
                    this_obj_mean_images = collections.defaultdict(list)
                this_obj_absdiff_images = collections.defaultdict(list)
                this_obj_morphed_images = collections.defaultdict(list)
                this_obj_morph_failures = collections.defaultdict(list)
                this_obj_im_coords = collections.defaultdict(list)
                this_obj_com_coords = collections.defaultdict(list)
                this_obj_camn_pt_no = collections.defaultdict(list)

                for this_3d_row in obj_3d_rows:
                    # iterate over each sample in the current camera
                    framenumber = this_3d_row["frame"]
                    if start is not None:
                        if not framenumber >= start:
                            continue
                    if stop is not None:
                        if not framenumber <= stop:
                            continue
                    h5_2d_row_idxs = h5_frame_qfi.get_frame_idxs(framenumber)

                    frame2d = data2d[h5_2d_row_idxs]
                    frame2d_idxs = data2d_idxs[h5_2d_row_idxs]

                    obs_2d_idx = this_3d_row["obs_2d_idx"]
                    kobs_2d_data = ML_estimates_2d_idxs[int(obs_2d_idx)]

                    # Parse VLArray.
                    this_camns = kobs_2d_data[0::2]
                    this_camn_idxs = kobs_2d_data[1::2]

                    # Now, for each camera viewing this object at this
                    # frame, extract images.
                    for camn, camn_pt_no in zip(this_camns, this_camn_idxs):

                        # find 2D point corresponding to object
                        cam_id = camn2cam_id[camn]

                        movie_tups_for_this_camn = cam_id2fmfs[cam_id]
                        cond = (frame2d["camn"] == camn) & (
                            frame2d["frame_pt_idx"] == camn_pt_no)
                        idxs = np.nonzero(cond)[0]
                        assert len(idxs) == 1
                        idx = idxs[0]

                        orig_data2d_rownum = frame2d_idxs[idx]

                        if not old_sync_timestamp_source:
                            # Change the next line to 'timestamp' for old
                            # data (before May/June 2009 -- the switch to
                            # fview_ext_trig)
                            frame_timestamp = frame2d[idx][
                                "cam_received_timestamp"]
                        else:
                            # previous version
                            frame_timestamp = frame2d[idx]["timestamp"]
                        found = None
                        for fmf, fmf_timestamp_qi in movie_tups_for_this_camn:
                            fmf_fnos = fmf_timestamp_qi.get_idxs(
                                frame_timestamp)
                            if not len(fmf_fnos):
                                continue
                            assert len(fmf_fnos) == 1

                            # should only be one .ufmf with this frame and cam_id
                            assert found is None

                            fmf_fno = fmf_fnos[0]
                            found = (fmf, fmf_fno)
                        if found is None:
                            print(
                                "no image data for frame timestamp %s cam_id %s"
                                % (repr(frame_timestamp), cam_id))
                            continue
                        fmf, fmf_fno = found
                        image, fmf_timestamp = fmf.get_frame(fmf_fno)
                        mean_image = fmf.get_mean_for_timestamp(fmf_timestamp)
                        coding = fmf.get_format()
                        if imops.is_coding_color(coding):
                            image = imops.to_rgb8(coding, image)
                            mean_image = imops.to_rgb8(coding, mean_image)
                        else:
                            image = imops.to_mono8(coding, image)
                            mean_image = imops.to_mono8(coding, mean_image)

                        xy = (
                            int(round(frame2d[idx]["x"])),
                            int(round(frame2d[idx]["y"])),
                        )
                        maxsize = (fmf.get_width(), fmf.get_height())

                        # Accumulate cropped images. Note that the region
                        # of the full image that the cropped image
                        # occupies changes over time as the tracked object
                        # moves. Thus, averaging these cropped-and-shifted
                        # images is not the same as simply averaging the
                        # full frame.

                        roiradius = 25
                        warnings.warn(
                            "roiradius hard-coded to %d: could be set "
                            "from 3D tracking" % roiradius)
                        tmp = clip_and_math(image, mean_image, xy, roiradius,
                                            maxsize)
                        im_coords, raw_im, mean_im, absdiff_im = tmp

                        max_absdiff_im = absdiff_im.max()
                        intermediate_thresh = intermediate_thresh_frac * max_absdiff_im
                        absdiff_im[absdiff_im <= intermediate_thresh] = 0

                        if erode > 0:
                            morphed_im = scipy.ndimage.grey_erosion(absdiff_im,
                                                                    size=erode)
                            ## morphed_im = scipy.ndimage.binary_erosion(absdiff_im>1).astype(np.float32)*255.0
                        else:
                            morphed_im = absdiff_im

                        y0_roi, x0_roi = scipy.ndimage.center_of_mass(
                            morphed_im)
                        x0 = im_coords[0] + x0_roi
                        y0 = im_coords[1] + y0_roi

                        if 1:
                            morphed_im_binary = morphed_im > 0
                            labels, n_labels = scipy.ndimage.label(
                                morphed_im_binary)
                            morph_fail_because_multiple_blobs = False

                            if n_labels > 1:
                                x0, y0 = np.nan, np.nan
                                # More than one blob -- don't allow image.
                                if 1:
                                    # for min flattening
                                    morphed_im = np.empty(morphed_im.shape,
                                                          dtype=np.uint8)
                                    morphed_im.fill(255)
                                    morph_fail_because_multiple_blobs = True
                                else:
                                    # for mean flattening
                                    morphed_im = np.zeros_like(morphed_im)
                                    morph_fail_because_multiple_blobs = True

                        this_obj_framenumbers[camn].append(framenumber)
                        if save_images:
                            this_obj_raw_images[camn].append(
                                (raw_im, im_coords))
                            this_obj_mean_images[camn].append(mean_im)
                        this_obj_absdiff_images[camn].append(absdiff_im)
                        this_obj_morphed_images[camn].append(morphed_im)
                        this_obj_morph_failures[camn].append(
                            morph_fail_because_multiple_blobs)
                        this_obj_im_coords[camn].append(im_coords)
                        this_obj_com_coords[camn].append((x0, y0))
                        this_obj_camn_pt_no[camn].append(orig_data2d_rownum)
                        if 0:
                            fname = "obj%05d_%s_frame%07d_pt%02d.png" % (
                                obj_id,
                                cam_id,
                                framenumber,
                                camn_pt_no,
                            )
                            plot_image_subregion(
                                raw_im,
                                mean_im,
                                absdiff_im,
                                roiradius,
                                fname,
                                im_coords,
                                view=filename2view[fmf.filename],
                            )

                # Now, all the frames from all cameras for this obj_id
                # have been gathered. Do a camera-by-camera analysis.
                for camn in this_obj_absdiff_images:
                    cam_id = camn2cam_id[camn]
                    image_framenumbers = np.array(this_obj_framenumbers[camn])
                    if save_images:
                        raw_images = this_obj_raw_images[camn]
                        mean_images = this_obj_mean_images[camn]
                    absdiff_images = this_obj_absdiff_images[camn]
                    morphed_images = this_obj_morphed_images[camn]
                    morph_failures = np.array(this_obj_morph_failures[camn])
                    im_coords = this_obj_im_coords[camn]
                    com_coords = this_obj_com_coords[camn]
                    camn_pt_no_array = this_obj_camn_pt_no[camn]

                    all_framenumbers = np.arange(image_framenumbers[0],
                                                 image_framenumbers[-1] + 1)

                    com_coords = np.array(com_coords)
                    if do_rts_smoothing:
                        # Perform RTS smoothing on center-of-mass coordinates.

                        # Find first good datum.
                        fgnz = np.nonzero(~np.isnan(com_coords[:, 0]))
                        com_coords_smooth = np.empty(com_coords.shape,
                                                     dtype=np.float)
                        com_coords_smooth.fill(np.nan)

                        if len(fgnz[0]):
                            first_good = fgnz[0][0]

                            RTS_com_coords = com_coords[first_good:, :]

                            # Setup parameters for Kalman filter.
                            dt = 1.0 / fps
                            A = np.array(
                                [
                                    [1, 0, dt, 0],  # process update
                                    [0, 1, 0, dt],
                                    [0, 0, 1, 0],
                                    [0, 0, 0, 1],
                                ],
                                dtype=np.float,
                            )
                            C = np.array(
                                [[1, 0, 0, 0], [0, 1, 0, 0]
                                 ],  # observation matrix
                                dtype=np.float,
                            )
                            Q = 0.1 * np.eye(4)  # process noise
                            R = 1.0 * np.eye(2)  # observation noise
                            initx = np.array(
                                [
                                    RTS_com_coords[0, 0], RTS_com_coords[0, 1],
                                    0, 0
                                ],
                                dtype=np.float,
                            )
                            initV = 2 * np.eye(4)
                            initV[0, 0] = 0.1
                            initV[1, 1] = 0.1
                            y = RTS_com_coords
                            xsmooth, Vsmooth = adskalman.adskalman.kalman_smoother(
                                y, A, C, Q, R, initx, initV)
                            com_coords_smooth[first_good:] = xsmooth[:, :2]

                        # Now shift images

                        image_shift = com_coords_smooth - com_coords
                        bad_cond = np.isnan(image_shift[:, 0])
                        # broadcast zeros to places where no good tracking
                        image_shift[bad_cond, 0] = 0
                        image_shift[bad_cond, 1] = 0

                        shifted_morphed_images = [
                            shift_image(im, xy)
                            for im, xy in zip(morphed_images, image_shift)
                        ]

                        results = flatten_image_stack(
                            image_framenumbers,
                            shifted_morphed_images,
                            im_coords,
                            camn_pt_no_array,
                            N=stack_N_images,
                        )
                    else:
                        results = flatten_image_stack(
                            image_framenumbers,
                            morphed_images,
                            im_coords,
                            camn_pt_no_array,
                            N=stack_N_images,
                        )

                    # The variable fno (the first element of the results
                    # tuple) is guaranteed to be contiguous and to span
                    # the range from the first to last frames available.

                    for (
                            fno,
                            av_im,
                            lowerleft,
                            orig_data2d_rownum,
                            orig_idx,
                            orig_idxs_in_average,
                    ) in results:

                        # Clip image to reduce moment arms.
                        av_im[av_im <= final_thresh] = 0

                        fail_fit = False
                        fast_av_im = FastImage.asfastimage(
                            av_im.astype(np.uint8))
                        try:
                            (x0_roi, y0_roi, area, slope,
                             eccentricity) = fpc.fit(fast_av_im)
                        except realtime_image_analysis.FitParamsError as err:
                            fail_fit = True

                        this_morph_failures = morph_failures[
                            orig_idxs_in_average]
                        n_failed_images = np.sum(this_morph_failures)
                        n_good_images = stack_N_images - n_failed_images
                        if n_good_images >= stack_N_images_min:
                            n_images_is_acceptable = True
                        else:
                            n_images_is_acceptable = False

                        if fail_fit:
                            x0_roi = np.nan
                            y0_roi = np.nan
                            area, slope, eccentricity = np.nan, np.nan, np.nan

                        if not n_images_is_acceptable:
                            x0_roi = np.nan
                            y0_roi = np.nan
                            area, slope, eccentricity = np.nan, np.nan, np.nan

                        x0 = x0_roi + lowerleft[0]
                        y0 = y0_roi + lowerleft[1]

                        if 1:
                            for row in dest_table.iterrows(
                                    start=orig_data2d_rownum,
                                    stop=orig_data2d_rownum + 1):

                                row["x"] = x0
                                row["y"] = y0
                                row["area"] = area
                                row["slope"] = slope
                                row["eccentricity"] = eccentricity
                                row.update()  # save data

                        if save_images:
                            # Display debugging images
                            fname = "av_obj%05d_%s_frame%07d.png" % (
                                obj_id,
                                cam_id,
                                fno,
                            )
                            if save_image_dir is not None:
                                fname = os.path.join(save_image_dir, fname)

                            raw_im, raw_coords = raw_images[orig_idx]
                            mean_im = mean_images[orig_idx]
                            absdiff_im = absdiff_images[orig_idx]
                            morphed_im = morphed_images[orig_idx]
                            raw_l, raw_b = raw_coords[:2]

                            imh, imw = raw_im.shape[:2]
                            n_ims = 5

                            if 1:
                                # increase contrast
                                contrast_scale = 2.0
                                av_im_show = np.clip(av_im * contrast_scale, 0,
                                                     255)

                            margin = 10
                            scale = 3

                            # calculate the orientation line
                            yintercept = y0 - slope * x0
                            xplt = np.array([
                                lowerleft[0] - 5,
                                lowerleft[0] + av_im_show.shape[1] + 5,
                            ])
                            yplt = slope * xplt + yintercept
                            if 1:
                                # only send non-nan values to plot
                                plt_good = ~np.isnan(xplt) & ~np.isnan(yplt)
                                xplt = xplt[plt_good]
                                yplt = yplt[plt_good]

                            top_row_width = scale * imw * n_ims + (
                                1 + n_ims) * margin
                            SHOW_STACK = True
                            if SHOW_STACK:
                                n_stack_rows = 4
                                rw = scale * imw * stack_N_images + (
                                    1 + n_ims) * margin
                                row_width = max(top_row_width, rw)
                                col_height = (n_stack_rows * scale * imh +
                                              (n_stack_rows + 1) * margin)
                                stack_margin = 20
                            else:
                                row_width = top_row_width
                                col_height = scale * imh + 2 * margin
                                stack_margin = 0

                            canv = benu.Canvas(
                                fname,
                                row_width,
                                col_height + stack_margin,
                                color_rgba=(1, 1, 1, 1),
                            )

                            if SHOW_STACK:
                                for (stacki, s_orig_idx
                                     ) in enumerate(orig_idxs_in_average):

                                    (s_raw_im,
                                     s_raw_coords) = raw_images[s_orig_idx]
                                    s_raw_l, s_raw_b = s_raw_coords[:2]
                                    s_imh, s_imw = s_raw_im.shape[:2]
                                    user_rect = (s_raw_l, s_raw_b, s_imw,
                                                 s_imh)

                                    x_display = (stacki + 1) * margin + (
                                        scale * imw) * stacki
                                    for show in ["raw", "absdiff", "morphed"]:
                                        if show == "raw":
                                            y_display = scale * imh + 2 * margin
                                        elif show == "absdiff":
                                            y_display = 2 * scale * imh + 3 * margin
                                        elif show == "morphed":
                                            y_display = 3 * scale * imh + 4 * margin
                                        display_rect = (
                                            x_display,
                                            y_display + stack_margin,
                                            scale * raw_im.shape[1],
                                            scale * raw_im.shape[0],
                                        )

                                        with canv.set_user_coords(
                                                display_rect,
                                                user_rect,
                                                transform=cam_id2view[cam_id],
                                        ):

                                            if show == "raw":
                                                s_im = s_raw_im.astype(
                                                    np.uint8)
                                            elif show == "absdiff":
                                                tmp = absdiff_images[
                                                    s_orig_idx]
                                                s_im = tmp.astype(np.uint8)
                                            elif show == "morphed":
                                                tmp = morphed_images[
                                                    s_orig_idx]
                                                s_im = tmp.astype(np.uint8)

                                            canv.imshow(s_im, s_raw_l, s_raw_b)
                                            sx0, sy0 = com_coords[s_orig_idx]
                                            X = [sx0]
                                            Y = [sy0]
                                            # the raw coords in red
                                            canv.scatter(X,
                                                         Y,
                                                         color_rgba=(1, 0.5,
                                                                     0.5, 1))

                                            if do_rts_smoothing:
                                                sx0, sy0 = com_coords_smooth[
                                                    s_orig_idx]
                                                X = [sx0]
                                                Y = [sy0]
                                                # the RTS smoothed coords in green
                                                canv.scatter(
                                                    X,
                                                    Y,
                                                    color_rgba=(0.5, 1, 0.5,
                                                                1))

                                            if s_orig_idx == orig_idx:
                                                boxx = np.array([
                                                    s_raw_l,
                                                    s_raw_l,
                                                    s_raw_l + s_imw,
                                                    s_raw_l + s_imw,
                                                    s_raw_l,
                                                ])
                                                boxy = np.array([
                                                    s_raw_b,
                                                    s_raw_b + s_imh,
                                                    s_raw_b + s_imh,
                                                    s_raw_b,
                                                    s_raw_b,
                                                ])
                                                canv.plot(
                                                    boxx,
                                                    boxy,
                                                    color_rgba=(0.5, 1, 0.5,
                                                                1),
                                                )
                                        if show == "morphed":
                                            canv.text(
                                                "morphed %d" %
                                                (s_orig_idx - orig_idx, ),
                                                display_rect[0],
                                                (display_rect[1] +
                                                 display_rect[3] +
                                                 stack_margin - 20),
                                                font_size=font_size,
                                                color_rgba=(1, 0, 0, 1),
                                            )

                            # Display raw_im
                            display_rect = (
                                margin,
                                margin,
                                scale * raw_im.shape[1],
                                scale * raw_im.shape[0],
                            )
                            user_rect = (raw_l, raw_b, imw, imh)
                            with canv.set_user_coords(
                                    display_rect,
                                    user_rect,
                                    transform=cam_id2view[cam_id],
                            ):
                                canv.imshow(raw_im.astype(np.uint8), raw_l,
                                            raw_b)
                                canv.plot(
                                    xplt, yplt,
                                    color_rgba=(0, 1, 0,
                                                0.5))  # the orientation line
                            canv.text(
                                "raw",
                                display_rect[0],
                                display_rect[1] + display_rect[3],
                                font_size=font_size,
                                color_rgba=(0.5, 0.5, 0.9, 1),
                                shadow_offset=1,
                            )

                            # Display mean_im
                            display_rect = (
                                2 * margin + (scale * imw),
                                margin,
                                scale * mean_im.shape[1],
                                scale * mean_im.shape[0],
                            )
                            user_rect = (raw_l, raw_b, imw, imh)
                            with canv.set_user_coords(
                                    display_rect,
                                    user_rect,
                                    transform=cam_id2view[cam_id],
                            ):
                                canv.imshow(mean_im.astype(np.uint8), raw_l,
                                            raw_b)
                            canv.text(
                                "mean",
                                display_rect[0],
                                display_rect[1] + display_rect[3],
                                font_size=font_size,
                                color_rgba=(0.5, 0.5, 0.9, 1),
                                shadow_offset=1,
                            )

                            # Display absdiff_im
                            display_rect = (
                                3 * margin + (scale * imw) * 2,
                                margin,
                                scale * absdiff_im.shape[1],
                                scale * absdiff_im.shape[0],
                            )
                            user_rect = (raw_l, raw_b, imw, imh)
                            absdiff_clip = np.clip(absdiff_im * contrast_scale,
                                                   0, 255)
                            with canv.set_user_coords(
                                    display_rect,
                                    user_rect,
                                    transform=cam_id2view[cam_id],
                            ):
                                canv.imshow(absdiff_clip.astype(np.uint8),
                                            raw_l, raw_b)
                            canv.text(
                                "absdiff",
                                display_rect[0],
                                display_rect[1] + display_rect[3],
                                font_size=font_size,
                                color_rgba=(0.5, 0.5, 0.9, 1),
                                shadow_offset=1,
                            )

                            # Display morphed_im
                            display_rect = (
                                4 * margin + (scale * imw) * 3,
                                margin,
                                scale * morphed_im.shape[1],
                                scale * morphed_im.shape[0],
                            )
                            user_rect = (raw_l, raw_b, imw, imh)
                            morphed_clip = np.clip(morphed_im * contrast_scale,
                                                   0, 255)
                            with canv.set_user_coords(
                                    display_rect,
                                    user_rect,
                                    transform=cam_id2view[cam_id],
                            ):
                                canv.imshow(morphed_clip.astype(np.uint8),
                                            raw_l, raw_b)
                            if 0:
                                canv.text(
                                    "morphed",
                                    display_rect[0],
                                    display_rect[1] + display_rect[3],
                                    font_size=font_size,
                                    color_rgba=(0.5, 0.5, 0.9, 1),
                                    shadow_offset=1,
                                )

                            # Display time-averaged absdiff_im
                            display_rect = (
                                5 * margin + (scale * imw) * 4,
                                margin,
                                scale * av_im_show.shape[1],
                                scale * av_im_show.shape[0],
                            )
                            user_rect = (
                                lowerleft[0],
                                lowerleft[1],
                                av_im_show.shape[1],
                                av_im_show.shape[0],
                            )
                            with canv.set_user_coords(
                                    display_rect,
                                    user_rect,
                                    transform=cam_id2view[cam_id],
                            ):
                                canv.imshow(
                                    av_im_show.astype(np.uint8),
                                    lowerleft[0],
                                    lowerleft[1],
                                )
                                canv.plot(
                                    xplt, yplt,
                                    color_rgba=(0, 1, 0,
                                                0.5))  # the orientation line
                            canv.text(
                                "stacked/flattened",
                                display_rect[0],
                                display_rect[1] + display_rect[3],
                                font_size=font_size,
                                color_rgba=(0.5, 0.5, 0.9, 1),
                                shadow_offset=1,
                            )

                            canv.text(
                                "%s frame % 7d: eccentricity % 5.1f, min N images %d, actual N images %d"
                                % (
                                    cam_id,
                                    fno,
                                    eccentricity,
                                    stack_N_images_min,
                                    n_good_images,
                                ),
                                0,
                                15,
                                font_size=font_size,
                                color_rgba=(0.6, 0.7, 0.9, 1),
                                shadow_offset=1,
                            )
                            canv.save()

                # Save results to new table
                if 0:
                    recarray = np.rec.array(list_of_rows_of_data2d,
                                            dtype=Info2DCol_description)
                    dest_table.append(recarray)
                    dest_table.flush()
            dest_table.attrs.has_ibo_data = True
        data_file.close()
예제 #17
0
파일: wxvideo.py 프로젝트: motmot/wxvideo
    def update_image_and_drawings(self,
                                  id_val,
                                  image,
                                  format=None,
                                  points=None,
                                  linesegs=None,
                                  point_colors=None,
                                  point_radii=None,
                                  lineseg_colors=None,
                                  lineseg_widths=None,
                                  xoffset=0,
                                  yoffset=0,
                                  doresize=None):
        """update the displayed image

        **Arguments**

        id_val : string
            An identifier for the particular source being updated
        image : numpy array
            The image data to update

        **Optional keyword arguments**

        format : string
            The image format (e.g. 'MONO8', 'RGB8', or 'YUV422')
        points : list of points
            Points to display (e.g. [(x0,y0),(x1,y1)])
        linesegs : list of line segments
            Line segments to display (e.g. [(x0,y0,x1,y1),(x1,y1,x2,y2)])
        """

        # create bitmap, don't paint on screen
        if points is None:
            points = []
        if linesegs is None:
            linesegs = []
        if format is None:
            format='MONO8'
            warnings.warn('format unspecified - assuming MONO8')

        # if doresize is not input, then use the default value
        if doresize is None:
            doresize = self.doresize

        rgb8 = imops.to_rgb8(format,image)

        if doresize:
            from scipy.misc.pilutil import imresize

            # how much should we resize the image
            windowwidth = self.GetRect().GetWidth()
            windowheight = self.GetRect().GetHeight()
            imagewidth = rgb8.shape[1]
            imageheight = rgb8.shape[0]
            resizew = float(windowwidth) / float(imagewidth)
            resizeh = float(windowheight) / float(imageheight)
            self.resize = min(resizew,resizeh)
            # resize the image
            rgb8 = imresize(rgb8,self.resize)
            # scale all the points and lines
            pointscp = []
            for pt in points:
                pointscp.append([pt[0]*self.resize,pt[1]*self.resize])
            points = pointscp
            linesegscp = []
            for line in linesegs:
                linesegscp.append([line[0]*self.resize,line[1]*self.resize,line[2]*self.resize,line[3]*self.resize])
            linesegs = linesegscp

        if self.id_val is None:
            self.id_val = id_val
        if id_val != self.id_val:
            raise NotImplementedError("only 1 image source currently supported")

        h,w,three = rgb8.shape
        # get full image
        if self.full_image_numpy is not None:
            full_h, full_w, tmp = self.full_image_numpy.shape
            if h<full_h or w<full_w:
                self.full_image_numpy[yoffset:yoffset+h,xoffset:xoffset+w,:] = rgb8
                rgb8 = self.full_image_numpy
                h,w = full_h, full_w
        else:
            self.full_image_numpy = rgb8

        image = wx.EmptyImage(w,h)

        # XXX TODO could eliminate data copy here?
        image.SetData( rgb8.tostring() )
        bmp = wx.BitmapFromImage(image)

        # now draw into bmp

        drawDC = wx.MemoryDC()
        #assert drawDC.Ok(), "drawDC not OK"
        drawDC.SelectObject( bmp ) # draw into bmp
        drawDC.SetBrush(wx.Brush(wx.Colour(255,255,255), wx.TRANSPARENT))

        if self.do_draw_points and points is not None and len(points) > 0:
            if point_radii is None:
                point_radii = [ 8 ] * len(points)
            if point_colors is None:
                point_colors = [ (0,1,0) ]*len(points)
        if self.do_draw_points and linesegs is not None and len(linesegs) > 0:
            if lineseg_widths is None:
                lineseg_widths = [ 5 ] * len(linesegs)
            if lineseg_colors is None:
                lineseg_colors = [ (0,1,0) ]*len(linesegs)

        # fixing drawing point colors!!!
        if self.do_draw_points:
            for i in range(len(points)):

                # point
                pt = points[i]

                # point color
                ptcolor = point_colors[i]
                wxptcolor = wx.Colour(round(ptcolor[0]*255),
                                      round(ptcolor[1]*255),
                                      round(ptcolor[2]*255))

                # radius of point
                ptradius = point_radii[i]

                # draw it
                drawDC.SetPen(wx.Pen(colour=wxptcolor,
                                     width=1))
                drawDC.DrawCircle(int(pt[0]),int(pt[1]),ptradius)

            for i in range(len(linesegs)):
                lineseg = linesegs[i]
                linesegcolor = lineseg_colors[i]
                wxlinesegcolor = wx.Colour(round(linesegcolor[0]*255),
                                           round(linesegcolor[1]*255),
                                           round(linesegcolor[2]*255))
                linesegwidth = lineseg_widths[i]

                drawDC.SetPen(wx.Pen(colour=wxlinesegcolor,
                                     width=linesegwidth))
                if len(lineseg)<=4:
                    drawDC.DrawLine(*lineseg)
                else:
                    for start_idx in range(0, len(lineseg)-3, 2):
                        this_seg = lineseg[start_idx:start_idx+4]
                        drawDC.DrawLine(*this_seg)

        if id_val in self.lbrt:
            drawDC.SetPen(wx.Pen('GREEN',width=1))
            l,b,r,t = self.lbrt[id_val]
            drawDC.DrawLine(l,b, r,b)
            drawDC.DrawLine(r,b, r,t)
            drawDC.DrawLine(r,t, l,t)
            drawDC.DrawLine(l,t, l,b)

        img = wx.ImageFromBitmap(bmp)
        if self.mirror_display:
            if not self.display_rotate_180:
                img = img.Rotate90()
                img = img.Rotate90()
        else:
            img = img.Mirror(True)
            if not self.display_rotate_180:
                img = img.Rotate90()
                img = img.Rotate90()
        bmp = wx.BitmapFromImage(img)

        self.bitmap = bmp
예제 #18
0
    def update_image_and_drawings(self,
                                  id_val,
                                  image,
                                  format=None,
                                  points=None,
                                  linesegs=None,
                                  point_colors=None,
                                  point_radii=None,
                                  lineseg_colors=None,
                                  lineseg_widths=None,
                                  xoffset=0,
                                  yoffset=0,
                                  doresize=None):

        # create bitmap, don't paint on screen
        if points is None:
            points = []
        if linesegs is None:
            linesegs = []
        if format is None:
            raise ValueError("must specify format")

        # if doresize is not input, then use the default value
        if doresize is None:
            doresize = self.doresize

        rgb8 = imops.to_rgb8(format,image)

        if doresize:
            # how much should we resize the image
            windowwidth = self.GetRect().GetWidth()
            windowheight = self.GetRect().GetHeight()
            imagewidth = rgb8.shape[1]
            imageheight = rgb8.shape[0]
            resizew = float(windowwidth) / float(imagewidth)
            resizeh = float(windowheight) / float(imageheight)
            self.resize = min(resizew,resizeh)
            # resize the image
            rgb8 = imresize(rgb8,self.resize)
            # scale all the points and lines
            pointscp = []
            for pt in points:
                pointscp.append([pt[0]*self.resize,pt[1]*self.resize])
            points = pointscp
            linesegscp = []
            for line in linesegs:
                linesegscp.append([line[0]*self.resize,line[1]*self.resize,line[2]*self.resize,line[3]*self.resize])
            linesegs = linesegscp

        if self.id_val is None:
            self.id_val = id_val
        if id_val != self.id_val:
            raise NotImplementedError("only 1 image source currently supported")

        h,w,three = rgb8.shape
        # get full image
        if self.full_image_numpy is not None:
            full_h, full_w, tmp = self.full_image_numpy.shape
            if h<full_h or w<full_w:
                self.full_image_numpy[yoffset:yoffset+h,xoffset:xoffset+w,:] = rgb8
                rgb8 = self.full_image_numpy
                h,w = full_h, full_w
        else:
            self.full_image_numpy = rgb8

        image = wx.EmptyImage(w,h)

        # XXX TODO could eliminate data copy here?
        image.SetData( rgb8.tostring() )
        bmp = wx.BitmapFromImage(image)

        # now draw into bmp

        drawDC = wx.MemoryDC()
        #assert drawDC.Ok(), "drawDC not OK"
        drawDC.SelectObject( bmp ) # draw into bmp
        drawDC.SetBrush(wx.Brush(wx.Colour(255,255,255), wx.TRANSPARENT))

        if self.do_draw_points and points is not None and len(points) > 0:
            if point_radii is None:
                point_radii = [ 8 ] * len(points)
            if point_colors is None:
                point_colors = [ (0,1,0) ]*len(points)
        if self.do_draw_points and linesegs is not None and len(linesegs) > 0:
            if lineseg_widths is None:
                lineseg_widths = [ 1 ] * len(linesegs)
            if lineseg_colors is None:
                lineseg_colors = [ (0,1,0) ]*len(linesegs)
        
        #point_radius=8
        # fixing drawing point colors!!!
        if self.do_draw_points:
            for i in range(len(points)):

                # point
                pt = points[i]

                # point color
                ptcolor = point_colors[i]
                wxptcolor = wx.Colour(round(ptcolor[0]*255),
                                      round(ptcolor[1]*255),
                                      round(ptcolor[2]*255))
                
                # radius of point
                ptradius = point_radii[i]

                # draw it
                drawDC.SetPen(wx.Pen(colour=wxptcolor,
                                     width=ptradius))
                drawDC.DrawCircle(int(pt[0]),int(pt[1]),ptradius)

            for i in range(len(linesegs)):
                lineseg = linesegs[i]
                linesegcolor = lineseg_colors[i]
                wxlinesegcolor = wx.Colour(round(linesegcolor[0]*255),
                                           round(linesegcolor[1]*255),
                                           round(linesegcolor[2]*255))
                linesegwidth = lineseg_widths[i]

                drawDC.SetPen(wx.Pen(colour=wxlinesegcolor,
                                     width=linesegwidth))
                drawDC.DrawLine(*lineseg)

        if id_val in self.lbrt:
            drawDC.SetPen(wx.Pen('GREEN',width=1))
            l,b,r,t = self.lbrt[id_val]
            drawDC.DrawLine(l,b, r,b)
            drawDC.DrawLine(r,b, r,t)
            drawDC.DrawLine(r,t, l,t)
            drawDC.DrawLine(l,t, l,b)

        img = wx.ImageFromBitmap(bmp)
        if self.mirror_display:
            if not self.display_rotate_180:
                img = img.Rotate90()
                img = img.Rotate90()
        else:
            img = img.Mirror(True)
            if not self.display_rotate_180:
                img = img.Rotate90()
                img = img.Rotate90()
        bmp = wx.BitmapFromImage(img)

        self.bitmap = bmp
예제 #19
0
def iterate_frames(h5_filename,
                   ufmf_fnames, # or fmfs
                   white_background=False,
                   max_n_frames = None,
                   start = None,
                   stop = None,
                   rgb8_if_color=False,
                   movie_cam_ids=None,
                   camn2cam_id = None,
                   ):
    """yield frame-by-frame data"""

    # First pass over .ufmf files: get intersection of timestamps
    first_ufmf_ts = -np.inf
    last_ufmf_ts = np.inf
    ufmfs = {}
    cam_ids = []
    global_data = {'width_heights': {}}
    for movie_idx,ufmf_fname in enumerate(ufmf_fnames):
        if movie_cam_ids is not None:
            cam_id = movie_cam_ids[movie_idx]
        else:
            cam_id = get_cam_id_from_ufmf_fname(ufmf_fname)
        cam_ids.append( cam_id )
        kwargs = {}
        extra = {}
        if ufmf_fname.lower().endswith('.fmf'):
            ufmf = fmf_mod.FlyMovie(ufmf_fname)
            bg_fmf_filename = os.path.splitext(ufmf_fname)[0] + '_mean.fmf'
            if os.path.exists(bg_fmf_filename):
                extra['bg_fmf'] = fmf_mod.FlyMovie(bg_fmf_filename)
                extra['bg_tss'] = extra['bg_fmf'].get_all_timestamps()
                extra['bg_fmf'].seek(0)

        else:
            ufmf = ufmf_mod.FlyMovieEmulator(ufmf_fname,
                                             white_background=white_background,
                                             **kwargs)

        global_data['width_heights'][cam_id] = ( ufmf.get_width(), ufmf.get_height() )
        tss = ufmf.get_all_timestamps()
        ufmf.seek(0)
        ufmfs[ufmf_fname] = (ufmf, cam_id, tss, extra)
        min_ts = np.min(tss)
        max_ts = np.max(tss)
        if min_ts > first_ufmf_ts:
            first_ufmf_ts = min_ts
        if max_ts < last_ufmf_ts:
            last_ufmf_ts = max_ts

    assert first_ufmf_ts < last_ufmf_ts, ".ufmf files don't all overlap in time"

    ufmf_fnames.sort()
    cam_ids.sort()

    with open_file_safe( h5_filename, mode='r' ) as h5:
        if camn2cam_id is None:
            camn2cam_id, cam_id2camns = result_utils.get_caminfo_dicts(h5)
        parsed = result_utils.read_textlog_header(h5)
        flydra_version = parsed.get('flydra_version',None)
        if flydra_version is not None and flydra_version >= '0.4.45':
            # camnode.py saved timestamps into .ufmf file given by
            # time.time() (camn_receive_timestamp). Compare with
            # mainbrain's data2d_distorted column
            # 'cam_received_timestamp'.
            old_camera_timestamp_source = False
            timestamp_name = 'cam_received_timestamp'
        else:
            # camnode.py saved timestamps into .ufmf file given by
            # camera driver. Compare with mainbrain's data2d_distorted
            # column 'timestamp'.
            old_camera_timestamp_source = True
            timestamp_name = 'timestamp'

        h5_data = h5.root.data2d_distorted[:]

    if 1:
        # narrow search to local region of .h5
        cond = ((first_ufmf_ts <= h5_data[timestamp_name]) &
                (h5_data[timestamp_name] <= last_ufmf_ts))
        narrow_h5_data = h5_data[cond]

        narrow_camns = narrow_h5_data['camn']
        narrow_timestamps = narrow_h5_data[timestamp_name]

        # Find the camn for each .ufmf file
        cam_id2camn = {}
        for cam_id in cam_ids:
            cam_id_camn_already_found = False
            for ufmf_fname in ufmfs.keys():
                (ufmf, test_cam_id, tss, extra) = ufmfs[ufmf_fname]
                if cam_id != test_cam_id:
                    continue
                assert not cam_id_camn_already_found
                cam_id_camn_already_found = True

                umin=np.min(tss)
                umax=np.max(tss)
                cond = (umin<=narrow_timestamps) & (narrow_timestamps<=umax)
                ucamns = narrow_camns[cond]
                ucamns = np.unique(ucamns)
                camns = []
                for camn in ucamns:
                    if camn2cam_id[camn]==cam_id:
                        camns.append(camn)

                assert len(camns)<2, "can't handle multiple camns per cam_id"
                if len(camns):
                    cam_id2camn[cam_id] = camns[0]

        ff = utils.FastFinder(narrow_h5_data['frame'])
        unique_frames = list(np.unique(narrow_h5_data['frame']))
        unique_frames.sort()
        unique_frames = np.array( unique_frames )
        if start is not None:
            unique_frames = unique_frames[ unique_frames >= start ]
        if stop is not None:
            unique_frames = unique_frames[ unique_frames <= stop ]

        if max_n_frames is not None:
            unique_frames = unique_frames[:max_n_frames]
        for frame_enum,frame in enumerate(unique_frames):
            narrow_idxs = ff.get_idxs_of_equal(frame)

            # trim data under consideration to just this frame
            this_h5_data = narrow_h5_data[narrow_idxs]
            this_camns = this_h5_data['camn']
            this_tss = this_h5_data[timestamp_name]

            # a couple more checks
            if np.any( this_tss < first_ufmf_ts):
                continue
            if np.any( this_tss >= last_ufmf_ts):
                break

            per_frame_dict = {}
            for ufmf_fname in ufmf_fnames:
                ufmf, cam_id, tss, extra = ufmfs[ufmf_fname]
                if cam_id not in cam_id2camn:
                    continue
                camn = cam_id2camn[cam_id]
                this_camn_cond = this_camns == camn
                this_cam_h5_data = this_h5_data[this_camn_cond]
                this_camn_tss = this_cam_h5_data[timestamp_name]
                if not len(this_camn_tss):
                    # no h5 data for this cam_id at this frame
                    continue
                this_camn_ts=np.unique(this_camn_tss)
                assert len(this_camn_ts)==1
                this_camn_ts = this_camn_ts[0]

                if isinstance(ufmf, ufmf_mod.FlyMovieEmulator):
                    is_real_ufmf = True
                else:
                    is_real_ufmf = False

                # optimistic: get next frame. it's probably the one we want
                try:
                    if is_real_ufmf:
                        image,image_ts,more  = ufmf.get_next_frame(_return_more=True)
                    else:
                        image,image_ts = ufmf.get_next_frame()
                        more = fill_more_for( extra, image_ts )
                except ufmf_mod.NoMoreFramesException:
                    image_ts = None
                if this_camn_ts != image_ts:
                    # It was not the frame we wanted. Find it.
                    ufmf_frame_idxs = np.nonzero(tss == this_camn_ts)[0]
                    if (len(ufmf_frame_idxs)==0 and
                        old_camera_timestamp_source):
                        warnings.warn(
                            'low-precision timestamp comparison in '
                            'use due to outdated .ufmf timestamp '
                            'saving.')
                        # 2.5 msec precision required
                        ufmf_frame_idxs = np.nonzero(
                            abs( tss - this_camn_ts ) < 0.0025)[0]
                    assert len(ufmf_frame_idxs)==1
                    ufmf_frame_no = ufmf_frame_idxs[0]
                    if is_real_ufmf:
                        image,image_ts,more = ufmf.get_frame(ufmf_frame_no,
                                                             _return_more=True)
                    else:
                        image,image_ts = ufmf.get_frame(ufmf_frame_no)
                        more = fill_more_for( extra, image_ts )

                    del ufmf_frame_no, ufmf_frame_idxs
                coding = ufmf.get_format()
                if imops.is_coding_color(coding):
                    if rgb8_if_color:
                        image = imops.to_rgb8(coding,image)
                    else:
                        warnings.warn('color image not converted to color')
                per_frame_dict[ufmf_fname] = {
                    'image':image,
                    'cam_id':cam_id,
                    'camn':camn,
                    'timestamp':this_cam_h5_data['timestamp'][0],
                    'cam_received_timestamp':
                    this_cam_h5_data['cam_received_timestamp'][0],
                    'ufmf_frame_timestamp':this_cam_h5_data[timestamp_name][0],
                    }
                if more is not None:
                    per_frame_dict[ufmf_fname].update(more)
            per_frame_dict['tracker_data']=this_h5_data
            per_frame_dict['global_data']=global_data # on every iteration, pass our global data
            yield (per_frame_dict,frame)
예제 #20
0
def iterate_frames(h5_filename,
                   ufmf_fnames, # or fmfs
                   white_background=False,
                   max_n_frames = None,
                   start = None,
                   stop = None,
                   rgb8_if_color=False,
                   movie_cam_ids=None,
                   camn2cam_id = None,
                   ):
    """yield frame-by-frame data"""

    h5_filename = '/home/caveman/DATA/20120924_HCS_odor_horizon/data/h5_files/DATA20121002_184808.h5'
    ufmf_fnames = ['/home/caveman/DATA/20120924_HCS_odor_horizon/data/ufmfs/small_20121002_184626_Basler_21111538.ufmf']
    white_background=False
    max_n_frames = None
    start = None
    stop = None
    rgb8_if_color=False
    movie_cam_ids=['Basler_21111538']
    camn2cam_id = None

    # First pass over .ufmf files: get intersection of timestamps
    first_ufmf_ts = -np.inf
    last_ufmf_ts = np.inf
    ufmfs = {}
    cam_ids = []
    for movie_idx,ufmf_fname in enumerate(ufmf_fnames):
        if movie_cam_ids is not None:
            cam_id = movie_cam_ids[movie_idx]
        else:
            cam_id = get_cam_id_from_ufmf_fname(ufmf_fname)
        cam_ids.append( cam_id )
        kwargs = {}
        if ufmf_fname.lower().endswith('.fmf'):
            ufmf = fmf_mod.FlyMovie(ufmf_fname)
        else:
            ufmf = ufmf_mod.FlyMovieEmulator(ufmf_fname,
                                             white_background=white_background,
                                             **kwargs)
        tss = ufmf.get_all_timestamps()
        ufmfs[ufmf_fname] = (ufmf, cam_id, tss)
        min_ts = np.min(tss)
        max_ts = np.max(tss)
        if min_ts > first_ufmf_ts:
            first_ufmf_ts = min_ts
        if max_ts < last_ufmf_ts:
            last_ufmf_ts = max_ts

    #assert first_ufmf_ts < last_ufmf_ts, ".ufmf files don't all overlap in time"

    #ufmf_fnames.sort()
    #cam_ids.sort()

    with openFileSafe( h5_filename, mode='r' ) as h5:
        if camn2cam_id is None:
            camn2cam_id, cam_id2camns = result_utils.get_caminfo_dicts(h5)
        parsed = result_utils.read_textlog_header(h5)
        flydra_version = parsed.get('flydra_version',None)
        if flydra_version is not None and flydra_version >= '0.4.45':
            # camnode.py saved timestamps into .ufmf file given by
            # time.time() (camn_receive_timestamp). Compare with
            # mainbrain's data2d_distorted column
            # 'cam_received_timestamp'.
            old_camera_timestamp_source = False
            timestamp_name = 'cam_received_timestamp'
        else:
            # camnode.py saved timestamps into .ufmf file given by
            # camera driver. Compare with mainbrain's data2d_distorted
            # column 'timestamp'.
            old_camera_timestamp_source = True
            timestamp_name = 'timestamp'

        h5_data = h5.root.data2d_distorted[:]

    if 1:
        # narrow search to local region of .h5
        #cond = ((first_ufmf_ts <= h5_data[timestamp_name]) &
        #        (h5_data[timestamp_name] <= last_ufmf_ts))
        #narrow_h5_data = h5_data[cond]
        
        narrow_h5_data = h5_data

        narrow_camns = h5_data['camn']
        narrow_timestamps = h5_data[timestamp_name]
        # Find the camn for each .ufmf file
        cam_id2camn = {}
        for cam_id in cam_ids:
            cam_id_camn_already_found = False
            for ufmf_fname in ufmfs.keys():
                (ufmf, test_cam_id, tss) = ufmfs[ufmf_fname]
                if cam_id != test_cam_id:
                    continue
                assert not cam_id_camn_already_found
                cam_id_camn_already_found = True

                umin=np.min(tss)
                umax=np.max(tss)
                cond = (umin<=narrow_timestamps) & (narrow_timestamps<=umax)
                ucamns = narrow_camns[cond]
                ucamns = np.unique(ucamns)
                camns = []
                for camn in ucamns:
                    if camn2cam_id[camn]==cam_id:
                        camns.append(camn)

                assert len(camns)<2, "can't handle multiple camns per cam_id"
                if len(camns):
                    cam_id2camn[cam_id] = camns[0]

        ff = utils.FastFinder(narrow_h5_data['frame'])
        unique_frames = list(np.unique(narrow_h5_data['frame']))
        unique_frames.sort()
        unique_frames = np.array( unique_frames )
        if start is not None:
            unique_frames = unique_frames[ unique_frames >= start ]
        if stop is not None:
            unique_frames = unique_frames[ unique_frames <= stop ]

        if max_n_frames is not None:
            unique_frames = unique_frames[:max_n_frames]
        for frame_enum,frame in enumerate(unique_frames):
            narrow_idxs = ff.get_idxs_of_equal(frame)

            # trim data under consideration to just this frame
            this_h5_data = narrow_h5_data[narrow_idxs]
            this_camns = this_h5_data['camn']
            this_tss = this_h5_data[timestamp_name]

            # a couple more checks
            if np.any( this_tss < first_ufmf_ts):
                continue
            if np.any( this_tss >= last_ufmf_ts):
                break

            per_frame_dict = {}
            for ufmf_fname in ufmf_fnames:
                ufmf, cam_id, tss = ufmfs[ufmf_fname]
                if cam_id not in cam_id2camn:
                    continue
                camn = cam_id2camn[cam_id]
                this_camn_cond = this_camns == camn
                this_cam_h5_data = this_h5_data[this_camn_cond]
                this_camn_tss = this_cam_h5_data[timestamp_name]
                if not len(this_camn_tss):
                    # no h5 data for this cam_id at this frame
                    continue
                this_camn_ts=np.unique1d(this_camn_tss)
                assert len(this_camn_ts)==1
                this_camn_ts = this_camn_ts[0]

                # optimistic: get next frame. it's probably the one we want
                try:
                    image,image_ts,more = ufmf.get_next_frame(_return_more=True)
                except ufmf_mod.NoMoreFramesException:
                    image_ts = None
                if this_camn_ts != image_ts:
                    # It was not the frame we wanted. Find it.
                    ufmf_frame_idxs = np.nonzero(tss == this_camn_ts)[0]
                    if (len(ufmf_frame_idxs)==0 and
                        old_camera_timestamp_source):
                        warnings.warn(
                            'low-precision timestamp comparison in '
                            'use due to outdated .ufmf timestamp '
                            'saving.')
                        # 2.5 msec precision required
                        ufmf_frame_idxs = np.nonzero(
                            abs( tss - this_camn_ts ) < 0.0025)[0]
                    assert len(ufmf_frame_idxs)==1
                    ufmf_frame_no = ufmf_frame_idxs[0]
                    image,image_ts,more = ufmf.get_frame(ufmf_frame_no,
                                                         _return_more=True)
                    del ufmf_frame_no, ufmf_frame_idxs
                coding = ufmf.get_format()
                if imops.is_coding_color(coding):
                    if rgb8_if_color:
                        image = imops.to_rgb8(coding,image)
                    else:
                        warnings.warn('color image not converted to color')
                per_frame_dict[ufmf_fname] = {
                    'image':image,
                    'cam_id':cam_id,
                    'camn':camn,
                    'timestamp':this_cam_h5_data['timestamp'][0],
                    'cam_received_timestamp':
                    this_cam_h5_data['cam_received_timestamp'][0],
                    'ufmf_frame_timestamp':this_cam_h5_data[timestamp_name][0],
                    }
                per_frame_dict[ufmf_fname].update(more)
            per_frame_dict['tracker_data']=this_h5_data
            yield (per_frame_dict,frame)