def _get_diagnostics(self):
        algo_log = OrderedDict()
        append_log(algo_log,
                   self.replay_buffer.get_diagnostics(),
                   prefix='replay_buffer/')
        append_log(algo_log, self.trainer.get_diagnostics(), prefix='trainer/')
        # Exploration
        append_log(algo_log,
                   self.expl_data_collector.get_diagnostics(),
                   prefix='exploration/')
        expl_paths = self.expl_data_collector.get_epoch_paths()
        if hasattr(self.expl_env, 'get_diagnostics'):
            append_log(algo_log,
                       self.expl_env.get_diagnostics(expl_paths),
                       prefix='exploration/')
        append_log(algo_log,
                   eval_util.get_generic_path_information(expl_paths),
                   prefix="exploration/")
        # Eval
        append_log(algo_log,
                   self.eval_data_collector.get_diagnostics(),
                   prefix='evaluation/')
        eval_paths = self.eval_data_collector.get_epoch_paths()
        if hasattr(self.eval_env, 'get_diagnostics'):
            append_log(algo_log,
                       self.eval_env.get_diagnostics(eval_paths),
                       prefix='evaluation/')
        append_log(algo_log,
                   eval_util.get_generic_path_information(eval_paths),
                   prefix="evaluation/")

        timer.stamp('logging')
        append_log(algo_log, _get_epoch_timings())
        algo_log['epoch'] = self.epoch
        return algo_log
Esempio n. 2
0
    def evaluate(self, epoch):
        """
        Perform evaluation for this algorithm.

        :param epoch: The epoch number.
        """
        statistics = OrderedDict()

        train_batch = self.get_batch()
        statistics.update(self._statistics_from_batch(train_batch, "Train"))

        logger.log("Collecting samples for evaluation")
        test_paths = self._sample_eval_paths()
        statistics.update(
            get_generic_path_information(
                test_paths,
                stat_prefix="Test",
            ))
        statistics.update(self._statistics_from_paths(test_paths, "Test"))
        average_returns = get_average_returns(test_paths)
        statistics['AverageReturn'] = average_returns

        statistics['Epoch'] = epoch

        for key, value in statistics.items():
            logger.record_tabular(key, value)

        self.env.log_diagnostics(test_paths)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)
Esempio n. 3
0
 def _statistics_from_paths(self, paths, stat_prefix):
     eval_replay_buffer = UpdatableSubtrajReplayBuffer(
         len(paths) * (self.max_path_length + 1),
         self.env,
         self.subtraj_length,
         self.memory_dim,
     )
     for path in paths:
         eval_replay_buffer.add_trajectory(path)
     raw_subtraj_batch = eval_replay_buffer.get_all_valid_subtrajectories()
     assert raw_subtraj_batch is not None
     subtraj_batch = create_torch_subtraj_batch(raw_subtraj_batch)
     if self.save_memory_gradients:
         subtraj_batch['memories'].requires_grad = True
     statistics = self._statistics_from_subtraj_batch(
         subtraj_batch, stat_prefix=stat_prefix
     )
     statistics.update(eval_util.get_generic_path_information(
         paths, stat_prefix="Test",
     ))
     env_actions = np.vstack([path["actions"][:self.action_dim] for path in
                              paths])
     writes = np.vstack([path["actions"][self.action_dim:] for path in
                         paths])
     statistics.update(create_stats_ordered_dict(
         'Env Actions', env_actions, stat_prefix=stat_prefix
     ))
     statistics.update(create_stats_ordered_dict(
         'Writes', writes, stat_prefix=stat_prefix
     ))
     return statistics
Esempio n. 4
0
    def evaluate(self, epoch, eval_paths=None):
        statistics = OrderedDict()
        statistics.update(self.eval_statistics)

        logger.log("Collecting samples for evaluation")
        if eval_paths:
            test_paths = eval_paths
        else:
            test_paths = self.get_eval_paths()
        statistics.update(
            eval_util.get_generic_path_information(
                test_paths,
                stat_prefix="Test",
            ))
        # if len(self._exploration_paths) > 0:
        #     statistics.update(eval_util.get_generic_path_information(
        #         self._exploration_paths, stat_prefix="Exploration",
        #     ))
        if hasattr(self.env, "log_diagnostics"):
            self.env.log_diagnostics(test_paths, logger=logger)
        if hasattr(self.env, "get_diagnostics"):
            statistics.update(self.env.get_diagnostics(test_paths))

        average_returns = eval_util.get_average_returns(test_paths)
        statistics['AverageReturn'] = average_returns
        for key, value in statistics.items():
            logger.record_tabular(key, value)
        self.need_to_update_eval_statistics = True
Esempio n. 5
0
def simulate_policy(args):
    data = pickle.load(open(args.file, "rb"))
    policy_key = args.policy_type + '/policy'
    if policy_key in data:
        policy = data[policy_key]
    else:
        raise Exception("No policy found in loaded dict. Keys: {}".format(
            data.keys()))

    env_key = args.env_type + '/env'
    if env_key in data:
        env = data[env_key]
    else:
        raise Exception("No environment found in loaded dict. Keys: {}".format(
            data.keys()))

    #robosuite env specific things
    env._wrapped_env.has_renderer = True
    env.reset()
    env.viewer.set_camera(camera_id=0)

    if isinstance(env, RemoteRolloutEnv):
        env = env._wrapped_env
    print("Policy loaded")

    if args.enable_render:
        # some environments need to be reconfigured for visualization
        env.enable_render()
    if args.gpu:
        ptu.set_gpu_mode(True)
    if hasattr(policy, "to"):
        policy.to(ptu.device)
    if hasattr(env, "vae"):
        env.vae.to(ptu.device)

    if args.pause:
        import ipdb
        ipdb.set_trace()
    if isinstance(policy, PyTorchModule):
        policy.train(False)
    paths = []
    while True:
        paths.append(
            deprecated_rollout(
                env,
                policy,
                max_path_length=args.H,
                render=not args.hide,
            ))
        if args.log_diagnostics:
            if hasattr(env, "log_diagnostics"):
                env.log_diagnostics(paths, logger)
            for k, v in eval_util.get_generic_path_information(paths).items():
                logger.record_tabular(k, v)
            logger.dump_tabular()
Esempio n. 6
0
    def evaluate(self, epoch, eval_paths=None):
        statistics = OrderedDict()
        statistics.update(self.eval_statistics)

        logger.log("Collecting samples for evaluation")
        if eval_paths:
            test_paths = eval_paths
        else:
            test_paths = self.get_eval_paths()
        statistics.update(
            eval_util.get_generic_path_information(
                test_paths,
                stat_prefix="Test",
            ))
        if len(self._exploration_paths) > 0:
            statistics.update(
                eval_util.get_generic_path_information(
                    self._exploration_paths,
                    stat_prefix="Exploration",
                ))
        if hasattr(self.env, "log_diagnostics"):
            self.env.log_diagnostics(test_paths, logger=logger)
        if hasattr(self.env, "get_diagnostics"):
            statistics.update(self.env.get_diagnostics(test_paths))

        if hasattr(self.eval_policy, "log_diagnostics"):
            self.eval_policy.log_diagnostics(test_paths, logger=logger)
        if hasattr(self.eval_policy, "get_diagnostics"):
            statistics.update(self.eval_policy.get_diagnostics(test_paths))

        process = psutil.Process(os.getpid())
        statistics['RAM Usage (Mb)'] = int(process.memory_info().rss / 1000000)

        statistics['Exploration Policy Noise'] = self._exploration_policy_noise

        average_returns = eval_util.get_average_returns(test_paths)
        statistics['AverageReturn'] = average_returns
        for key, value in statistics.items():
            logger.record_tabular(key, value)
        self.need_to_update_eval_statistics = True
Esempio n. 7
0
def plot_performance(policy, env, nrolls):
    print("max_tau, distance")
    # fixed_goals = [-40, -30, 30, 40]
    fixed_goals = [-5, -3, 3, 5]
    taus = np.arange(10) * 10
    for row, fix_tau in enumerate([True, False]):
        for col, horizon_fixed in enumerate([True, False]):
            plot_num = row + 2 * col + 1
            plt.subplot(2, 2, plot_num)
            for fixed_goal in fixed_goals:
                distances = []
                for max_tau in taus:
                    paths = []
                    for _ in range(nrolls):
                        goal = env.sample_goal_for_rollout()
                        goal[0] = fixed_goal
                        path = multitask_rollout(
                            env,
                            policy,
                            goal,
                            init_tau=max_tau,
                            max_path_length=100 if horizon_fixed else max_tau +
                            1,
                            animated=False,
                            cycle_tau=True,
                            decrement_tau=not fix_tau,
                        )
                        paths.append(path)
                    env.log_diagnostics(paths)
                    for key, value in get_generic_path_information(
                            paths).items():
                        logger.record_tabular(key, value)
                    distance = float(
                        dict(logger._tabular)['Final Distance to goal Mean'])
                    distances.append(distance)

                plt.plot(taus, distances)
                print("line done")
            plt.legend([str(goal) for goal in fixed_goals])
            if fix_tau:
                plt.xlabel("Tau (Horizon-1)")
            else:
                plt.xlabel("Initial tau (=Horizon-1)")
            plt.xlabel("Max tau")
            plt.ylabel("Final distance to goal")
            plt.title("Fix Tau = {}, Horizon Fixed to 100  = {}".format(
                fix_tau,
                horizon_fixed,
            ))
    plt.show()
    plt.savefig('results/iclr2018/cheetah-sweep-tau-eval-5-3.jpg')
 def _get_diagnostics(self):
     algo_log = super()._get_diagnostics()
     if self.vae_eval_data_collector is not None:
         append_log(algo_log,
                    self.vae_eval_data_collector.get_diagnostics(),
                    prefix='evaluation_vae/')
         vae_eval_paths = self.vae_eval_data_collector.get_epoch_paths()
         if hasattr(self.eval_env, 'get_diagnostics'):
             append_log(algo_log,
                        self.eval_env.get_diagnostics(vae_eval_paths),
                        prefix='evaluation_vae/')
         append_log(algo_log,
                    eval_util.get_generic_path_information(vae_eval_paths),
                    prefix="evaluation_vae/")
     return algo_log
        # some environments need to be reconfigured for visualization
        env.enable_render()
    if args.mode:
        env.mode(args.mode)

    while True:
        paths = []
        for _ in range(args.nrolls):
            if args.silent:
                goal = None
            else:
                goal = env.sample_goal_for_rollout()
            path = multitask_rollout(
                env,
                policy,
                init_tau=max_tau,
                goal=goal,
                max_path_length=args.H,
                # animated=not args.hide,
                cycle_tau=args.cycle or not args.ndc,
                decrement_tau=args.dt or not args.ndc,
                env_samples_goal_on_reset=args.silent,
                # get_action_kwargs={'deterministic': True},
            )
            print("last state", path['next_observations'][-1][21:24])
            paths.append(path)
        env.log_diagnostics(paths)
        for key, value in get_generic_path_information(paths).items():
            logger.record_tabular(key, value)
        logger.dump_tabular()
Esempio n. 10
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('file', type=str, help='path to the snapshot file')
    parser.add_argument('--H',
                        type=int,
                        default=300,
                        help='Max length of rollout')
    parser.add_argument('--nrolls',
                        type=int,
                        default=1,
                        help='Number of rollout per eval')
    parser.add_argument('--verbose', action='store_true')
    parser.add_argument('--mtau', type=float, help='Max tau value')
    parser.add_argument('--grid', action='store_true')
    parser.add_argument('--gpu', action='store_true')
    parser.add_argument('--load', action='store_true')
    parser.add_argument('--hide', action='store_true')
    parser.add_argument('--pause', action='store_true')
    parser.add_argument('--cycle', help='cycle tau', action='store_true')
    args = parser.parse_args()

    data = joblib.load(args.file)
    env = data['env']
    if 'policy' in data:
        policy = data['policy']
    else:
        policy = data['exploration_policy']
    qf = data['qf']
    policy.train(False)
    qf.train(False)

    if args.pause:
        import ipdb
        ipdb.set_trace()

    if args.gpu:
        ptu.set_gpu_mode(True)
        policy.to(ptu.device)

    if args.mtau is None:
        print("Defaulting max tau to 10.")
        max_tau = 10
    else:
        max_tau = args.mtau

    while True:
        paths = []
        for _ in range(args.nrolls):
            goal = env.sample_goal_for_rollout()
            print("goal", goal)
            env.set_goal(goal)
            policy.set_goal(goal)
            policy.set_tau(max_tau)
            path = rollout(
                env,
                policy,
                qf,
                init_tau=max_tau,
                max_path_length=args.H,
                animated=not args.hide,
                cycle_tau=args.cycle,
            )
            paths.append(path)
        env.log_diagnostics(paths)
        for key, value in get_generic_path_information(paths).items():
            logger.record_tabular(key, value)
        logger.dump_tabular()