示例#1
0
def consider_stimulus(h5file,
                      verbose_problems=False,
                      fanout_name="fanout.xml"):
    """ 
        Parses the corresponding fanout XML and finds IDs to use as well 
        as the stimulus.
        Returns 3 values: valid, use_objs_ids, stimulus.  
        valid is false if something was wrong
    """

    try:
        dirname = os.path.dirname(h5file)
        fanout_xml = os.path.join(dirname, fanout_name)
        if not (os.path.exists(fanout_xml)):
            if verbose_problems:
                logger.error("Stim_xml path not found '%s' for file '%s'" %
                             (h5file, fanout_xml))
            return False, None, None

        ca = core_analysis.get_global_CachingAnalyzer()
        (_, use_obj_ids, _, _, _) = ca.initial_file_load(h5file)

        file_timestamp = timestamp_string_from_filename(h5file)

        fanout = xml_stimulus.xml_fanout_from_filename(fanout_xml)
        include_obj_ids, exclude_obj_ids = \
        fanout.get_obj_ids_for_timestamp(timestamp_string=file_timestamp)
        if include_obj_ids is not None:
            use_obj_ids = include_obj_ids
        if exclude_obj_ids is not None:
            use_obj_ids = list(set(use_obj_ids).difference(exclude_obj_ids))

        episode = fanout._get_episode_for_timestamp(
            timestamp_string=file_timestamp)
        (_, _, stim_fname) = episode
        return True, use_obj_ids, stim_fname

    except xml_stimulus.WrongXMLTypeError:
        if verbose_problems:
            logger.error("Caught WrongXMLTypeError for '%s'" % file_timestamp)
        return False, None, None
    except ValueError as ex:
        if verbose_problems:
            logger.error("Caught ValueError for '%s': %s" %
                         (file_timestamp, ex))
        return False, None, None
    except Exception as ex:
        logger.error('Not predicted exception while reading %s; %s' %
                     (h5file, ex))
        return False, None, None
示例#2
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('filename', nargs='+',
                        help='name of flydra .hdf5 file',
                        )

    parser.add_argument("--stim-xml",
                        type=str,
                        default=None,
                        help="name of XML file with stimulus info",
                        required=True,
                        )

    parser.add_argument("--align-json",
                        type=str,
                        default=None,
                        help="previously exported json file containing s,R,T",
                        )

    parser.add_argument("--radius", type=float,
                      help="radius of line (in meters)",
                      default=0.002,
                      metavar="RADIUS")

    parser.add_argument("--obj-only", type=str)

    parser.add_argument("--obj-filelist", type=str,
                      help="use object ids from list in text file",
                      )

    parser.add_argument(
        "-r", "--reconstructor", dest="reconstructor_path",
        type=str,
        help=("calibration/reconstructor path (if not specified, "
              "defaults to FILE)"))

    args = parser.parse_args()
    options = args # optparse OptionParser backwards compatibility

    reconstructor_path = args.reconstructor_path
    fps = None

    ca = core_analysis.get_global_CachingAnalyzer()
    by_file = {}

    for h5_filename in args.filename:
        assert(tables.is_hdf5_file(h5_filename))
        obj_ids, use_obj_ids, is_mat_file, data_file, extra = ca.initial_file_load(
            h5_filename)
        this_fps = result_utils.get_fps( data_file, fail_on_error=False )
        if fps is None:
            if this_fps is not None:
                fps = this_fps
        if reconstructor_path is None:
            reconstructor_path = data_file
        by_file[h5_filename] = (use_obj_ids, data_file)
    del h5_filename
    del obj_ids, use_obj_ids, is_mat_file, data_file, extra

    if options.obj_only is not None:
        obj_only = core_analysis.parse_seq(options.obj_only)
    else:
        obj_only = None

    if reconstructor_path is None:
        raise RuntimeError('must specify reconstructor from CLI if not using .h5 files')

    R = reconstruct.Reconstructor(reconstructor_path)

    if fps is None:
        fps = 100.0
        warnings.warn('Setting fps to default value of %f'%fps)
    else:
        fps = 1.0

    if options.stim_xml is None:
        raise ValueError(
            'stim_xml must be specified (how else will you align the data?')

    if 1:
        stim_xml = xml_stimulus.xml_stimulus_from_filename(
            options.stim_xml,
            )
        try:
            fanout = xml_stimulus.xml_fanout_from_filename( options.stim_xml )
        except xml_stimulus.WrongXMLTypeError:
            pass
        else:
            include_obj_ids, exclude_obj_ids = fanout.get_obj_ids_for_timestamp(
                timestamp_string=file_timestamp )
            if include_obj_ids is not None:
                use_obj_ids = include_obj_ids
            if exclude_obj_ids is not None:
                use_obj_ids = list( set(use_obj_ids).difference(
                    exclude_obj_ids ) )
            print('using object ids specified in fanout .xml file')
        if stim_xml.has_reconstructor():
            stim_xml.verify_reconstructor(R)

    x = []
    y = []
    z = []
    speed = []

    if options.obj_filelist is not None:
        obj_filelist=options.obj_filelist
    else:
        obj_filelist=None

    if obj_filelist is not None:
        obj_only = 1

    if obj_only is not None:
        if len(by_file) != 1:
            raise RuntimeError("specifying obj_only can only be done for a single file")
        if obj_filelist is not None:
            data = np.loadtxt(obj_filelist,delimiter=',')
            obj_only = np.array(data[:,0], dtype='int')
            print(obj_only)

        use_obj_ids = numpy.array(obj_only)
        h5_filename = by_file.keys()[0]
        (prev_use_ob_ids, data_file) = by_file[h5_filename]
        by_file[h5_filename] = (use_obj_ids, data_file)

    for h5_filename in by_file:
        (use_obj_ids, data_file) = by_file[h5_filename]
        for obj_id_enum,obj_id in enumerate(use_obj_ids):
            rows = ca.load_data( obj_id, data_file,
                                 use_kalman_smoothing=False,
                                 #dynamic_model_name = dynamic_model_name,
                                 #frames_per_second=fps,
                                 #up_dir=up_dir,
                                )
            verts = numpy.array( [rows['x'], rows['y'], rows['z']] ).T
            if len(verts)>=3:
                verts_central_diff = verts[2:,:] - verts[:-2,:]
                dt = 1.0/fps
                vels = verts_central_diff/(2*dt)
                speeds = numpy.sqrt(numpy.sum(vels**2,axis=1))
                # pad end points
                speeds = numpy.array([speeds[0]] + list(speeds) + [speeds[-1]])
            else:
                speeds = numpy.zeros( (verts.shape[0],) )

            if verts.shape[0] != len(speeds):
                raise ValueError('mismatch length of x data and speeds')
            x.append( verts[:,0] )
            y.append( verts[:,1] )
            z.append( verts[:,2] )
            speed.append(speeds)
        data_file.close()
    del h5_filename, use_obj_ids, data_file

    if 0:
        # debug
        if stim_xml is not None:
            v = None
            for child in stim_xml.root:
                if child.tag == 'cubic_arena':
                    info = stim_xml._get_info_for_cubic_arena(child)
                    v=info['verts4x4']
            if v is not None:
                for vi in v:
                    print('adding',vi)
                    x.append( [vi[0]] )
                    y.append( [vi[1]] )
                    z.append( [vi[2]] )
                    speed.append( [100.0] )

    x = np.concatenate(x)
    y = np.concatenate(y)
    z = np.concatenate(z)
    w = np.ones_like(x)
    speed = np.concatenate(speed)

    # homogeneous coords
    verts = np.array([x,y,z,w])

    #######################################################

    # Create the MayaVi engine and start it.
    e = Engine()
    # start does nothing much but useful if someone is listening to
    # your engine.
    e.start()

    # Create a new scene.
    from tvtk.tools import ivtk
    #viewer = ivtk.IVTK(size=(600,600))
    viewer = IVTKWithCalGUI(size=(800,600))
    viewer.open()
    e.new_scene(viewer)

    viewer.cal_align.set_data(verts,speed,R,args.align_json)

    if 0:
        # Do this if you need to see the MayaVi tree view UI.
        ev = EngineView(engine=e)
        ui = ev.edit_traits()

    # view aligned data
    e.add_source(viewer.cal_align.source)

    v = Vectors()
    v.glyph.scale_mode = 'data_scaling_off'
    v.glyph.color_mode = 'color_by_scalar'
    v.glyph.glyph_source.glyph_position='center'
    v.glyph.glyph_source.glyph_source = tvtk.SphereSource(
        radius=options.radius,
        )
    e.add_module(v)

    if stim_xml is not None:
        if 0:
            stim_xml.draw_in_mayavi_scene(e)
        else:
            actors = stim_xml.get_tvtk_actors()
            viewer.scene.add_actors(actors)

    gui = GUI()
    gui.start_event_loop()
示例#3
0
    ca = core_analysis.get_global_CachingAnalyzer()
    obj_ids, use_obj_ids, is_mat_file, data_file, extra = ca.initial_file_load(
        filename)

    if obj_only is not None:
        use_obj_ids = obj_only

    if options.stim_xml is not None:
        file_timestamp = data_file.filename[4:19]
        stim_xml = xml_stimulus.xml_stimulus_from_filename(
            options.stim_xml,
            timestamp_string=file_timestamp,
        )
    try:
        fanout = xml_stimulus.xml_fanout_from_filename(options.stim_xml)
    except xml_stimulus.WrongXMLTypeError:
        pass
    else:
        include_obj_ids, exclude_obj_ids = fanout.get_obj_ids_for_timestamp(
            timestamp_string=file_timestamp)
        if include_obj_ids is not None:
            use_obj_ids = include_obj_ids
        if exclude_obj_ids is not None:
            use_obj_ids = list(set(use_obj_ids).difference(exclude_obj_ids))
        print 'using object ids specified in fanout .xml file'

    if dynamic_model_name is None:
        if 'dynamic_model_name' in extra:
            dynamic_model_name = extra['dynamic_model_name']
            print 'detected file loaded with dynamic model "%s"' % dynamic_model_name
示例#4
0
def plot_top_and_side_views(
    subplot=None, options=None, obs_mew=None, scale=1.0, units="m",
):
    """
    inputs
    ------
    subplot - a dictionary of matplotlib axes instances with keys 'xy' and/or 'xz'
    fps - the framerate of the data
    """
    assert subplot is not None

    assert options is not None

    if not hasattr(options, "show_track_ends"):
        options.show_track_ends = False

    if not hasattr(options, "unicolor"):
        options.unicolor = False

    if not hasattr(options, "show_landing"):
        options.show_landing = False

    kalman_filename = options.kalman_filename
    fps = options.fps
    dynamic_model = options.dynamic_model
    use_kalman_smoothing = options.use_kalman_smoothing

    if not hasattr(options, "ellipsoids"):
        options.ellipsoids = False

    if not hasattr(options, "show_observations"):
        options.show_observations = False

    if not hasattr(options, "markersize"):
        options.markersize = 0.5

    if options.ellipsoids and use_kalman_smoothing:
        warnings.warn(
            "plotting ellipsoids while using Kalman smoothing does not reveal original error estimates"
        )

    assert kalman_filename is not None

    start = options.start
    stop = options.stop
    obj_only = options.obj_only

    if not use_kalman_smoothing:
        if dynamic_model is not None:
            print(
                "ERROR: disabling Kalman smoothing (--disable-kalman-smoothing) is incompatable with setting dynamic model options (--dynamic-model)",
                file=sys.stderr,
            )
            sys.exit(1)

    ca = core_analysis.get_global_CachingAnalyzer()

    if kalman_filename is not None:
        obj_ids, use_obj_ids, is_mat_file, data_file, extra = ca.initial_file_load(
            kalman_filename
        )

    if not is_mat_file:
        mat_data = None

        if fps is None:
            fps = result_utils.get_fps(data_file, fail_on_error=False)

        if fps is None:
            fps = 100.0
            warnings.warn("Setting fps to default value of %f" % fps)
        reconstructor = reconstruct.Reconstructor(data_file)
    else:
        reconstructor = None

    if options.stim_xml:
        file_timestamp = data_file.filename[4:19]
        stim_xml = xml_stimulus.xml_stimulus_from_filename(
            options.stim_xml, timestamp_string=file_timestamp,
        )
        try:
            fanout = xml_stimulus.xml_fanout_from_filename(options.stim_xml)
        except xml_stimulus.WrongXMLTypeError:
            walking_start_stops = []
        else:
            include_obj_ids, exclude_obj_ids = fanout.get_obj_ids_for_timestamp(
                timestamp_string=file_timestamp
            )
            walking_start_stops = fanout.get_walking_start_stops_for_timestamp(
                timestamp_string=file_timestamp
            )
            if include_obj_ids is not None:
                use_obj_ids = include_obj_ids
            if exclude_obj_ids is not None:
                use_obj_ids = list(set(use_obj_ids).difference(exclude_obj_ids))
            stim_xml = fanout.get_stimulus_for_timestamp(
                timestamp_string=file_timestamp
            )
        if stim_xml.has_reconstructor():
            stim_xml.verify_reconstructor(reconstructor)
    else:
        walking_start_stops = []

    if dynamic_model is None:
        dynamic_model = extra.get("dynamic_model_name", None)

    if dynamic_model is None:
        if use_kalman_smoothing:
            warnings.warn(
                "no kalman smoothing will be performed because no "
                "dynamic model specified or found."
            )
            use_kalman_smoothing = False
    else:
        print('detected file loaded with dynamic model "%s"' % dynamic_model)
        if use_kalman_smoothing:
            if dynamic_model.startswith("EKF "):
                dynamic_model = dynamic_model[4:]
            print('  for smoothing, will use dynamic model "%s"' % dynamic_model)

    subplots = subplot.keys()
    subplots.sort()  # ensure consistency across runs

    dt = 1.0 / fps

    if obj_only is not None:
        use_obj_ids = [i for i in use_obj_ids if i in obj_only]

    subplot["xy"].set_aspect("equal")
    subplot["xz"].set_aspect("equal")

    subplot["xy"].set_xlabel("x (%s)" % units)
    subplot["xy"].set_ylabel("y (%s)" % units)

    subplot["xz"].set_xlabel("x (%s)" % units)
    subplot["xz"].set_ylabel("z (%s)" % units)

    if options.stim_xml:
        stim_xml.plot_stim(
            subplot["xy"], projection=xml_stimulus.SimpleOrthographicXYProjection()
        )
        stim_xml.plot_stim(
            subplot["xz"], projection=xml_stimulus.SimpleOrthographicXZProjection()
        )

    allX = {}
    frame0 = None
    results = collections.defaultdict(list)
    for obj_id in use_obj_ids:
        line = None
        ellipse_lines = []
        MLE_line = None
        try:
            kalman_rows = ca.load_data(
                obj_id,
                data_file,
                use_kalman_smoothing=use_kalman_smoothing,
                dynamic_model_name=dynamic_model,
                frames_per_second=fps,
                up_dir=options.up_dir,
            )
        except core_analysis.ObjectIDDataError:
            continue

        if options.show_observations:
            kobs_rows = ca.load_dynamics_free_MLE_position(obj_id, data_file)

        frame = kalman_rows["frame"]
        if options.show_observations:
            frame_obs = kobs_rows["frame"]

        if (start is not None) or (stop is not None):
            valid_cond = numpy.ones(frame.shape, dtype=numpy.bool)
            if options.show_observations:
                valid_obs_cond = np.ones(frame_obs.shape, dtype=numpy.bool)

            if start is not None:
                valid_cond = valid_cond & (frame >= start)
                if options.show_observations:
                    valid_obs_cond = valid_obs_cond & (frame_obs >= start)

            if stop is not None:
                valid_cond = valid_cond & (frame <= stop)
                if options.show_observations:
                    valid_obs_cond = valid_obs_cond & (frame_obs <= stop)

            kalman_rows = kalman_rows[valid_cond]
            if options.show_observations:
                kobs_rows = kobs_rows[valid_obs_cond]
            if not len(kalman_rows):
                continue

        frame = kalman_rows["frame"]

        Xx = kalman_rows["x"]
        Xy = kalman_rows["y"]
        Xz = kalman_rows["z"]

        if options.max_z is not None:
            cond = Xz <= options.max_z

            frame = numpy.ma.masked_where(~cond, frame)
            Xx = numpy.ma.masked_where(~cond, Xx)
            Xy = numpy.ma.masked_where(~cond, Xy)
            Xz = numpy.ma.masked_where(~cond, Xz)
            with keep_axes_dimensions_if(subplot["xz"], options.stim_xml):
                subplot["xz"].axhline(options.max_z)

        kws = {"markersize": options.markersize}

        if options.unicolor:
            kws["color"] = "k"

        landing_idxs = []
        for walkstart, walkstop in walking_start_stops:
            if walkstart in frame:
                tmp_idx = numpy.nonzero(frame == walkstart)[0][0]
                landing_idxs.append(tmp_idx)

        with keep_axes_dimensions_if(subplot["xy"], options.stim_xml):
            (line,) = subplot["xy"].plot(
                Xx * scale, Xy * scale, ".", label="obj %d" % obj_id, **kws
            )
            kws["color"] = line.get_color()
            if options.ellipsoids:
                for i in range(len(Xx)):
                    rowi = kalman_rows[i]
                    mu = [rowi["x"], rowi["y"], rowi["z"]]
                    va = get_covariance(rowi)
                    ellx, elly = densities.gauss_ell(mu, va, [0, 1], 30, 0.39)
                    (ellipse_line,) = subplot["xy"].plot(
                        ellx * scale, elly * scale, color=kws["color"]
                    )
                    ellipse_lines.append(ellipse_line)
            if options.show_track_ends:
                subplot["xy"].plot(
                    [Xx[0] * scale, Xx[-1] * scale],
                    [Xy[0] * scale, Xy[-1] * scale],
                    "cd",
                    ms=6,
                    label="track end",
                )
            if options.show_obj_id:
                subplot["xy"].text(Xx[0] * scale, Xy[0] * scale, str(obj_id))
            if options.show_landing:
                for landing_idx in landing_idxs:
                    subplot["xy"].plot(
                        [Xx[landing_idx] * scale],
                        [Xy[landing_idx] * scale],
                        "rD",
                        ms=10,
                        label="landing",
                    )
            if options.show_observations:
                mykw = {}
                mykw.update(kws)
                mykw["markersize"] *= 5
                mykw["mew"] = obs_mew

                badcond = np.isnan(kobs_rows["x"])
                Xox = np.ma.masked_where(badcond, kobs_rows["x"])
                Xoy = np.ma.masked_where(badcond, kobs_rows["y"])

                (MLE_line,) = subplot["xy"].plot(
                    Xox * scale, Xoy * scale, "x", label="obj %d" % obj_id, **mykw
                )

        with keep_axes_dimensions_if(subplot["xz"], options.stim_xml):
            (line,) = subplot["xz"].plot(
                Xx * scale, Xz * scale, ".", label="obj %d" % obj_id, **kws
            )
            kws["color"] = line.get_color()
            if options.ellipsoids:
                for i in range(len(Xx)):
                    rowi = kalman_rows[i]
                    mu = [rowi["x"], rowi["y"], rowi["z"]]
                    va = get_covariance(rowi)
                    ellx, ellz = densities.gauss_ell(mu, va, [0, 2], 30, 0.39)
                    (ellipse_line,) = subplot["xz"].plot(
                        ellx * scale, ellz * scale, color=kws["color"]
                    )
                    ellipse_lines.append(ellipse_line)

            if options.show_track_ends:
                subplot["xz"].plot(
                    [Xx[0] * scale, Xx[-1] * scale],
                    [Xz[0] * scale, Xz[-1] * scale],
                    "cd",
                    ms=6,
                    label="track end",
                )
            if options.show_obj_id:
                subplot["xz"].text(Xx[0] * scale, Xz[0] * scale, str(obj_id))
            if options.show_landing:
                for landing_idx in landing_idxs:
                    subplot["xz"].plot(
                        [Xx[landing_idx] * scale],
                        [Xz[landing_idx] * scale],
                        "rD",
                        ms=10,
                        label="landing",
                    )
            if options.show_observations:
                mykw = {}
                mykw.update(kws)
                mykw["markersize"] *= 5
                mykw["mew"] = obs_mew

                badcond = np.isnan(kobs_rows["x"])
                Xox = np.ma.masked_where(badcond, kobs_rows["x"])
                Xoz = np.ma.masked_where(badcond, kobs_rows["z"])

                (MLE_line,) = subplot["xz"].plot(
                    Xox * scale, Xoz * scale, "x", label="obj %d" % obj_id, **mykw
                )
        results["lines"].append(line)
        results["ellipse_lines"].extend(ellipse_lines)
        results["MLE_line"].append(MLE_line)
    return results
示例#5
0
def doit(
    filename,
    obj_only=None,
    min_length=10,
    use_kalman_smoothing=True,
    data_fps=100.0,
    save_fps=25,
    vertical_scale=False,
    max_vel="auto",
    draw_stim_func_str=None,
    floor=True,
    animation_path_fname=None,
    output_dir=".",
    cam_only_move_duration=5.0,
    options=None,
):

    if not use_kalman_smoothing:
        if dynamic_model_name is not None:
            print(
                "ERROR: disabling Kalman smoothing (--disable-kalman-smoothing) is incompatable with setting dynamic model option (--dynamic-model)",
                file=sys.stderr,
            )
            sys.exit(1)
    dynamic_model_name = options.dynamic_model

    if animation_path_fname is None:
        animation_path_fname = pkg_resources.resource_filename(
            __name__, "kdmovie_saver_default_path.kmp"
        )
    camera_animation_path = AnimationPath(animation_path_fname)

    mat_data = None
    try:
        try:
            data_path, data_filename = os.path.split(filename)
            data_path = os.path.expanduser(data_path)
            sys.path.insert(0, data_path)
            mat_data = scipy.io.mio.loadmat(data_filename)
        finally:
            del sys.path[0]
    except IOError as err:
        print(
            "not a .mat file at %s, treating as .hdf5 file"
            % (os.path.join(data_path, data_filename))
        )

    ca = core_analysis.get_global_CachingAnalyzer()
    obj_ids, use_obj_ids, is_mat_file, data_file, extra = ca.initial_file_load(filename)

    if obj_only is not None:
        use_obj_ids = obj_only

    if options.stim_xml is not None:
        file_timestamp = data_file.filename[4:19]
        stim_xml = xml_stimulus.xml_stimulus_from_filename(
            options.stim_xml, timestamp_string=file_timestamp,
        )
    try:
        fanout = xml_stimulus.xml_fanout_from_filename(options.stim_xml)
    except xml_stimulus.WrongXMLTypeError:
        pass
    else:
        include_obj_ids, exclude_obj_ids = fanout.get_obj_ids_for_timestamp(
            timestamp_string=file_timestamp
        )
        if include_obj_ids is not None:
            use_obj_ids = include_obj_ids
        if exclude_obj_ids is not None:
            use_obj_ids = list(set(use_obj_ids).difference(exclude_obj_ids))
        print("using object ids specified in fanout .xml file")

    if dynamic_model_name is None:
        if "dynamic_model_name" in extra:
            dynamic_model_name = extra["dynamic_model_name"]
            print('detected file loaded with dynamic model "%s"' % dynamic_model_name)
            if dynamic_model_name.startswith("EKF "):
                dynamic_model_name = dynamic_model_name[4:]
            print('  for smoothing, will use dynamic model "%s"' % dynamic_model_name)
        else:
            print(
                "no dynamic model name specified, and it could not be determined from the file, either"
            )

    filename_trimmed = os.path.split(os.path.splitext(filename)[0])[-1]

    assert use_obj_ids is not None

    #################
    rw = tvtk.RenderWindow(size=(1024, 768))

    ren = tvtk.Renderer(background=(1.0, 1.0, 1.0))
    camera = ren.active_camera

    rw.add_renderer(ren)

    lut = tvtk.LookupTable(hue_range=(0.667, 0.0))
    #################
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    if len(use_obj_ids) == 1:
        animate_path = True
        # allow path to grow during trajectory
    else:
        animate_path = False
        obj_verts = []
        speeds = []

    for obj_id in use_obj_ids:

        print("loading %d" % obj_id)
        results = ca.calculate_trajectory_metrics(
            obj_id,
            data_file,
            use_kalman_smoothing=use_kalman_smoothing,
            frames_per_second=data_fps,
            dynamic_model_name=dynamic_model_name,
            # method='position based',
            method_params={"downsample": 1,},
        )

        if len(use_obj_ids) == 1:
            obj_verts = results["X_kalmanized"]
            speeds = results["speed_kalmanized"]
            real_frames = results["frame"]

        else:
            obj_verts.append(results["X_kalmanized"])
            speeds.append(results["speed_kalmanized"])
            real_frames.append(results["frame"])

    if options.start is not None:
        good_cond = real_frames >= options.start
        obj_verts = obj_verts[good_cond]
        speeds = speeds[good_cond]
        real_frames = real_frames[good_cond]

    if not len(use_obj_ids) == 1:
        obj_verts = numpy.concatenate(obj_verts, axis=0)
        speeds = numpy.concatenate(speeds, axis=0)
        real_frames = numpy.concatenate(real_frames, axis=0)

    ####################### start draw permanently on stuff ############################

    if options.stim_xml is not None:

        if not is_mat_file:
            R = reconstruct.Reconstructor(data_file)
            stim_xml.verify_reconstructor(R)

        if not is_mat_file:
            assert data_file.filename.startswith("DATA") and (
                data_file.filename.endswith(".h5")
                or data_file.filename.endswith(".kh5")
            )
            file_timestamp = data_file.filename[4:19]
        actors = stim_xml.get_tvtk_actors()
        for actor in actors:
            ren.add_actor(actor)

    if 1:
        if 0:
            # Inspired by pyface.tvtk.decorated_scene
            marker = tvtk.OrientationMarkerWidget()

        axes = tvtk.AxesActor()
        axes.set(
            # normalized_tip_length=(0.04, 0.4, 0.4),
            # normalized_shaft_length=(0.6, 0.6, 0.6),
            shaft_type="cylinder",
            total_length=(0.15, 0.15, 0.15),
        )

        if 1:
            axes.x_axis_label_text = ""
            axes.y_axis_label_text = ""
            axes.z_axis_label_text = ""
        else:
            p = axes.x_axis_caption_actor2d.caption_text_property
            axes.y_axis_caption_actor2d.caption_text_property = p
            axes.z_axis_caption_actor2d.caption_text_property = p
            p.color = 0.0, 0.0, 0.0  # black
        # axes.camera = camera
        # axes.attachment_point_coordinate = (0,0,0)
        axes.origin = (-0.5, 1, 0)

        if 0:
            rwi = rw.interactor
            print("rwi", rwi)
            marker.orientation_marker = axes
            # marker.interactive = False
            marker.interactor = rwi
            marker.enabled = True

        ren.add_actor(axes)

    #######################

    if max_vel == "auto":
        max_vel = speeds.max()
    else:
        max_vel = float(max_vel)
    vel_mapper = tvtk.PolyDataMapper()
    vel_mapper.lookup_table = lut
    vel_mapper.scalar_range = 0.0, max_vel

    if 1:
        # Create a scalar bar
        if vertical_scale:
            scalar_bar = tvtk.ScalarBarActor(
                orientation="vertical", width=0.08, height=0.4
            )
        else:
            scalar_bar = tvtk.ScalarBarActor(
                orientation="horizontal", width=0.4, height=0.08
            )
        scalar_bar.title = "Speed (m/s)"
        scalar_bar.lookup_table = vel_mapper.lookup_table

        scalar_bar.property.color = 0.0, 0.0, 0.0  # black

        scalar_bar.title_text_property.color = 0.0, 0.0, 0.0
        scalar_bar.title_text_property.shadow = False

        scalar_bar.label_text_property.color = 0.0, 0.0, 0.0
        scalar_bar.label_text_property.shadow = False

        scalar_bar.position_coordinate.coordinate_system = "normalized_viewport"
        if vertical_scale:
            scalar_bar.position_coordinate.value = 0.01, 0.01, 0.0
        else:
            scalar_bar.position_coordinate.value = 0.1, 0.01, 0.0

        ren.add_actor(scalar_bar)

    imf = tvtk.WindowToImageFilter(input=rw, read_front_buffer="off")
    writer = tvtk.PNGWriter()

    ####################### end draw permanently on stuff ############################

    save_dt = 1.0 / save_fps

    if animate_path:
        data_dt = 1.0 / data_fps
        n_frames = len(obj_verts)
        dur = n_frames * data_dt
    else:
        data_dt = 0.0
        dur = 0.0

    print("data_fps", data_fps)
    print("data_dt", data_dt)
    print("save_fps", save_fps)

    t_now = 0.0
    frame_number = 0
    while t_now <= dur:
        frame_number += 1
        t_now += save_dt
        print("t_now", t_now)

        pos, ori = camera_animation_path.get_pos_ori(t_now)
        focal_point, view_up = pos_ori2fu(pos, ori)

        camera.position = tuple(pos)
        # camera.focal_point = (focal_point[0], focal_point[1], focal_point[2])
        # camera.view_up = (view_up[0], view_up[1], view_up[2])
        camera.focal_point = tuple(focal_point)
        camera.view_up = tuple(view_up)

        if data_dt != 0.0:
            draw_n_frames = int(round(t_now / data_dt))
        else:
            draw_n_frames = len(obj_verts)
        print("frame_number, draw_n_frames", frame_number, draw_n_frames)

        #################

        pd = tvtk.PolyData()
        pd.points = obj_verts[:draw_n_frames]
        real_frame_number = real_frames[:draw_n_frames][-1]
        pd.point_data.scalars = speeds
        if numpy.any(speeds > max_vel):
            print(
                "WARNING: maximum speed (%.3f m/s) exceeds color map max"
                % (speeds.max(),)
            )

        g = tvtk.Glyph3D(
            scale_mode="data_scaling_off", vector_mode="use_vector", input=pd
        )
        vel_mapper.input = g.output
        ss = tvtk.SphereSource(radius=options.radius)
        g.source = ss.output
        a = tvtk.Actor(mapper=vel_mapper)

        ##################

        ren.add_actor(a)

        if 1:
            imf.update()
            imf.modified()
            writer.input = imf.output
            # fname = 'movie_%s_%03d_frame%05d.png'%(filename_trimmed,obj_id,frame_number)
            fname = "movie_%s_%03d_frame%05d.png" % (
                filename_trimmed,
                obj_id,
                real_frame_number,
            )
            print("saving", fname)
            full_fname = os.path.join(output_dir, fname)
            writer.file_name = full_fname
            writer.write()

        ren.remove_actor(a)

    ren.add_actor(a)  # restore actors removed
    dur = dur + cam_only_move_duration

    while t_now < dur:
        frame_number += 1
        t_now += save_dt
        print("t_now", t_now)

        pos, ori = camera_animation_path.get_pos_ori(t_now)
        focal_point, view_up = pos_ori2fu(pos, ori)
        camera.position = tuple(pos)
        camera.focal_point = tuple(focal_point)
        camera.view_up = tuple(view_up)
        if 1:
            imf.update()
            imf.modified()
            writer.input = imf.output
            if len(use_obj_ids) == 1:
                fname = "movie_%s_%03d_frame%05d.png" % (
                    filename_trimmed,
                    obj_id,
                    frame_number,
                )
            else:
                fname = "movie_%s_many_frame%05d.png" % (filename_trimmed, frame_number)
            full_fname = os.path.join(output_dir, fname)
            writer.file_name = full_fname
            writer.write()

    if not is_mat_file:
        data_file.close()
示例#6
0
def plot_timeseries(subplot=None, options=None):
    kalman_filename = options.kalman_filename

    if not hasattr(options, 'frames'):
        options.frames = False

    if not hasattr(options, 'show_landing'):
        options.show_landing = False

    if not hasattr(options, 'unicolor'):
        options.unicolor = False

    if not hasattr(options, 'show_obj_id'):
        options.show_obj_id = True

    if not hasattr(options, 'show_track_ends'):
        options.show_track_ends = False

    start = options.start
    stop = options.stop
    obj_only = options.obj_only
    fps = options.fps
    dynamic_model = options.dynamic_model
    use_kalman_smoothing = options.use_kalman_smoothing

    if not use_kalman_smoothing:
        if (dynamic_model is not None):
            print >> sys.stderr, (
                'WARNING: disabling Kalman smoothing '
                '(--disable-kalman-smoothing) is incompatable '
                'with setting dynamic model options (--dynamic-model)')

    ca = core_analysis.get_global_CachingAnalyzer()

    if kalman_filename is None:
        raise ValueError('No kalman_filename given. Nothing to do.')

    m = hashlib.md5()
    m.update(open(kalman_filename, mode='rb').read())
    actual_md5 = m.hexdigest()
    (obj_ids, use_obj_ids, is_mat_file, data_file,
     extra) = ca.initial_file_load(kalman_filename)
    print 'opened kalman file %s %s, %d obj_ids' % (
        kalman_filename, actual_md5, len(use_obj_ids))

    if 'frames' in extra:
        if (start is not None) or (stop is not None):
            valid_frames = np.ones((len(extra['frames']), ), dtype=np.bool)
            if start is not None:
                valid_frames &= extra['frames'] >= start
            if stop is not None:
                valid_frames &= extra['frames'] <= stop
            this_use_obj_ids = np.unique(obj_ids[valid_frames])
            use_obj_ids = list(set(use_obj_ids).intersection(this_use_obj_ids))

    include_obj_ids = None
    exclude_obj_ids = None
    do_fuse = False
    if options.stim_xml:
        file_timestamp = data_file.filename[4:19]
        fanout = xml_stimulus.xml_fanout_from_filename(options.stim_xml)
        include_obj_ids, exclude_obj_ids = fanout.get_obj_ids_for_timestamp(
            timestamp_string=file_timestamp)
        walking_start_stops = fanout.get_walking_start_stops_for_timestamp(
            timestamp_string=file_timestamp)
        if include_obj_ids is not None:
            use_obj_ids = include_obj_ids
        if exclude_obj_ids is not None:
            use_obj_ids = list(set(use_obj_ids).difference(exclude_obj_ids))
        if options.fuse:
            do_fuse = True
    else:
        walking_start_stops = []

    if dynamic_model is None:
        dynamic_model = extra['dynamic_model_name']
        print 'detected file loaded with dynamic model "%s"' % dynamic_model
        if dynamic_model.startswith('EKF '):
            dynamic_model = dynamic_model[4:]
        print '  for smoothing, will use dynamic model "%s"' % dynamic_model

    if not is_mat_file:
        mat_data = None

        if fps is None:
            fps = result_utils.get_fps(data_file, fail_on_error=False)

        if fps is None:
            fps = 100.0
            import warnings
            warnings.warn('Setting fps to default value of %f' % fps)

        tz = result_utils.get_tz(data_file)

    dt = 1.0 / fps

    all_vels = []

    if obj_only is not None:
        use_obj_ids = [i for i in use_obj_ids if i in obj_only]

    allX = {}
    frame0 = None

    line2obj_id = {}
    Xz_all = []

    fuse_did_once = False

    if not hasattr(options, 'timestamp_file'):
        options.timestamp_file = None

    if not hasattr(options, 'ori_qual'):
        options.ori_qual = None

    if options.timestamp_file is not None:
        h5 = tables.open_file(options.timestamp_file, mode='r')
        print 'reading timestamps and frames'
        table_data2d_frames = h5.root.data2d_distorted.read(field='frame')
        table_data2d_timestamps = h5.root.data2d_distorted.read(
            field='timestamp')
        print 'done'
        h5.close()
        table_data2d_frames_find = utils.FastFinder(table_data2d_frames)

    if len(use_obj_ids) == 0:
        print 'No obj_ids to plot, quitting'
        sys.exit(0)

    time0 = 0.0  # set default value

    for obj_id in use_obj_ids:
        if not do_fuse:
            try:
                kalman_rows = ca.load_data(
                    obj_id,
                    data_file,
                    use_kalman_smoothing=use_kalman_smoothing,
                    dynamic_model_name=dynamic_model,
                    return_smoothed_directions=options.smooth_orientations,
                    frames_per_second=fps,
                    up_dir=options.up_dir,
                    min_ori_quality_required=options.ori_qual,
                )
            except core_analysis.ObjectIDDataError:
                continue
            #kobs_rows = ca.load_dynamics_free_MLE_position( obj_id, data_file )
        else:
            if options.show_3d_orientations:
                raise NotImplementedError('orientation data is not supported '
                                          'when fusing obj_ids')
            if fuse_did_once:
                break
            fuse_did_once = True
            kalman_rows = flydra_analysis.a2.flypos.fuse_obj_ids(
                use_obj_ids,
                data_file,
                dynamic_model_name=dynamic_model,
                frames_per_second=fps)
        frame = kalman_rows['frame']

        if (start is not None) or (stop is not None):
            valid_cond = numpy.ones(frame.shape, dtype=numpy.bool)

            if start is not None:
                valid_cond = valid_cond & (frame >= start)

            if stop is not None:
                valid_cond = valid_cond & (frame <= stop)

            kalman_rows = kalman_rows[valid_cond]
            if not len(kalman_rows):
                continue

        walking_and_flying_kalman_rows = kalman_rows  # preserve original data

        for flystate in ['flying', 'walking']:
            frame = walking_and_flying_kalman_rows['frame']  # restore
            if flystate == 'flying':
                # assume flying unless we're told it's walking
                state_cond = numpy.ones(frame.shape, dtype=numpy.bool)
            else:
                state_cond = numpy.zeros(frame.shape, dtype=numpy.bool)

            if len(walking_start_stops):
                for walkstart, walkstop in walking_start_stops:
                    frame = walking_and_flying_kalman_rows['frame']  # restore

                    # handle each bout of walking
                    walking_bout = numpy.ones(frame.shape, dtype=numpy.bool)
                    if walkstart is not None:
                        walking_bout &= (frame >= walkstart)
                    if walkstop is not None:
                        walking_bout &= (frame <= walkstop)
                    if flystate == 'flying':
                        state_cond &= ~walking_bout
                    else:
                        state_cond |= walking_bout

                kalman_rows = np.take(walking_and_flying_kalman_rows,
                                      np.nonzero(state_cond)[0])
                assert len(kalman_rows) == np.sum(state_cond)
                frame = kalman_rows['frame']

            if frame0 is None:
                frame0 = int(frame[0])

            time0 = 0.0
            if options.timestamp_file is not None:
                frame_idxs = table_data2d_frames_find.get_idxs_of_equal(frame0)
                if len(frame_idxs):
                    time0 = table_data2d_timestamps[frame_idxs[0]]
                else:
                    raise ValueError(
                        'could not fine frame %d in timestamp file' % frame0)

            Xx = kalman_rows['x']
            Xy = kalman_rows['y']
            Xz = kalman_rows['z']

            Dx = Dy = Dz = None
            if options.smooth_orientations:
                Dx = kalman_rows['dir_x']
                Dy = kalman_rows['dir_y']
                Dz = kalman_rows['dir_z']
            elif 'rawdir_x' in kalman_rows.dtype.fields:
                Dx = kalman_rows['rawdir_x']
                Dy = kalman_rows['rawdir_y']
                Dz = kalman_rows['rawdir_z']

            if not options.frames:
                f2t = Frames2Time(frame0, fps, time0)
            else:

                def identity(x):
                    return x

                f2t = identity

            kws = {
                'linewidth': 2,
                'picker': 5,
            }
            if options.unicolor:
                kws['color'] = 'k'

            line = None

            if 'frame' in subplot:
                subplot['frame'].plot(f2t(frame), frame)

            if 'P55' in subplot:
                subplot['P55'].plot(f2t(frame), kalman_rows['P55'])

            if 'x' in subplot:
                line, = subplot['x'].plot(f2t(frame),
                                          Xx,
                                          label='obj %d (%s)' %
                                          (obj_id, flystate),
                                          **kws)
                line2obj_id[line] = obj_id
                kws['color'] = line.get_color()

            if 'y' in subplot:
                line, = subplot['y'].plot(f2t(frame),
                                          Xy,
                                          label='obj %d (%s)' %
                                          (obj_id, flystate),
                                          **kws)
                line2obj_id[line] = obj_id
                kws['color'] = line.get_color()

            if 'z' in subplot:
                frame_data = numpy.ma.getdata(
                    frame)  # works if frame is masked or not

                # plot landing time
                if options.show_landing:
                    if flystate == 'flying':  # only do this once
                        for walkstart, walkstop in walking_start_stops:
                            if walkstart in frame_data:
                                landing_dix = numpy.nonzero(
                                    frame_data == walkstart)[0][0]
                                subplot['z'].plot([f2t(walkstart)],
                                                  [Xz.data[landing_dix]],
                                                  'rD',
                                                  ms=10,
                                                  label='landing')

                if options.show_track_ends:
                    if flystate == 'flying':  # only do this once
                        subplot['z'].plot(f2t([frame_data[0], frame_data[-1]]),
                                          [
                                              numpy.ma.getdata(Xz)[0],
                                              numpy.ma.getdata(Xz)[-1]
                                          ],
                                          'cd',
                                          ms=6,
                                          label='track end')

                line, = subplot['z'].plot(f2t(frame),
                                          Xz,
                                          label='obj %d (%s)' %
                                          (obj_id, flystate),
                                          **kws)
                kws['color'] = line.get_color()
                line2obj_id[line] = obj_id

                if flystate == 'flying':
                    # only do this once
                    if options.show_obj_id:
                        subplot['z'].text(f2t(frame_data[0]),
                                          numpy.ma.getdata(Xz)[0],
                                          '%d' % (obj_id, ))
                        line2obj_id[line] = obj_id

            if flystate == 'flying':
                Xz_all.append(np.ma.array(Xz).compressed())
                #bins = np.linspace(0,.8,30)
                #print 'Xz.shape',Xz.shape
                #pylab.hist(Xz, bins=bins)

            for (dir_var, Dd) in [('dx', Dx), ('dy', Dy), ('dz', Dz)]:
                if dir_var in subplot:
                    line, = subplot[dir_var].plot(f2t(frame),
                                                  Dd,
                                                  label='obj %d (%s)' %
                                                  (obj_id, flystate),
                                                  **kws)
                    line2obj_id[line] = obj_id
                    kws['color'] = line.get_color()

            if numpy.__version__ >= '1.2.0':
                X = numpy.ma.array((Xx, Xy, Xz))
            else:
                # See http://scipy.org/scipy/numpy/ticket/820
                X = numpy.ma.vstack(
                    (Xx[numpy.newaxis, :], Xy[numpy.newaxis, :],
                     Xz[numpy.newaxis, :]))

            dist_central_diff = (X[:, 2:] - X[:, :-2])
            vel_central_diff = dist_central_diff / (2 * dt)

            vel2mag = numpy.ma.sqrt(numpy.ma.sum(vel_central_diff**2, axis=0))
            xy_vel2mag = numpy.ma.sqrt(
                numpy.ma.sum(vel_central_diff[:2, :]**2, axis=0))

            frames2 = frame[1:-1]

            accel4mag = (vel2mag[2:] - vel2mag[:-2]) / (2 * dt)
            frames4 = frames2[1:-1]

            if 'vel' in subplot:
                line, = subplot['vel'].plot(f2t(frames2),
                                            vel2mag,
                                            label='obj %d (%s)' %
                                            (obj_id, flystate),
                                            **kws)
                line2obj_id[line] = obj_id
                kws['color'] = line.get_color()

            if 'xy_vel' in subplot:
                line, = subplot['xy_vel'].plot(f2t(frames2),
                                               xy_vel2mag,
                                               label='obj %d (%s)' %
                                               (obj_id, flystate),
                                               **kws)
                line2obj_id[line] = obj_id
                kws['color'] = line.get_color()

            if len(accel4mag.compressed()) and 'accel' in subplot:
                line, = subplot['accel'].plot(f2t(frames4),
                                              accel4mag,
                                              label='obj %d (%s)' %
                                              (obj_id, flystate),
                                              **kws)
                line2obj_id[line] = obj_id
                kws['color'] = line.get_color()

            if flystate == 'flying':
                valid_vel2mag = vel2mag.compressed()
                all_vels.append(valid_vel2mag)
    if len(all_vels):
        all_vels = numpy.hstack(all_vels)
    else:
        all_vels = numpy.array([], dtype=float)

    if 1:
        cond = all_vels < 2.0
        if numpy.ma.sum(cond) != len(all_vels):
            all_vels = all_vels[cond]
            import warnings
            warnings.warn('clipping all velocities > 2.0 m/s')

    if not options.frames:
        xlabel = 'time (s)'
    else:
        xlabel = 'frame'

    for ax in subplot.itervalues():
        ax.xaxis.set_major_formatter(ticker.FormatStrFormatter("%d"))
        ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%s"))

    fixup_ax = FixupAxesWithTimeZone(tz).fixup_ax

    if 'frame' in subplot:
        if time0 != 0.0:
            fixup_ax(subplot['frame'])
        else:
            subplot['frame'].set_xlabel(xlabel)

    if 'x' in subplot:
        subplot['x'].set_ylim([-1, 1])
        subplot['x'].set_ylabel(r'x (m)')
        if time0 != 0.0:
            fixup_ax(subplot['x'])
        else:
            subplot['x'].set_xlabel(xlabel)

    if 'y' in subplot:
        subplot['y'].set_ylim([-0.5, 1.5])
        subplot['y'].set_ylabel(r'y (m)')
        if time0 != 0.0:
            fixup_ax(subplot['y'])
        else:
            subplot['y'].set_xlabel(xlabel)

    max_z = None
    if options.stim_xml:
        file_timestamp = options.kalman_filename[4:19]
        stim_xml = xml_stimulus.xml_stimulus_from_filename(
            options.stim_xml, timestamp_string=file_timestamp)
        post_max_zs = []
        for post_num, post in enumerate(stim_xml.iterate_posts()):
            post_max_zs.append(max(post['verts'][0][2],
                                   post['verts'][1][2]))  # max post height
        if len(post_max_zs):
            max_z = min(post_max_zs)  # take shortest of posts

    if 'z' in subplot:
        subplot['z'].set_ylim([0, 1])
        subplot['z'].set_ylabel(r'z (m)')
        if max_z is not None:
            subplot['z'].axhline(max_z, color='m')
        if time0 != 0.0:
            fixup_ax(subplot['z'])
        else:
            subplot['z'].set_xlabel(xlabel)

    for dir_var in ['dx', 'dy', 'dz']:
        if dir_var in subplot:
            subplot[dir_var].set_ylabel(dir_var)
            if time0 != 0.0:
                fixup_ax(subplot[dir_var])
            else:
                subplot[dir_var].set_xlabel(xlabel)

    if 'z_hist' in subplot:  # and flystate=='flying':
        Xz_all = np.hstack(Xz_all)
        bins = np.linspace(0, .8, 30)
        ax = subplot['z_hist']
        ax.hist(Xz_all, bins=bins, orientation='horizontal')
        ax.set_xticks([])
        ax.set_yticks([])
        xlim = tuple(ax.get_xlim())  # matplotlib 0.98.3 returned np.array view
        ax.set_xlim((xlim[1], xlim[0]))
        ax.axhline(max_z, color='m')

    if 'vel' in subplot:
        subplot['vel'].set_ylim([0, 2])
        subplot['vel'].set_ylabel(r'vel (m/s)')
        subplot['vel'].set_xlabel(xlabel)
        if time0 != 0.0:
            fixup_ax(subplot['vel'])
        else:
            subplot['vel'].set_xlabel(xlabel)

    if 'xy_vel' in subplot:
        #subplot['xy_vel'].set_ylim([0,2])
        subplot['xy_vel'].set_ylabel(r'horiz vel (m/s)')
        subplot['xy_vel'].set_xlabel(xlabel)
        if time0 != 0.0:
            fixup_ax(subplot['xy_vel'])
        else:
            subplot['xy_vel'].set_xlabel(xlabel)

    if 'accel' in subplot:
        subplot['accel'].set_ylabel(r'acceleration (m/(s^2))')
        subplot['accel'].set_xlabel(xlabel)
        if time0 != 0.0:
            fixup_ax(subplot['accel'])
        else:
            subplot['accel'].set_xlabel(xlabel)

    if 'vel_hist' in subplot:
        ax = subplot['vel_hist']
        bins = numpy.linspace(0, 2, 50)
        ax.set_title('excluding walking')
        pdf, bins, patches = ax.hist(all_vels, bins=bins, normed=True)
        ax.set_xlim(0, 2)
        ax.set_ylabel('probability density')
        ax.set_xlabel('velocity (m/s)')

    return line2obj_id