Beispiel #1
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('filename', nargs='+',
                        help='name of flydra .hdf5 file',
                        )

    parser.add_argument("--stim-xml",
                        type=str,
                        default=None,
                        help="name of XML file with stimulus info",
                        required=True,
                        )

    parser.add_argument("--align-json",
                        type=str,
                        default=None,
                        help="previously exported json file containing s,R,T",
                        )

    parser.add_argument("--radius", type=float,
                      help="radius of line (in meters)",
                      default=0.002,
                      metavar="RADIUS")

    parser.add_argument("--obj-only", type=str)

    parser.add_argument("--obj-filelist", type=str,
                      help="use object ids from list in text file",
                      )

    parser.add_argument(
        "-r", "--reconstructor", dest="reconstructor_path",
        type=str,
        help=("calibration/reconstructor path (if not specified, "
              "defaults to FILE)"))

    args = parser.parse_args()
    options = args # optparse OptionParser backwards compatibility

    reconstructor_path = args.reconstructor_path
    fps = None

    ca = core_analysis.get_global_CachingAnalyzer()
    by_file = {}

    for h5_filename in args.filename:
        assert(tables.is_hdf5_file(h5_filename))
        obj_ids, use_obj_ids, is_mat_file, data_file, extra = ca.initial_file_load(
            h5_filename)
        this_fps = result_utils.get_fps( data_file, fail_on_error=False )
        if fps is None:
            if this_fps is not None:
                fps = this_fps
        if reconstructor_path is None:
            reconstructor_path = data_file
        by_file[h5_filename] = (use_obj_ids, data_file)
    del h5_filename
    del obj_ids, use_obj_ids, is_mat_file, data_file, extra

    if options.obj_only is not None:
        obj_only = core_analysis.parse_seq(options.obj_only)
    else:
        obj_only = None

    if reconstructor_path is None:
        raise RuntimeError('must specify reconstructor from CLI if not using .h5 files')

    R = reconstruct.Reconstructor(reconstructor_path)

    if fps is None:
        fps = 100.0
        warnings.warn('Setting fps to default value of %f'%fps)
    else:
        fps = 1.0

    if options.stim_xml is None:
        raise ValueError(
            'stim_xml must be specified (how else will you align the data?')

    if 1:
        stim_xml = xml_stimulus.xml_stimulus_from_filename(
            options.stim_xml,
            )
        try:
            fanout = xml_stimulus.xml_fanout_from_filename( options.stim_xml )
        except xml_stimulus.WrongXMLTypeError:
            pass
        else:
            include_obj_ids, exclude_obj_ids = fanout.get_obj_ids_for_timestamp(
                timestamp_string=file_timestamp )
            if include_obj_ids is not None:
                use_obj_ids = include_obj_ids
            if exclude_obj_ids is not None:
                use_obj_ids = list( set(use_obj_ids).difference(
                    exclude_obj_ids ) )
            print('using object ids specified in fanout .xml file')
        if stim_xml.has_reconstructor():
            stim_xml.verify_reconstructor(R)

    x = []
    y = []
    z = []
    speed = []

    if options.obj_filelist is not None:
        obj_filelist=options.obj_filelist
    else:
        obj_filelist=None

    if obj_filelist is not None:
        obj_only = 1

    if obj_only is not None:
        if len(by_file) != 1:
            raise RuntimeError("specifying obj_only can only be done for a single file")
        if obj_filelist is not None:
            data = np.loadtxt(obj_filelist,delimiter=',')
            obj_only = np.array(data[:,0], dtype='int')
            print(obj_only)

        use_obj_ids = numpy.array(obj_only)
        h5_filename = by_file.keys()[0]
        (prev_use_ob_ids, data_file) = by_file[h5_filename]
        by_file[h5_filename] = (use_obj_ids, data_file)

    for h5_filename in by_file:
        (use_obj_ids, data_file) = by_file[h5_filename]
        for obj_id_enum,obj_id in enumerate(use_obj_ids):
            rows = ca.load_data( obj_id, data_file,
                                 use_kalman_smoothing=False,
                                 #dynamic_model_name = dynamic_model_name,
                                 #frames_per_second=fps,
                                 #up_dir=up_dir,
                                )
            verts = numpy.array( [rows['x'], rows['y'], rows['z']] ).T
            if len(verts)>=3:
                verts_central_diff = verts[2:,:] - verts[:-2,:]
                dt = 1.0/fps
                vels = verts_central_diff/(2*dt)
                speeds = numpy.sqrt(numpy.sum(vels**2,axis=1))
                # pad end points
                speeds = numpy.array([speeds[0]] + list(speeds) + [speeds[-1]])
            else:
                speeds = numpy.zeros( (verts.shape[0],) )

            if verts.shape[0] != len(speeds):
                raise ValueError('mismatch length of x data and speeds')
            x.append( verts[:,0] )
            y.append( verts[:,1] )
            z.append( verts[:,2] )
            speed.append(speeds)
        data_file.close()
    del h5_filename, use_obj_ids, data_file

    if 0:
        # debug
        if stim_xml is not None:
            v = None
            for child in stim_xml.root:
                if child.tag == 'cubic_arena':
                    info = stim_xml._get_info_for_cubic_arena(child)
                    v=info['verts4x4']
            if v is not None:
                for vi in v:
                    print('adding',vi)
                    x.append( [vi[0]] )
                    y.append( [vi[1]] )
                    z.append( [vi[2]] )
                    speed.append( [100.0] )

    x = np.concatenate(x)
    y = np.concatenate(y)
    z = np.concatenate(z)
    w = np.ones_like(x)
    speed = np.concatenate(speed)

    # homogeneous coords
    verts = np.array([x,y,z,w])

    #######################################################

    # Create the MayaVi engine and start it.
    e = Engine()
    # start does nothing much but useful if someone is listening to
    # your engine.
    e.start()

    # Create a new scene.
    from tvtk.tools import ivtk
    #viewer = ivtk.IVTK(size=(600,600))
    viewer = IVTKWithCalGUI(size=(800,600))
    viewer.open()
    e.new_scene(viewer)

    viewer.cal_align.set_data(verts,speed,R,args.align_json)

    if 0:
        # Do this if you need to see the MayaVi tree view UI.
        ev = EngineView(engine=e)
        ui = ev.edit_traits()

    # view aligned data
    e.add_source(viewer.cal_align.source)

    v = Vectors()
    v.glyph.scale_mode = 'data_scaling_off'
    v.glyph.color_mode = 'color_by_scalar'
    v.glyph.glyph_source.glyph_position='center'
    v.glyph.glyph_source.glyph_source = tvtk.SphereSource(
        radius=options.radius,
        )
    e.add_module(v)

    if stim_xml is not None:
        if 0:
            stim_xml.draw_in_mayavi_scene(e)
        else:
            actors = stim_xml.get_tvtk_actors()
            viewer.scene.add_actors(actors)

    gui = GUI()
    gui.start_event_loop()
Beispiel #2
0
            del sys.path[0]
    except IOError, err:
        print 'not a .mat file at %s, treating as .hdf5 file' % (os.path.join(
            data_path, data_filename))

    ca = core_analysis.get_global_CachingAnalyzer()
    obj_ids, use_obj_ids, is_mat_file, data_file, extra = ca.initial_file_load(
        filename)

    if obj_only is not None:
        use_obj_ids = obj_only

    if options.stim_xml is not None:
        file_timestamp = data_file.filename[4:19]
        stim_xml = xml_stimulus.xml_stimulus_from_filename(
            options.stim_xml,
            timestamp_string=file_timestamp,
        )
    try:
        fanout = xml_stimulus.xml_fanout_from_filename(options.stim_xml)
    except xml_stimulus.WrongXMLTypeError:
        pass
    else:
        include_obj_ids, exclude_obj_ids = fanout.get_obj_ids_for_timestamp(
            timestamp_string=file_timestamp)
        if include_obj_ids is not None:
            use_obj_ids = include_obj_ids
        if exclude_obj_ids is not None:
            use_obj_ids = list(set(use_obj_ids).difference(exclude_obj_ids))
        print 'using object ids specified in fanout .xml file'

    if dynamic_model_name is None:
Beispiel #3
0
def plot_top_and_side_views(
    subplot=None, options=None, obs_mew=None, scale=1.0, units="m",
):
    """
    inputs
    ------
    subplot - a dictionary of matplotlib axes instances with keys 'xy' and/or 'xz'
    fps - the framerate of the data
    """
    assert subplot is not None

    assert options is not None

    if not hasattr(options, "show_track_ends"):
        options.show_track_ends = False

    if not hasattr(options, "unicolor"):
        options.unicolor = False

    if not hasattr(options, "show_landing"):
        options.show_landing = False

    kalman_filename = options.kalman_filename
    fps = options.fps
    dynamic_model = options.dynamic_model
    use_kalman_smoothing = options.use_kalman_smoothing

    if not hasattr(options, "ellipsoids"):
        options.ellipsoids = False

    if not hasattr(options, "show_observations"):
        options.show_observations = False

    if not hasattr(options, "markersize"):
        options.markersize = 0.5

    if options.ellipsoids and use_kalman_smoothing:
        warnings.warn(
            "plotting ellipsoids while using Kalman smoothing does not reveal original error estimates"
        )

    assert kalman_filename is not None

    start = options.start
    stop = options.stop
    obj_only = options.obj_only

    if not use_kalman_smoothing:
        if dynamic_model is not None:
            print(
                "ERROR: disabling Kalman smoothing (--disable-kalman-smoothing) is incompatable with setting dynamic model options (--dynamic-model)",
                file=sys.stderr,
            )
            sys.exit(1)

    ca = core_analysis.get_global_CachingAnalyzer()

    if kalman_filename is not None:
        obj_ids, use_obj_ids, is_mat_file, data_file, extra = ca.initial_file_load(
            kalman_filename
        )

    if not is_mat_file:
        mat_data = None

        if fps is None:
            fps = result_utils.get_fps(data_file, fail_on_error=False)

        if fps is None:
            fps = 100.0
            warnings.warn("Setting fps to default value of %f" % fps)
        reconstructor = reconstruct.Reconstructor(data_file)
    else:
        reconstructor = None

    if options.stim_xml:
        file_timestamp = data_file.filename[4:19]
        stim_xml = xml_stimulus.xml_stimulus_from_filename(
            options.stim_xml, timestamp_string=file_timestamp,
        )
        try:
            fanout = xml_stimulus.xml_fanout_from_filename(options.stim_xml)
        except xml_stimulus.WrongXMLTypeError:
            walking_start_stops = []
        else:
            include_obj_ids, exclude_obj_ids = fanout.get_obj_ids_for_timestamp(
                timestamp_string=file_timestamp
            )
            walking_start_stops = fanout.get_walking_start_stops_for_timestamp(
                timestamp_string=file_timestamp
            )
            if include_obj_ids is not None:
                use_obj_ids = include_obj_ids
            if exclude_obj_ids is not None:
                use_obj_ids = list(set(use_obj_ids).difference(exclude_obj_ids))
            stim_xml = fanout.get_stimulus_for_timestamp(
                timestamp_string=file_timestamp
            )
        if stim_xml.has_reconstructor():
            stim_xml.verify_reconstructor(reconstructor)
    else:
        walking_start_stops = []

    if dynamic_model is None:
        dynamic_model = extra.get("dynamic_model_name", None)

    if dynamic_model is None:
        if use_kalman_smoothing:
            warnings.warn(
                "no kalman smoothing will be performed because no "
                "dynamic model specified or found."
            )
            use_kalman_smoothing = False
    else:
        print('detected file loaded with dynamic model "%s"' % dynamic_model)
        if use_kalman_smoothing:
            if dynamic_model.startswith("EKF "):
                dynamic_model = dynamic_model[4:]
            print('  for smoothing, will use dynamic model "%s"' % dynamic_model)

    subplots = subplot.keys()
    subplots.sort()  # ensure consistency across runs

    dt = 1.0 / fps

    if obj_only is not None:
        use_obj_ids = [i for i in use_obj_ids if i in obj_only]

    subplot["xy"].set_aspect("equal")
    subplot["xz"].set_aspect("equal")

    subplot["xy"].set_xlabel("x (%s)" % units)
    subplot["xy"].set_ylabel("y (%s)" % units)

    subplot["xz"].set_xlabel("x (%s)" % units)
    subplot["xz"].set_ylabel("z (%s)" % units)

    if options.stim_xml:
        stim_xml.plot_stim(
            subplot["xy"], projection=xml_stimulus.SimpleOrthographicXYProjection()
        )
        stim_xml.plot_stim(
            subplot["xz"], projection=xml_stimulus.SimpleOrthographicXZProjection()
        )

    allX = {}
    frame0 = None
    results = collections.defaultdict(list)
    for obj_id in use_obj_ids:
        line = None
        ellipse_lines = []
        MLE_line = None
        try:
            kalman_rows = ca.load_data(
                obj_id,
                data_file,
                use_kalman_smoothing=use_kalman_smoothing,
                dynamic_model_name=dynamic_model,
                frames_per_second=fps,
                up_dir=options.up_dir,
            )
        except core_analysis.ObjectIDDataError:
            continue

        if options.show_observations:
            kobs_rows = ca.load_dynamics_free_MLE_position(obj_id, data_file)

        frame = kalman_rows["frame"]
        if options.show_observations:
            frame_obs = kobs_rows["frame"]

        if (start is not None) or (stop is not None):
            valid_cond = numpy.ones(frame.shape, dtype=numpy.bool)
            if options.show_observations:
                valid_obs_cond = np.ones(frame_obs.shape, dtype=numpy.bool)

            if start is not None:
                valid_cond = valid_cond & (frame >= start)
                if options.show_observations:
                    valid_obs_cond = valid_obs_cond & (frame_obs >= start)

            if stop is not None:
                valid_cond = valid_cond & (frame <= stop)
                if options.show_observations:
                    valid_obs_cond = valid_obs_cond & (frame_obs <= stop)

            kalman_rows = kalman_rows[valid_cond]
            if options.show_observations:
                kobs_rows = kobs_rows[valid_obs_cond]
            if not len(kalman_rows):
                continue

        frame = kalman_rows["frame"]

        Xx = kalman_rows["x"]
        Xy = kalman_rows["y"]
        Xz = kalman_rows["z"]

        if options.max_z is not None:
            cond = Xz <= options.max_z

            frame = numpy.ma.masked_where(~cond, frame)
            Xx = numpy.ma.masked_where(~cond, Xx)
            Xy = numpy.ma.masked_where(~cond, Xy)
            Xz = numpy.ma.masked_where(~cond, Xz)
            with keep_axes_dimensions_if(subplot["xz"], options.stim_xml):
                subplot["xz"].axhline(options.max_z)

        kws = {"markersize": options.markersize}

        if options.unicolor:
            kws["color"] = "k"

        landing_idxs = []
        for walkstart, walkstop in walking_start_stops:
            if walkstart in frame:
                tmp_idx = numpy.nonzero(frame == walkstart)[0][0]
                landing_idxs.append(tmp_idx)

        with keep_axes_dimensions_if(subplot["xy"], options.stim_xml):
            (line,) = subplot["xy"].plot(
                Xx * scale, Xy * scale, ".", label="obj %d" % obj_id, **kws
            )
            kws["color"] = line.get_color()
            if options.ellipsoids:
                for i in range(len(Xx)):
                    rowi = kalman_rows[i]
                    mu = [rowi["x"], rowi["y"], rowi["z"]]
                    va = get_covariance(rowi)
                    ellx, elly = densities.gauss_ell(mu, va, [0, 1], 30, 0.39)
                    (ellipse_line,) = subplot["xy"].plot(
                        ellx * scale, elly * scale, color=kws["color"]
                    )
                    ellipse_lines.append(ellipse_line)
            if options.show_track_ends:
                subplot["xy"].plot(
                    [Xx[0] * scale, Xx[-1] * scale],
                    [Xy[0] * scale, Xy[-1] * scale],
                    "cd",
                    ms=6,
                    label="track end",
                )
            if options.show_obj_id:
                subplot["xy"].text(Xx[0] * scale, Xy[0] * scale, str(obj_id))
            if options.show_landing:
                for landing_idx in landing_idxs:
                    subplot["xy"].plot(
                        [Xx[landing_idx] * scale],
                        [Xy[landing_idx] * scale],
                        "rD",
                        ms=10,
                        label="landing",
                    )
            if options.show_observations:
                mykw = {}
                mykw.update(kws)
                mykw["markersize"] *= 5
                mykw["mew"] = obs_mew

                badcond = np.isnan(kobs_rows["x"])
                Xox = np.ma.masked_where(badcond, kobs_rows["x"])
                Xoy = np.ma.masked_where(badcond, kobs_rows["y"])

                (MLE_line,) = subplot["xy"].plot(
                    Xox * scale, Xoy * scale, "x", label="obj %d" % obj_id, **mykw
                )

        with keep_axes_dimensions_if(subplot["xz"], options.stim_xml):
            (line,) = subplot["xz"].plot(
                Xx * scale, Xz * scale, ".", label="obj %d" % obj_id, **kws
            )
            kws["color"] = line.get_color()
            if options.ellipsoids:
                for i in range(len(Xx)):
                    rowi = kalman_rows[i]
                    mu = [rowi["x"], rowi["y"], rowi["z"]]
                    va = get_covariance(rowi)
                    ellx, ellz = densities.gauss_ell(mu, va, [0, 2], 30, 0.39)
                    (ellipse_line,) = subplot["xz"].plot(
                        ellx * scale, ellz * scale, color=kws["color"]
                    )
                    ellipse_lines.append(ellipse_line)

            if options.show_track_ends:
                subplot["xz"].plot(
                    [Xx[0] * scale, Xx[-1] * scale],
                    [Xz[0] * scale, Xz[-1] * scale],
                    "cd",
                    ms=6,
                    label="track end",
                )
            if options.show_obj_id:
                subplot["xz"].text(Xx[0] * scale, Xz[0] * scale, str(obj_id))
            if options.show_landing:
                for landing_idx in landing_idxs:
                    subplot["xz"].plot(
                        [Xx[landing_idx] * scale],
                        [Xz[landing_idx] * scale],
                        "rD",
                        ms=10,
                        label="landing",
                    )
            if options.show_observations:
                mykw = {}
                mykw.update(kws)
                mykw["markersize"] *= 5
                mykw["mew"] = obs_mew

                badcond = np.isnan(kobs_rows["x"])
                Xox = np.ma.masked_where(badcond, kobs_rows["x"])
                Xoz = np.ma.masked_where(badcond, kobs_rows["z"])

                (MLE_line,) = subplot["xz"].plot(
                    Xox * scale, Xoz * scale, "x", label="obj %d" % obj_id, **mykw
                )
        results["lines"].append(line)
        results["ellipse_lines"].extend(ellipse_lines)
        results["MLE_line"].append(MLE_line)
    return results
Beispiel #4
0
def doit(
    filename,
    obj_only=None,
    min_length=10,
    use_kalman_smoothing=True,
    data_fps=100.0,
    save_fps=25,
    vertical_scale=False,
    max_vel="auto",
    draw_stim_func_str=None,
    floor=True,
    animation_path_fname=None,
    output_dir=".",
    cam_only_move_duration=5.0,
    options=None,
):

    if not use_kalman_smoothing:
        if dynamic_model_name is not None:
            print(
                "ERROR: disabling Kalman smoothing (--disable-kalman-smoothing) is incompatable with setting dynamic model option (--dynamic-model)",
                file=sys.stderr,
            )
            sys.exit(1)
    dynamic_model_name = options.dynamic_model

    if animation_path_fname is None:
        animation_path_fname = pkg_resources.resource_filename(
            __name__, "kdmovie_saver_default_path.kmp"
        )
    camera_animation_path = AnimationPath(animation_path_fname)

    mat_data = None
    try:
        try:
            data_path, data_filename = os.path.split(filename)
            data_path = os.path.expanduser(data_path)
            sys.path.insert(0, data_path)
            mat_data = scipy.io.mio.loadmat(data_filename)
        finally:
            del sys.path[0]
    except IOError as err:
        print(
            "not a .mat file at %s, treating as .hdf5 file"
            % (os.path.join(data_path, data_filename))
        )

    ca = core_analysis.get_global_CachingAnalyzer()
    obj_ids, use_obj_ids, is_mat_file, data_file, extra = ca.initial_file_load(filename)

    if obj_only is not None:
        use_obj_ids = obj_only

    if options.stim_xml is not None:
        file_timestamp = data_file.filename[4:19]
        stim_xml = xml_stimulus.xml_stimulus_from_filename(
            options.stim_xml, timestamp_string=file_timestamp,
        )
    try:
        fanout = xml_stimulus.xml_fanout_from_filename(options.stim_xml)
    except xml_stimulus.WrongXMLTypeError:
        pass
    else:
        include_obj_ids, exclude_obj_ids = fanout.get_obj_ids_for_timestamp(
            timestamp_string=file_timestamp
        )
        if include_obj_ids is not None:
            use_obj_ids = include_obj_ids
        if exclude_obj_ids is not None:
            use_obj_ids = list(set(use_obj_ids).difference(exclude_obj_ids))
        print("using object ids specified in fanout .xml file")

    if dynamic_model_name is None:
        if "dynamic_model_name" in extra:
            dynamic_model_name = extra["dynamic_model_name"]
            print('detected file loaded with dynamic model "%s"' % dynamic_model_name)
            if dynamic_model_name.startswith("EKF "):
                dynamic_model_name = dynamic_model_name[4:]
            print('  for smoothing, will use dynamic model "%s"' % dynamic_model_name)
        else:
            print(
                "no dynamic model name specified, and it could not be determined from the file, either"
            )

    filename_trimmed = os.path.split(os.path.splitext(filename)[0])[-1]

    assert use_obj_ids is not None

    #################
    rw = tvtk.RenderWindow(size=(1024, 768))

    ren = tvtk.Renderer(background=(1.0, 1.0, 1.0))
    camera = ren.active_camera

    rw.add_renderer(ren)

    lut = tvtk.LookupTable(hue_range=(0.667, 0.0))
    #################
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    if len(use_obj_ids) == 1:
        animate_path = True
        # allow path to grow during trajectory
    else:
        animate_path = False
        obj_verts = []
        speeds = []

    for obj_id in use_obj_ids:

        print("loading %d" % obj_id)
        results = ca.calculate_trajectory_metrics(
            obj_id,
            data_file,
            use_kalman_smoothing=use_kalman_smoothing,
            frames_per_second=data_fps,
            dynamic_model_name=dynamic_model_name,
            # method='position based',
            method_params={"downsample": 1,},
        )

        if len(use_obj_ids) == 1:
            obj_verts = results["X_kalmanized"]
            speeds = results["speed_kalmanized"]
            real_frames = results["frame"]

        else:
            obj_verts.append(results["X_kalmanized"])
            speeds.append(results["speed_kalmanized"])
            real_frames.append(results["frame"])

    if options.start is not None:
        good_cond = real_frames >= options.start
        obj_verts = obj_verts[good_cond]
        speeds = speeds[good_cond]
        real_frames = real_frames[good_cond]

    if not len(use_obj_ids) == 1:
        obj_verts = numpy.concatenate(obj_verts, axis=0)
        speeds = numpy.concatenate(speeds, axis=0)
        real_frames = numpy.concatenate(real_frames, axis=0)

    ####################### start draw permanently on stuff ############################

    if options.stim_xml is not None:

        if not is_mat_file:
            R = reconstruct.Reconstructor(data_file)
            stim_xml.verify_reconstructor(R)

        if not is_mat_file:
            assert data_file.filename.startswith("DATA") and (
                data_file.filename.endswith(".h5")
                or data_file.filename.endswith(".kh5")
            )
            file_timestamp = data_file.filename[4:19]
        actors = stim_xml.get_tvtk_actors()
        for actor in actors:
            ren.add_actor(actor)

    if 1:
        if 0:
            # Inspired by pyface.tvtk.decorated_scene
            marker = tvtk.OrientationMarkerWidget()

        axes = tvtk.AxesActor()
        axes.set(
            # normalized_tip_length=(0.04, 0.4, 0.4),
            # normalized_shaft_length=(0.6, 0.6, 0.6),
            shaft_type="cylinder",
            total_length=(0.15, 0.15, 0.15),
        )

        if 1:
            axes.x_axis_label_text = ""
            axes.y_axis_label_text = ""
            axes.z_axis_label_text = ""
        else:
            p = axes.x_axis_caption_actor2d.caption_text_property
            axes.y_axis_caption_actor2d.caption_text_property = p
            axes.z_axis_caption_actor2d.caption_text_property = p
            p.color = 0.0, 0.0, 0.0  # black
        # axes.camera = camera
        # axes.attachment_point_coordinate = (0,0,0)
        axes.origin = (-0.5, 1, 0)

        if 0:
            rwi = rw.interactor
            print("rwi", rwi)
            marker.orientation_marker = axes
            # marker.interactive = False
            marker.interactor = rwi
            marker.enabled = True

        ren.add_actor(axes)

    #######################

    if max_vel == "auto":
        max_vel = speeds.max()
    else:
        max_vel = float(max_vel)
    vel_mapper = tvtk.PolyDataMapper()
    vel_mapper.lookup_table = lut
    vel_mapper.scalar_range = 0.0, max_vel

    if 1:
        # Create a scalar bar
        if vertical_scale:
            scalar_bar = tvtk.ScalarBarActor(
                orientation="vertical", width=0.08, height=0.4
            )
        else:
            scalar_bar = tvtk.ScalarBarActor(
                orientation="horizontal", width=0.4, height=0.08
            )
        scalar_bar.title = "Speed (m/s)"
        scalar_bar.lookup_table = vel_mapper.lookup_table

        scalar_bar.property.color = 0.0, 0.0, 0.0  # black

        scalar_bar.title_text_property.color = 0.0, 0.0, 0.0
        scalar_bar.title_text_property.shadow = False

        scalar_bar.label_text_property.color = 0.0, 0.0, 0.0
        scalar_bar.label_text_property.shadow = False

        scalar_bar.position_coordinate.coordinate_system = "normalized_viewport"
        if vertical_scale:
            scalar_bar.position_coordinate.value = 0.01, 0.01, 0.0
        else:
            scalar_bar.position_coordinate.value = 0.1, 0.01, 0.0

        ren.add_actor(scalar_bar)

    imf = tvtk.WindowToImageFilter(input=rw, read_front_buffer="off")
    writer = tvtk.PNGWriter()

    ####################### end draw permanently on stuff ############################

    save_dt = 1.0 / save_fps

    if animate_path:
        data_dt = 1.0 / data_fps
        n_frames = len(obj_verts)
        dur = n_frames * data_dt
    else:
        data_dt = 0.0
        dur = 0.0

    print("data_fps", data_fps)
    print("data_dt", data_dt)
    print("save_fps", save_fps)

    t_now = 0.0
    frame_number = 0
    while t_now <= dur:
        frame_number += 1
        t_now += save_dt
        print("t_now", t_now)

        pos, ori = camera_animation_path.get_pos_ori(t_now)
        focal_point, view_up = pos_ori2fu(pos, ori)

        camera.position = tuple(pos)
        # camera.focal_point = (focal_point[0], focal_point[1], focal_point[2])
        # camera.view_up = (view_up[0], view_up[1], view_up[2])
        camera.focal_point = tuple(focal_point)
        camera.view_up = tuple(view_up)

        if data_dt != 0.0:
            draw_n_frames = int(round(t_now / data_dt))
        else:
            draw_n_frames = len(obj_verts)
        print("frame_number, draw_n_frames", frame_number, draw_n_frames)

        #################

        pd = tvtk.PolyData()
        pd.points = obj_verts[:draw_n_frames]
        real_frame_number = real_frames[:draw_n_frames][-1]
        pd.point_data.scalars = speeds
        if numpy.any(speeds > max_vel):
            print(
                "WARNING: maximum speed (%.3f m/s) exceeds color map max"
                % (speeds.max(),)
            )

        g = tvtk.Glyph3D(
            scale_mode="data_scaling_off", vector_mode="use_vector", input=pd
        )
        vel_mapper.input = g.output
        ss = tvtk.SphereSource(radius=options.radius)
        g.source = ss.output
        a = tvtk.Actor(mapper=vel_mapper)

        ##################

        ren.add_actor(a)

        if 1:
            imf.update()
            imf.modified()
            writer.input = imf.output
            # fname = 'movie_%s_%03d_frame%05d.png'%(filename_trimmed,obj_id,frame_number)
            fname = "movie_%s_%03d_frame%05d.png" % (
                filename_trimmed,
                obj_id,
                real_frame_number,
            )
            print("saving", fname)
            full_fname = os.path.join(output_dir, fname)
            writer.file_name = full_fname
            writer.write()

        ren.remove_actor(a)

    ren.add_actor(a)  # restore actors removed
    dur = dur + cam_only_move_duration

    while t_now < dur:
        frame_number += 1
        t_now += save_dt
        print("t_now", t_now)

        pos, ori = camera_animation_path.get_pos_ori(t_now)
        focal_point, view_up = pos_ori2fu(pos, ori)
        camera.position = tuple(pos)
        camera.focal_point = tuple(focal_point)
        camera.view_up = tuple(view_up)
        if 1:
            imf.update()
            imf.modified()
            writer.input = imf.output
            if len(use_obj_ids) == 1:
                fname = "movie_%s_%03d_frame%05d.png" % (
                    filename_trimmed,
                    obj_id,
                    frame_number,
                )
            else:
                fname = "movie_%s_many_frame%05d.png" % (filename_trimmed, frame_number)
            full_fname = os.path.join(output_dir, fname)
            writer.file_name = full_fname
            writer.write()

    if not is_mat_file:
        data_file.close()
Beispiel #5
0
    def show_it(
        self,
        fig,
        filename,
        kalman_filename=None,
        frame_start=None,
        frame_stop=None,
        show_nth_frame=None,
        obj_only=None,
        reconstructor_filename=None,
        options=None,
    ):

        if show_nth_frame == 0:
            show_nth_frame = None

        results = result_utils.get_results(filename, mode="r")
        opened_kresults = False
        kresults = None
        if kalman_filename is not None:
            if os.path.abspath(kalman_filename) == os.path.abspath(filename):
                kresults = results
            else:
                kresults = PT.open_file(kalman_filename, mode="r")
                opened_kresults = True

            # copied from plot_timeseries_2d_3d.py
            ca = core_analysis.get_global_CachingAnalyzer()
            (
                xxobj_ids,
                xxuse_obj_ids,
                xxis_mat_file,
                xxdata_file,
                extra,
            ) = ca.initial_file_load(kalman_filename)
            fps = extra["frames_per_second"]
            dynamic_model_name = None
            if dynamic_model_name is None:
                dynamic_model_name = extra.get("dynamic_model_name", None)
                if dynamic_model_name is None:
                    dynamic_model_name = dynamic_models.DEFAULT_MODEL
                    warnings.warn('no dynamic model specified, using "%s"' %
                                  dynamic_model_name)
                else:
                    print('detected file loaded with dynamic model "%s"' %
                          dynamic_model_name)
                if dynamic_model_name.startswith("EKF "):
                    dynamic_model_name = dynamic_model_name[4:]
                print('  for smoothing, will use dynamic model "%s"' %
                      dynamic_model_name)

        if hasattr(results.root, "images"):
            img_table = results.root.images
        else:
            img_table = None

        reconstructor_source = None
        if reconstructor_filename is None:
            if kresults is not None:
                reconstructor_source = kresults
            elif hasattr(results.root, "calibration"):
                reconstructor_source = results
            else:
                reconstructor_source = None
        else:
            if os.path.abspath(reconstructor_filename) == os.path.abspath(
                    filename):
                reconstructor_source = results
            elif (kalman_filename
                  is not None) and (os.path.abspath(reconstructor_filename)
                                    == os.path.abspath(kalman_filename)):
                reconstructor_source = kresults
            else:
                reconstructor_source = reconstructor_filename

        if reconstructor_source is not None:
            self.reconstructor = flydra_core.reconstruct.Reconstructor(
                reconstructor_source)

        if options.stim_xml:
            file_timestamp = results.filename[4:19]
            stim_xml = xml_stimulus.xml_stimulus_from_filename(
                options.stim_xml, timestamp_string=file_timestamp)
            if self.reconstructor is not None:
                stim_xml.verify_reconstructor(self.reconstructor)

        if self.reconstructor is not None:
            self.reconstructor = self.reconstructor.get_scaled()

        camn2cam_id, cam_id2camns = result_utils.get_caminfo_dicts(results)

        data2d = results.root.data2d_distorted  # make sure we have 2d data table

        print("reading frames...")
        frames = data2d.read(field="frame")
        print("OK")

        if frame_start is not None:
            print("selecting frames after start")
            # after_start = data2d.get_where_list( 'frame>=frame_start')
            after_start = numpy.nonzero(frames >= frame_start)[0]
        else:
            after_start = None

        if frame_stop is not None:
            print("selecting frames before stop")
            # before_stop = data2d.get_where_list( 'frame<=frame_stop')
            before_stop = numpy.nonzero(frames <= frame_stop)[0]
        else:
            before_stop = None

        print("finding all frames")
        if after_start is not None and before_stop is not None:
            use_idxs = numpy.intersect1d(after_start, before_stop)
        elif after_start is not None:
            use_idxs = after_start
        elif before_stop is not None:
            use_idxs = before_stop
        else:
            use_idxs = numpy.arange(data2d.nrows)

        # OK, we have data coords, plot

        print("reading cameras")
        frames = frames[
            use_idxs]  # data2d.read_coordinates( use_idxs, field='frame')
        if len(frames):
            print("frame range: %d - %d (%d frames total)" %
                  (frames[0], frames[-1], len(frames)))
        camns = data2d.read(field="camn")
        camns = camns[use_idxs]
        # camns = data2d.read_coordinates( use_idxs, field='camn')
        unique_camns = numpy.unique(camns)
        unique_cam_ids = list(set([camn2cam_id[camn]
                                   for camn in unique_camns]))
        unique_cam_ids.sort()
        print("%d cameras with data" % (len(unique_cam_ids), ))

        # plot all cameras, not just those with data
        all_cam_ids = cam_id2camns.keys()
        all_cam_ids.sort()
        unique_cam_ids = all_cam_ids

        if len(unique_cam_ids) == 1:
            n_rows = 1
            n_cols = 1
        elif len(unique_cam_ids) <= 6:
            n_rows = 2
            n_cols = 3
        elif len(unique_cam_ids) <= 12:
            n_rows = 3
            n_cols = 4
        else:
            n_rows = 4
            n_cols = int(math.ceil(len(unique_cam_ids) / n_rows))

        for i, cam_id in enumerate(unique_cam_ids):
            ax = auto_subplot(fig, i, n_rows=n_rows, n_cols=n_cols)
            ax.set_title("%s: %s" % (cam_id, str(cam_id2camns[cam_id])))
            ##        ax.set_xticks([])
            ##        ax.set_yticks([])
            ax.this_minx = np.inf
            ax.this_maxx = -np.inf
            ax.this_miny = np.inf
            ax.this_maxy = -np.inf
            self.subplot_by_cam_id[cam_id] = ax

        for cam_id in unique_cam_ids:
            ax = self.subplot_by_cam_id[cam_id]
            if img_table is not None:
                bg_arr_h5 = getattr(img_table, cam_id)
                bg_arr = bg_arr_h5.read()
                ax.imshow(bg_arr.squeeze(), origin="lower", cmap=cm.pink)
                ax.set_autoscale_on(True)
                ax.autoscale_view()
                pylab.draw()
                ax.set_autoscale_on(False)

            if self.reconstructor is not None:
                if cam_id in self.reconstructor.get_cam_ids():
                    res = self.reconstructor.get_resolution(cam_id)
                    ax.set_xlim([0, res[0]])
                    ax.set_ylim([res[1], 0])

            if options.stim_xml is not None:
                stim_xml.plot_stim_over_distorted_image(ax, cam_id)
        for camn in unique_camns:
            cam_id = camn2cam_id[camn]
            ax = self.subplot_by_cam_id[cam_id]
            this_camn_idxs = use_idxs[camns == camn]

            xs = data2d.read_coordinates(this_camn_idxs, field="x")

            valid_idx = numpy.nonzero(~numpy.isnan(xs))[0]
            if not len(valid_idx):
                continue
            ys = data2d.read_coordinates(this_camn_idxs, field="y")
            if options.show_orientation:
                slope = data2d.read_coordinates(this_camn_idxs, field="slope")

            idx_first_valid = valid_idx[0]
            idx_last_valid = valid_idx[-1]
            tmp_frames = data2d.read_coordinates(this_camn_idxs, field="frame")

            ax.plot([xs[idx_first_valid]], [ys[idx_first_valid]],
                    "ro",
                    label="first point")

            ax.this_minx = min(np.min(xs[valid_idx]), ax.this_minx)
            ax.this_maxx = max(np.max(xs[valid_idx]), ax.this_maxx)

            ax.this_miny = min(np.min(ys[valid_idx]), ax.this_miny)
            ax.this_maxy = max(np.max(ys[valid_idx]), ax.this_maxy)

            ax.plot(xs[valid_idx], ys[valid_idx], "g.", label="all points")

            if options.show_orientation:
                angle = np.arctan(slope)
                r = 20.0
                dx = r * np.cos(angle)
                dy = r * np.sin(angle)
                x0 = xs - dx
                x1 = xs + dx
                y0 = ys - dy
                y1 = ys + dy
                segs = []
                for i in valid_idx:
                    segs.append(((x0[i], y0[i]), (x1[i], y1[i])))
                line_segments = collections.LineCollection(
                    segs,
                    linewidths=[1],
                    colors=[(0, 1, 0)],
                )
                ax.add_collection(line_segments)

            ax.plot([xs[idx_last_valid]], [ys[idx_last_valid]],
                    "bo",
                    label="first point")

            if show_nth_frame is not None:
                for i, f in enumerate(tmp_frames):
                    if f % show_nth_frame == 0:
                        ax.text(xs[i], ys[i], "%d" % (f, ))

            if 0:
                for x, y, frame in zip(xs[::5], ys[::5], tmp_frames[::5]):
                    ax.text(x, y, "%d" % (frame, ))

        fig.canvas.mpl_connect("key_press_event", self.on_key_press)

        if options.autozoom:
            for cam_id in self.subplot_by_cam_id.keys():
                ax = self.subplot_by_cam_id[cam_id]
                ax.set_xlim((ax.this_minx - 10, ax.this_maxx + 10))
                ax.set_ylim((ax.this_miny - 10, ax.this_maxy + 10))

        if options.save_fig:
            for cam_id in self.subplot_by_cam_id.keys():
                ax = self.subplot_by_cam_id[cam_id]
                ax.set_xticks([])
                ax.set_yticks([])

        if kalman_filename is None:
            return

        if 0:
            # Do same as above for Kalman-filtered data

            kobs = kresults.root.ML_estimates
            kframes = kobs.read(field="frame")
            if frame_start is not None:
                k_after_start = numpy.nonzero(kframes >= frame_start)[0]
            else:
                k_after_start = None
            if frame_stop is not None:
                k_before_stop = numpy.nonzero(kframes <= frame_stop)[0]
            else:
                k_before_stop = None

            if k_after_start is not None and k_before_stop is not None:
                k_use_idxs = numpy.intersect1d(k_after_start, k_before_stop)
            elif k_after_start is not None:
                k_use_idxs = k_after_start
            elif k_before_stop is not None:
                k_use_idxs = k_before_stop
            else:
                k_use_idxs = numpy.arange(kobs.nrows)

            obj_ids = kobs.read(field="obj_id")[k_use_idxs]
            obs_2d_idxs = kobs.read(field="obs_2d_idx")[k_use_idxs]
            kframes = kframes[k_use_idxs]

            kobs_2d = kresults.root.ML_estimates_2d_idxs
            xys_by_obj_id = {}
            for obj_id, kframe, obs_2d_idx in zip(obj_ids, kframes,
                                                  obs_2d_idxs):
                if obj_only is not None:
                    if obj_id not in obj_only:
                        continue

                obs_2d_idx_find = int(
                    obs_2d_idx)  # XXX grr, why can't pytables do this?
                obj_id_save = int(obj_id)  # convert from possible numpy scalar
                xys_by_cam_id = xys_by_obj_id.setdefault(obj_id_save, {})
                kobs_2d_data = kobs_2d.read(start=obs_2d_idx_find,
                                            stop=obs_2d_idx_find + 1)
                assert len(kobs_2d_data) == 1
                kobs_2d_data = kobs_2d_data[0]
                this_camns = kobs_2d_data[0::2]
                this_camn_idxs = kobs_2d_data[1::2]

                this_use_idxs = use_idxs[frames == kframe]

                d2d = data2d.read_coordinates(this_use_idxs)
                for this_camn, this_camn_idx in zip(this_camns,
                                                    this_camn_idxs):
                    this_idxs_tmp = numpy.nonzero(d2d["camn"] == this_camn)[0]
                    this_camn_d2d = d2d[d2d["camn"] == this_camn]
                    found = False
                    for this_row in this_camn_d2d:  # XXX could be sped up
                        if this_row["frame_pt_idx"] == this_camn_idx:
                            found = True
                            break
                    if not found:
                        if 1:
                            print(
                                "WARNING:point not found in data -- 3D data starts before 2D I guess."
                            )
                            continue
                        else:
                            raise RuntimeError("point not found in data!?")
                    this_cam_id = camn2cam_id[this_camn]
                    xys = xys_by_cam_id.setdefault(this_cam_id, ([], []))
                    xys[0].append(this_row["x"])
                    xys[1].append(this_row["y"])

            for obj_id in xys_by_obj_id:
                xys_by_cam_id = xys_by_obj_id[obj_id]
                for cam_id, (xs, ys) in xys_by_cam_id.iteritems():
                    ax = self.subplot_by_cam_id[cam_id]
                    ax.plot(xs, ys, "x-", label="obs: %d" % obj_id)
                    ax.text(xs[0], ys[0], "%d:" % (obj_id, ))
                    ax.text(xs[-1], ys[-1], ":%d" % (obj_id, ))

        if 1:
            # do for core_analysis smoothed (or not) data

            for obj_id in xxuse_obj_ids:
                try:
                    rows = ca.load_data(
                        obj_id,
                        kalman_filename,
                        use_kalman_smoothing=True,
                        frames_per_second=fps,
                        dynamic_model_name=dynamic_model_name,
                    )
                except core_analysis.NotEnoughDataToSmoothError:
                    warnings.warn(
                        "not enough data to smooth obj_id %d, skipping." %
                        (obj_id, ))
                if frame_start is not None:
                    c1 = rows["frame"] >= frame_start
                else:
                    c1 = np.ones((len(rows), ), dtype=np.bool)
                if frame_stop is not None:
                    c2 = rows["frame"] <= frame_stop
                else:
                    c2 = np.ones((len(rows), ), dtype=np.bool)
                valid = c1 & c2
                rows = rows[valid]
                if len(rows) > 1:
                    X3d = np.array((rows["x"], rows["y"], rows["z"],
                                    np.ones_like(rows["z"]))).T

                for cam_id in self.subplot_by_cam_id.keys():
                    ax = self.subplot_by_cam_id[cam_id]
                    newx, newy = self.reconstructor.find2d(cam_id,
                                                           X3d,
                                                           distorted=True)
                    ax.plot(newx, newy, "-", label="k: %d" % obj_id)

        results.close()
        if opened_kresults:
            kresults.close()
Beispiel #6
0
def plot_timeseries(subplot=None, options=None):
    kalman_filename = options.kalman_filename

    if not hasattr(options, 'frames'):
        options.frames = False

    if not hasattr(options, 'show_landing'):
        options.show_landing = False

    if not hasattr(options, 'unicolor'):
        options.unicolor = False

    if not hasattr(options, 'show_obj_id'):
        options.show_obj_id = True

    if not hasattr(options, 'show_track_ends'):
        options.show_track_ends = False

    start = options.start
    stop = options.stop
    obj_only = options.obj_only
    fps = options.fps
    dynamic_model = options.dynamic_model
    use_kalman_smoothing = options.use_kalman_smoothing

    if not use_kalman_smoothing:
        if (dynamic_model is not None):
            print >> sys.stderr, (
                'WARNING: disabling Kalman smoothing '
                '(--disable-kalman-smoothing) is incompatable '
                'with setting dynamic model options (--dynamic-model)')

    ca = core_analysis.get_global_CachingAnalyzer()

    if kalman_filename is None:
        raise ValueError('No kalman_filename given. Nothing to do.')

    m = hashlib.md5()
    m.update(open(kalman_filename, mode='rb').read())
    actual_md5 = m.hexdigest()
    (obj_ids, use_obj_ids, is_mat_file, data_file,
     extra) = ca.initial_file_load(kalman_filename)
    print 'opened kalman file %s %s, %d obj_ids' % (
        kalman_filename, actual_md5, len(use_obj_ids))

    if 'frames' in extra:
        if (start is not None) or (stop is not None):
            valid_frames = np.ones((len(extra['frames']), ), dtype=np.bool)
            if start is not None:
                valid_frames &= extra['frames'] >= start
            if stop is not None:
                valid_frames &= extra['frames'] <= stop
            this_use_obj_ids = np.unique(obj_ids[valid_frames])
            use_obj_ids = list(set(use_obj_ids).intersection(this_use_obj_ids))

    include_obj_ids = None
    exclude_obj_ids = None
    do_fuse = False
    if options.stim_xml:
        file_timestamp = data_file.filename[4:19]
        fanout = xml_stimulus.xml_fanout_from_filename(options.stim_xml)
        include_obj_ids, exclude_obj_ids = fanout.get_obj_ids_for_timestamp(
            timestamp_string=file_timestamp)
        walking_start_stops = fanout.get_walking_start_stops_for_timestamp(
            timestamp_string=file_timestamp)
        if include_obj_ids is not None:
            use_obj_ids = include_obj_ids
        if exclude_obj_ids is not None:
            use_obj_ids = list(set(use_obj_ids).difference(exclude_obj_ids))
        if options.fuse:
            do_fuse = True
    else:
        walking_start_stops = []

    if dynamic_model is None:
        dynamic_model = extra['dynamic_model_name']
        print 'detected file loaded with dynamic model "%s"' % dynamic_model
        if dynamic_model.startswith('EKF '):
            dynamic_model = dynamic_model[4:]
        print '  for smoothing, will use dynamic model "%s"' % dynamic_model

    if not is_mat_file:
        mat_data = None

        if fps is None:
            fps = result_utils.get_fps(data_file, fail_on_error=False)

        if fps is None:
            fps = 100.0
            import warnings
            warnings.warn('Setting fps to default value of %f' % fps)

        tz = result_utils.get_tz(data_file)

    dt = 1.0 / fps

    all_vels = []

    if obj_only is not None:
        use_obj_ids = [i for i in use_obj_ids if i in obj_only]

    allX = {}
    frame0 = None

    line2obj_id = {}
    Xz_all = []

    fuse_did_once = False

    if not hasattr(options, 'timestamp_file'):
        options.timestamp_file = None

    if not hasattr(options, 'ori_qual'):
        options.ori_qual = None

    if options.timestamp_file is not None:
        h5 = tables.open_file(options.timestamp_file, mode='r')
        print 'reading timestamps and frames'
        table_data2d_frames = h5.root.data2d_distorted.read(field='frame')
        table_data2d_timestamps = h5.root.data2d_distorted.read(
            field='timestamp')
        print 'done'
        h5.close()
        table_data2d_frames_find = utils.FastFinder(table_data2d_frames)

    if len(use_obj_ids) == 0:
        print 'No obj_ids to plot, quitting'
        sys.exit(0)

    time0 = 0.0  # set default value

    for obj_id in use_obj_ids:
        if not do_fuse:
            try:
                kalman_rows = ca.load_data(
                    obj_id,
                    data_file,
                    use_kalman_smoothing=use_kalman_smoothing,
                    dynamic_model_name=dynamic_model,
                    return_smoothed_directions=options.smooth_orientations,
                    frames_per_second=fps,
                    up_dir=options.up_dir,
                    min_ori_quality_required=options.ori_qual,
                )
            except core_analysis.ObjectIDDataError:
                continue
            #kobs_rows = ca.load_dynamics_free_MLE_position( obj_id, data_file )
        else:
            if options.show_3d_orientations:
                raise NotImplementedError('orientation data is not supported '
                                          'when fusing obj_ids')
            if fuse_did_once:
                break
            fuse_did_once = True
            kalman_rows = flydra_analysis.a2.flypos.fuse_obj_ids(
                use_obj_ids,
                data_file,
                dynamic_model_name=dynamic_model,
                frames_per_second=fps)
        frame = kalman_rows['frame']

        if (start is not None) or (stop is not None):
            valid_cond = numpy.ones(frame.shape, dtype=numpy.bool)

            if start is not None:
                valid_cond = valid_cond & (frame >= start)

            if stop is not None:
                valid_cond = valid_cond & (frame <= stop)

            kalman_rows = kalman_rows[valid_cond]
            if not len(kalman_rows):
                continue

        walking_and_flying_kalman_rows = kalman_rows  # preserve original data

        for flystate in ['flying', 'walking']:
            frame = walking_and_flying_kalman_rows['frame']  # restore
            if flystate == 'flying':
                # assume flying unless we're told it's walking
                state_cond = numpy.ones(frame.shape, dtype=numpy.bool)
            else:
                state_cond = numpy.zeros(frame.shape, dtype=numpy.bool)

            if len(walking_start_stops):
                for walkstart, walkstop in walking_start_stops:
                    frame = walking_and_flying_kalman_rows['frame']  # restore

                    # handle each bout of walking
                    walking_bout = numpy.ones(frame.shape, dtype=numpy.bool)
                    if walkstart is not None:
                        walking_bout &= (frame >= walkstart)
                    if walkstop is not None:
                        walking_bout &= (frame <= walkstop)
                    if flystate == 'flying':
                        state_cond &= ~walking_bout
                    else:
                        state_cond |= walking_bout

                kalman_rows = np.take(walking_and_flying_kalman_rows,
                                      np.nonzero(state_cond)[0])
                assert len(kalman_rows) == np.sum(state_cond)
                frame = kalman_rows['frame']

            if frame0 is None:
                frame0 = int(frame[0])

            time0 = 0.0
            if options.timestamp_file is not None:
                frame_idxs = table_data2d_frames_find.get_idxs_of_equal(frame0)
                if len(frame_idxs):
                    time0 = table_data2d_timestamps[frame_idxs[0]]
                else:
                    raise ValueError(
                        'could not fine frame %d in timestamp file' % frame0)

            Xx = kalman_rows['x']
            Xy = kalman_rows['y']
            Xz = kalman_rows['z']

            Dx = Dy = Dz = None
            if options.smooth_orientations:
                Dx = kalman_rows['dir_x']
                Dy = kalman_rows['dir_y']
                Dz = kalman_rows['dir_z']
            elif 'rawdir_x' in kalman_rows.dtype.fields:
                Dx = kalman_rows['rawdir_x']
                Dy = kalman_rows['rawdir_y']
                Dz = kalman_rows['rawdir_z']

            if not options.frames:
                f2t = Frames2Time(frame0, fps, time0)
            else:

                def identity(x):
                    return x

                f2t = identity

            kws = {
                'linewidth': 2,
                'picker': 5,
            }
            if options.unicolor:
                kws['color'] = 'k'

            line = None

            if 'frame' in subplot:
                subplot['frame'].plot(f2t(frame), frame)

            if 'P55' in subplot:
                subplot['P55'].plot(f2t(frame), kalman_rows['P55'])

            if 'x' in subplot:
                line, = subplot['x'].plot(f2t(frame),
                                          Xx,
                                          label='obj %d (%s)' %
                                          (obj_id, flystate),
                                          **kws)
                line2obj_id[line] = obj_id
                kws['color'] = line.get_color()

            if 'y' in subplot:
                line, = subplot['y'].plot(f2t(frame),
                                          Xy,
                                          label='obj %d (%s)' %
                                          (obj_id, flystate),
                                          **kws)
                line2obj_id[line] = obj_id
                kws['color'] = line.get_color()

            if 'z' in subplot:
                frame_data = numpy.ma.getdata(
                    frame)  # works if frame is masked or not

                # plot landing time
                if options.show_landing:
                    if flystate == 'flying':  # only do this once
                        for walkstart, walkstop in walking_start_stops:
                            if walkstart in frame_data:
                                landing_dix = numpy.nonzero(
                                    frame_data == walkstart)[0][0]
                                subplot['z'].plot([f2t(walkstart)],
                                                  [Xz.data[landing_dix]],
                                                  'rD',
                                                  ms=10,
                                                  label='landing')

                if options.show_track_ends:
                    if flystate == 'flying':  # only do this once
                        subplot['z'].plot(f2t([frame_data[0], frame_data[-1]]),
                                          [
                                              numpy.ma.getdata(Xz)[0],
                                              numpy.ma.getdata(Xz)[-1]
                                          ],
                                          'cd',
                                          ms=6,
                                          label='track end')

                line, = subplot['z'].plot(f2t(frame),
                                          Xz,
                                          label='obj %d (%s)' %
                                          (obj_id, flystate),
                                          **kws)
                kws['color'] = line.get_color()
                line2obj_id[line] = obj_id

                if flystate == 'flying':
                    # only do this once
                    if options.show_obj_id:
                        subplot['z'].text(f2t(frame_data[0]),
                                          numpy.ma.getdata(Xz)[0],
                                          '%d' % (obj_id, ))
                        line2obj_id[line] = obj_id

            if flystate == 'flying':
                Xz_all.append(np.ma.array(Xz).compressed())
                #bins = np.linspace(0,.8,30)
                #print 'Xz.shape',Xz.shape
                #pylab.hist(Xz, bins=bins)

            for (dir_var, Dd) in [('dx', Dx), ('dy', Dy), ('dz', Dz)]:
                if dir_var in subplot:
                    line, = subplot[dir_var].plot(f2t(frame),
                                                  Dd,
                                                  label='obj %d (%s)' %
                                                  (obj_id, flystate),
                                                  **kws)
                    line2obj_id[line] = obj_id
                    kws['color'] = line.get_color()

            if numpy.__version__ >= '1.2.0':
                X = numpy.ma.array((Xx, Xy, Xz))
            else:
                # See http://scipy.org/scipy/numpy/ticket/820
                X = numpy.ma.vstack(
                    (Xx[numpy.newaxis, :], Xy[numpy.newaxis, :],
                     Xz[numpy.newaxis, :]))

            dist_central_diff = (X[:, 2:] - X[:, :-2])
            vel_central_diff = dist_central_diff / (2 * dt)

            vel2mag = numpy.ma.sqrt(numpy.ma.sum(vel_central_diff**2, axis=0))
            xy_vel2mag = numpy.ma.sqrt(
                numpy.ma.sum(vel_central_diff[:2, :]**2, axis=0))

            frames2 = frame[1:-1]

            accel4mag = (vel2mag[2:] - vel2mag[:-2]) / (2 * dt)
            frames4 = frames2[1:-1]

            if 'vel' in subplot:
                line, = subplot['vel'].plot(f2t(frames2),
                                            vel2mag,
                                            label='obj %d (%s)' %
                                            (obj_id, flystate),
                                            **kws)
                line2obj_id[line] = obj_id
                kws['color'] = line.get_color()

            if 'xy_vel' in subplot:
                line, = subplot['xy_vel'].plot(f2t(frames2),
                                               xy_vel2mag,
                                               label='obj %d (%s)' %
                                               (obj_id, flystate),
                                               **kws)
                line2obj_id[line] = obj_id
                kws['color'] = line.get_color()

            if len(accel4mag.compressed()) and 'accel' in subplot:
                line, = subplot['accel'].plot(f2t(frames4),
                                              accel4mag,
                                              label='obj %d (%s)' %
                                              (obj_id, flystate),
                                              **kws)
                line2obj_id[line] = obj_id
                kws['color'] = line.get_color()

            if flystate == 'flying':
                valid_vel2mag = vel2mag.compressed()
                all_vels.append(valid_vel2mag)
    if len(all_vels):
        all_vels = numpy.hstack(all_vels)
    else:
        all_vels = numpy.array([], dtype=float)

    if 1:
        cond = all_vels < 2.0
        if numpy.ma.sum(cond) != len(all_vels):
            all_vels = all_vels[cond]
            import warnings
            warnings.warn('clipping all velocities > 2.0 m/s')

    if not options.frames:
        xlabel = 'time (s)'
    else:
        xlabel = 'frame'

    for ax in subplot.itervalues():
        ax.xaxis.set_major_formatter(ticker.FormatStrFormatter("%d"))
        ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%s"))

    fixup_ax = FixupAxesWithTimeZone(tz).fixup_ax

    if 'frame' in subplot:
        if time0 != 0.0:
            fixup_ax(subplot['frame'])
        else:
            subplot['frame'].set_xlabel(xlabel)

    if 'x' in subplot:
        subplot['x'].set_ylim([-1, 1])
        subplot['x'].set_ylabel(r'x (m)')
        if time0 != 0.0:
            fixup_ax(subplot['x'])
        else:
            subplot['x'].set_xlabel(xlabel)

    if 'y' in subplot:
        subplot['y'].set_ylim([-0.5, 1.5])
        subplot['y'].set_ylabel(r'y (m)')
        if time0 != 0.0:
            fixup_ax(subplot['y'])
        else:
            subplot['y'].set_xlabel(xlabel)

    max_z = None
    if options.stim_xml:
        file_timestamp = options.kalman_filename[4:19]
        stim_xml = xml_stimulus.xml_stimulus_from_filename(
            options.stim_xml, timestamp_string=file_timestamp)
        post_max_zs = []
        for post_num, post in enumerate(stim_xml.iterate_posts()):
            post_max_zs.append(max(post['verts'][0][2],
                                   post['verts'][1][2]))  # max post height
        if len(post_max_zs):
            max_z = min(post_max_zs)  # take shortest of posts

    if 'z' in subplot:
        subplot['z'].set_ylim([0, 1])
        subplot['z'].set_ylabel(r'z (m)')
        if max_z is not None:
            subplot['z'].axhline(max_z, color='m')
        if time0 != 0.0:
            fixup_ax(subplot['z'])
        else:
            subplot['z'].set_xlabel(xlabel)

    for dir_var in ['dx', 'dy', 'dz']:
        if dir_var in subplot:
            subplot[dir_var].set_ylabel(dir_var)
            if time0 != 0.0:
                fixup_ax(subplot[dir_var])
            else:
                subplot[dir_var].set_xlabel(xlabel)

    if 'z_hist' in subplot:  # and flystate=='flying':
        Xz_all = np.hstack(Xz_all)
        bins = np.linspace(0, .8, 30)
        ax = subplot['z_hist']
        ax.hist(Xz_all, bins=bins, orientation='horizontal')
        ax.set_xticks([])
        ax.set_yticks([])
        xlim = tuple(ax.get_xlim())  # matplotlib 0.98.3 returned np.array view
        ax.set_xlim((xlim[1], xlim[0]))
        ax.axhline(max_z, color='m')

    if 'vel' in subplot:
        subplot['vel'].set_ylim([0, 2])
        subplot['vel'].set_ylabel(r'vel (m/s)')
        subplot['vel'].set_xlabel(xlabel)
        if time0 != 0.0:
            fixup_ax(subplot['vel'])
        else:
            subplot['vel'].set_xlabel(xlabel)

    if 'xy_vel' in subplot:
        #subplot['xy_vel'].set_ylim([0,2])
        subplot['xy_vel'].set_ylabel(r'horiz vel (m/s)')
        subplot['xy_vel'].set_xlabel(xlabel)
        if time0 != 0.0:
            fixup_ax(subplot['xy_vel'])
        else:
            subplot['xy_vel'].set_xlabel(xlabel)

    if 'accel' in subplot:
        subplot['accel'].set_ylabel(r'acceleration (m/(s^2))')
        subplot['accel'].set_xlabel(xlabel)
        if time0 != 0.0:
            fixup_ax(subplot['accel'])
        else:
            subplot['accel'].set_xlabel(xlabel)

    if 'vel_hist' in subplot:
        ax = subplot['vel_hist']
        bins = numpy.linspace(0, 2, 50)
        ax.set_title('excluding walking')
        pdf, bins, patches = ax.hist(all_vels, bins=bins, normed=True)
        ax.set_xlim(0, 2)
        ax.set_ylabel('probability density')
        ax.set_xlabel('velocity (m/s)')

    return line2obj_id