def get_player(viz=False, train=False): env = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT, viz=viz, live_lost_as_eoe=train, max_num_frames=30000) env = FireResetEnv(env) env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE)) if not train: # in training, history is taken care of in expreplay buffer env = FrameStack(env, FRAME_HISTORY) return env
def get_player(train=False, dumpdir=None): env = gym.make(ENV_NAME) if dumpdir: env = gym.wrappers.Monitor(env, dumpdir) env = FireResetEnv(env) env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE)) env = FrameStack(env, 4) if train: env = LimitLength(env, 60000) return env
def get_player(rom, viz=False, train=False): env = AtariPlayer(rom, frame_skip=ACTION_REPEAT, viz=viz, live_lost_as_eoe=train, max_num_frames=60000) env = FireResetEnv(env) env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE)) if not train: # in training, context is taken care of in expreplay buffer env = FrameStack(env, CONTEXT_LEN) return env
def get_player(viz=False, train=False): if USE_GYM: env = gym.make(ENV_NAME) else: from atari import AtariPlayer env = AtariPlayer(ENV_NAME, frame_skip=4, viz=viz, live_lost_as_eoe=train, max_num_frames=60000) env = FireResetEnv(env) env = MapState(env, lambda im: resize_keepdims(im, IMAGE_SIZE)) if not train: # in training, history is taken care of in expreplay buffer env = FrameStack(env, FRAME_HISTORY) if train and USE_GYM: env = LimitLength(env, 60000) return env
def get_player(rom, image_size, viz=False, train=False, frame_skip=1, context_len=1): env = AtariPlayer(rom, frame_skip=frame_skip, viz=viz, live_lost_as_eoe=train, max_num_frames=60000) env = FireResetEnv(env) env = MapState(env, lambda im: cv2.resize(im, image_size)) if not train: # in training, context is taken care of in expreplay buffer env = FrameStack(env, context_len) return env
def get_player(train=False, dumpdir=None): use_gym = not ENV_NAME.endswith(".bin") if use_gym: env = gym.make(ENV_NAME) else: from atari import AtariPlayer env = AtariPlayer(ENV_NAME, frame_skip=4, viz=False, live_lost_as_eoe=train, max_num_frames=60000, grayscale=False) if dumpdir: env = gym.wrappers.Monitor(env, dumpdir, video_callable=lambda _: True) env = FireResetEnv(env) env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE)) env = FrameStack(env, 4) if train and use_gym: env = LimitLength(env, 60000) return env
self.model.fit(ob_train, target_f, epochs=1, verbose=0) self.teaching = False self.frustration = self.frustration_max self.has_taught = 0 if self.epsilon > self.min_epsilon: self.epsilon *= self.epsilon_decay if __name__ == "__main__": ENV_NAME = 'MsPacman-v0' NUM_ACTIONS = get_player().action_space.n env = gym.make(ENV_NAME) env = FireResetEnv(env) env = MapState(env, lambda im: cv2.resize(im, (84,84))) env = FrameStack(env, 4) pred = OfflinePredictor(PredictConfig( model=Model(), session_init=get_model_loader("models/MsPacman-v0.tfmodel"), input_names=['state'], output_names=['policy'])) student = student_dqn(env, teacher=pred) episodes = 1000000 scores = [] teacher_step_nums = [] step_nums = [] ep_avgs = [0] ep_avg = []