def GetTrainer(self): """Creates an IRLtrainer instance.""" policy = self._agent.generator irl = self._agent.discriminator args = self._get_args_from_params() trainer = IRLTrainer(policy=policy, env=self._environment, args=args, irl=irl, expert_obs=self._expert_trajs["obses"], expert_next_obs=self._expert_trajs["next_obses"], expert_act=self._expert_trajs["acts"]) return trainer
import gym from tf2rl.algos.ddpg import DDPG from tf2rl.algos.vail import VAIL from tf2rl.experiments.irl_trainer import IRLTrainer from tf2rl.experiments.utils import restore_latest_n_traj if __name__ == '__main__': parser = IRLTrainer.get_argument() parser = VAIL.get_argument(parser) parser.add_argument('--env-name', type=str, default="RoboschoolReacher-v1") args = parser.parse_args() if args.expert_path_dir is None: print("Plaese generate demonstrations first") print("python examples/run_sac.py --env-name=RoboschoolReacher-v1 --save-test-path --test-interval=50000") exit() units = [400, 300] env = gym.make(args.env_name) test_env = gym.make(args.env_name) policy = DDPG( state_shape=env.observation_space.shape, action_dim=env.action_space.high.size, max_action=env.action_space.high[0], gpu=args.gpu, actor_units=units, critic_units=units, n_warmup=10000,