Ejemplo n.º 1
0
def load_3d_raw_data(kalman_filename, require_qual=True, **kwargs):
    with open_file_safe(kalman_filename, mode='r') as kh5:
        ca = core_analysis.get_global_CachingAnalyzer()
        all_obj_ids, obj_ids, is_mat_file, data_file, extra = \
                     ca.initial_file_load(kalman_filename)
        allrows = []
        if require_qual:
            allqualrows = []
        this_kw = {
            'min_ori_quality_required': kwargs['min_ori_quality_required'],
            'ori_quality_smooth_len': kwargs['ori_quality_smooth_len']
        }
        for obj_id in obj_ids:
            rows = ca.load_dynamics_free_MLE_position(obj_id,
                                                      data_file=kh5,
                                                      **this_kw)
            allrows.append(rows)
            if require_qual:
                if np.any(~np.isnan(rows['hz_line0'])):
                    qualrows = compute_ori_quality(data_file,
                                                   rows['frame'],
                                                   obj_id,
                                                   smooth_len=0)
                else:
                    qualrows = np.zeros_like(rows['hz_line0'])
                allqualrows.append(qualrows)
    data3d = np.concatenate(allrows)
    if require_qual:
        dataqual3d = np.concatenate(allqualrows)
        return data3d, dataqual3d
    else:
        return data3d
Ejemplo n.º 2
0
def load_3d_data(kalman_filename,
                 start=None,
                 stop=None,
                 require_qual=True,
                 **kwargs):
    with open_file_safe(kalman_filename, mode="r") as kh5:
        ca = core_analysis.get_global_CachingAnalyzer()
        all_obj_ids, obj_ids, is_mat_file, data_file, extra = ca.initial_file_load(
            kalman_filename)
        fps = extra["frames_per_second"]
        dynamic_model_name = None
        if dynamic_model_name is None:
            dynamic_model_name = extra.get("dynamic_model_name", None)
            if dynamic_model_name is None:
                dynamic_model_name = dynamic_models.DEFAULT_MODEL
                warnings.warn('no dynamic model specified, using "%s"' %
                              dynamic_model_name)
            else:
                print('detected file loaded with dynamic model "%s"' %
                      dynamic_model_name)
            if dynamic_model_name.startswith("EKF "):
                dynamic_model_name = dynamic_model_name[4:]
            print('  for smoothing, will use dynamic model "%s"' %
                  dynamic_model_name)
        allrows = []
        if require_qual:
            allqualrows = []
        for obj_id in obj_ids:
            if not is_obj_in_frame_range(
                    obj_id, all_obj_ids, extra["frames"], start=start,
                    stop=stop):
                # obj_id not in range of frames that we're analyzing now
                continue
            try:
                rows = ca.load_data(obj_id,
                                    kalman_filename,
                                    use_kalman_smoothing=True,
                                    frames_per_second=fps,
                                    dynamic_model_name=dynamic_model_name,
                                    return_smoothed_directions=True,
                                    **kwargs)
            except core_analysis.NotEnoughDataToSmoothError:
                warnings.warn(
                    "not enough data to smooth obj_id %d, skipping." %
                    (obj_id, ))
                continue
            allrows.append(rows)
            if require_qual:
                qualrows = compute_ori_quality(data_file,
                                               rows["frame"],
                                               obj_id,
                                               smooth_len=0)
                allqualrows.append(qualrows)
    data3d = np.concatenate(allrows)
    if require_qual:
        dataqual3d = np.concatenate(allqualrows)
        return data3d, dataqual3d
    else:
        return data3d
Ejemplo n.º 3
0
def compute_ori_quality(h5_context, orig_frames, obj_id, smooth_len=10):
    """compute quality of orientation estimate
    """
    ca = core_analysis.get_global_CachingAnalyzer()
    group = h5_context.get_or_make_group_for_obj(obj_id)
    try:
        table = getattr(group, 'obj%d' % obj_id)
    except:
        sys.stderr.write(
            'ERROR while getting EKF fit data for obj_id %d in file opening %s\n'
            % (obj_id, h5_context.filename))
        sys.stderr.write(
            'Hint: re-run orientation fitting for this file (for this obj_id).\n'
        )
        raise
    table_ram = table[:]
    frames = table_ram['frame']

    camns = []
    for colname in table.colnames:
        if colname.startswith('dist'):
            camn = int(colname[4:])
            camns.append(camn)
    camns.sort()
    ncams = len(camns)

    # start at zero quality
    results = np.zeros((len(orig_frames), ))
    for origi, frame in enumerate(orig_frames):
        cond = frames == frame
        idxs = np.nonzero(cond)[0]
        if len(idxs) == 0:
            results[origi] = np.nan
            continue

        assert len(idxs) == 1
        idx = idxs[0]
        this_row = table_ram[idx]
        used_this_row = np.array([this_row['used%d' % camn] for camn in camns])
        n_used = np.sum(used_this_row)
        if 1:
            results[origi] = n_used
        else:
            theta_this_row = np.array(
                [this_row['theta%d' % camn] for camn in camns])
            data_this_row = ~np.isnan(theta_this_row)
            n_data = np.sum(data_this_row)
            n_rejected = n_data - n_used
            if n_rejected == 0:
                if n_used == 0:
                    results[origi] = 0.0
                else:
                    results[origi] = ncams
            else:
                results[origi] = n_used / n_rejected
    if smooth_len:
        if len(results) > smooth_len:
            results = smooth(results, window_len=smooth_len)
    return results
Ejemplo n.º 4
0
    def __init__(self, src_h5, dst_h5):
        self.ca = core_analysis.get_global_CachingAnalyzer()
        self.src_h5 = src_h5
        self.dst_h5 = dst_h5
        self.ca.initial_file_load(self.src_h5)

        tmp = self.ca.initial_file_load(self.dst_h5)
        obj_ids, unique_obj_ids, is_mat_file, data_file, extra = tmp

        self.dst_frames = extra["frames"]
        self.dst_obj_ids = obj_ids
        self.dst_unique_obj_ids = unique_obj_ids
        self.ff = utils.FastFinder(self.dst_frames)
Ejemplo n.º 5
0
def consider_stimulus(h5file,
                      verbose_problems=False,
                      fanout_name="fanout.xml"):
    """ 
        Parses the corresponding fanout XML and finds IDs to use as well 
        as the stimulus.
        Returns 3 values: valid, use_objs_ids, stimulus.  
        valid is false if something was wrong
    """

    try:
        dirname = os.path.dirname(h5file)
        fanout_xml = os.path.join(dirname, fanout_name)
        if not (os.path.exists(fanout_xml)):
            if verbose_problems:
                logger.error("Stim_xml path not found '%s' for file '%s'" %
                             (h5file, fanout_xml))
            return False, None, None

        ca = core_analysis.get_global_CachingAnalyzer()
        (_, use_obj_ids, _, _, _) = ca.initial_file_load(h5file)

        file_timestamp = timestamp_string_from_filename(h5file)

        fanout = xml_stimulus.xml_fanout_from_filename(fanout_xml)
        include_obj_ids, exclude_obj_ids = \
        fanout.get_obj_ids_for_timestamp(timestamp_string=file_timestamp)
        if include_obj_ids is not None:
            use_obj_ids = include_obj_ids
        if exclude_obj_ids is not None:
            use_obj_ids = list(set(use_obj_ids).difference(exclude_obj_ids))

        episode = fanout._get_episode_for_timestamp(
            timestamp_string=file_timestamp)
        (_, _, stim_fname) = episode
        return True, use_obj_ids, stim_fname

    except xml_stimulus.WrongXMLTypeError:
        if verbose_problems:
            logger.error("Caught WrongXMLTypeError for '%s'" % file_timestamp)
        return False, None, None
    except ValueError as ex:
        if verbose_problems:
            logger.error("Caught ValueError for '%s': %s" %
                         (file_timestamp, ex))
        return False, None, None
    except Exception as ex:
        logger.error('Not predicted exception while reading %s; %s' %
                     (h5file, ex))
        return False, None, None
Ejemplo n.º 6
0
def test_retracking_without_data_association():
    ca = core_analysis.get_global_CachingAnalyzer()

    orig_fname = pkg_resources.resource_filename('flydra_analysis.a2',
                                                 'sample_datafile-v0.4.28.h5')

    tmpdir = tempfile.mkdtemp()
    try:
        retracked_fname = os.path.join(tmpdir, 'retracked.h5')
        retrack_reuse_data_association(
            h5_filename=orig_fname,
            output_h5_filename=retracked_fname,
            kalman_filename=orig_fname,
        )
        with ca.kalman_analysis_context(orig_fname) as orig_h5_context:
            orig_obj_ids = orig_h5_context.get_unique_obj_ids()
            extra = orig_h5_context.get_extra_info()

            with ca.kalman_analysis_context(
                    retracked_fname) as retracked_h5_context:
                retracked_obj_ids = retracked_h5_context.get_unique_obj_ids()

                assert len(retracked_obj_ids) > 10

                for obj_id in retracked_obj_ids[1:-1]:
                    # Cycle over retracked obj_ids, which may be subset of
                    # original (due to missing 2D data)

                    retracked_rows = retracked_h5_context.load_data(
                        obj_id,
                        use_kalman_smoothing=False,
                        dynamic_model_name=extra['dynamic_model_name'],
                        frames_per_second=extra['frames_per_second'],
                    )
                    orig_rows = orig_h5_context.load_data(
                        obj_id,
                        use_kalman_smoothing=False,
                        dynamic_model_name=extra['dynamic_model_name'],
                        frames_per_second=extra['frames_per_second'],
                    )
                    # They tracks start at the same frame...
                    assert retracked_rows['frame'][0] == orig_rows['frame'][0]
                    # and they should be no longer than the original.
                    assert len(retracked_rows) <= len(
                        orig_rows)  # may be shorter?!
    finally:
        shutil.rmtree(tmpdir)
Ejemplo n.º 7
0
def calculate_skipped_frames(
    h5_filename=None,
    output_h5_filename=None,
    kalman_filename=None,
):
    if os.path.exists(output_h5_filename):
        raise RuntimeError("will not overwrite old file '%s'" %
                           output_h5_filename)

    pre_df = {
        'obj_id': [],
        'start_frame': [],
        'stop_frame': [],
        'duration': []
    }
    ca = core_analysis.get_global_CachingAnalyzer()
    with ca.kalman_analysis_context(kalman_filename) as h5_context:
        R = h5_context.get_reconstructor()
        ML_estimates_2d_idxs = h5_context.load_entire_table(
            'ML_estimates_2d_idxs')
        use_obj_ids = h5_context.get_unique_obj_ids()
        for obj_id_enum, obj_id in enumerate(use_obj_ids):
            obj_3d_rows = h5_context.load_dynamics_free_MLE_position(obj_id)
            prev_frame = None
            for this_3d_row in obj_3d_rows:
                # iterate over each sample in the current camera
                framenumber = this_3d_row['frame']
                if prev_frame is not None:
                    if framenumber - prev_frame > 1:
                        pre_df['obj_id'].append(obj_id)
                        pre_df['start_frame'].append(prev_frame)
                        pre_df['stop_frame'].append(framenumber)
                        pre_df['duration'].append(framenumber - prev_frame)
                prev_frame = framenumber

    df = pd.DataFrame(pre_df)

    # save to disk
    store = pd.HDFStore(output_h5_filename)
    store.append('skipped_info', df)
    store.close()
def calculate_reprojection_errors(h5_filename=None,
                                  output_h5_filename=None,
                                  kalman_filename=None,
                                  from_source=None,
                                  start=None,
                                  stop=None,
                                  show_progress=False,
                                  show_progress_json=False,
                                  ):
    assert from_source in ['ML_estimates','smoothed']
    if os.path.exists( output_h5_filename ):
        raise RuntimeError(
            "will not overwrite old file '%s'"%output_h5_filename)

    out = {'camn':[],
           'frame':[],
           'obj_id':[],
           'dist':[],
           'z':[],
           }

    ca = core_analysis.get_global_CachingAnalyzer()
    with ca.kalman_analysis_context( kalman_filename,
                                     data2d_fname=h5_filename ) as h5_context:
        R = h5_context.get_reconstructor()
        ML_estimates_2d_idxs = h5_context.load_entire_table('ML_estimates_2d_idxs')
        use_obj_ids = h5_context.get_unique_obj_ids()

        extra = h5_context.get_extra_info()

        if from_source=='smoothed':
            dynamic_model_name = extra['dynamic_model_name']
            if dynamic_model_name.startswith('EKF '):
                dynamic_model_name = dynamic_model_name[4:]

        fps = h5_context.get_fps()
        camn2cam_id, cam_id2camns = h5_context.get_caminfo_dicts()

        # associate framenumbers with timestamps using 2d .h5 file
        data2d = h5_context.load_entire_table('data2d_distorted',
                                              from_2d_file=True )
        data2d_idxs = np.arange(len(data2d))
        h5_framenumbers = data2d['frame']
        h5_frame_qfi = result_utils.QuickFrameIndexer(h5_framenumbers)

        if show_progress:
            string_widget = StringWidget()
            objs_per_sec_widget = progressbar.FileTransferSpeed(unit='obj_ids ')
            widgets=[string_widget, objs_per_sec_widget,
                     progressbar.Percentage(), progressbar.Bar(), progressbar.ETA()]
            pbar=progressbar.ProgressBar(widgets=widgets,maxval=len(use_obj_ids)).start()

        for obj_id_enum,obj_id in enumerate(use_obj_ids):
            if show_progress:
                string_widget.set_string( '[obj_id: % 5d]'%obj_id )
                pbar.update(obj_id_enum)
            if show_progress_json and obj_id_enum%100==0:
                rough_percent_done = float(obj_id_enum)/len(use_obj_ids)*100.0
                result_utils.do_json_progress(rough_percent_done)

            obj_3d_rows = h5_context.load_dynamics_free_MLE_position(obj_id)

            if from_source=='smoothed':

                smoothed_rows = None
                try:
                    smoothed_rows = h5_context.load_data(
                        obj_id,
                        use_kalman_smoothing=True,
                        dynamic_model_name = dynamic_model_name,
                        frames_per_second=fps,
                        )
                except core_analysis.NotEnoughDataToSmoothError, err:
                    # OK, we don't have data from this obj_id
                    pass
                except core_analysis.DiscontiguousFramesError:
                    pass

            for this_3d_row in obj_3d_rows:
                # iterate over each sample in the current camera
                framenumber = this_3d_row['frame']
                if start is not None:
                    if not framenumber >= start:
                        continue
                if stop is not None:
                    if not framenumber <= stop:
                        continue
                h5_2d_row_idxs = h5_frame_qfi.get_frame_idxs(framenumber)
                if len(h5_2d_row_idxs) == 0:
                    # At the start, there may be 3d data without 2d data.
                    continue

                if from_source=='ML_estimates':
                    X3d = this_3d_row['x'], this_3d_row['y'], this_3d_row['z']
                elif from_source=='smoothed':
                    if smoothed_rows is None:
                        X3d = np.nan, np.nan, np.nan
                    else:
                        this_smoothed_rows = smoothed_rows[ smoothed_rows['frame']==framenumber ]
                        assert len(this_smoothed_rows) <= 1
                        if len(this_smoothed_rows) == 0:
                            X3d = np.nan, np.nan, np.nan
                        else:
                            X3d = this_smoothed_rows['x'][0], this_smoothed_rows['y'][0], this_smoothed_rows['z'][0]

                # If there was a 3D ML estimate, there must be 2D data.

                frame2d = data2d[h5_2d_row_idxs]
                frame2d_idxs = data2d_idxs[h5_2d_row_idxs]

                obs_2d_idx = this_3d_row['obs_2d_idx']
                kobs_2d_data = ML_estimates_2d_idxs[int(obs_2d_idx)]

                # Parse VLArray.
                this_camns = kobs_2d_data[0::2]
                this_camn_idxs = kobs_2d_data[1::2]

                # Now, for each camera viewing this object at this
                # frame, extract images.
                for camn, camn_pt_no in zip(this_camns, this_camn_idxs):
                    try:
                        cam_id = camn2cam_id[camn]
                    except KeyError:
                        warnings.warn('camn %d not found'%(camn, ))
                        continue

                    # find 2D point corresponding to object
                    cond = ((frame2d['camn']==camn) &
                            (frame2d['frame_pt_idx']==camn_pt_no))
                    idxs = np.nonzero(cond)[0]
                    if len(idxs)==0:
                        #no frame for that camera (start or stop of file)
                        continue
                    elif len(idxs)>1:
                        print "MEGA WARNING MULTIPLE 2D POINTS\n", camn, camn_pt_no,"\n\n"
                        continue

                    idx = idxs[0]

                    frame2d_row = frame2d[idx]
                    x2d_real = frame2d_row['x'], frame2d_row['y']
                    x2d_reproj = R.find2d( cam_id, X3d, distorted = True )
                    dist = np.sqrt(np.sum((x2d_reproj - x2d_real)**2))

                    out['camn'].append(camn)
                    out['frame'].append(framenumber)
                    out['obj_id'].append(obj_id)
                    out['dist'].append(dist)
                    out['z'].append( X3d[2] )
Ejemplo n.º 9
0
def fuse_obj_ids(use_obj_ids, data_file,
                 dynamic_model_name = None,
                 frames_per_second=None):
    """take multiple obj_id tracks and fuse them into one long trajectory

    Current implementation
    ======================
    Load 'observations' (MLEs of fly positions) across all obj_ids,
    and then do Kalman smoothing across all this, which fills in gaps
    (but ignores single camera views).

    """
    ca = core_analysis.get_global_CachingAnalyzer()

    frames = []
    xs = []
    ys = []
    zs = []
    for obj_id in use_obj_ids:
        #print
        #print obj_id
        kalman_rows = ca.load_dynamics_free_MLE_position( obj_id, data_file)

        if len(frames):
            tmp = kalman_rows['frame']
            #print tmp

            #(hmm, why do we care? this seems wrong, anyway...)
            #assert tmp[0] > frames[-1][-1] # ensure new frames follow last

            assert numpy.all((tmp[1:]-tmp[:-1]) > 0) # ensure new frames are ordered

        this_x = kalman_rows['x']
        full_obs_idx = ~numpy.isnan(this_x)
        if 1:
            warnings.warn('dropping last ML estimate of position in fuse_obj_ids because it is frequently noisy')
            full_obs_idx = np.nonzero(full_obs_idx)[0]
            full_obs_idx = full_obs_idx[:-1]
            if not len(full_obs_idx):
                warnings.warn('no data used for obj_id %d in fuse_obj_ids()'%obj_id)
                continue # no data
        frames.append( kalman_rows['frame'][full_obs_idx] )
        xs.append( kalman_rows['x'][full_obs_idx] )
        ys.append( kalman_rows['y'][full_obs_idx] )
        zs.append( kalman_rows['z'][full_obs_idx] )

    frames = numpy.hstack(frames)
    xs = numpy.hstack(xs)
    ys = numpy.hstack(ys)
    zs = numpy.hstack(zs)
    X = numpy.array([xs,ys,zs])

    if 0:
        import pylab
        X=X.T
        ax = pylab.subplot(3,1,1)
        ax.plot( frames, X[:,0],'.' )
        ax = pylab.subplot(3,1,2,sharex=ax)
        ax.plot( frames, X[:,1],'.' )
        ax = pylab.subplot(3,1,3,sharex=ax)
        ax.plot( frames, X[:,2],'.' )
        pylab.show()
        sys.exit()

    # convert to a single continuous masked array
    frames_all = numpy.arange(frames[0],frames[-1]+1)

    xs_all = numpy.ma.masked_array( data=numpy.ones( frames_all.shape ),
                                    mask=numpy.ones( frames_all.shape, dtype=numpy.bool ))
    ys_all = numpy.ma.masked_array( data=numpy.ones( frames_all.shape ),
                                    mask=numpy.ones( frames_all.shape, dtype=numpy.bool ))
    zs_all = numpy.ma.masked_array( data=numpy.ones( frames_all.shape ),
                                    mask=numpy.ones( frames_all.shape, dtype=numpy.bool ))

    idxs = frames_all.searchsorted( frames )
    #idxs = find_first_idxs(frames,frames_all)
    xs_all[idxs] = xs
    ys_all[idxs] = ys
    zs_all[idxs] = zs
    orig_data_present = np.zeros( frames_all.shape, dtype=bool )
    orig_data_present[idxs] = True

    # "obs" == "observations" == ML estimates of position without dynamics
    if 0:
        # requires numpy >= r5284
        obs = numpy.ma.masked_array( [xs_all, ys_all, zs_all] ).T # Nx3 array for N frames of data
    else:
        obs = numpy.ma.hstack( [xs_all[:,numpy.newaxis], ys_all[:,numpy.newaxis], zs_all[:,numpy.newaxis]] )

    if 0:
        import pylab
        X=obs
        frames = frames_all
        ax = pylab.subplot(3,1,1)
        ax.plot( frames, X[:,0],'.' )
        ax = pylab.subplot(3,1,2,sharex=ax)
        ax.plot( frames, X[:,1],'.' )
        ax = pylab.subplot(3,1,3,sharex=ax)
        ax.plot( frames, X[:,2],'.' )
        pylab.show()
        sys.exit()

    if 1:
        # convert from masked array to array with nan
        obs[ obs.mask ] = numpy.nan
        obs = numpy.ma.getdata(obs)

    if 0:
        obs = obs[:100,:]
        print 'obs'
        print obs

    if 1:
        # now do kalman smoothing across all obj_ids
        model = flydra_core.kalman.dynamic_models.get_kalman_model(name=dynamic_model_name,
                                                              dt=(1.0/frames_per_second))
        # initial state guess: postion = observation, other parameters = 0
        ss = model['ss']
        init_x = numpy.zeros( (ss,) )
        init_x[:3] = obs[0,:]

        P_k1=numpy.zeros((ss,ss)) # initial state error covariance guess

        for i in range(0,3):
            P_k1[i,i]=model['initial_position_covariance_estimate']
        for i in range(3,6):
            P_k1[i,i]=model.get('initial_velocity_covariance_estimate',0.0)
        if ss > 6:
            for i in range(6,9):
                P_k1[i,i]=model.get('initial_acceleration_covariance_estimate',0.0)

        if not 'C' in model:
            raise ValueError('model does not have a linear observation matrix "C".')
        xsmooth, Psmooth = adskalman.kalman_smoother(obs,
                                                     model['A'],
                                                     model['C'],
                                                     model['Q'],
                                                     model['R'],
                                                     init_x,
                                                     P_k1,
                                                     )
    X = xsmooth[:,:3] # kalman estimates of position
    if 0:
        print 'X'
        print X
        print dynamic_model_name
        sys.exit()

    recarray = numpy.rec.fromarrays([frames_all,
                                     X[:,0],
                                     X[:,1],
                                     X[:,2],
                                     orig_data_present,
                                     ],
                                    names='frame,x,y,z,orig_data_present')
    return recarray
Ejemplo n.º 10
0
def doit(output_h5_filename=None,
         kalman_filename=None,
         data2d_filename=None,
         start=None,
         stop=None,
         gate_angle_threshold_degrees=40.0,
         area_threshold_for_orientation=0.0,
         obj_only=None,
         options=None):
    gate_angle_threshold_radians = gate_angle_threshold_degrees * D2R

    if options.show:
        import matplotlib.pyplot as plt
        import matplotlib.ticker as mticker

    M = SymobolicModels()
    x = sympy.DeferredVector('x')
    G_symbolic = M.get_observation_model(x)
    dx_symbolic = M.get_process_model(x)

    if 0:
        print 'G_symbolic'
        sympy.pprint(G_symbolic)
        print

    G_linearized = [G_symbolic.diff(x[i]) for i in range(7)]
    if 0:
        print 'G_linearized'
        for i in range(len(G_linearized)):
            sympy.pprint(G_linearized[i])
        print

    arg_tuple_x = (M.P00, M.P01, M.P02, M.P03, M.P10, M.P11, M.P12, M.P13,
                   M.P20, M.P21, M.P22, M.P23, M.Ax, M.Ay, M.Az, x)

    xm = sympy.DeferredVector('xm')
    arg_tuple_x_xm = (M.P00, M.P01, M.P02, M.P03, M.P10, M.P11, M.P12, M.P13,
                      M.P20, M.P21, M.P22, M.P23, M.Ax, M.Ay, M.Az, x, xm)

    eval_G = lambdify(arg_tuple_x, G_symbolic, 'numpy')
    eval_linG = lambdify(arg_tuple_x, G_linearized, 'numpy')

    # coord shift of observation model
    phi_symbolic = M.get_observation_model(xm)

    # H = G - phi
    H_symbolic = G_symbolic - phi_symbolic

    # We still take derivative wrt x (not xm).
    H_linearized = [H_symbolic.diff(x[i]) for i in range(7)]

    eval_phi = lambdify(arg_tuple_x_xm, phi_symbolic, 'numpy')
    eval_H = lambdify(arg_tuple_x_xm, H_symbolic, 'numpy')
    eval_linH = lambdify(arg_tuple_x_xm, H_linearized, 'numpy')

    if 0:
        print 'dx_symbolic'
        sympy.pprint(dx_symbolic)
        print

    eval_dAdt = drop_dims(lambdify(x, dx_symbolic, 'numpy'))

    debug_level = 0
    if debug_level:
        np.set_printoptions(linewidth=130, suppress=True)

    if os.path.exists(output_h5_filename):
        raise RuntimeError("will not overwrite old file '%s'" %
                           output_h5_filename)

    ca = core_analysis.get_global_CachingAnalyzer()
    with open_file_safe(output_h5_filename, mode='w') as output_h5:

        with open_file_safe(kalman_filename, mode='r') as kh5:
            with open_file_safe(data2d_filename, mode='r') as h5:
                for input_node in kh5.root._f_iter_nodes():
                    # copy everything from source to dest
                    input_node._f_copy(output_h5.root, recursive=True)

                try:
                    dest_table = output_h5.root.ML_estimates
                except tables.exceptions.NoSuchNodeError, err1:
                    # backwards compatibility
                    try:
                        dest_table = output_h5.root.kalman_observations
                    except tables.exceptions.NoSuchNodeError, err2:
                        raise err1
                for colname in ['hz_line%d' % i for i in range(6)]:
                    clear_col(dest_table, colname)
                dest_table.flush()

                if options.show:
                    fig1 = plt.figure()
                    ax1 = fig1.add_subplot(511)
                    ax2 = fig1.add_subplot(512, sharex=ax1)
                    ax3 = fig1.add_subplot(513, sharex=ax1)
                    ax4 = fig1.add_subplot(514, sharex=ax1)
                    ax5 = fig1.add_subplot(515, sharex=ax1)
                    ax1.xaxis.set_major_formatter(
                        mticker.FormatStrFormatter("%d"))

                    min_frame_range = np.inf
                    max_frame_range = -np.inf

                reconst = reconstruct.Reconstructor(kh5)

                camn2cam_id, cam_id2camns = result_utils.get_caminfo_dicts(h5)
                fps = result_utils.get_fps(h5)
                dt = 1.0 / fps

                used_camn_dict = {}

                # associate framenumbers with timestamps using 2d .h5 file
                data2d = h5.root.data2d_distorted[:]  # load to RAM
                if start is not None:
                    data2d = data2d[data2d['frame'] >= start]
                if stop is not None:
                    data2d = data2d[data2d['frame'] <= stop]
                data2d_idxs = np.arange(len(data2d))
                h5_framenumbers = data2d['frame']
                h5_frame_qfi = result_utils.QuickFrameIndexer(h5_framenumbers)

                ML_estimates_2d_idxs = (kh5.root.ML_estimates_2d_idxs[:])

                all_kobs_obj_ids = dest_table.read(field='obj_id')
                all_kobs_frames = dest_table.read(field='frame')
                use_obj_ids = np.unique(all_kobs_obj_ids)
                if obj_only is not None:
                    use_obj_ids = obj_only

                if hasattr(kh5.root.kalman_estimates.attrs,
                           'dynamic_model_name'):
                    dynamic_model = kh5.root.kalman_estimates.attrs.dynamic_model_name
                    if dynamic_model.startswith('EKF '):
                        dynamic_model = dynamic_model[4:]
                else:
                    dynamic_model = 'mamarama, units: mm'
                    warnings.warn(
                        'could not determine dynamic model name, using "%s"' %
                        dynamic_model)

                for obj_id_enum, obj_id in enumerate(use_obj_ids):
                    # Use data association step from kalmanization to load potentially
                    # relevant 2D orientations, but discard previous 3D orientation.
                    if obj_id_enum % 100 == 0:
                        print 'obj_id %d (%d of %d)' % (obj_id, obj_id_enum,
                                                        len(use_obj_ids))
                    if options.show:
                        all_xhats = []
                        all_ori = []

                    output_row_obj_id_cond = all_kobs_obj_ids == obj_id

                    obj_3d_rows = ca.load_dynamics_free_MLE_position(
                        obj_id, kh5)
                    if start is not None:
                        obj_3d_rows = obj_3d_rows[
                            obj_3d_rows['frame'] >= start]
                    if stop is not None:
                        obj_3d_rows = obj_3d_rows[obj_3d_rows['frame'] <= stop]

                    try:
                        smoothed_3d_rows = ca.load_data(
                            obj_id,
                            kh5,
                            use_kalman_smoothing=True,
                            frames_per_second=fps,
                            dynamic_model_name=dynamic_model)
                    except core_analysis.NotEnoughDataToSmoothError:
                        continue

                    smoothed_frame_qfi = result_utils.QuickFrameIndexer(
                        smoothed_3d_rows['frame'])

                    slopes_by_camn_by_frame = collections.defaultdict(dict)
                    x0d_by_camn_by_frame = collections.defaultdict(dict)
                    y0d_by_camn_by_frame = collections.defaultdict(dict)
                    pt_idx_by_camn_by_frame = collections.defaultdict(dict)
                    min_frame = np.inf
                    max_frame = -np.inf

                    start_idx = None
                    for this_idx, this_3d_row in enumerate(obj_3d_rows):
                        # iterate over each sample in the current camera
                        framenumber = this_3d_row['frame']

                        if not np.isnan(this_3d_row['hz_line0']):
                            # We have a valid initial 3d orientation guess.
                            if framenumber < min_frame:
                                min_frame = framenumber
                                assert start_idx is None, "frames out of order?"
                                start_idx = this_idx

                        max_frame = max(max_frame, framenumber)
                        h5_2d_row_idxs = h5_frame_qfi.get_frame_idxs(
                            framenumber)

                        frame2d = data2d[h5_2d_row_idxs]
                        frame2d_idxs = data2d_idxs[h5_2d_row_idxs]

                        obs_2d_idx = this_3d_row['obs_2d_idx']
                        kobs_2d_data = ML_estimates_2d_idxs[int(obs_2d_idx)]

                        # Parse VLArray.
                        this_camns = kobs_2d_data[0::2]
                        this_camn_idxs = kobs_2d_data[1::2]

                        # Now, for each camera viewing this object at this
                        # frame, extract images.
                        for camn, camn_pt_no in zip(this_camns,
                                                    this_camn_idxs):
                            # find 2D point corresponding to object
                            cam_id = camn2cam_id[camn]

                            cond = ((frame2d['camn'] == camn) &
                                    (frame2d['frame_pt_idx'] == camn_pt_no))
                            idxs = np.nonzero(cond)[0]
                            if len(idxs) == 0:
                                continue
                            assert len(idxs) == 1
                            ## if len(idxs)!=1:
                            ##     raise ValueError('expected one (and only one) frame, got %d'%len(idxs))
                            idx = idxs[0]

                            orig_data2d_rownum = frame2d_idxs[idx]
                            frame_timestamp = frame2d[idx]['timestamp']

                            row = frame2d[idx]
                            assert framenumber == row['frame']
                            if ((row['eccentricity'] <
                                 reconst.minimum_eccentricity)
                                    or (row['area'] <
                                        area_threshold_for_orientation)):
                                slopes_by_camn_by_frame[camn][
                                    framenumber] = np.nan
                                x0d_by_camn_by_frame[camn][
                                    framenumber] = np.nan
                                y0d_by_camn_by_frame[camn][
                                    framenumber] = np.nan
                                pt_idx_by_camn_by_frame[camn][
                                    framenumber] = camn_pt_no
                            else:
                                slopes_by_camn_by_frame[camn][
                                    framenumber] = row['slope']
                                x0d_by_camn_by_frame[camn][framenumber] = row[
                                    'x']
                                y0d_by_camn_by_frame[camn][framenumber] = row[
                                    'y']
                                pt_idx_by_camn_by_frame[camn][
                                    framenumber] = camn_pt_no

                    if start_idx is None:
                        warnings.warn("skipping obj_id %d: "
                                      "could not find valid start frame" %
                                      obj_id)
                        continue

                    obj_3d_rows = obj_3d_rows[start_idx:]

                    # now collect in a numpy array for all cam

                    assert int(min_frame) == min_frame
                    assert int(max_frame + 1) == max_frame + 1
                    frame_range = np.arange(int(min_frame), int(max_frame + 1))
                    if debug_level >= 1:
                        print 'frame range %d-%d' % (frame_range[0],
                                                     frame_range[-1])
                    camn_list = slopes_by_camn_by_frame.keys()
                    camn_list.sort()
                    cam_id_list = [camn2cam_id[camn] for camn in camn_list]
                    n_cams = len(camn_list)
                    n_frames = len(frame_range)

                    save_cols = {}
                    save_cols['frame'] = []
                    for camn in camn_list:
                        save_cols['dist%d' % camn] = []
                        save_cols['used%d' % camn] = []
                        save_cols['theta%d' % camn] = []

                    # NxM array with rows being frames and cols being cameras
                    slopes = np.ones((n_frames, n_cams), dtype=np.float)
                    x0ds = np.ones((n_frames, n_cams), dtype=np.float)
                    y0ds = np.ones((n_frames, n_cams), dtype=np.float)
                    for j, camn in enumerate(camn_list):

                        slopes_by_frame = slopes_by_camn_by_frame[camn]
                        x0d_by_frame = x0d_by_camn_by_frame[camn]
                        y0d_by_frame = y0d_by_camn_by_frame[camn]

                        for frame_idx, absolute_frame_number in enumerate(
                                frame_range):

                            slopes[frame_idx, j] = slopes_by_frame.get(
                                absolute_frame_number, np.nan)
                            x0ds[frame_idx,
                                 j] = x0d_by_frame.get(absolute_frame_number,
                                                       np.nan)
                            y0ds[frame_idx,
                                 j] = y0d_by_frame.get(absolute_frame_number,
                                                       np.nan)

                        if options.show:
                            frf = np.array(frame_range, dtype=np.float)
                            min_frame_range = min(np.min(frf), min_frame_range)
                            max_frame_range = max(np.max(frf), max_frame_range)

                            ax1.plot(frame_range,
                                     slope2modpi(slopes[:, j]),
                                     '.',
                                     label=camn2cam_id[camn])

                    if options.show:
                        ax1.legend()

                    if 1:
                        # estimate orientation of initial frame
                        row0 = obj_3d_rows[:
                                           1]  # take only first row but keep as 1d array
                        hzlines = np.array([
                            row0['hz_line0'], row0['hz_line1'],
                            row0['hz_line2'], row0['hz_line3'],
                            row0['hz_line4'], row0['hz_line5']
                        ]).T
                        directions = reconstruct.line_direction(hzlines)
                        q0 = PQmath.orientation_to_quat(directions[0])
                        assert not np.isnan(
                            q0.x), "cannot start with missing orientation"
                        w0 = 0, 0, 0  # no angular rate
                        init_x = np.array(
                            [w0[0], w0[1], w0[2], q0.x, q0.y, q0.z, q0.w])

                        Pminus = np.zeros((7, 7))

                        # angular rate part of state variance is .5
                        for i in range(0, 3):
                            Pminus[i, i] = .5

                        # quaternion part of state variance is 1
                        for i in range(3, 7):
                            Pminus[i, i] = 1

                    if 1:
                        # setup of noise estimates
                        Q = np.zeros((7, 7))

                        # angular rate part of state variance
                        for i in range(0, 3):
                            Q[i, i] = Q_scalar_rate

                        # quaternion part of state variance
                        for i in range(3, 7):
                            Q[i, i] = Q_scalar_quat

                    preA = np.eye(7)

                    ekf = kalman_ekf.EKF(init_x, Pminus)
                    previous_posterior_x = init_x
                    if options.show:
                        _save_plot_rows = []
                        _save_plot_rows_used = []
                    for frame_idx, absolute_frame_number in enumerate(
                            frame_range):
                        # Evaluate the Jacobian of the process update
                        # using previous frame's posterior estimate. (This
                        # is not quite the same as this frame's prior
                        # estimate. The difference this frame's prior
                        # estimate is _after_ the process update
                        # model. Which we need to get doing this.)

                        if options.show:
                            _save_plot_rows.append(np.nan * np.ones(
                                (n_cams, )))
                            _save_plot_rows_used.append(np.nan * np.ones(
                                (n_cams, )))

                        this_dx = eval_dAdt(previous_posterior_x)
                        A = preA + this_dx * dt
                        if debug_level >= 1:
                            print
                            print 'frame', absolute_frame_number, '-' * 40
                            print 'previous posterior', previous_posterior_x
                            if debug_level > 6:
                                print 'A'
                                print A

                        xhatminus, Pminus = ekf.step1__calculate_a_priori(A, Q)
                        if debug_level >= 1:
                            print 'new prior', xhatminus

                        # 1. Gate per-camera orientations.

                        this_frame_slopes = slopes[frame_idx, :]
                        this_frame_theta_measured = slope2modpi(
                            this_frame_slopes)
                        this_frame_x0d = x0ds[frame_idx, :]
                        this_frame_y0d = y0ds[frame_idx, :]
                        if debug_level >= 5:
                            print 'this_frame_slopes', this_frame_slopes

                        save_cols['frame'].append(absolute_frame_number)
                        for j, camn in enumerate(camn_list):
                            # default to no detection, change below
                            save_cols['dist%d' % camn].append(np.nan)
                            save_cols['used%d' % camn].append(0)
                            save_cols['theta%d' % camn].append(
                                this_frame_theta_measured[j])

                        all_data_this_frame_missing = False
                        gate_vector = None

                        y = []  # observation (per camera)
                        hx = []  # expected observation (per camera)
                        C = []  # linearized observation model (per camera)
                        N_obs_this_frame = 0
                        cams_without_data = np.isnan(this_frame_slopes)
                        if np.all(cams_without_data):
                            all_data_this_frame_missing = True

                        smoothed_pos_idxs = smoothed_frame_qfi.get_frame_idxs(
                            absolute_frame_number)
                        if len(smoothed_pos_idxs) == 0:
                            all_data_this_frame_missing = True
                            smoothed_pos_idx = None
                            smooth_row = None
                            center_position = None
                        else:
                            try:
                                assert len(smoothed_pos_idxs) == 1
                            except:
                                print 'obj_id', obj_id
                                print 'absolute_frame_number', absolute_frame_number
                                if len(frame_range):
                                    print 'frame_range[0],frame_rang[-1]', frame_range[
                                        0], frame_range[-1]
                                else:
                                    print 'no frame range'
                                print 'len(smoothed_pos_idxs)', len(
                                    smoothed_pos_idxs)
                                raise
                            smoothed_pos_idx = smoothed_pos_idxs[0]
                            smooth_row = smoothed_3d_rows[smoothed_pos_idx]
                            assert smooth_row['frame'] == absolute_frame_number
                            center_position = np.array(
                                (smooth_row['x'], smooth_row['y'],
                                 smooth_row['z']))
                            if debug_level >= 2:
                                print 'center_position', center_position

                        if not all_data_this_frame_missing:
                            if expected_orientation_method == 'trust_prior':
                                state_for_phi = xhatminus  # use a priori
                            elif expected_orientation_method == 'SVD_line_fits':
                                # construct matrix of planes
                                P = []
                                for camn_idx in range(n_cams):
                                    this_x0d = this_frame_x0d[camn_idx]
                                    this_y0d = this_frame_y0d[camn_idx]
                                    slope = this_frame_slopes[camn_idx]
                                    plane, ray = reconst.get_3D_plane_and_ray(
                                        cam_id, this_x0d, this_y0d, slope)
                                    if np.isnan(plane[0]):
                                        continue
                                    P.append(plane)
                                if len(P) < 2:
                                    # not enough data to do SVD... fallback to prior
                                    state_for_phi = xhatminus  # use a priori
                                else:
                                    Lco = reconstruct.intersect_planes_to_find_line(
                                        P)
                                    q = PQmath.pluecker_to_quat(Lco)
                                    state_for_phi = cgtypes_quat2statespace(q)

                            cams_with_data = ~cams_without_data
                            possible_cam_idxs = np.nonzero(cams_with_data)[0]
                            if debug_level >= 6:
                                print 'possible_cam_idxs', possible_cam_idxs
                            gate_vector = np.zeros((n_cams, ), dtype=np.bool)
                            ## flip_vector = np.zeros( (n_cams,), dtype=np.bool)
                            for camn_idx in possible_cam_idxs:
                                cam_id = cam_id_list[camn_idx]
                                camn = camn_list[camn_idx]

                                # This ignores distortion. To incorporate
                                # distortion, this would require
                                # appropriate scaling of orientation
                                # vector, which would require knowing
                                # target's size. In which case we should
                                # track head and tail separately and not
                                # use this whole quaternion mess.

                                ## theta_measured=slope2modpi(
                                ##     this_frame_slopes[camn_idx])
                                theta_measured = this_frame_theta_measured[
                                    camn_idx]
                                if debug_level >= 6:
                                    print 'cam_id %s, camn %d' % (cam_id, camn)
                                if debug_level >= 3:
                                    a = reconst.find2d(cam_id, center_position)
                                    other_position = get_point_on_line(
                                        xhatminus, center_position)
                                    b = reconst.find2d(cam_id, other_position)
                                    theta_expected = find_theta_mod_pi_between_points(
                                        a, b)
                                    print('  theta_expected,theta_measured',
                                          theta_expected * R2D,
                                          theta_measured * R2D)

                                P = reconst.get_pmat(cam_id)
                                if 0:
                                    args_x = (P[0, 0], P[0, 1], P[0, 2],
                                              P[0, 3], P[1, 0], P[1, 1], P[1,
                                                                           2],
                                              P[1, 3], P[2, 0], P[2, 1], P[2,
                                                                           2],
                                              P[2, 3], center_position[0],
                                              center_position[1],
                                              center_position[2], xhatminus)
                                    this_y = theta_measured
                                    this_hx = eval_G(*args_x)
                                    this_C = eval_linG(*args_x)
                                else:
                                    args_x_xm = (P[0, 0], P[0, 1], P[0, 2],
                                                 P[0, 3], P[1, 0], P[1,
                                                                     1], P[1,
                                                                           2],
                                                 P[1, 3], P[2, 0], P[2,
                                                                     1], P[2,
                                                                           2],
                                                 P[2, 3], center_position[0],
                                                 center_position[1],
                                                 center_position[2], xhatminus,
                                                 state_for_phi)
                                    this_phi = eval_phi(*args_x_xm)
                                    this_y = angle_diff(theta_measured,
                                                        this_phi,
                                                        mod_pi=True)
                                    this_hx = eval_H(*args_x_xm)
                                    this_C = eval_linH(*args_x_xm)
                                    if debug_level >= 3:
                                        print('  this_phi,this_y',
                                              this_phi * R2D, this_y * R2D)

                                save_cols['dist%d' % camn][-1] = this_y  # save

                                # gate
                                if abs(this_y) < gate_angle_threshold_radians:
                                    save_cols['used%d' % camn][-1] = 1
                                    gate_vector[camn_idx] = 1
                                    if debug_level >= 3:
                                        print '    good'
                                    if options.show:
                                        _save_plot_rows_used[-1][
                                            camn_idx] = this_y
                                    y.append(this_y)
                                    hx.append(this_hx)
                                    C.append(this_C)
                                    N_obs_this_frame += 1

                                    # Save which camn and camn_pt_no was used.
                                    if absolute_frame_number not in used_camn_dict:
                                        used_camn_dict[
                                            absolute_frame_number] = []
                                    camn_pt_no = (pt_idx_by_camn_by_frame[camn]
                                                  [absolute_frame_number])
                                    used_camn_dict[
                                        absolute_frame_number].append(
                                            (camn, camn_pt_no))
                                else:
                                    if options.show:
                                        _save_plot_rows[-1][camn_idx] = this_y
                                    if debug_level >= 6:
                                        print '    bad'
                            if debug_level >= 1:
                                print 'gate_vector', gate_vector
                                #print 'flip_vector',flip_vector
                            all_data_this_frame_missing = not bool(
                                np.sum(gate_vector))

                        # 3. Construct observations model using all
                        # gated-in camera orientations.

                        if all_data_this_frame_missing:
                            C = None
                            R = None
                            hx = None
                        else:
                            C = np.array(C)
                            R = R_scalar * np.eye(N_obs_this_frame)
                            hx = np.array(hx)
                            if 0:
                                # crazy observation error scaling
                                for i in range(N_obs_this_frame):
                                    beyond = abs(y[i]) - 10 * D2R
                                    beyond = max(0, beyond)  # clip at zero
                                    R[i:i] = R_scalar * (1 + 10 * beyond)
                            if debug_level >= 6:
                                print 'full values'
                                print 'C', C
                                print 'hx', hx
                                print 'y', y
                                print 'R', R

                        if debug_level >= 1:
                            print 'all_data_this_frame_missing', all_data_this_frame_missing
                        xhat, P = ekf.step2__calculate_a_posteriori(
                            xhatminus,
                            Pminus,
                            y=y,
                            hx=hx,
                            C=C,
                            R=R,
                            missing_data=all_data_this_frame_missing)
                        if debug_level >= 1:
                            print 'xhat', xhat
                        previous_posterior_x = xhat
                        if center_position is not None:
                            # save
                            output_row_frame_cond = all_kobs_frames == absolute_frame_number
                            output_row_cond = output_row_frame_cond & output_row_obj_id_cond
                            output_idxs = np.nonzero(output_row_cond)[0]
                            if len(output_idxs) == 0:
                                pass
                            else:
                                assert len(output_idxs) == 1
                                idx = output_idxs[0]
                                hz = state_to_hzline(xhat, center_position)
                                for row in dest_table.iterrows(start=idx,
                                                               stop=(idx + 1)):
                                    for i in range(6):
                                        row['hz_line%d' % i] = hz[i]
                                    row.update()
                        ## xhat_results[ obj_id ][absolute_frame_number ] = (
                        ##     xhat,center_position)
                        if options.show:
                            all_xhats.append(xhat)
                            all_ori.append(state_to_ori(xhat))

                    # save to H5 file
                    names = [colname for colname in save_cols]
                    names.sort()
                    arrays = []
                    for name in names:
                        if name == 'frame':
                            dtype = np.int64
                        elif name.startswith('dist'):
                            dtype = np.float32
                        elif name.startswith('used'):
                            dtype = np.bool
                        elif name.startswith('theta'):
                            dtype = np.float32
                        else:
                            raise NameError('unknown name %s' % name)
                        arr = np.array(save_cols[name], dtype=dtype)
                        arrays.append(arr)
                    save_recarray = np.rec.fromarrays(arrays, names=names)
                    h5group = core_analysis.get_group_for_obj(obj_id,
                                                              output_h5,
                                                              writeable=True)
                    output_h5.create_table(h5group,
                                           'obj%d' % obj_id,
                                           save_recarray,
                                           filters=tables.Filters(
                                               1, complib='lzo'))

                    if options.show:
                        all_xhats = np.array(all_xhats)
                        all_ori = np.array(all_ori)
                        _save_plot_rows = np.array(_save_plot_rows)
                        _save_plot_rows_used = np.array(_save_plot_rows_used)

                        ax2.plot(frame_range, all_xhats[:, 0], '.', label='p')
                        ax2.plot(frame_range, all_xhats[:, 1], '.', label='q')
                        ax2.plot(frame_range, all_xhats[:, 2], '.', label='r')
                        ax2.legend()

                        ax3.plot(frame_range, all_xhats[:, 3], '.', label='a')
                        ax3.plot(frame_range, all_xhats[:, 4], '.', label='b')
                        ax3.plot(frame_range, all_xhats[:, 5], '.', label='c')
                        ax3.plot(frame_range, all_xhats[:, 6], '.', label='d')
                        ax3.legend()

                        ax4.plot(frame_range, all_ori[:, 0], '.', label='x')
                        ax4.plot(frame_range, all_ori[:, 1], '.', label='y')
                        ax4.plot(frame_range, all_ori[:, 2], '.', label='z')
                        ax4.legend()

                        colors = []
                        for i in range(n_cams):
                            line, = ax5.plot(frame_range,
                                             _save_plot_rows_used[:, i] * R2D,
                                             'o',
                                             label=cam_id_list[i])
                            colors.append(line.get_color())
                        for i in range(n_cams):
                            # loop again to get normal MPL color cycling
                            ax5.plot(frame_range,
                                     _save_plot_rows[:, i] * R2D,
                                     'o',
                                     mec=colors[i],
                                     ms=1.0)
                        ax5.set_ylabel('observation (deg)')
                        ax5.legend()
Ejemplo n.º 11
0
def plot_ori(
    kalman_filename=None,
    h5=None,
    obj_only=None,
    start=None,
    stop=None,
    output_filename=None,
    options=None,
):
    if output_filename is not None:
        import matplotlib

        matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    import matplotlib.ticker as mticker

    fps = None
    if h5 is not None:
        h5f = tables.open_file(h5, mode="r")
        camn2cam_id, cam_id2camns = result_utils.get_caminfo_dicts(h5f)
        fps = result_utils.get_fps(h5f)
        h5f.close()
    else:
        camn2cam_id = {}

    use_kalman_smoothing = options.use_kalman_smoothing
    fps = options.fps
    dynamic_model = options.dynamic_model

    ca = core_analysis.get_global_CachingAnalyzer()

    (obj_ids, use_obj_ids, is_mat_file, data_file,
     extra) = ca.initial_file_load(kalman_filename)

    if dynamic_model is None:
        dynamic_model = extra["dynamic_model_name"]
        print('detected file loaded with dynamic model "%s"' % dynamic_model)
        if dynamic_model.startswith("EKF "):
            dynamic_model = dynamic_model[4:]
        print('  for smoothing, will use dynamic model "%s"' % dynamic_model)

    with open_file_safe(kalman_filename, mode="r") as kh5:
        if fps is None:
            fps = result_utils.get_fps(kh5, fail_on_error=True)

        kmle = kh5.root.ML_estimates[:]  # load into RAM

        if start is not None:
            kmle = kmle[kmle["frame"] >= start]

        if stop is not None:
            kmle = kmle[kmle["frame"] <= stop]

        all_mle_obj_ids = kmle["obj_id"]

        # walk all tables to get all obj_ids
        all_obj_ids = {}
        parent = kh5.root.ori_ekf_qual
        for group in parent._f_iter_nodes():
            for table in group._f_iter_nodes():
                assert table.name.startswith("obj")
                obj_id = int(table.name[3:])
                all_obj_ids[obj_id] = table

        if obj_only is None:
            use_obj_ids = all_obj_ids.keys()
            mle_use_obj_ids = list(np.unique(all_mle_obj_ids))
            missing_objs = list(set(mle_use_obj_ids) - set(use_obj_ids))
            if len(missing_objs):
                warnings.warn("orientation not fit for %d obj_ids" %
                              (len(missing_objs), ))
            use_obj_ids.sort()
        else:
            use_obj_ids = obj_only

        # now, generate plots
        fig = plt.figure()
        ax1 = fig.add_subplot(511)
        ax2 = fig.add_subplot(512, sharex=ax1)
        ax3 = fig.add_subplot(513, sharex=ax1)
        ax4 = fig.add_subplot(514, sharex=ax1)
        ax5 = fig.add_subplot(515, sharex=ax1)

        min_frame_range = np.inf
        max_frame_range = -np.inf

        if options.print_status:
            print("%d object IDs in file" % (len(use_obj_ids), ))
        for obj_id in use_obj_ids:
            table = all_obj_ids[obj_id]
            rows = table[:]

            if start is not None:
                rows = rows[rows["frame"] >= start]

            if stop is not None:
                rows = rows[rows["frame"] <= stop]

            if options.print_status:
                print("obj_id %d: %d rows of EKF data" % (obj_id, len(rows)))

            frame = rows["frame"]
            # get camns
            camns = []
            for colname in table.colnames:
                if colname.startswith("dist"):
                    camn = int(colname[4:])
                    camns.append(camn)
            for camn in camns:
                label = camn2cam_id.get(camn, "camn%d" % camn)
                theta = rows["theta%d" % camn]
                used = rows["used%d" % camn]
                dist = rows["dist%d" % camn]

                frf = np.array(frame, dtype=np.float)
                min_frame_range = min(np.min(frf), min_frame_range)
                max_frame_range = max(np.max(frf), max_frame_range)

                (line, ) = ax1.plot(frame,
                                    theta * R2D,
                                    "o",
                                    mew=0,
                                    ms=2.0,
                                    label=label)
                c = line.get_color()
                ax2.plot(frame[used],
                         dist[used] * R2D,
                         "o",
                         color=c,
                         mew=0,
                         label=label)
                ax2.plot(frame[~used],
                         dist[~used] * R2D,
                         "o",
                         color=c,
                         mew=0,
                         ms=2.0)
            # plot 3D orientation
            mle_row_cond = all_mle_obj_ids == obj_id
            rows_this_obj = kmle[mle_row_cond]
            if options.print_status:
                print("obj_id %d: %d rows of ML data" %
                      (obj_id, len(rows_this_obj)))
            frame = rows_this_obj["frame"]
            hz = [rows_this_obj["hz_line%d" % i] for i in range(6)]
            # hz = np.rec.fromarrays(hz,names=['hz%d'%for i in range(6)])
            hz = np.vstack(hz).T
            orient = reconstruct.line_direction(hz)
            ax3.plot(frame, orient[:, 0], "ro", mew=0, ms=2.0, label="x")
            ax3.plot(frame, orient[:, 1], "go", mew=0, ms=2.0, label="y")
            ax3.plot(frame, orient[:, 2], "bo", mew=0, ms=2.0, label="z")

            qual = compute_ori_quality(kh5, rows_this_obj["frame"], obj_id)
            if 1:
                orinan = np.array(orient, copy=True)
                if options.ori_qual is not None and options.ori_qual != 0:
                    orinan[qual < options.ori_qual] = np.nan
                try:
                    sori = ori_smooth(orinan, frames_per_second=fps)
                except AssertionError:
                    if options.print_status:
                        print("not plotting smoothed ori for object id %d" %
                              (obj_id, ))
                    else:
                        pass
                else:
                    ax3.plot(frame, sori[:, 0], "r-", mew=0,
                             ms=2.0)  # ,label='x')
                    ax3.plot(frame, sori[:, 1], "g-", mew=0,
                             ms=2.0)  # ,label='y')
                    ax3.plot(frame, sori[:, 2], "b-", mew=0,
                             ms=2.0)  # ,label='z')

            ax4.plot(frame, qual, "b-")  # , mew=0, ms=3 )

            # --------------
            kalman_rows = ca.load_data(
                obj_id,
                kh5,
                use_kalman_smoothing=use_kalman_smoothing,
                dynamic_model_name=dynamic_model,
                return_smoothed_directions=options.smooth_orientations,
                frames_per_second=fps,
                up_dir=options.up_dir,
                min_ori_quality_required=options.ori_qual,
            )
            frame = kalman_rows["frame"]
            cond = np.ones(frame.shape, dtype=np.bool)
            if options.start is not None:
                cond &= options.start <= frame
            if options.stop is not None:
                cond &= frame <= options.stop
            kalman_rows = kalman_rows[cond]

            frame = kalman_rows["frame"]
            Dx = Dy = Dz = None
            if options.smooth_orientations:
                Dx = kalman_rows["dir_x"]
                Dy = kalman_rows["dir_y"]
                Dz = kalman_rows["dir_z"]
            elif "rawdir_x" in kalman_rows.dtype.fields:
                Dx = kalman_rows["rawdir_x"]
                Dy = kalman_rows["rawdir_y"]
                Dz = kalman_rows["rawdir_z"]

            if Dx is not None:
                ax5.plot(frame, Dx, "r-", label="dx")
                ax5.plot(frame, Dy, "g-", label="dy")
                ax5.plot(frame, Dz, "b-", label="dz")

    ax1.xaxis.set_major_formatter(mticker.FormatStrFormatter("%d"))
    ax1.set_ylabel("theta (deg)")
    ax1.legend()

    ax2.set_ylabel("z (deg)")
    ax2.legend()

    ax3.set_ylabel("ori")
    ax3.legend()

    ax4.set_ylabel("quality")

    ax5.set_ylabel("dir")
    ax5.set_xlabel("frame")
    ax5.legend()

    ax1.set_xlim(min_frame_range, max_frame_range)
    if output_filename is None:
        plt.show()
    else:
        plt.savefig(output_filename)
Ejemplo n.º 12
0
def main():
    np.seterr(all='raise')

    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("--output_dir",
                        default='saccade_detect_output',
                        help="Output directory")

    parser.add_argument("--min_frames_per_track",
                        default=400,
                        help="Minimum number of frames per track [= %default]")

    parser.add_argument("--confirm_problems",
                        help="Stop interactively on problems with log files'\
                      '(e.g.: cannot find valid obj_ids) [default: %default]",
                        default=False,
                        action="store_true")

    parser.add_argument("--dynamic_model_name",
                        help="Smoothing dynamical model [default: %default]",
                        default="mamarama, units: mm")

    parser.add_argument("--debug_output",
                        help="Creates debug figures.",
                        default=False,
                        action="store_true")

    parser.add_argument("--nocache",
                        help="Ignores already computed results.",
                        default=False,
                        action="store_true")

    parser.add_argument("--smoothing",
                        help="Uses Kalman-smoothed data.",
                        default=False,
                        action="store_true")

    parser.add_argument("--fps",
                        help="Framerate of the recording.",
                        default=100,
                        action="store_true")

    (options, args) = parser.parse_known_args()

    # detection parameters
    dt = 1.0 / options.fps
    parser.add_argument("--deltaT_inner_sec",
                        default=4 * dt,
                        type=float,
                        help="Inner interval.")
    parser.add_argument("--deltaT_outer_sec",
                        default=10 * dt,
                        type=float,
                        help="Outer interval.")
    parser.add_argument("--min_amplitude_deg",
                        default=25,
                        type=float,
                        help="Minimum saccade amplitude (deg).")
    parser.add_argument("--min_linear_velocity",
                        default=0.1,
                        type=float,
                        help="Minimum linear velocity when saccading (m/s).")
    parser.add_argument(
        "--max_linear_acceleration",
        default=20,
        type=float,
        help="Maximum linear acceleration when saccading (m/s^2).")
    parser.add_argument(
        "--max_angular_velocity",
        default=8000,
        type=float,
        help="Maximum angular velocity when saccading (deg/s).")
    parser.add_argument("--max_orientation_dispersion_deg",
                        default=15,
                        type=float,
                        help="Maximum dispersion (deg).")
    parser.add_argument("--minimum_interval_sec",
                        default=10 * dt,
                        type=float,
                        help="Minimum interval between saccades.")

    options = parser.parse_args()

    # if not args:
    #     logger.error('No files or directories specified.')
    #     sys.exit(-1)

    # Create processed string
    processed = 'geometric_saccade_detector %s %s %s@%s Python %s' % \
                (__version__, datetime.now().strftime("%Y%m%d_%H%M%S"),
                get_user(), platform.node(), platform.python_version())

    if not os.path.exists(options.output_dir):
        os.makedirs(options.output_dir)

    good_files = get_good_files(where=args,
                                pattern="*.kh5",
                                confirm_problems=options.confirm_problems)

    if len(good_files) == 0:
        logger.error("No good files to process.")
        sys.exit(1)

    try:
        n = len(good_files)
        for i in range(n):
            (filename, obj_ids, stim_fname) = good_files[i]
            # only maintain basename
            stim_fname = os.path.splitext(os.path.basename(stim_fname))[0]
            basename = os.path.splitext(os.path.basename(filename))[0]

            output_basename = os.path.join(options.output_dir,
                                           basename + '-saccades')
            output_saccades_hdf = output_basename + '.h5'

            if os.path.exists(output_saccades_hdf) and not options.nocache:
                logger.info('File %r exists; skipping. '
                            '(use --nocache to ignore)' % output_saccades_hdf)
                continue

            logger.info("File %d/%d %s %s %s " %
                        (i, n, str(filename), str(obj_ids), stim_fname))

            # concatenate all in one track
            all_data = None

            for _, rows in get_good_smoothed_tracks(
                    filename=filename,
                    obj_ids=obj_ids,
                    min_frames_per_track=options.min_frames_per_track,
                    dynamic_model_name=options.dynamic_model_name,
                    use_smoothing=options.smoothing):

                all_data = rows.copy() if all_data is None \
                            else np.concatenate((all_data, rows))

            if all_data is None:
                logger.info('Not enough data found for %s; skipping.' %
                            filename)
                continue

            params = {
                'deltaT_inner_sec': options.deltaT_inner_sec,
                'deltaT_outer_sec': options.deltaT_outer_sec,
                'min_amplitude_deg': options.min_amplitude_deg,
                'max_orientation_dispersion_deg':
                options.max_orientation_dispersion_deg,
                'minimum_interval_sec': options.minimum_interval_sec,
                'max_linear_acceleration': options.max_linear_acceleration,
                'min_linear_velocity': options.min_linear_velocity,
                'max_angular_velocity': options.max_angular_velocity,
            }
            saccades, annotated_data = geometric_saccade_detect(
                all_data, params)

            for saccade in saccades:
                check_saccade_is_well_formed(saccade)

            # other fields used for managing different samples,
            # used in the analysis
            saccades['species'] = 'Dmelanogaster'
            saccades['stimulus'] = stim_fname
            sample_name = 'DATA' + timestamp_string_from_filename(filename)
            saccades['sample'] = sample_name
            saccades['sample_num'] = -1  # will be filled in by someone else
            saccades['processed'] = processed

            logger.info("Writing to %s {h5,mat,pickle}" % output_basename)
            saccades_write_all(output_basename, saccades)

            # Write debug figures
            if options.debug_output:
                debug_output_dir = os.path.join(options.output_dir, basename)
                logger.info("Writing HTML+png to %s" % debug_output_dir)
                write_debug_output(debug_output_dir, basename, annotated_data,
                                   saccades)

    except Exception as e:
        logger.error('Error while processing. Exception and traceback follow.')
        logger.error(str(e))
        logger.error(traceback.format_exc())
        sys.exit(-2)

    finally:
        print('Closing flydra cache')
        ca = core_analysis.get_global_CachingAnalyzer()
        ca.close()

    sys.exit(0)
Ejemplo n.º 13
0
def doit(
    movie_fname=None,
    reconstructor_fname=None,
    h5_fname=None,
    cam_id=None,
    dest_dir=None,
    transform=None,
    start=None,
    stop=None,
    h5start=None,
    h5stop=None,
    show_obj_ids=False,
    obj_only=None,
    image_format=None,
    subtract_frame=None,
    save_framelist_fname=None,
):

    if dest_dir is None:
        dest_dir = os.curdir

    if movie_fname is None:
        raise NotImplementedError('')

    if image_format is None:
        image_format = 'png'

    if cam_id is None:
        raise NotImplementedError('')

    if movie_fname.lower().endswith('.fmf'):
        movie = fmf_mod.FlyMovie(movie_fname)
    else:
        movie = ufmf_mod.FlyMovieEmulator(movie_fname)

    if start is None:
        start = 0

    if stop is None:
        stop = movie.get_n_frames() - 1

    ca = core_analysis.get_global_CachingAnalyzer()
    (obj_ids, unique_obj_ids, is_mat_file, data_file, extra) = \
              ca.initial_file_load(h5_fname)
    if obj_only is not None:
        unique_obj_ids = obj_only

    dynamic_model_name = extra['dynamic_model_name']
    if dynamic_model_name.startswith('EKF'):
        dynamic_model_name = dynamic_model_name[4:]

    if reconstructor_fname is None:
        reconstructor = flydra_core.reconstruct.Reconstructor(data_file)
    else:
        reconstructor = flydra_core.reconstruct.Reconstructor(
            reconstructor_fname)

    fix_w = movie.get_width()
    fix_h = movie.get_height()
    is_color = imops.is_coding_color(movie.get_format())

    if subtract_frame is not None:
        if not subtract_frame.endswith('.fmf'):
            raise NotImplementedError(
                'only fmf supported for --subtract-frame')
        tmp_fmf = fmf_mod.FlyMovie(subtract_frame)

        if is_color:
            tmp_frame, tmp_timestamp = tmp_fmf.get_next_frame()
            subtract_frame = imops.to_rgb8(tmp_fmf.get_format(), tmp_frame)
            subtract_frame = subtract_frame.astype(
                np.float32)  # force upconversion to float
        else:
            tmp_frame, tmp_timestamp = tmp_fmf.get_next_frame()
            subtract_frame = imops.to_mono8(tmp_fmf.get_format(), tmp_frame)
            subtract_frame = subtract_frame.astype(
                np.float32)  # force upconversion to float

    if save_framelist_fname is not None:
        save_framelist_fd = open(save_framelist_fname, mode='w')

    movie_fno_count = 0
    for movie_fno in range(start, stop + 1):
        movie.seek(movie_fno)
        image, timestamp = movie.get_next_frame()
        h5_frame = extra['time_model'].timestamp2framestamp(timestamp)
        if h5start is not None:
            if h5_frame < h5start:
                continue
        if h5stop is not None:
            if h5_frame > h5stop:
                continue
        if is_color:
            image = imops.to_rgb8(movie.get_format(), image)
        else:
            image = imops.to_mono8(movie.get_format(), image)
        if subtract_frame is not None:
            new_image = np.clip(image - subtract_frame, 0, 255)
            image = new_image.astype(np.uint8)
        warnings.warn('not implemented: interpolating data')
        h5_frame = int(round(h5_frame))
        if save_framelist_fname is not None:
            save_framelist_fd.write('%d\n' % h5_frame)

        movie_fno_count += 1
        if 0:
            # save starting from frame 1
            save_fname_path = os.path.splitext(movie_fname)[
                0] + '_frame%06d.%s' % (movie_fno_count, image_format)
        else:
            # frame is frame in movie file
            save_fname_path = os.path.splitext(
                movie_fname)[0] + '_frame%06d.%s' % (movie_fno, image_format)
        save_fname_path = os.path.join(dest_dir, save_fname_path)
        if transform in ['rot 90', 'rot -90']:
            device_rect = (0, 0, fix_h, fix_w)
            canv = benu.Canvas(save_fname_path, fix_h, fix_w)
        else:
            device_rect = (0, 0, fix_w, fix_h)
            canv = benu.Canvas(save_fname_path, fix_w, fix_h)
        user_rect = (0, 0, image.shape[1], image.shape[0])
        show_points = []
        with canv.set_user_coords(device_rect, user_rect, transform=transform):
            canv.imshow(image, 0, 0)
            for obj_id in unique_obj_ids:
                try:
                    data = ca.load_data(
                        obj_id,
                        data_file,
                        frames_per_second=extra['frames_per_second'],
                        dynamic_model_name=dynamic_model_name,
                    )
                except core_analysis.NotEnoughDataToSmoothError:
                    continue
                cond = data['frame'] == h5_frame
                idxs = np.nonzero(cond)[0]
                if not len(idxs):
                    continue  # no data at this frame for this obj_id
                assert len(idxs) == 1
                idx = idxs[0]
                row = data[idx]

                # circle over data point
                xyz = row['x'], row['y'], row['z']
                x2d, y2d = reconstructor.find2d(cam_id, xyz, distorted=True)
                radius = 10
                canv.scatter([x2d], [y2d],
                             color_rgba=green,
                             markeredgewidth=3,
                             radius=radius)

                if 1:
                    # z line to XY plane through origin
                    xyz0 = row['x'], row['y'], 0
                    x2d_z0, y2d_z0 = reconstructor.find2d(cam_id,
                                                          xyz0,
                                                          distorted=True)
                    warnings.warn('not distorting Z line')
                    if 1:
                        xdist = x2d - x2d_z0
                        ydist = y2d - y2d_z0
                        dist = np.sqrt(xdist**2 + ydist**2)
                        start_frac = radius / dist
                        if radius > dist:
                            start_frac = 0
                        x2d_r = x2d - xdist * start_frac
                        y2d_r = y2d - ydist * start_frac
                    else:
                        x2d_r = x2d
                        y2d_r = y2d
                    canv.plot([x2d_r, x2d_z0], [y2d_r, y2d_z0],
                              color_rgba=green,
                              linewidth=3)
                if show_obj_ids:
                    show_points.append((obj_id, x2d, y2d))
        for show_point in show_points:
            obj_id, x2d, y2d = show_point
            x, y = canv.get_transformed_point(x2d,
                                              y2d,
                                              device_rect,
                                              user_rect,
                                              transform=transform)
            canv.text(
                'obj_id %d' % obj_id,
                x,
                y,
                color_rgba=(0, 1, 0, 1),
                font_size=20,
            )
        canv.save()
Ejemplo n.º 14
0
def calculate_reprojection_errors(
    h5_filename=None,
    output_h5_filename=None,
    kalman_filename=None,
    from_source=None,
    start=None,
    stop=None,
    show_progress=False,
    show_progress_json=False,
):
    assert from_source in ["ML_estimates", "smoothed"]
    if os.path.exists(output_h5_filename):
        raise RuntimeError("will not overwrite old file '%s'" %
                           output_h5_filename)

    out = {
        "camn": [],
        "frame": [],
        "obj_id": [],
        "dist": [],
        "z": [],
    }

    ca = core_analysis.get_global_CachingAnalyzer()
    with ca.kalman_analysis_context(kalman_filename,
                                    data2d_fname=h5_filename) as h5_context:
        R = h5_context.get_reconstructor()
        ML_estimates_2d_idxs = h5_context.load_entire_table(
            "ML_estimates_2d_idxs")
        use_obj_ids = h5_context.get_unique_obj_ids()

        extra = h5_context.get_extra_info()

        if from_source == "smoothed":
            dynamic_model_name = extra["dynamic_model_name"]
            if dynamic_model_name.startswith("EKF "):
                dynamic_model_name = dynamic_model_name[4:]

        fps = h5_context.get_fps()
        camn2cam_id, cam_id2camns = h5_context.get_caminfo_dicts()

        # associate framenumbers with timestamps using 2d .h5 file
        data2d = h5_context.load_entire_table("data2d_distorted",
                                              from_2d_file=True)
        data2d_idxs = np.arange(len(data2d))
        h5_framenumbers = data2d["frame"]
        h5_frame_qfi = result_utils.QuickFrameIndexer(h5_framenumbers)

        if show_progress:
            string_widget = StringWidget()
            objs_per_sec_widget = progressbar.FileTransferSpeed(
                unit="obj_ids ")
            widgets = [
                string_widget,
                objs_per_sec_widget,
                progressbar.Percentage(),
                progressbar.Bar(),
                progressbar.ETA(),
            ]
            pbar = progressbar.ProgressBar(widgets=widgets,
                                           maxval=len(use_obj_ids)).start()

        for obj_id_enum, obj_id in enumerate(use_obj_ids):
            if show_progress:
                string_widget.set_string("[obj_id: % 5d]" % obj_id)
                pbar.update(obj_id_enum)
            if show_progress_json and obj_id_enum % 100 == 0:
                rough_percent_done = float(obj_id_enum) / len(
                    use_obj_ids) * 100.0
                result_utils.do_json_progress(rough_percent_done)

            obj_3d_rows = h5_context.load_dynamics_free_MLE_position(obj_id)

            if from_source == "smoothed":

                smoothed_rows = None
                try:
                    smoothed_rows = h5_context.load_data(
                        obj_id,
                        use_kalman_smoothing=True,
                        dynamic_model_name=dynamic_model_name,
                        frames_per_second=fps,
                    )
                except core_analysis.NotEnoughDataToSmoothError as err:
                    # OK, we don't have data from this obj_id
                    pass
                except core_analysis.DiscontiguousFramesError:
                    pass

            for this_3d_row in obj_3d_rows:
                # iterate over each sample in the current camera
                framenumber = this_3d_row["frame"]
                if start is not None:
                    if not framenumber >= start:
                        continue
                if stop is not None:
                    if not framenumber <= stop:
                        continue
                h5_2d_row_idxs = h5_frame_qfi.get_frame_idxs(framenumber)
                if len(h5_2d_row_idxs) == 0:
                    # At the start, there may be 3d data without 2d data.
                    continue

                if from_source == "ML_estimates":
                    X3d = this_3d_row["x"], this_3d_row["y"], this_3d_row["z"]
                elif from_source == "smoothed":
                    if smoothed_rows is None:
                        X3d = np.nan, np.nan, np.nan
                    else:
                        this_smoothed_rows = smoothed_rows[
                            smoothed_rows["frame"] == framenumber]
                        assert len(this_smoothed_rows) <= 1
                        if len(this_smoothed_rows) == 0:
                            X3d = np.nan, np.nan, np.nan
                        else:
                            X3d = (
                                this_smoothed_rows["x"][0],
                                this_smoothed_rows["y"][0],
                                this_smoothed_rows["z"][0],
                            )

                # If there was a 3D ML estimate, there must be 2D data.

                frame2d = data2d[h5_2d_row_idxs]
                frame2d_idxs = data2d_idxs[h5_2d_row_idxs]

                obs_2d_idx = this_3d_row["obs_2d_idx"]
                kobs_2d_data = ML_estimates_2d_idxs[int(obs_2d_idx)]

                # Parse VLArray.
                this_camns = kobs_2d_data[0::2]
                this_camn_idxs = kobs_2d_data[1::2]

                # Now, for each camera viewing this object at this
                # frame, extract images.
                for camn, camn_pt_no in zip(this_camns, this_camn_idxs):
                    try:
                        cam_id = camn2cam_id[camn]
                    except KeyError:
                        warnings.warn("camn %d not found" % (camn, ))
                        continue

                    # find 2D point corresponding to object
                    cond = (frame2d["camn"] == camn) & (frame2d["frame_pt_idx"]
                                                        == camn_pt_no)
                    idxs = np.nonzero(cond)[0]
                    if len(idxs) == 0:
                        # no frame for that camera (start or stop of file)
                        continue
                    elif len(idxs) > 1:
                        print(
                            "MEGA WARNING MULTIPLE 2D POINTS\n",
                            camn,
                            camn_pt_no,
                            "\n\n",
                        )
                        continue

                    idx = idxs[0]

                    frame2d_row = frame2d[idx]
                    x2d_real = frame2d_row["x"], frame2d_row["y"]
                    x2d_reproj = R.find2d(cam_id, X3d, distorted=True)
                    dist = np.sqrt(np.sum((x2d_reproj - x2d_real)**2))

                    out["camn"].append(camn)
                    out["frame"].append(framenumber)
                    out["obj_id"].append(obj_id)
                    out["dist"].append(dist)
                    out["z"].append(X3d[2])

    # convert to numpy arrays
    for k in out:
        out[k] = np.array(out[k])
    reprojection = pandas.DataFrame(out)
    del out  # free memory

    # new tables
    camns = []
    cam_ids = []
    for camn in camn2cam_id:
        camns.append(camn)
        cam_ids.append(camn2cam_id[camn])
    cam_table = {
        "camn": np.array(camns),
        "cam_id": np.array(cam_ids),
    }
    cam_df = pandas.DataFrame(cam_table)

    # save to disk
    store = pandas.HDFStore(output_h5_filename)
    store.append("reprojection",
                 reprojection,
                 data_columns=reprojection.columns)
    store.append("cameras", cam_df)
    store.close()
    if show_progress_json:
        result_utils.do_json_progress(100)
Ejemplo n.º 15
0
def convert(
        infilename,
        outfilename,
        frames_per_second=None,
        save_timestamps=True,
        file_time_data=None,
        do_nothing=False,  # set to true to test for file existance
        start_obj_id=None,
        stop_obj_id=None,
        obj_only=None,
        dynamic_model_name=None,
        hdf5=False,
        show_progress=False,
        show_progress_json=False,
        **kwargs):
    if start_obj_id is None:
        start_obj_id = -numpy.inf
    if stop_obj_id is None:
        stop_obj_id = numpy.inf

    smoothed_data_filename = os.path.split(infilename)[1]
    raw_data_filename = smoothed_data_filename

    ca = core_analysis.get_global_CachingAnalyzer()
    with ca.kalman_analysis_context(infilename,
                                    data2d_fname=file_time_data) as h5_context:

        extra_vars = {}
        tzname = None

        if save_timestamps:
            print 'STAGE 1: finding timestamps'
            table_kobs = h5_context.get_pytable_node('ML_estimates')

            tzname = h5_context.get_tzname0()
            fps = h5_context.get_fps()

            try:
                table_data2d = h5_context.get_pytable_node('data2d_distorted',
                                                           from_2d_file=True)
            except tables.exceptions.NoSuchNodeError, err:
                print >> sys.stderr, "No timestamps in file. Either specify not to save timestamps ('--no-timestamps') or specify the original .h5 file with the timestamps ('--time-data=FILE2D')"
                sys.exit(1)

            print 'caching raw 2D data...',
            sys.stdout.flush()
            table_data2d_frames = table_data2d.read(field='frame')
            assert numpy.max(table_data2d_frames) < 2**63
            table_data2d_frames = table_data2d_frames.astype(numpy.int64)
            #table_data2d_frames_find = fastsearch.binarysearch.BinarySearcher( table_data2d_frames )
            table_data2d_frames_find = utils.FastFinder(table_data2d_frames)
            table_data2d_camns = table_data2d.read(field='camn')
            table_data2d_timestamps = table_data2d.read(field='timestamp')
            print 'done'
            print '(cached index of %d frame values of dtype %s)' % (
                len(table_data2d_frames), str(table_data2d_frames.dtype))

            drift_estimates = h5_context.get_drift_estimates()
            camn2cam_id, cam_id2camns = h5_context.get_caminfo_dicts()

            gain = {}
            offset = {}
            print 'hostname time_gain time_offset'
            print '-------- --------- -----------'
            for i, hostname in enumerate(drift_estimates.get('hostnames', [])):
                tgain, toffset = result_utils.model_remote_to_local(
                    drift_estimates['remote_timestamp'][hostname][::10],
                    drift_estimates['local_timestamp'][hostname][::10])
                gain[hostname] = tgain
                offset[hostname] = toffset
                print '  ', repr(hostname), tgain, toffset
            print

            if do_nothing:
                return

            print 'caching Kalman obj_ids...'
            obs_obj_ids = table_kobs.read(field='obj_id')
            fast_obs_obj_ids = utils.FastFinder(obs_obj_ids)
            print 'finding unique obj_ids...'
            unique_obj_ids = numpy.unique(obs_obj_ids)
            print '(found %d)' % (len(unique_obj_ids), )
            unique_obj_ids = unique_obj_ids[unique_obj_ids >= start_obj_id]
            unique_obj_ids = unique_obj_ids[unique_obj_ids <= stop_obj_id]

            if obj_only is not None:
                unique_obj_ids = numpy.array(
                    [oid for oid in unique_obj_ids if oid in obj_only])
                print 'filtered to obj_only', obj_only

            print '(will export %d)' % (len(unique_obj_ids), )
            print 'finding 2d data for each obj_id...'
            timestamp_time = numpy.zeros(unique_obj_ids.shape,
                                         dtype=numpy.float64)
            table_kobs_frame = table_kobs.read(field='frame')
            if len(table_kobs_frame) == 0:
                raise ValueError('no 3D data, cannot convert')
            assert numpy.max(table_kobs_frame) < 2**63
            table_kobs_frame = table_kobs_frame.astype(numpy.int64)
            assert table_kobs_frame.dtype == table_data2d_frames.dtype  # otherwise very slow

            all_idxs = fast_obs_obj_ids.get_idx_of_equal(unique_obj_ids)
            for obj_id_enum, obj_id in enumerate(unique_obj_ids):
                idx0 = all_idxs[obj_id_enum]
                framenumber = table_kobs_frame[idx0]
                remote_timestamp = numpy.nan

                this_camn = None
                frame_idxs = table_data2d_frames_find.get_idxs_of_equal(
                    framenumber)
                if len(frame_idxs):
                    frame_idx = frame_idxs[0]
                    this_camn = table_data2d_camns[frame_idx]
                    remote_timestamp = table_data2d_timestamps[frame_idx]

                if this_camn is None:
                    print 'skipping frame %d (obj %d): no data2d_distorted data' % (
                        framenumber, obj_id)
                    continue

                cam_id = camn2cam_id[this_camn]
                try:
                    remote_hostname = cam_id2hostname(cam_id, h5_context)
                except ValueError, e:
                    print 'error getting hostname of cam: %s' % e.message
                    continue
                if remote_hostname not in gain:
                    warnings.warn('no host %s in timestamp data. making up '
                                  'data.' % remote_hostname)
                    gain[remote_hostname] = 1.0
                    offset[remote_hostname] = 0.0
                mainbrain_timestamp = remote_timestamp * gain[
                    remote_hostname] + offset[
                        remote_hostname]  # find mainbrain timestamp

                timestamp_time[obj_id_enum] = mainbrain_timestamp

            extra_vars['obj_ids'] = unique_obj_ids
            extra_vars['timestamps'] = timestamp_time

            print 'STAGE 2: running Kalman smoothing operation'
Ejemplo n.º 16
0
    def show_it(
        self,
        fig,
        filename,
        kalman_filename=None,
        frame_start=None,
        frame_stop=None,
        show_nth_frame=None,
        obj_only=None,
        reconstructor_filename=None,
        options=None,
    ):

        if show_nth_frame == 0:
            show_nth_frame = None

        results = result_utils.get_results(filename, mode="r")
        opened_kresults = False
        kresults = None
        if kalman_filename is not None:
            if os.path.abspath(kalman_filename) == os.path.abspath(filename):
                kresults = results
            else:
                kresults = PT.open_file(kalman_filename, mode="r")
                opened_kresults = True

            # copied from plot_timeseries_2d_3d.py
            ca = core_analysis.get_global_CachingAnalyzer()
            (
                xxobj_ids,
                xxuse_obj_ids,
                xxis_mat_file,
                xxdata_file,
                extra,
            ) = ca.initial_file_load(kalman_filename)
            fps = extra["frames_per_second"]
            dynamic_model_name = None
            if dynamic_model_name is None:
                dynamic_model_name = extra.get("dynamic_model_name", None)
                if dynamic_model_name is None:
                    dynamic_model_name = dynamic_models.DEFAULT_MODEL
                    warnings.warn('no dynamic model specified, using "%s"' %
                                  dynamic_model_name)
                else:
                    print('detected file loaded with dynamic model "%s"' %
                          dynamic_model_name)
                if dynamic_model_name.startswith("EKF "):
                    dynamic_model_name = dynamic_model_name[4:]
                print('  for smoothing, will use dynamic model "%s"' %
                      dynamic_model_name)

        if hasattr(results.root, "images"):
            img_table = results.root.images
        else:
            img_table = None

        reconstructor_source = None
        if reconstructor_filename is None:
            if kresults is not None:
                reconstructor_source = kresults
            elif hasattr(results.root, "calibration"):
                reconstructor_source = results
            else:
                reconstructor_source = None
        else:
            if os.path.abspath(reconstructor_filename) == os.path.abspath(
                    filename):
                reconstructor_source = results
            elif (kalman_filename
                  is not None) and (os.path.abspath(reconstructor_filename)
                                    == os.path.abspath(kalman_filename)):
                reconstructor_source = kresults
            else:
                reconstructor_source = reconstructor_filename

        if reconstructor_source is not None:
            self.reconstructor = flydra_core.reconstruct.Reconstructor(
                reconstructor_source)

        if options.stim_xml:
            file_timestamp = results.filename[4:19]
            stim_xml = xml_stimulus.xml_stimulus_from_filename(
                options.stim_xml, timestamp_string=file_timestamp)
            if self.reconstructor is not None:
                stim_xml.verify_reconstructor(self.reconstructor)

        if self.reconstructor is not None:
            self.reconstructor = self.reconstructor.get_scaled()

        camn2cam_id, cam_id2camns = result_utils.get_caminfo_dicts(results)

        data2d = results.root.data2d_distorted  # make sure we have 2d data table

        print("reading frames...")
        frames = data2d.read(field="frame")
        print("OK")

        if frame_start is not None:
            print("selecting frames after start")
            # after_start = data2d.get_where_list( 'frame>=frame_start')
            after_start = numpy.nonzero(frames >= frame_start)[0]
        else:
            after_start = None

        if frame_stop is not None:
            print("selecting frames before stop")
            # before_stop = data2d.get_where_list( 'frame<=frame_stop')
            before_stop = numpy.nonzero(frames <= frame_stop)[0]
        else:
            before_stop = None

        print("finding all frames")
        if after_start is not None and before_stop is not None:
            use_idxs = numpy.intersect1d(after_start, before_stop)
        elif after_start is not None:
            use_idxs = after_start
        elif before_stop is not None:
            use_idxs = before_stop
        else:
            use_idxs = numpy.arange(data2d.nrows)

        # OK, we have data coords, plot

        print("reading cameras")
        frames = frames[
            use_idxs]  # data2d.read_coordinates( use_idxs, field='frame')
        if len(frames):
            print("frame range: %d - %d (%d frames total)" %
                  (frames[0], frames[-1], len(frames)))
        camns = data2d.read(field="camn")
        camns = camns[use_idxs]
        # camns = data2d.read_coordinates( use_idxs, field='camn')
        unique_camns = numpy.unique(camns)
        unique_cam_ids = list(set([camn2cam_id[camn]
                                   for camn in unique_camns]))
        unique_cam_ids.sort()
        print("%d cameras with data" % (len(unique_cam_ids), ))

        # plot all cameras, not just those with data
        all_cam_ids = cam_id2camns.keys()
        all_cam_ids.sort()
        unique_cam_ids = all_cam_ids

        if len(unique_cam_ids) == 1:
            n_rows = 1
            n_cols = 1
        elif len(unique_cam_ids) <= 6:
            n_rows = 2
            n_cols = 3
        elif len(unique_cam_ids) <= 12:
            n_rows = 3
            n_cols = 4
        else:
            n_rows = 4
            n_cols = int(math.ceil(len(unique_cam_ids) / n_rows))

        for i, cam_id in enumerate(unique_cam_ids):
            ax = auto_subplot(fig, i, n_rows=n_rows, n_cols=n_cols)
            ax.set_title("%s: %s" % (cam_id, str(cam_id2camns[cam_id])))
            ##        ax.set_xticks([])
            ##        ax.set_yticks([])
            ax.this_minx = np.inf
            ax.this_maxx = -np.inf
            ax.this_miny = np.inf
            ax.this_maxy = -np.inf
            self.subplot_by_cam_id[cam_id] = ax

        for cam_id in unique_cam_ids:
            ax = self.subplot_by_cam_id[cam_id]
            if img_table is not None:
                bg_arr_h5 = getattr(img_table, cam_id)
                bg_arr = bg_arr_h5.read()
                ax.imshow(bg_arr.squeeze(), origin="lower", cmap=cm.pink)
                ax.set_autoscale_on(True)
                ax.autoscale_view()
                pylab.draw()
                ax.set_autoscale_on(False)

            if self.reconstructor is not None:
                if cam_id in self.reconstructor.get_cam_ids():
                    res = self.reconstructor.get_resolution(cam_id)
                    ax.set_xlim([0, res[0]])
                    ax.set_ylim([res[1], 0])

            if options.stim_xml is not None:
                stim_xml.plot_stim_over_distorted_image(ax, cam_id)
        for camn in unique_camns:
            cam_id = camn2cam_id[camn]
            ax = self.subplot_by_cam_id[cam_id]
            this_camn_idxs = use_idxs[camns == camn]

            xs = data2d.read_coordinates(this_camn_idxs, field="x")

            valid_idx = numpy.nonzero(~numpy.isnan(xs))[0]
            if not len(valid_idx):
                continue
            ys = data2d.read_coordinates(this_camn_idxs, field="y")
            if options.show_orientation:
                slope = data2d.read_coordinates(this_camn_idxs, field="slope")

            idx_first_valid = valid_idx[0]
            idx_last_valid = valid_idx[-1]
            tmp_frames = data2d.read_coordinates(this_camn_idxs, field="frame")

            ax.plot([xs[idx_first_valid]], [ys[idx_first_valid]],
                    "ro",
                    label="first point")

            ax.this_minx = min(np.min(xs[valid_idx]), ax.this_minx)
            ax.this_maxx = max(np.max(xs[valid_idx]), ax.this_maxx)

            ax.this_miny = min(np.min(ys[valid_idx]), ax.this_miny)
            ax.this_maxy = max(np.max(ys[valid_idx]), ax.this_maxy)

            ax.plot(xs[valid_idx], ys[valid_idx], "g.", label="all points")

            if options.show_orientation:
                angle = np.arctan(slope)
                r = 20.0
                dx = r * np.cos(angle)
                dy = r * np.sin(angle)
                x0 = xs - dx
                x1 = xs + dx
                y0 = ys - dy
                y1 = ys + dy
                segs = []
                for i in valid_idx:
                    segs.append(((x0[i], y0[i]), (x1[i], y1[i])))
                line_segments = collections.LineCollection(
                    segs,
                    linewidths=[1],
                    colors=[(0, 1, 0)],
                )
                ax.add_collection(line_segments)

            ax.plot([xs[idx_last_valid]], [ys[idx_last_valid]],
                    "bo",
                    label="first point")

            if show_nth_frame is not None:
                for i, f in enumerate(tmp_frames):
                    if f % show_nth_frame == 0:
                        ax.text(xs[i], ys[i], "%d" % (f, ))

            if 0:
                for x, y, frame in zip(xs[::5], ys[::5], tmp_frames[::5]):
                    ax.text(x, y, "%d" % (frame, ))

        fig.canvas.mpl_connect("key_press_event", self.on_key_press)

        if options.autozoom:
            for cam_id in self.subplot_by_cam_id.keys():
                ax = self.subplot_by_cam_id[cam_id]
                ax.set_xlim((ax.this_minx - 10, ax.this_maxx + 10))
                ax.set_ylim((ax.this_miny - 10, ax.this_maxy + 10))

        if options.save_fig:
            for cam_id in self.subplot_by_cam_id.keys():
                ax = self.subplot_by_cam_id[cam_id]
                ax.set_xticks([])
                ax.set_yticks([])

        if kalman_filename is None:
            return

        if 0:
            # Do same as above for Kalman-filtered data

            kobs = kresults.root.ML_estimates
            kframes = kobs.read(field="frame")
            if frame_start is not None:
                k_after_start = numpy.nonzero(kframes >= frame_start)[0]
            else:
                k_after_start = None
            if frame_stop is not None:
                k_before_stop = numpy.nonzero(kframes <= frame_stop)[0]
            else:
                k_before_stop = None

            if k_after_start is not None and k_before_stop is not None:
                k_use_idxs = numpy.intersect1d(k_after_start, k_before_stop)
            elif k_after_start is not None:
                k_use_idxs = k_after_start
            elif k_before_stop is not None:
                k_use_idxs = k_before_stop
            else:
                k_use_idxs = numpy.arange(kobs.nrows)

            obj_ids = kobs.read(field="obj_id")[k_use_idxs]
            obs_2d_idxs = kobs.read(field="obs_2d_idx")[k_use_idxs]
            kframes = kframes[k_use_idxs]

            kobs_2d = kresults.root.ML_estimates_2d_idxs
            xys_by_obj_id = {}
            for obj_id, kframe, obs_2d_idx in zip(obj_ids, kframes,
                                                  obs_2d_idxs):
                if obj_only is not None:
                    if obj_id not in obj_only:
                        continue

                obs_2d_idx_find = int(
                    obs_2d_idx)  # XXX grr, why can't pytables do this?
                obj_id_save = int(obj_id)  # convert from possible numpy scalar
                xys_by_cam_id = xys_by_obj_id.setdefault(obj_id_save, {})
                kobs_2d_data = kobs_2d.read(start=obs_2d_idx_find,
                                            stop=obs_2d_idx_find + 1)
                assert len(kobs_2d_data) == 1
                kobs_2d_data = kobs_2d_data[0]
                this_camns = kobs_2d_data[0::2]
                this_camn_idxs = kobs_2d_data[1::2]

                this_use_idxs = use_idxs[frames == kframe]

                d2d = data2d.read_coordinates(this_use_idxs)
                for this_camn, this_camn_idx in zip(this_camns,
                                                    this_camn_idxs):
                    this_idxs_tmp = numpy.nonzero(d2d["camn"] == this_camn)[0]
                    this_camn_d2d = d2d[d2d["camn"] == this_camn]
                    found = False
                    for this_row in this_camn_d2d:  # XXX could be sped up
                        if this_row["frame_pt_idx"] == this_camn_idx:
                            found = True
                            break
                    if not found:
                        if 1:
                            print(
                                "WARNING:point not found in data -- 3D data starts before 2D I guess."
                            )
                            continue
                        else:
                            raise RuntimeError("point not found in data!?")
                    this_cam_id = camn2cam_id[this_camn]
                    xys = xys_by_cam_id.setdefault(this_cam_id, ([], []))
                    xys[0].append(this_row["x"])
                    xys[1].append(this_row["y"])

            for obj_id in xys_by_obj_id:
                xys_by_cam_id = xys_by_obj_id[obj_id]
                for cam_id, (xs, ys) in xys_by_cam_id.iteritems():
                    ax = self.subplot_by_cam_id[cam_id]
                    ax.plot(xs, ys, "x-", label="obs: %d" % obj_id)
                    ax.text(xs[0], ys[0], "%d:" % (obj_id, ))
                    ax.text(xs[-1], ys[-1], ":%d" % (obj_id, ))

        if 1:
            # do for core_analysis smoothed (or not) data

            for obj_id in xxuse_obj_ids:
                try:
                    rows = ca.load_data(
                        obj_id,
                        kalman_filename,
                        use_kalman_smoothing=True,
                        frames_per_second=fps,
                        dynamic_model_name=dynamic_model_name,
                    )
                except core_analysis.NotEnoughDataToSmoothError:
                    warnings.warn(
                        "not enough data to smooth obj_id %d, skipping." %
                        (obj_id, ))
                if frame_start is not None:
                    c1 = rows["frame"] >= frame_start
                else:
                    c1 = np.ones((len(rows), ), dtype=np.bool)
                if frame_stop is not None:
                    c2 = rows["frame"] <= frame_stop
                else:
                    c2 = np.ones((len(rows), ), dtype=np.bool)
                valid = c1 & c2
                rows = rows[valid]
                if len(rows) > 1:
                    X3d = np.array((rows["x"], rows["y"], rows["z"],
                                    np.ones_like(rows["z"]))).T

                for cam_id in self.subplot_by_cam_id.keys():
                    ax = self.subplot_by_cam_id[cam_id]
                    newx, newy = self.reconstructor.find2d(cam_id,
                                                           X3d,
                                                           distorted=True)
                    ax.plot(newx, newy, "-", label="k: %d" % obj_id)

        results.close()
        if opened_kresults:
            kresults.close()
Ejemplo n.º 17
0
def check_offline_reconstruction(with_water=False,
                                 use_kalman_smoothing=False,
                                 with_orientation=False,
                                 fps=120.0,
                                 with_distortion=True):
    D = setup_data(
        fps=fps,
        with_water=with_water,
        with_orientation=with_orientation,
        with_distortion=with_distortion,
    )

    data2d_fname = tempfile.mktemp(suffix='-data2d.h5')
    to_unlink = [data2d_fname]
    try:
        flydra_analysis.offline_data_save.save_data(
            fname=data2d_fname,
            data2d=D['data2d'],
            fps=fps,
            reconstructor=D['reconstructor'],
            eccentricity=D['eccentricity'],
        )
        d1 = D['reconstructor'].get_intrinsic_nonlinear('cam03')
        d2 = flydra_core.reconstruct.Reconstructor(
            data2d_fname).get_intrinsic_nonlinear('cam03')
        assert np.allclose(d1, d2)

        data3d_fname = tempfile.mktemp(suffix='-data3d.h5')
        kalmanize(
            data2d_fname,
            dest_filename=data3d_fname,
            dynamic_model_name=D['dynamic_model_name'],
            reconstructor=D['reconstructor'],
        )
        to_unlink.append(data3d_fname)

        ca = core_analysis.get_global_CachingAnalyzer()
        (obj_ids, use_obj_ids, is_mat_file, data_file,
         extra) = ca.initial_file_load(data3d_fname)

        assert len(use_obj_ids) == 1
        obj_id = use_obj_ids[0]

        load_model = D['dynamic_model_name']

        if use_kalman_smoothing:
            if load_model.startswith('EKF '):
                load_model = load_model[4:]
            smoothcache_fname = os.path.splitext(
                data3d_fname)[0] + '.kh5-smoothcache'
            to_unlink.append(smoothcache_fname)

        my_rows = ca.load_data(
            obj_id,
            data_file,
            use_kalman_smoothing=use_kalman_smoothing,
            dynamic_model_name=load_model,
            frames_per_second=fps,
        )

        x_actual = my_rows['x']
        y_actual = my_rows['y']
        z_actual = my_rows['z']

        data_file.close()
        ca.close()

    finally:
        for fname in to_unlink:
            try:
                os.unlink(fname)
            except OSError as err:
                # file does not exist?
                pass

    assert my_rows['x'].shape == D['x'].shape
    mean_error = np.mean(
        np.sqrt((D['x'] - x_actual)**2 + (D['y'] - y_actual)**2 +
                (D['z'] - z_actual)**2))

    # We should have very low error
    fudge = 2 if use_kalman_smoothing else 1
    assert mean_error < fudge * MAX_MEAN_ERROR
Ejemplo n.º 18
0
def doit(
    filename=None,
    obj_only=None,
    do_ransac=False,
    show=False,
):
    # get original 3D points -------------------------------
    ca = core_analysis.get_global_CachingAnalyzer()
    obj_ids, use_obj_ids, is_mat_file, data_file, extra = ca.initial_file_load(
        filename)
    if obj_only is not None:
        use_obj_ids = np.array(obj_only)

    x = []
    y = []
    z = []
    for obj_id in use_obj_ids:
        obs_rows = ca.load_dynamics_free_MLE_position(obj_id, data_file)
        goodcond = ~np.isnan(obs_rows['x'])
        good_rows = obs_rows[goodcond]
        x.append(good_rows['x'])
        y.append(good_rows['y'])
        z.append(good_rows['z'])
    x = np.concatenate(x)
    y = np.concatenate(y)
    z = np.concatenate(z)

    recon = Reconstructor(cal_source=data_file)
    extra['kresults'].close()  # close file

    data = np.empty((len(x), 3), dtype=np.float)
    data[:, 0] = x
    data[:, 1] = y
    data[:, 2] = z

    # calculate plane-of-best fit ------------

    helper = PlaneModelHelper()
    if not do_ransac:
        plane_params = helper.fit(data)
    else:
        # do RANSAC
        """
        n: the minimum number of data values required to fit the model
        k: the maximum number of iterations allowed in the algorithm
        t: a threshold value for determining when a data point fits a model
        d: the number of close data values required to assert that a model fits well to data
        """
        n = 20
        k = 100
        t = np.mean([np.std(x), np.std(y), np.std(z)])
        d = 100
        plane_params = ransac.ransac(data, helper, n, k, t, d, debug=False)

    # Calculate rotation matrix from plane-of-best-fit to z==0 --------
    orig_normal = norm(plane_params[:3])
    new_normal = np.array([0, 0, 1], dtype=np.float)
    rot_axis = norm(np.cross(orig_normal, new_normal))
    cos_angle = np.dot(orig_normal, new_normal)
    angle = np.arccos(cos_angle)
    q = cgtypes.quat().fromAngleAxis(angle, rot_axis)
    m = q.toMat3()
    R = cgmat2np(m)

    # Calculate aligned data without translation -----------------
    s = 1.0
    t = np.array([0, 0, 0], dtype=np.float)

    aligned_data = align.align_points(s, R, t, data.T).T

    # Calculate aligned data so that mean point is origin -----------------
    t = -np.mean(aligned_data[:, :3], axis=0)
    aligned_data = align.align_points(s, R, t, data.T).T

    M = align.build_xform(s, R, t)
    r2 = recon.get_aligned_copy(M)
    wateri = water.WaterInterface(
        refractive_index=DEFAULT_WATER_REFRACTIVE_INDEX,
        water_roots_eps=WATER_ROOTS_EPS)
    r2.add_water(wateri)

    dst = os.path.splitext(filename)[0] + '-water-aligned.xml'
    r2.save_to_xml_filename(dst)
    print 'saved to', dst

    if show:
        import matplotlib.pyplot as plt
        from pymvg.plot_utils import plot_system
        from mpl_toolkits.mplot3d import Axes3D

        fig = plt.figure()

        ax1 = fig.add_subplot(221)
        ax1.plot(data[:, 0], data[:, 1], 'b.')
        ax1.set_xlabel('x')
        ax1.set_ylabel('y')

        ax2 = fig.add_subplot(222)
        ax2.plot(data[:, 0], data[:, 2], 'b.')
        ax2.set_xlabel('x')
        ax2.set_ylabel('z')

        ax3 = fig.add_subplot(223)
        ax3.plot(aligned_data[:, 0], aligned_data[:, 1], 'b.')
        ax3.set_xlabel('x')
        ax3.set_ylabel('y')

        ax4 = fig.add_subplot(224)
        ax4.plot(aligned_data[:, 0], aligned_data[:, 2], 'b.')
        ax4.set_xlabel('x')
        ax4.set_ylabel('z')

        fig2 = plt.figure('cameras')
        ax = fig2.add_subplot(111, projection='3d')
        system = r2.convert_to_pymvg(ignore_water=True)
        plot_system(ax, system)
        x = np.linspace(-0.1, 0.1, 10)
        y = np.linspace(-0.1, 0.1, 10)
        X, Y = np.meshgrid(x, y)
        Z = np.zeros_like(X)
        ax.plot(X.ravel(), Y.ravel(), Z.ravel(), 'b.')
        ax.set_title('aligned camera positions')

        plt.show()
Ejemplo n.º 19
0
def retrack_reuse_data_association(
    h5_filename=None,
    output_h5_filename=None,
    kalman_filename=None,
    start=None,
    stop=None,
    less_ram=False,
    show_progress=False,
    show_progress_json=False,
):
    if os.path.exists(output_h5_filename):
        raise RuntimeError("will not overwrite old file '%s'" %
                           output_h5_filename)

    ca = core_analysis.get_global_CachingAnalyzer()
    with ca.kalman_analysis_context(kalman_filename,
                                    data2d_fname=h5_filename) as h5_context:
        R = h5_context.get_reconstructor()
        if less_ram:
            ML_estimates_2d_idxs = h5_context.get_pytable_node(
                'ML_estimates_2d_idxs')
        else:
            ML_estimates_2d_idxs = h5_context.load_entire_table(
                'ML_estimates_2d_idxs')
        use_obj_ids = h5_context.get_unique_obj_ids()
        extra = h5_context.get_extra_info()
        dt = 1.0 / extra['frames_per_second']
        dynamic_model_name = extra['dynamic_model_name']
        kalman_model = dynamic_models.get_kalman_model(name=dynamic_model_name,
                                                       dt=dt)
        kalman_model['max_frames_skipped'] = 2**62  # close to max i64

        fps = extra['frames_per_second']
        camn2cam_id, cam_id2camns = h5_context.get_caminfo_dicts()

        parsed = h5_context.read_textlog_header()
        if 'trigger_CS3' not in parsed:
            parsed['trigger_CS3'] = 'unknown'

        textlog_save_lines = [
            'retrack_reuse_data_association running at %s fps, (top %s, trigger_CS3 %s, flydra_version %s)'
            %
            (str(fps), str(parsed.get('top', 'unknown')),
             str(parsed['trigger_CS3']), flydra_analysis.version.__version__),
            'original file: %s' % (kalman_filename, ),
            'dynamic model: %s' % (dynamic_model_name, ),
            'reconstructor file: %s' % (kalman_filename, ),
        ]

        with open_file_safe(output_h5_filename,
                            mode="w",
                            title="tracked Flydra data file",
                            delete_on_error=True) as output_h5:

            h5saver = KalmanSaver(
                output_h5,
                R,
                cam_id2camns=cam_id2camns,
                min_observations_to_save=0,
                textlog_save_lines=textlog_save_lines,
                dynamic_model_name=dynamic_model_name,
                dynamic_model=kalman_model,
            )

            # associate framenumbers with timestamps using 2d .h5 file
            if less_ram:
                data2d = h5_context.get_pytable_node('data2d_distorted',
                                                     from_2d_file=True)
                h5_framenumbers = data2d.cols.frame[:]
            else:
                data2d = h5_context.load_entire_table('data2d_distorted',
                                                      from_2d_file=True)
                h5_framenumbers = data2d['frame']
            h5_frame_qfi = result_utils.QuickFrameIndexer(h5_framenumbers)

            if show_progress:
                string_widget = StringWidget()
                objs_per_sec_widget = progressbar.FileTransferSpeed(
                    unit='obj_ids ')
                widgets = [
                    string_widget, objs_per_sec_widget,
                    progressbar.Percentage(),
                    progressbar.Bar(),
                    progressbar.ETA()
                ]
                pbar = progressbar.ProgressBar(
                    widgets=widgets, maxval=len(use_obj_ids)).start()

            for obj_id_enum, obj_id in enumerate(use_obj_ids):
                if show_progress:
                    string_widget.set_string('[obj_id: % 5d]' % obj_id)
                    pbar.update(obj_id_enum)
                if show_progress_json and obj_id_enum % 100 == 0:
                    rough_percent_done = float(obj_id_enum) / len(
                        use_obj_ids) * 100.0
                    result_utils.do_json_progress(rough_percent_done)

                tro = None
                first_frame_per_obj = True
                obj_3d_rows = h5_context.load_dynamics_free_MLE_position(
                    obj_id)
                for this_3d_row in obj_3d_rows:
                    # iterate over each sample in the current camera
                    framenumber = this_3d_row['frame']
                    if start is not None:
                        if not framenumber >= start:
                            continue
                    if stop is not None:
                        if not framenumber <= stop:
                            continue
                    h5_2d_row_idxs = h5_frame_qfi.get_frame_idxs(framenumber)
                    if len(h5_2d_row_idxs) == 0:
                        # At the start, there may be 3d data without 2d data.
                        continue

                    # If there was a 3D ML estimate, there must be 2D data.

                    frame2d = data2d[h5_2d_row_idxs]

                    obs_2d_idx = this_3d_row['obs_2d_idx']
                    kobs_2d_data = ML_estimates_2d_idxs[int(obs_2d_idx)]

                    # Parse VLArray.
                    this_camns = kobs_2d_data[0::2]
                    this_camn_idxs = kobs_2d_data[1::2]

                    # Now, for each camera viewing this object at this
                    # frame, extract images.
                    observation_camns = []
                    observation_idxs = []
                    data_dict = {}
                    used_camns_and_idxs = []
                    cam_ids_and_points2d = []

                    for camn, frame_pt_idx in zip(this_camns, this_camn_idxs):
                        try:
                            cam_id = camn2cam_id[camn]
                        except KeyError:
                            warnings.warn('camn %d not found' % (camn, ))
                            continue

                        # find 2D point corresponding to object
                        cond = ((frame2d['camn'] == camn) &
                                (frame2d['frame_pt_idx'] == frame_pt_idx))
                        idxs = np.nonzero(cond)[0]
                        if len(idxs) == 0:
                            #no frame for that camera (start or stop of file)
                            continue
                        elif len(idxs) > 1:
                            print "MEGA WARNING MULTIPLE 2D POINTS\n", camn, frame_pt_idx, "\n\n"
                            continue

                        idx = idxs[0]

                        frame2d_row = frame2d[idx]
                        x2d_real = frame2d_row['x'], frame2d_row['y']
                        pt_undistorted = R.undistort(cam_id, x2d_real)
                        x2d_area = frame2d_row['area']

                        observation_camns.append(camn)
                        observation_idxs.append(idx)
                        candidate_point_list = []
                        data_dict[camn] = candidate_point_list
                        used_camns_and_idxs.append((camn, frame_pt_idx, None))

                        # with no orientation
                        observed_2d = (pt_undistorted[0], pt_undistorted[1],
                                       x2d_area)

                        cam_ids_and_points2d.append((cam_id, observed_2d))

                    if first_frame_per_obj:
                        if len(cam_ids_and_points2d) < 2:
                            warnings.warn(
                                'some 2D data seems to be missing, cannot completely reconstruct'
                            )
                        else:
                            X3d = R.find3d(
                                cam_ids_and_points2d,
                                return_line_coords=False,
                                simulate_via_tracking_dynamic_model=kalman_model
                            )

                            # first frame
                            tro = TrackedObject(
                                R,
                                obj_id,
                                framenumber,
                                X3d,  # obs0_position
                                None,  # obs0_Lcoords
                                observation_camns,  # first_observation_camns
                                observation_idxs,  # first_observation_idxs
                                kalman_model=kalman_model,
                            )
                            del X3d
                            first_frame_per_obj = False
                    else:
                        tro.calculate_a_posteriori_estimate(
                            framenumber,
                            data_dict,
                            camn2cam_id,
                            skip_data_association=True,
                            original_camns_and_idxs=used_camns_and_idxs,
                            original_cam_ids_and_points2d=cam_ids_and_points2d,
                        )

                # done with all data for this obj_id
                if tro is not None:
                    tro.kill()
                    h5saver.save_tro(tro, force_obj_id=obj_id)
    if show_progress_json:
        result_utils.do_json_progress(100)
Ejemplo n.º 20
0
def get_good_smoothed_tracks(filename, obj_ids, min_frames_per_track,
                             use_smoothing, dynamic_model_name):
    ''' Yields (obj_id, rows) for each track in obj_ids in the file
        that has the given minimum number of frames. '''

    frames_per_second = 60.0
    dt = 1 / frames_per_second

    ca = core_analysis.get_global_CachingAnalyzer()

    warned = False

    #obj_ids, unique_obj_ids, is_mat_file, data_file, extra = \
    #     ca.initial_file_load(filename)
    data_file = filename

    for obj_id in obj_ids:
        try:
            frows = ca.load_data(obj_id, data_file, use_kalman_smoothing=False)

            # don't consider tracks too small
            if len(frows) < min_frames_per_track:
                continue

            # write timestamp entry

            # The 'timestamp' field returned by flydra is the time
            # when the computation was made, not the actual data timestamp.
            # For computing the actual timestamp, use the frame number
            # and multiply by dt

            global warned_fixed_dt
            if not warned_fixed_dt:
                warned_fixed_dt = True
                logger.info('Warning: We are assuming that the data is ' \
                      'equally spaced, and fps = %s.' % frames_per_second)

            for i in range(len(frows)):
                frows['timestamp'][i] = frows['frame'][i] * dt

            for i in range(len(frows) - 1):
                if frows['obj_id'][i] == frows['obj_id'][i + 1]:
                    assert frows['timestamp'][i] < frows['timestamp'][i + 1]

            # return raw data if smoothing is not requested
            if not use_smoothing:
                yield (obj_id,
                       extract_interesting_fields(frows, np.dtype(rows_dtype)))
                continue

            # otherwise, run the smoothing
            srows = ca.load_data(obj_id,
                                 data_file,
                                 use_kalman_smoothing=True,
                                 frames_per_second=frames_per_second,
                                 dynamic_model_name=dynamic_model_name)

            # make a copy, just in case
            srows = srows.copy()

            for i in range(len(srows)):
                srows['timestamp'][i] = srows['frame'][i] * dt

            # From Andrew:
            # I'm pretty sure there is an inconsistency in some of this
            # unit stuff. Basically, I used to do the camera calibrations
            # all in mm (so that the 3D coords would come out in mm). Then,
            # I started doing analyses in meters... And I think some of
            # the calibration and dynamic model stuff got defaulted to meters.
            # And basically there are inconsistencies in there.
            # Anyhow, I think the extent of the issue is that you'll be off
            # by 1000, so hopefully you can just determine that by looking
            # at the data.
            # quick fix
            if dynamic_model_name == "mamarama, units: mm" and not warned:
                warned = True
                logger.info("Warning: Implementing simple workaround"
                            " for flydra's "
                            "units inconsistencies "
                            "(multiplying xvel,yvel by 1000).")
                srows['xvel'] *= 1000
                srows['yvel'] *= 1000

            yield obj_id, extract_interesting_fields(srows,
                                                     np.dtype(rows_dtype))

        except core_analysis.NotEnoughDataToSmoothError:
            #logger.warning('not enough data to
            # smooth obj_id %d, skipping.'%(obj_id,))
            continue

    ca.close()
Ejemplo n.º 21
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('filename', nargs='+',
                        help='name of flydra .hdf5 file',
                        )

    parser.add_argument("--stim-xml",
                        type=str,
                        default=None,
                        help="name of XML file with stimulus info",
                        required=True,
                        )

    parser.add_argument("--align-json",
                        type=str,
                        default=None,
                        help="previously exported json file containing s,R,T",
                        )

    parser.add_argument("--radius", type=float,
                      help="radius of line (in meters)",
                      default=0.002,
                      metavar="RADIUS")

    parser.add_argument("--obj-only", type=str)

    parser.add_argument("--obj-filelist", type=str,
                      help="use object ids from list in text file",
                      )

    parser.add_argument(
        "-r", "--reconstructor", dest="reconstructor_path",
        type=str,
        help=("calibration/reconstructor path (if not specified, "
              "defaults to FILE)"))

    args = parser.parse_args()
    options = args # optparse OptionParser backwards compatibility

    reconstructor_path = args.reconstructor_path
    fps = None

    ca = core_analysis.get_global_CachingAnalyzer()
    by_file = {}

    for h5_filename in args.filename:
        assert(tables.is_hdf5_file(h5_filename))
        obj_ids, use_obj_ids, is_mat_file, data_file, extra = ca.initial_file_load(
            h5_filename)
        this_fps = result_utils.get_fps( data_file, fail_on_error=False )
        if fps is None:
            if this_fps is not None:
                fps = this_fps
        if reconstructor_path is None:
            reconstructor_path = data_file
        by_file[h5_filename] = (use_obj_ids, data_file)
    del h5_filename
    del obj_ids, use_obj_ids, is_mat_file, data_file, extra

    if options.obj_only is not None:
        obj_only = core_analysis.parse_seq(options.obj_only)
    else:
        obj_only = None

    if reconstructor_path is None:
        raise RuntimeError('must specify reconstructor from CLI if not using .h5 files')

    R = reconstruct.Reconstructor(reconstructor_path)

    if fps is None:
        fps = 100.0
        warnings.warn('Setting fps to default value of %f'%fps)
    else:
        fps = 1.0

    if options.stim_xml is None:
        raise ValueError(
            'stim_xml must be specified (how else will you align the data?')

    if 1:
        stim_xml = xml_stimulus.xml_stimulus_from_filename(
            options.stim_xml,
            )
        try:
            fanout = xml_stimulus.xml_fanout_from_filename( options.stim_xml )
        except xml_stimulus.WrongXMLTypeError:
            pass
        else:
            include_obj_ids, exclude_obj_ids = fanout.get_obj_ids_for_timestamp(
                timestamp_string=file_timestamp )
            if include_obj_ids is not None:
                use_obj_ids = include_obj_ids
            if exclude_obj_ids is not None:
                use_obj_ids = list( set(use_obj_ids).difference(
                    exclude_obj_ids ) )
            print('using object ids specified in fanout .xml file')
        if stim_xml.has_reconstructor():
            stim_xml.verify_reconstructor(R)

    x = []
    y = []
    z = []
    speed = []

    if options.obj_filelist is not None:
        obj_filelist=options.obj_filelist
    else:
        obj_filelist=None

    if obj_filelist is not None:
        obj_only = 1

    if obj_only is not None:
        if len(by_file) != 1:
            raise RuntimeError("specifying obj_only can only be done for a single file")
        if obj_filelist is not None:
            data = np.loadtxt(obj_filelist,delimiter=',')
            obj_only = np.array(data[:,0], dtype='int')
            print(obj_only)

        use_obj_ids = numpy.array(obj_only)
        h5_filename = by_file.keys()[0]
        (prev_use_ob_ids, data_file) = by_file[h5_filename]
        by_file[h5_filename] = (use_obj_ids, data_file)

    for h5_filename in by_file:
        (use_obj_ids, data_file) = by_file[h5_filename]
        for obj_id_enum,obj_id in enumerate(use_obj_ids):
            rows = ca.load_data( obj_id, data_file,
                                 use_kalman_smoothing=False,
                                 #dynamic_model_name = dynamic_model_name,
                                 #frames_per_second=fps,
                                 #up_dir=up_dir,
                                )
            verts = numpy.array( [rows['x'], rows['y'], rows['z']] ).T
            if len(verts)>=3:
                verts_central_diff = verts[2:,:] - verts[:-2,:]
                dt = 1.0/fps
                vels = verts_central_diff/(2*dt)
                speeds = numpy.sqrt(numpy.sum(vels**2,axis=1))
                # pad end points
                speeds = numpy.array([speeds[0]] + list(speeds) + [speeds[-1]])
            else:
                speeds = numpy.zeros( (verts.shape[0],) )

            if verts.shape[0] != len(speeds):
                raise ValueError('mismatch length of x data and speeds')
            x.append( verts[:,0] )
            y.append( verts[:,1] )
            z.append( verts[:,2] )
            speed.append(speeds)
        data_file.close()
    del h5_filename, use_obj_ids, data_file

    if 0:
        # debug
        if stim_xml is not None:
            v = None
            for child in stim_xml.root:
                if child.tag == 'cubic_arena':
                    info = stim_xml._get_info_for_cubic_arena(child)
                    v=info['verts4x4']
            if v is not None:
                for vi in v:
                    print('adding',vi)
                    x.append( [vi[0]] )
                    y.append( [vi[1]] )
                    z.append( [vi[2]] )
                    speed.append( [100.0] )

    x = np.concatenate(x)
    y = np.concatenate(y)
    z = np.concatenate(z)
    w = np.ones_like(x)
    speed = np.concatenate(speed)

    # homogeneous coords
    verts = np.array([x,y,z,w])

    #######################################################

    # Create the MayaVi engine and start it.
    e = Engine()
    # start does nothing much but useful if someone is listening to
    # your engine.
    e.start()

    # Create a new scene.
    from tvtk.tools import ivtk
    #viewer = ivtk.IVTK(size=(600,600))
    viewer = IVTKWithCalGUI(size=(800,600))
    viewer.open()
    e.new_scene(viewer)

    viewer.cal_align.set_data(verts,speed,R,args.align_json)

    if 0:
        # Do this if you need to see the MayaVi tree view UI.
        ev = EngineView(engine=e)
        ui = ev.edit_traits()

    # view aligned data
    e.add_source(viewer.cal_align.source)

    v = Vectors()
    v.glyph.scale_mode = 'data_scaling_off'
    v.glyph.color_mode = 'color_by_scalar'
    v.glyph.glyph_source.glyph_position='center'
    v.glyph.glyph_source.glyph_source = tvtk.SphereSource(
        radius=options.radius,
        )
    e.add_module(v)

    if stim_xml is not None:
        if 0:
            stim_xml.draw_in_mayavi_scene(e)
        else:
            actors = stim_xml.get_tvtk_actors()
            viewer.scene.add_actors(actors)

    gui = GUI()
    gui.start_event_loop()
Ejemplo n.º 22
0
def convert(
    infilename,
    outfilename,
    frames_per_second=None,
    save_timestamps=True,
    file_time_data=None,
    do_nothing=False,  # set to true to test for file existance
    start_obj_id=None,
    stop_obj_id=None,
    obj_only=None,
    dynamic_model_name=None,
    hdf5=False,
    show_progress=False,
    show_progress_json=False,
    **kwargs
):
    if start_obj_id is None:
        start_obj_id = -numpy.inf
    if stop_obj_id is None:
        stop_obj_id = numpy.inf

    smoothed_data_filename = os.path.split(infilename)[1]
    raw_data_filename = smoothed_data_filename

    ca = core_analysis.get_global_CachingAnalyzer()
    with ca.kalman_analysis_context(
        infilename, data2d_fname=file_time_data
    ) as h5_context:

        extra_vars = {}
        tzname = None

        if save_timestamps:
            print("STAGE 1: finding timestamps")
            table_kobs = h5_context.get_pytable_node("ML_estimates")

            tzname = h5_context.get_tzname0()
            fps = h5_context.get_fps()

            try:
                table_data2d = h5_context.get_pytable_node(
                    "data2d_distorted", from_2d_file=True
                )
            except tables.exceptions.NoSuchNodeError as err:
                print(
                    "No timestamps in file. Either specify not to save timestamps ('--no-timestamps') or specify the original .h5 file with the timestamps ('--time-data=FILE2D')",
                    file=sys.stderr,
                )
                sys.exit(1)

            print("caching raw 2D data...", end=" ")
            sys.stdout.flush()
            table_data2d_frames = table_data2d.read(field="frame")
            assert numpy.max(table_data2d_frames) < 2 ** 63
            table_data2d_frames = table_data2d_frames.astype(numpy.int64)
            # table_data2d_frames_find = fastsearch.binarysearch.BinarySearcher( table_data2d_frames )
            table_data2d_frames_find = utils.FastFinder(table_data2d_frames)
            table_data2d_camns = table_data2d.read(field="camn")
            table_data2d_timestamps = table_data2d.read(field="timestamp")
            print("done")
            print(
                "(cached index of %d frame values of dtype %s)"
                % (len(table_data2d_frames), str(table_data2d_frames.dtype))
            )

            drift_estimates = h5_context.get_drift_estimates()
            camn2cam_id, cam_id2camns = h5_context.get_caminfo_dicts()

            gain = {}
            offset = {}
            print("hostname time_gain time_offset")
            print("-------- --------- -----------")
            for i, hostname in enumerate(drift_estimates.get("hostnames", [])):
                tgain, toffset = result_utils.model_remote_to_local(
                    drift_estimates["remote_timestamp"][hostname][::10],
                    drift_estimates["local_timestamp"][hostname][::10],
                )
                gain[hostname] = tgain
                offset[hostname] = toffset
                print("  ", repr(hostname), tgain, toffset)
            print()

            if do_nothing:
                return

            print("caching Kalman obj_ids...")
            obs_obj_ids = table_kobs.read(field="obj_id")
            fast_obs_obj_ids = utils.FastFinder(obs_obj_ids)
            print("finding unique obj_ids...")
            unique_obj_ids = numpy.unique(obs_obj_ids)
            print("(found %d)" % (len(unique_obj_ids),))
            unique_obj_ids = unique_obj_ids[unique_obj_ids >= start_obj_id]
            unique_obj_ids = unique_obj_ids[unique_obj_ids <= stop_obj_id]

            if obj_only is not None:
                unique_obj_ids = numpy.array(
                    [oid for oid in unique_obj_ids if oid in obj_only]
                )
                print("filtered to obj_only", obj_only)

            print("(will export %d)" % (len(unique_obj_ids),))
            print("finding 2d data for each obj_id...")
            timestamp_time = numpy.zeros(unique_obj_ids.shape, dtype=numpy.float64)
            table_kobs_frame = table_kobs.read(field="frame")
            if len(table_kobs_frame) == 0:
                raise ValueError("no 3D data, cannot convert")
            assert numpy.max(table_kobs_frame) < 2 ** 63
            table_kobs_frame = table_kobs_frame.astype(numpy.int64)
            assert (
                table_kobs_frame.dtype == table_data2d_frames.dtype
            )  # otherwise very slow

            all_idxs = fast_obs_obj_ids.get_idx_of_equal(unique_obj_ids)
            for obj_id_enum, obj_id in enumerate(unique_obj_ids):
                idx0 = all_idxs[obj_id_enum]
                framenumber = table_kobs_frame[idx0]
                remote_timestamp = numpy.nan

                this_camn = None
                frame_idxs = table_data2d_frames_find.get_idxs_of_equal(framenumber)
                if len(frame_idxs):
                    frame_idx = frame_idxs[0]
                    this_camn = table_data2d_camns[frame_idx]
                    remote_timestamp = table_data2d_timestamps[frame_idx]

                if this_camn is None:
                    print(
                        "skipping frame %d (obj %d): no data2d_distorted data"
                        % (framenumber, obj_id)
                    )
                    continue

                cam_id = camn2cam_id[this_camn]
                try:
                    remote_hostname = cam_id2hostname(cam_id, h5_context)
                except ValueError as e:
                    print("error getting hostname of cam: %s" % e.message)
                    continue
                if remote_hostname not in gain:
                    warnings.warn(
                        "no host %s in timestamp data. making up "
                        "data." % remote_hostname
                    )
                    gain[remote_hostname] = 1.0
                    offset[remote_hostname] = 0.0
                mainbrain_timestamp = (
                    remote_timestamp * gain[remote_hostname] + offset[remote_hostname]
                )  # find mainbrain timestamp

                timestamp_time[obj_id_enum] = mainbrain_timestamp

            extra_vars["obj_ids"] = unique_obj_ids
            extra_vars["timestamps"] = timestamp_time

            print("STAGE 2: running Kalman smoothing operation")

        # also save the experiment data if present
        uuid = None
        try:
            table_experiment = h5_context.get_pytable_node(
                "experiment_info", from_2d_file=True
            )
        except tables.exceptions.NoSuchNodeError:
            pass
        else:
            try:
                uuid = table_experiment.read(field="uuid")
            except (KeyError, tables.exceptions.HDF5ExtError):
                pass
            else:
                extra_vars["experiment_uuid"] = uuid

        recording_header = h5_context.read_textlog_header_2d()
        recording_flydra_version = recording_header["flydra_version"]

        # -----------------------------------------------

        obj_ids = h5_context.get_unique_obj_ids()
        smoothing_flydra_version = h5_context.get_extra_info()["header"][
            "flydra_version"
        ]

        obj_ids = obj_ids[obj_ids >= start_obj_id]
        obj_ids = obj_ids[obj_ids <= stop_obj_id]

        if obj_only is not None:
            obj_ids = numpy.array(obj_only)
            print("filtered to obj_only", obj_ids)

        if frames_per_second is None:
            frames_per_second = h5_context.get_fps()

        if dynamic_model_name is None:
            extra = h5_context.get_extra_info()
            orig_dynamic_model_name = extra.get("dynamic_model_name", None)
            dynamic_model_name = orig_dynamic_model_name
            if dynamic_model_name is None:
                dynamic_model_name = dynamic_models.DEFAULT_MODEL
                warnings.warn(
                    'no dynamic model specified, using "%s"' % dynamic_model_name
                )
            else:
                print(
                    'detected file loaded with dynamic model "%s"' % dynamic_model_name
                )
            if dynamic_model_name.startswith("EKF "):
                dynamic_model_name = dynamic_model_name[4:]
            print('  for smoothing, will use dynamic model "%s"' % dynamic_model_name)

        allrows = []
        allqualrows = []
        failed_quality = False

        if show_progress:
            import progressbar

            class StringWidget(progressbar.Widget):
                def set_string(self, ts):
                    self.ts = ts

                def update(self, pbar):
                    if hasattr(self, "ts"):
                        return self.ts
                    else:
                        return ""

            string_widget = StringWidget()
            objs_per_sec_widget = progressbar.FileTransferSpeed(unit="obj_ids ")
            widgets = [
                string_widget,
                objs_per_sec_widget,
                progressbar.Percentage(),
                progressbar.Bar(),
                progressbar.ETA(),
            ]
            pbar = progressbar.ProgressBar(widgets=widgets, maxval=len(obj_ids)).start()

        for i, obj_id in enumerate(obj_ids):
            if obj_id > stop_obj_id:
                break
            if show_progress:
                string_widget.set_string("[obj_id: % 5d]" % obj_id)
                pbar.update(i)
            if show_progress_json and i % 100 == 0:
                rough_percent_done = float(i) / len(obj_ids) * 100.0
                result_utils.do_json_progress(rough_percent_done)
            try:
                rows = h5_context.load_data(
                    obj_id,
                    dynamic_model_name=dynamic_model_name,
                    frames_per_second=frames_per_second,
                    **kwargs
                )
            except core_analysis.DiscontiguousFramesError:
                warnings.warn(
                    "discontiguous frames smoothing obj_id %d, skipping." % (obj_id,)
                )
                continue
            except core_analysis.NotEnoughDataToSmoothError:
                # warnings.warn('not enough data to smooth obj_id %d, skipping.'%(obj_id,))
                continue
            except numpy.linalg.linalg.LinAlgError:
                warnings.warn(
                    "linear algebra error smoothing obj_id %d, skipping." % (obj_id,)
                )
                continue
            except core_analysis.CouldNotCalculateOrientationError:
                warnings.warn(
                    "orientation error smoothing obj_id %d, skipping." % (obj_id,)
                )
                continue

            allrows.append(rows)
            try:
                qualrows = compute_ori_quality(
                    h5_context, rows["frame"], obj_id, smooth_len=0
                )
                allqualrows.append(qualrows)
            except ValueError:
                failed_quality = True
        if show_progress:
            pbar.finish()

        allrows = numpy.concatenate(allrows)
        if not failed_quality:
            allqualrows = numpy.concatenate(allqualrows)
        else:
            allqualrows = None
        recarray = numpy.rec.array(allrows)

        smoothed_source = "kalman_estimates"

        flydra_analysis.analysis.flydra_analysis_convert_to_mat.do_it(
            rows=recarray,
            ignore_observations=True,
            newfilename=outfilename,
            extra_vars=extra_vars,
            orientation_quality=allqualrows,
            hdf5=hdf5,
            tzname=tzname,
            fps=fps,
            smoothed_source=smoothed_source,
            smoothed_data_filename=smoothed_data_filename,
            raw_data_filename=raw_data_filename,
            dynamic_model_name=orig_dynamic_model_name,
            recording_flydra_version=recording_flydra_version,
            smoothing_flydra_version=smoothing_flydra_version,
        )
        if show_progress_json:
            result_utils.do_json_progress(100)