Exemplo n.º 1
0
def main(connect_to_mainbrain=True,
         save_osg_info=None, # dict with info
         live_demo = False, # if True, never leave VR mode
         ):
    global EXPERIMENT_STATE
    global EXPERIMENT_STATE_START_TIME
    global mainbrain

    if save_osg_info is not None:
        save_osg = True
    else:
        save_osg = False

    sendsock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)

    if connect_to_mainbrain:
        assert not save_osg, 'cannot connect to mainbrain and save to .osg file'

        # make connection to flydra mainbrain
        my_host = '' # get fully qualified hostname
        my_port = 8322 # arbitrary number

        # create UDP socket object, grab the port
        sockobj = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        print 'binding',( my_host, my_port)
        sockobj.bind(( my_host, my_port))

        # connect to mainbrain
        mainbrain_hostname = 'brain1'
        mainbrain_port = flydra.common_variables.mainbrain_port
        mainbrain_name = 'main_brain'
        remote_URI = "PYROLOC://%s:%d/%s" % (mainbrain_hostname,
                                             mainbrain_port,
                                             mainbrain_name)
        Pyro.core.initClient(banner=0)
        mainbrain = Pyro.core.getProxyForURI(remote_URI)
        mainbrain._setOneway(['log_message'])
        my_host_fqdn = socket.getfqdn(my_host)
        mainbrain.register_downstream_kalman_host(my_host_fqdn,my_port)

        listener = InterpolatingListener(sockobj)
        #listener = Listener(sockobj)
        listen_thread = threading.Thread(target=listener.run)
        listen_thread.setDaemon(True)
        listen_thread.start()
    else:
        if save_osg:
            listener = ReplayListener(save_osg_info['kalman_h5'])
        else:
            listener = InterpolatingListener(None,dummy=True)
        mainbrain = DummyMainbrain()

    screen=Screen(size=(1024,768),
                  bgcolor=(0,0,0),
                  alpha_bits=8,
                  #fullscreen=True,
                  )

    if 1:
        # this loads faster, for debugging only
        mipmaps_enabled=False
    else:
        mipmaps_enabled=True

    DEBUG='none' # no debugging
    #DEBUG='flicker' # alternate views on successive frames

    vr_objects = []

    on_eye_loc_update_funcs = []

    if 0:
        # floor stimulus

        # Get a texture
        filename = 'stones_and_cement.jpg'
        texture = Texture(filename)

        # floor is iso-z rectangle
        x0 = -1
        x1 = 1
        y0 = -.25
        y1 =  .25
        z0 = -.1
        z1 = -.1
        vr_object = FlatRect.FlatRect(texture=texture,
                                      shrink_texture_ok=1,
                                      lowerleft=(x0,y0,z0),
                                      upperleft=(x0,y1,z1),
                                      upperright=(x1,y1,z1),
                                      lowerright=(x1,y0,z0),
                                      tex_phase = 0.2,
                                      mipmaps_enabled=mipmaps_enabled,
                                      depth_test=True,
                                      )
        vr_objects.append( vr_object )

    if FOREST:
        # +y wall stimulus

        if 0:
            # wall iso-y rectangle
            x0 = -1
            x1 = 1
            y0 =  .25
            y1 =  .25
            z0 = -.1
            z1 = .5
            # Get a texture
            filename = 'stones_and_cement.jpg'
            texture = Texture(filename)

            vr_object = FlatRect.FlatRect(texture=texture,
                                          shrink_texture_ok=1,
                                          lowerleft=(x0,y0,z0),
                                          upperleft=(x0,y1,z1),
                                          upperright=(x1,y1,z1),
                                          lowerright=(x1,y0,z0),
                                          tex_phase = 0.2,
                                          mipmaps_enabled=mipmaps_enabled,
                                          depth_test=True,
                                          )
            vr_objects.append(vr_object)
        else:
            # forest of trees in +y direction
            filename = 'tree.png'
            tree_texture = Texture(filename)
            for x0 in np.arange(-1,1,.3):
                for y0 in np.arange(.1, 1.0, .2):
                #for y0 in [0.0]:
                    x1 = x0+0.1

                    y1 = y0
                    z0 = -0.5
                    z1 = 1.0
                    tree = TextureStimulus3D(texture = tree_texture,
                                             shrink_texture_ok=True,
                                             internal_format=gl.GL_RGBA,
                                             lowerleft=(x0,y0,z0),
                                             upperleft=(x0,y1,z1),
                                             upperright=(x1,y1,z1),
                                             lowerright=(x1,y0,z0),
                                             mipmaps_enabled=False,
                                             depth_test=True,
                                             )
                    vr_objects.append(tree)
    if FOREST:
        if 1:
            # forest of trees in -y direction
            filename = 'tree.png'
            tree_texture = Texture(filename)
            for x0 in np.arange(-1,1,.3):
                for y0 in np.arange(-.1, -1.0, -.2):
                #for y0 in [0.0]:
                    x1 = x0+0.1

                    y1 = y0
                    z0 = -0.5
                    z1 = 1.0
                    tree = TextureStimulus3D(texture = tree_texture,
                                             shrink_texture_ok=True,
                                             internal_format=gl.GL_RGBA,
                                             lowerleft=(x0,y0,z0),
                                             upperleft=(x0,y1,z1),
                                             upperright=(x1,y1,z1),
                                             lowerright=(x1,y0,z0),
                                             mipmaps_enabled=False,
                                             depth_test=True,
                                             )
                    vr_objects.append(tree)

    if UNIT_CUBE:
        arena = Arena()
        for face_name, face_verts in arena.face_verts.iteritems():
            face = TextureStimulus3D(texture = arena.texture,
                                     shrink_texture_ok=True,
                                     internal_format=gl.GL_RGBA,
                                     lowerleft=arena.verts[arena.face_verts[face_name][0]],
                                     upperleft=arena.verts[arena.face_verts[face_name][1]],
                                     upperright=arena.verts[arena.face_verts[face_name][2]],
                                     lowerright=arena.verts[arena.face_verts[face_name][3]],
                                     mipmaps_enabled=False,
                                     depth_test=True,
                                     )
            vr_objects.append(face)

    if DRUM:
        filename = os.path.join( os.path.split( __file__ )[0], 'panorama-checkerboard.png')
        texture = Texture(filename)
        # cylinder
        drum = SpinningDrum( position=(0.5,0.5,0.5), # center of WT
                             orientation=0.0,
                             texture=texture,
                             drum_center_elevation=90.0,
                             radius = (0.4),
                             height=0.3,
                             internal_format=gl.GL_RGBA,
                             )
        vr_objects.append(drum)

    if DOTS:
        # due to a bug in Vision Egg, these only show up if drawn last.
        if 1:
            # fly tracker negZ stimulus
            negZ_stim = Rectangle3D(color=(1,1,1,1),
                                     )
            updater = StimulusLocationUpdater(negZ_stim,offset=(0,0,-1),inc=0.1)
            on_eye_loc_update_funcs.append( updater.update )
            vr_objects.append( negZ_stim )

        if 1:
            # fly tracker plusY stimulus
            plusY_stim = Rectangle3D(color=(1,1,1,1),
                                     )
            updater = StimulusLocationUpdater(plusY_stim,
                                              offset=(0,1,0),
                                              inc1_dir=(1,0,0),
                                              inc2_dir=(0,0,1),
                                              inc=0.1)
            on_eye_loc_update_funcs.append( updater.update )
            vr_objects.append( plusY_stim )

        if 1:
            # fly tracker negY stimulus
            negY_stim = Rectangle3D(color=(1,1,1,1),
                                     )
            updater = StimulusLocationUpdater(negY_stim,
                                              offset=(0,-1,0),
                                              inc1_dir=(1,0,0),
                                              inc2_dir=(0,0,1),
                                              inc=0.1)
            on_eye_loc_update_funcs.append( updater.update )
            vr_objects.append( negY_stim )


    vr_walls = {}
    arena = Arena()
    if 0:
        # test
        ## corners_3d = [
        ##     ( 0.5,  0.15, 0.01),
        ##     (-0.5,  0.15, 0.01),
        ##     (-0.5, -0.15, 0.01),
        ##     ( 0.5, -0.15, 0.01),
        ##               ]
        testz = 0.
        ## corners_3d = [(-0.5, -0.15, testz),
        ##               ( 0.5, -0.15, testz),
        ##               ( 0.5,  0.15, testz),
        ##               (-0.5,  0.15, testz)]
        corners_3d = [(-0.5, 0.15, testz),
                      ( 0.5, 0.15, testz),
                      ( 0.5, -0.15, testz),
                      (-0.5, -0.15, testz)]

        corners_2d = [ (0,210),
                       (799,210),
                       (799,401),
                       (0,399)]
        name = 'test'
        approx_view_dir = (0,0,-1) # down
        vr_wall_dict = get_wall_dict( SHOW, screen,
                                      corners_3d, corners_2d, vr_objects,
                                      approx_view_dir, name)
        vr_walls[name] = vr_wall_dict
        del testz
    if 0:
        # test2
        testy = 0.15
        corners_3d = [(-0.5, testy, 0),
                      ( -0.5, testy, 0.3),
                      ( 0.5, testy, 0.3),
                      (0.5, testy, 0)]
        del testy

        corners_2d = [ (0,0),
                       (799,0),
                       (799,200),
                       (0,200)]
        name = 'test2'
        approx_view_dir = (0,0,-1) # down
        vr_wall_dict = get_wall_dict( SHOW, screen,
                                      corners_3d, corners_2d, vr_objects,
                                      approx_view_dir, name)
        vr_walls[name] = vr_wall_dict
    if 0:
        # floor
        # Measured Mar 24, 2009. Used calibration cal20081120.xml.
        corners_3d = arena.faces_3d['-z']
        corners_2d = [
            (1,401),
            (800,399),
            (800,209),
            (1,208),
            ]
        name = 'floor'
        approx_view_dir = None
        vr_wall_dict = get_wall_dict( SHOW, screen,
                                      corners_3d, corners_2d, vr_objects,
                                      approx_view_dir, name)
        vr_walls[name] = vr_wall_dict
    if 1:
        # order: LL, UL, UR, LR
        # +y wall
        corners_3d = arena.faces_3d['+y']

        corners_2d = [
            # LL
            (513, 1),
            # UL
            (513,768),
            #UR
            (1024,550),
            # LR
            (1024,50),
            ]

        name = '+y'
        approx_view_dir = None
        vr_wall_dict = get_wall_dict( SHOW, screen,
                                      corners_3d, corners_2d, vr_objects,
                                      approx_view_dir,name)
        vr_walls[name] = vr_wall_dict
    if 1:
        # order: LL, UL, UR, LR
        # -y wall
        corners_3d = arena.faces_3d['+x']

        corners_2d = [
            # LL
            (1,50),
            # UL
            (1,620),
            # UR
            (512,768),
            # LR
            (512,1),
            ]

        name = '-y'
        approx_view_dir = None
        vr_wall_dict = get_wall_dict( SHOW, screen,
                                      corners_3d, corners_2d, vr_objects,
                                      approx_view_dir,name)
        vr_walls[name] = vr_wall_dict

    if SHOW in ['overview','projector (calibrated)']:
        screen_stimuli = []
        for wall in vr_walls.itervalues():
            screen_stimuli.extend( wall['display_stimuli'] )

        if SHOW=='overview':
            # draw dot where VR camera is for overview
            VR_eye_stim = Rectangle3D(color=(.2,.2,.2,1), # gray
                                      depth_test=True,#requires VE 1.1.1.1
                                      )
            fly_stim_updater=StimulusLocationUpdater(VR_eye_stim)
            on_eye_loc_update_funcs.append( fly_stim_updater.update )

            VR_stimuli = []
            for wall in vr_walls.itervalues():
                VR_stimuli.extend( wall['cam_viewport'].parameters.stimuli )
            display_viewport = Viewport(
                screen=screen,
                projection=SimplePerspectiveProjection(fov_x=90.0),
                stimuli=VR_stimuli+screen_stimuli+[VR_eye_stim,
                                                   ],
                )

        elif SHOW=='projector (calibrated)':
            display_viewport = Viewport(
                screen=screen,
                stimuli=screen_stimuli,
                )
    else:
        # parse e.g. SHOW='cam:floor'
        camname = SHOW[4:]
        display_viewport = vr_walls[camname]['cam_viewport']

    last_log_message_time = -np.inf

    # OpenGL textures must be power of 2
    def next_power_of_2(f):
        return math.pow(2.0,math.ceil(math.log(f)/math.log(2.0)))
    fb_width_pow2  = int(next_power_of_2(screen.size[0]))
    fb_height_pow2  = int(next_power_of_2(screen.size[1]))

    if save_osg:
        listener.set_obj_frame(save_osg_info['obj'],save_osg_info['frame'])

    if 1:
        # initialize
        if live_demo:
            EXPERIMENT_STATE = 0
        else:
            EXPERIMENT_STATE = -1
        tmp = listener.get_fly_xyz()
        if tmp is not None:
            obj_id,fly_xyz,framenumber = tmp
        else:
            fly_xyz = None
            obj_id = None
            framenumber = None
        del tmp
        if not save_osg:
            advance_state(fly_xyz,obj_id,framenumber,sendsock)
        else:
            warnings.warn('save_osg mode -- forcing experiment state')
            EXPERIMENT_STATE = 1

    frame_timer = FrameTimer()
    quit_now = False
    while not quit_now:
        # infinite loop to draw stimuli
        if save_osg:
            quit_now = True # quit on next round

        # test for keypress or mouseclick to quit
        for event in pygame.event.get():
            if event.type in (QUIT,MOUSEBUTTONDOWN):
            #if event.type in (QUIT,KEYDOWN,MOUSEBUTTONDOWN):
                quit_now = True

        ## now = time.time()
        ## if now-last_log_message_time > 5.0: # log a message every 5 seconds
        ##     mainbrain.log_message('<wtstim>',
        ##                           time.time(),
        ##                           'This is my message.' )
        ##     last_log_message_time = now

        near = 0.001
        far = 10.0

        tmp = listener.get_fly_xyz(prefer_obj_id=obj_id)
        if tmp is not None:
            obj_id,fly_xyz,framenumber = tmp
        else:
            fly_xyz = None
            obj_id = None
            framenumber = None
        del tmp

        if not save_osg:
            if should_i_advance_to_next_state(fly_xyz):
                if not live_demo:
                    advance_state(fly_xyz,obj_id,framenumber,sendsock)

        state_string = EXPERIMENT_STATES[EXPERIMENT_STATE]
        state_string_split = state_string.split()

        draw_stimuli = True
        if state_string_split[0]=='waiting':
            draw_stimuli=False
        elif state_string=='doing gray only stimulus':
            draw_stimuli=False

        if fly_xyz is not None:
            if state_string.startswith('doing static'):
                fly_xyz = STATIC_FLY_XYZ

            for wall in vr_walls.itervalues():
                wall['screen_data'].update_VE_viewport( wall['cam_viewport'],
                                                        fly_xyz, near, far,
                                                        avoid_clipping=True)

            for func in on_eye_loc_update_funcs: # only used to draw dots and in overview mode
                func(fly_xyz)
        else:
            # no recent data
            draw_stimuli = False

        # render fly-eye views and copy to texture objects if necessary
        #screen.set(bgcolor=(.4,.4,0)) # ??
        for wallname,wall in vr_walls.iteritems():
            if wallname=='test':
                screen.set(bgcolor=(1,0,0)) # red
            else:
                screen.set(bgcolor=BGCOLOR)
            screen.clear() # clear screen

            if draw_stimuli:
                if DRUM and state_string.endswith('drum radius 0.1'):
                    drum.set(radius = 0.1)
                elif DRUM and state_string.endswith('drum radius 0.3'):
                    drum.set(radius = 0.3)
                # render fly-eye view
                wall['cam_viewport'].draw()

            if SHOW in ['overview','projector (calibrated)']:
                framebuffer_texture_object = wall['framebuffer_texture_object']

                # copy screen back-buffer to texture
                framebuffer_texture_object.put_new_framebuffer(
                    size=(fb_width_pow2,fb_height_pow2),
                    internal_format=gl.GL_RGB,
                    buffer='back',
                    )
            if save_osg:
                # save screen back buffer to image file
                pil_image = screen.get_framebuffer_as_image(
                    format=gl.GL_RGBA)
                if not os.path.exists(save_osg_info['dirname']):
                    os.mkdir(save_osg_info['dirname'])
                wall_fname = wallname+'.png'
                wall_full_path = os.path.join( save_osg_info['dirname'],
                                               wall_fname )
                print 'saving %s'%wall_fname
                pil_image.save(wall_full_path)

            if DEBUG=='flicker':
                swap_buffers() # swap buffers
                #time.sleep(3.0)

        if save_osg:
            save_wall_models( vr_walls, save_osg_info)

        if SHOW=='overview':
            now = time.time()
            overview_movement_tf = 0.1
            theta = (2*pi*now * overview_movement_tf)
            overview_eye_loc = (-.5 + 0.1*np.cos( theta ), # x
                                -1.5 + 0.1*np.sin( theta ), # y
                                2.0) # z

            camera_matrix = ModelView()
            camera_matrix.look_at( overview_eye_loc, # eye
                                   eye_loc_default, # look at fly center
                                   #screen_data.t_3d[:3,0], # look at upper left corner
                                   (0,0,1), # up
                                   )

            display_viewport.set(camera_matrix=camera_matrix)

        # clear screen again
        if SHOW=='overview':
            screen.set(bgcolor=(0.0,0.0,0.8)) # blue
        else:
            screen.set(bgcolor=(0.0,0.0,0.0)) #black
        screen.clear() # clear screen
        display_viewport.draw() # draw the viewport and hence the stimuli

        swap_buffers() # swap buffers
        frame_timer.tick() # notify the frame time logger that we just drew a frame
    frame_timer.log_histogram() # print frame interval histogram