Example #1
0
    def setup(self):
        args = p.parse_args()
        self.register_agent(SeedingAgent("SeedingAgent", self, args.seed))
        self.register_agent(RewardScalingAgent(
            "RewardScaler", self, reward_scaling=args.reward_scaling, cost_scaling=args.cost_scaling))

        self.register_agent(EpisodeTypeControlAgent('EpisodeTypeController', self, args.eval_mode,
                                                    args.min_explore_steps, args.exploit_freq))  # type: EpisodeTypeControlAgent

        model_loader = None
        if args.eval_mode:
            model_loader = self.register_agent(ModelLoadAgent('ModelLoader', self, None, os.path.join(
                args.rl_logdir, args.env_id, args.eval_run, 'checkpoints'), in_sequence=args.eval_in_sequence, wait_for_new=False))
        exp_buff_agent = None
        if not args.eval_mode:
            exp_buff_agent = self.register_agent(ExperienceBufferAgent(
                "ExpBuffAgent", self, args.nsteps, args.gamma, args.cost_gamma, args.exp_buff_len, None, not args.no_ignore_done_on_timelimit))

        convs = list(filter(lambda x: x != [0], [
                     args.conv1, args.conv2, args.conv3]))

        sac_agent = self.register_agent(SACDiscreteAgent('SACDiscreteAgent', self, convs, args.hiddens,
                                                         args.train_freq, args.sgd_steps, args.mb_size, args.dqn_mse_loss, args.gamma, args.nsteps,
                                                         args.td_clip, args.grad_clip, args.lr, args.a_lr, args.eval_mode, args.min_explore_steps, None if args.eval_mode else exp_buff_agent.experience_buffer, args.sac_alpha, args.fix_alpha))  # type: SACDiscreteAgent
        if args.eval_mode:
            model_loader.model = sac_agent.actor

        if not args.eval_mode:
            self.register_agent(ModelCopyAgent('TargetNetCopier1', self, sac_agent.critic1,
                                               sac_agent.target_critic1, 1, args.polyak, args.min_explore_steps))
            self.register_agent(ModelCopyAgent('TargetNetCopier2', self, sac_agent.critic2,
                                               sac_agent.target_critic2, 1, args.polyak, args.min_explore_steps))

            self.register_agent(PeriodicAgent('ModelSaver', self, lambda step_id, ep_id: (torch.save(
                sac_agent.actor.state_dict(), os.path.join(self.manager.logdir, 'checkpoints', f'step-{step_id}-ep+{ep_id}.model')), torch.save(
                sac_agent.actor.state_dict(), os.path.join(self.manager.logdir, 'checkpoints', f'latest.model'))), step_freq=args.model_save_freq))

        self.register_agent(StatsRecordingAgent("StatsRecorder", self, reward_scaling=args.reward_scaling, cost_scaling=args.cost_scaling, record_unscaled=args.record_unscaled,
                                                gamma=args.gamma, cost_gamma=args.cost_gamma, record_undiscounted=not args.record_discounted, frameskip=self.frameskip, RPE_av_over=args.RPE_av_over, RPS_av_over=args.RPS_av_over))  # type: StatsRecordingAgent

        self.register_agent(ConsolePrintAgent("ConsolePrinter", self, lambda: {
            'Steps': self.manager.num_steps,
            'Episodes': self.manager.num_episodes,
            'Len': self.manager.num_episode_steps,
            'R': wandb.run.history._data['Episode/Reward'],
            f'R({args.RPE_av_over})': wandb.run.history._data[f'Average/RPE (Last {args.RPE_av_over})'],
            'loss': wandb.run.history._data['SAC/Loss'],
            'a_loss': wandb.run.history._data['SAC/A_Loss'],
            'v': wandb.run.history._data['SAC/Value'],
            'alpha': wandb.run.history._data['SAC/Alpha'],
            'entropy': wandb.run.history._data['SAC/Entropy']
        }, lambda: {
            'Total Steps': self.manager.num_steps,
            'Total Episodes': self.manager.num_episodes,
            'Average RPE': wandb.run.history._data['Average/RPE'],
            'Average CPE': wandb.run.history._data['Average/CPE']
        }))

        if not args.no_render:
            self.register_agent(SimpleRenderAgent("SimpleRenderAgent", self))
Example #2
0
    def setup(self):
        args = p.parse_args()
        self.register_agent(SeedingAgent("SeedingAgent", self, args.seed))
        self.register_agent(
            RewardScalingAgent("RewardScaler",
                               self,
                               reward_scaling=args.reward_scaling,
                               cost_scaling=args.cost_scaling))
        self.manager.episode_type = 1

        self.register_agent(
            RandomPlayAgent("RandomAgent", self, play_for_steps=None))

        self.register_agent(
            StatsRecordingAgent(
                "StatsRecorder",
                self,
                reward_scaling=args.reward_scaling,
                cost_scaling=args.cost_scaling,
                record_unscaled=args.record_unscaled,
                gamma=args.gamma,
                cost_gamma=args.cost_gamma,
                record_undiscounted=not args.record_discounted,
                frameskip=self.frameskip,
                RPE_av_over=args.RPE_av_over,
                RPS_av_over=args.RPS_av_over))  # type: StatsRecordingAgent

        self.register_agent(
            ConsolePrintAgent(
                "ConsolePrinter", self, lambda: {
                    'Steps':
                    self.manager.num_steps,
                    'Episodes':
                    self.manager.num_episodes,
                    'Len':
                    self.manager.num_episode_steps,
                    'R':
                    wandb.run.history._data['Episode/Reward'],
                    f'R({args.RPE_av_over})':
                    wandb.run.history._data[
                        f'Average/RPE (Last {args.RPE_av_over})'],
                    'C':
                    wandb.run.history._data['Episode/Cost']
                }, lambda: {
                    'Total Steps': self.manager.num_steps,
                    'Total Episodes': self.manager.num_episodes,
                    'Average RPE': wandb.run.history._data['Average/RPE'],
                    'Average CPE': wandb.run.history._data['Average/CPE'],
                    'Average RPS': wandb.run.history._data['Average/RPS'],
                    'Average CPS': wandb.run.history._data['Average/CPS']
                }))

        if not args.no_render:
            self.register_agent(SimpleRenderAgent("SimpleRenderAgent", self))
Example #3
0
with open(os.path.join(logdir, 'args.json'), 'w') as args_file:
    args_file.write(str(vars(args)))

wandb.init(dir=logdir, project=args.env_id,
           name=f'{args.algo_id}_{args.algo_suffix}', monitor_gym=True, tags=args.tags)
wandb.config.update(args)

for tag in args.tags:
    wandb.config.update({tag: True})
# wandb.config.update(unknown)
wandb.save(logfile)
wandb.save(logdir)
wandb.save(checkpoints_dir)
# images_dir = os.path.join(logdir, 'images')
# os.makedirs(images_dir)
# wandb.save(images_dir)

try:
    import RL.algorithms
    if 'BulletEnv-' in args.env_id:
        import pybullet  # noqa
        import pybullet_envs  # noqa
    args = p.parse_args()
    wandb.config.update(args)
    m = RL.Manager(args.env_id, args.algo_id, args.algo_suffix, num_steps_to_run=args.num_steps_to_run,
                   num_episodes_to_run=args.num_episodes_to_run, logdir=logdir)
    m.run()
except Exception as e:
    logger.exception("The script crashed due to an exception")
    raise e
Example #4
0
    def setup(self):
        args = p.parse_args()
        self.register_agent(SeedingAgent("SeedingAgent", self, args.seed))
        self.register_agent(
            RewardScalingAgent("RewardScaler",
                               self,
                               reward_scaling=args.reward_scaling,
                               cost_scaling=args.cost_scaling))

        self.register_agent(
            EpisodeTypeControlAgent(
                'EpisodeTypeController', self, args.eval_mode,
                args.min_explore_steps,
                args.exploit_freq))  # type: EpisodeTypeControlAgent
        model_loader = None
        if args.eval_mode:
            model_loader = self.register_agent(
                ModelLoadAgent('ModelLoader',
                               self,
                               None,
                               os.path.join(args.rl_logdir, args.env_id,
                                            args.eval_run, 'checkpoints'),
                               in_sequence=args.eval_in_sequence,
                               wait_for_new=False))
        exp_buff_agent = None
        if not args.eval_mode:
            exp_buff_agent = self.register_agent(
                ExperienceBufferAgent("ExpBuffAgent", self, args.nsteps,
                                      args.gamma, args.cost_gamma,
                                      args.exp_buff_len, None,
                                      not args.no_ignore_done_on_timelimit))

        dqn_core_agent = self.register_agent(
            DQNCoreAgent(
                'DQNCoreAgent', self,
                list(
                    filter(lambda x: x != [0],
                           [args.conv1, args.conv2, args.conv3])),
                args.hiddens, args.train_freq, args.sgd_steps, args.mb_size,
                args.double_dqn, args.dueling_dqn, args.dqn_mse_loss,
                args.gamma, args.nsteps, args.td_clip, args.grad_clip, args.lr,
                args.ep, args.noisy_explore, args.eval_mode,
                args.min_explore_steps,
                None if args.eval_mode else exp_buff_agent.experience_buffer,
                args.dqn_ptemp))  # type: DQNCoreAgent

        if args.eval_mode:
            model_loader.model = dqn_core_agent.q

        if not args.eval_mode:
            self.register_agent(
                LinearAnnealingAgent('EpsilonAnnealer', self, dqn_core_agent,
                                     'epsilon', args.min_explore_steps, 1,
                                     args.ep, args.ep_anneal_steps))

            self.register_agent(
                ModelCopyAgent('TargetNetCopier', self, dqn_core_agent.q,
                               dqn_core_agent.target_q, args.target_q_freq,
                               args.target_q_polyak, args.min_explore_steps))
            self.register_agent(
                PeriodicAgent(
                    'ModelSaver',
                    self,
                    lambda step_id, ep_id:
                    (torch.save(
                        dqn_core_agent.q.state_dict(),
                        os.path.join(self.manager.logdir, 'checkpoints',
                                     f'step-{step_id}-ep+{ep_id}.model')),
                     torch.save(
                         dqn_core_agent.q.state_dict(),
                         os.path.join(self.manager.logdir, 'checkpoints',
                                      f'latest.model'))),
                    step_freq=args.model_save_freq))

        self.register_agent(
            StatsRecordingAgent(
                "StatsRecorder",
                self,
                reward_scaling=args.reward_scaling,
                cost_scaling=args.cost_scaling,
                record_unscaled=args.record_unscaled,
                gamma=args.gamma,
                cost_gamma=args.cost_gamma,
                record_undiscounted=not args.record_discounted,
                frameskip=self.frameskip,
                RPE_av_over=args.RPE_av_over,
                RPS_av_over=args.RPS_av_over))  # type: StatsRecordingAgent

        self.register_agent(
            ConsolePrintAgent(
                "ConsolePrinter", self, lambda: {
                    'Steps':
                    self.manager.num_steps,
                    'Episodes':
                    self.manager.num_episodes,
                    'Len':
                    self.manager.num_episode_steps,
                    'R':
                    wandb.run.history._data['Episode/Reward'],
                    f'R({args.RPE_av_over})':
                    wandb.run.history._data[
                        f'Average/RPE (Last {args.RPE_av_over})'],
                    'loss':
                    wandb.run.history._data['DQN/Loss'],
                    'mb_V':
                    wandb.run.history._data['DQN/Value'],
                    'ep':
                    wandb.run.history._data['DQN/Epsilon'],
                    'mb_QStd':
                    wandb.run.history._data['DQN/Q_Std']
                }, lambda: {
                    'Total Steps': self.manager.num_steps,
                    'Total Episodes': self.manager.num_episodes,
                    'Average RPE': wandb.run.history._data['Average/RPE'],
                    'Average CPE': wandb.run.history._data['Average/CPE']
                }))

        if not args.no_render:
            self.register_agent(SimpleRenderAgent("SimpleRenderAgent", self))
Example #5
0
 def wrap_env(self, env: gym.Env):
     global logger
     logger = logging.getLogger(__name__)
     args = p.parse_args()
     if args.artificial_timelimit:
         logger.info('Wrapping with Timelimit')
         env = TimeLimit(env, max_episode_steps=args.artificial_timelimit)
     if not args.no_monitor:
         env = Monitor(
             env,
             osp.join(self.manager.logdir, 'openai_monitor'),
             video_callable=lambda ep_id: capped_quadratic_video_schedule(
                 ep_id, args.monitor_video_freq),
             force=True,
             mode='evaluation' if args.eval_mode else 'training')
     if '-ramNoFrameskip-v4' in self.manager.env_id:  # for playing atari from ram
         logger.info('Atari RAM env detected')
         logger.info('Wrapping with Fire Reset')
         env = FireResetEnv(env)
         if args.atari_episodic_life:
             logger.info('Wrapping with EpisodicLife')
             env = EpisodicLifeEnv(env)
         logger.info('Wrapping with NoopReset')
         env = NoopResetEnv(env, noop_max=args.atari_noop_max)
         logger.info('Wrapping with Frameskip')
         env = FrameSkipWrapper(env, skip=args.atari_frameskip)
         if args.framestack > 1:
             logger.info('Wrapping with Framestack')
             env = LinearFrameStackWrapper(env, k=args.framestack)
         if args.atari_clip_rewards:
             logger.info('Wrapping with ClipRewards')
             env = ClipRewardEnv(env)
         self.frameskip = args.atari_frameskip
         self.framestack = args.framestack
     # Some Image obs environment
     elif isinstance(
             env.observation_space,
             gym.spaces.Box) and len(env.observation_space.shape) >= 2:
         if 'NoFrameskip-v4' in self.manager.env_id:
             logger.info('Atari env detected')
             logger.info('Wrapping with Fire Reset')
             env = FireResetEnv(env)
             logger.info('Wrapping with AtariPreprocessing')
             env = AtariPreprocessing(
                 env,
                 noop_max=args.atari_noop_max,
                 frame_skip=args.atari_frameskip,
                 terminal_on_life_loss=args.atari_episodic_life)
             logger.info('Wrapping with Framestack')
             env = FrameStack(env, args.atari_framestack)
             if args.atari_clip_rewards:
                 logger.info('Wrapping with ClipRewards')
                 env = ClipRewardEnv(env)
             self.frameskip = args.atari_frameskip
             self.framestack = args.atari_framestack
         else:
             logger.info('Some image based env detected')
             if args.frameskip > 1:
                 logger.info('Wrapping with Frameskip')
                 env = FrameSkipWrapper(env, skip=args.frameskip)
             if args.framestack > 1:
                 logger.info('Wrapping with Framestack')
                 env = FrameStack(env, args.framestack)
             self.frameskip = args.frameskip
             self.framestack = args.framestack
     else:
         if args.frameskip > 1:
             logger.info('Wrapping with Frameskip')
             env = FrameSkipWrapper(env, skip=args.frameskip)
         if args.framestack > 1:
             logger.info('Wrapping with Framestack')
             env = LinearFrameStackWrapper(env, k=args.framestack)
         self.frameskip = args.frameskip
         self.framestack = args.framestack
     return env