Ejemplo n.º 1
0
def main():
    filename = sys.argv[1]
    df = read_detection_csv.get_df(filename)
    A, Ainv = read_detection_csv.get_A_Ainv(filename)
    dt = calc_dt(df)

    frame = None
    observations = []
    results_filename = os.path.splitext(filename)[0] + '.kalmanized.csv'
    tracker = Tracker(dt, results_filename, meters_to_pixels=Ainv)
    for _, row in df.iterrows():
        # print('======= new row\n',row)
        if frame is None:
            # first frame
            frame = int(row['frame'])
            frame0 = frame
            # print('-------------- first frame')

        if row['frame'] > frame:
            # next frame
            tracker.handle_frame_data(frame, observations)
            # print('-------------- frame advance')
            observations = []
            frame = int(row['frame'])

        # print('-------------- within frame')
        observations.append((row['x'], row['y']))
        # current frame

        # if frame >= frame0+100:
        #     break
    tracker.handle_frame_data(frame, observations)
    tracker.close()
Ejemplo n.º 2
0
def get_dt(kalman_filename):
    fd = open(kalman_filename, mode='r')
    buf = fd.read(100)
    if buf.startswith('#'):
        line0 = buf.split('\n')[0]
        assert line0.startswith('# ')
        yaml_buf = line0[2:]
        yaml_data = yaml.safe_load(yaml_buf)
        dt = yaml_data['dt']
    else:
        base1 = os.path.splitext(kalman_filename)[0]
        base2 = os.path.splitext(base1)[0]
        csv_filename = base2 + '.csv'
        df = read_detection_csv.get_df(csv_filename)
        dt = read_detection_csv.calculate_dt_1(df['timestamp'].values)
    return dt
Ejemplo n.º 3
0
                else:
                    print('******************* no on-food bout detected: %s' %
                          fname)

                if skip_indiv_plots:
                    continue

                if count > 5:
                    continue
                count += 1

                # ----------------------------------------

                if 1:

                    raw_df = read_detection_csv.get_df(csv_filename)

                    plt.figure()
                    plt.imshow(jpeg,
                               interpolation='nearest',
                               cmap='gray',
                               zorder=-100)
                    plt.colorbar()
                    for obj_id, obj_df in kalman_df.groupby('obj_id'):
                        plt.plot(obj_df['pos_x_pix'], obj_df['pos_y_pix'], '-')
                    plt.plot(raw_df['x_px'],
                             raw_df['y_px'],
                             'k.',
                             ms=5,
                             zorder=-99)
                    plt.xlabel('x (px)')
Ejemplo n.º 4
0
import pandas as pd
import sys
import matplotlib.pyplot as plt
import numpy as np
import os
import scipy.misc
import read_detection_csv

fname = sys.argv[1]

df = pd.read_csv(fname, comment='#')
base1 = os.path.splitext(fname)[0]
base2 = os.path.splitext(base1)[0]
raw_df = read_detection_csv.get_df(base2 + '.csv')

if 1:
    jpeg_fname = base2 + '.jpg'
    jpeg = scipy.misc.imread(
        jpeg_fname)  # install python pillow package if AttributeError

    plt.figure()
    plt.imshow(jpeg, interpolation='nearest', cmap='gray', zorder=-100)
    plt.colorbar()
    plt.xlabel('x (px)')
    plt.ylabel('y (px)')

    for (obj_id, fly_df) in df.groupby('obj_id'):
        plt.plot(fly_df['pos_x_pix'],
                 fly_df['pos_y_pix'],
                 lw=2,
                 label='id %d' % obj_id)
ax_cntl = fig.add_subplot(2, 1, 2)

for (category, category_df) in metadata_df.groupby('category'):
    print("category: %s ------" % category)
    for (genotype, genotype_df) in category_df.groupby('genotype'):
        print("  genotype: %s ------" % genotype)
        for (food_name, food_df) in genotype_df.groupby('food_type'):
            print("    food_name: %s ------" % food_name)
            for (fname, fly_df) in food_df.groupby('filename'):
                print("      fname: %s" % fname)
                food_x = fly_df.iloc[0]['food_x']
                food_y = fly_df.iloc[0]['food_y']

                jpeg_fname = os.path.splitext(fname)[0] + '.jpg'
                jpeg = scipy.misc.imread(jpeg_fname)
                df = read_detection_csv.get_df(fname)

                food_distance = np.sqrt((df['x'] - food_x)**2 +
                                        (df['y'] - food_y)**2)
                food_dist_threshold = fly_df['food_radius'].iloc[0]

                first_valid_frame = fly_df['first_valid_frame'].iloc[0]
                print('first_valid_frame', first_valid_frame)

                on_food_condition = food_distance < food_dist_threshold
                on_food_idxs = np.nonzero(on_food_condition)[0]

                print(on_food_idxs.shape)
                if len(on_food_idxs) >= 1:
                    if not np.isnan(first_valid_frame):
                        # We had some bad tracking prior to this time, do not consider