예제 #1
0
def main(argv):
    """ Trains a model through backward RL. """
    ref_actions = np.load(os.path.join(DATA_DIR, FLAGS.ref_actions_path))
    clip_name, start_step = parse_path(FLAGS.ref_actions_path)

    make_env_fn = lambda: RefTrackingEnv(
        clip_name, ref_actions, start_step, reset_step=0)
    vec_env = SubprocVecEnv([make_env_fn for _ in range(FLAGS.num_workers)])
    eval_env = make_env_fn()

    config_class = SACTrainerConfig
    train_class = SACTrainer

    if FLAGS.visualize:
        tconf = config_class.from_json(FLAGS.config_path)
        trainer = train_class(vec_env, env, tconf, OUTPUT_DIR)
        trainer.load_checkpoint(os.path.join(OUTPUT_DIR,
                                             FLAGS.checkpoint_path))
        env.visualize(trainer.policy, device='cpu')
    else:
        tconf = config_class.from_flags(FLAGS)
        tconf.to_json(os.path.join(OUTPUT_DIR, FLAGS.config_path))
        trainer = train_class(vec_env, eval_env, tconf, OUTPUT_DIR)

        # Generate the curriculum
        for idx in range(len(ref_actions)):
            reset_step = len(ref_actions) - (idx + 1)

            # Modify the environments to reflect the new reset_step
            vec_env.set_attr('reset_step', reset_step)
            vec_env.reset()
            eval_env.reset_step = reset_step
            eval_env.reset()

            target_return = eval_env.ref_returns[reset_step]
            print(
                f'Curriculum Task {idx}: reset_step {reset_step} target_return {target_return:.3f}'
            )

            trainer.train(target_return)
            trainer.save_checkpoint(
                os.path.join(OUTPUT_DIR, FLAGS.checkpoint_path))

    vec_env.close()
    eval_env.close()
예제 #2
0
        os.path.join(os.environ.get('AMLT_DATA_DIR', 'data'),
                     'eval/CMU_069_02_start_step_0.npy'))
    reset_step = len(ref_actions) - 16
    env = RefTrackingEnv(
        clip_name='CMU_069_02',
        ref_actions=ref_actions,
        start_step=0,
        reset_step=reset_step,
    )

    # Check set_attr for reset_step is correct
    make_env_fn = lambda: RefTrackingEnv(
        'CMU_069_02', ref_actions, 0, reset_step=len(ref_actions) - 2)
    vec_env = SubprocVecEnv([make_env_fn for _ in range(2)])
    # Change the reset step and make sure we don't get an epsiode termination until expected
    vec_env.set_attr('reset_step', reset_step)
    vec_env.reset()
    for idx, act in enumerate(ref_actions[reset_step:]):
        acts = np.tile(act, (2, 1))
        obs, rew, done, info = vec_env.step(acts)
        # Ensure we're not done until the last step
        if idx < len(ref_actions[reset_step:]) - 1:
            assert not np.any(done)
    # Make sure we're done on the last step
    assert np.all(done)
    vec_env.close()

    # Check for correctness
    for _ in range(2):
        obs = env.reset()
        # Check that the reset is working properly and we get the expected sequences