Example #1
0
def sim(file="",
        isdeterm=True,
        speedup=1,
        max_path_length=200,
        animated=True,
        wrap_env=False):
    # If the snapshot file use tensorflow, do:
    # import tensorflow as tf
    # with tf.Session():
    #     [rest of the code]
    if isdeterm:
        rollout_fn = deterministic_rollout
    else:
        rollout_fn = rollout

    with tf.Session() as sess:
        data = joblib.load(file)
        policy = data['policy']
        if not wrap_env:
            env = data['env']
        else:
            env = VaryMassRolloutWrapper(data['env'])
            print("m0: ", env.env.m0)
            print("mf: ", env.env.mf)
            # env = TfEnv(GymEnv("MyPendulum-v0", record_video=False))
            # print("mass: ", env._wrapped_env.env.env.m)
        while True:
            path = rollout_fn(env,
                              policy,
                              max_path_length=max_path_length,
                              animated=animated,
                              speedup=speedup,
                              always_return_paths=True)
            import pdb
            pdb.set_trace()
            if not query_yes_no('Continue simulation?'):
                break
Example #2
0
def main():
    env = normalize(load_class('aa_simulation.envs.empty_env',
        Env, ["rllab", "envs"])())
    state = env.reset()
    env.render()

    t = 0
    max_t = 100

    while True:

        while t < max_t:
            if t < max_t / 2:
                action = np.array([1.0, np.deg2rad(15)])
            else:
                action = np.array([1.0, -np.deg2rad(15)])
            nextstate, reward, done, _ = env.step(action)
            env.render()
            t += 1

        if query_yes_no('Continue simulation?'):
            t = 0
        else:
            break
Example #3
0
                        type=bool,
                        default=False,
                        help='Whether or not to prompt for more sim')
    args = parser.parse_args()

    max_tries = 10
    tri = 0
    while True:
        tri += 1
        with tf.Session() as sess:
            data = joblib.load(args.file)
            policy = data['policy']
            env = data['env']
            while True:
                path = rollout(env,
                               policy,
                               max_path_length=args.max_path_length,
                               animated=True,
                               speedup=args.speedup,
                               video_filename=args.video_filename)
                if args.prompt:
                    if not query_yes_no('Continue simulation?'):
                        break
                else:
                    break
            #import pdb; pdb.set_trace()
        if len(path['rewards']) < args.max_path_length and tri >= max_tries:
            tf.reset_default_graph()
            continue
        break
Example #4
0
                        help='Max length of rollout')
    parser.add_argument('--speedup', type=float, default=1,
                        help='Speedup')
    parser.add_argument('--video_filename', type=str,
                        help='path to the out video file')
    parser.add_argument('--prompt', type=bool, default=False,
                        help='Whether or not to prompt for more sim')
    args = parser.parse_args()

    max_tries = 10
    tri = 0
    while True:
        tri += 1
        with tf.Session() as sess:
            data = joblib.load(args.file)
            policy = data['policy']
            env = data['env']
            while True:
                path = rollout(env, policy, max_path_length=args.max_path_length,
                               animated=True, speedup=args.speedup, video_filename=args.video_filename)
                if args.prompt:
                    if not query_yes_no('Continue simulation?'):
                        break
                else:
                    break
            #import pdb; pdb.set_trace()
        if len(path['rewards']) < args.max_path_length and tri >= max_tries:
            tf.reset_default_graph()
            continue
        break