Exemple #1
0
def run_policy(env,
               get_action,
               ckpt_num,
               con,
               max_ep_len=100,
               num_episodes=100,
               render=False,
               record=True,
               video_caption_off=False):

    assert env is not None, \
        "Environment not found!\n\n It looks like the environment wasn't saved, " + \
        "and we can't run the agent in it. :( \n\n Check out the readthedocs " + \
        "page on Experiment Outputs for how to handle this situation."

    logger = EpochLogger()
    o, r, d, ep_ret, ep_len, n = env.reset(), 0, False, 0, 0, 0
    visual_obs = []
    while n < num_episodes:

        vob = render_frame(env,
                           ep_len,
                           ep_ret,
                           'AC',
                           render,
                           record,
                           caption_off=video_caption_off)
        visual_obs.append(vob)
        a = get_action(torch.Tensor(o.reshape(1, -1)))[0]
        o, r, d, _ = env.step(a.detach().numpy()[0])
        ep_ret += r
        ep_len += 1

        if d or (ep_len == max_ep_len):
            vob = render_frame(env,
                               ep_len,
                               ep_ret,
                               'AC',
                               render,
                               record,
                               caption_off=video_caption_off)
            visual_obs.append(vob)  # add last frame
            logger.store(EpRet=ep_ret, EpLen=ep_len)
            print('Episode %d \t EpRet %.3f \t EpLen %d' % (n, ep_ret, ep_len))
            o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
            n += 1

    logger.log_tabular('EpRet %d' % con, with_min_and_max=True)
    logger.log_tabular('EpLen %d' % con, average_only=True)
    logger.dump_tabular()
    if record:
        # temp_info: [video_prefix, ckpt_num, ep_ret, ep_len, con]
        temp_info = ['', ckpt_num, ep_ret, ep_len, con]
        logger.save_video(visual_obs, temp_info)
def run_policy(env,
               get_action,
               ckpt_num,
               max_con,
               con,
               max_ep_len=100,
               num_episodes=100,
               fpath=None,
               render=False,
               record=True,
               video_caption_off=False):

    assert env is not None, \
        "Environment not found!\n\n It looks like the environment wasn't saved, " + \
        "and we can't run the agent in it. :( \n\n Check out the readthedocs " + \
        "page on Experiment Outputs for how to handle this situation."

    output_dir = osp.join(osp.abspath(osp.dirname(osp.dirname(__file__))),
                          "log/tmp/experiments/%i" % int(time.time()))
    logger = EpochLogger(output_dir=output_dir)
    o, r, d, ep_ret, ep_len, n = env.reset(), 0, False, 0, 0, 0
    visual_obs = []
    c_onehot = F.one_hot(torch.tensor(con), max_con).squeeze().float()
    while n < num_episodes:

        vob = render_frame(env,
                           ep_len,
                           ep_ret,
                           'AC',
                           render,
                           record,
                           caption_off=video_caption_off)
        visual_obs.append(vob)
        concat_obs = torch.cat(
            [torch.Tensor(o.reshape(1, -1)),
             c_onehot.reshape(1, -1)], 1)
        a = get_action(concat_obs)
        o, r, d, _ = env.step(a[0].detach().numpy()[0])
        ep_ret += r
        ep_len += 1
        d = False

        if d or (ep_len == max_ep_len):
            vob = render_frame(env,
                               ep_len,
                               ep_ret,
                               'AC',
                               render,
                               record,
                               caption_off=video_caption_off)
            visual_obs.append(vob)  # add last frame
            logger.store(EpRet=ep_ret, EpLen=ep_len)
            print('Episode %d \t EpRet %.3f \t EpLen %d' % (n, ep_ret, ep_len))
            o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
            n += 1

    logger.log_tabular('EpRet', with_min_and_max=True)
    logger.log_tabular('EpLen', average_only=True)
    logger.dump_tabular()
    if record:
        # temp_info: [video_prefix, ckpt_num, ep_ret, ep_len, con]
        temp_info = ['', ckpt_num, ep_ret, ep_len, con]
        logger.save_video(visual_obs, temp_info, fpath)