Пример #1
0
    def __init__(self, src_h5, dst_h5):
        self.ca = core_analysis.get_global_CachingAnalyzer()
        self.src_h5 = src_h5
        self.dst_h5 = dst_h5
        self.ca.initial_file_load(self.src_h5)

        tmp = self.ca.initial_file_load(self.dst_h5)
        obj_ids, unique_obj_ids, is_mat_file, data_file, extra = tmp

        self.dst_frames = extra["frames"]
        self.dst_obj_ids = obj_ids
        self.dst_unique_obj_ids = unique_obj_ids
        self.ff = utils.FastFinder(self.dst_frames)
Пример #2
0
def convert(
        infilename,
        outfilename,
        frames_per_second=None,
        save_timestamps=True,
        file_time_data=None,
        do_nothing=False,  # set to true to test for file existance
        start_obj_id=None,
        stop_obj_id=None,
        obj_only=None,
        dynamic_model_name=None,
        hdf5=False,
        show_progress=False,
        show_progress_json=False,
        **kwargs):
    if start_obj_id is None:
        start_obj_id = -numpy.inf
    if stop_obj_id is None:
        stop_obj_id = numpy.inf

    smoothed_data_filename = os.path.split(infilename)[1]
    raw_data_filename = smoothed_data_filename

    ca = core_analysis.get_global_CachingAnalyzer()
    with ca.kalman_analysis_context(infilename,
                                    data2d_fname=file_time_data) as h5_context:

        extra_vars = {}
        tzname = None

        if save_timestamps:
            print 'STAGE 1: finding timestamps'
            table_kobs = h5_context.get_pytable_node('ML_estimates')

            tzname = h5_context.get_tzname0()
            fps = h5_context.get_fps()

            try:
                table_data2d = h5_context.get_pytable_node('data2d_distorted',
                                                           from_2d_file=True)
            except tables.exceptions.NoSuchNodeError, err:
                print >> sys.stderr, "No timestamps in file. Either specify not to save timestamps ('--no-timestamps') or specify the original .h5 file with the timestamps ('--time-data=FILE2D')"
                sys.exit(1)

            print 'caching raw 2D data...',
            sys.stdout.flush()
            table_data2d_frames = table_data2d.read(field='frame')
            assert numpy.max(table_data2d_frames) < 2**63
            table_data2d_frames = table_data2d_frames.astype(numpy.int64)
            #table_data2d_frames_find = fastsearch.binarysearch.BinarySearcher( table_data2d_frames )
            table_data2d_frames_find = utils.FastFinder(table_data2d_frames)
            table_data2d_camns = table_data2d.read(field='camn')
            table_data2d_timestamps = table_data2d.read(field='timestamp')
            print 'done'
            print '(cached index of %d frame values of dtype %s)' % (
                len(table_data2d_frames), str(table_data2d_frames.dtype))

            drift_estimates = h5_context.get_drift_estimates()
            camn2cam_id, cam_id2camns = h5_context.get_caminfo_dicts()

            gain = {}
            offset = {}
            print 'hostname time_gain time_offset'
            print '-------- --------- -----------'
            for i, hostname in enumerate(drift_estimates.get('hostnames', [])):
                tgain, toffset = result_utils.model_remote_to_local(
                    drift_estimates['remote_timestamp'][hostname][::10],
                    drift_estimates['local_timestamp'][hostname][::10])
                gain[hostname] = tgain
                offset[hostname] = toffset
                print '  ', repr(hostname), tgain, toffset
            print

            if do_nothing:
                return

            print 'caching Kalman obj_ids...'
            obs_obj_ids = table_kobs.read(field='obj_id')
            fast_obs_obj_ids = utils.FastFinder(obs_obj_ids)
            print 'finding unique obj_ids...'
            unique_obj_ids = numpy.unique(obs_obj_ids)
            print '(found %d)' % (len(unique_obj_ids), )
            unique_obj_ids = unique_obj_ids[unique_obj_ids >= start_obj_id]
            unique_obj_ids = unique_obj_ids[unique_obj_ids <= stop_obj_id]

            if obj_only is not None:
                unique_obj_ids = numpy.array(
                    [oid for oid in unique_obj_ids if oid in obj_only])
                print 'filtered to obj_only', obj_only

            print '(will export %d)' % (len(unique_obj_ids), )
            print 'finding 2d data for each obj_id...'
            timestamp_time = numpy.zeros(unique_obj_ids.shape,
                                         dtype=numpy.float64)
            table_kobs_frame = table_kobs.read(field='frame')
            if len(table_kobs_frame) == 0:
                raise ValueError('no 3D data, cannot convert')
            assert numpy.max(table_kobs_frame) < 2**63
            table_kobs_frame = table_kobs_frame.astype(numpy.int64)
            assert table_kobs_frame.dtype == table_data2d_frames.dtype  # otherwise very slow

            all_idxs = fast_obs_obj_ids.get_idx_of_equal(unique_obj_ids)
            for obj_id_enum, obj_id in enumerate(unique_obj_ids):
                idx0 = all_idxs[obj_id_enum]
                framenumber = table_kobs_frame[idx0]
                remote_timestamp = numpy.nan

                this_camn = None
                frame_idxs = table_data2d_frames_find.get_idxs_of_equal(
                    framenumber)
                if len(frame_idxs):
                    frame_idx = frame_idxs[0]
                    this_camn = table_data2d_camns[frame_idx]
                    remote_timestamp = table_data2d_timestamps[frame_idx]

                if this_camn is None:
                    print 'skipping frame %d (obj %d): no data2d_distorted data' % (
                        framenumber, obj_id)
                    continue

                cam_id = camn2cam_id[this_camn]
                try:
                    remote_hostname = cam_id2hostname(cam_id, h5_context)
                except ValueError, e:
                    print 'error getting hostname of cam: %s' % e.message
                    continue
                if remote_hostname not in gain:
                    warnings.warn('no host %s in timestamp data. making up '
                                  'data.' % remote_hostname)
                    gain[remote_hostname] = 1.0
                    offset[remote_hostname] = 0.0
                mainbrain_timestamp = remote_timestamp * gain[
                    remote_hostname] + offset[
                        remote_hostname]  # find mainbrain timestamp

                timestamp_time[obj_id_enum] = mainbrain_timestamp

            extra_vars['obj_ids'] = unique_obj_ids
            extra_vars['timestamps'] = timestamp_time

            print 'STAGE 2: running Kalman smoothing operation'
Пример #3
0
    def find_equiv(self, src_obj_id, mean_distance_maximum=None):
        """find the obj_id in the dst file that corresponds to src_obj_id

        arguments
        ---------
        src_obj_id : int
            The obj_id of the object in src_h5 to find.
        mean_distance_maximum : float or None
            The maximum average distance between points in dst and src.

        returns
        -------
        dst_obj_id : int
            The obj_id in dst_h5 that corresponds to the src_obj_id
        """
        # get information from source to identify trace in dest
        src_rows = self.ca.load_data(src_obj_id,
                                     self.src_h5,
                                     use_kalman_smoothing=False)
        src_frame = src_rows["frame"]

        if len(src_frame) < 2:
            raise ValueError("Can only find equivalent obj_id if "
                             "2 or more frames present")

        src_X = np.vstack((src_rows["x"], src_rows["y"], src_rows["z"]))
        src_timestamp = src_rows["timestamp"]

        candidate_obj_id = set()
        for f in src_frame:
            idxs = self.ff.get_idxs_of_equal(f)
            for obj_id in self.dst_obj_ids[idxs]:
                candidate_obj_id.add(obj_id)

        candidate_obj_id = list(candidate_obj_id)
        ## print 'candidate_obj_id',candidate_obj_id
        error = []
        for obj_id in candidate_obj_id:
            # get array for each candidation obj_id in destination
            dst_rows = self.ca.load_data(obj_id,
                                         self.dst_h5,
                                         use_kalman_smoothing=False)
            dst_frame = dst_rows["frame"]
            dst_X = np.vstack((dst_rows["x"], dst_rows["y"], dst_rows["z"]))
            dst_ff = utils.FastFinder(dst_frame)

            # get indices into destination array for each frame of source
            dst_idxs = dst_ff.get_idx_of_equal(src_frame, missing_ok=1)

            assert len(dst_idxs) == len(src_frame)

            missing_cond = dst_idxs == -1  # these points are in source but not dest
            n_missing = np.sum(missing_cond)
            n_total = len(src_frame)
            present_cond = ~missing_cond

            final_dst_idxs = dst_idxs[present_cond]
            final_src_idxs = np.arange(len(src_frame))[present_cond]

            src_X_i = src_X[:, final_src_idxs]
            dst_X_i = dst_X[:, final_dst_idxs]

            diff = src_X_i - dst_X_i
            dist = np.sqrt(np.sum(diff**2, axis=0))
            av_dist = np.mean(dist)

            frac_missing = n_missing / float(
                n_total)  # 0 = none missing, 1 = all
            ## print 'candidate dst obj_id %d: %s dist, %s missing'%(
            ##     obj_id, av_dist, frac_missing)
            if frac_missing > 0.1:
                this_error = np.inf
            else:
                this_error = av_dist
            error.append(this_error)
        idx = np.argmin(error)
        best_error = error[idx]
        if not np.isfinite(best_error):
            return None  # could not find answer
        else:
            if (mean_distance_maximum is None) or (best_error <=
                                                   mean_distance_maximum):
                return candidate_obj_id[idx]
            else:
                return None
Пример #4
0
def iterate_frames(h5_filename,
                   ufmf_fnames, # or fmfs
                   white_background=False,
                   max_n_frames = None,
                   start = None,
                   stop = None,
                   rgb8_if_color=False,
                   movie_cam_ids=None,
                   camn2cam_id = None,
                   ):
    """yield frame-by-frame data"""

    # First pass over .ufmf files: get intersection of timestamps
    first_ufmf_ts = -np.inf
    last_ufmf_ts = np.inf
    ufmfs = {}
    cam_ids = []
    global_data = {'width_heights': {}}
    for movie_idx,ufmf_fname in enumerate(ufmf_fnames):
        if movie_cam_ids is not None:
            cam_id = movie_cam_ids[movie_idx]
        else:
            cam_id = get_cam_id_from_ufmf_fname(ufmf_fname)
        cam_ids.append( cam_id )
        kwargs = {}
        extra = {}
        if ufmf_fname.lower().endswith('.fmf'):
            ufmf = fmf_mod.FlyMovie(ufmf_fname)
            bg_fmf_filename = os.path.splitext(ufmf_fname)[0] + '_mean.fmf'
            if os.path.exists(bg_fmf_filename):
                extra['bg_fmf'] = fmf_mod.FlyMovie(bg_fmf_filename)
                extra['bg_tss'] = extra['bg_fmf'].get_all_timestamps()
                extra['bg_fmf'].seek(0)

        else:
            ufmf = ufmf_mod.FlyMovieEmulator(ufmf_fname,
                                             white_background=white_background,
                                             **kwargs)

        global_data['width_heights'][cam_id] = ( ufmf.get_width(), ufmf.get_height() )
        tss = ufmf.get_all_timestamps()
        ufmf.seek(0)
        ufmfs[ufmf_fname] = (ufmf, cam_id, tss, extra)
        min_ts = np.min(tss)
        max_ts = np.max(tss)
        if min_ts > first_ufmf_ts:
            first_ufmf_ts = min_ts
        if max_ts < last_ufmf_ts:
            last_ufmf_ts = max_ts

    assert first_ufmf_ts < last_ufmf_ts, ".ufmf files don't all overlap in time"

    ufmf_fnames.sort()
    cam_ids.sort()

    with open_file_safe( h5_filename, mode='r' ) as h5:
        if camn2cam_id is None:
            camn2cam_id, cam_id2camns = result_utils.get_caminfo_dicts(h5)
        parsed = result_utils.read_textlog_header(h5)
        flydra_version = parsed.get('flydra_version',None)
        if flydra_version is not None and flydra_version >= '0.4.45':
            # camnode.py saved timestamps into .ufmf file given by
            # time.time() (camn_receive_timestamp). Compare with
            # mainbrain's data2d_distorted column
            # 'cam_received_timestamp'.
            old_camera_timestamp_source = False
            timestamp_name = 'cam_received_timestamp'
        else:
            # camnode.py saved timestamps into .ufmf file given by
            # camera driver. Compare with mainbrain's data2d_distorted
            # column 'timestamp'.
            old_camera_timestamp_source = True
            timestamp_name = 'timestamp'

        h5_data = h5.root.data2d_distorted[:]

    if 1:
        # narrow search to local region of .h5
        cond = ((first_ufmf_ts <= h5_data[timestamp_name]) &
                (h5_data[timestamp_name] <= last_ufmf_ts))
        narrow_h5_data = h5_data[cond]

        narrow_camns = narrow_h5_data['camn']
        narrow_timestamps = narrow_h5_data[timestamp_name]

        # Find the camn for each .ufmf file
        cam_id2camn = {}
        for cam_id in cam_ids:
            cam_id_camn_already_found = False
            for ufmf_fname in ufmfs.keys():
                (ufmf, test_cam_id, tss, extra) = ufmfs[ufmf_fname]
                if cam_id != test_cam_id:
                    continue
                assert not cam_id_camn_already_found
                cam_id_camn_already_found = True

                umin=np.min(tss)
                umax=np.max(tss)
                cond = (umin<=narrow_timestamps) & (narrow_timestamps<=umax)
                ucamns = narrow_camns[cond]
                ucamns = np.unique(ucamns)
                camns = []
                for camn in ucamns:
                    if camn2cam_id[camn]==cam_id:
                        camns.append(camn)

                assert len(camns)<2, "can't handle multiple camns per cam_id"
                if len(camns):
                    cam_id2camn[cam_id] = camns[0]

        ff = utils.FastFinder(narrow_h5_data['frame'])
        unique_frames = list(np.unique(narrow_h5_data['frame']))
        unique_frames.sort()
        unique_frames = np.array( unique_frames )
        if start is not None:
            unique_frames = unique_frames[ unique_frames >= start ]
        if stop is not None:
            unique_frames = unique_frames[ unique_frames <= stop ]

        if max_n_frames is not None:
            unique_frames = unique_frames[:max_n_frames]
        for frame_enum,frame in enumerate(unique_frames):
            narrow_idxs = ff.get_idxs_of_equal(frame)

            # trim data under consideration to just this frame
            this_h5_data = narrow_h5_data[narrow_idxs]
            this_camns = this_h5_data['camn']
            this_tss = this_h5_data[timestamp_name]

            # a couple more checks
            if np.any( this_tss < first_ufmf_ts):
                continue
            if np.any( this_tss >= last_ufmf_ts):
                break

            per_frame_dict = {}
            for ufmf_fname in ufmf_fnames:
                ufmf, cam_id, tss, extra = ufmfs[ufmf_fname]
                if cam_id not in cam_id2camn:
                    continue
                camn = cam_id2camn[cam_id]
                this_camn_cond = this_camns == camn
                this_cam_h5_data = this_h5_data[this_camn_cond]
                this_camn_tss = this_cam_h5_data[timestamp_name]
                if not len(this_camn_tss):
                    # no h5 data for this cam_id at this frame
                    continue
                this_camn_ts=np.unique(this_camn_tss)
                assert len(this_camn_ts)==1
                this_camn_ts = this_camn_ts[0]

                if isinstance(ufmf, ufmf_mod.FlyMovieEmulator):
                    is_real_ufmf = True
                else:
                    is_real_ufmf = False

                # optimistic: get next frame. it's probably the one we want
                try:
                    if is_real_ufmf:
                        image,image_ts,more  = ufmf.get_next_frame(_return_more=True)
                    else:
                        image,image_ts = ufmf.get_next_frame()
                        more = fill_more_for( extra, image_ts )
                except ufmf_mod.NoMoreFramesException:
                    image_ts = None
                if this_camn_ts != image_ts:
                    # It was not the frame we wanted. Find it.
                    ufmf_frame_idxs = np.nonzero(tss == this_camn_ts)[0]
                    if (len(ufmf_frame_idxs)==0 and
                        old_camera_timestamp_source):
                        warnings.warn(
                            'low-precision timestamp comparison in '
                            'use due to outdated .ufmf timestamp '
                            'saving.')
                        # 2.5 msec precision required
                        ufmf_frame_idxs = np.nonzero(
                            abs( tss - this_camn_ts ) < 0.0025)[0]
                    assert len(ufmf_frame_idxs)==1
                    ufmf_frame_no = ufmf_frame_idxs[0]
                    if is_real_ufmf:
                        image,image_ts,more = ufmf.get_frame(ufmf_frame_no,
                                                             _return_more=True)
                    else:
                        image,image_ts = ufmf.get_frame(ufmf_frame_no)
                        more = fill_more_for( extra, image_ts )

                    del ufmf_frame_no, ufmf_frame_idxs
                coding = ufmf.get_format()
                if imops.is_coding_color(coding):
                    if rgb8_if_color:
                        image = imops.to_rgb8(coding,image)
                    else:
                        warnings.warn('color image not converted to color')
                per_frame_dict[ufmf_fname] = {
                    'image':image,
                    'cam_id':cam_id,
                    'camn':camn,
                    'timestamp':this_cam_h5_data['timestamp'][0],
                    'cam_received_timestamp':
                    this_cam_h5_data['cam_received_timestamp'][0],
                    'ufmf_frame_timestamp':this_cam_h5_data[timestamp_name][0],
                    }
                if more is not None:
                    per_frame_dict[ufmf_fname].update(more)
            per_frame_dict['tracker_data']=this_h5_data
            per_frame_dict['global_data']=global_data # on every iteration, pass our global data
            yield (per_frame_dict,frame)
Пример #5
0
def plot_timeseries(subplot=None, options=None):
    kalman_filename = options.kalman_filename

    if not hasattr(options, 'frames'):
        options.frames = False

    if not hasattr(options, 'show_landing'):
        options.show_landing = False

    if not hasattr(options, 'unicolor'):
        options.unicolor = False

    if not hasattr(options, 'show_obj_id'):
        options.show_obj_id = True

    if not hasattr(options, 'show_track_ends'):
        options.show_track_ends = False

    start = options.start
    stop = options.stop
    obj_only = options.obj_only
    fps = options.fps
    dynamic_model = options.dynamic_model
    use_kalman_smoothing = options.use_kalman_smoothing

    if not use_kalman_smoothing:
        if (dynamic_model is not None):
            print >> sys.stderr, (
                'WARNING: disabling Kalman smoothing '
                '(--disable-kalman-smoothing) is incompatable '
                'with setting dynamic model options (--dynamic-model)')

    ca = core_analysis.get_global_CachingAnalyzer()

    if kalman_filename is None:
        raise ValueError('No kalman_filename given. Nothing to do.')

    m = hashlib.md5()
    m.update(open(kalman_filename, mode='rb').read())
    actual_md5 = m.hexdigest()
    (obj_ids, use_obj_ids, is_mat_file, data_file,
     extra) = ca.initial_file_load(kalman_filename)
    print 'opened kalman file %s %s, %d obj_ids' % (
        kalman_filename, actual_md5, len(use_obj_ids))

    if 'frames' in extra:
        if (start is not None) or (stop is not None):
            valid_frames = np.ones((len(extra['frames']), ), dtype=np.bool)
            if start is not None:
                valid_frames &= extra['frames'] >= start
            if stop is not None:
                valid_frames &= extra['frames'] <= stop
            this_use_obj_ids = np.unique(obj_ids[valid_frames])
            use_obj_ids = list(set(use_obj_ids).intersection(this_use_obj_ids))

    include_obj_ids = None
    exclude_obj_ids = None
    do_fuse = False
    if options.stim_xml:
        file_timestamp = data_file.filename[4:19]
        fanout = xml_stimulus.xml_fanout_from_filename(options.stim_xml)
        include_obj_ids, exclude_obj_ids = fanout.get_obj_ids_for_timestamp(
            timestamp_string=file_timestamp)
        walking_start_stops = fanout.get_walking_start_stops_for_timestamp(
            timestamp_string=file_timestamp)
        if include_obj_ids is not None:
            use_obj_ids = include_obj_ids
        if exclude_obj_ids is not None:
            use_obj_ids = list(set(use_obj_ids).difference(exclude_obj_ids))
        if options.fuse:
            do_fuse = True
    else:
        walking_start_stops = []

    if dynamic_model is None:
        dynamic_model = extra['dynamic_model_name']
        print 'detected file loaded with dynamic model "%s"' % dynamic_model
        if dynamic_model.startswith('EKF '):
            dynamic_model = dynamic_model[4:]
        print '  for smoothing, will use dynamic model "%s"' % dynamic_model

    if not is_mat_file:
        mat_data = None

        if fps is None:
            fps = result_utils.get_fps(data_file, fail_on_error=False)

        if fps is None:
            fps = 100.0
            import warnings
            warnings.warn('Setting fps to default value of %f' % fps)

        tz = result_utils.get_tz(data_file)

    dt = 1.0 / fps

    all_vels = []

    if obj_only is not None:
        use_obj_ids = [i for i in use_obj_ids if i in obj_only]

    allX = {}
    frame0 = None

    line2obj_id = {}
    Xz_all = []

    fuse_did_once = False

    if not hasattr(options, 'timestamp_file'):
        options.timestamp_file = None

    if not hasattr(options, 'ori_qual'):
        options.ori_qual = None

    if options.timestamp_file is not None:
        h5 = tables.open_file(options.timestamp_file, mode='r')
        print 'reading timestamps and frames'
        table_data2d_frames = h5.root.data2d_distorted.read(field='frame')
        table_data2d_timestamps = h5.root.data2d_distorted.read(
            field='timestamp')
        print 'done'
        h5.close()
        table_data2d_frames_find = utils.FastFinder(table_data2d_frames)

    if len(use_obj_ids) == 0:
        print 'No obj_ids to plot, quitting'
        sys.exit(0)

    time0 = 0.0  # set default value

    for obj_id in use_obj_ids:
        if not do_fuse:
            try:
                kalman_rows = ca.load_data(
                    obj_id,
                    data_file,
                    use_kalman_smoothing=use_kalman_smoothing,
                    dynamic_model_name=dynamic_model,
                    return_smoothed_directions=options.smooth_orientations,
                    frames_per_second=fps,
                    up_dir=options.up_dir,
                    min_ori_quality_required=options.ori_qual,
                )
            except core_analysis.ObjectIDDataError:
                continue
            #kobs_rows = ca.load_dynamics_free_MLE_position( obj_id, data_file )
        else:
            if options.show_3d_orientations:
                raise NotImplementedError('orientation data is not supported '
                                          'when fusing obj_ids')
            if fuse_did_once:
                break
            fuse_did_once = True
            kalman_rows = flydra_analysis.a2.flypos.fuse_obj_ids(
                use_obj_ids,
                data_file,
                dynamic_model_name=dynamic_model,
                frames_per_second=fps)
        frame = kalman_rows['frame']

        if (start is not None) or (stop is not None):
            valid_cond = numpy.ones(frame.shape, dtype=numpy.bool)

            if start is not None:
                valid_cond = valid_cond & (frame >= start)

            if stop is not None:
                valid_cond = valid_cond & (frame <= stop)

            kalman_rows = kalman_rows[valid_cond]
            if not len(kalman_rows):
                continue

        walking_and_flying_kalman_rows = kalman_rows  # preserve original data

        for flystate in ['flying', 'walking']:
            frame = walking_and_flying_kalman_rows['frame']  # restore
            if flystate == 'flying':
                # assume flying unless we're told it's walking
                state_cond = numpy.ones(frame.shape, dtype=numpy.bool)
            else:
                state_cond = numpy.zeros(frame.shape, dtype=numpy.bool)

            if len(walking_start_stops):
                for walkstart, walkstop in walking_start_stops:
                    frame = walking_and_flying_kalman_rows['frame']  # restore

                    # handle each bout of walking
                    walking_bout = numpy.ones(frame.shape, dtype=numpy.bool)
                    if walkstart is not None:
                        walking_bout &= (frame >= walkstart)
                    if walkstop is not None:
                        walking_bout &= (frame <= walkstop)
                    if flystate == 'flying':
                        state_cond &= ~walking_bout
                    else:
                        state_cond |= walking_bout

                kalman_rows = np.take(walking_and_flying_kalman_rows,
                                      np.nonzero(state_cond)[0])
                assert len(kalman_rows) == np.sum(state_cond)
                frame = kalman_rows['frame']

            if frame0 is None:
                frame0 = int(frame[0])

            time0 = 0.0
            if options.timestamp_file is not None:
                frame_idxs = table_data2d_frames_find.get_idxs_of_equal(frame0)
                if len(frame_idxs):
                    time0 = table_data2d_timestamps[frame_idxs[0]]
                else:
                    raise ValueError(
                        'could not fine frame %d in timestamp file' % frame0)

            Xx = kalman_rows['x']
            Xy = kalman_rows['y']
            Xz = kalman_rows['z']

            Dx = Dy = Dz = None
            if options.smooth_orientations:
                Dx = kalman_rows['dir_x']
                Dy = kalman_rows['dir_y']
                Dz = kalman_rows['dir_z']
            elif 'rawdir_x' in kalman_rows.dtype.fields:
                Dx = kalman_rows['rawdir_x']
                Dy = kalman_rows['rawdir_y']
                Dz = kalman_rows['rawdir_z']

            if not options.frames:
                f2t = Frames2Time(frame0, fps, time0)
            else:

                def identity(x):
                    return x

                f2t = identity

            kws = {
                'linewidth': 2,
                'picker': 5,
            }
            if options.unicolor:
                kws['color'] = 'k'

            line = None

            if 'frame' in subplot:
                subplot['frame'].plot(f2t(frame), frame)

            if 'P55' in subplot:
                subplot['P55'].plot(f2t(frame), kalman_rows['P55'])

            if 'x' in subplot:
                line, = subplot['x'].plot(f2t(frame),
                                          Xx,
                                          label='obj %d (%s)' %
                                          (obj_id, flystate),
                                          **kws)
                line2obj_id[line] = obj_id
                kws['color'] = line.get_color()

            if 'y' in subplot:
                line, = subplot['y'].plot(f2t(frame),
                                          Xy,
                                          label='obj %d (%s)' %
                                          (obj_id, flystate),
                                          **kws)
                line2obj_id[line] = obj_id
                kws['color'] = line.get_color()

            if 'z' in subplot:
                frame_data = numpy.ma.getdata(
                    frame)  # works if frame is masked or not

                # plot landing time
                if options.show_landing:
                    if flystate == 'flying':  # only do this once
                        for walkstart, walkstop in walking_start_stops:
                            if walkstart in frame_data:
                                landing_dix = numpy.nonzero(
                                    frame_data == walkstart)[0][0]
                                subplot['z'].plot([f2t(walkstart)],
                                                  [Xz.data[landing_dix]],
                                                  'rD',
                                                  ms=10,
                                                  label='landing')

                if options.show_track_ends:
                    if flystate == 'flying':  # only do this once
                        subplot['z'].plot(f2t([frame_data[0], frame_data[-1]]),
                                          [
                                              numpy.ma.getdata(Xz)[0],
                                              numpy.ma.getdata(Xz)[-1]
                                          ],
                                          'cd',
                                          ms=6,
                                          label='track end')

                line, = subplot['z'].plot(f2t(frame),
                                          Xz,
                                          label='obj %d (%s)' %
                                          (obj_id, flystate),
                                          **kws)
                kws['color'] = line.get_color()
                line2obj_id[line] = obj_id

                if flystate == 'flying':
                    # only do this once
                    if options.show_obj_id:
                        subplot['z'].text(f2t(frame_data[0]),
                                          numpy.ma.getdata(Xz)[0],
                                          '%d' % (obj_id, ))
                        line2obj_id[line] = obj_id

            if flystate == 'flying':
                Xz_all.append(np.ma.array(Xz).compressed())
                #bins = np.linspace(0,.8,30)
                #print 'Xz.shape',Xz.shape
                #pylab.hist(Xz, bins=bins)

            for (dir_var, Dd) in [('dx', Dx), ('dy', Dy), ('dz', Dz)]:
                if dir_var in subplot:
                    line, = subplot[dir_var].plot(f2t(frame),
                                                  Dd,
                                                  label='obj %d (%s)' %
                                                  (obj_id, flystate),
                                                  **kws)
                    line2obj_id[line] = obj_id
                    kws['color'] = line.get_color()

            if numpy.__version__ >= '1.2.0':
                X = numpy.ma.array((Xx, Xy, Xz))
            else:
                # See http://scipy.org/scipy/numpy/ticket/820
                X = numpy.ma.vstack(
                    (Xx[numpy.newaxis, :], Xy[numpy.newaxis, :],
                     Xz[numpy.newaxis, :]))

            dist_central_diff = (X[:, 2:] - X[:, :-2])
            vel_central_diff = dist_central_diff / (2 * dt)

            vel2mag = numpy.ma.sqrt(numpy.ma.sum(vel_central_diff**2, axis=0))
            xy_vel2mag = numpy.ma.sqrt(
                numpy.ma.sum(vel_central_diff[:2, :]**2, axis=0))

            frames2 = frame[1:-1]

            accel4mag = (vel2mag[2:] - vel2mag[:-2]) / (2 * dt)
            frames4 = frames2[1:-1]

            if 'vel' in subplot:
                line, = subplot['vel'].plot(f2t(frames2),
                                            vel2mag,
                                            label='obj %d (%s)' %
                                            (obj_id, flystate),
                                            **kws)
                line2obj_id[line] = obj_id
                kws['color'] = line.get_color()

            if 'xy_vel' in subplot:
                line, = subplot['xy_vel'].plot(f2t(frames2),
                                               xy_vel2mag,
                                               label='obj %d (%s)' %
                                               (obj_id, flystate),
                                               **kws)
                line2obj_id[line] = obj_id
                kws['color'] = line.get_color()

            if len(accel4mag.compressed()) and 'accel' in subplot:
                line, = subplot['accel'].plot(f2t(frames4),
                                              accel4mag,
                                              label='obj %d (%s)' %
                                              (obj_id, flystate),
                                              **kws)
                line2obj_id[line] = obj_id
                kws['color'] = line.get_color()

            if flystate == 'flying':
                valid_vel2mag = vel2mag.compressed()
                all_vels.append(valid_vel2mag)
    if len(all_vels):
        all_vels = numpy.hstack(all_vels)
    else:
        all_vels = numpy.array([], dtype=float)

    if 1:
        cond = all_vels < 2.0
        if numpy.ma.sum(cond) != len(all_vels):
            all_vels = all_vels[cond]
            import warnings
            warnings.warn('clipping all velocities > 2.0 m/s')

    if not options.frames:
        xlabel = 'time (s)'
    else:
        xlabel = 'frame'

    for ax in subplot.itervalues():
        ax.xaxis.set_major_formatter(ticker.FormatStrFormatter("%d"))
        ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%s"))

    fixup_ax = FixupAxesWithTimeZone(tz).fixup_ax

    if 'frame' in subplot:
        if time0 != 0.0:
            fixup_ax(subplot['frame'])
        else:
            subplot['frame'].set_xlabel(xlabel)

    if 'x' in subplot:
        subplot['x'].set_ylim([-1, 1])
        subplot['x'].set_ylabel(r'x (m)')
        if time0 != 0.0:
            fixup_ax(subplot['x'])
        else:
            subplot['x'].set_xlabel(xlabel)

    if 'y' in subplot:
        subplot['y'].set_ylim([-0.5, 1.5])
        subplot['y'].set_ylabel(r'y (m)')
        if time0 != 0.0:
            fixup_ax(subplot['y'])
        else:
            subplot['y'].set_xlabel(xlabel)

    max_z = None
    if options.stim_xml:
        file_timestamp = options.kalman_filename[4:19]
        stim_xml = xml_stimulus.xml_stimulus_from_filename(
            options.stim_xml, timestamp_string=file_timestamp)
        post_max_zs = []
        for post_num, post in enumerate(stim_xml.iterate_posts()):
            post_max_zs.append(max(post['verts'][0][2],
                                   post['verts'][1][2]))  # max post height
        if len(post_max_zs):
            max_z = min(post_max_zs)  # take shortest of posts

    if 'z' in subplot:
        subplot['z'].set_ylim([0, 1])
        subplot['z'].set_ylabel(r'z (m)')
        if max_z is not None:
            subplot['z'].axhline(max_z, color='m')
        if time0 != 0.0:
            fixup_ax(subplot['z'])
        else:
            subplot['z'].set_xlabel(xlabel)

    for dir_var in ['dx', 'dy', 'dz']:
        if dir_var in subplot:
            subplot[dir_var].set_ylabel(dir_var)
            if time0 != 0.0:
                fixup_ax(subplot[dir_var])
            else:
                subplot[dir_var].set_xlabel(xlabel)

    if 'z_hist' in subplot:  # and flystate=='flying':
        Xz_all = np.hstack(Xz_all)
        bins = np.linspace(0, .8, 30)
        ax = subplot['z_hist']
        ax.hist(Xz_all, bins=bins, orientation='horizontal')
        ax.set_xticks([])
        ax.set_yticks([])
        xlim = tuple(ax.get_xlim())  # matplotlib 0.98.3 returned np.array view
        ax.set_xlim((xlim[1], xlim[0]))
        ax.axhline(max_z, color='m')

    if 'vel' in subplot:
        subplot['vel'].set_ylim([0, 2])
        subplot['vel'].set_ylabel(r'vel (m/s)')
        subplot['vel'].set_xlabel(xlabel)
        if time0 != 0.0:
            fixup_ax(subplot['vel'])
        else:
            subplot['vel'].set_xlabel(xlabel)

    if 'xy_vel' in subplot:
        #subplot['xy_vel'].set_ylim([0,2])
        subplot['xy_vel'].set_ylabel(r'horiz vel (m/s)')
        subplot['xy_vel'].set_xlabel(xlabel)
        if time0 != 0.0:
            fixup_ax(subplot['xy_vel'])
        else:
            subplot['xy_vel'].set_xlabel(xlabel)

    if 'accel' in subplot:
        subplot['accel'].set_ylabel(r'acceleration (m/(s^2))')
        subplot['accel'].set_xlabel(xlabel)
        if time0 != 0.0:
            fixup_ax(subplot['accel'])
        else:
            subplot['accel'].set_xlabel(xlabel)

    if 'vel_hist' in subplot:
        ax = subplot['vel_hist']
        bins = numpy.linspace(0, 2, 50)
        ax.set_title('excluding walking')
        pdf, bins, patches = ax.hist(all_vels, bins=bins, normed=True)
        ax.set_xlim(0, 2)
        ax.set_ylabel('probability density')
        ax.set_xlabel('velocity (m/s)')

    return line2obj_id
Пример #6
0
def convert(
    infilename,
    outfilename,
    frames_per_second=None,
    save_timestamps=True,
    file_time_data=None,
    do_nothing=False,  # set to true to test for file existance
    start_obj_id=None,
    stop_obj_id=None,
    obj_only=None,
    dynamic_model_name=None,
    hdf5=False,
    show_progress=False,
    show_progress_json=False,
    **kwargs
):
    if start_obj_id is None:
        start_obj_id = -numpy.inf
    if stop_obj_id is None:
        stop_obj_id = numpy.inf

    smoothed_data_filename = os.path.split(infilename)[1]
    raw_data_filename = smoothed_data_filename

    ca = core_analysis.get_global_CachingAnalyzer()
    with ca.kalman_analysis_context(
        infilename, data2d_fname=file_time_data
    ) as h5_context:

        extra_vars = {}
        tzname = None

        if save_timestamps:
            print("STAGE 1: finding timestamps")
            table_kobs = h5_context.get_pytable_node("ML_estimates")

            tzname = h5_context.get_tzname0()
            fps = h5_context.get_fps()

            try:
                table_data2d = h5_context.get_pytable_node(
                    "data2d_distorted", from_2d_file=True
                )
            except tables.exceptions.NoSuchNodeError as err:
                print(
                    "No timestamps in file. Either specify not to save timestamps ('--no-timestamps') or specify the original .h5 file with the timestamps ('--time-data=FILE2D')",
                    file=sys.stderr,
                )
                sys.exit(1)

            print("caching raw 2D data...", end=" ")
            sys.stdout.flush()
            table_data2d_frames = table_data2d.read(field="frame")
            assert numpy.max(table_data2d_frames) < 2 ** 63
            table_data2d_frames = table_data2d_frames.astype(numpy.int64)
            # table_data2d_frames_find = fastsearch.binarysearch.BinarySearcher( table_data2d_frames )
            table_data2d_frames_find = utils.FastFinder(table_data2d_frames)
            table_data2d_camns = table_data2d.read(field="camn")
            table_data2d_timestamps = table_data2d.read(field="timestamp")
            print("done")
            print(
                "(cached index of %d frame values of dtype %s)"
                % (len(table_data2d_frames), str(table_data2d_frames.dtype))
            )

            drift_estimates = h5_context.get_drift_estimates()
            camn2cam_id, cam_id2camns = h5_context.get_caminfo_dicts()

            gain = {}
            offset = {}
            print("hostname time_gain time_offset")
            print("-------- --------- -----------")
            for i, hostname in enumerate(drift_estimates.get("hostnames", [])):
                tgain, toffset = result_utils.model_remote_to_local(
                    drift_estimates["remote_timestamp"][hostname][::10],
                    drift_estimates["local_timestamp"][hostname][::10],
                )
                gain[hostname] = tgain
                offset[hostname] = toffset
                print("  ", repr(hostname), tgain, toffset)
            print()

            if do_nothing:
                return

            print("caching Kalman obj_ids...")
            obs_obj_ids = table_kobs.read(field="obj_id")
            fast_obs_obj_ids = utils.FastFinder(obs_obj_ids)
            print("finding unique obj_ids...")
            unique_obj_ids = numpy.unique(obs_obj_ids)
            print("(found %d)" % (len(unique_obj_ids),))
            unique_obj_ids = unique_obj_ids[unique_obj_ids >= start_obj_id]
            unique_obj_ids = unique_obj_ids[unique_obj_ids <= stop_obj_id]

            if obj_only is not None:
                unique_obj_ids = numpy.array(
                    [oid for oid in unique_obj_ids if oid in obj_only]
                )
                print("filtered to obj_only", obj_only)

            print("(will export %d)" % (len(unique_obj_ids),))
            print("finding 2d data for each obj_id...")
            timestamp_time = numpy.zeros(unique_obj_ids.shape, dtype=numpy.float64)
            table_kobs_frame = table_kobs.read(field="frame")
            if len(table_kobs_frame) == 0:
                raise ValueError("no 3D data, cannot convert")
            assert numpy.max(table_kobs_frame) < 2 ** 63
            table_kobs_frame = table_kobs_frame.astype(numpy.int64)
            assert (
                table_kobs_frame.dtype == table_data2d_frames.dtype
            )  # otherwise very slow

            all_idxs = fast_obs_obj_ids.get_idx_of_equal(unique_obj_ids)
            for obj_id_enum, obj_id in enumerate(unique_obj_ids):
                idx0 = all_idxs[obj_id_enum]
                framenumber = table_kobs_frame[idx0]
                remote_timestamp = numpy.nan

                this_camn = None
                frame_idxs = table_data2d_frames_find.get_idxs_of_equal(framenumber)
                if len(frame_idxs):
                    frame_idx = frame_idxs[0]
                    this_camn = table_data2d_camns[frame_idx]
                    remote_timestamp = table_data2d_timestamps[frame_idx]

                if this_camn is None:
                    print(
                        "skipping frame %d (obj %d): no data2d_distorted data"
                        % (framenumber, obj_id)
                    )
                    continue

                cam_id = camn2cam_id[this_camn]
                try:
                    remote_hostname = cam_id2hostname(cam_id, h5_context)
                except ValueError as e:
                    print("error getting hostname of cam: %s" % e.message)
                    continue
                if remote_hostname not in gain:
                    warnings.warn(
                        "no host %s in timestamp data. making up "
                        "data." % remote_hostname
                    )
                    gain[remote_hostname] = 1.0
                    offset[remote_hostname] = 0.0
                mainbrain_timestamp = (
                    remote_timestamp * gain[remote_hostname] + offset[remote_hostname]
                )  # find mainbrain timestamp

                timestamp_time[obj_id_enum] = mainbrain_timestamp

            extra_vars["obj_ids"] = unique_obj_ids
            extra_vars["timestamps"] = timestamp_time

            print("STAGE 2: running Kalman smoothing operation")

        # also save the experiment data if present
        uuid = None
        try:
            table_experiment = h5_context.get_pytable_node(
                "experiment_info", from_2d_file=True
            )
        except tables.exceptions.NoSuchNodeError:
            pass
        else:
            try:
                uuid = table_experiment.read(field="uuid")
            except (KeyError, tables.exceptions.HDF5ExtError):
                pass
            else:
                extra_vars["experiment_uuid"] = uuid

        recording_header = h5_context.read_textlog_header_2d()
        recording_flydra_version = recording_header["flydra_version"]

        # -----------------------------------------------

        obj_ids = h5_context.get_unique_obj_ids()
        smoothing_flydra_version = h5_context.get_extra_info()["header"][
            "flydra_version"
        ]

        obj_ids = obj_ids[obj_ids >= start_obj_id]
        obj_ids = obj_ids[obj_ids <= stop_obj_id]

        if obj_only is not None:
            obj_ids = numpy.array(obj_only)
            print("filtered to obj_only", obj_ids)

        if frames_per_second is None:
            frames_per_second = h5_context.get_fps()

        if dynamic_model_name is None:
            extra = h5_context.get_extra_info()
            orig_dynamic_model_name = extra.get("dynamic_model_name", None)
            dynamic_model_name = orig_dynamic_model_name
            if dynamic_model_name is None:
                dynamic_model_name = dynamic_models.DEFAULT_MODEL
                warnings.warn(
                    'no dynamic model specified, using "%s"' % dynamic_model_name
                )
            else:
                print(
                    'detected file loaded with dynamic model "%s"' % dynamic_model_name
                )
            if dynamic_model_name.startswith("EKF "):
                dynamic_model_name = dynamic_model_name[4:]
            print('  for smoothing, will use dynamic model "%s"' % dynamic_model_name)

        allrows = []
        allqualrows = []
        failed_quality = False

        if show_progress:
            import progressbar

            class StringWidget(progressbar.Widget):
                def set_string(self, ts):
                    self.ts = ts

                def update(self, pbar):
                    if hasattr(self, "ts"):
                        return self.ts
                    else:
                        return ""

            string_widget = StringWidget()
            objs_per_sec_widget = progressbar.FileTransferSpeed(unit="obj_ids ")
            widgets = [
                string_widget,
                objs_per_sec_widget,
                progressbar.Percentage(),
                progressbar.Bar(),
                progressbar.ETA(),
            ]
            pbar = progressbar.ProgressBar(widgets=widgets, maxval=len(obj_ids)).start()

        for i, obj_id in enumerate(obj_ids):
            if obj_id > stop_obj_id:
                break
            if show_progress:
                string_widget.set_string("[obj_id: % 5d]" % obj_id)
                pbar.update(i)
            if show_progress_json and i % 100 == 0:
                rough_percent_done = float(i) / len(obj_ids) * 100.0
                result_utils.do_json_progress(rough_percent_done)
            try:
                rows = h5_context.load_data(
                    obj_id,
                    dynamic_model_name=dynamic_model_name,
                    frames_per_second=frames_per_second,
                    **kwargs
                )
            except core_analysis.DiscontiguousFramesError:
                warnings.warn(
                    "discontiguous frames smoothing obj_id %d, skipping." % (obj_id,)
                )
                continue
            except core_analysis.NotEnoughDataToSmoothError:
                # warnings.warn('not enough data to smooth obj_id %d, skipping.'%(obj_id,))
                continue
            except numpy.linalg.linalg.LinAlgError:
                warnings.warn(
                    "linear algebra error smoothing obj_id %d, skipping." % (obj_id,)
                )
                continue
            except core_analysis.CouldNotCalculateOrientationError:
                warnings.warn(
                    "orientation error smoothing obj_id %d, skipping." % (obj_id,)
                )
                continue

            allrows.append(rows)
            try:
                qualrows = compute_ori_quality(
                    h5_context, rows["frame"], obj_id, smooth_len=0
                )
                allqualrows.append(qualrows)
            except ValueError:
                failed_quality = True
        if show_progress:
            pbar.finish()

        allrows = numpy.concatenate(allrows)
        if not failed_quality:
            allqualrows = numpy.concatenate(allqualrows)
        else:
            allqualrows = None
        recarray = numpy.rec.array(allrows)

        smoothed_source = "kalman_estimates"

        flydra_analysis.analysis.flydra_analysis_convert_to_mat.do_it(
            rows=recarray,
            ignore_observations=True,
            newfilename=outfilename,
            extra_vars=extra_vars,
            orientation_quality=allqualrows,
            hdf5=hdf5,
            tzname=tzname,
            fps=fps,
            smoothed_source=smoothed_source,
            smoothed_data_filename=smoothed_data_filename,
            raw_data_filename=raw_data_filename,
            dynamic_model_name=orig_dynamic_model_name,
            recording_flydra_version=recording_flydra_version,
            smoothing_flydra_version=smoothing_flydra_version,
        )
        if show_progress_json:
            result_utils.do_json_progress(100)