def run_policy(env, get_action, ckpt_num, con, max_ep_len=100, num_episodes=100, render=False, record=True, video_caption_off=False): assert env is not None, \ "Environment not found!\n\n It looks like the environment wasn't saved, " + \ "and we can't run the agent in it. :( \n\n Check out the readthedocs " + \ "page on Experiment Outputs for how to handle this situation." logger = EpochLogger() o, r, d, ep_ret, ep_len, n = env.reset(), 0, False, 0, 0, 0 visual_obs = [] while n < num_episodes: vob = render_frame(env, ep_len, ep_ret, 'AC', render, record, caption_off=video_caption_off) visual_obs.append(vob) a = get_action(torch.Tensor(o.reshape(1, -1)))[0] o, r, d, _ = env.step(a.detach().numpy()[0]) ep_ret += r ep_len += 1 if d or (ep_len == max_ep_len): vob = render_frame(env, ep_len, ep_ret, 'AC', render, record, caption_off=video_caption_off) visual_obs.append(vob) # add last frame logger.store(EpRet=ep_ret, EpLen=ep_len) print('Episode %d \t EpRet %.3f \t EpLen %d' % (n, ep_ret, ep_len)) o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0 n += 1 logger.log_tabular('EpRet %d' % con, with_min_and_max=True) logger.log_tabular('EpLen %d' % con, average_only=True) logger.dump_tabular() if record: # temp_info: [video_prefix, ckpt_num, ep_ret, ep_len, con] temp_info = ['', ckpt_num, ep_ret, ep_len, con] logger.save_video(visual_obs, temp_info)
def run_policy(env, get_action, ckpt_num, max_con, con, max_ep_len=100, num_episodes=100, fpath=None, render=False, record=True, video_caption_off=False): assert env is not None, \ "Environment not found!\n\n It looks like the environment wasn't saved, " + \ "and we can't run the agent in it. :( \n\n Check out the readthedocs " + \ "page on Experiment Outputs for how to handle this situation." output_dir = osp.join(osp.abspath(osp.dirname(osp.dirname(__file__))), "log/tmp/experiments/%i" % int(time.time())) logger = EpochLogger(output_dir=output_dir) o, r, d, ep_ret, ep_len, n = env.reset(), 0, False, 0, 0, 0 visual_obs = [] c_onehot = F.one_hot(torch.tensor(con), max_con).squeeze().float() while n < num_episodes: vob = render_frame(env, ep_len, ep_ret, 'AC', render, record, caption_off=video_caption_off) visual_obs.append(vob) concat_obs = torch.cat( [torch.Tensor(o.reshape(1, -1)), c_onehot.reshape(1, -1)], 1) a = get_action(concat_obs) o, r, d, _ = env.step(a[0].detach().numpy()[0]) ep_ret += r ep_len += 1 d = False if d or (ep_len == max_ep_len): vob = render_frame(env, ep_len, ep_ret, 'AC', render, record, caption_off=video_caption_off) visual_obs.append(vob) # add last frame logger.store(EpRet=ep_ret, EpLen=ep_len) print('Episode %d \t EpRet %.3f \t EpLen %d' % (n, ep_ret, ep_len)) o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0 n += 1 logger.log_tabular('EpRet', with_min_and_max=True) logger.log_tabular('EpLen', average_only=True) logger.dump_tabular() if record: # temp_info: [video_prefix, ckpt_num, ep_ret, ep_len, con] temp_info = ['', ckpt_num, ep_ret, ep_len, con] logger.save_video(visual_obs, temp_info, fpath)