def main(argv): """ Trains a model through backward RL. """ ref_actions = np.load(os.path.join(DATA_DIR, FLAGS.ref_actions_path)) clip_name, start_step = parse_path(FLAGS.ref_actions_path) make_env_fn = lambda: RefTrackingEnv( clip_name, ref_actions, start_step, reset_step=0) vec_env = SubprocVecEnv([make_env_fn for _ in range(FLAGS.num_workers)]) eval_env = make_env_fn() config_class = SACTrainerConfig train_class = SACTrainer if FLAGS.visualize: tconf = config_class.from_json(FLAGS.config_path) trainer = train_class(vec_env, env, tconf, OUTPUT_DIR) trainer.load_checkpoint(os.path.join(OUTPUT_DIR, FLAGS.checkpoint_path)) env.visualize(trainer.policy, device='cpu') else: tconf = config_class.from_flags(FLAGS) tconf.to_json(os.path.join(OUTPUT_DIR, FLAGS.config_path)) trainer = train_class(vec_env, eval_env, tconf, OUTPUT_DIR) # Generate the curriculum for idx in range(len(ref_actions)): reset_step = len(ref_actions) - (idx + 1) # Modify the environments to reflect the new reset_step vec_env.set_attr('reset_step', reset_step) vec_env.reset() eval_env.reset_step = reset_step eval_env.reset() target_return = eval_env.ref_returns[reset_step] print( f'Curriculum Task {idx}: reset_step {reset_step} target_return {target_return:.3f}' ) trainer.train(target_return) trainer.save_checkpoint( os.path.join(OUTPUT_DIR, FLAGS.checkpoint_path)) vec_env.close() eval_env.close()
os.path.join(os.environ.get('AMLT_DATA_DIR', 'data'), 'eval/CMU_069_02_start_step_0.npy')) reset_step = len(ref_actions) - 16 env = RefTrackingEnv( clip_name='CMU_069_02', ref_actions=ref_actions, start_step=0, reset_step=reset_step, ) # Check set_attr for reset_step is correct make_env_fn = lambda: RefTrackingEnv( 'CMU_069_02', ref_actions, 0, reset_step=len(ref_actions) - 2) vec_env = SubprocVecEnv([make_env_fn for _ in range(2)]) # Change the reset step and make sure we don't get an epsiode termination until expected vec_env.set_attr('reset_step', reset_step) vec_env.reset() for idx, act in enumerate(ref_actions[reset_step:]): acts = np.tile(act, (2, 1)) obs, rew, done, info = vec_env.step(acts) # Ensure we're not done until the last step if idx < len(ref_actions[reset_step:]) - 1: assert not np.any(done) # Make sure we're done on the last step assert np.all(done) vec_env.close() # Check for correctness for _ in range(2): obs = env.reset() # Check that the reset is working properly and we get the expected sequences