def __init__(self, src_h5, dst_h5): self.ca = core_analysis.get_global_CachingAnalyzer() self.src_h5 = src_h5 self.dst_h5 = dst_h5 self.ca.initial_file_load(self.src_h5) tmp = self.ca.initial_file_load(self.dst_h5) obj_ids, unique_obj_ids, is_mat_file, data_file, extra = tmp self.dst_frames = extra["frames"] self.dst_obj_ids = obj_ids self.dst_unique_obj_ids = unique_obj_ids self.ff = utils.FastFinder(self.dst_frames)
def convert( infilename, outfilename, frames_per_second=None, save_timestamps=True, file_time_data=None, do_nothing=False, # set to true to test for file existance start_obj_id=None, stop_obj_id=None, obj_only=None, dynamic_model_name=None, hdf5=False, show_progress=False, show_progress_json=False, **kwargs): if start_obj_id is None: start_obj_id = -numpy.inf if stop_obj_id is None: stop_obj_id = numpy.inf smoothed_data_filename = os.path.split(infilename)[1] raw_data_filename = smoothed_data_filename ca = core_analysis.get_global_CachingAnalyzer() with ca.kalman_analysis_context(infilename, data2d_fname=file_time_data) as h5_context: extra_vars = {} tzname = None if save_timestamps: print 'STAGE 1: finding timestamps' table_kobs = h5_context.get_pytable_node('ML_estimates') tzname = h5_context.get_tzname0() fps = h5_context.get_fps() try: table_data2d = h5_context.get_pytable_node('data2d_distorted', from_2d_file=True) except tables.exceptions.NoSuchNodeError, err: print >> sys.stderr, "No timestamps in file. Either specify not to save timestamps ('--no-timestamps') or specify the original .h5 file with the timestamps ('--time-data=FILE2D')" sys.exit(1) print 'caching raw 2D data...', sys.stdout.flush() table_data2d_frames = table_data2d.read(field='frame') assert numpy.max(table_data2d_frames) < 2**63 table_data2d_frames = table_data2d_frames.astype(numpy.int64) #table_data2d_frames_find = fastsearch.binarysearch.BinarySearcher( table_data2d_frames ) table_data2d_frames_find = utils.FastFinder(table_data2d_frames) table_data2d_camns = table_data2d.read(field='camn') table_data2d_timestamps = table_data2d.read(field='timestamp') print 'done' print '(cached index of %d frame values of dtype %s)' % ( len(table_data2d_frames), str(table_data2d_frames.dtype)) drift_estimates = h5_context.get_drift_estimates() camn2cam_id, cam_id2camns = h5_context.get_caminfo_dicts() gain = {} offset = {} print 'hostname time_gain time_offset' print '-------- --------- -----------' for i, hostname in enumerate(drift_estimates.get('hostnames', [])): tgain, toffset = result_utils.model_remote_to_local( drift_estimates['remote_timestamp'][hostname][::10], drift_estimates['local_timestamp'][hostname][::10]) gain[hostname] = tgain offset[hostname] = toffset print ' ', repr(hostname), tgain, toffset print if do_nothing: return print 'caching Kalman obj_ids...' obs_obj_ids = table_kobs.read(field='obj_id') fast_obs_obj_ids = utils.FastFinder(obs_obj_ids) print 'finding unique obj_ids...' unique_obj_ids = numpy.unique(obs_obj_ids) print '(found %d)' % (len(unique_obj_ids), ) unique_obj_ids = unique_obj_ids[unique_obj_ids >= start_obj_id] unique_obj_ids = unique_obj_ids[unique_obj_ids <= stop_obj_id] if obj_only is not None: unique_obj_ids = numpy.array( [oid for oid in unique_obj_ids if oid in obj_only]) print 'filtered to obj_only', obj_only print '(will export %d)' % (len(unique_obj_ids), ) print 'finding 2d data for each obj_id...' timestamp_time = numpy.zeros(unique_obj_ids.shape, dtype=numpy.float64) table_kobs_frame = table_kobs.read(field='frame') if len(table_kobs_frame) == 0: raise ValueError('no 3D data, cannot convert') assert numpy.max(table_kobs_frame) < 2**63 table_kobs_frame = table_kobs_frame.astype(numpy.int64) assert table_kobs_frame.dtype == table_data2d_frames.dtype # otherwise very slow all_idxs = fast_obs_obj_ids.get_idx_of_equal(unique_obj_ids) for obj_id_enum, obj_id in enumerate(unique_obj_ids): idx0 = all_idxs[obj_id_enum] framenumber = table_kobs_frame[idx0] remote_timestamp = numpy.nan this_camn = None frame_idxs = table_data2d_frames_find.get_idxs_of_equal( framenumber) if len(frame_idxs): frame_idx = frame_idxs[0] this_camn = table_data2d_camns[frame_idx] remote_timestamp = table_data2d_timestamps[frame_idx] if this_camn is None: print 'skipping frame %d (obj %d): no data2d_distorted data' % ( framenumber, obj_id) continue cam_id = camn2cam_id[this_camn] try: remote_hostname = cam_id2hostname(cam_id, h5_context) except ValueError, e: print 'error getting hostname of cam: %s' % e.message continue if remote_hostname not in gain: warnings.warn('no host %s in timestamp data. making up ' 'data.' % remote_hostname) gain[remote_hostname] = 1.0 offset[remote_hostname] = 0.0 mainbrain_timestamp = remote_timestamp * gain[ remote_hostname] + offset[ remote_hostname] # find mainbrain timestamp timestamp_time[obj_id_enum] = mainbrain_timestamp extra_vars['obj_ids'] = unique_obj_ids extra_vars['timestamps'] = timestamp_time print 'STAGE 2: running Kalman smoothing operation'
def find_equiv(self, src_obj_id, mean_distance_maximum=None): """find the obj_id in the dst file that corresponds to src_obj_id arguments --------- src_obj_id : int The obj_id of the object in src_h5 to find. mean_distance_maximum : float or None The maximum average distance between points in dst and src. returns ------- dst_obj_id : int The obj_id in dst_h5 that corresponds to the src_obj_id """ # get information from source to identify trace in dest src_rows = self.ca.load_data(src_obj_id, self.src_h5, use_kalman_smoothing=False) src_frame = src_rows["frame"] if len(src_frame) < 2: raise ValueError("Can only find equivalent obj_id if " "2 or more frames present") src_X = np.vstack((src_rows["x"], src_rows["y"], src_rows["z"])) src_timestamp = src_rows["timestamp"] candidate_obj_id = set() for f in src_frame: idxs = self.ff.get_idxs_of_equal(f) for obj_id in self.dst_obj_ids[idxs]: candidate_obj_id.add(obj_id) candidate_obj_id = list(candidate_obj_id) ## print 'candidate_obj_id',candidate_obj_id error = [] for obj_id in candidate_obj_id: # get array for each candidation obj_id in destination dst_rows = self.ca.load_data(obj_id, self.dst_h5, use_kalman_smoothing=False) dst_frame = dst_rows["frame"] dst_X = np.vstack((dst_rows["x"], dst_rows["y"], dst_rows["z"])) dst_ff = utils.FastFinder(dst_frame) # get indices into destination array for each frame of source dst_idxs = dst_ff.get_idx_of_equal(src_frame, missing_ok=1) assert len(dst_idxs) == len(src_frame) missing_cond = dst_idxs == -1 # these points are in source but not dest n_missing = np.sum(missing_cond) n_total = len(src_frame) present_cond = ~missing_cond final_dst_idxs = dst_idxs[present_cond] final_src_idxs = np.arange(len(src_frame))[present_cond] src_X_i = src_X[:, final_src_idxs] dst_X_i = dst_X[:, final_dst_idxs] diff = src_X_i - dst_X_i dist = np.sqrt(np.sum(diff**2, axis=0)) av_dist = np.mean(dist) frac_missing = n_missing / float( n_total) # 0 = none missing, 1 = all ## print 'candidate dst obj_id %d: %s dist, %s missing'%( ## obj_id, av_dist, frac_missing) if frac_missing > 0.1: this_error = np.inf else: this_error = av_dist error.append(this_error) idx = np.argmin(error) best_error = error[idx] if not np.isfinite(best_error): return None # could not find answer else: if (mean_distance_maximum is None) or (best_error <= mean_distance_maximum): return candidate_obj_id[idx] else: return None
def iterate_frames(h5_filename, ufmf_fnames, # or fmfs white_background=False, max_n_frames = None, start = None, stop = None, rgb8_if_color=False, movie_cam_ids=None, camn2cam_id = None, ): """yield frame-by-frame data""" # First pass over .ufmf files: get intersection of timestamps first_ufmf_ts = -np.inf last_ufmf_ts = np.inf ufmfs = {} cam_ids = [] global_data = {'width_heights': {}} for movie_idx,ufmf_fname in enumerate(ufmf_fnames): if movie_cam_ids is not None: cam_id = movie_cam_ids[movie_idx] else: cam_id = get_cam_id_from_ufmf_fname(ufmf_fname) cam_ids.append( cam_id ) kwargs = {} extra = {} if ufmf_fname.lower().endswith('.fmf'): ufmf = fmf_mod.FlyMovie(ufmf_fname) bg_fmf_filename = os.path.splitext(ufmf_fname)[0] + '_mean.fmf' if os.path.exists(bg_fmf_filename): extra['bg_fmf'] = fmf_mod.FlyMovie(bg_fmf_filename) extra['bg_tss'] = extra['bg_fmf'].get_all_timestamps() extra['bg_fmf'].seek(0) else: ufmf = ufmf_mod.FlyMovieEmulator(ufmf_fname, white_background=white_background, **kwargs) global_data['width_heights'][cam_id] = ( ufmf.get_width(), ufmf.get_height() ) tss = ufmf.get_all_timestamps() ufmf.seek(0) ufmfs[ufmf_fname] = (ufmf, cam_id, tss, extra) min_ts = np.min(tss) max_ts = np.max(tss) if min_ts > first_ufmf_ts: first_ufmf_ts = min_ts if max_ts < last_ufmf_ts: last_ufmf_ts = max_ts assert first_ufmf_ts < last_ufmf_ts, ".ufmf files don't all overlap in time" ufmf_fnames.sort() cam_ids.sort() with open_file_safe( h5_filename, mode='r' ) as h5: if camn2cam_id is None: camn2cam_id, cam_id2camns = result_utils.get_caminfo_dicts(h5) parsed = result_utils.read_textlog_header(h5) flydra_version = parsed.get('flydra_version',None) if flydra_version is not None and flydra_version >= '0.4.45': # camnode.py saved timestamps into .ufmf file given by # time.time() (camn_receive_timestamp). Compare with # mainbrain's data2d_distorted column # 'cam_received_timestamp'. old_camera_timestamp_source = False timestamp_name = 'cam_received_timestamp' else: # camnode.py saved timestamps into .ufmf file given by # camera driver. Compare with mainbrain's data2d_distorted # column 'timestamp'. old_camera_timestamp_source = True timestamp_name = 'timestamp' h5_data = h5.root.data2d_distorted[:] if 1: # narrow search to local region of .h5 cond = ((first_ufmf_ts <= h5_data[timestamp_name]) & (h5_data[timestamp_name] <= last_ufmf_ts)) narrow_h5_data = h5_data[cond] narrow_camns = narrow_h5_data['camn'] narrow_timestamps = narrow_h5_data[timestamp_name] # Find the camn for each .ufmf file cam_id2camn = {} for cam_id in cam_ids: cam_id_camn_already_found = False for ufmf_fname in ufmfs.keys(): (ufmf, test_cam_id, tss, extra) = ufmfs[ufmf_fname] if cam_id != test_cam_id: continue assert not cam_id_camn_already_found cam_id_camn_already_found = True umin=np.min(tss) umax=np.max(tss) cond = (umin<=narrow_timestamps) & (narrow_timestamps<=umax) ucamns = narrow_camns[cond] ucamns = np.unique(ucamns) camns = [] for camn in ucamns: if camn2cam_id[camn]==cam_id: camns.append(camn) assert len(camns)<2, "can't handle multiple camns per cam_id" if len(camns): cam_id2camn[cam_id] = camns[0] ff = utils.FastFinder(narrow_h5_data['frame']) unique_frames = list(np.unique(narrow_h5_data['frame'])) unique_frames.sort() unique_frames = np.array( unique_frames ) if start is not None: unique_frames = unique_frames[ unique_frames >= start ] if stop is not None: unique_frames = unique_frames[ unique_frames <= stop ] if max_n_frames is not None: unique_frames = unique_frames[:max_n_frames] for frame_enum,frame in enumerate(unique_frames): narrow_idxs = ff.get_idxs_of_equal(frame) # trim data under consideration to just this frame this_h5_data = narrow_h5_data[narrow_idxs] this_camns = this_h5_data['camn'] this_tss = this_h5_data[timestamp_name] # a couple more checks if np.any( this_tss < first_ufmf_ts): continue if np.any( this_tss >= last_ufmf_ts): break per_frame_dict = {} for ufmf_fname in ufmf_fnames: ufmf, cam_id, tss, extra = ufmfs[ufmf_fname] if cam_id not in cam_id2camn: continue camn = cam_id2camn[cam_id] this_camn_cond = this_camns == camn this_cam_h5_data = this_h5_data[this_camn_cond] this_camn_tss = this_cam_h5_data[timestamp_name] if not len(this_camn_tss): # no h5 data for this cam_id at this frame continue this_camn_ts=np.unique(this_camn_tss) assert len(this_camn_ts)==1 this_camn_ts = this_camn_ts[0] if isinstance(ufmf, ufmf_mod.FlyMovieEmulator): is_real_ufmf = True else: is_real_ufmf = False # optimistic: get next frame. it's probably the one we want try: if is_real_ufmf: image,image_ts,more = ufmf.get_next_frame(_return_more=True) else: image,image_ts = ufmf.get_next_frame() more = fill_more_for( extra, image_ts ) except ufmf_mod.NoMoreFramesException: image_ts = None if this_camn_ts != image_ts: # It was not the frame we wanted. Find it. ufmf_frame_idxs = np.nonzero(tss == this_camn_ts)[0] if (len(ufmf_frame_idxs)==0 and old_camera_timestamp_source): warnings.warn( 'low-precision timestamp comparison in ' 'use due to outdated .ufmf timestamp ' 'saving.') # 2.5 msec precision required ufmf_frame_idxs = np.nonzero( abs( tss - this_camn_ts ) < 0.0025)[0] assert len(ufmf_frame_idxs)==1 ufmf_frame_no = ufmf_frame_idxs[0] if is_real_ufmf: image,image_ts,more = ufmf.get_frame(ufmf_frame_no, _return_more=True) else: image,image_ts = ufmf.get_frame(ufmf_frame_no) more = fill_more_for( extra, image_ts ) del ufmf_frame_no, ufmf_frame_idxs coding = ufmf.get_format() if imops.is_coding_color(coding): if rgb8_if_color: image = imops.to_rgb8(coding,image) else: warnings.warn('color image not converted to color') per_frame_dict[ufmf_fname] = { 'image':image, 'cam_id':cam_id, 'camn':camn, 'timestamp':this_cam_h5_data['timestamp'][0], 'cam_received_timestamp': this_cam_h5_data['cam_received_timestamp'][0], 'ufmf_frame_timestamp':this_cam_h5_data[timestamp_name][0], } if more is not None: per_frame_dict[ufmf_fname].update(more) per_frame_dict['tracker_data']=this_h5_data per_frame_dict['global_data']=global_data # on every iteration, pass our global data yield (per_frame_dict,frame)
def plot_timeseries(subplot=None, options=None): kalman_filename = options.kalman_filename if not hasattr(options, 'frames'): options.frames = False if not hasattr(options, 'show_landing'): options.show_landing = False if not hasattr(options, 'unicolor'): options.unicolor = False if not hasattr(options, 'show_obj_id'): options.show_obj_id = True if not hasattr(options, 'show_track_ends'): options.show_track_ends = False start = options.start stop = options.stop obj_only = options.obj_only fps = options.fps dynamic_model = options.dynamic_model use_kalman_smoothing = options.use_kalman_smoothing if not use_kalman_smoothing: if (dynamic_model is not None): print >> sys.stderr, ( 'WARNING: disabling Kalman smoothing ' '(--disable-kalman-smoothing) is incompatable ' 'with setting dynamic model options (--dynamic-model)') ca = core_analysis.get_global_CachingAnalyzer() if kalman_filename is None: raise ValueError('No kalman_filename given. Nothing to do.') m = hashlib.md5() m.update(open(kalman_filename, mode='rb').read()) actual_md5 = m.hexdigest() (obj_ids, use_obj_ids, is_mat_file, data_file, extra) = ca.initial_file_load(kalman_filename) print 'opened kalman file %s %s, %d obj_ids' % ( kalman_filename, actual_md5, len(use_obj_ids)) if 'frames' in extra: if (start is not None) or (stop is not None): valid_frames = np.ones((len(extra['frames']), ), dtype=np.bool) if start is not None: valid_frames &= extra['frames'] >= start if stop is not None: valid_frames &= extra['frames'] <= stop this_use_obj_ids = np.unique(obj_ids[valid_frames]) use_obj_ids = list(set(use_obj_ids).intersection(this_use_obj_ids)) include_obj_ids = None exclude_obj_ids = None do_fuse = False if options.stim_xml: file_timestamp = data_file.filename[4:19] fanout = xml_stimulus.xml_fanout_from_filename(options.stim_xml) include_obj_ids, exclude_obj_ids = fanout.get_obj_ids_for_timestamp( timestamp_string=file_timestamp) walking_start_stops = fanout.get_walking_start_stops_for_timestamp( timestamp_string=file_timestamp) if include_obj_ids is not None: use_obj_ids = include_obj_ids if exclude_obj_ids is not None: use_obj_ids = list(set(use_obj_ids).difference(exclude_obj_ids)) if options.fuse: do_fuse = True else: walking_start_stops = [] if dynamic_model is None: dynamic_model = extra['dynamic_model_name'] print 'detected file loaded with dynamic model "%s"' % dynamic_model if dynamic_model.startswith('EKF '): dynamic_model = dynamic_model[4:] print ' for smoothing, will use dynamic model "%s"' % dynamic_model if not is_mat_file: mat_data = None if fps is None: fps = result_utils.get_fps(data_file, fail_on_error=False) if fps is None: fps = 100.0 import warnings warnings.warn('Setting fps to default value of %f' % fps) tz = result_utils.get_tz(data_file) dt = 1.0 / fps all_vels = [] if obj_only is not None: use_obj_ids = [i for i in use_obj_ids if i in obj_only] allX = {} frame0 = None line2obj_id = {} Xz_all = [] fuse_did_once = False if not hasattr(options, 'timestamp_file'): options.timestamp_file = None if not hasattr(options, 'ori_qual'): options.ori_qual = None if options.timestamp_file is not None: h5 = tables.open_file(options.timestamp_file, mode='r') print 'reading timestamps and frames' table_data2d_frames = h5.root.data2d_distorted.read(field='frame') table_data2d_timestamps = h5.root.data2d_distorted.read( field='timestamp') print 'done' h5.close() table_data2d_frames_find = utils.FastFinder(table_data2d_frames) if len(use_obj_ids) == 0: print 'No obj_ids to plot, quitting' sys.exit(0) time0 = 0.0 # set default value for obj_id in use_obj_ids: if not do_fuse: try: kalman_rows = ca.load_data( obj_id, data_file, use_kalman_smoothing=use_kalman_smoothing, dynamic_model_name=dynamic_model, return_smoothed_directions=options.smooth_orientations, frames_per_second=fps, up_dir=options.up_dir, min_ori_quality_required=options.ori_qual, ) except core_analysis.ObjectIDDataError: continue #kobs_rows = ca.load_dynamics_free_MLE_position( obj_id, data_file ) else: if options.show_3d_orientations: raise NotImplementedError('orientation data is not supported ' 'when fusing obj_ids') if fuse_did_once: break fuse_did_once = True kalman_rows = flydra_analysis.a2.flypos.fuse_obj_ids( use_obj_ids, data_file, dynamic_model_name=dynamic_model, frames_per_second=fps) frame = kalman_rows['frame'] if (start is not None) or (stop is not None): valid_cond = numpy.ones(frame.shape, dtype=numpy.bool) if start is not None: valid_cond = valid_cond & (frame >= start) if stop is not None: valid_cond = valid_cond & (frame <= stop) kalman_rows = kalman_rows[valid_cond] if not len(kalman_rows): continue walking_and_flying_kalman_rows = kalman_rows # preserve original data for flystate in ['flying', 'walking']: frame = walking_and_flying_kalman_rows['frame'] # restore if flystate == 'flying': # assume flying unless we're told it's walking state_cond = numpy.ones(frame.shape, dtype=numpy.bool) else: state_cond = numpy.zeros(frame.shape, dtype=numpy.bool) if len(walking_start_stops): for walkstart, walkstop in walking_start_stops: frame = walking_and_flying_kalman_rows['frame'] # restore # handle each bout of walking walking_bout = numpy.ones(frame.shape, dtype=numpy.bool) if walkstart is not None: walking_bout &= (frame >= walkstart) if walkstop is not None: walking_bout &= (frame <= walkstop) if flystate == 'flying': state_cond &= ~walking_bout else: state_cond |= walking_bout kalman_rows = np.take(walking_and_flying_kalman_rows, np.nonzero(state_cond)[0]) assert len(kalman_rows) == np.sum(state_cond) frame = kalman_rows['frame'] if frame0 is None: frame0 = int(frame[0]) time0 = 0.0 if options.timestamp_file is not None: frame_idxs = table_data2d_frames_find.get_idxs_of_equal(frame0) if len(frame_idxs): time0 = table_data2d_timestamps[frame_idxs[0]] else: raise ValueError( 'could not fine frame %d in timestamp file' % frame0) Xx = kalman_rows['x'] Xy = kalman_rows['y'] Xz = kalman_rows['z'] Dx = Dy = Dz = None if options.smooth_orientations: Dx = kalman_rows['dir_x'] Dy = kalman_rows['dir_y'] Dz = kalman_rows['dir_z'] elif 'rawdir_x' in kalman_rows.dtype.fields: Dx = kalman_rows['rawdir_x'] Dy = kalman_rows['rawdir_y'] Dz = kalman_rows['rawdir_z'] if not options.frames: f2t = Frames2Time(frame0, fps, time0) else: def identity(x): return x f2t = identity kws = { 'linewidth': 2, 'picker': 5, } if options.unicolor: kws['color'] = 'k' line = None if 'frame' in subplot: subplot['frame'].plot(f2t(frame), frame) if 'P55' in subplot: subplot['P55'].plot(f2t(frame), kalman_rows['P55']) if 'x' in subplot: line, = subplot['x'].plot(f2t(frame), Xx, label='obj %d (%s)' % (obj_id, flystate), **kws) line2obj_id[line] = obj_id kws['color'] = line.get_color() if 'y' in subplot: line, = subplot['y'].plot(f2t(frame), Xy, label='obj %d (%s)' % (obj_id, flystate), **kws) line2obj_id[line] = obj_id kws['color'] = line.get_color() if 'z' in subplot: frame_data = numpy.ma.getdata( frame) # works if frame is masked or not # plot landing time if options.show_landing: if flystate == 'flying': # only do this once for walkstart, walkstop in walking_start_stops: if walkstart in frame_data: landing_dix = numpy.nonzero( frame_data == walkstart)[0][0] subplot['z'].plot([f2t(walkstart)], [Xz.data[landing_dix]], 'rD', ms=10, label='landing') if options.show_track_ends: if flystate == 'flying': # only do this once subplot['z'].plot(f2t([frame_data[0], frame_data[-1]]), [ numpy.ma.getdata(Xz)[0], numpy.ma.getdata(Xz)[-1] ], 'cd', ms=6, label='track end') line, = subplot['z'].plot(f2t(frame), Xz, label='obj %d (%s)' % (obj_id, flystate), **kws) kws['color'] = line.get_color() line2obj_id[line] = obj_id if flystate == 'flying': # only do this once if options.show_obj_id: subplot['z'].text(f2t(frame_data[0]), numpy.ma.getdata(Xz)[0], '%d' % (obj_id, )) line2obj_id[line] = obj_id if flystate == 'flying': Xz_all.append(np.ma.array(Xz).compressed()) #bins = np.linspace(0,.8,30) #print 'Xz.shape',Xz.shape #pylab.hist(Xz, bins=bins) for (dir_var, Dd) in [('dx', Dx), ('dy', Dy), ('dz', Dz)]: if dir_var in subplot: line, = subplot[dir_var].plot(f2t(frame), Dd, label='obj %d (%s)' % (obj_id, flystate), **kws) line2obj_id[line] = obj_id kws['color'] = line.get_color() if numpy.__version__ >= '1.2.0': X = numpy.ma.array((Xx, Xy, Xz)) else: # See http://scipy.org/scipy/numpy/ticket/820 X = numpy.ma.vstack( (Xx[numpy.newaxis, :], Xy[numpy.newaxis, :], Xz[numpy.newaxis, :])) dist_central_diff = (X[:, 2:] - X[:, :-2]) vel_central_diff = dist_central_diff / (2 * dt) vel2mag = numpy.ma.sqrt(numpy.ma.sum(vel_central_diff**2, axis=0)) xy_vel2mag = numpy.ma.sqrt( numpy.ma.sum(vel_central_diff[:2, :]**2, axis=0)) frames2 = frame[1:-1] accel4mag = (vel2mag[2:] - vel2mag[:-2]) / (2 * dt) frames4 = frames2[1:-1] if 'vel' in subplot: line, = subplot['vel'].plot(f2t(frames2), vel2mag, label='obj %d (%s)' % (obj_id, flystate), **kws) line2obj_id[line] = obj_id kws['color'] = line.get_color() if 'xy_vel' in subplot: line, = subplot['xy_vel'].plot(f2t(frames2), xy_vel2mag, label='obj %d (%s)' % (obj_id, flystate), **kws) line2obj_id[line] = obj_id kws['color'] = line.get_color() if len(accel4mag.compressed()) and 'accel' in subplot: line, = subplot['accel'].plot(f2t(frames4), accel4mag, label='obj %d (%s)' % (obj_id, flystate), **kws) line2obj_id[line] = obj_id kws['color'] = line.get_color() if flystate == 'flying': valid_vel2mag = vel2mag.compressed() all_vels.append(valid_vel2mag) if len(all_vels): all_vels = numpy.hstack(all_vels) else: all_vels = numpy.array([], dtype=float) if 1: cond = all_vels < 2.0 if numpy.ma.sum(cond) != len(all_vels): all_vels = all_vels[cond] import warnings warnings.warn('clipping all velocities > 2.0 m/s') if not options.frames: xlabel = 'time (s)' else: xlabel = 'frame' for ax in subplot.itervalues(): ax.xaxis.set_major_formatter(ticker.FormatStrFormatter("%d")) ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%s")) fixup_ax = FixupAxesWithTimeZone(tz).fixup_ax if 'frame' in subplot: if time0 != 0.0: fixup_ax(subplot['frame']) else: subplot['frame'].set_xlabel(xlabel) if 'x' in subplot: subplot['x'].set_ylim([-1, 1]) subplot['x'].set_ylabel(r'x (m)') if time0 != 0.0: fixup_ax(subplot['x']) else: subplot['x'].set_xlabel(xlabel) if 'y' in subplot: subplot['y'].set_ylim([-0.5, 1.5]) subplot['y'].set_ylabel(r'y (m)') if time0 != 0.0: fixup_ax(subplot['y']) else: subplot['y'].set_xlabel(xlabel) max_z = None if options.stim_xml: file_timestamp = options.kalman_filename[4:19] stim_xml = xml_stimulus.xml_stimulus_from_filename( options.stim_xml, timestamp_string=file_timestamp) post_max_zs = [] for post_num, post in enumerate(stim_xml.iterate_posts()): post_max_zs.append(max(post['verts'][0][2], post['verts'][1][2])) # max post height if len(post_max_zs): max_z = min(post_max_zs) # take shortest of posts if 'z' in subplot: subplot['z'].set_ylim([0, 1]) subplot['z'].set_ylabel(r'z (m)') if max_z is not None: subplot['z'].axhline(max_z, color='m') if time0 != 0.0: fixup_ax(subplot['z']) else: subplot['z'].set_xlabel(xlabel) for dir_var in ['dx', 'dy', 'dz']: if dir_var in subplot: subplot[dir_var].set_ylabel(dir_var) if time0 != 0.0: fixup_ax(subplot[dir_var]) else: subplot[dir_var].set_xlabel(xlabel) if 'z_hist' in subplot: # and flystate=='flying': Xz_all = np.hstack(Xz_all) bins = np.linspace(0, .8, 30) ax = subplot['z_hist'] ax.hist(Xz_all, bins=bins, orientation='horizontal') ax.set_xticks([]) ax.set_yticks([]) xlim = tuple(ax.get_xlim()) # matplotlib 0.98.3 returned np.array view ax.set_xlim((xlim[1], xlim[0])) ax.axhline(max_z, color='m') if 'vel' in subplot: subplot['vel'].set_ylim([0, 2]) subplot['vel'].set_ylabel(r'vel (m/s)') subplot['vel'].set_xlabel(xlabel) if time0 != 0.0: fixup_ax(subplot['vel']) else: subplot['vel'].set_xlabel(xlabel) if 'xy_vel' in subplot: #subplot['xy_vel'].set_ylim([0,2]) subplot['xy_vel'].set_ylabel(r'horiz vel (m/s)') subplot['xy_vel'].set_xlabel(xlabel) if time0 != 0.0: fixup_ax(subplot['xy_vel']) else: subplot['xy_vel'].set_xlabel(xlabel) if 'accel' in subplot: subplot['accel'].set_ylabel(r'acceleration (m/(s^2))') subplot['accel'].set_xlabel(xlabel) if time0 != 0.0: fixup_ax(subplot['accel']) else: subplot['accel'].set_xlabel(xlabel) if 'vel_hist' in subplot: ax = subplot['vel_hist'] bins = numpy.linspace(0, 2, 50) ax.set_title('excluding walking') pdf, bins, patches = ax.hist(all_vels, bins=bins, normed=True) ax.set_xlim(0, 2) ax.set_ylabel('probability density') ax.set_xlabel('velocity (m/s)') return line2obj_id
def convert( infilename, outfilename, frames_per_second=None, save_timestamps=True, file_time_data=None, do_nothing=False, # set to true to test for file existance start_obj_id=None, stop_obj_id=None, obj_only=None, dynamic_model_name=None, hdf5=False, show_progress=False, show_progress_json=False, **kwargs ): if start_obj_id is None: start_obj_id = -numpy.inf if stop_obj_id is None: stop_obj_id = numpy.inf smoothed_data_filename = os.path.split(infilename)[1] raw_data_filename = smoothed_data_filename ca = core_analysis.get_global_CachingAnalyzer() with ca.kalman_analysis_context( infilename, data2d_fname=file_time_data ) as h5_context: extra_vars = {} tzname = None if save_timestamps: print("STAGE 1: finding timestamps") table_kobs = h5_context.get_pytable_node("ML_estimates") tzname = h5_context.get_tzname0() fps = h5_context.get_fps() try: table_data2d = h5_context.get_pytable_node( "data2d_distorted", from_2d_file=True ) except tables.exceptions.NoSuchNodeError as err: print( "No timestamps in file. Either specify not to save timestamps ('--no-timestamps') or specify the original .h5 file with the timestamps ('--time-data=FILE2D')", file=sys.stderr, ) sys.exit(1) print("caching raw 2D data...", end=" ") sys.stdout.flush() table_data2d_frames = table_data2d.read(field="frame") assert numpy.max(table_data2d_frames) < 2 ** 63 table_data2d_frames = table_data2d_frames.astype(numpy.int64) # table_data2d_frames_find = fastsearch.binarysearch.BinarySearcher( table_data2d_frames ) table_data2d_frames_find = utils.FastFinder(table_data2d_frames) table_data2d_camns = table_data2d.read(field="camn") table_data2d_timestamps = table_data2d.read(field="timestamp") print("done") print( "(cached index of %d frame values of dtype %s)" % (len(table_data2d_frames), str(table_data2d_frames.dtype)) ) drift_estimates = h5_context.get_drift_estimates() camn2cam_id, cam_id2camns = h5_context.get_caminfo_dicts() gain = {} offset = {} print("hostname time_gain time_offset") print("-------- --------- -----------") for i, hostname in enumerate(drift_estimates.get("hostnames", [])): tgain, toffset = result_utils.model_remote_to_local( drift_estimates["remote_timestamp"][hostname][::10], drift_estimates["local_timestamp"][hostname][::10], ) gain[hostname] = tgain offset[hostname] = toffset print(" ", repr(hostname), tgain, toffset) print() if do_nothing: return print("caching Kalman obj_ids...") obs_obj_ids = table_kobs.read(field="obj_id") fast_obs_obj_ids = utils.FastFinder(obs_obj_ids) print("finding unique obj_ids...") unique_obj_ids = numpy.unique(obs_obj_ids) print("(found %d)" % (len(unique_obj_ids),)) unique_obj_ids = unique_obj_ids[unique_obj_ids >= start_obj_id] unique_obj_ids = unique_obj_ids[unique_obj_ids <= stop_obj_id] if obj_only is not None: unique_obj_ids = numpy.array( [oid for oid in unique_obj_ids if oid in obj_only] ) print("filtered to obj_only", obj_only) print("(will export %d)" % (len(unique_obj_ids),)) print("finding 2d data for each obj_id...") timestamp_time = numpy.zeros(unique_obj_ids.shape, dtype=numpy.float64) table_kobs_frame = table_kobs.read(field="frame") if len(table_kobs_frame) == 0: raise ValueError("no 3D data, cannot convert") assert numpy.max(table_kobs_frame) < 2 ** 63 table_kobs_frame = table_kobs_frame.astype(numpy.int64) assert ( table_kobs_frame.dtype == table_data2d_frames.dtype ) # otherwise very slow all_idxs = fast_obs_obj_ids.get_idx_of_equal(unique_obj_ids) for obj_id_enum, obj_id in enumerate(unique_obj_ids): idx0 = all_idxs[obj_id_enum] framenumber = table_kobs_frame[idx0] remote_timestamp = numpy.nan this_camn = None frame_idxs = table_data2d_frames_find.get_idxs_of_equal(framenumber) if len(frame_idxs): frame_idx = frame_idxs[0] this_camn = table_data2d_camns[frame_idx] remote_timestamp = table_data2d_timestamps[frame_idx] if this_camn is None: print( "skipping frame %d (obj %d): no data2d_distorted data" % (framenumber, obj_id) ) continue cam_id = camn2cam_id[this_camn] try: remote_hostname = cam_id2hostname(cam_id, h5_context) except ValueError as e: print("error getting hostname of cam: %s" % e.message) continue if remote_hostname not in gain: warnings.warn( "no host %s in timestamp data. making up " "data." % remote_hostname ) gain[remote_hostname] = 1.0 offset[remote_hostname] = 0.0 mainbrain_timestamp = ( remote_timestamp * gain[remote_hostname] + offset[remote_hostname] ) # find mainbrain timestamp timestamp_time[obj_id_enum] = mainbrain_timestamp extra_vars["obj_ids"] = unique_obj_ids extra_vars["timestamps"] = timestamp_time print("STAGE 2: running Kalman smoothing operation") # also save the experiment data if present uuid = None try: table_experiment = h5_context.get_pytable_node( "experiment_info", from_2d_file=True ) except tables.exceptions.NoSuchNodeError: pass else: try: uuid = table_experiment.read(field="uuid") except (KeyError, tables.exceptions.HDF5ExtError): pass else: extra_vars["experiment_uuid"] = uuid recording_header = h5_context.read_textlog_header_2d() recording_flydra_version = recording_header["flydra_version"] # ----------------------------------------------- obj_ids = h5_context.get_unique_obj_ids() smoothing_flydra_version = h5_context.get_extra_info()["header"][ "flydra_version" ] obj_ids = obj_ids[obj_ids >= start_obj_id] obj_ids = obj_ids[obj_ids <= stop_obj_id] if obj_only is not None: obj_ids = numpy.array(obj_only) print("filtered to obj_only", obj_ids) if frames_per_second is None: frames_per_second = h5_context.get_fps() if dynamic_model_name is None: extra = h5_context.get_extra_info() orig_dynamic_model_name = extra.get("dynamic_model_name", None) dynamic_model_name = orig_dynamic_model_name if dynamic_model_name is None: dynamic_model_name = dynamic_models.DEFAULT_MODEL warnings.warn( 'no dynamic model specified, using "%s"' % dynamic_model_name ) else: print( 'detected file loaded with dynamic model "%s"' % dynamic_model_name ) if dynamic_model_name.startswith("EKF "): dynamic_model_name = dynamic_model_name[4:] print(' for smoothing, will use dynamic model "%s"' % dynamic_model_name) allrows = [] allqualrows = [] failed_quality = False if show_progress: import progressbar class StringWidget(progressbar.Widget): def set_string(self, ts): self.ts = ts def update(self, pbar): if hasattr(self, "ts"): return self.ts else: return "" string_widget = StringWidget() objs_per_sec_widget = progressbar.FileTransferSpeed(unit="obj_ids ") widgets = [ string_widget, objs_per_sec_widget, progressbar.Percentage(), progressbar.Bar(), progressbar.ETA(), ] pbar = progressbar.ProgressBar(widgets=widgets, maxval=len(obj_ids)).start() for i, obj_id in enumerate(obj_ids): if obj_id > stop_obj_id: break if show_progress: string_widget.set_string("[obj_id: % 5d]" % obj_id) pbar.update(i) if show_progress_json and i % 100 == 0: rough_percent_done = float(i) / len(obj_ids) * 100.0 result_utils.do_json_progress(rough_percent_done) try: rows = h5_context.load_data( obj_id, dynamic_model_name=dynamic_model_name, frames_per_second=frames_per_second, **kwargs ) except core_analysis.DiscontiguousFramesError: warnings.warn( "discontiguous frames smoothing obj_id %d, skipping." % (obj_id,) ) continue except core_analysis.NotEnoughDataToSmoothError: # warnings.warn('not enough data to smooth obj_id %d, skipping.'%(obj_id,)) continue except numpy.linalg.linalg.LinAlgError: warnings.warn( "linear algebra error smoothing obj_id %d, skipping." % (obj_id,) ) continue except core_analysis.CouldNotCalculateOrientationError: warnings.warn( "orientation error smoothing obj_id %d, skipping." % (obj_id,) ) continue allrows.append(rows) try: qualrows = compute_ori_quality( h5_context, rows["frame"], obj_id, smooth_len=0 ) allqualrows.append(qualrows) except ValueError: failed_quality = True if show_progress: pbar.finish() allrows = numpy.concatenate(allrows) if not failed_quality: allqualrows = numpy.concatenate(allqualrows) else: allqualrows = None recarray = numpy.rec.array(allrows) smoothed_source = "kalman_estimates" flydra_analysis.analysis.flydra_analysis_convert_to_mat.do_it( rows=recarray, ignore_observations=True, newfilename=outfilename, extra_vars=extra_vars, orientation_quality=allqualrows, hdf5=hdf5, tzname=tzname, fps=fps, smoothed_source=smoothed_source, smoothed_data_filename=smoothed_data_filename, raw_data_filename=raw_data_filename, dynamic_model_name=orig_dynamic_model_name, recording_flydra_version=recording_flydra_version, smoothing_flydra_version=smoothing_flydra_version, ) if show_progress_json: result_utils.do_json_progress(100)