Пример #1
0
def get_player(viz=False, train=False):
    env = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT, viz=viz,
                      live_lost_as_eoe=train, max_num_frames=30000)
    env = FireResetEnv(env)
    env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE))
    if not train:
        # in training, history is taken care of in expreplay buffer
        env = FrameStack(env, FRAME_HISTORY)
    return env
Пример #2
0
def get_player(train=False, dumpdir=None):
    env = gym.make(ENV_NAME)
    if dumpdir:
        env = gym.wrappers.Monitor(env, dumpdir)
    env = FireResetEnv(env)
    env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE))
    env = FrameStack(env, 4)
    if train:
        env = LimitLength(env, 60000)
    return env
Пример #3
0
def get_player(rom, viz=False, train=False):
    env = AtariPlayer(rom,
                      frame_skip=ACTION_REPEAT,
                      viz=viz,
                      live_lost_as_eoe=train,
                      max_num_frames=60000)
    env = FireResetEnv(env)
    env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE))
    if not train:
        # in training, context is taken care of in expreplay buffer
        env = FrameStack(env, CONTEXT_LEN)
    return env
Пример #4
0
def get_player(viz=False, train=False):
    if USE_GYM:
        env = gym.make(ENV_NAME)
    else:
        from atari import AtariPlayer
        env = AtariPlayer(ENV_NAME, frame_skip=4, viz=viz,
                          live_lost_as_eoe=train, max_num_frames=60000)
    env = FireResetEnv(env)
    env = MapState(env, lambda im: resize_keepdims(im, IMAGE_SIZE))
    if not train:
        # in training, history is taken care of in expreplay buffer
        env = FrameStack(env, FRAME_HISTORY)
    if train and USE_GYM:
        env = LimitLength(env, 60000)
    return env
Пример #5
0
def get_player(rom,
               image_size,
               viz=False,
               train=False,
               frame_skip=1,
               context_len=1):
    env = AtariPlayer(rom,
                      frame_skip=frame_skip,
                      viz=viz,
                      live_lost_as_eoe=train,
                      max_num_frames=60000)
    env = FireResetEnv(env)
    env = MapState(env, lambda im: cv2.resize(im, image_size))
    if not train:
        # in training, context is taken care of in expreplay buffer
        env = FrameStack(env, context_len)
    return env
Пример #6
0
def get_player(train=False, dumpdir=None):
    use_gym = not ENV_NAME.endswith(".bin")
    if use_gym:
        env = gym.make(ENV_NAME)
    else:
        from atari import AtariPlayer
        env = AtariPlayer(ENV_NAME,
                          frame_skip=4,
                          viz=False,
                          live_lost_as_eoe=train,
                          max_num_frames=60000,
                          grayscale=False)
    if dumpdir:
        env = gym.wrappers.Monitor(env, dumpdir, video_callable=lambda _: True)
    env = FireResetEnv(env)
    env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE))
    env = FrameStack(env, 4)
    if train and use_gym:
        env = LimitLength(env, 60000)
    return env
Пример #7
0
            self.model.fit(ob_train, target_f, epochs=1, verbose=0)
        
        self.teaching = False
        self.frustration = self.frustration_max
        self.has_taught = 0
        
        if self.epsilon > self.min_epsilon:
            self.epsilon *= self.epsilon_decay

if __name__ == "__main__":
    ENV_NAME = 'MsPacman-v0'
    NUM_ACTIONS = get_player().action_space.n
    env = gym.make(ENV_NAME)
    
    env = FireResetEnv(env)
    env = MapState(env, lambda im: cv2.resize(im, (84,84)))
    env = FrameStack(env, 4)

    pred = OfflinePredictor(PredictConfig(
                model=Model(),
                session_init=get_model_loader("models/MsPacman-v0.tfmodel"),
                input_names=['state'],
                output_names=['policy']))

    student = student_dqn(env, teacher=pred)
    episodes = 1000000
    scores = []
    teacher_step_nums = []
    step_nums = []
    ep_avgs = [0]
    ep_avg = []