示例#1
0
def kalmanize(
    src_filename,
    do_full_kalmanization=True,
    dest_filename=None,
    reconstructor=None,
    reconstructor_filename=None,
    start_frame=None,
    stop_frame=None,
    exclude_cam_ids=None,
    exclude_camns=None,
    dynamic_model_name=None,
    debug=False,
    frames_per_second=None,
    area_threshold=0,
    min_observations_to_save=0,
    options=None,
):
    if options is None:
        # get default options
        parser = get_parser()
        (options, args) = parser.parse_args([])

    if debug:
        numpy.set_printoptions(precision=3, linewidth=120, suppress=False)

    if exclude_cam_ids is None:
        exclude_cam_ids = []

    if exclude_camns is None:
        exclude_camns = []

    use_existing_filename = True

    if reconstructor is not None:
        assert isinstance(reconstructor, flydra_core.reconstruct.Reconstructor)
        assert reconstructor_filename is None

    with open_file_safe(src_filename, mode='r') as results:
        camn2cam_id, cam_id2camns = get_caminfo_dicts(results)

        if do_full_kalmanization:
            if dynamic_model_name is None:
                if hasattr(results.root, 'kalman_estimates'):
                    if hasattr(results.root.kalman_estimates.attrs,
                               'dynamic_model_name'):
                        dynamic_model_name = (results.root.kalman_estimates.
                                              attrs.dynamic_model_name)
                        warnings.warn('dynamic model not specified. '
                                      'using "%s"' % dynamic_model_name)
            if dynamic_model_name is None:
                dynamic_model_name = 'EKF mamarama, units: mm'
                warnings.warn('dynamic model not specified. '
                              'using "%s"' % dynamic_model_name)
            else:
                print 'using dynamic model "%s"' % dynamic_model_name

            if reconstructor_filename is not None:
                if reconstructor_filename.endswith('h5'):
                    with PT.open_file(reconstructor_filename, mode='r') as fd:
                        reconstructor = flydra_core.reconstruct.Reconstructor(
                            fd,
                            minimum_eccentricity=options.
                            force_minimum_eccentricity)
                else:
                    reconstructor = flydra_core.reconstruct.Reconstructor(
                        reconstructor_filename,
                        minimum_eccentricity=options.force_minimum_eccentricity
                    )
            else:
                # reconstructor_filename is None
                if reconstructor is None:
                    reconstructor = flydra_core.reconstruct.Reconstructor(
                        results,
                        minimum_eccentricity=options.force_minimum_eccentricity
                    )

            if options.force_minimum_eccentricity is not None:
                if (reconstructor.minimum_eccentricity !=
                        options.force_minimum_eccentricity):
                    raise ValueError('could not force minimum_eccentricity')

            if dest_filename is None:
                dest_filename = os.path.splitext(
                    results.filename)[0] + '.kalmanized.h5'
        else:
            use_existing_filename = False
            dest_filename = tempfile.mktemp(suffix='.h5')

        if reconstructor is not None and reconstructor.cal_source_type == 'pytables':
            save_reconstructor_filename = reconstructor.cal_source.filename
        else:
            warnings.warn('unable to determine reconstructor source '
                          'filename for %r' % reconstructor.cal_source_type)
            save_reconstructor_filename = None

        if frames_per_second is None:
            frames_per_second = get_fps(results)
            if do_full_kalmanization:
                print 'read frames_per_second from file', frames_per_second

        dt = 1.0 / frames_per_second

        if options.sync_error_threshold_msec is None:
            # default is IFI/2
            sync_error_threshold = (0.5 * dt)
        else:
            sync_error_threshold = options.sync_error_threshold_msec / 1000.0

        if os.path.exists(dest_filename):
            if use_existing_filename:
                raise ValueError('%s already exists. Will not '
                                 'overwrite.' % dest_filename)
            else:
                os.unlink(dest_filename)

        with open_file_safe(dest_filename,
                            mode="w",
                            title="tracked Flydra data file",
                            delete_on_error=True) as h5file:

            if 'experiment_info' in results.root:
                results.root.experiment_info._f_copy(h5file.root,
                                                     recursive=True)

            if do_full_kalmanization:
                parsed = read_textlog_header(results)
                if 'trigger_CS3' not in parsed:
                    parsed['trigger_CS3'] = 'unknown'
                textlog_save_lines = [
                    'kalmanize running at %s fps, (top %s, trigger_CS3 %s, flydra_version %s)'
                    % (str(frames_per_second), str(parsed.get(
                        'top', 'unknown')), str(parsed['trigger_CS3']),
                       flydra_core.version.__version__),
                    'original file: %s' % (src_filename, ),
                    'dynamic model: %s' % (dynamic_model_name, ),
                    'reconstructor file: %s' % (save_reconstructor_filename, ),
                ]

                kalman_model = dynamic_models.get_kalman_model(
                    name=dynamic_model_name, dt=dt)

                h5saver = KalmanSaver(
                    h5file,
                    reconstructor,
                    cam_id2camns=cam_id2camns,
                    min_observations_to_save=min_observations_to_save,
                    textlog_save_lines=textlog_save_lines,
                    dynamic_model_name=dynamic_model_name,
                    dynamic_model=kalman_model,
                    debug=debug,
                    fake_timestamp=options.fake_timestamp,
                )

                tracker = Tracker(
                    reconstructor,
                    kalman_model=kalman_model,
                    save_all_data=True,
                    area_threshold=area_threshold,
                    area_threshold_for_orientation=options.
                    area_threshold_for_orientation,
                    disable_image_stat_gating=options.
                    disable_image_stat_gating,
                    orientation_consensus=options.orientation_consensus,
                    fake_timestamp=options.fake_timestamp,
                )

                tracker.set_killed_tracker_callback(h5saver.save_tro)

                # copy timestamp data into newly created kalmanized file
                if hasattr(results.root, 'trigger_clock_info'):
                    results.root.trigger_clock_info._f_copy(h5file.root)

            data2d = results.root.data2d_distorted

            frame_count = 0
            last_frame = None
            frame_data = collections.defaultdict(list)
            time_frame_all_cam_timestamps = []
            time_frame_all_camns = []

            if 1:
                time1 = time.time()
                if do_full_kalmanization:
                    print 'loading all frame numbers...'
                frames_array = numpy.asarray(data2d.read(field='frame'))
                time2 = time.time()
                if do_full_kalmanization:
                    print 'done in %.1f sec' % (time2 - time1)
                    if (not options.disable_image_stat_gating
                            and 'cur_val' in data2d.colnames):
                        warnings.warn(
                            'No pre-filtering of data based on zero '
                            'probability -- more data association work is '
                            'being done than necessary')

            if len(frames_array) == 0:
                # no data
                print 'No 2D data. Nothing to do.'
                return

            if do_full_kalmanization:
                print '2D data range: approximately %d<frame<%d' % (
                    frames_array[0], frames_array[-1])

            if do_full_kalmanization:
                accum_frame_spread = None
            else:
                accum_frame_spread = []
                accum_frame_spread_fno = []
                accum_frame_all_timestamps = []
                accum_frame_all_camns = []

            max_all_check_times = -np.inf

            for row_start, row_stop in utils.iter_non_overlapping_chunk_start_stops(
                    frames_array,
                    min_chunk_size=500000,
                    size_increment=1000,
                    status_fd=sys.stdout):

                print 'Doing initial scan of approx frame range %d-%d.' % (
                    frames_array[row_start], frames_array[row_stop - 1])

                this_frames_array = frames_array[row_start:row_stop]
                if start_frame is not None:
                    if this_frames_array.max() < start_frame:
                        continue
                if stop_frame is not None:
                    if this_frames_array.min() > stop_frame:
                        continue

                data2d_recarray = data2d.read(start=row_start, stop=row_stop)
                this_frames = data2d_recarray['frame']
                print 'Examining frames %d-%d in detail.' % (this_frames[0],
                                                             this_frames[-1])
                this_row_idxs = np.argsort(this_frames)
                for ii in range(len(this_row_idxs) + 1):

                    if ii >= len(this_row_idxs):
                        finish_frame = True
                    else:
                        finish_frame = False

                        this_row_idx = this_row_idxs[ii]

                        row = data2d_recarray[this_row_idx]

                        new_frame = row['frame']

                        if start_frame is not None:
                            if new_frame < start_frame:
                                continue
                        if stop_frame is not None:
                            if new_frame > stop_frame:
                                continue

                        if last_frame != new_frame:
                            if new_frame < last_frame:
                                print 'new_frame', new_frame
                                print 'last_frame', last_frame
                                raise RuntimeError(
                                    "expected continuously increasing "
                                    "frame numbers")
                            finish_frame = True

                    if finish_frame:
                        # new frame
                        ########################################
                        # Data for this frame is complete
                        if last_frame is not None:

                            this_frame_spread = 0.0
                            if len(time_frame_all_cam_timestamps) > 1:
                                check_times = np.array(
                                    time_frame_all_cam_timestamps)
                                check_times -= check_times.min()
                                this_frame_spread = check_times.max()
                                if accum_frame_spread is not None:
                                    accum_frame_spread.append(
                                        this_frame_spread)
                                    accum_frame_spread_fno.append(last_frame)

                                    accum_frame_all_timestamps.append(
                                        time_frame_all_cam_timestamps)
                                    accum_frame_all_camns.append(
                                        time_frame_all_camns)

                                max_all_check_times = max(
                                    this_frame_spread, max_all_check_times)
                                if this_frame_spread > sync_error_threshold:
                                    if this_frame_spread == max_all_check_times:
                                        print '%s frame %d: sync diff: %.1f msec' % (
                                            os.path.split(
                                                results.filename)[-1],
                                            last_frame,
                                            this_frame_spread * 1000.0)

                            if debug > 5:
                                print
                                print 'frame_data for frame %d' % (
                                    last_frame, )
                                pprint.pprint(dict(frame_data))
                                print
                            if do_full_kalmanization:
                                if this_frame_spread > sync_error_threshold:
                                    if debug > 5:
                                        print(
                                            'frame sync error (spread %.1f msec), '
                                            'skipping' %
                                            (this_frame_spread * 1e3, ))
                                        print
                                    warnings.warn(
                                        'Synchronization error detected, '
                                        'but continuing analysis without '
                                        'potentially bad data.')
                                else:
                                    process_frame(reconstructor,
                                                  tracker,
                                                  last_frame,
                                                  frame_data,
                                                  camn2cam_id,
                                                  debug=debug)
                            frame_count += 1
                            if do_full_kalmanization and frame_count % 1000 == 0:
                                time2 = time.time()
                                dur = time2 - time1
                                fps = frame_count / dur
                                print 'frame % 10d, kalmanization/data association speed: % 8.1f fps' % (
                                    last_frame, fps)
                                time1 = time2
                                frame_count = 0

                        ########################################
                        frame_data = collections.defaultdict(list)
                        time_frame_all_cam_timestamps = []  # clear values
                        time_frame_all_camns = []  # clear values
                        last_frame = new_frame

                    camn = row['camn']
                    try:
                        cam_id = camn2cam_id[camn]
                    except KeyError:
                        # This will happen if cameras were re-synchronized (and
                        # thus gain new cam_ids) immediately before saving was
                        # turned on in MainBrain. The reason is that the network
                        # buffers are still full of old data coming in from the
                        # cameras.
                        warnings.warn('WARNING: no cam_id for camn '
                                      '%d, skipping this row of data' % camn)
                        continue

                    if cam_id in exclude_cam_ids:
                        # exclude this camera
                        continue

                    if camn in exclude_camns:
                        # exclude this camera
                        continue

                    time_frame_all_cam_timestamps.append(row['timestamp'])
                    time_frame_all_camns.append(row['camn'])

                    if do_full_kalmanization:

                        x_distorted = row['x']
                        if numpy.isnan(x_distorted):
                            # drop point -- not found
                            continue
                        y_distorted = row['y']

                        (x_undistorted,
                         y_undistorted) = reconstructor.undistort(
                             cam_id, (x_distorted, y_distorted))

                        (area, slope, eccentricity,
                         frame_pt_idx) = (row['area'], row['slope'],
                                          row['eccentricity'],
                                          row['frame_pt_idx'])

                        if 'cur_val' in row.dtype.fields:
                            cur_val = row['cur_val']
                        else:
                            cur_val = None
                        if 'mean_val' in row.dtype.fields:
                            mean_val = row['mean_val']
                        else:
                            mean_val = None
                        if 'sumsqf_val' in row.dtype.fields:
                            sumsqf_val = row['sumsqf_val']
                        else:
                            sumsqf_val = None

                        # FIXME: cache this stuff?
                        pmat_inv = reconstructor.get_pmat_inv(cam_id)
                        camera_center = reconstructor.get_camera_center(cam_id)
                        camera_center = numpy.hstack((camera_center[:,
                                                                    0], [1]))
                        camera_center_meters = reconstructor.get_camera_center(
                            cam_id)
                        camera_center_meters = numpy.hstack(
                            (camera_center_meters[:, 0], [1]))
                        helper = reconstructor.get_reconstruct_helper_dict(
                        )[cam_id]
                        rise = slope
                        run = 1.0
                        if np.isinf(rise):
                            if rise > 0:
                                rise = 1.0
                                run = 0.0
                            else:
                                rise = -1.0
                                run = 0.0

                        (p1, p2, p3, p4, ray0, ray1, ray2, ray3, ray4,
                         ray5) = do_3d_operations_on_2d_point(
                             helper, x_undistorted, y_undistorted, pmat_inv,
                             camera_center, x_distorted, y_distorted, rise,
                             run)
                        line_found = not numpy.isnan(p1)
                        pluecker_hz_meters = (ray0, ray1, ray2, ray3, ray4,
                                              ray5)

                        # Keep in sync with kalmanize.py and data_descriptions.py
                        pt_undistorted = (x_undistorted, y_undistorted, area,
                                          slope, eccentricity, p1, p2, p3, p4,
                                          line_found, frame_pt_idx, cur_val,
                                          mean_val, sumsqf_val)

                        projected_line_meters = geom.line_from_HZline(
                            pluecker_hz_meters)

                        frame_data[camn].append(
                            (pt_undistorted, projected_line_meters))

            if do_full_kalmanization:
                tracker.kill_all_trackers()  # done tracking

        if not do_full_kalmanization:
            os.unlink(dest_filename)

    if accum_frame_spread is not None:
        # save spread data to file for analysis
        accum_frame_spread = np.array(accum_frame_spread)
        accum_frame_spread_fno = np.array(accum_frame_spread_fno)
        if options.dest_file is not None:
            accum_frame_spread_filename = options.dest_file
        else:
            accum_frame_spread_filename = src_filename + '.spreadh5'

        cam_ids = cam_id2camns.keys()
        cam_ids.sort()
        camn_order = []
        for cam_id in cam_ids:
            camn_order.extend(cam_id2camns[cam_id])

        camn_order = np.array(camn_order)
        cam_id_array = np.array(cam_ids)

        N_cams = len(camn_order)
        N_frames = len(accum_frame_spread_fno)

        all_timestamps = np.empty((N_frames, N_cams), dtype=np.float)
        all_timestamps.fill(np.nan)
        for i, (timestamps, camns) in enumerate(
                zip(accum_frame_all_timestamps, accum_frame_all_camns)):

            for j, camn in enumerate(camn_order):
                try:
                    idx = camns.index(camn)
                except ValueError:
                    continue  # not found, skip
                timestamp = timestamps[idx]
                all_timestamps[i, j] = timestamp

        h5 = tables.open_file(accum_frame_spread_filename, mode='w')
        h5.create_array(h5.root, 'spread', accum_frame_spread,
                        'frame timestamp spreads (sec)')
        h5.create_array(h5.root, 'framenumber', accum_frame_spread_fno,
                        'frame number')
        h5.create_array(h5.root, 'all_timestamps', all_timestamps,
                        'all timestamps')
        h5.create_array(h5.root, 'camn_order', camn_order, 'camn_order')
        h5.create_array(h5.root, 'cam_id_array', cam_id_array, 'cam_id_array')
        h5.close()
        print 'saved %s' % accum_frame_spread_filename

    if max_all_check_times > sync_error_threshold:
        if not options.keep_sync_errors:
            if do_full_kalmanization:
                print 'max_all_check_times %.2f msec' % (max_all_check_times *
                                                         1000.0)
                handle, target = tempfile.mkstemp(
                    os.path.split(dest_filename)[1])
                os.unlink(target)  # remove original file there
                shutil.move(dest_filename, target)

                raise ValueError(
                    'Synchonization errors exist in the data. Moved result file'
                    ' to ensure it is not confused with valid data. The new '
                    'location is: %s' % (target, ))

            else:
                sys.exit(1)  # sync error
    else:
        if not do_full_kalmanization:
            print '%s no sync differences greater than %.1f msec' % (
                os.path.split(src_filename)[-1],
                sync_error_threshold * 1000.0,
            )
示例#2
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('filename', nargs='+',
                        help='name of flydra .hdf5 file',
                        )

    parser.add_argument("--stim-xml",
                        type=str,
                        default=None,
                        help="name of XML file with stimulus info",
                        required=True,
                        )

    parser.add_argument("--align-json",
                        type=str,
                        default=None,
                        help="previously exported json file containing s,R,T",
                        )

    parser.add_argument("--radius", type=float,
                      help="radius of line (in meters)",
                      default=0.002,
                      metavar="RADIUS")

    parser.add_argument("--obj-only", type=str)

    parser.add_argument("--obj-filelist", type=str,
                      help="use object ids from list in text file",
                      )

    parser.add_argument(
        "-r", "--reconstructor", dest="reconstructor_path",
        type=str,
        help=("calibration/reconstructor path (if not specified, "
              "defaults to FILE)"))

    args = parser.parse_args()
    options = args # optparse OptionParser backwards compatibility

    reconstructor_path = args.reconstructor_path
    fps = None

    ca = core_analysis.get_global_CachingAnalyzer()
    by_file = {}

    for h5_filename in args.filename:
        assert(tables.is_hdf5_file(h5_filename))
        obj_ids, use_obj_ids, is_mat_file, data_file, extra = ca.initial_file_load(
            h5_filename)
        this_fps = result_utils.get_fps( data_file, fail_on_error=False )
        if fps is None:
            if this_fps is not None:
                fps = this_fps
        if reconstructor_path is None:
            reconstructor_path = data_file
        by_file[h5_filename] = (use_obj_ids, data_file)
    del h5_filename
    del obj_ids, use_obj_ids, is_mat_file, data_file, extra

    if options.obj_only is not None:
        obj_only = core_analysis.parse_seq(options.obj_only)
    else:
        obj_only = None

    if reconstructor_path is None:
        raise RuntimeError('must specify reconstructor from CLI if not using .h5 files')

    R = reconstruct.Reconstructor(reconstructor_path)

    if fps is None:
        fps = 100.0
        warnings.warn('Setting fps to default value of %f'%fps)
    else:
        fps = 1.0

    if options.stim_xml is None:
        raise ValueError(
            'stim_xml must be specified (how else will you align the data?')

    if 1:
        stim_xml = xml_stimulus.xml_stimulus_from_filename(
            options.stim_xml,
            )
        try:
            fanout = xml_stimulus.xml_fanout_from_filename( options.stim_xml )
        except xml_stimulus.WrongXMLTypeError:
            pass
        else:
            include_obj_ids, exclude_obj_ids = fanout.get_obj_ids_for_timestamp(
                timestamp_string=file_timestamp )
            if include_obj_ids is not None:
                use_obj_ids = include_obj_ids
            if exclude_obj_ids is not None:
                use_obj_ids = list( set(use_obj_ids).difference(
                    exclude_obj_ids ) )
            print('using object ids specified in fanout .xml file')
        if stim_xml.has_reconstructor():
            stim_xml.verify_reconstructor(R)

    x = []
    y = []
    z = []
    speed = []

    if options.obj_filelist is not None:
        obj_filelist=options.obj_filelist
    else:
        obj_filelist=None

    if obj_filelist is not None:
        obj_only = 1

    if obj_only is not None:
        if len(by_file) != 1:
            raise RuntimeError("specifying obj_only can only be done for a single file")
        if obj_filelist is not None:
            data = np.loadtxt(obj_filelist,delimiter=',')
            obj_only = np.array(data[:,0], dtype='int')
            print(obj_only)

        use_obj_ids = numpy.array(obj_only)
        h5_filename = by_file.keys()[0]
        (prev_use_ob_ids, data_file) = by_file[h5_filename]
        by_file[h5_filename] = (use_obj_ids, data_file)

    for h5_filename in by_file:
        (use_obj_ids, data_file) = by_file[h5_filename]
        for obj_id_enum,obj_id in enumerate(use_obj_ids):
            rows = ca.load_data( obj_id, data_file,
                                 use_kalman_smoothing=False,
                                 #dynamic_model_name = dynamic_model_name,
                                 #frames_per_second=fps,
                                 #up_dir=up_dir,
                                )
            verts = numpy.array( [rows['x'], rows['y'], rows['z']] ).T
            if len(verts)>=3:
                verts_central_diff = verts[2:,:] - verts[:-2,:]
                dt = 1.0/fps
                vels = verts_central_diff/(2*dt)
                speeds = numpy.sqrt(numpy.sum(vels**2,axis=1))
                # pad end points
                speeds = numpy.array([speeds[0]] + list(speeds) + [speeds[-1]])
            else:
                speeds = numpy.zeros( (verts.shape[0],) )

            if verts.shape[0] != len(speeds):
                raise ValueError('mismatch length of x data and speeds')
            x.append( verts[:,0] )
            y.append( verts[:,1] )
            z.append( verts[:,2] )
            speed.append(speeds)
        data_file.close()
    del h5_filename, use_obj_ids, data_file

    if 0:
        # debug
        if stim_xml is not None:
            v = None
            for child in stim_xml.root:
                if child.tag == 'cubic_arena':
                    info = stim_xml._get_info_for_cubic_arena(child)
                    v=info['verts4x4']
            if v is not None:
                for vi in v:
                    print('adding',vi)
                    x.append( [vi[0]] )
                    y.append( [vi[1]] )
                    z.append( [vi[2]] )
                    speed.append( [100.0] )

    x = np.concatenate(x)
    y = np.concatenate(y)
    z = np.concatenate(z)
    w = np.ones_like(x)
    speed = np.concatenate(speed)

    # homogeneous coords
    verts = np.array([x,y,z,w])

    #######################################################

    # Create the MayaVi engine and start it.
    e = Engine()
    # start does nothing much but useful if someone is listening to
    # your engine.
    e.start()

    # Create a new scene.
    from tvtk.tools import ivtk
    #viewer = ivtk.IVTK(size=(600,600))
    viewer = IVTKWithCalGUI(size=(800,600))
    viewer.open()
    e.new_scene(viewer)

    viewer.cal_align.set_data(verts,speed,R,args.align_json)

    if 0:
        # Do this if you need to see the MayaVi tree view UI.
        ev = EngineView(engine=e)
        ui = ev.edit_traits()

    # view aligned data
    e.add_source(viewer.cal_align.source)

    v = Vectors()
    v.glyph.scale_mode = 'data_scaling_off'
    v.glyph.color_mode = 'color_by_scalar'
    v.glyph.glyph_source.glyph_position='center'
    v.glyph.glyph_source.glyph_source = tvtk.SphereSource(
        radius=options.radius,
        )
    e.add_module(v)

    if stim_xml is not None:
        if 0:
            stim_xml.draw_in_mayavi_scene(e)
        else:
            actors = stim_xml.get_tvtk_actors()
            viewer.scene.add_actors(actors)

    gui = GUI()
    gui.start_event_loop()
示例#3
0
def plot_top_and_side_views(
    subplot=None, options=None, obs_mew=None, scale=1.0, units="m",
):
    """
    inputs
    ------
    subplot - a dictionary of matplotlib axes instances with keys 'xy' and/or 'xz'
    fps - the framerate of the data
    """
    assert subplot is not None

    assert options is not None

    if not hasattr(options, "show_track_ends"):
        options.show_track_ends = False

    if not hasattr(options, "unicolor"):
        options.unicolor = False

    if not hasattr(options, "show_landing"):
        options.show_landing = False

    kalman_filename = options.kalman_filename
    fps = options.fps
    dynamic_model = options.dynamic_model
    use_kalman_smoothing = options.use_kalman_smoothing

    if not hasattr(options, "ellipsoids"):
        options.ellipsoids = False

    if not hasattr(options, "show_observations"):
        options.show_observations = False

    if not hasattr(options, "markersize"):
        options.markersize = 0.5

    if options.ellipsoids and use_kalman_smoothing:
        warnings.warn(
            "plotting ellipsoids while using Kalman smoothing does not reveal original error estimates"
        )

    assert kalman_filename is not None

    start = options.start
    stop = options.stop
    obj_only = options.obj_only

    if not use_kalman_smoothing:
        if dynamic_model is not None:
            print(
                "ERROR: disabling Kalman smoothing (--disable-kalman-smoothing) is incompatable with setting dynamic model options (--dynamic-model)",
                file=sys.stderr,
            )
            sys.exit(1)

    ca = core_analysis.get_global_CachingAnalyzer()

    if kalman_filename is not None:
        obj_ids, use_obj_ids, is_mat_file, data_file, extra = ca.initial_file_load(
            kalman_filename
        )

    if not is_mat_file:
        mat_data = None

        if fps is None:
            fps = result_utils.get_fps(data_file, fail_on_error=False)

        if fps is None:
            fps = 100.0
            warnings.warn("Setting fps to default value of %f" % fps)
        reconstructor = reconstruct.Reconstructor(data_file)
    else:
        reconstructor = None

    if options.stim_xml:
        file_timestamp = data_file.filename[4:19]
        stim_xml = xml_stimulus.xml_stimulus_from_filename(
            options.stim_xml, timestamp_string=file_timestamp,
        )
        try:
            fanout = xml_stimulus.xml_fanout_from_filename(options.stim_xml)
        except xml_stimulus.WrongXMLTypeError:
            walking_start_stops = []
        else:
            include_obj_ids, exclude_obj_ids = fanout.get_obj_ids_for_timestamp(
                timestamp_string=file_timestamp
            )
            walking_start_stops = fanout.get_walking_start_stops_for_timestamp(
                timestamp_string=file_timestamp
            )
            if include_obj_ids is not None:
                use_obj_ids = include_obj_ids
            if exclude_obj_ids is not None:
                use_obj_ids = list(set(use_obj_ids).difference(exclude_obj_ids))
            stim_xml = fanout.get_stimulus_for_timestamp(
                timestamp_string=file_timestamp
            )
        if stim_xml.has_reconstructor():
            stim_xml.verify_reconstructor(reconstructor)
    else:
        walking_start_stops = []

    if dynamic_model is None:
        dynamic_model = extra.get("dynamic_model_name", None)

    if dynamic_model is None:
        if use_kalman_smoothing:
            warnings.warn(
                "no kalman smoothing will be performed because no "
                "dynamic model specified or found."
            )
            use_kalman_smoothing = False
    else:
        print('detected file loaded with dynamic model "%s"' % dynamic_model)
        if use_kalman_smoothing:
            if dynamic_model.startswith("EKF "):
                dynamic_model = dynamic_model[4:]
            print('  for smoothing, will use dynamic model "%s"' % dynamic_model)

    subplots = subplot.keys()
    subplots.sort()  # ensure consistency across runs

    dt = 1.0 / fps

    if obj_only is not None:
        use_obj_ids = [i for i in use_obj_ids if i in obj_only]

    subplot["xy"].set_aspect("equal")
    subplot["xz"].set_aspect("equal")

    subplot["xy"].set_xlabel("x (%s)" % units)
    subplot["xy"].set_ylabel("y (%s)" % units)

    subplot["xz"].set_xlabel("x (%s)" % units)
    subplot["xz"].set_ylabel("z (%s)" % units)

    if options.stim_xml:
        stim_xml.plot_stim(
            subplot["xy"], projection=xml_stimulus.SimpleOrthographicXYProjection()
        )
        stim_xml.plot_stim(
            subplot["xz"], projection=xml_stimulus.SimpleOrthographicXZProjection()
        )

    allX = {}
    frame0 = None
    results = collections.defaultdict(list)
    for obj_id in use_obj_ids:
        line = None
        ellipse_lines = []
        MLE_line = None
        try:
            kalman_rows = ca.load_data(
                obj_id,
                data_file,
                use_kalman_smoothing=use_kalman_smoothing,
                dynamic_model_name=dynamic_model,
                frames_per_second=fps,
                up_dir=options.up_dir,
            )
        except core_analysis.ObjectIDDataError:
            continue

        if options.show_observations:
            kobs_rows = ca.load_dynamics_free_MLE_position(obj_id, data_file)

        frame = kalman_rows["frame"]
        if options.show_observations:
            frame_obs = kobs_rows["frame"]

        if (start is not None) or (stop is not None):
            valid_cond = numpy.ones(frame.shape, dtype=numpy.bool)
            if options.show_observations:
                valid_obs_cond = np.ones(frame_obs.shape, dtype=numpy.bool)

            if start is not None:
                valid_cond = valid_cond & (frame >= start)
                if options.show_observations:
                    valid_obs_cond = valid_obs_cond & (frame_obs >= start)

            if stop is not None:
                valid_cond = valid_cond & (frame <= stop)
                if options.show_observations:
                    valid_obs_cond = valid_obs_cond & (frame_obs <= stop)

            kalman_rows = kalman_rows[valid_cond]
            if options.show_observations:
                kobs_rows = kobs_rows[valid_obs_cond]
            if not len(kalman_rows):
                continue

        frame = kalman_rows["frame"]

        Xx = kalman_rows["x"]
        Xy = kalman_rows["y"]
        Xz = kalman_rows["z"]

        if options.max_z is not None:
            cond = Xz <= options.max_z

            frame = numpy.ma.masked_where(~cond, frame)
            Xx = numpy.ma.masked_where(~cond, Xx)
            Xy = numpy.ma.masked_where(~cond, Xy)
            Xz = numpy.ma.masked_where(~cond, Xz)
            with keep_axes_dimensions_if(subplot["xz"], options.stim_xml):
                subplot["xz"].axhline(options.max_z)

        kws = {"markersize": options.markersize}

        if options.unicolor:
            kws["color"] = "k"

        landing_idxs = []
        for walkstart, walkstop in walking_start_stops:
            if walkstart in frame:
                tmp_idx = numpy.nonzero(frame == walkstart)[0][0]
                landing_idxs.append(tmp_idx)

        with keep_axes_dimensions_if(subplot["xy"], options.stim_xml):
            (line,) = subplot["xy"].plot(
                Xx * scale, Xy * scale, ".", label="obj %d" % obj_id, **kws
            )
            kws["color"] = line.get_color()
            if options.ellipsoids:
                for i in range(len(Xx)):
                    rowi = kalman_rows[i]
                    mu = [rowi["x"], rowi["y"], rowi["z"]]
                    va = get_covariance(rowi)
                    ellx, elly = densities.gauss_ell(mu, va, [0, 1], 30, 0.39)
                    (ellipse_line,) = subplot["xy"].plot(
                        ellx * scale, elly * scale, color=kws["color"]
                    )
                    ellipse_lines.append(ellipse_line)
            if options.show_track_ends:
                subplot["xy"].plot(
                    [Xx[0] * scale, Xx[-1] * scale],
                    [Xy[0] * scale, Xy[-1] * scale],
                    "cd",
                    ms=6,
                    label="track end",
                )
            if options.show_obj_id:
                subplot["xy"].text(Xx[0] * scale, Xy[0] * scale, str(obj_id))
            if options.show_landing:
                for landing_idx in landing_idxs:
                    subplot["xy"].plot(
                        [Xx[landing_idx] * scale],
                        [Xy[landing_idx] * scale],
                        "rD",
                        ms=10,
                        label="landing",
                    )
            if options.show_observations:
                mykw = {}
                mykw.update(kws)
                mykw["markersize"] *= 5
                mykw["mew"] = obs_mew

                badcond = np.isnan(kobs_rows["x"])
                Xox = np.ma.masked_where(badcond, kobs_rows["x"])
                Xoy = np.ma.masked_where(badcond, kobs_rows["y"])

                (MLE_line,) = subplot["xy"].plot(
                    Xox * scale, Xoy * scale, "x", label="obj %d" % obj_id, **mykw
                )

        with keep_axes_dimensions_if(subplot["xz"], options.stim_xml):
            (line,) = subplot["xz"].plot(
                Xx * scale, Xz * scale, ".", label="obj %d" % obj_id, **kws
            )
            kws["color"] = line.get_color()
            if options.ellipsoids:
                for i in range(len(Xx)):
                    rowi = kalman_rows[i]
                    mu = [rowi["x"], rowi["y"], rowi["z"]]
                    va = get_covariance(rowi)
                    ellx, ellz = densities.gauss_ell(mu, va, [0, 2], 30, 0.39)
                    (ellipse_line,) = subplot["xz"].plot(
                        ellx * scale, ellz * scale, color=kws["color"]
                    )
                    ellipse_lines.append(ellipse_line)

            if options.show_track_ends:
                subplot["xz"].plot(
                    [Xx[0] * scale, Xx[-1] * scale],
                    [Xz[0] * scale, Xz[-1] * scale],
                    "cd",
                    ms=6,
                    label="track end",
                )
            if options.show_obj_id:
                subplot["xz"].text(Xx[0] * scale, Xz[0] * scale, str(obj_id))
            if options.show_landing:
                for landing_idx in landing_idxs:
                    subplot["xz"].plot(
                        [Xx[landing_idx] * scale],
                        [Xz[landing_idx] * scale],
                        "rD",
                        ms=10,
                        label="landing",
                    )
            if options.show_observations:
                mykw = {}
                mykw.update(kws)
                mykw["markersize"] *= 5
                mykw["mew"] = obs_mew

                badcond = np.isnan(kobs_rows["x"])
                Xox = np.ma.masked_where(badcond, kobs_rows["x"])
                Xoz = np.ma.masked_where(badcond, kobs_rows["z"])

                (MLE_line,) = subplot["xz"].plot(
                    Xox * scale, Xoz * scale, "x", label="obj %d" % obj_id, **mykw
                )
        results["lines"].append(line)
        results["ellipse_lines"].extend(ellipse_lines)
        results["MLE_line"].append(MLE_line)
    return results
示例#4
0
def plot_ori(
    kalman_filename=None,
    h5=None,
    obj_only=None,
    start=None,
    stop=None,
    output_filename=None,
    options=None,
):
    if output_filename is not None:
        import matplotlib

        matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    import matplotlib.ticker as mticker

    fps = None
    if h5 is not None:
        h5f = tables.open_file(h5, mode="r")
        camn2cam_id, cam_id2camns = result_utils.get_caminfo_dicts(h5f)
        fps = result_utils.get_fps(h5f)
        h5f.close()
    else:
        camn2cam_id = {}

    use_kalman_smoothing = options.use_kalman_smoothing
    fps = options.fps
    dynamic_model = options.dynamic_model

    ca = core_analysis.get_global_CachingAnalyzer()

    (obj_ids, use_obj_ids, is_mat_file, data_file,
     extra) = ca.initial_file_load(kalman_filename)

    if dynamic_model is None:
        dynamic_model = extra["dynamic_model_name"]
        print('detected file loaded with dynamic model "%s"' % dynamic_model)
        if dynamic_model.startswith("EKF "):
            dynamic_model = dynamic_model[4:]
        print('  for smoothing, will use dynamic model "%s"' % dynamic_model)

    with open_file_safe(kalman_filename, mode="r") as kh5:
        if fps is None:
            fps = result_utils.get_fps(kh5, fail_on_error=True)

        kmle = kh5.root.ML_estimates[:]  # load into RAM

        if start is not None:
            kmle = kmle[kmle["frame"] >= start]

        if stop is not None:
            kmle = kmle[kmle["frame"] <= stop]

        all_mle_obj_ids = kmle["obj_id"]

        # walk all tables to get all obj_ids
        all_obj_ids = {}
        parent = kh5.root.ori_ekf_qual
        for group in parent._f_iter_nodes():
            for table in group._f_iter_nodes():
                assert table.name.startswith("obj")
                obj_id = int(table.name[3:])
                all_obj_ids[obj_id] = table

        if obj_only is None:
            use_obj_ids = all_obj_ids.keys()
            mle_use_obj_ids = list(np.unique(all_mle_obj_ids))
            missing_objs = list(set(mle_use_obj_ids) - set(use_obj_ids))
            if len(missing_objs):
                warnings.warn("orientation not fit for %d obj_ids" %
                              (len(missing_objs), ))
            use_obj_ids.sort()
        else:
            use_obj_ids = obj_only

        # now, generate plots
        fig = plt.figure()
        ax1 = fig.add_subplot(511)
        ax2 = fig.add_subplot(512, sharex=ax1)
        ax3 = fig.add_subplot(513, sharex=ax1)
        ax4 = fig.add_subplot(514, sharex=ax1)
        ax5 = fig.add_subplot(515, sharex=ax1)

        min_frame_range = np.inf
        max_frame_range = -np.inf

        if options.print_status:
            print("%d object IDs in file" % (len(use_obj_ids), ))
        for obj_id in use_obj_ids:
            table = all_obj_ids[obj_id]
            rows = table[:]

            if start is not None:
                rows = rows[rows["frame"] >= start]

            if stop is not None:
                rows = rows[rows["frame"] <= stop]

            if options.print_status:
                print("obj_id %d: %d rows of EKF data" % (obj_id, len(rows)))

            frame = rows["frame"]
            # get camns
            camns = []
            for colname in table.colnames:
                if colname.startswith("dist"):
                    camn = int(colname[4:])
                    camns.append(camn)
            for camn in camns:
                label = camn2cam_id.get(camn, "camn%d" % camn)
                theta = rows["theta%d" % camn]
                used = rows["used%d" % camn]
                dist = rows["dist%d" % camn]

                frf = np.array(frame, dtype=np.float)
                min_frame_range = min(np.min(frf), min_frame_range)
                max_frame_range = max(np.max(frf), max_frame_range)

                (line, ) = ax1.plot(frame,
                                    theta * R2D,
                                    "o",
                                    mew=0,
                                    ms=2.0,
                                    label=label)
                c = line.get_color()
                ax2.plot(frame[used],
                         dist[used] * R2D,
                         "o",
                         color=c,
                         mew=0,
                         label=label)
                ax2.plot(frame[~used],
                         dist[~used] * R2D,
                         "o",
                         color=c,
                         mew=0,
                         ms=2.0)
            # plot 3D orientation
            mle_row_cond = all_mle_obj_ids == obj_id
            rows_this_obj = kmle[mle_row_cond]
            if options.print_status:
                print("obj_id %d: %d rows of ML data" %
                      (obj_id, len(rows_this_obj)))
            frame = rows_this_obj["frame"]
            hz = [rows_this_obj["hz_line%d" % i] for i in range(6)]
            # hz = np.rec.fromarrays(hz,names=['hz%d'%for i in range(6)])
            hz = np.vstack(hz).T
            orient = reconstruct.line_direction(hz)
            ax3.plot(frame, orient[:, 0], "ro", mew=0, ms=2.0, label="x")
            ax3.plot(frame, orient[:, 1], "go", mew=0, ms=2.0, label="y")
            ax3.plot(frame, orient[:, 2], "bo", mew=0, ms=2.0, label="z")

            qual = compute_ori_quality(kh5, rows_this_obj["frame"], obj_id)
            if 1:
                orinan = np.array(orient, copy=True)
                if options.ori_qual is not None and options.ori_qual != 0:
                    orinan[qual < options.ori_qual] = np.nan
                try:
                    sori = ori_smooth(orinan, frames_per_second=fps)
                except AssertionError:
                    if options.print_status:
                        print("not plotting smoothed ori for object id %d" %
                              (obj_id, ))
                    else:
                        pass
                else:
                    ax3.plot(frame, sori[:, 0], "r-", mew=0,
                             ms=2.0)  # ,label='x')
                    ax3.plot(frame, sori[:, 1], "g-", mew=0,
                             ms=2.0)  # ,label='y')
                    ax3.plot(frame, sori[:, 2], "b-", mew=0,
                             ms=2.0)  # ,label='z')

            ax4.plot(frame, qual, "b-")  # , mew=0, ms=3 )

            # --------------
            kalman_rows = ca.load_data(
                obj_id,
                kh5,
                use_kalman_smoothing=use_kalman_smoothing,
                dynamic_model_name=dynamic_model,
                return_smoothed_directions=options.smooth_orientations,
                frames_per_second=fps,
                up_dir=options.up_dir,
                min_ori_quality_required=options.ori_qual,
            )
            frame = kalman_rows["frame"]
            cond = np.ones(frame.shape, dtype=np.bool)
            if options.start is not None:
                cond &= options.start <= frame
            if options.stop is not None:
                cond &= frame <= options.stop
            kalman_rows = kalman_rows[cond]

            frame = kalman_rows["frame"]
            Dx = Dy = Dz = None
            if options.smooth_orientations:
                Dx = kalman_rows["dir_x"]
                Dy = kalman_rows["dir_y"]
                Dz = kalman_rows["dir_z"]
            elif "rawdir_x" in kalman_rows.dtype.fields:
                Dx = kalman_rows["rawdir_x"]
                Dy = kalman_rows["rawdir_y"]
                Dz = kalman_rows["rawdir_z"]

            if Dx is not None:
                ax5.plot(frame, Dx, "r-", label="dx")
                ax5.plot(frame, Dy, "g-", label="dy")
                ax5.plot(frame, Dz, "b-", label="dz")

    ax1.xaxis.set_major_formatter(mticker.FormatStrFormatter("%d"))
    ax1.set_ylabel("theta (deg)")
    ax1.legend()

    ax2.set_ylabel("z (deg)")
    ax2.legend()

    ax3.set_ylabel("ori")
    ax3.legend()

    ax4.set_ylabel("quality")

    ax5.set_ylabel("dir")
    ax5.set_xlabel("frame")
    ax5.legend()

    ax1.set_xlim(min_frame_range, max_frame_range)
    if output_filename is None:
        plt.show()
    else:
        plt.savefig(output_filename)
def doit(
    filenames=None,
    start=None,
    stop=None,
    kalman_filename=None,
    fps=None,
    use_kalman_smoothing=True,
    dynamic_model=None,
    up_dir=None,
    options=None,
):
    if options.save_fig is not None:
        matplotlib.use('Agg')
    import pylab

    if not use_kalman_smoothing:
        if (fps is not None) or (dynamic_model is not None):
            print(
                'WARNING: disabling Kalman smoothing '
                '(--disable-kalman-smoothing) is '
                'incompatable with setting fps and '
                'dynamic model options (--fps and '
                '--dynamic-model)',
                file=sys.stderr)

    ax = None
    ax_by_cam = {}
    fig = pylab.figure()

    assert len(filenames) >= 1, 'must give at least one filename!'

    n_files = 0
    for filename in filenames:

        if options.show_source_name:
            figtitle = filename
            if kalman_filename is not None:
                figtitle += ' ' + kalman_filename
        else:
            figtitle = ''
        if options.obj_only is not None:
            figtitle += ' only showing objects: ' + ' '.join(
                map(str, options.obj_only))
        if figtitle != '':
            pylab.figtext(0.01, 0.01, figtitle, verticalalignment='bottom')

        with PT.open_file(filename, mode='r') as h5:
            if options.spreadh5 is not None:
                h5spread = PT.open_file(options.spreadh5, mode='r')
            else:
                h5spread = None

            if fps is None:
                fps = result_utils.get_fps(h5)

            camn2cam_id, cam_id2camns = result_utils.get_caminfo_dicts(h5)
            cam_ids = cam_id2camns.keys()
            cam_ids.sort()

            if start is not None or stop is not None:
                frames = h5.root.data2d_distorted.read(field='frame')
                valid_cond = numpy.ones(frames.shape, dtype=numpy.bool)
                if start is not None:
                    valid_cond = valid_cond & (frames >= start)
                if stop is not None:
                    valid_cond = valid_cond & (frames <= stop)
                read_idxs = np.nonzero(valid_cond)[0]
                all_data = []
                for start_stop in utils.iter_contig_chunk_idxs(read_idxs):
                    (read_idx_start_idx, read_idx_stop_idx) = start_stop
                    start_idx = read_idxs[read_idx_start_idx]
                    stop_idx = read_idxs[read_idx_stop_idx - 1]
                    these_rows = h5.root.data2d_distorted.read(start=start_idx,
                                                               stop=stop_idx +
                                                               1)
                    all_data.append(these_rows)
                if len(all_data) == 0:
                    print('file %s has no frames in range %s - %s' %
                          (filename, start, stop))
                    continue
                all_data = np.concatenate(all_data)
                del valid_cond, frames, start_idx, stop_idx, these_rows, read_idxs
            else:
                all_data = h5.root.data2d_distorted[:]

            tmp_frames = all_data['frame']
            if len(tmp_frames) == 0:
                print('file %s has no frames, skipping.' % filename)
                continue
            n_files += 1
            start_frame = tmp_frames.min()
            stop_frame = tmp_frames.max()
            del tmp_frames

            for cam_id_enum, cam_id in enumerate(cam_ids):
                if cam_id in ax_by_cam:
                    ax = ax_by_cam[cam_id]
                else:
                    n_subplots = len(cam_ids)
                    if kalman_filename is not None:
                        n_subplots += 1
                    if h5spread is not None:
                        n_subplots += 1
                    ax = pylab.subplot(n_subplots,
                                       1,
                                       cam_id_enum + 1,
                                       sharex=ax)
                    ax_by_cam[cam_id] = ax
                    ax.fmt_xdata = str
                    ax.fmt_ydata = str

                camns = cam_id2camns[cam_id]
                cam_id_n_valid = 0
                for camn in camns:
                    this_idx = numpy.nonzero(all_data['camn'] == camn)[0]
                    data = all_data[this_idx]

                    xdata = data['x']
                    valid = ~numpy.isnan(xdata)

                    data = data[valid]
                    del valid

                    if options.area_threshold > 0.0:
                        area = data['area']

                        valid2 = area >= options.area_threshold
                        data = data[valid2]
                        del valid2

                    if options.likely_only:
                        pt_area = data['area']
                        cur_val = data['cur_val']
                        mean_val = data['mean_val']
                        sumsqf_val = data['sumsqf_val']

                        p_y_x = some_rough_negative_log_likelihood(
                            pt_area, cur_val, mean_val, sumsqf_val)
                        valid3 = np.isfinite(p_y_x)
                        data = data[valid3]

                    n_valid = len(data)
                    cam_id_n_valid += n_valid
                    if options.timestamps:
                        xdata = data['timestamp']
                    else:
                        xdata = data['frame']
                    if n_valid >= 1:
                        ax.plot(xdata, data['x'], 'ro', ms=2, mew=0)
                        ax.plot(xdata, data['y'], 'go', ms=2, mew=0)
                ax.text(
                    0.1,
                    0,
                    '%s %s: %d pts' %
                    (cam_id, cam_id2camns[cam_id], cam_id_n_valid),
                    horizontalalignment='left',
                    verticalalignment='bottom',
                    transform=ax.transAxes,
                )
                ax.set_ylabel('pixels')
                if not options.timestamps:
                    ax.set_xlim((start_frame, stop_frame))
            ax.set_xlabel('frame')
            if options.timestamps:
                timezone = result_utils.get_tz(h5)
                df = DateFormatter(timezone)
                ax.xaxis.set_major_formatter(
                    ticker.FuncFormatter(df.format_date))
            else:
                ax.xaxis.set_major_formatter(ticker.FormatStrFormatter("%d"))
            ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%d"))
            if h5spread is not None:
                if options.timestamps:
                    raise NotImplementedError(
                        '--timestamps is currently incompatible with --spreadh5'
                    )
                ax_by_cam['h5spread'] = ax
                if kalman_filename is not None:
                    # this is 2nd to last
                    ax = pylab.subplot(n_subplots,
                                       1,
                                       n_subplots - 1,
                                       sharex=ax)
                else:
                    # this is last
                    ax = pylab.subplot(n_subplots, 1, n_subplots, sharex=ax)

                frames = h5spread.root.framenumber[:]
                spread = h5spread.root.spread[:]

                valid_cond = numpy.ones(frames.shape, dtype=numpy.bool)
                if start is not None:
                    valid_cond = valid_cond & (frames >= start)
                if stop is not None:
                    valid_cond = valid_cond & (frames <= stop)

                spread_msec = spread[valid_cond] * 1000.0
                ax.plot(frames[valid_cond], spread_msec, 'o', ms=2, mew=0)

                if spread_msec.max() < 1.0:
                    ax.set_ylim((0, 1))
                    ax.set_yticks([0, 1])
                ax.set_xlabel('frame')
                ax.set_ylabel('timestamp spread (msec)')
                ax.xaxis.set_major_formatter(ticker.FormatStrFormatter("%d"))
                ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%d"))
                h5spread.close()
                del frames
                del spread

    if options.timestamps:
        fig.autofmt_xdate()

    if kalman_filename is not None:
        if 1:
            ax = pylab.subplot(n_subplots, 1, n_subplots, sharex=ax)
            ax_by_cam['kalman pmean'] = ax
            ax.fmt_xdata = str
            ax.set_ylabel('3d error\nmeters')

        frame_start = start
        frame_stop = stop

        # copied from save_movies_overlay.py
        ca = core_analysis.get_global_CachingAnalyzer()
        (obj_ids, use_obj_ids, is_mat_file, data_file,
         extra) = ca.initial_file_load(kalman_filename)
        if options.timestamps:
            time_model = result_utils.get_time_model_from_data(data_file)
        if 'frames' in extra:
            frames = extra['frames']
            valid_cond = np.ones((len(frames, )), dtype=np.bool)
            if start is not None:
                valid_cond &= frames >= start
            if stop is not None:
                valid_cond &= frames <= stop
            obj_ids = obj_ids[valid_cond]
            use_obj_ids = np.unique(obj_ids)
            print('quick found use_obj_ids', use_obj_ids)
        if is_mat_file:
            raise ValueError('cannot use .mat file for kalman_filename '
                             'because it is missing the reconstructor '
                             'and ability to get framenumbers')
        R = reconstruct.Reconstructor(data_file)

        if options.obj_only is not None:
            use_obj_ids = options.obj_only

        if dynamic_model is None and use_kalman_smoothing:
            dynamic_model = extra['dynamic_model_name']
            print('detected file loaded with dynamic model "%s"' %
                  dynamic_model)
            if dynamic_model.startswith('EKF '):
                dynamic_model = dynamic_model[4:]
            print('  for smoothing, will use dynamic model "%s"' %
                  dynamic_model)

        if options.reproj_error:
            reproj_error = collections.defaultdict(list)
            max_reproj_error = {}
            kalman_rows = []
            for obj_id in use_obj_ids:
                kalman_rows.append(ca.load_observations(obj_id, data_file))
            kalman_rows = numpy.concatenate(kalman_rows)
            kalman_3d_frame = kalman_rows['frame']

            if start is not None or stop is not None:
                if start is None:
                    start = -numpy.inf
                if stop is None:
                    stop = numpy.inf
                valid_cond = ((kalman_3d_frame >= start) &
                              (kalman_3d_frame <= stop))

                kalman_rows = kalman_rows[valid_cond]
                kalman_3d_frame = kalman_3d_frame[valid_cond]

            # modified from save_movies_overlay
            for this_3d_row_enum, this_3d_row in enumerate(kalman_rows):
                if this_3d_row_enum % 100 == 0:
                    print('doing reprojection error for MLE 3d estimate for '
                          'row %d of %d' %
                          (this_3d_row_enum, len(kalman_rows)))
                vert = numpy.array(
                    [this_3d_row['x'], this_3d_row['y'], this_3d_row['z']])
                obj_id = this_3d_row['obj_id']
                if numpy.isnan(vert[0]):
                    # no observation this frame
                    continue
                obs_2d_idx = this_3d_row['obs_2d_idx']
                try:
                    kobs_2d_data = data_file.root.ML_estimates_2d_idxs[int(
                        obs_2d_idx)]
                except tables.exceptions.NoSuchNodeError, err:
                    # backwards compatibility
                    kobs_2d_data = data_file.root.kalman_observations_2d_idxs[
                        int(obs_2d_idx)]

                # parse VLArray
                this_camns = kobs_2d_data[0::2]
                this_camn_idxs = kobs_2d_data[1::2]

                # find original 2d data
                #   narrow down search
                obs2d = all_data[all_data['frame'] == this_3d_row['frame']]

                for camn, this_camn_idx in zip(this_camns, this_camn_idxs):
                    cam_id = camn2cam_id[camn]

                    # do projection to camera image plane
                    vert_image = R.find2d(cam_id, vert, distorted=True)

                    new_cond = ((obs2d['camn'] == camn) &
                                (obs2d['frame_pt_idx'] == this_camn_idx))
                    assert numpy.sum(new_cond) == 1

                    x = obs2d[new_cond]['x'][0]
                    y = obs2d[new_cond]['y'][0]

                    this_reproj_error = numpy.sqrt((vert_image[0] - x)**2 +
                                                   (vert_image[1] - y)**2)
                    if this_reproj_error > 100:
                        print('  reprojection error > 100 (%.1f) at frame %d '
                              'for camera %s, obj_id %d' %
                              (this_reproj_error, this_3d_row['frame'], cam_id,
                               obj_id))
                    if numpy.isnan(this_reproj_error):
                        print('error:')
                        print(this_camns, this_camn_idxs)
                        print(cam_id)
                        print(vert_image)
                        print(vert)
                        raise ValueError('nan at frame %d' %
                                         this_3d_row['frame'])
                    reproj_error[cam_id].append(this_reproj_error)
                    if cam_id in max_reproj_error:
                        (cur_max_frame, cur_max_reproj_error,
                         cur_obj_id) = max_reproj_error[cam_id]
                        if this_reproj_error > cur_max_reproj_error:
                            max_reproj_error[cam_id] = (this_3d_row['frame'],
                                                        this_reproj_error,
                                                        obj_id)
                    else:
                        max_reproj_error[cam_id] = (this_3d_row['frame'],
                                                    this_reproj_error, obj_id)

            del kalman_rows, kalman_3d_frame, obj_ids
            print('mean reprojection errors:')
            cam_ids = reproj_error.keys()
            cam_ids.sort()
            for cam_id in cam_ids:
                errors = reproj_error[cam_id]
                mean_error = numpy.mean(errors)
                worst_frame, worst_error, worst_obj_id = max_reproj_error[
                    cam_id]
                print(' %s: %.1f (worst: frame %d, obj_id %d, error %.1f)' %
                      (cam_id, mean_error, worst_frame, worst_obj_id,
                       worst_error))
            print()

        for kalman_smoothing in [True, False]:
            if use_kalman_smoothing == False and kalman_smoothing == True:
                continue
            print('loading frame numbers for kalman objects (estimates)')
            kalman_rows = []
            for obj_id in use_obj_ids:
                try:
                    my_rows = ca.load_data(
                        obj_id,
                        data_file,
                        use_kalman_smoothing=kalman_smoothing,
                        dynamic_model_name=dynamic_model,
                        frames_per_second=fps,
                        up_dir=up_dir,
                    )
                except core_analysis.NotEnoughDataToSmoothError, err:
                    # OK, we don't have data from this obj_id
                    continue
                else:
                    kalman_rows.append(my_rows)
            if not len(kalman_rows):
                # no data
                continue
            kalman_rows = numpy.concatenate(kalman_rows)
            kalman_3d_frame = kalman_rows['frame']

            if start is not None or stop is not None:
                if start is None:
                    start = -numpy.inf
                if stop is None:
                    stop = numpy.inf
                valid_cond = ((kalman_3d_frame >= start) &
                              (kalman_3d_frame <= stop))

                kalman_rows = kalman_rows[valid_cond]
                kalman_3d_frame = kalman_3d_frame[valid_cond]

            obj_ids = kalman_rows['obj_id']
            use_obj_ids = numpy.unique(obj_ids)
            non_nan_rows = ~np.isnan(kalman_rows['x'])
            print('plotting %d Kalman objects' % (len(use_obj_ids), ))
            for obj_id in use_obj_ids:
                cond = obj_ids == obj_id
                cond &= non_nan_rows
                x = kalman_rows['x'][cond]
                y = kalman_rows['y'][cond]
                z = kalman_rows['z'][cond]
                w = numpy.ones(x.shape)
                X = numpy.vstack((x, y, z, w)).T
                frame = kalman_rows['frame'][cond]
                #print '%d %d %d'%(frame[0],obj_id, len(frame))
                if options.timestamps:
                    time_est = time_model.framestamp2timestamp(frame)

                if kalman_smoothing:
                    kwprops = dict(lw=0.5)
                else:
                    kwprops = dict(lw=1)

                for cam_id in cam_ids:
                    if cam_id not in R.get_cam_ids():
                        print(
                            'no calibration for %s: not showing 3D projections'
                            % (cam_id, ))
                        continue
                    ax = ax_by_cam[cam_id]
                    x2d = R.find2d(cam_id, X, distorted=True)
                    ## print '%d %d %s (%f,%f)'%(
                    ##     obj_id,frame[0],cam_id,x2d[0,0],x2d[1,0])
                    if options.timestamps:
                        xdata = time_est
                    else:
                        xdata = frame
                    ax.text(xdata[0], x2d[0, 0], '%d' % obj_id)
                    thisline, = ax.plot(xdata,
                                        x2d[0, :],
                                        'b-',
                                        picker=5,
                                        **kwprops)  #5pt tolerance
                    all_kalman_lines[thisline] = obj_id
                    thisline, = ax.plot(xdata,
                                        x2d[1, :],
                                        'y-',
                                        picker=5,
                                        **kwprops)  #5pt tolerance
                    all_kalman_lines[thisline] = obj_id
                    ax.set_ylim([-100, 800])
                    if options.timestamps:
                        ## ax.set_xlim( *time_model.framestamp2timestamp(
                        ##     (start_frame, stop_frame) ))
                        pass
                    else:
                        ax.set_xlim((start_frame, stop_frame))
                if 1:
                    ax = ax_by_cam['kalman pmean']
                    P00 = kalman_rows['P00'][cond]
                    P11 = kalman_rows['P11'][cond]
                    P22 = kalman_rows['P22'][cond]
                    Pmean = numpy.sqrt(P00**2 + P11**2 + P22**2)  # variance
                    std = numpy.sqrt(Pmean)  # standard deviation (in meters)
                    if options.timestamps:
                        xdata = time_est
                    else:
                        xdata = frame
                    ax.plot(xdata, std, 'k-', **kwprops)

                    if options.timestamps:
                        ax.set_xlabel('time (sec)')
                        timezone = result_utils.get_tz(h5)
                        df = DateFormatter(timezone)
                        ax.xaxis.set_major_formatter(
                            ticker.FuncFormatter(df.format_date))
                        for label in ax.get_xticklabels():
                            label.set_rotation(30)

                    else:
                        ax.set_xlabel('frame')
                        ax.xaxis.set_major_formatter(
                            ticker.FormatStrFormatter("%d"))
                    ax.yaxis.set_major_formatter(
                        ticker.FormatStrFormatter("%s"))

            if not kalman_smoothing:
                # plot 2D data contributing to 3D object
                # this is forked from flydra_analysis_plot_kalman_2d.py

                kresults = ca.get_pytables_file_by_filename(kalman_filename)
                try:
                    kobs = kresults.root.ML_estimates
                except tables.exceptions.NoSuchNodeError:
                    # backward compatibility
                    kobs = kresults.root.kalman_observations
                kframes = kobs.read(field='frame')
                if frame_start is not None:
                    k_after_start = numpy.nonzero(kframes >= frame_start)[0]
                else:
                    k_after_start = None
                if frame_stop is not None:
                    k_before_stop = numpy.nonzero(kframes <= frame_stop)[0]
                else:
                    k_before_stop = None

                if k_after_start is not None and k_before_stop is not None:
                    k_use_idxs = numpy.intersect1d(k_after_start,
                                                   k_before_stop)
                elif k_after_start is not None:
                    k_use_idxs = k_after_start
                elif k_before_stop is not None:
                    k_use_idxs = k_before_stop
                else:
                    k_use_idxs = numpy.arange(kobs.nrows)

                obs_2d_idxs = kobs.read(field='obs_2d_idx')[k_use_idxs]
                kframes = kframes[k_use_idxs]

                try:
                    kobs_2d = kresults.root.ML_estimates_2d_idxs
                except tables.exceptions.NoSuchNodeError:
                    # backwards compatibility
                    kobs_2d = kresults.root.kalman_observations_2d_idxs
                # this will be slooow...
                used_cam_ids = collections.defaultdict(list)
                for obs_2d_idx, kframe in zip(obs_2d_idxs, kframes):
                    obs_2d_row = kobs_2d[int(obs_2d_idx)]
                    #print kframe,obs_2d_row
                    for camn in obs_2d_row[::2]:
                        try:
                            cam_id = camn2cam_id[camn]
                        except KeyError:
                            cam_id = None
                        if cam_id is not None:
                            used_cam_ids[cam_id].append(kframe)
                for cam_id, kframes_used in used_cam_ids.iteritems():
                    kframes_used = numpy.array(kframes_used)
                    yval = -99 * numpy.ones_like(kframes_used)
                    ax = ax_by_cam[cam_id]
                    if options.timestamps:
                        ax.plot(time_model.framestamp2timestamp(kframes_used),
                                yval, 'kx')
                    else:
                        ax.plot(kframes_used, yval, 'kx')
                        ax.set_xlim((start_frame, stop_frame))
                    ax.set_ylim([-100, 800])
示例#6
0
def doit(output_h5_filename=None,
         kalman_filename=None,
         data2d_filename=None,
         start=None,
         stop=None,
         gate_angle_threshold_degrees=40.0,
         area_threshold_for_orientation=0.0,
         obj_only=None,
         options=None):
    gate_angle_threshold_radians = gate_angle_threshold_degrees * D2R

    if options.show:
        import matplotlib.pyplot as plt
        import matplotlib.ticker as mticker

    M = SymobolicModels()
    x = sympy.DeferredVector('x')
    G_symbolic = M.get_observation_model(x)
    dx_symbolic = M.get_process_model(x)

    if 0:
        print 'G_symbolic'
        sympy.pprint(G_symbolic)
        print

    G_linearized = [G_symbolic.diff(x[i]) for i in range(7)]
    if 0:
        print 'G_linearized'
        for i in range(len(G_linearized)):
            sympy.pprint(G_linearized[i])
        print

    arg_tuple_x = (M.P00, M.P01, M.P02, M.P03, M.P10, M.P11, M.P12, M.P13,
                   M.P20, M.P21, M.P22, M.P23, M.Ax, M.Ay, M.Az, x)

    xm = sympy.DeferredVector('xm')
    arg_tuple_x_xm = (M.P00, M.P01, M.P02, M.P03, M.P10, M.P11, M.P12, M.P13,
                      M.P20, M.P21, M.P22, M.P23, M.Ax, M.Ay, M.Az, x, xm)

    eval_G = lambdify(arg_tuple_x, G_symbolic, 'numpy')
    eval_linG = lambdify(arg_tuple_x, G_linearized, 'numpy')

    # coord shift of observation model
    phi_symbolic = M.get_observation_model(xm)

    # H = G - phi
    H_symbolic = G_symbolic - phi_symbolic

    # We still take derivative wrt x (not xm).
    H_linearized = [H_symbolic.diff(x[i]) for i in range(7)]

    eval_phi = lambdify(arg_tuple_x_xm, phi_symbolic, 'numpy')
    eval_H = lambdify(arg_tuple_x_xm, H_symbolic, 'numpy')
    eval_linH = lambdify(arg_tuple_x_xm, H_linearized, 'numpy')

    if 0:
        print 'dx_symbolic'
        sympy.pprint(dx_symbolic)
        print

    eval_dAdt = drop_dims(lambdify(x, dx_symbolic, 'numpy'))

    debug_level = 0
    if debug_level:
        np.set_printoptions(linewidth=130, suppress=True)

    if os.path.exists(output_h5_filename):
        raise RuntimeError("will not overwrite old file '%s'" %
                           output_h5_filename)

    ca = core_analysis.get_global_CachingAnalyzer()
    with open_file_safe(output_h5_filename, mode='w') as output_h5:

        with open_file_safe(kalman_filename, mode='r') as kh5:
            with open_file_safe(data2d_filename, mode='r') as h5:
                for input_node in kh5.root._f_iter_nodes():
                    # copy everything from source to dest
                    input_node._f_copy(output_h5.root, recursive=True)

                try:
                    dest_table = output_h5.root.ML_estimates
                except tables.exceptions.NoSuchNodeError, err1:
                    # backwards compatibility
                    try:
                        dest_table = output_h5.root.kalman_observations
                    except tables.exceptions.NoSuchNodeError, err2:
                        raise err1
                for colname in ['hz_line%d' % i for i in range(6)]:
                    clear_col(dest_table, colname)
                dest_table.flush()

                if options.show:
                    fig1 = plt.figure()
                    ax1 = fig1.add_subplot(511)
                    ax2 = fig1.add_subplot(512, sharex=ax1)
                    ax3 = fig1.add_subplot(513, sharex=ax1)
                    ax4 = fig1.add_subplot(514, sharex=ax1)
                    ax5 = fig1.add_subplot(515, sharex=ax1)
                    ax1.xaxis.set_major_formatter(
                        mticker.FormatStrFormatter("%d"))

                    min_frame_range = np.inf
                    max_frame_range = -np.inf

                reconst = reconstruct.Reconstructor(kh5)

                camn2cam_id, cam_id2camns = result_utils.get_caminfo_dicts(h5)
                fps = result_utils.get_fps(h5)
                dt = 1.0 / fps

                used_camn_dict = {}

                # associate framenumbers with timestamps using 2d .h5 file
                data2d = h5.root.data2d_distorted[:]  # load to RAM
                if start is not None:
                    data2d = data2d[data2d['frame'] >= start]
                if stop is not None:
                    data2d = data2d[data2d['frame'] <= stop]
                data2d_idxs = np.arange(len(data2d))
                h5_framenumbers = data2d['frame']
                h5_frame_qfi = result_utils.QuickFrameIndexer(h5_framenumbers)

                ML_estimates_2d_idxs = (kh5.root.ML_estimates_2d_idxs[:])

                all_kobs_obj_ids = dest_table.read(field='obj_id')
                all_kobs_frames = dest_table.read(field='frame')
                use_obj_ids = np.unique(all_kobs_obj_ids)
                if obj_only is not None:
                    use_obj_ids = obj_only

                if hasattr(kh5.root.kalman_estimates.attrs,
                           'dynamic_model_name'):
                    dynamic_model = kh5.root.kalman_estimates.attrs.dynamic_model_name
                    if dynamic_model.startswith('EKF '):
                        dynamic_model = dynamic_model[4:]
                else:
                    dynamic_model = 'mamarama, units: mm'
                    warnings.warn(
                        'could not determine dynamic model name, using "%s"' %
                        dynamic_model)

                for obj_id_enum, obj_id in enumerate(use_obj_ids):
                    # Use data association step from kalmanization to load potentially
                    # relevant 2D orientations, but discard previous 3D orientation.
                    if obj_id_enum % 100 == 0:
                        print 'obj_id %d (%d of %d)' % (obj_id, obj_id_enum,
                                                        len(use_obj_ids))
                    if options.show:
                        all_xhats = []
                        all_ori = []

                    output_row_obj_id_cond = all_kobs_obj_ids == obj_id

                    obj_3d_rows = ca.load_dynamics_free_MLE_position(
                        obj_id, kh5)
                    if start is not None:
                        obj_3d_rows = obj_3d_rows[
                            obj_3d_rows['frame'] >= start]
                    if stop is not None:
                        obj_3d_rows = obj_3d_rows[obj_3d_rows['frame'] <= stop]

                    try:
                        smoothed_3d_rows = ca.load_data(
                            obj_id,
                            kh5,
                            use_kalman_smoothing=True,
                            frames_per_second=fps,
                            dynamic_model_name=dynamic_model)
                    except core_analysis.NotEnoughDataToSmoothError:
                        continue

                    smoothed_frame_qfi = result_utils.QuickFrameIndexer(
                        smoothed_3d_rows['frame'])

                    slopes_by_camn_by_frame = collections.defaultdict(dict)
                    x0d_by_camn_by_frame = collections.defaultdict(dict)
                    y0d_by_camn_by_frame = collections.defaultdict(dict)
                    pt_idx_by_camn_by_frame = collections.defaultdict(dict)
                    min_frame = np.inf
                    max_frame = -np.inf

                    start_idx = None
                    for this_idx, this_3d_row in enumerate(obj_3d_rows):
                        # iterate over each sample in the current camera
                        framenumber = this_3d_row['frame']

                        if not np.isnan(this_3d_row['hz_line0']):
                            # We have a valid initial 3d orientation guess.
                            if framenumber < min_frame:
                                min_frame = framenumber
                                assert start_idx is None, "frames out of order?"
                                start_idx = this_idx

                        max_frame = max(max_frame, framenumber)
                        h5_2d_row_idxs = h5_frame_qfi.get_frame_idxs(
                            framenumber)

                        frame2d = data2d[h5_2d_row_idxs]
                        frame2d_idxs = data2d_idxs[h5_2d_row_idxs]

                        obs_2d_idx = this_3d_row['obs_2d_idx']
                        kobs_2d_data = ML_estimates_2d_idxs[int(obs_2d_idx)]

                        # Parse VLArray.
                        this_camns = kobs_2d_data[0::2]
                        this_camn_idxs = kobs_2d_data[1::2]

                        # Now, for each camera viewing this object at this
                        # frame, extract images.
                        for camn, camn_pt_no in zip(this_camns,
                                                    this_camn_idxs):
                            # find 2D point corresponding to object
                            cam_id = camn2cam_id[camn]

                            cond = ((frame2d['camn'] == camn) &
                                    (frame2d['frame_pt_idx'] == camn_pt_no))
                            idxs = np.nonzero(cond)[0]
                            if len(idxs) == 0:
                                continue
                            assert len(idxs) == 1
                            ## if len(idxs)!=1:
                            ##     raise ValueError('expected one (and only one) frame, got %d'%len(idxs))
                            idx = idxs[0]

                            orig_data2d_rownum = frame2d_idxs[idx]
                            frame_timestamp = frame2d[idx]['timestamp']

                            row = frame2d[idx]
                            assert framenumber == row['frame']
                            if ((row['eccentricity'] <
                                 reconst.minimum_eccentricity)
                                    or (row['area'] <
                                        area_threshold_for_orientation)):
                                slopes_by_camn_by_frame[camn][
                                    framenumber] = np.nan
                                x0d_by_camn_by_frame[camn][
                                    framenumber] = np.nan
                                y0d_by_camn_by_frame[camn][
                                    framenumber] = np.nan
                                pt_idx_by_camn_by_frame[camn][
                                    framenumber] = camn_pt_no
                            else:
                                slopes_by_camn_by_frame[camn][
                                    framenumber] = row['slope']
                                x0d_by_camn_by_frame[camn][framenumber] = row[
                                    'x']
                                y0d_by_camn_by_frame[camn][framenumber] = row[
                                    'y']
                                pt_idx_by_camn_by_frame[camn][
                                    framenumber] = camn_pt_no

                    if start_idx is None:
                        warnings.warn("skipping obj_id %d: "
                                      "could not find valid start frame" %
                                      obj_id)
                        continue

                    obj_3d_rows = obj_3d_rows[start_idx:]

                    # now collect in a numpy array for all cam

                    assert int(min_frame) == min_frame
                    assert int(max_frame + 1) == max_frame + 1
                    frame_range = np.arange(int(min_frame), int(max_frame + 1))
                    if debug_level >= 1:
                        print 'frame range %d-%d' % (frame_range[0],
                                                     frame_range[-1])
                    camn_list = slopes_by_camn_by_frame.keys()
                    camn_list.sort()
                    cam_id_list = [camn2cam_id[camn] for camn in camn_list]
                    n_cams = len(camn_list)
                    n_frames = len(frame_range)

                    save_cols = {}
                    save_cols['frame'] = []
                    for camn in camn_list:
                        save_cols['dist%d' % camn] = []
                        save_cols['used%d' % camn] = []
                        save_cols['theta%d' % camn] = []

                    # NxM array with rows being frames and cols being cameras
                    slopes = np.ones((n_frames, n_cams), dtype=np.float)
                    x0ds = np.ones((n_frames, n_cams), dtype=np.float)
                    y0ds = np.ones((n_frames, n_cams), dtype=np.float)
                    for j, camn in enumerate(camn_list):

                        slopes_by_frame = slopes_by_camn_by_frame[camn]
                        x0d_by_frame = x0d_by_camn_by_frame[camn]
                        y0d_by_frame = y0d_by_camn_by_frame[camn]

                        for frame_idx, absolute_frame_number in enumerate(
                                frame_range):

                            slopes[frame_idx, j] = slopes_by_frame.get(
                                absolute_frame_number, np.nan)
                            x0ds[frame_idx,
                                 j] = x0d_by_frame.get(absolute_frame_number,
                                                       np.nan)
                            y0ds[frame_idx,
                                 j] = y0d_by_frame.get(absolute_frame_number,
                                                       np.nan)

                        if options.show:
                            frf = np.array(frame_range, dtype=np.float)
                            min_frame_range = min(np.min(frf), min_frame_range)
                            max_frame_range = max(np.max(frf), max_frame_range)

                            ax1.plot(frame_range,
                                     slope2modpi(slopes[:, j]),
                                     '.',
                                     label=camn2cam_id[camn])

                    if options.show:
                        ax1.legend()

                    if 1:
                        # estimate orientation of initial frame
                        row0 = obj_3d_rows[:
                                           1]  # take only first row but keep as 1d array
                        hzlines = np.array([
                            row0['hz_line0'], row0['hz_line1'],
                            row0['hz_line2'], row0['hz_line3'],
                            row0['hz_line4'], row0['hz_line5']
                        ]).T
                        directions = reconstruct.line_direction(hzlines)
                        q0 = PQmath.orientation_to_quat(directions[0])
                        assert not np.isnan(
                            q0.x), "cannot start with missing orientation"
                        w0 = 0, 0, 0  # no angular rate
                        init_x = np.array(
                            [w0[0], w0[1], w0[2], q0.x, q0.y, q0.z, q0.w])

                        Pminus = np.zeros((7, 7))

                        # angular rate part of state variance is .5
                        for i in range(0, 3):
                            Pminus[i, i] = .5

                        # quaternion part of state variance is 1
                        for i in range(3, 7):
                            Pminus[i, i] = 1

                    if 1:
                        # setup of noise estimates
                        Q = np.zeros((7, 7))

                        # angular rate part of state variance
                        for i in range(0, 3):
                            Q[i, i] = Q_scalar_rate

                        # quaternion part of state variance
                        for i in range(3, 7):
                            Q[i, i] = Q_scalar_quat

                    preA = np.eye(7)

                    ekf = kalman_ekf.EKF(init_x, Pminus)
                    previous_posterior_x = init_x
                    if options.show:
                        _save_plot_rows = []
                        _save_plot_rows_used = []
                    for frame_idx, absolute_frame_number in enumerate(
                            frame_range):
                        # Evaluate the Jacobian of the process update
                        # using previous frame's posterior estimate. (This
                        # is not quite the same as this frame's prior
                        # estimate. The difference this frame's prior
                        # estimate is _after_ the process update
                        # model. Which we need to get doing this.)

                        if options.show:
                            _save_plot_rows.append(np.nan * np.ones(
                                (n_cams, )))
                            _save_plot_rows_used.append(np.nan * np.ones(
                                (n_cams, )))

                        this_dx = eval_dAdt(previous_posterior_x)
                        A = preA + this_dx * dt
                        if debug_level >= 1:
                            print
                            print 'frame', absolute_frame_number, '-' * 40
                            print 'previous posterior', previous_posterior_x
                            if debug_level > 6:
                                print 'A'
                                print A

                        xhatminus, Pminus = ekf.step1__calculate_a_priori(A, Q)
                        if debug_level >= 1:
                            print 'new prior', xhatminus

                        # 1. Gate per-camera orientations.

                        this_frame_slopes = slopes[frame_idx, :]
                        this_frame_theta_measured = slope2modpi(
                            this_frame_slopes)
                        this_frame_x0d = x0ds[frame_idx, :]
                        this_frame_y0d = y0ds[frame_idx, :]
                        if debug_level >= 5:
                            print 'this_frame_slopes', this_frame_slopes

                        save_cols['frame'].append(absolute_frame_number)
                        for j, camn in enumerate(camn_list):
                            # default to no detection, change below
                            save_cols['dist%d' % camn].append(np.nan)
                            save_cols['used%d' % camn].append(0)
                            save_cols['theta%d' % camn].append(
                                this_frame_theta_measured[j])

                        all_data_this_frame_missing = False
                        gate_vector = None

                        y = []  # observation (per camera)
                        hx = []  # expected observation (per camera)
                        C = []  # linearized observation model (per camera)
                        N_obs_this_frame = 0
                        cams_without_data = np.isnan(this_frame_slopes)
                        if np.all(cams_without_data):
                            all_data_this_frame_missing = True

                        smoothed_pos_idxs = smoothed_frame_qfi.get_frame_idxs(
                            absolute_frame_number)
                        if len(smoothed_pos_idxs) == 0:
                            all_data_this_frame_missing = True
                            smoothed_pos_idx = None
                            smooth_row = None
                            center_position = None
                        else:
                            try:
                                assert len(smoothed_pos_idxs) == 1
                            except:
                                print 'obj_id', obj_id
                                print 'absolute_frame_number', absolute_frame_number
                                if len(frame_range):
                                    print 'frame_range[0],frame_rang[-1]', frame_range[
                                        0], frame_range[-1]
                                else:
                                    print 'no frame range'
                                print 'len(smoothed_pos_idxs)', len(
                                    smoothed_pos_idxs)
                                raise
                            smoothed_pos_idx = smoothed_pos_idxs[0]
                            smooth_row = smoothed_3d_rows[smoothed_pos_idx]
                            assert smooth_row['frame'] == absolute_frame_number
                            center_position = np.array(
                                (smooth_row['x'], smooth_row['y'],
                                 smooth_row['z']))
                            if debug_level >= 2:
                                print 'center_position', center_position

                        if not all_data_this_frame_missing:
                            if expected_orientation_method == 'trust_prior':
                                state_for_phi = xhatminus  # use a priori
                            elif expected_orientation_method == 'SVD_line_fits':
                                # construct matrix of planes
                                P = []
                                for camn_idx in range(n_cams):
                                    this_x0d = this_frame_x0d[camn_idx]
                                    this_y0d = this_frame_y0d[camn_idx]
                                    slope = this_frame_slopes[camn_idx]
                                    plane, ray = reconst.get_3D_plane_and_ray(
                                        cam_id, this_x0d, this_y0d, slope)
                                    if np.isnan(plane[0]):
                                        continue
                                    P.append(plane)
                                if len(P) < 2:
                                    # not enough data to do SVD... fallback to prior
                                    state_for_phi = xhatminus  # use a priori
                                else:
                                    Lco = reconstruct.intersect_planes_to_find_line(
                                        P)
                                    q = PQmath.pluecker_to_quat(Lco)
                                    state_for_phi = cgtypes_quat2statespace(q)

                            cams_with_data = ~cams_without_data
                            possible_cam_idxs = np.nonzero(cams_with_data)[0]
                            if debug_level >= 6:
                                print 'possible_cam_idxs', possible_cam_idxs
                            gate_vector = np.zeros((n_cams, ), dtype=np.bool)
                            ## flip_vector = np.zeros( (n_cams,), dtype=np.bool)
                            for camn_idx in possible_cam_idxs:
                                cam_id = cam_id_list[camn_idx]
                                camn = camn_list[camn_idx]

                                # This ignores distortion. To incorporate
                                # distortion, this would require
                                # appropriate scaling of orientation
                                # vector, which would require knowing
                                # target's size. In which case we should
                                # track head and tail separately and not
                                # use this whole quaternion mess.

                                ## theta_measured=slope2modpi(
                                ##     this_frame_slopes[camn_idx])
                                theta_measured = this_frame_theta_measured[
                                    camn_idx]
                                if debug_level >= 6:
                                    print 'cam_id %s, camn %d' % (cam_id, camn)
                                if debug_level >= 3:
                                    a = reconst.find2d(cam_id, center_position)
                                    other_position = get_point_on_line(
                                        xhatminus, center_position)
                                    b = reconst.find2d(cam_id, other_position)
                                    theta_expected = find_theta_mod_pi_between_points(
                                        a, b)
                                    print('  theta_expected,theta_measured',
                                          theta_expected * R2D,
                                          theta_measured * R2D)

                                P = reconst.get_pmat(cam_id)
                                if 0:
                                    args_x = (P[0, 0], P[0, 1], P[0, 2],
                                              P[0, 3], P[1, 0], P[1, 1], P[1,
                                                                           2],
                                              P[1, 3], P[2, 0], P[2, 1], P[2,
                                                                           2],
                                              P[2, 3], center_position[0],
                                              center_position[1],
                                              center_position[2], xhatminus)
                                    this_y = theta_measured
                                    this_hx = eval_G(*args_x)
                                    this_C = eval_linG(*args_x)
                                else:
                                    args_x_xm = (P[0, 0], P[0, 1], P[0, 2],
                                                 P[0, 3], P[1, 0], P[1,
                                                                     1], P[1,
                                                                           2],
                                                 P[1, 3], P[2, 0], P[2,
                                                                     1], P[2,
                                                                           2],
                                                 P[2, 3], center_position[0],
                                                 center_position[1],
                                                 center_position[2], xhatminus,
                                                 state_for_phi)
                                    this_phi = eval_phi(*args_x_xm)
                                    this_y = angle_diff(theta_measured,
                                                        this_phi,
                                                        mod_pi=True)
                                    this_hx = eval_H(*args_x_xm)
                                    this_C = eval_linH(*args_x_xm)
                                    if debug_level >= 3:
                                        print('  this_phi,this_y',
                                              this_phi * R2D, this_y * R2D)

                                save_cols['dist%d' % camn][-1] = this_y  # save

                                # gate
                                if abs(this_y) < gate_angle_threshold_radians:
                                    save_cols['used%d' % camn][-1] = 1
                                    gate_vector[camn_idx] = 1
                                    if debug_level >= 3:
                                        print '    good'
                                    if options.show:
                                        _save_plot_rows_used[-1][
                                            camn_idx] = this_y
                                    y.append(this_y)
                                    hx.append(this_hx)
                                    C.append(this_C)
                                    N_obs_this_frame += 1

                                    # Save which camn and camn_pt_no was used.
                                    if absolute_frame_number not in used_camn_dict:
                                        used_camn_dict[
                                            absolute_frame_number] = []
                                    camn_pt_no = (pt_idx_by_camn_by_frame[camn]
                                                  [absolute_frame_number])
                                    used_camn_dict[
                                        absolute_frame_number].append(
                                            (camn, camn_pt_no))
                                else:
                                    if options.show:
                                        _save_plot_rows[-1][camn_idx] = this_y
                                    if debug_level >= 6:
                                        print '    bad'
                            if debug_level >= 1:
                                print 'gate_vector', gate_vector
                                #print 'flip_vector',flip_vector
                            all_data_this_frame_missing = not bool(
                                np.sum(gate_vector))

                        # 3. Construct observations model using all
                        # gated-in camera orientations.

                        if all_data_this_frame_missing:
                            C = None
                            R = None
                            hx = None
                        else:
                            C = np.array(C)
                            R = R_scalar * np.eye(N_obs_this_frame)
                            hx = np.array(hx)
                            if 0:
                                # crazy observation error scaling
                                for i in range(N_obs_this_frame):
                                    beyond = abs(y[i]) - 10 * D2R
                                    beyond = max(0, beyond)  # clip at zero
                                    R[i:i] = R_scalar * (1 + 10 * beyond)
                            if debug_level >= 6:
                                print 'full values'
                                print 'C', C
                                print 'hx', hx
                                print 'y', y
                                print 'R', R

                        if debug_level >= 1:
                            print 'all_data_this_frame_missing', all_data_this_frame_missing
                        xhat, P = ekf.step2__calculate_a_posteriori(
                            xhatminus,
                            Pminus,
                            y=y,
                            hx=hx,
                            C=C,
                            R=R,
                            missing_data=all_data_this_frame_missing)
                        if debug_level >= 1:
                            print 'xhat', xhat
                        previous_posterior_x = xhat
                        if center_position is not None:
                            # save
                            output_row_frame_cond = all_kobs_frames == absolute_frame_number
                            output_row_cond = output_row_frame_cond & output_row_obj_id_cond
                            output_idxs = np.nonzero(output_row_cond)[0]
                            if len(output_idxs) == 0:
                                pass
                            else:
                                assert len(output_idxs) == 1
                                idx = output_idxs[0]
                                hz = state_to_hzline(xhat, center_position)
                                for row in dest_table.iterrows(start=idx,
                                                               stop=(idx + 1)):
                                    for i in range(6):
                                        row['hz_line%d' % i] = hz[i]
                                    row.update()
                        ## xhat_results[ obj_id ][absolute_frame_number ] = (
                        ##     xhat,center_position)
                        if options.show:
                            all_xhats.append(xhat)
                            all_ori.append(state_to_ori(xhat))

                    # save to H5 file
                    names = [colname for colname in save_cols]
                    names.sort()
                    arrays = []
                    for name in names:
                        if name == 'frame':
                            dtype = np.int64
                        elif name.startswith('dist'):
                            dtype = np.float32
                        elif name.startswith('used'):
                            dtype = np.bool
                        elif name.startswith('theta'):
                            dtype = np.float32
                        else:
                            raise NameError('unknown name %s' % name)
                        arr = np.array(save_cols[name], dtype=dtype)
                        arrays.append(arr)
                    save_recarray = np.rec.fromarrays(arrays, names=names)
                    h5group = core_analysis.get_group_for_obj(obj_id,
                                                              output_h5,
                                                              writeable=True)
                    output_h5.create_table(h5group,
                                           'obj%d' % obj_id,
                                           save_recarray,
                                           filters=tables.Filters(
                                               1, complib='lzo'))

                    if options.show:
                        all_xhats = np.array(all_xhats)
                        all_ori = np.array(all_ori)
                        _save_plot_rows = np.array(_save_plot_rows)
                        _save_plot_rows_used = np.array(_save_plot_rows_used)

                        ax2.plot(frame_range, all_xhats[:, 0], '.', label='p')
                        ax2.plot(frame_range, all_xhats[:, 1], '.', label='q')
                        ax2.plot(frame_range, all_xhats[:, 2], '.', label='r')
                        ax2.legend()

                        ax3.plot(frame_range, all_xhats[:, 3], '.', label='a')
                        ax3.plot(frame_range, all_xhats[:, 4], '.', label='b')
                        ax3.plot(frame_range, all_xhats[:, 5], '.', label='c')
                        ax3.plot(frame_range, all_xhats[:, 6], '.', label='d')
                        ax3.legend()

                        ax4.plot(frame_range, all_ori[:, 0], '.', label='x')
                        ax4.plot(frame_range, all_ori[:, 1], '.', label='y')
                        ax4.plot(frame_range, all_ori[:, 2], '.', label='z')
                        ax4.legend()

                        colors = []
                        for i in range(n_cams):
                            line, = ax5.plot(frame_range,
                                             _save_plot_rows_used[:, i] * R2D,
                                             'o',
                                             label=cam_id_list[i])
                            colors.append(line.get_color())
                        for i in range(n_cams):
                            # loop again to get normal MPL color cycling
                            ax5.plot(frame_range,
                                     _save_plot_rows[:, i] * R2D,
                                     'o',
                                     mec=colors[i],
                                     ms=1.0)
                        ax5.set_ylabel('observation (deg)')
                        ax5.legend()
示例#7
0
def plot_timeseries(subplot=None, options=None):
    kalman_filename = options.kalman_filename

    if not hasattr(options, 'frames'):
        options.frames = False

    if not hasattr(options, 'show_landing'):
        options.show_landing = False

    if not hasattr(options, 'unicolor'):
        options.unicolor = False

    if not hasattr(options, 'show_obj_id'):
        options.show_obj_id = True

    if not hasattr(options, 'show_track_ends'):
        options.show_track_ends = False

    start = options.start
    stop = options.stop
    obj_only = options.obj_only
    fps = options.fps
    dynamic_model = options.dynamic_model
    use_kalman_smoothing = options.use_kalman_smoothing

    if not use_kalman_smoothing:
        if (dynamic_model is not None):
            print >> sys.stderr, (
                'WARNING: disabling Kalman smoothing '
                '(--disable-kalman-smoothing) is incompatable '
                'with setting dynamic model options (--dynamic-model)')

    ca = core_analysis.get_global_CachingAnalyzer()

    if kalman_filename is None:
        raise ValueError('No kalman_filename given. Nothing to do.')

    m = hashlib.md5()
    m.update(open(kalman_filename, mode='rb').read())
    actual_md5 = m.hexdigest()
    (obj_ids, use_obj_ids, is_mat_file, data_file,
     extra) = ca.initial_file_load(kalman_filename)
    print 'opened kalman file %s %s, %d obj_ids' % (
        kalman_filename, actual_md5, len(use_obj_ids))

    if 'frames' in extra:
        if (start is not None) or (stop is not None):
            valid_frames = np.ones((len(extra['frames']), ), dtype=np.bool)
            if start is not None:
                valid_frames &= extra['frames'] >= start
            if stop is not None:
                valid_frames &= extra['frames'] <= stop
            this_use_obj_ids = np.unique(obj_ids[valid_frames])
            use_obj_ids = list(set(use_obj_ids).intersection(this_use_obj_ids))

    include_obj_ids = None
    exclude_obj_ids = None
    do_fuse = False
    if options.stim_xml:
        file_timestamp = data_file.filename[4:19]
        fanout = xml_stimulus.xml_fanout_from_filename(options.stim_xml)
        include_obj_ids, exclude_obj_ids = fanout.get_obj_ids_for_timestamp(
            timestamp_string=file_timestamp)
        walking_start_stops = fanout.get_walking_start_stops_for_timestamp(
            timestamp_string=file_timestamp)
        if include_obj_ids is not None:
            use_obj_ids = include_obj_ids
        if exclude_obj_ids is not None:
            use_obj_ids = list(set(use_obj_ids).difference(exclude_obj_ids))
        if options.fuse:
            do_fuse = True
    else:
        walking_start_stops = []

    if dynamic_model is None:
        dynamic_model = extra['dynamic_model_name']
        print 'detected file loaded with dynamic model "%s"' % dynamic_model
        if dynamic_model.startswith('EKF '):
            dynamic_model = dynamic_model[4:]
        print '  for smoothing, will use dynamic model "%s"' % dynamic_model

    if not is_mat_file:
        mat_data = None

        if fps is None:
            fps = result_utils.get_fps(data_file, fail_on_error=False)

        if fps is None:
            fps = 100.0
            import warnings
            warnings.warn('Setting fps to default value of %f' % fps)

        tz = result_utils.get_tz(data_file)

    dt = 1.0 / fps

    all_vels = []

    if obj_only is not None:
        use_obj_ids = [i for i in use_obj_ids if i in obj_only]

    allX = {}
    frame0 = None

    line2obj_id = {}
    Xz_all = []

    fuse_did_once = False

    if not hasattr(options, 'timestamp_file'):
        options.timestamp_file = None

    if not hasattr(options, 'ori_qual'):
        options.ori_qual = None

    if options.timestamp_file is not None:
        h5 = tables.open_file(options.timestamp_file, mode='r')
        print 'reading timestamps and frames'
        table_data2d_frames = h5.root.data2d_distorted.read(field='frame')
        table_data2d_timestamps = h5.root.data2d_distorted.read(
            field='timestamp')
        print 'done'
        h5.close()
        table_data2d_frames_find = utils.FastFinder(table_data2d_frames)

    if len(use_obj_ids) == 0:
        print 'No obj_ids to plot, quitting'
        sys.exit(0)

    time0 = 0.0  # set default value

    for obj_id in use_obj_ids:
        if not do_fuse:
            try:
                kalman_rows = ca.load_data(
                    obj_id,
                    data_file,
                    use_kalman_smoothing=use_kalman_smoothing,
                    dynamic_model_name=dynamic_model,
                    return_smoothed_directions=options.smooth_orientations,
                    frames_per_second=fps,
                    up_dir=options.up_dir,
                    min_ori_quality_required=options.ori_qual,
                )
            except core_analysis.ObjectIDDataError:
                continue
            #kobs_rows = ca.load_dynamics_free_MLE_position( obj_id, data_file )
        else:
            if options.show_3d_orientations:
                raise NotImplementedError('orientation data is not supported '
                                          'when fusing obj_ids')
            if fuse_did_once:
                break
            fuse_did_once = True
            kalman_rows = flydra_analysis.a2.flypos.fuse_obj_ids(
                use_obj_ids,
                data_file,
                dynamic_model_name=dynamic_model,
                frames_per_second=fps)
        frame = kalman_rows['frame']

        if (start is not None) or (stop is not None):
            valid_cond = numpy.ones(frame.shape, dtype=numpy.bool)

            if start is not None:
                valid_cond = valid_cond & (frame >= start)

            if stop is not None:
                valid_cond = valid_cond & (frame <= stop)

            kalman_rows = kalman_rows[valid_cond]
            if not len(kalman_rows):
                continue

        walking_and_flying_kalman_rows = kalman_rows  # preserve original data

        for flystate in ['flying', 'walking']:
            frame = walking_and_flying_kalman_rows['frame']  # restore
            if flystate == 'flying':
                # assume flying unless we're told it's walking
                state_cond = numpy.ones(frame.shape, dtype=numpy.bool)
            else:
                state_cond = numpy.zeros(frame.shape, dtype=numpy.bool)

            if len(walking_start_stops):
                for walkstart, walkstop in walking_start_stops:
                    frame = walking_and_flying_kalman_rows['frame']  # restore

                    # handle each bout of walking
                    walking_bout = numpy.ones(frame.shape, dtype=numpy.bool)
                    if walkstart is not None:
                        walking_bout &= (frame >= walkstart)
                    if walkstop is not None:
                        walking_bout &= (frame <= walkstop)
                    if flystate == 'flying':
                        state_cond &= ~walking_bout
                    else:
                        state_cond |= walking_bout

                kalman_rows = np.take(walking_and_flying_kalman_rows,
                                      np.nonzero(state_cond)[0])
                assert len(kalman_rows) == np.sum(state_cond)
                frame = kalman_rows['frame']

            if frame0 is None:
                frame0 = int(frame[0])

            time0 = 0.0
            if options.timestamp_file is not None:
                frame_idxs = table_data2d_frames_find.get_idxs_of_equal(frame0)
                if len(frame_idxs):
                    time0 = table_data2d_timestamps[frame_idxs[0]]
                else:
                    raise ValueError(
                        'could not fine frame %d in timestamp file' % frame0)

            Xx = kalman_rows['x']
            Xy = kalman_rows['y']
            Xz = kalman_rows['z']

            Dx = Dy = Dz = None
            if options.smooth_orientations:
                Dx = kalman_rows['dir_x']
                Dy = kalman_rows['dir_y']
                Dz = kalman_rows['dir_z']
            elif 'rawdir_x' in kalman_rows.dtype.fields:
                Dx = kalman_rows['rawdir_x']
                Dy = kalman_rows['rawdir_y']
                Dz = kalman_rows['rawdir_z']

            if not options.frames:
                f2t = Frames2Time(frame0, fps, time0)
            else:

                def identity(x):
                    return x

                f2t = identity

            kws = {
                'linewidth': 2,
                'picker': 5,
            }
            if options.unicolor:
                kws['color'] = 'k'

            line = None

            if 'frame' in subplot:
                subplot['frame'].plot(f2t(frame), frame)

            if 'P55' in subplot:
                subplot['P55'].plot(f2t(frame), kalman_rows['P55'])

            if 'x' in subplot:
                line, = subplot['x'].plot(f2t(frame),
                                          Xx,
                                          label='obj %d (%s)' %
                                          (obj_id, flystate),
                                          **kws)
                line2obj_id[line] = obj_id
                kws['color'] = line.get_color()

            if 'y' in subplot:
                line, = subplot['y'].plot(f2t(frame),
                                          Xy,
                                          label='obj %d (%s)' %
                                          (obj_id, flystate),
                                          **kws)
                line2obj_id[line] = obj_id
                kws['color'] = line.get_color()

            if 'z' in subplot:
                frame_data = numpy.ma.getdata(
                    frame)  # works if frame is masked or not

                # plot landing time
                if options.show_landing:
                    if flystate == 'flying':  # only do this once
                        for walkstart, walkstop in walking_start_stops:
                            if walkstart in frame_data:
                                landing_dix = numpy.nonzero(
                                    frame_data == walkstart)[0][0]
                                subplot['z'].plot([f2t(walkstart)],
                                                  [Xz.data[landing_dix]],
                                                  'rD',
                                                  ms=10,
                                                  label='landing')

                if options.show_track_ends:
                    if flystate == 'flying':  # only do this once
                        subplot['z'].plot(f2t([frame_data[0], frame_data[-1]]),
                                          [
                                              numpy.ma.getdata(Xz)[0],
                                              numpy.ma.getdata(Xz)[-1]
                                          ],
                                          'cd',
                                          ms=6,
                                          label='track end')

                line, = subplot['z'].plot(f2t(frame),
                                          Xz,
                                          label='obj %d (%s)' %
                                          (obj_id, flystate),
                                          **kws)
                kws['color'] = line.get_color()
                line2obj_id[line] = obj_id

                if flystate == 'flying':
                    # only do this once
                    if options.show_obj_id:
                        subplot['z'].text(f2t(frame_data[0]),
                                          numpy.ma.getdata(Xz)[0],
                                          '%d' % (obj_id, ))
                        line2obj_id[line] = obj_id

            if flystate == 'flying':
                Xz_all.append(np.ma.array(Xz).compressed())
                #bins = np.linspace(0,.8,30)
                #print 'Xz.shape',Xz.shape
                #pylab.hist(Xz, bins=bins)

            for (dir_var, Dd) in [('dx', Dx), ('dy', Dy), ('dz', Dz)]:
                if dir_var in subplot:
                    line, = subplot[dir_var].plot(f2t(frame),
                                                  Dd,
                                                  label='obj %d (%s)' %
                                                  (obj_id, flystate),
                                                  **kws)
                    line2obj_id[line] = obj_id
                    kws['color'] = line.get_color()

            if numpy.__version__ >= '1.2.0':
                X = numpy.ma.array((Xx, Xy, Xz))
            else:
                # See http://scipy.org/scipy/numpy/ticket/820
                X = numpy.ma.vstack(
                    (Xx[numpy.newaxis, :], Xy[numpy.newaxis, :],
                     Xz[numpy.newaxis, :]))

            dist_central_diff = (X[:, 2:] - X[:, :-2])
            vel_central_diff = dist_central_diff / (2 * dt)

            vel2mag = numpy.ma.sqrt(numpy.ma.sum(vel_central_diff**2, axis=0))
            xy_vel2mag = numpy.ma.sqrt(
                numpy.ma.sum(vel_central_diff[:2, :]**2, axis=0))

            frames2 = frame[1:-1]

            accel4mag = (vel2mag[2:] - vel2mag[:-2]) / (2 * dt)
            frames4 = frames2[1:-1]

            if 'vel' in subplot:
                line, = subplot['vel'].plot(f2t(frames2),
                                            vel2mag,
                                            label='obj %d (%s)' %
                                            (obj_id, flystate),
                                            **kws)
                line2obj_id[line] = obj_id
                kws['color'] = line.get_color()

            if 'xy_vel' in subplot:
                line, = subplot['xy_vel'].plot(f2t(frames2),
                                               xy_vel2mag,
                                               label='obj %d (%s)' %
                                               (obj_id, flystate),
                                               **kws)
                line2obj_id[line] = obj_id
                kws['color'] = line.get_color()

            if len(accel4mag.compressed()) and 'accel' in subplot:
                line, = subplot['accel'].plot(f2t(frames4),
                                              accel4mag,
                                              label='obj %d (%s)' %
                                              (obj_id, flystate),
                                              **kws)
                line2obj_id[line] = obj_id
                kws['color'] = line.get_color()

            if flystate == 'flying':
                valid_vel2mag = vel2mag.compressed()
                all_vels.append(valid_vel2mag)
    if len(all_vels):
        all_vels = numpy.hstack(all_vels)
    else:
        all_vels = numpy.array([], dtype=float)

    if 1:
        cond = all_vels < 2.0
        if numpy.ma.sum(cond) != len(all_vels):
            all_vels = all_vels[cond]
            import warnings
            warnings.warn('clipping all velocities > 2.0 m/s')

    if not options.frames:
        xlabel = 'time (s)'
    else:
        xlabel = 'frame'

    for ax in subplot.itervalues():
        ax.xaxis.set_major_formatter(ticker.FormatStrFormatter("%d"))
        ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%s"))

    fixup_ax = FixupAxesWithTimeZone(tz).fixup_ax

    if 'frame' in subplot:
        if time0 != 0.0:
            fixup_ax(subplot['frame'])
        else:
            subplot['frame'].set_xlabel(xlabel)

    if 'x' in subplot:
        subplot['x'].set_ylim([-1, 1])
        subplot['x'].set_ylabel(r'x (m)')
        if time0 != 0.0:
            fixup_ax(subplot['x'])
        else:
            subplot['x'].set_xlabel(xlabel)

    if 'y' in subplot:
        subplot['y'].set_ylim([-0.5, 1.5])
        subplot['y'].set_ylabel(r'y (m)')
        if time0 != 0.0:
            fixup_ax(subplot['y'])
        else:
            subplot['y'].set_xlabel(xlabel)

    max_z = None
    if options.stim_xml:
        file_timestamp = options.kalman_filename[4:19]
        stim_xml = xml_stimulus.xml_stimulus_from_filename(
            options.stim_xml, timestamp_string=file_timestamp)
        post_max_zs = []
        for post_num, post in enumerate(stim_xml.iterate_posts()):
            post_max_zs.append(max(post['verts'][0][2],
                                   post['verts'][1][2]))  # max post height
        if len(post_max_zs):
            max_z = min(post_max_zs)  # take shortest of posts

    if 'z' in subplot:
        subplot['z'].set_ylim([0, 1])
        subplot['z'].set_ylabel(r'z (m)')
        if max_z is not None:
            subplot['z'].axhline(max_z, color='m')
        if time0 != 0.0:
            fixup_ax(subplot['z'])
        else:
            subplot['z'].set_xlabel(xlabel)

    for dir_var in ['dx', 'dy', 'dz']:
        if dir_var in subplot:
            subplot[dir_var].set_ylabel(dir_var)
            if time0 != 0.0:
                fixup_ax(subplot[dir_var])
            else:
                subplot[dir_var].set_xlabel(xlabel)

    if 'z_hist' in subplot:  # and flystate=='flying':
        Xz_all = np.hstack(Xz_all)
        bins = np.linspace(0, .8, 30)
        ax = subplot['z_hist']
        ax.hist(Xz_all, bins=bins, orientation='horizontal')
        ax.set_xticks([])
        ax.set_yticks([])
        xlim = tuple(ax.get_xlim())  # matplotlib 0.98.3 returned np.array view
        ax.set_xlim((xlim[1], xlim[0]))
        ax.axhline(max_z, color='m')

    if 'vel' in subplot:
        subplot['vel'].set_ylim([0, 2])
        subplot['vel'].set_ylabel(r'vel (m/s)')
        subplot['vel'].set_xlabel(xlabel)
        if time0 != 0.0:
            fixup_ax(subplot['vel'])
        else:
            subplot['vel'].set_xlabel(xlabel)

    if 'xy_vel' in subplot:
        #subplot['xy_vel'].set_ylim([0,2])
        subplot['xy_vel'].set_ylabel(r'horiz vel (m/s)')
        subplot['xy_vel'].set_xlabel(xlabel)
        if time0 != 0.0:
            fixup_ax(subplot['xy_vel'])
        else:
            subplot['xy_vel'].set_xlabel(xlabel)

    if 'accel' in subplot:
        subplot['accel'].set_ylabel(r'acceleration (m/(s^2))')
        subplot['accel'].set_xlabel(xlabel)
        if time0 != 0.0:
            fixup_ax(subplot['accel'])
        else:
            subplot['accel'].set_xlabel(xlabel)

    if 'vel_hist' in subplot:
        ax = subplot['vel_hist']
        bins = numpy.linspace(0, 2, 50)
        ax.set_title('excluding walking')
        pdf, bins, patches = ax.hist(all_vels, bins=bins, normed=True)
        ax.set_xlim(0, 2)
        ax.set_ylabel('probability density')
        ax.set_xlabel('velocity (m/s)')

    return line2obj_id
示例#8
0
def doit(
    h5_filename=None,
    output_h5_filename=None,
    ufmf_filenames=None,
    kalman_filename=None,
    start=None,
    stop=None,
    view=None,
    erode=0,
    save_images=False,
    save_image_dir=None,
    intermediate_thresh_frac=None,
    final_thresh=None,
    stack_N_images=None,
    stack_N_images_min=None,
    old_sync_timestamp_source=False,
    do_rts_smoothing=True,
):
    """

    Copy all data in .h5 file (specified by h5_filename) to a new .h5
    file in which orientations are set based on image analysis of
    .ufmf files. Tracking data to associate 2D points from subsequent
    frames is read from the .h5 kalman file specified by
    kalman_filename.

    """
    if view is None:
        view = ["orig" for f in ufmf_filenames]
    else:
        assert len(view) == len(ufmf_filenames)

    if intermediate_thresh_frac is None or final_thresh is None:
        raise ValueError("intermediate_thresh_frac and final_thresh must be "
                         "set")

    filename2view = dict(zip(ufmf_filenames, view))

    ca = core_analysis.get_global_CachingAnalyzer()
    obj_ids, use_obj_ids, is_mat_file, data_file, extra = ca.initial_file_load(
        kalman_filename)
    try:
        ML_estimates_2d_idxs = data_file.root.ML_estimates_2d_idxs[:]
    except tables.exceptions.NoSuchNodeError as err1:
        # backwards compatibility
        try:
            ML_estimates_2d_idxs = data_file.root.kalman_observations_2d_idxs[:]
        except tables.exceptions.NoSuchNodeError as err2:
            raise err1

    if os.path.exists(output_h5_filename):
        raise RuntimeError("will not overwrite old file '%s'" %
                           output_h5_filename)
    with open_file_safe(output_h5_filename, delete_on_error=True,
                        mode="w") as output_h5:
        if save_image_dir is not None:
            if not os.path.exists(save_image_dir):
                os.mkdir(save_image_dir)

        with open_file_safe(h5_filename, mode="r") as h5:

            fps = result_utils.get_fps(h5, fail_on_error=True)

            for input_node in h5.root._f_iter_nodes():
                # copy everything from source to dest
                input_node._f_copy(output_h5.root, recursive=True)
            print("done copying")

            # Clear values in destination table that we may overwrite.
            dest_table = output_h5.root.data2d_distorted
            for colname in [
                    "x",
                    "y",
                    "area",
                    "slope",
                    "eccentricity",
                    "cur_val",
                    "mean_val",
                    "sumsqf_val",
            ]:
                if colname == "cur_val":
                    fill_value = 0
                else:
                    fill_value = np.nan
                clear_col(dest_table, colname, fill_value=fill_value)
            dest_table.flush()
            print("done clearing")

            camn2cam_id, cam_id2camns = result_utils.get_caminfo_dicts(h5)

            cam_id2fmfs = collections.defaultdict(list)
            cam_id2view = {}
            for ufmf_filename in ufmf_filenames:
                fmf = ufmf.FlyMovieEmulator(
                    ufmf_filename,
                    # darken=-50,
                    allow_no_such_frame_errors=True,
                )
                timestamps = fmf.get_all_timestamps()

                cam_id = get_cam_id_from_filename(fmf.filename,
                                                  cam_id2camns.keys())
                cam_id2fmfs[cam_id].append(
                    (fmf, result_utils.Quick1DIndexer(timestamps)))

                cam_id2view[cam_id] = filename2view[fmf.filename]

            # associate framenumbers with timestamps using 2d .h5 file
            data2d = h5.root.data2d_distorted[:]  # load to RAM
            data2d_idxs = np.arange(len(data2d))
            h5_framenumbers = data2d["frame"]
            h5_frame_qfi = result_utils.QuickFrameIndexer(h5_framenumbers)

            fpc = realtime_image_analysis.FitParamsClass(
            )  # allocate FitParamsClass

            for obj_id_enum, obj_id in enumerate(use_obj_ids):
                print("object %d of %d" % (obj_id_enum, len(use_obj_ids)))

                # get all images for this camera and this obj_id

                obj_3d_rows = ca.load_dynamics_free_MLE_position(
                    obj_id, data_file)

                this_obj_framenumbers = collections.defaultdict(list)
                if save_images:
                    this_obj_raw_images = collections.defaultdict(list)
                    this_obj_mean_images = collections.defaultdict(list)
                this_obj_absdiff_images = collections.defaultdict(list)
                this_obj_morphed_images = collections.defaultdict(list)
                this_obj_morph_failures = collections.defaultdict(list)
                this_obj_im_coords = collections.defaultdict(list)
                this_obj_com_coords = collections.defaultdict(list)
                this_obj_camn_pt_no = collections.defaultdict(list)

                for this_3d_row in obj_3d_rows:
                    # iterate over each sample in the current camera
                    framenumber = this_3d_row["frame"]
                    if start is not None:
                        if not framenumber >= start:
                            continue
                    if stop is not None:
                        if not framenumber <= stop:
                            continue
                    h5_2d_row_idxs = h5_frame_qfi.get_frame_idxs(framenumber)

                    frame2d = data2d[h5_2d_row_idxs]
                    frame2d_idxs = data2d_idxs[h5_2d_row_idxs]

                    obs_2d_idx = this_3d_row["obs_2d_idx"]
                    kobs_2d_data = ML_estimates_2d_idxs[int(obs_2d_idx)]

                    # Parse VLArray.
                    this_camns = kobs_2d_data[0::2]
                    this_camn_idxs = kobs_2d_data[1::2]

                    # Now, for each camera viewing this object at this
                    # frame, extract images.
                    for camn, camn_pt_no in zip(this_camns, this_camn_idxs):

                        # find 2D point corresponding to object
                        cam_id = camn2cam_id[camn]

                        movie_tups_for_this_camn = cam_id2fmfs[cam_id]
                        cond = (frame2d["camn"] == camn) & (
                            frame2d["frame_pt_idx"] == camn_pt_no)
                        idxs = np.nonzero(cond)[0]
                        assert len(idxs) == 1
                        idx = idxs[0]

                        orig_data2d_rownum = frame2d_idxs[idx]

                        if not old_sync_timestamp_source:
                            # Change the next line to 'timestamp' for old
                            # data (before May/June 2009 -- the switch to
                            # fview_ext_trig)
                            frame_timestamp = frame2d[idx][
                                "cam_received_timestamp"]
                        else:
                            # previous version
                            frame_timestamp = frame2d[idx]["timestamp"]
                        found = None
                        for fmf, fmf_timestamp_qi in movie_tups_for_this_camn:
                            fmf_fnos = fmf_timestamp_qi.get_idxs(
                                frame_timestamp)
                            if not len(fmf_fnos):
                                continue
                            assert len(fmf_fnos) == 1

                            # should only be one .ufmf with this frame and cam_id
                            assert found is None

                            fmf_fno = fmf_fnos[0]
                            found = (fmf, fmf_fno)
                        if found is None:
                            print(
                                "no image data for frame timestamp %s cam_id %s"
                                % (repr(frame_timestamp), cam_id))
                            continue
                        fmf, fmf_fno = found
                        image, fmf_timestamp = fmf.get_frame(fmf_fno)
                        mean_image = fmf.get_mean_for_timestamp(fmf_timestamp)
                        coding = fmf.get_format()
                        if imops.is_coding_color(coding):
                            image = imops.to_rgb8(coding, image)
                            mean_image = imops.to_rgb8(coding, mean_image)
                        else:
                            image = imops.to_mono8(coding, image)
                            mean_image = imops.to_mono8(coding, mean_image)

                        xy = (
                            int(round(frame2d[idx]["x"])),
                            int(round(frame2d[idx]["y"])),
                        )
                        maxsize = (fmf.get_width(), fmf.get_height())

                        # Accumulate cropped images. Note that the region
                        # of the full image that the cropped image
                        # occupies changes over time as the tracked object
                        # moves. Thus, averaging these cropped-and-shifted
                        # images is not the same as simply averaging the
                        # full frame.

                        roiradius = 25
                        warnings.warn(
                            "roiradius hard-coded to %d: could be set "
                            "from 3D tracking" % roiradius)
                        tmp = clip_and_math(image, mean_image, xy, roiradius,
                                            maxsize)
                        im_coords, raw_im, mean_im, absdiff_im = tmp

                        max_absdiff_im = absdiff_im.max()
                        intermediate_thresh = intermediate_thresh_frac * max_absdiff_im
                        absdiff_im[absdiff_im <= intermediate_thresh] = 0

                        if erode > 0:
                            morphed_im = scipy.ndimage.grey_erosion(absdiff_im,
                                                                    size=erode)
                            ## morphed_im = scipy.ndimage.binary_erosion(absdiff_im>1).astype(np.float32)*255.0
                        else:
                            morphed_im = absdiff_im

                        y0_roi, x0_roi = scipy.ndimage.center_of_mass(
                            morphed_im)
                        x0 = im_coords[0] + x0_roi
                        y0 = im_coords[1] + y0_roi

                        if 1:
                            morphed_im_binary = morphed_im > 0
                            labels, n_labels = scipy.ndimage.label(
                                morphed_im_binary)
                            morph_fail_because_multiple_blobs = False

                            if n_labels > 1:
                                x0, y0 = np.nan, np.nan
                                # More than one blob -- don't allow image.
                                if 1:
                                    # for min flattening
                                    morphed_im = np.empty(morphed_im.shape,
                                                          dtype=np.uint8)
                                    morphed_im.fill(255)
                                    morph_fail_because_multiple_blobs = True
                                else:
                                    # for mean flattening
                                    morphed_im = np.zeros_like(morphed_im)
                                    morph_fail_because_multiple_blobs = True

                        this_obj_framenumbers[camn].append(framenumber)
                        if save_images:
                            this_obj_raw_images[camn].append(
                                (raw_im, im_coords))
                            this_obj_mean_images[camn].append(mean_im)
                        this_obj_absdiff_images[camn].append(absdiff_im)
                        this_obj_morphed_images[camn].append(morphed_im)
                        this_obj_morph_failures[camn].append(
                            morph_fail_because_multiple_blobs)
                        this_obj_im_coords[camn].append(im_coords)
                        this_obj_com_coords[camn].append((x0, y0))
                        this_obj_camn_pt_no[camn].append(orig_data2d_rownum)
                        if 0:
                            fname = "obj%05d_%s_frame%07d_pt%02d.png" % (
                                obj_id,
                                cam_id,
                                framenumber,
                                camn_pt_no,
                            )
                            plot_image_subregion(
                                raw_im,
                                mean_im,
                                absdiff_im,
                                roiradius,
                                fname,
                                im_coords,
                                view=filename2view[fmf.filename],
                            )

                # Now, all the frames from all cameras for this obj_id
                # have been gathered. Do a camera-by-camera analysis.
                for camn in this_obj_absdiff_images:
                    cam_id = camn2cam_id[camn]
                    image_framenumbers = np.array(this_obj_framenumbers[camn])
                    if save_images:
                        raw_images = this_obj_raw_images[camn]
                        mean_images = this_obj_mean_images[camn]
                    absdiff_images = this_obj_absdiff_images[camn]
                    morphed_images = this_obj_morphed_images[camn]
                    morph_failures = np.array(this_obj_morph_failures[camn])
                    im_coords = this_obj_im_coords[camn]
                    com_coords = this_obj_com_coords[camn]
                    camn_pt_no_array = this_obj_camn_pt_no[camn]

                    all_framenumbers = np.arange(image_framenumbers[0],
                                                 image_framenumbers[-1] + 1)

                    com_coords = np.array(com_coords)
                    if do_rts_smoothing:
                        # Perform RTS smoothing on center-of-mass coordinates.

                        # Find first good datum.
                        fgnz = np.nonzero(~np.isnan(com_coords[:, 0]))
                        com_coords_smooth = np.empty(com_coords.shape,
                                                     dtype=np.float)
                        com_coords_smooth.fill(np.nan)

                        if len(fgnz[0]):
                            first_good = fgnz[0][0]

                            RTS_com_coords = com_coords[first_good:, :]

                            # Setup parameters for Kalman filter.
                            dt = 1.0 / fps
                            A = np.array(
                                [
                                    [1, 0, dt, 0],  # process update
                                    [0, 1, 0, dt],
                                    [0, 0, 1, 0],
                                    [0, 0, 0, 1],
                                ],
                                dtype=np.float,
                            )
                            C = np.array(
                                [[1, 0, 0, 0], [0, 1, 0, 0]
                                 ],  # observation matrix
                                dtype=np.float,
                            )
                            Q = 0.1 * np.eye(4)  # process noise
                            R = 1.0 * np.eye(2)  # observation noise
                            initx = np.array(
                                [
                                    RTS_com_coords[0, 0], RTS_com_coords[0, 1],
                                    0, 0
                                ],
                                dtype=np.float,
                            )
                            initV = 2 * np.eye(4)
                            initV[0, 0] = 0.1
                            initV[1, 1] = 0.1
                            y = RTS_com_coords
                            xsmooth, Vsmooth = adskalman.adskalman.kalman_smoother(
                                y, A, C, Q, R, initx, initV)
                            com_coords_smooth[first_good:] = xsmooth[:, :2]

                        # Now shift images

                        image_shift = com_coords_smooth - com_coords
                        bad_cond = np.isnan(image_shift[:, 0])
                        # broadcast zeros to places where no good tracking
                        image_shift[bad_cond, 0] = 0
                        image_shift[bad_cond, 1] = 0

                        shifted_morphed_images = [
                            shift_image(im, xy)
                            for im, xy in zip(morphed_images, image_shift)
                        ]

                        results = flatten_image_stack(
                            image_framenumbers,
                            shifted_morphed_images,
                            im_coords,
                            camn_pt_no_array,
                            N=stack_N_images,
                        )
                    else:
                        results = flatten_image_stack(
                            image_framenumbers,
                            morphed_images,
                            im_coords,
                            camn_pt_no_array,
                            N=stack_N_images,
                        )

                    # The variable fno (the first element of the results
                    # tuple) is guaranteed to be contiguous and to span
                    # the range from the first to last frames available.

                    for (
                            fno,
                            av_im,
                            lowerleft,
                            orig_data2d_rownum,
                            orig_idx,
                            orig_idxs_in_average,
                    ) in results:

                        # Clip image to reduce moment arms.
                        av_im[av_im <= final_thresh] = 0

                        fail_fit = False
                        fast_av_im = FastImage.asfastimage(
                            av_im.astype(np.uint8))
                        try:
                            (x0_roi, y0_roi, area, slope,
                             eccentricity) = fpc.fit(fast_av_im)
                        except realtime_image_analysis.FitParamsError as err:
                            fail_fit = True

                        this_morph_failures = morph_failures[
                            orig_idxs_in_average]
                        n_failed_images = np.sum(this_morph_failures)
                        n_good_images = stack_N_images - n_failed_images
                        if n_good_images >= stack_N_images_min:
                            n_images_is_acceptable = True
                        else:
                            n_images_is_acceptable = False

                        if fail_fit:
                            x0_roi = np.nan
                            y0_roi = np.nan
                            area, slope, eccentricity = np.nan, np.nan, np.nan

                        if not n_images_is_acceptable:
                            x0_roi = np.nan
                            y0_roi = np.nan
                            area, slope, eccentricity = np.nan, np.nan, np.nan

                        x0 = x0_roi + lowerleft[0]
                        y0 = y0_roi + lowerleft[1]

                        if 1:
                            for row in dest_table.iterrows(
                                    start=orig_data2d_rownum,
                                    stop=orig_data2d_rownum + 1):

                                row["x"] = x0
                                row["y"] = y0
                                row["area"] = area
                                row["slope"] = slope
                                row["eccentricity"] = eccentricity
                                row.update()  # save data

                        if save_images:
                            # Display debugging images
                            fname = "av_obj%05d_%s_frame%07d.png" % (
                                obj_id,
                                cam_id,
                                fno,
                            )
                            if save_image_dir is not None:
                                fname = os.path.join(save_image_dir, fname)

                            raw_im, raw_coords = raw_images[orig_idx]
                            mean_im = mean_images[orig_idx]
                            absdiff_im = absdiff_images[orig_idx]
                            morphed_im = morphed_images[orig_idx]
                            raw_l, raw_b = raw_coords[:2]

                            imh, imw = raw_im.shape[:2]
                            n_ims = 5

                            if 1:
                                # increase contrast
                                contrast_scale = 2.0
                                av_im_show = np.clip(av_im * contrast_scale, 0,
                                                     255)

                            margin = 10
                            scale = 3

                            # calculate the orientation line
                            yintercept = y0 - slope * x0
                            xplt = np.array([
                                lowerleft[0] - 5,
                                lowerleft[0] + av_im_show.shape[1] + 5,
                            ])
                            yplt = slope * xplt + yintercept
                            if 1:
                                # only send non-nan values to plot
                                plt_good = ~np.isnan(xplt) & ~np.isnan(yplt)
                                xplt = xplt[plt_good]
                                yplt = yplt[plt_good]

                            top_row_width = scale * imw * n_ims + (
                                1 + n_ims) * margin
                            SHOW_STACK = True
                            if SHOW_STACK:
                                n_stack_rows = 4
                                rw = scale * imw * stack_N_images + (
                                    1 + n_ims) * margin
                                row_width = max(top_row_width, rw)
                                col_height = (n_stack_rows * scale * imh +
                                              (n_stack_rows + 1) * margin)
                                stack_margin = 20
                            else:
                                row_width = top_row_width
                                col_height = scale * imh + 2 * margin
                                stack_margin = 0

                            canv = benu.Canvas(
                                fname,
                                row_width,
                                col_height + stack_margin,
                                color_rgba=(1, 1, 1, 1),
                            )

                            if SHOW_STACK:
                                for (stacki, s_orig_idx
                                     ) in enumerate(orig_idxs_in_average):

                                    (s_raw_im,
                                     s_raw_coords) = raw_images[s_orig_idx]
                                    s_raw_l, s_raw_b = s_raw_coords[:2]
                                    s_imh, s_imw = s_raw_im.shape[:2]
                                    user_rect = (s_raw_l, s_raw_b, s_imw,
                                                 s_imh)

                                    x_display = (stacki + 1) * margin + (
                                        scale * imw) * stacki
                                    for show in ["raw", "absdiff", "morphed"]:
                                        if show == "raw":
                                            y_display = scale * imh + 2 * margin
                                        elif show == "absdiff":
                                            y_display = 2 * scale * imh + 3 * margin
                                        elif show == "morphed":
                                            y_display = 3 * scale * imh + 4 * margin
                                        display_rect = (
                                            x_display,
                                            y_display + stack_margin,
                                            scale * raw_im.shape[1],
                                            scale * raw_im.shape[0],
                                        )

                                        with canv.set_user_coords(
                                                display_rect,
                                                user_rect,
                                                transform=cam_id2view[cam_id],
                                        ):

                                            if show == "raw":
                                                s_im = s_raw_im.astype(
                                                    np.uint8)
                                            elif show == "absdiff":
                                                tmp = absdiff_images[
                                                    s_orig_idx]
                                                s_im = tmp.astype(np.uint8)
                                            elif show == "morphed":
                                                tmp = morphed_images[
                                                    s_orig_idx]
                                                s_im = tmp.astype(np.uint8)

                                            canv.imshow(s_im, s_raw_l, s_raw_b)
                                            sx0, sy0 = com_coords[s_orig_idx]
                                            X = [sx0]
                                            Y = [sy0]
                                            # the raw coords in red
                                            canv.scatter(X,
                                                         Y,
                                                         color_rgba=(1, 0.5,
                                                                     0.5, 1))

                                            if do_rts_smoothing:
                                                sx0, sy0 = com_coords_smooth[
                                                    s_orig_idx]
                                                X = [sx0]
                                                Y = [sy0]
                                                # the RTS smoothed coords in green
                                                canv.scatter(
                                                    X,
                                                    Y,
                                                    color_rgba=(0.5, 1, 0.5,
                                                                1))

                                            if s_orig_idx == orig_idx:
                                                boxx = np.array([
                                                    s_raw_l,
                                                    s_raw_l,
                                                    s_raw_l + s_imw,
                                                    s_raw_l + s_imw,
                                                    s_raw_l,
                                                ])
                                                boxy = np.array([
                                                    s_raw_b,
                                                    s_raw_b + s_imh,
                                                    s_raw_b + s_imh,
                                                    s_raw_b,
                                                    s_raw_b,
                                                ])
                                                canv.plot(
                                                    boxx,
                                                    boxy,
                                                    color_rgba=(0.5, 1, 0.5,
                                                                1),
                                                )
                                        if show == "morphed":
                                            canv.text(
                                                "morphed %d" %
                                                (s_orig_idx - orig_idx, ),
                                                display_rect[0],
                                                (display_rect[1] +
                                                 display_rect[3] +
                                                 stack_margin - 20),
                                                font_size=font_size,
                                                color_rgba=(1, 0, 0, 1),
                                            )

                            # Display raw_im
                            display_rect = (
                                margin,
                                margin,
                                scale * raw_im.shape[1],
                                scale * raw_im.shape[0],
                            )
                            user_rect = (raw_l, raw_b, imw, imh)
                            with canv.set_user_coords(
                                    display_rect,
                                    user_rect,
                                    transform=cam_id2view[cam_id],
                            ):
                                canv.imshow(raw_im.astype(np.uint8), raw_l,
                                            raw_b)
                                canv.plot(
                                    xplt, yplt,
                                    color_rgba=(0, 1, 0,
                                                0.5))  # the orientation line
                            canv.text(
                                "raw",
                                display_rect[0],
                                display_rect[1] + display_rect[3],
                                font_size=font_size,
                                color_rgba=(0.5, 0.5, 0.9, 1),
                                shadow_offset=1,
                            )

                            # Display mean_im
                            display_rect = (
                                2 * margin + (scale * imw),
                                margin,
                                scale * mean_im.shape[1],
                                scale * mean_im.shape[0],
                            )
                            user_rect = (raw_l, raw_b, imw, imh)
                            with canv.set_user_coords(
                                    display_rect,
                                    user_rect,
                                    transform=cam_id2view[cam_id],
                            ):
                                canv.imshow(mean_im.astype(np.uint8), raw_l,
                                            raw_b)
                            canv.text(
                                "mean",
                                display_rect[0],
                                display_rect[1] + display_rect[3],
                                font_size=font_size,
                                color_rgba=(0.5, 0.5, 0.9, 1),
                                shadow_offset=1,
                            )

                            # Display absdiff_im
                            display_rect = (
                                3 * margin + (scale * imw) * 2,
                                margin,
                                scale * absdiff_im.shape[1],
                                scale * absdiff_im.shape[0],
                            )
                            user_rect = (raw_l, raw_b, imw, imh)
                            absdiff_clip = np.clip(absdiff_im * contrast_scale,
                                                   0, 255)
                            with canv.set_user_coords(
                                    display_rect,
                                    user_rect,
                                    transform=cam_id2view[cam_id],
                            ):
                                canv.imshow(absdiff_clip.astype(np.uint8),
                                            raw_l, raw_b)
                            canv.text(
                                "absdiff",
                                display_rect[0],
                                display_rect[1] + display_rect[3],
                                font_size=font_size,
                                color_rgba=(0.5, 0.5, 0.9, 1),
                                shadow_offset=1,
                            )

                            # Display morphed_im
                            display_rect = (
                                4 * margin + (scale * imw) * 3,
                                margin,
                                scale * morphed_im.shape[1],
                                scale * morphed_im.shape[0],
                            )
                            user_rect = (raw_l, raw_b, imw, imh)
                            morphed_clip = np.clip(morphed_im * contrast_scale,
                                                   0, 255)
                            with canv.set_user_coords(
                                    display_rect,
                                    user_rect,
                                    transform=cam_id2view[cam_id],
                            ):
                                canv.imshow(morphed_clip.astype(np.uint8),
                                            raw_l, raw_b)
                            if 0:
                                canv.text(
                                    "morphed",
                                    display_rect[0],
                                    display_rect[1] + display_rect[3],
                                    font_size=font_size,
                                    color_rgba=(0.5, 0.5, 0.9, 1),
                                    shadow_offset=1,
                                )

                            # Display time-averaged absdiff_im
                            display_rect = (
                                5 * margin + (scale * imw) * 4,
                                margin,
                                scale * av_im_show.shape[1],
                                scale * av_im_show.shape[0],
                            )
                            user_rect = (
                                lowerleft[0],
                                lowerleft[1],
                                av_im_show.shape[1],
                                av_im_show.shape[0],
                            )
                            with canv.set_user_coords(
                                    display_rect,
                                    user_rect,
                                    transform=cam_id2view[cam_id],
                            ):
                                canv.imshow(
                                    av_im_show.astype(np.uint8),
                                    lowerleft[0],
                                    lowerleft[1],
                                )
                                canv.plot(
                                    xplt, yplt,
                                    color_rgba=(0, 1, 0,
                                                0.5))  # the orientation line
                            canv.text(
                                "stacked/flattened",
                                display_rect[0],
                                display_rect[1] + display_rect[3],
                                font_size=font_size,
                                color_rgba=(0.5, 0.5, 0.9, 1),
                                shadow_offset=1,
                            )

                            canv.text(
                                "%s frame % 7d: eccentricity % 5.1f, min N images %d, actual N images %d"
                                % (
                                    cam_id,
                                    fno,
                                    eccentricity,
                                    stack_N_images_min,
                                    n_good_images,
                                ),
                                0,
                                15,
                                font_size=font_size,
                                color_rgba=(0.6, 0.7, 0.9, 1),
                                shadow_offset=1,
                            )
                            canv.save()

                # Save results to new table
                if 0:
                    recarray = np.rec.array(list_of_rows_of_data2d,
                                            dtype=Info2DCol_description)
                    dest_table.append(recarray)
                    dest_table.flush()
            dest_table.attrs.has_ibo_data = True
        data_file.close()
            ML_estimates_2d_idxs = data_file.root.kalman_observations_2d_idxs[:]
        except tables.exceptions.NoSuchNodeError, err2:
            raise err1

    if os.path.exists( output_h5_filename ):
        raise RuntimeError(
            "will not overwrite old file '%s'"%output_h5_filename)
    with open_file_safe( output_h5_filename, delete_on_error=True,
                       mode='w') as output_h5:
        if save_image_dir is not None:
            if not os.path.exists( save_image_dir ):
                os.mkdir( save_image_dir )

        with open_file_safe( h5_filename, mode='r' ) as h5:

            fps = result_utils.get_fps( h5, fail_on_error=True )

            for input_node in h5.root._f_iter_nodes():
                # copy everything from source to dest
                input_node._f_copy(output_h5.root,recursive=True)
            print 'done copying'

            # Clear values in destination table that we may overwrite.
            dest_table = output_h5.root.data2d_distorted
            for colname in ['x','y','area','slope','eccentricity','cur_val',
                            'mean_val','sumsqf_val']:
                if colname=='cur_val':
                    fill_value = 0
                else:
                    fill_value = np.nan
                clear_col(dest_table,colname,fill_value=fill_value)