Example #1
0
    def evaluate(self, epoch):
        statistics = OrderedDict()
        statistics.update(self.eval_statistics)
        self.eval_statistics = None

        logger.log("Collecting samples for evaluation")
        test_paths = self.eval_sampler.obtain_samples()

        statistics.update(eval_util.get_generic_path_information(
            test_paths, stat_prefix="Test",
        ))
        statistics.update(eval_util.get_generic_path_information(
            self._exploration_paths, stat_prefix="Exploration",
        ))
        if hasattr(self.env, "log_diagnostics"):
            self.env.log_diagnostics(test_paths)

        average_returns = rlkit.core.eval_util.get_average_returns(test_paths)
        statistics['AverageReturn'] = average_returns
        for key, value in statistics.items():
            logger.record_tabular(key, value)

        if self.render_eval_paths:
            self.env.render_paths(test_paths)

        if self.plotter:
            self.plotter.draw()
Example #2
0
    def _log_stats(self, epoch):
        logger.log("Epoch {} finished".format(epoch), with_timestamp=True)

        """
        Replay Buffer
        """
        logger.record_dict(
            self.replay_buffer.get_diagnostics(),
            prefix='replay_buffer/'
        )

        """
        Trainer
        """
        logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/')

        """
        Exploration
        """
        logger.record_dict(
            self.expl_data_collector.get_diagnostics(),
            prefix='exploration/'
        )
        expl_paths = self.expl_data_collector.get_epoch_paths()
        # import ipdb; ipdb.set_trace()
        if hasattr(self.expl_env, 'get_diagnostics'):
            logger.record_dict(
                self.expl_env.get_diagnostics(expl_paths),
                prefix='exploration/',
            )
        if not self.batch_rl or self.eval_both:
            logger.record_dict(
                eval_util.get_generic_path_information(expl_paths),
                prefix="exploration/",
            )
        """
        Evaluation
        """
        logger.record_dict(
            self.eval_data_collector.get_diagnostics(),
            prefix='evaluation/',
        )
        eval_paths = self.eval_data_collector.get_epoch_paths()
        if hasattr(self.eval_env, 'get_diagnostics'):
            logger.record_dict(
                self.eval_env.get_diagnostics(eval_paths),
                prefix='evaluation/',
            )
        logger.record_dict(
            eval_util.get_generic_path_information(eval_paths),
            prefix="evaluation/",
        )

        """
        Misc
        """
        gt.stamp('logging')
        logger.record_dict(_get_epoch_timings())
        logger.record_tabular('Epoch', epoch)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)
    def evaluate(self, epoch, eval_paths=None):
        statistics = OrderedDict()
        statistics.update(self.eval_statistics)

        logger.log("Collecting samples for evaluation")
        if eval_paths:
            test_paths = eval_paths
        else:
            test_paths = self.get_eval_paths()
        statistics.update(
            eval_util.get_generic_path_information(
                test_paths,
                stat_prefix="Test",
            ))
        if len(self._exploration_paths) > 0:
            statistics.update(
                eval_util.get_generic_path_information(
                    self._exploration_paths,
                    stat_prefix="Exploration",
                ))
        if hasattr(self.env, "log_diagnostics"):
            self.env.log_diagnostics(test_paths, logger=logger)
        if hasattr(self.env, "get_diagnostics"):
            statistics.update(self.env.get_diagnostics(test_paths))

        average_returns = eval_util.get_average_returns(test_paths)
        statistics['AverageReturn'] = average_returns
        for key, value in statistics.items():
            logger.record_tabular(key, value)
        self.need_to_update_eval_statistics = True
Example #4
0
 def _log_stats(self, epoch, num_epochs_per_eval=0):
     logger.log("Epoch {} finished".format(epoch), with_timestamp=True)
     """
     Replay Buffer
     """
     logger.record_dict(self.replay_buffer.get_diagnostics(),
                        prefix='replay_buffer/')
     """
     Trainer
     """
     logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/')
     """
     Exploration
     """
     logger.record_dict(self.expl_data_collector.get_diagnostics(),
                        prefix='exploration/')
     expl_paths = self.expl_data_collector.get_epoch_paths()
     if hasattr(self.expl_env, 'get_diagnostics'):
         logger.record_dict(
             self.expl_env.get_diagnostics(expl_paths),
             prefix='exploration/',
         )
     logger.record_dict(
         eval_util.get_generic_path_information(expl_paths),
         prefix="exploration/",
     )
     """
     Evaluation
     """
     try:
         # if epoch % num_epochs_per_eval == 0:
         logger.record_dict(
             self.eval_data_collector.get_diagnostics(),
             prefix='evaluation/',
         )
         eval_paths = self.eval_data_collector.get_epoch_paths()
         if hasattr(self.eval_env, 'get_diagnostics'):
             logger.record_dict(
                 self.eval_env.get_diagnostics(eval_paths),
                 prefix='evaluation/',
             )
         logger.record_dict(
             eval_util.get_generic_path_information(eval_paths),
             prefix="evaluation/",
         )
     except:
         pass
     """
     Misc
     """
     gt.stamp('logging')
     logger.record_dict(_get_epoch_timings())
     logger.record_tabular('Epoch', epoch)
     logger.dump_tabular(with_prefix=False, with_timestamp=False)
Example #5
0
    def _log_stats(self, epoch):
        expl_paths = self.expl_data_collector.get_epoch_paths()
        eval_paths = self.eval_data_collector.get_epoch_paths()

        expl_path_information = eval_util.get_generic_path_information(expl_paths)
        eval_path_information = eval_util.get_generic_path_information(eval_paths)

        for k, v in expl_path_information.items():
            if k == 'Returns Mean' or k == 'Returns Std':
                self.writer.add_scalar('expl/' + k, v, epoch)
        for k, v in eval_path_information.items():
            if k == 'Returns Mean' or k == 'Returns Std':
                self.writer.add_scalar('eval/' + k, v, epoch)
Example #6
0
 def _log_stats(self, epoch):
     logger.log("Epoch {} finished".format(epoch), with_timestamp=True)
     """
     Replay Buffer
     """
     logger.record_dict(self.replay_buffer.get_diagnostics(),
                        prefix="replay_buffer/")
     """
     Trainer
     """
     logger.record_dict(self.trainer.get_diagnostics(), prefix="trainer/")
     """
     Exploration
     """
     logger.record_dict(self.expl_data_collector.get_diagnostics(),
                        prefix="exploration/")
     expl_paths = self.expl_data_collector.get_epoch_paths()
     if hasattr(self.expl_env, "get_diagnostics"):
         logger.record_dict(
             self.expl_env.get_diagnostics(expl_paths),
             prefix="exploration/",
         )
     logger.record_dict(
         eval_util.get_generic_path_information(expl_paths),
         prefix="exploration/",
     )
     """
     Evaluation
     """
     logger.record_dict(
         self.eval_data_collector.get_diagnostics(),
         prefix="evaluation/",
     )
     eval_paths = self.eval_data_collector.get_epoch_paths()
     if hasattr(self.eval_env, "get_diagnostics"):
         logger.record_dict(
             self.eval_env.get_diagnostics(eval_paths),
             prefix="evaluation/",
         )
     logger.record_dict(
         eval_util.get_generic_path_information(eval_paths),
         prefix="evaluation/",
     )
     """
     Misc
     """
     gt.stamp("logging")
     logger.record_dict(_get_epoch_timings())
     logger.record_tabular("Epoch", epoch)
     logger.dump_tabular(with_prefix=False, with_timestamp=False)
Example #7
0
    def evaluate(self, epoch):
        """
        Evaluate the policy, e.g. save/print progress.
        :param epoch:
        :return:
        """
        statistics = OrderedDict()
        try:
            statistics.update(self.eval_statistics)
            self.eval_statistics = None
        except:
            print('No Stats to Eval')

        logger.log("Collecting samples for evaluation")
        test_paths = self.eval_sampler.obtain_samples()

        statistics.update(
            eval_util.get_generic_path_information(
                test_paths,
                stat_prefix="Test",
            ))
        statistics.update(
            eval_util.get_generic_path_information(
                self._exploration_paths,
                stat_prefix="Exploration",
            ))

        if hasattr(self.env, "log_diagnostics"):
            self.env.log_diagnostics(test_paths)
        if hasattr(self.env, "log_statistics"):
            statistics.update(self.env.log_statistics(test_paths))
        if epoch % self.freq_log_visuals == 0:
            if hasattr(self.env, "log_visuals"):
                self.env.log_visuals(test_paths, epoch,
                                     logger.get_snapshot_dir())

        average_returns = eval_util.get_average_returns(test_paths)
        statistics['AverageReturn'] = average_returns
        for key, value in statistics.items():
            logger.record_tabular(key, value)

        best_statistic = statistics[self.best_key]
        if best_statistic > self.best_statistic_so_far:
            self.best_statistic_so_far = best_statistic
            if self.save_best and epoch >= self.save_best_starting_from_epoch:
                data_to_save = {'epoch': epoch, 'statistics': statistics}
                data_to_save.update(self.get_epoch_snapshot(epoch))
                logger.save_extra_data(data_to_save, 'best.pkl')
                print('\n\nSAVED BEST\n\n')
Example #8
0
    def _log_stats(self, epoch):
        print()
        print(
            "####################### RESULT OF EPOCH {} #######################"
            .format(epoch))
        for k, v in self.trainer.get_diagnostics().items():
            if str(k) == "QF1 Loss":
                print("QF1 Loss", v)
                self.writer.add_scalar("QF1 Loss", v, epoch)
            elif str(k) == "QF2 Loss":
                print("QF2 Loss", v)
                self.writer.add_scalar("QF2 Loss", v, epoch)
            elif str(k) == "Policy Loss":
                print("Policy Loss", v)
                self.writer.add_scalar("Policy Loss", v, epoch)
            elif str(k) == "Alpha":
                print("Alpha", v)
                self.writer.add_scalar("Alpha", v, epoch)
            elif str(k) == "Log Pis Mean":
                print("Log Pis Mean", v)
                self.writer.add_scalar("Log Pis Mean", v, epoch)
            elif str(k) == "Target Entropy":
                print("Target Entropy", v)
                self.writer.add_scalar("Target Entropy", v, epoch)

        total_num_steps = (epoch + 1) * self.max_path_length

        expl_paths = self.expl_data_collector.get_epoch_paths()
        d = eval_util.get_generic_path_information(expl_paths)
        for k, v in d.items():
            if str(k) == "Average Returns":
                print("Exploration_rewards", v)
                self.writer.add_scalar("Episode_rewards", v, epoch)
                self.writer.add_scalar("Episode_rewards_envstep", v,
                                       total_num_steps)

        eval_util.print_returns_info(expl_paths, epoch)
        if self.eval_env:
            eval_paths = self.eval_data_collector.get_epoch_paths()
            d = eval_util.get_generic_path_information(eval_paths)
            for k, v in d.items():
                if str(k) == "Average Returns":
                    print("Evaluation_rewards", v)
                    self.writer.add_scalar("Evaluation_rewards", v, epoch)
                    self.writer.add_scalar("Evaluation_rewards_envstep", v,
                                           total_num_steps)
        print(
            "################################################################")
Example #9
0
    def _end_epoch(self, epoch):
        print('in _end_epoch, epoch is: {}'.format(epoch))
        snapshot = self._get_snapshot()
        logger.save_itr_params(epoch, snapshot)
        # trainer_obj = self.trainer
        # ckpt_path='ckpt.pkl'
        # logger.save_ckpt(epoch, trainer_obj, ckpt_path)
        # gt.stamp('saving')
        if epoch % 1 == 0:
            self.save_snapshot_2(epoch)
        expl_paths = self.expl_data_collector.get_epoch_paths()
        d = eval_util.get_generic_path_information(expl_paths)
        # print(d.keys())
        metric_val = d['Rewards Mean']

        cur_best_metric_val = self.get_cur_best_metric_val()
        if epoch != 0:
            self.save_snapshot_2_best_only(
                metric_val=metric_val,
                cur_best_metric_val=cur_best_metric_val,
                min_or_max='max')
        self._log_stats(epoch)

        self.expl_data_collector.end_epoch(epoch)
        self.eval_data_collector.end_epoch(epoch)
        self.replay_buffer.end_epoch(epoch)
        self.trainer.end_epoch(epoch)

        for post_epoch_func in self.post_epoch_funcs:
            post_epoch_func(self, epoch)
Example #10
0
def simulate_policy(args):
    data = torch.load(args.file)
    policy = data['evaluation/policy']
    env = data['evaluation/env']
    print("Policy loaded")
    if args.gpu:
        set_gpu_mode(True)
        policy.cuda()
    paths = []
    while True:
        path = rollout(
            env,
            policy,
            max_path_length=args.H,
            render=True,
        )
        paths.append(path)
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics(paths)
        if hasattr(env, "get_diagnostics"):
            for k, v in env.get_diagnostics(paths).items():
                logger.record_tabular(k, v)
        else:
            logger.record_dict(
                eval_util.get_generic_path_information(paths),
                prefix="evaluation/",
            )
        logger.dump_tabular()
Example #11
0
 def _log_stats(self, epoch):
     logger.log("Epoch {} finished".format(epoch), with_timestamp=True)
     """
     Trainer
     """
     logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/')
     """
     Evaluation
     """
     logger.record_dict(
         self.eval_data_collector.get_diagnostics(),
         prefix='evaluation/',
     )
     eval_paths = self.eval_data_collector.get_epoch_paths()
     if hasattr(self.eval_env, 'get_diagnostics'):
         logger.record_dict(
             self.eval_env.get_diagnostics(eval_paths),
             prefix='evaluation/',
         )
     logger.record_dict(
         eval_util.get_generic_path_information(eval_paths),
         prefix="evaluation/",
     )
     """
     Misc
     """
     gt.stamp('logging')
     logger.record_dict(_get_epoch_timings())
     logger.record_tabular('Epoch', epoch)
     logger.dump_tabular(with_prefix=False, with_timestamp=False)
    def evaluate(self, epoch):
        statistics = OrderedDict()
        statistics.update(self.eval_statistics)
        self.eval_statistics = None

        # statistics.update(eval_util.get_generic_path_information(
        #     self._exploration_paths, stat_prefix="Exploration",
        # ))

        for mode in ['meta_train', 'meta_test']:
            logger.log("Collecting samples for evaluation")
            test_paths = self.obtain_eval_samples(epoch, mode=mode)

            statistics.update(
                eval_util.get_generic_path_information(
                    test_paths,
                    stat_prefix="Test " + mode,
                ))
            # print(statistics.keys())
            if hasattr(self.env, "log_diagnostics"):
                self.env.log_diagnostics(test_paths)
            if hasattr(self.env, "log_statistics"):
                log_stats = self.env.log_statistics(test_paths)
                new_log_stats = OrderedDict(
                    (k + ' ' + mode, v) for k, v in log_stats.items())
                statistics.update(new_log_stats)

            average_returns = rlkit.core.eval_util.get_average_returns(
                test_paths)
            statistics['AverageReturn ' + mode] = average_returns

            if self.render_eval_paths:
                self.env.render_paths(test_paths)

        # meta_test_this_epoch = statistics['Percent_Solved meta_test']
        # meta_test_this_epoch = statistics['Percent_Solved meta_test']
        meta_test_this_epoch = statistics['AverageReturn meta_test']
        if meta_test_this_epoch >= self.best_meta_test:
            # make sure you set save_algorithm to true then call save_extra_data
            prev_save_alg = self.save_algorithm
            self.save_algorithm = True
            if self.save_best:
                if epoch > self.save_best_after_epoch:
                    temp = self.replay_buffer
                    self.replay_buffer = None
                    logger.save_extra_data(self.get_extra_data_to_save(epoch),
                                           'best_meta_test.pkl')
                    self.replay_buffer = temp
                    self.best_meta_test = meta_test_this_epoch
                    print('\n\nSAVED ALG AT EPOCH %d\n\n' % epoch)
            self.save_algorithm = prev_save_alg

        for key, value in statistics.items():
            logger.record_tabular(key, value)

        if self.plotter:
            self.plotter.draw()
Example #13
0
 def log_statistics(self, paths, split=''):
     self.eval_statistics.update(
         eval_util.get_generic_path_information(
             paths,
             stat_prefix="{}_task{}".format(split, self.task_idx),
         ))
     # TODO(KR) what are these?
     self.eval_statistics.update(
         eval_util.get_generic_path_information(
             self._exploration_paths,
             stat_prefix="Exploration_task{}".format(self.task_idx),
         )
     )  # something is wrong with these exploration paths i'm pretty sure...
     average_returns = eval_util.get_average_returns(paths)
     self.eval_statistics['AverageReturn_{}_task{}'.format(
         split, self.task_idx)] = average_returns
     goal = self.env._goal
     dprint('GoalPosition_{}_task'.format(split))
     dprint(goal)
     self.eval_statistics['GoalPosition_{}_task{}'.format(
         split, self.task_idx)] = goal
Example #14
0
def simulate_policy(fpath,
                    env_name,
                    seed,
                    max_path_length,
                    num_eval_steps,
                    headless,
                    max_eps,
                    verbose=True,
                    pause=False):
    data = torch.load(fpath, map_location=ptu.device)
    policy = data['evaluation/policy']
    policy.to(ptu.device)
    # make new env, reloading with data['evaluation/env'] seems to make bug
    env = gym.make(env_name, **{"headless": headless, "verbose": False})
    env.seed(seed)
    if pause:
        input("Waiting to start.")
    path_collector = MdpPathCollector(env, policy)
    paths = path_collector.collect_new_paths(
        max_path_length,
        num_eval_steps,
        discard_incomplete_paths=True,
    )

    if max_eps:
        paths = paths[:max_eps]
    if verbose:
        completions = sum([
            info["completed"] for path in paths for info in path["env_infos"]
        ])
        print("Completed {} out of {}".format(completions, len(paths)))
        # plt.plot(paths[0]["actions"])
        # plt.show()
        # plt.plot(paths[2]["observations"])
        # plt.show()
        logger.record_dict(
            eval_util.get_generic_path_information(paths),
            prefix="evaluation/",
        )
        logger.dump_tabular()
    return paths
Example #15
0
    def _log_stats(self, epoch):
        logger.log("Epoch {} finished".format(epoch), with_timestamp=True)
        """
        dump video of policy
        """
        if epoch % self.dump_video_interval == 0:
            imsize = self.expl_env.imsize
            env = self.eval_env

            dump_path = logger.get_snapshot_dir()

            rollout_func = rollout

            video_name = "{}_sample_goal.gif".format(epoch)
            latent_distance_name = "{}_latent_sample_goal.png".format(epoch)
            dump_video(env,
                       self.eval_data_collector._policy,
                       osp.join(dump_path, video_name),
                       rollout_function=rollout_func,
                       imsize=imsize,
                       horizon=self.max_path_length,
                       rows=1,
                       columns=8)
            plot_latent_dist(env,
                             self.eval_data_collector._policy,
                             save_name=osp.join(dump_path,
                                                latent_distance_name),
                             rollout_function=rollout_func,
                             horizon=self.max_path_length)

            old_goal_sampling_mode, old_decode_goals = env._goal_sampling_mode, env.decode_goals
            env._goal_sampling_mode = 'reset_of_env'
            env.decode_goals = False
            video_name = "{}_reset.gif".format(epoch)
            latent_distance_name = "{}_latent_reset.png".format(epoch)
            dump_video(env,
                       self.eval_data_collector._policy,
                       osp.join(dump_path, video_name),
                       rollout_function=rollout_func,
                       imsize=imsize,
                       horizon=self.max_path_length,
                       rows=1,
                       columns=8)
            plot_latent_dist(env,
                             self.eval_data_collector._policy,
                             save_name=osp.join(dump_path,
                                                latent_distance_name),
                             rollout_function=rollout_func,
                             horizon=self.max_path_length)
            self.eval_env._goal_sampling_mode = old_goal_sampling_mode
            self.eval_env.decode_goals = old_decode_goals
        """
        Replay Buffer
        """
        logger.record_dict(self.replay_buffer.get_diagnostics(),
                           prefix='replay_buffer/')
        """
        Trainer
        """
        logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/')
        """
        Exploration
        """
        logger.record_dict(self.expl_data_collector.get_diagnostics(),
                           prefix='exploration/')
        expl_paths = self.expl_data_collector.get_epoch_paths()
        if hasattr(self.expl_env, 'get_diagnostics'):
            logger.record_dict(
                self.expl_env.get_diagnostics(expl_paths),
                prefix='exploration/',
            )
        logger.record_dict(
            eval_util.get_generic_path_information(expl_paths),
            prefix="exploration/",
        )
        """
        Evaluation
        """
        logger.record_dict(
            self.eval_data_collector.get_diagnostics(),
            prefix='evaluation/',
        )
        eval_paths = self.eval_data_collector.get_epoch_paths()
        if hasattr(self.eval_env, 'get_diagnostics'):
            logger.record_dict(
                self.eval_env.get_diagnostics(eval_paths),
                prefix='evaluation/',
            )
        logger.record_dict(
            eval_util.get_generic_path_information(eval_paths),
            prefix="evaluation/",
        )
        """
        Misc
        """
        gt.stamp('logging')
        logger.record_dict(_get_epoch_timings())
        logger.record_tabular('Epoch', epoch)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)
    def evaluate(self, epoch):
        """
        Evaluate the policy, e.g. save/print progress.
        :param epoch:
        :return:
        """
        statistics = OrderedDict()
        try:
            statistics.update(self.eval_statistics)
            self.eval_statistics = None
        except:
            print('No Stats to Eval')

        logger.log("Collecting samples for evaluation")

        test_paths = []
        sampled_task_params = self.test_task_params_sampler.sample_unique(
            self.num_eval_tasks)
        for i in range(self.num_eval_tasks):
            env = self.env_factory(sampled_task_params[i])
            for _ in range(self.num_rollouts_per_task_per_eval):
                test_paths.append(
                    rollout(
                        self.env,
                        self.get_eval_policy(sampled_task_params[i]),
                        self.max_path_length,
                        no_terminal=self.no_terminal,
                        render=self.render,
                        render_kwargs=self.render_kwargs,
                    ))

        statistics.update(
            eval_util.get_generic_path_information(
                test_paths,
                stat_prefix="Test",
            ))
        statistics.update(
            eval_util.get_generic_path_information(
                self._exploration_paths,
                stat_prefix="Exploration",
            ))

        if hasattr(self.env, "log_diagnostics"):
            self.env.log_diagnostics(test_paths)
        if hasattr(self.env, "log_statistics"):
            statistics.update(self.env.log_statistics(test_paths))
        if epoch % self.freq_log_visuals == 0:
            if hasattr(self.env, "log_visuals"):
                self.env.log_visuals(test_paths, epoch,
                                     logger.get_snapshot_dir())

        average_returns = eval_util.get_average_returns(test_paths)
        statistics['AverageReturn'] = average_returns
        for key, value in statistics.items():
            logger.record_tabular(key, value)

        best_statistic = statistics[self.best_key]
        if best_statistic > self.best_statistic_so_far:
            self.best_statistic_so_far = best_statistic
            if self.save_best and epoch >= self.save_best_starting_from_epoch:
                data_to_save = {'epoch': epoch, 'statistics': statistics}
                data_to_save.update(self.get_epoch_snapshot(epoch))
                logger.save_extra_data(data_to_save, 'best.pkl')
                print('\n\nSAVED BEST\n\n')
Example #17
0
    else:
        max_tau = args.mtau

    env = data['env']
    policy = data['policy']
    policy.train(False)

    if args.gpu:
        ptu.set_gpu_mode(True)
        policy.cuda()

    while True:
        paths = []
        for _ in range(args.nrolls):
            goal = env.sample_goal_for_rollout()
            path = multitask_rollout(
                env,
                policy,
                init_tau=max_tau,
                goal=goal,
                max_path_length=args.H,
                animated=not args.hide,
                cycle_tau=True,
                decrement_tau=True,
            )
            paths.append(path)
        env.log_diagnostics(paths)
        for key, value in get_generic_path_information(paths).items():
            logger.record_tabular(key, value)
        logger.dump_tabular()
Example #18
0
    def _log_exploration_tb_stats(self, epoch):

        expl_paths = self.expl_data_collector.get_epoch_paths()
        #eval_paths = self.eval_data_collector.get_epoch_paths()
        '''
        self.writer.add_scalar('test_regular/eval', eval_util.get_generic_path_information(
            eval_paths)['env_infos/final/is_success Mean'], epoch)
        '''
        '''
        ates = []
        for path in expl_paths:
            ate = np.sqrt(eval_util.get_generic_path_information([path])[
                          'env_infos/squared_error_norm Mean'])
            ates.append(ate)
        ates = np.array(ates)
        '''

        self.writer.add_scalar(
            'expl/delta_size_penalty',
            eval_util.get_generic_path_information(expl_paths)
            ['env_infos/delta_size_penalty Mean'], epoch)

        self.writer.add_scalar(
            'expl/delta_size',
            eval_util.get_generic_path_information(expl_paths)
            ['env_infos/delta_size Mean'], epoch)

        self.writer.add_scalar(
            'expl/ate_penalty',
            eval_util.get_generic_path_information(expl_paths)
            ['env_infos/ate_penalty Mean'], epoch)

        self.writer.add_scalar(
            'expl/cosine_distance',
            eval_util.get_generic_path_information(expl_paths)
            ['env_infos/cosine_distance Mean'], epoch)

        self.writer.add_scalar(
            'expl/task_reward',
            eval_util.get_generic_path_information(expl_paths)
            ['env_infos/task_reward Mean'], epoch)

        self.writer.add_scalar(
            'expl/control_penalty',
            eval_util.get_generic_path_information(expl_paths)
            ['env_infos/control_penalty Mean'], epoch)

        self.writer.add_scalar(
            'expl/reward',
            eval_util.get_generic_path_information(expl_paths)['Rewards Mean'],
            epoch)

        self.writer.add_scalar(
            'expl/returns',
            eval_util.get_generic_path_information(expl_paths)['Returns Mean'],
            epoch)

        if 'State estimation loss' in self.trainer.get_diagnostics().keys():
            self.writer.add_scalar(
                'losses/eval/state',
                self.trainer.get_diagnostics()['State estimation loss'], epoch)

        self.writer.add_scalar('losses/q1',
                               self.trainer.get_diagnostics()['QF1 Loss'],
                               epoch)
        self.writer.add_scalar('losses/q2',
                               self.trainer.get_diagnostics()['QF2 Loss'],
                               epoch)
        self.writer.add_scalar(
            'losses/policy',
            self.trainer.get_diagnostics()['Raw Policy Loss'], epoch)
        self.writer.add_scalar('losses/logpi',
                               self.trainer.get_diagnostics()['Log Pi'], epoch)
        self.writer.add_scalar('losses/alpha',
                               self.trainer.get_diagnostics()['Alpha'], epoch)
        self.writer.add_scalar('losses/alphaloss',
                               self.trainer.get_diagnostics()['Alpha Loss'],
                               epoch)
Example #19
0
    def evaluate(self, epoch):
        statistics = OrderedDict()
        if self.rewardf_eval_statistics is not None:
            statistics.update(self.rewardf_eval_statistics)
        # statistics.update(self.policy_optimizer.eval_statistics)
        self.rewardf_eval_statistics = None
        # self.policy_optimizer.eval_statistics = None

        logger.log("Collecting samples for evaluation")
        test_paths = self.eval_sampler.obtain_samples()

        statistics.update(
            eval_util.get_generic_path_information(
                test_paths,
                stat_prefix="Test",
            ))
        statistics.update(
            eval_util.get_generic_path_information(
                self._exploration_paths,
                stat_prefix="Exploration",
            ))
        # print(statistics.keys())
        if hasattr(self.env, "log_diagnostics"):
            self.env.log_diagnostics(test_paths)
        if hasattr(self.env, "log_statistics"):
            env_log_stats = self.env.log_statistics(test_paths)
            statistics.update(env_log_stats)
        if hasattr(self.env, "log_new_ant_multi_statistics"):
            env_log_stats = self.env.log_new_ant_multi_statistics(
                test_paths, epoch, logger.get_snapshot_dir())
            statistics.update(env_log_stats)

        average_returns = rlkit.core.eval_util.get_average_returns(test_paths)
        statistics['AverageReturn'] = average_returns
        for key, value in statistics.items():
            logger.record_tabular(key, value)

        if self.render_eval_paths:
            self.env.render_paths(test_paths)

        if self.plotter:
            self.plotter.draw()

        # if self.best_success_rate < statistics['Success Rate']:
        #     self.best_success_rate = statistics['Success Rate']
        #     params = self.get_epoch_snapshot(-1)
        #     params['epoch'] = epoch
        #     logger.save_extra_data(params, 'best_params.pkl')

        if average_returns > self.max_returns:
            self.max_returns = average_returns
            if self.save_best and epoch >= self.save_best_starting_from_epoch:
                data_to_save = {
                    'algorithm': self,
                    'epoch': epoch,
                    'average_returns': average_returns,
                    'test_returns_mean': statistics['Test Returns Mean'],
                    'test_returns_std': statistics['Test Returns Std'],
                    'exp_returns_mean': statistics['Exploration Returns Mean'],
                    'exp_returns_std': statistics['Exploration Returns Std']
                }
                logger.save_extra_data(data_to_save, 'best_test.pkl')
                print('\n\nSAVED BEST\n\n')
    def _train(self):
        start_time = time.time()
        if self.min_num_steps_before_training > 0:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        temp_policy_weights = copy.deepcopy(
            self.trainer._base_trainer.policy.state_dict())
        self.policy_weights_queue.put(temp_policy_weights)
        self.new_policy_event.set()

        print("Initialized policy")
        process = psutil.Process(os.getpid())

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            if epoch % self.save_policy_every_epoch == 0:
                torch.save(self.trainer._base_trainer.policy.state_dict(),
                           'async_policy/current_policy.mdl')
                torch.save(
                    self.trainer._base_trainer.alpha_optimizer.state_dict(),
                    'async_policy/current_alpha_optimizer.mdl')
                torch.save(
                    self.trainer._base_trainer.policy_optimizer.state_dict(),
                    'async_policy/current_policy_optimizer.mdl')
                torch.save(
                    self.trainer._base_trainer.qf1_optimizer.state_dict(),
                    'async_policy/current_qf1_optimizer.mdl')
                torch.save(
                    self.trainer._base_trainer.qf2_optimizer.state_dict(),
                    'async_policy/current_qf2_optimizer.mdl')

                print("Saved current policy")

            files = glob.glob('success_images/*')
            for f in files:
                os.remove(f)

            self.eval_data_collector.collect_new_paths(
                self.max_path_length, self.num_eval_rollouts_per_epoch)

            self.preset_eval_data_collector.collect_new_paths(
                self.max_path_length, self.num_eval_param_buckets)
            gt.stamp('evaluation sampling')

            eval_paths = self.eval_data_collector.get_epoch_paths()
            print(
                "EVAL SUCCESS RATRE",
                eval_util.get_generic_path_information(eval_paths)
                ['env_infos/final/is_success Mean'])

            print("Epoch", epoch)
            for cycle in range(self.num_train_loops_per_epoch):

                # TODO: Use for memory debug
                #print("Memory usage in train",process.memory_info().rss/10E9, "GB")
                train_steps = epoch*self.num_train_loops_per_epoch * \
                    self.num_trains_per_train_loop + cycle*self.num_trains_per_train_loop

                while train_steps > self.train_collect_ratio * self.num_collected_steps.value:
                    print("Waiting collector to catch up...", train_steps,
                          self.num_collected_steps.value)
                    time.sleep(3)

                start_cycle = time.time()
                self.training_mode(True)
                sam_times_cycle = 0
                train_train_times_cycle = 0

                for tren in range(self.num_trains_per_train_loop):
                    start_sam = time.time()
                    train_data = self.batch_queue.get()
                    self.batch_processed_event.set()

                    sam_time = time.time() - start_sam
                    sam_times_cycle += sam_time

                    start_train_train = time.time()
                    self.trainer.train_from_torch(train_data)
                    del train_data

                    train_train_time = time.time() - start_train_train

                    if not self.new_policy_event.is_set():
                        temp_policy_weights = copy.deepcopy(
                            self.trainer._base_trainer.policy.state_dict())
                        self.policy_weights_queue.put(temp_policy_weights)
                        self.new_policy_event.set()
                        #print("Updated policy")
                    if tren % 100 == 0:
                        print("--STATUS--")
                        print(tren, "/", self.num_trains_per_train_loop,
                              "Took to sample:", sam_time)
                        print(tren, "/", self.num_trains_per_train_loop,
                              "Took to train:", train_train_time)
                        print(
                            "Total train steps so far:",
                            epoch * self.num_train_loops_per_epoch *
                            self.num_trains_per_train_loop +
                            cycle * self.num_trains_per_train_loop + tren)
                        print("Total collected steps in train",
                              self.num_collected_steps.value)
                        print("Memory usages, train:",
                              process.memory_info().rss / 10E9, "buffer:",
                              self.buffer_memory_usage.value, "collector:",
                              self.collector_memory_usage.value, "envs:",
                              [emu.value
                               for emu in self.env_memory_usages], "\n")

                    train_train_times_cycle += train_train_time

                cycle_time = time.time() - start_cycle

                print("Cycle", cycle, "took: \n", cycle_time,
                      "\nAverage pure train: \n",
                      train_train_times_cycle / self.num_trains_per_train_loop,
                      "\nAverage sample time: \n",
                      sam_times_cycle / self.num_trains_per_train_loop,
                      "\nAverage full sample and train: \n",
                      cycle_time / self.num_trains_per_train_loop)

                gt.stamp('training', unique=False)
                self.training_mode(False)
                self.trainer._base_trainer.policy.vers += 1

            self._end_epoch(epoch)
            print("Seconds since start", time.time() - start_time)
def simulate_policy(args):
    # import torch
    # torch.manual_seed(6199)
    if args.pause:
        import ipdb
        ipdb.set_trace()
    data = pickle.load(open(args.file, "rb"))
    policy = data['algorithm'].policy

    num_blocks = 6
    stack_only = True

    # env = data['env']
    env = gym.make(
        F"FetchBlockConstruction_{num_blocks}Blocks_IncrementalReward_DictstateObs_42Rendersize_{stack_only}Stackonly_AllCase-v1"
    )

    env = Monitor(env,
                  force=True,
                  directory="videos/",
                  video_callable=lambda x: x)

    print("Policy and environment loaded")
    if args.gpu:
        ptu.set_gpu_mode(True)
        policy.to(ptu.device)
    if args.enable_render or hasattr(env, 'enable_render'):
        # some environments need to be reconfigured for visualization
        env.enable_render()
    policy.train(False)
    failures = []
    successes = []
    for path_idx in range(100):
        path = multitask_rollout(
            env,
            policy,
            max_path_length=num_blocks * 50,
            animated=not args.hide,
            observation_key='observation',
            desired_goal_key='desired_goal',
            get_action_kwargs=dict(mask=np.ones((1, num_blocks)),
                                   deterministic=True),
        )

        if not is_solved(path, num_blocks):
            failures.append(path)
            print(F"Failed {path_idx}")
        else:
            print(F"Succeeded {path_idx}")
            successes.append(path)
        # if hasattr(env, "log_diagnostics"):
        #     env.log_diagnostics(paths)
        # if hasattr(env, "get_diagnostics"):
        #     for k, v in env.get_diagnostics(paths).items():
        #         logger.record_tabular(k, v)
        # logger.dump_tabular()
    print(f"Success rate {len(successes)/(len(successes) + len(failures))}")
    from rlkit.core.eval_util import get_generic_path_information
    path_info = get_generic_path_information(successes + failures,
                                             num_blocks=num_blocks)
    print(path_info)