コード例 #1
0
def simulate_policy(args):
    data = torch.load(args.file)
    policy = data['evaluation/policy']
    env = data['evaluation/env']
    print("Policy loaded")
    if args.gpu:
        set_gpu_mode(True)
        policy.cuda()
    paths = []
    while True:
        path = rollout(
            env,
            policy,
            max_path_length=args.H,
            render=True,
        )
        paths.append(path)
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics(paths)
        if hasattr(env, "get_diagnostics"):
            for k, v in env.get_diagnostics(paths).items():
                logger.record_tabular(k, v)
        else:
            logger.record_dict(
                eval_util.get_generic_path_information(paths),
                prefix="evaluation/",
            )
        logger.dump_tabular()
コード例 #2
0
def train_ae(ae_trainer,
             training_distrib,
             num_epochs=100,
             num_batches_per_epoch=500,
             batch_size=512,
             goal_key='image_desired_goal',
             rl_csv_fname='progress.csv'):
    from rlkit.core import logger

    logger.remove_tabular_output(rl_csv_fname, relative_to_snapshot_dir=True)
    logger.add_tabular_output('ae_progress.csv', relative_to_snapshot_dir=True)

    for epoch in range(num_epochs):
        for batch_num in range(num_batches_per_epoch):
            goals = ptu.from_numpy(
                training_distrib.sample(batch_size)[goal_key])
            batch = dict(raw_next_observations=goals, )
            ae_trainer.train_from_torch(batch)
        log = OrderedDict()
        log['epoch'] = epoch
        append_log(log, ae_trainer.eval_statistics, prefix='ae/')
        logger.record_dict(log)
        logger.dump_tabular(with_prefix=True, with_timestamp=False)
        ae_trainer.end_epoch(epoch)

    logger.add_tabular_output(rl_csv_fname, relative_to_snapshot_dir=True)
    logger.remove_tabular_output('ae_progress.csv',
                                 relative_to_snapshot_dir=True)
コード例 #3
0
 def _log_stats(self, epoch):
     logger.log("Epoch {} finished".format(epoch), with_timestamp=True)
     """
     Trainer
     """
     logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/')
     """
     Evaluation
     """
     logger.record_dict(
         self.eval_data_collector.get_diagnostics(),
         prefix='evaluation/',
     )
     eval_paths = self.eval_data_collector.get_epoch_paths()
     if hasattr(self.eval_env, 'get_diagnostics'):
         logger.record_dict(
             self.eval_env.get_diagnostics(eval_paths),
             prefix='evaluation/',
         )
     logger.record_dict(
         eval_util.get_generic_path_information(eval_paths),
         prefix="evaluation/",
     )
     """
     Misc
     """
     gt.stamp('logging')
     logger.record_dict(_get_epoch_timings())
     logger.record_tabular('Epoch', epoch)
     logger.dump_tabular(with_prefix=False, with_timestamp=False)
コード例 #4
0
 def run(self):
     if self.progress_csv_file_name != 'progress.csv':
         logger.remove_tabular_output('progress.csv',
                                      relative_to_snapshot_dir=True)
         logger.add_tabular_output(
             self.progress_csv_file_name,
             relative_to_snapshot_dir=True,
         )
     timer.return_global_times = True
     for _ in range(self.num_iters):
         self._begin_epoch()
         timer.start_timer('saving')
         logger.save_itr_params(self.epoch, self._get_snapshot())
         timer.stop_timer('saving')
         log_dict, _ = self._train()
         logger.record_dict(log_dict)
         logger.dump_tabular(with_prefix=True, with_timestamp=False)
         self._end_epoch()
     logger.save_itr_params(self.epoch, self._get_snapshot())
     if self.progress_csv_file_name != 'progress.csv':
         logger.remove_tabular_output(
             self.progress_csv_file_name,
             relative_to_snapshot_dir=True,
         )
         logger.add_tabular_output(
             'progress.csv',
             relative_to_snapshot_dir=True,
         )
コード例 #5
0
 def _log_vae_stats(self):
     logger.record_dict(
         self.vae_trainer_original.get_diagnostics(),
         prefix='vae_trainer_original/',
     )
     logger.record_dict(
         self.vae_trainer_segmented.get_diagnostics(),
         prefix='vae_trainer_segmented/',
     )
コード例 #6
0
 def train(self):
     timer.return_global_times = True
     for _ in range(self.num_epochs):
         self._begin_epoch()
         timer.start_timer('saving')
         logger.save_itr_params(self.epoch, self._get_snapshot())
         timer.stop_timer('saving')
         log_dict, _ = self._train()
         logger.record_dict(log_dict)
         logger.dump_tabular(with_prefix=True, with_timestamp=False)
         self._end_epoch()
     logger.save_itr_params(self.epoch, self._get_snapshot())
コード例 #7
0
    def pretrain_q_with_bc_data(self, batch_size):
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain_q.csv',
                                  relative_to_snapshot_dir=True)

        self.update_policy = True
        # then train policy and Q function together
        prev_time = time.time()
        for i in range(self.num_pretrain_steps):
            self.eval_statistics = dict()
            if i % self.pretraining_logging_period == 0:
                self._need_to_update_eval_statistics = True
            train_data = self.replay_buffer.random_batch(batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs  # torch.cat((obs, goals), dim=1)
            train_data[
                'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data)
            if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0:
                total_ret = self.do_rollouts()
                print("Return at step {} : {}".format(i, total_ret / 20))

            if i % self.pretraining_logging_period == 0:
                if self.do_pretrain_rollouts:
                    self.eval_statistics[
                        "pretrain_bc/avg_return"] = total_ret / 20
                self.eval_statistics["batch"] = i
                self.eval_statistics["epoch_time"] = time.time() - prev_time
                logger.record_dict(self.eval_statistics)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_q.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        self._need_to_update_eval_statistics = True
        self.eval_statistics = dict()
コード例 #8
0
 def train(self):
     # first train only the Q function
     iteration = 0
     for i in range(self.num_batches):
         train_data = self.replay_buffer.random_batch(self.batch_size)
         train_data = np_to_pytorch_batch(train_data)
         obs = train_data['observations']
         next_obs = train_data['next_observations']
         train_data['observations'] = obs
         train_data['next_observations'] = next_obs
         self.trainer.train_from_torch(train_data)
         if i % self.logging_period == 0:
             stats_with_prefix = add_prefix(
                 self.trainer.eval_statistics, prefix="trainer/")
             self.trainer.end_epoch(iteration)
             iteration += 1
             logger.record_dict(stats_with_prefix)
             logger.dump_tabular(with_prefix=True, with_timestamp=False)
コード例 #9
0
def simulate_policy(fpath,
                    env_name,
                    seed,
                    max_path_length,
                    num_eval_steps,
                    headless,
                    max_eps,
                    verbose=True,
                    pause=False):
    data = torch.load(fpath, map_location=ptu.device)
    policy = data['evaluation/policy']
    policy.to(ptu.device)
    # make new env, reloading with data['evaluation/env'] seems to make bug
    env = gym.make(env_name, **{"headless": headless, "verbose": False})
    env.seed(seed)
    if pause:
        input("Waiting to start.")
    path_collector = MdpPathCollector(env, policy)
    paths = path_collector.collect_new_paths(
        max_path_length,
        num_eval_steps,
        discard_incomplete_paths=True,
    )

    if max_eps:
        paths = paths[:max_eps]
    if verbose:
        completions = sum([
            info["completed"] for path in paths for info in path["env_infos"]
        ])
        print("Completed {} out of {}".format(completions, len(paths)))
        # plt.plot(paths[0]["actions"])
        # plt.show()
        # plt.plot(paths[2]["observations"])
        # plt.show()
        logger.record_dict(
            eval_util.get_generic_path_information(paths),
            prefix="evaluation/",
        )
        logger.dump_tabular()
    return paths
コード例 #10
0
    def pretrain_q_with_bc_data(self, batch_size):
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain_q.csv',
                                  relative_to_snapshot_dir=True)

        prev_time = time.time()
        for i in range(self.num_pretrain_steps):
            self.eval_statistics = dict()
            if i % self.pretraining_logging_period == 0:
                self._need_to_update_eval_statistics = True
            train_data = self.replay_buffer.random_batch(self.bc_batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs  # torch.cat((obs, goals), dim=1)
            train_data[
                'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data, pretrain=True)

            if i % self.pretraining_logging_period == 0:
                self.eval_statistics["batch"] = i
                self.eval_statistics["epoch_time"] = time.time() - prev_time
                stats_with_prefix = add_prefix(self.eval_statistics,
                                               prefix="trainer/")
                logger.record_dict(stats_with_prefix)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_q.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        self._need_to_update_eval_statistics = True
        self.eval_statistics = dict()
コード例 #11
0
    def train(self):
        # first train only the Q function
        iteration = 0
        timer.return_global_times = True
        timer.reset()
        for i in range(self.num_batches):
            if self.use_meta_learning_buffer:
                train_data = self.meta_replay_buffer.sample_meta_batch(
                    rl_batch_size=self.batch_size,
                    meta_batch_size=self.meta_batch_size,
                    embedding_batch_size=self.task_embedding_batch_size,
                )
                train_data = np_to_pytorch_batch(train_data)
            else:
                task_indices = np.random.choice(
                    self.train_tasks, self.meta_batch_size,
                )
                train_data = self.replay_buffer.sample_batch(
                    task_indices,
                    self.batch_size,
                )
                train_data = np_to_pytorch_batch(train_data)
                obs = train_data['observations']
                next_obs = train_data['next_observations']
                train_data['observations'] = obs
                train_data['next_observations'] = next_obs
                train_data['context'] = (
                    self.task_embedding_replay_buffer.sample_context(
                        task_indices,
                        self.task_embedding_batch_size,
                    ))
            timer.start_timer('train', unique=False)
            self.trainer.train_from_torch(train_data)
            timer.stop_timer('train')
            if i % self.logging_period == 0 or i == self.num_batches - 1:
                stats_with_prefix = add_prefix(
                    self.trainer.eval_statistics, prefix="trainer/")
                self.trainer.end_epoch(iteration)
                logger.record_dict(stats_with_prefix)
                timer.start_timer('extra_fns', unique=False)
                for fn in self._extra_eval_fns:
                    extra_stats = fn()
                    logger.record_dict(extra_stats)
                timer.stop_timer('extra_fns')


                # TODO: evaluate during offline RL
                # eval_stats = self.get_eval_statistics()
                # eval_stats_with_prefix = add_prefix(eval_stats, prefix="eval/")
                # logger.record_dict(eval_stats_with_prefix)

                logger.record_tabular('iteration', iteration)
                logger.record_dict(_get_epoch_timings())
                try:
                    import os
                    import psutil
                    process = psutil.Process(os.getpid())
                    logger.record_tabular('RAM Usage (Mb)', int(process.memory_info().rss / 1000000))
                except ImportError:
                    pass
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                iteration += 1
コード例 #12
0
 def _log_stats(self, epoch, num_epochs_per_eval=0):
     logger.log("Epoch {} finished".format(epoch), with_timestamp=True)
     """
     Replay Buffer
     """
     logger.record_dict(self.replay_buffer.get_diagnostics(),
                        prefix='replay_buffer/')
     """
     Trainer
     """
     logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/')
     """
     Exploration
     """
     logger.record_dict(self.expl_data_collector.get_diagnostics(),
                        prefix='exploration/')
     expl_paths = self.expl_data_collector.get_epoch_paths()
     if hasattr(self.expl_env, 'get_diagnostics'):
         logger.record_dict(
             self.expl_env.get_diagnostics(expl_paths),
             prefix='exploration/',
         )
     logger.record_dict(
         eval_util.get_generic_path_information(expl_paths),
         prefix="exploration/",
     )
     """
     Evaluation
     """
     try:
         # if epoch % num_epochs_per_eval == 0:
         logger.record_dict(
             self.eval_data_collector.get_diagnostics(),
             prefix='evaluation/',
         )
         eval_paths = self.eval_data_collector.get_epoch_paths()
         if hasattr(self.eval_env, 'get_diagnostics'):
             logger.record_dict(
                 self.eval_env.get_diagnostics(eval_paths),
                 prefix='evaluation/',
             )
         logger.record_dict(
             eval_util.get_generic_path_information(eval_paths),
             prefix="evaluation/",
         )
     except:
         pass
     """
     Misc
     """
     gt.stamp('logging')
     logger.record_dict(_get_epoch_timings())
     logger.record_tabular('Epoch', epoch)
     logger.dump_tabular(with_prefix=False, with_timestamp=False)
コード例 #13
0
ファイル: rl_algorithm.py プロジェクト: ngthanhtin/cog
 def _log_stats(self, epoch):
     logger.log("Epoch {} finished".format(epoch), with_timestamp=True)
     """
     Replay Buffer
     """
     logger.record_dict(self.replay_buffer.get_diagnostics(),
                        prefix='replay_buffer/')
     """
     Trainer
     """
     logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/')
     """
     Exploration
     """
     logger.record_dict(self.expl_data_collector.get_diagnostics(),
                        prefix='exploration/')
     expl_paths = self.expl_data_collector.get_epoch_paths()
     # import ipdb; ipdb.set_trace()
     if hasattr(self.expl_env, 'get_diagnostics'):
         logger.record_dict(
             self.expl_env.get_diagnostics(expl_paths),
             prefix='exploration/',
         )
     if not self.batch_rl or self.eval_both:
         logger.record_dict(
             eval_util.get_generic_path_information(expl_paths),
             prefix="exploration/",
         )
     """
     Evaluation
     """
     logger.record_dict(
         self.eval_data_collector.get_diagnostics(),
         prefix='evaluation/',
     )
     eval_paths = self.eval_data_collector.get_epoch_paths()
     if hasattr(self.eval_env, 'get_diagnostics'):
         logger.record_dict(
             self.eval_env.get_diagnostics(eval_paths),
             prefix='evaluation/',
         )
     logger.record_dict(
         eval_util.get_generic_path_information(eval_paths),
         prefix="evaluation/",
     )
     """
     Misc
     """
     gt.stamp('logging')
     logger.record_dict(_get_epoch_timings())
     logger.record_tabular('Epoch', epoch)
     logger.dump_tabular(with_prefix=False, with_timestamp=False)
コード例 #14
0
ファイル: awac_trainer.py プロジェクト: DrawZeroPoint/RoRL
    def pretrain_q_with_bc_data(self):
        """

        :return:
        """
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain_q.csv',
                                  relative_to_snapshot_dir=True)

        self.update_policy = False
        # first train only the Q function
        for i in range(self.q_num_pretrain1_steps):
            self.eval_statistics = dict()

            train_data = self.replay_buffer.random_batch(self.bc_batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs  # torch.cat((obs, goals), dim=1)
            train_data[
                'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data, pretrain=True)
            if i % self.pretraining_logging_period == 0:
                stats_with_prefix = add_prefix(self.eval_statistics,
                                               prefix="trainer/")
                logger.record_dict(stats_with_prefix)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)

        self.update_policy = True
        # then train policy and Q function together
        prev_time = time.time()
        for i in range(self.q_num_pretrain2_steps):
            self.eval_statistics = dict()
            if i % self.pretraining_logging_period == 0:
                self._need_to_update_eval_statistics = True
            train_data = self.replay_buffer.random_batch(self.bc_batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs  # torch.cat((obs, goals), dim=1)
            train_data[
                'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data, pretrain=True)

            if i % self.pretraining_logging_period == 0:
                self.eval_statistics["batch"] = i
                self.eval_statistics["epoch_time"] = time.time() - prev_time
                stats_with_prefix = add_prefix(self.eval_statistics,
                                               prefix="trainer/")
                logger.record_dict(stats_with_prefix)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_q.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        self._need_to_update_eval_statistics = True
        self.eval_statistics = dict()

        if self.post_pretrain_hyperparams:
            self.set_algorithm_weights(**self.post_pretrain_hyperparams)
コード例 #15
0
ファイル: awac_trainer.py プロジェクト: DrawZeroPoint/RoRL
    def pretrain_policy_with_bc(
        self,
        policy,
        train_buffer,
        test_buffer,
        steps,
        label="policy",
    ):
        """Given a policy, first get its optimizer, then run the policy on the train buffer, get the
        losses, and back propagate the loss. After training on a batch, test on the test buffer and
        get the statistics

        :param policy:
        :param train_buffer:
        :param test_buffer:
        :param steps:
        :param label:
        :return:
        """
        logger.remove_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'pretrain_%s.csv' % label,
            relative_to_snapshot_dir=True,
        )

        optimizer = self.optimizers[policy]
        prev_time = time.time()
        for i in range(steps):
            train_policy_loss, train_logp_loss, train_mse_loss, train_stats = self.run_bc_batch(
                train_buffer, policy)
            train_policy_loss = train_policy_loss * self.bc_weight

            optimizer.zero_grad()
            train_policy_loss.backward()
            optimizer.step()

            test_policy_loss, test_logp_loss, test_mse_loss, test_stats = self.run_bc_batch(
                test_buffer, policy)
            test_policy_loss = test_policy_loss * self.bc_weight

            if i % self.pretraining_logging_period == 0:
                stats = {
                    "pretrain_bc/batch":
                    i,
                    "pretrain_bc/Train Logprob Loss":
                    ptu.get_numpy(train_logp_loss),
                    "pretrain_bc/Test Logprob Loss":
                    ptu.get_numpy(test_logp_loss),
                    "pretrain_bc/Train MSE":
                    ptu.get_numpy(train_mse_loss),
                    "pretrain_bc/Test MSE":
                    ptu.get_numpy(test_mse_loss),
                    "pretrain_bc/train_policy_loss":
                    ptu.get_numpy(train_policy_loss),
                    "pretrain_bc/test_policy_loss":
                    ptu.get_numpy(test_policy_loss),
                    "pretrain_bc/epoch_time":
                    time.time() - prev_time,
                }

                logger.record_dict(stats)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                pickle.dump(
                    self.policy,
                    open(logger.get_snapshot_dir() + '/bc_%s.pkl' % label,
                         "wb"))
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_%s.csv' % label,
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        if self.post_bc_pretrain_hyperparams:
            self.set_algorithm_weights(**self.post_bc_pretrain_hyperparams)
コード例 #16
0
ファイル: rl_algorithm.py プロジェクト: rstrudel/rlkit
 def _log_stats(self, epoch):
     logger.log("Epoch {} finished".format(epoch), with_timestamp=True)
     """
     Replay Buffer
     """
     logger.record_dict(
         self.replay_buffer.get_diagnostics(),
         global_step=epoch,
         prefix="replay_buffer/",
     )
     """
     Trainer
     """
     logger.record_dict(self.trainer.get_diagnostics(),
                        global_step=epoch,
                        prefix="trainer/")
     """
     Exploration
     """
     logger.record_dict(
         self.expl_data_collector.get_diagnostics(),
         global_step=epoch,
         prefix="exploration/",
     )
     expl_paths = self.expl_data_collector.get_epoch_paths()
     if hasattr(self.expl_env, "get_diagnostics"):
         logger.record_dict(
             self.expl_env.get_diagnostics(expl_paths),
             global_step=epoch,
             prefix="exploration/",
         )
     logger.record_dict(
         eval_util.get_generic_path_information(expl_paths),
         global_step=epoch,
         prefix="exploration/",
     )
     """
     Evaluation
     """
     logger.record_dict(
         self.eval_data_collector.get_diagnostics(),
         global_step=epoch,
         prefix="evaluation/",
     )
     eval_paths = self.eval_data_collector.get_epoch_paths()
     if hasattr(self.eval_env, "get_diagnostics"):
         logger.record_dict(
             self.eval_env.get_diagnostics(eval_paths),
             global_step=epoch,
             prefix="evaluation/",
         )
     logger.record_dict(
         eval_util.get_generic_path_information(eval_paths),
         global_step=epoch,
         prefix="evaluation/",
     )
     """
     Misc
     """
     gt.stamp("logging")
     logger.record_dict(_get_epoch_timings(), global_step=epoch)
     logger.record_tabular("Epoch", epoch)
     logger.dump_tabular(with_prefix=False, with_timestamp=False)
コード例 #17
0
ファイル: kitchen_train.py プロジェクト: mihdalal/rad
    def run_eval_loop(sample_stochastically=True):
        start_time = time.time()
        prefix = "stochastic_" if sample_stochastically else ""
        for i in range(num_episodes):
            obs = env.reset()
            video.init(enabled=(i == 0))
            done = False
            episode_reward = 0
            ep_infos = []
            while not done:
                # center crop image
                if encoder_type == "pixel" and "crop" in data_augs:
                    obs = utils.center_crop_image(obs, image_size)
                if encoder_type == "pixel" and "translate" in data_augs:
                    # first crop the center with pre_transform_image_size
                    obs = utils.center_crop_image(obs, pre_transform_image_size)
                    # then translate cropped to center
                    obs = utils.center_translate(obs, image_size)
                with utils.eval_mode(agent):
                    if sample_stochastically:
                        action = agent.sample_action(obs / 255.0)
                    else:
                        action = agent.select_action(obs / 255.0)
                obs, reward, done, info = env.step(action)
                video.record(env)
                episode_reward += reward
                ep_infos.append(info)

            video.save("%d.mp4" % step)
            L.log("eval/" + prefix + "episode_reward", episode_reward, step)
            all_ep_rewards.append(episode_reward)
            all_infos.append(ep_infos)

        L.log("eval/" + prefix + "eval_time", time.time() - start_time, step)
        mean_ep_reward = np.mean(all_ep_rewards)
        best_ep_reward = np.max(all_ep_rewards)
        std_ep_reward = np.std(all_ep_rewards)
        L.log("eval/" + prefix + "mean_episode_reward", mean_ep_reward, step)
        L.log("eval/" + prefix + "best_episode_reward", best_ep_reward, step)
        rlkit_logger.record_dict(
            {"Average Returns": mean_ep_reward}, prefix="evaluation/"
        )
        statistics = compute_path_info(all_infos)
        rlkit_logger.record_dict(statistics, prefix="evaluation/")

        filename = (
            work_dir
            + "/"
            + env_name
            + "-"
            + data_augs
            + "--s"
            + str(seed)
            + "--eval_scores.npy"
        )
        key = env_name + "-" + data_augs
        try:
            log_data = np.load(filename, allow_pickle=True)
            log_data = log_data.item()
        except:
            log_data = {}

        if key not in log_data:
            log_data[key] = {}

        log_data[key][step] = {}
        log_data[key][step]["step"] = step
        log_data[key][step]["mean_ep_reward"] = mean_ep_reward
        log_data[key][step]["max_ep_reward"] = best_ep_reward
        log_data[key][step]["std_ep_reward"] = std_ep_reward
        log_data[key][step]["env_step"] = step * action_repeat

        np.save(filename, log_data)
コード例 #18
0
ファイル: kitchen_train.py プロジェクト: mihdalal/rad
def experiment(variant):
    gym.logger.set_level(40)
    work_dir = rlkit_logger.get_snapshot_dir()
    args = parse_args()
    seed = int(variant["seed"])
    utils.set_seed_everywhere(seed)
    os.makedirs(work_dir, exist_ok=True)
    agent_kwargs = variant["agent_kwargs"]
    data_augs = agent_kwargs["data_augs"]
    encoder_type = agent_kwargs["encoder_type"]
    discrete_continuous_dist = agent_kwargs["discrete_continuous_dist"]

    env_suite = variant["env_suite"]
    env_name = variant["env_name"]
    env_kwargs = variant["env_kwargs"]
    pre_transform_image_size = variant["pre_transform_image_size"]
    image_size = variant["image_size"]
    frame_stack = variant["frame_stack"]
    batch_size = variant["batch_size"]
    replay_buffer_capacity = variant["replay_buffer_capacity"]
    num_train_steps = variant["num_train_steps"]
    num_eval_episodes = variant["num_eval_episodes"]
    eval_freq = variant["eval_freq"]
    action_repeat = variant["action_repeat"]
    init_steps = variant["init_steps"]
    log_interval = variant["log_interval"]
    use_raw_actions = variant["use_raw_actions"]
    pre_transform_image_size = (
        pre_transform_image_size if "crop" in data_augs else image_size
    )
    pre_transform_image_size = pre_transform_image_size

    if data_augs == "crop":
        pre_transform_image_size = 100
        image_size = image_size
    elif data_augs == "translate":
        pre_transform_image_size = 100
        image_size = 108

    if env_suite == 'kitchen':
        env_kwargs['imwidth'] = pre_transform_image_size
        env_kwargs['imheight'] = pre_transform_image_size
    else:
        env_kwargs['image_kwargs']['imwidth'] = pre_transform_image_size
        env_kwargs['image_kwargs']['imheight'] = pre_transform_image_size

    expl_env = primitives_make_env.make_env(env_suite, env_name, env_kwargs)
    eval_env = primitives_make_env.make_env(env_suite, env_name, env_kwargs)
    # stack several consecutive frames together
    if encoder_type == "pixel":
        expl_env = utils.FrameStack(expl_env, k=frame_stack)
        eval_env = utils.FrameStack(eval_env, k=frame_stack)

    # make directory
    ts = time.gmtime()
    ts = time.strftime("%m-%d", ts)
    env_name = env_name
    exp_name = (
        env_name
        + "-"
        + ts
        + "-im"
        + str(image_size)
        + "-b"
        + str(batch_size)
        + "-s"
        + str(seed)
        + "-"
        + encoder_type
    )
    work_dir = work_dir + "/" + exp_name

    utils.make_dir(work_dir)
    video_dir = utils.make_dir(os.path.join(work_dir, "video"))
    model_dir = utils.make_dir(os.path.join(work_dir, "model"))
    buffer_dir = utils.make_dir(os.path.join(work_dir, "buffer"))

    video = VideoRecorder(video_dir if args.save_video else None)

    with open(os.path.join(work_dir, "args.json"), "w") as f:
        json.dump(vars(args), f, sort_keys=True, indent=4)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if use_raw_actions:
        continuous_action_dim = expl_env.action_space.low.size
        discrete_action_dim = 0
    else:
        num_primitives = expl_env.num_primitives
        max_arg_len = expl_env.max_arg_len
        if discrete_continuous_dist:
            continuous_action_dim = max_arg_len
            discrete_action_dim = num_primitives
        else:
            continuous_action_dim = max_arg_len + num_primitives
            discrete_action_dim = 0

    if encoder_type == "pixel":
        obs_shape = (3 * frame_stack, image_size, image_size)
        pre_aug_obs_shape = (
            3 * frame_stack,
            pre_transform_image_size,
            pre_transform_image_size,
        )
    else:
        obs_shape = env.observation_space.shape
        pre_aug_obs_shape = obs_shape

    replay_buffer = utils.ReplayBuffer(
        obs_shape=pre_aug_obs_shape,
        action_size=continuous_action_dim + discrete_action_dim,
        capacity=replay_buffer_capacity,
        batch_size=batch_size,
        device=device,
        image_size=image_size,
        pre_image_size=pre_transform_image_size,
    )

    agent = make_agent(
        obs_shape=obs_shape,
        continuous_action_dim=continuous_action_dim,
        discrete_action_dim=discrete_action_dim,
        args=args,
        device=device,
        agent_kwargs=agent_kwargs,
    )

    L = Logger(work_dir, use_tb=args.save_tb)

    episode, episode_reward, done = 0, 0, True
    start_time = time.time()
    epoch_start_time = time.time()
    train_expl_st = time.time()
    total_train_expl_time = 0
    all_infos = []
    ep_infos = []
    num_train_calls = 0
    for step in range(num_train_steps):
        # evaluate agent periodically

        if step % eval_freq == 0:
            total_train_expl_time += time.time()-train_expl_st
            L.log("eval/episode", episode, step)
            evaluate(
                eval_env,
                agent,
                video,
                num_eval_episodes,
                L,
                step,
                encoder_type,
                data_augs,
                image_size,
                pre_transform_image_size,
                env_name,
                action_repeat,
                work_dir,
                seed,
            )
            if args.save_model:
                agent.save_curl(model_dir, step)
            if args.save_buffer:
                replay_buffer.save(buffer_dir)
            train_expl_st = time.time()
        if done:
            if step > 0:
                if step % log_interval == 0:
                    L.log("train/duration", time.time() - epoch_start_time, step)
                    L.dump(step)
            if step % log_interval == 0:
                L.log("train/episode_reward", episode_reward, step)
            obs = expl_env.reset()
            done = False
            episode_reward = 0
            episode_step = 0
            episode += 1
            if step % log_interval == 0:
                all_infos.append(ep_infos)

                L.log("train/episode", episode, step)
                statistics = compute_path_info(all_infos)

                rlkit_logger.record_dict(statistics, prefix="exploration/")
                rlkit_logger.record_tabular(
                    "time/epoch (s)", time.time() - epoch_start_time
                )
                rlkit_logger.record_tabular("time/total (s)", time.time() - start_time)
                rlkit_logger.record_tabular("time/training and exploration (s)", total_train_expl_time)
                rlkit_logger.record_tabular("trainer/num train calls", num_train_calls)
                rlkit_logger.record_tabular("exploration/num steps total", step)
                rlkit_logger.record_tabular("Epoch", step // log_interval)
                rlkit_logger.dump_tabular(with_prefix=False, with_timestamp=False)
                all_infos = []
                epoch_start_time = time.time()
            ep_infos = []


        # sample action for data collection
        if step < init_steps:
            action = expl_env.action_space.sample()
        else:
            with utils.eval_mode(agent):
                action = agent.sample_action(obs / 255.0)

        # run training update
        if step >= init_steps:
            num_updates = 1
            for _ in range(num_updates):
                agent.update(replay_buffer, L, step)
                num_train_calls += 1

        next_obs, reward, done, info = expl_env.step(action)
        ep_infos.append(info)
        # allow infinit bootstrap
        done_bool = (
            0 if episode_step + 1 == expl_env._max_episode_steps else float(done)
        )
        episode_reward += reward
        replay_buffer.add(obs, action, reward, next_obs, done_bool)

        obs = next_obs
        episode_step += 1
コード例 #19
0
    def _log_stats(self, epoch):
        logger.log(f"Epoch {epoch} finished", with_timestamp=True)

        """
        Replay Buffer
        """
        logger.record_dict(
            self.replay_buffer.get_diagnostics(), prefix="replay_buffer/"
        )

        """
        Trainer
        """
        logger.record_dict(self.trainer.get_diagnostics(), prefix="trainer/")

        """
        Exploration
        """
        logger.record_dict(
            self.expl_data_collector.get_diagnostics(), prefix="exploration/"
        )

        expl_paths = self.expl_data_collector.get_epoch_paths()
        if len(expl_paths) > 0:
            if hasattr(self.expl_env, "get_diagnostics"):
                logger.record_dict(
                    self.expl_env.get_diagnostics(expl_paths),
                    prefix="exploration/",
                )

            logger.record_dict(
                eval_util.get_generic_path_information(expl_paths),
                prefix="exploration/",
            )

        """
        Evaluation
        """
        logger.record_dict(
            self.eval_data_collector.get_diagnostics(),
            prefix="evaluation/",
        )
        eval_paths = self.eval_data_collector.get_epoch_paths()
        if hasattr(self.eval_env, "get_diagnostics"):
            logger.record_dict(
                self.eval_env.get_diagnostics(eval_paths),
                prefix="evaluation/",
            )

        logger.record_dict(
            eval_util.get_generic_path_information(eval_paths),
            prefix="evaluation/",
        )

        """
        Misc
        """
        gt.stamp("logging")
        timings = _get_epoch_timings()
        timings["time/training and exploration (s)"] = self.total_train_expl_time
        logger.record_dict(timings)

        logger.record_tabular("Epoch", epoch)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)
コード例 #20
0
    def pretrain_policy_with_bc(self):
        if self.buffer_for_bc_training == "demos":
            self.bc_training_buffer = self.demo_train_buffer
            self.bc_test_buffer = self.demo_test_buffer
        elif self.buffer_for_bc_training == "replay_buffer":
            self.bc_training_buffer = self.replay_buffer.train_replay_buffer
            self.bc_test_buffer = self.replay_buffer.validation_replay_buffer
        else:
            self.bc_training_buffer = None
            self.bc_test_buffer = None

        if self.load_policy_path:
            self.policy = load_local_or_remote_file(self.load_policy_path)
            ptu.copy_model_params_from_to(self.policy, self.target_policy)
            return

        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain_policy.csv',
                                  relative_to_snapshot_dir=True)
        if self.do_pretrain_rollouts:
            total_ret = self.do_rollouts()
            print("INITIAL RETURN", total_ret / 20)

        prev_time = time.time()
        for i in range(self.bc_num_pretrain_steps):
            train_policy_loss, train_logp_loss, train_mse_loss, train_log_std = self.run_bc_batch(
                self.demo_train_buffer, self.policy)
            train_policy_loss = train_policy_loss * self.bc_weight

            self.policy_optimizer.zero_grad()
            train_policy_loss.backward()
            self.policy_optimizer.step()

            test_policy_loss, test_logp_loss, test_mse_loss, test_log_std = self.run_bc_batch(
                self.demo_test_buffer, self.policy)
            test_policy_loss = test_policy_loss * self.bc_weight

            if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0:
                total_ret = self.do_rollouts()
                print("Return at step {} : {}".format(i, total_ret / 20))

            if i % self.pretraining_logging_period == 0:
                stats = {
                    "pretrain_bc/batch":
                    i,
                    "pretrain_bc/Train Logprob Loss":
                    ptu.get_numpy(train_logp_loss),
                    "pretrain_bc/Test Logprob Loss":
                    ptu.get_numpy(test_logp_loss),
                    "pretrain_bc/Train MSE":
                    ptu.get_numpy(train_mse_loss),
                    "pretrain_bc/Test MSE":
                    ptu.get_numpy(test_mse_loss),
                    "pretrain_bc/train_policy_loss":
                    ptu.get_numpy(train_policy_loss),
                    "pretrain_bc/test_policy_loss":
                    ptu.get_numpy(test_policy_loss),
                    "pretrain_bc/epoch_time":
                    time.time() - prev_time,
                }

                if self.do_pretrain_rollouts:
                    stats["pretrain_bc/avg_return"] = total_ret / 20

                logger.record_dict(stats)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                pickle.dump(self.policy,
                            open(logger.get_snapshot_dir() + '/bc.pkl', "wb"))
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_policy.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        ptu.copy_model_params_from_to(self.policy, self.target_policy)

        if self.post_bc_pretrain_hyperparams:
            self.set_algorithm_weights(**self.post_bc_pretrain_hyperparams)
コード例 #21
0
    def _try_to_eval(self, epoch):
        if epoch % self.logging_period != 0:
            return
        if epoch in self.save_extra_manual_epoch_set:
            logger.save_extra_data(
                self.get_extra_data_to_save(epoch),
                file_name='extra_snapshot_itr{}'.format(epoch),
                mode='cloudpickle',
            )
        if self._save_extra_every_epoch:
            logger.save_extra_data(self.get_extra_data_to_save(epoch))
        gt.stamp('save-extra')
        if self._can_evaluate():
            self.evaluate(epoch)
            gt.stamp('eval')

            params = self.get_epoch_snapshot(epoch)
            logger.save_itr_params(epoch, params)
            gt.stamp('save-snapshot')
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration.")
            self._old_table_keys = table_keys

            logger.record_dict(
                self.trainer.get_diagnostics(),
                prefix='trainer/',
            )

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            save_extra_time = times_itrs['save-extra'][-1]
            save_snapshot_time = times_itrs['save-snapshot'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + save_extra_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('in_unsupervised_model',
                                  float(self.in_unsupervised_phase))
            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Save Extra Time (s)', save_extra_time)
            logger.record_tabular('Save Snapshot Time (s)', save_snapshot_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
コード例 #22
0
def experiment(variant):
    import numpy as np
    import torch
    from torch import nn, optim
    from tqdm import tqdm

    import rlkit.torch.pytorch_util as ptu
    from rlkit.core import logger
    from rlkit.envs.primitives_make_env import make_env
    from rlkit.torch.model_based.dreamer.mlp import Mlp, MlpResidual
    from rlkit.torch.model_based.dreamer.train_world_model import (
        compute_world_model_loss,
        get_dataloader,
        get_dataloader_rt,
        get_dataloader_separately,
        update_network,
        visualize_rollout,
        world_model_loss_rt,
    )
    from rlkit.torch.model_based.dreamer.world_models import (
        LowlevelRAPSWorldModel,
        WorldModel,
    )

    env_suite, env_name, env_kwargs = (
        variant["env_suite"],
        variant["env_name"],
        variant["env_kwargs"],
    )
    max_path_length = variant["env_kwargs"]["max_path_length"]
    low_level_primitives = variant["low_level_primitives"]
    num_low_level_actions_per_primitive = variant[
        "num_low_level_actions_per_primitive"]
    low_level_action_dim = variant["low_level_action_dim"]
    dataloader_kwargs = variant["dataloader_kwargs"]
    env = make_env(env_suite, env_name, env_kwargs)
    world_model_kwargs = variant["model_kwargs"]
    optimizer_kwargs = variant["optimizer_kwargs"]
    gradient_clip = variant["gradient_clip"]
    if low_level_primitives:
        world_model_kwargs["action_dim"] = low_level_action_dim
    else:
        world_model_kwargs["action_dim"] = env.action_space.low.shape[0]
    image_shape = env.image_shape
    world_model_kwargs["image_shape"] = image_shape
    scaler = torch.cuda.amp.GradScaler()
    world_model_loss_kwargs = variant["world_model_loss_kwargs"]
    clone_primitives = variant["clone_primitives"]
    clone_primitives_separately = variant["clone_primitives_separately"]
    clone_primitives_and_train_world_model = variant.get(
        "clone_primitives_and_train_world_model", False)
    batch_len = variant.get("batch_len", 100)
    num_epochs = variant["num_epochs"]
    loss_to_use = variant.get("loss_to_use", "both")

    logdir = logger.get_snapshot_dir()

    if clone_primitives_separately:
        (
            train_dataloaders,
            test_dataloaders,
            train_datasets,
            test_datasets,
        ) = get_dataloader_separately(
            variant["datafile"],
            num_low_level_actions_per_primitive=
            num_low_level_actions_per_primitive,
            num_primitives=env.num_primitives,
            env=env,
            **dataloader_kwargs,
        )
    elif clone_primitives_and_train_world_model:
        print("LOADING DATA")
        (
            train_dataloader,
            test_dataloader,
            train_dataset,
            test_dataset,
        ) = get_dataloader_rt(
            variant["datafile"],
            max_path_length=max_path_length *
            num_low_level_actions_per_primitive + 1,
            **dataloader_kwargs,
        )
    elif low_level_primitives or clone_primitives:
        print("LOADING DATA")
        (
            train_dataloader,
            test_dataloader,
            train_dataset,
            test_dataset,
        ) = get_dataloader(
            variant["datafile"],
            max_path_length=max_path_length *
            num_low_level_actions_per_primitive + 1,
            **dataloader_kwargs,
        )
    else:
        train_dataloader, test_dataloader, train_dataset, test_dataset = get_dataloader(
            variant["datafile"],
            max_path_length=max_path_length + 1,
            **dataloader_kwargs,
        )

    if clone_primitives_and_train_world_model:
        if variant["mlp_act"] == "elu":
            mlp_act = nn.functional.elu
        elif variant["mlp_act"] == "relu":
            mlp_act = nn.functional.relu
        if variant["mlp_res"]:
            mlp_class = MlpResidual
        else:
            mlp_class = Mlp
        criterion = nn.MSELoss()
        primitive_model = mlp_class(
            hidden_sizes=variant["mlp_hidden_sizes"],
            output_size=low_level_action_dim,
            input_size=250 + env.action_space.low.shape[0] + 1,
            hidden_activation=mlp_act,
        ).to(ptu.device)
        world_model_class = LowlevelRAPSWorldModel
        world_model = world_model_class(
            primitive_model=primitive_model,
            **world_model_kwargs,
        ).to(ptu.device)
        optimizer = optim.Adam(
            world_model.parameters(),
            **optimizer_kwargs,
        )
        best_test_loss = np.inf
        for i in tqdm(range(num_epochs)):
            eval_statistics = OrderedDict()
            print("Epoch: ", i)
            total_primitive_loss = 0
            total_world_model_loss = 0
            total_div_loss = 0
            total_image_pred_loss = 0
            total_transition_loss = 0
            total_entropy_loss = 0
            total_pred_discount_loss = 0
            total_reward_pred_loss = 0
            total_train_steps = 0
            for data in train_dataloader:
                with torch.cuda.amp.autocast():
                    (
                        high_level_actions,
                        obs,
                        rewards,
                        terminals,
                    ), low_level_actions = data
                    obs = obs.to(ptu.device).float()
                    low_level_actions = low_level_actions.to(
                        ptu.device).float()
                    high_level_actions = high_level_actions.to(
                        ptu.device).float()
                    rewards = rewards.to(ptu.device).float()
                    terminals = terminals.to(ptu.device).float()
                    assert all(terminals[:, -1] == 1)
                    rt_idxs = np.arange(
                        num_low_level_actions_per_primitive,
                        obs.shape[1],
                        num_low_level_actions_per_primitive,
                    )
                    rt_idxs = np.concatenate(
                        [[0], rt_idxs]
                    )  # reset obs, effect of first primitive, second primitive, so on

                    batch_start = np.random.randint(0,
                                                    obs.shape[1] - batch_len,
                                                    size=(obs.shape[0]))
                    batch_indices = np.linspace(
                        batch_start,
                        batch_start + batch_len,
                        batch_len,
                        endpoint=False,
                    ).astype(int)
                    (
                        post,
                        prior,
                        post_dist,
                        prior_dist,
                        image_dist,
                        reward_dist,
                        pred_discount_dist,
                        _,
                        action_preds,
                    ) = world_model(
                        obs,
                        (high_level_actions, low_level_actions),
                        use_network_action=False,
                        batch_indices=batch_indices,
                        rt_idxs=rt_idxs,
                    )
                    obs = world_model.flatten_obs(
                        obs[np.arange(batch_indices.shape[1]),
                            batch_indices].permute(1, 0, 2),
                        (int(np.prod(image_shape)), ),
                    )
                    rewards = rewards.reshape(-1, rewards.shape[-1])
                    terminals = terminals.reshape(-1, terminals.shape[-1])
                    (
                        world_model_loss,
                        div,
                        image_pred_loss,
                        reward_pred_loss,
                        transition_loss,
                        entropy_loss,
                        pred_discount_loss,
                    ) = world_model_loss_rt(
                        world_model,
                        image_shape,
                        image_dist,
                        reward_dist,
                        {
                            key: value[np.arange(batch_indices.shape[1]),
                                       batch_indices].permute(1, 0, 2).reshape(
                                           -1, value.shape[-1])
                            for key, value in prior.items()
                        },
                        {
                            key: value[np.arange(batch_indices.shape[1]),
                                       batch_indices].permute(1, 0, 2).reshape(
                                           -1, value.shape[-1])
                            for key, value in post.items()
                        },
                        prior_dist,
                        post_dist,
                        pred_discount_dist,
                        obs,
                        rewards,
                        terminals,
                        **world_model_loss_kwargs,
                    )

                    batch_start = np.random.randint(
                        0,
                        low_level_actions.shape[1] - batch_len,
                        size=(low_level_actions.shape[0]),
                    )
                    batch_indices = np.linspace(
                        batch_start,
                        batch_start + batch_len,
                        batch_len,
                        endpoint=False,
                    ).astype(int)
                    primitive_loss = criterion(
                        action_preds[np.arange(batch_indices.shape[1]),
                                     batch_indices].permute(1, 0, 2).reshape(
                                         -1, action_preds.shape[-1]),
                        low_level_actions[:, 1:]
                        [np.arange(batch_indices.shape[1]),
                         batch_indices].permute(1, 0, 2).reshape(
                             -1, action_preds.shape[-1]),
                    )
                    total_primitive_loss += primitive_loss.item()
                    total_world_model_loss += world_model_loss.item()
                    total_div_loss += div.item()
                    total_image_pred_loss += image_pred_loss.item()
                    total_transition_loss += transition_loss.item()
                    total_entropy_loss += entropy_loss.item()
                    total_pred_discount_loss += pred_discount_loss.item()
                    total_reward_pred_loss += reward_pred_loss.item()

                    if loss_to_use == "wm":
                        loss = world_model_loss
                    elif loss_to_use == "primitive":
                        loss = primitive_loss
                    else:
                        loss = world_model_loss + primitive_loss
                    total_train_steps += 1

                update_network(world_model, optimizer, loss, gradient_clip,
                               scaler)
                scaler.update()
            eval_statistics["train/primitive_loss"] = (total_primitive_loss /
                                                       total_train_steps)
            eval_statistics["train/world_model_loss"] = (
                total_world_model_loss / total_train_steps)
            eval_statistics["train/image_pred_loss"] = (total_image_pred_loss /
                                                        total_train_steps)
            eval_statistics["train/transition_loss"] = (total_transition_loss /
                                                        total_train_steps)
            eval_statistics["train/entropy_loss"] = (total_entropy_loss /
                                                     total_train_steps)
            eval_statistics["train/pred_discount_loss"] = (
                total_pred_discount_loss / total_train_steps)
            eval_statistics["train/reward_pred_loss"] = (
                total_reward_pred_loss / total_train_steps)
            latest_state_dict = world_model.state_dict().copy()
            with torch.no_grad():
                total_primitive_loss = 0
                total_world_model_loss = 0
                total_div_loss = 0
                total_image_pred_loss = 0
                total_transition_loss = 0
                total_entropy_loss = 0
                total_pred_discount_loss = 0
                total_reward_pred_loss = 0
                total_loss = 0
                total_test_steps = 0
                for data in test_dataloader:
                    with torch.cuda.amp.autocast():
                        (
                            high_level_actions,
                            obs,
                            rewards,
                            terminals,
                        ), low_level_actions = data
                        obs = obs.to(ptu.device).float()
                        low_level_actions = low_level_actions.to(
                            ptu.device).float()
                        high_level_actions = high_level_actions.to(
                            ptu.device).float()
                        rewards = rewards.to(ptu.device).float()
                        terminals = terminals.to(ptu.device).float()
                        assert all(terminals[:, -1] == 1)
                        rt_idxs = np.arange(
                            num_low_level_actions_per_primitive,
                            obs.shape[1],
                            num_low_level_actions_per_primitive,
                        )
                        rt_idxs = np.concatenate(
                            [[0], rt_idxs]
                        )  # reset obs, effect of first primitive, second primitive, so on

                        batch_start = np.random.randint(0,
                                                        obs.shape[1] -
                                                        batch_len,
                                                        size=(obs.shape[0]))
                        batch_indices = np.linspace(
                            batch_start,
                            batch_start + batch_len,
                            batch_len,
                            endpoint=False,
                        ).astype(int)
                        (
                            post,
                            prior,
                            post_dist,
                            prior_dist,
                            image_dist,
                            reward_dist,
                            pred_discount_dist,
                            _,
                            action_preds,
                        ) = world_model(
                            obs,
                            (high_level_actions, low_level_actions),
                            use_network_action=False,
                            batch_indices=batch_indices,
                            rt_idxs=rt_idxs,
                        )
                        obs = world_model.flatten_obs(
                            obs[np.arange(batch_indices.shape[1]),
                                batch_indices].permute(1, 0, 2),
                            (int(np.prod(image_shape)), ),
                        )
                        rewards = rewards.reshape(-1, rewards.shape[-1])
                        terminals = terminals.reshape(-1, terminals.shape[-1])
                        (
                            world_model_loss,
                            div,
                            image_pred_loss,
                            reward_pred_loss,
                            transition_loss,
                            entropy_loss,
                            pred_discount_loss,
                        ) = world_model_loss_rt(
                            world_model,
                            image_shape,
                            image_dist,
                            reward_dist,
                            {
                                key: value[np.arange(batch_indices.shape[1]),
                                           batch_indices].permute(
                                               1, 0, 2).reshape(
                                                   -1, value.shape[-1])
                                for key, value in prior.items()
                            },
                            {
                                key: value[np.arange(batch_indices.shape[1]),
                                           batch_indices].permute(
                                               1, 0, 2).reshape(
                                                   -1, value.shape[-1])
                                for key, value in post.items()
                            },
                            prior_dist,
                            post_dist,
                            pred_discount_dist,
                            obs,
                            rewards,
                            terminals,
                            **world_model_loss_kwargs,
                        )

                        batch_start = np.random.randint(
                            0,
                            low_level_actions.shape[1] - batch_len,
                            size=(low_level_actions.shape[0]),
                        )
                        batch_indices = np.linspace(
                            batch_start,
                            batch_start + batch_len,
                            batch_len,
                            endpoint=False,
                        ).astype(int)
                        primitive_loss = criterion(
                            action_preds[np.arange(batch_indices.shape[1]),
                                         batch_indices].permute(
                                             1, 0, 2).reshape(
                                                 -1, action_preds.shape[-1]),
                            low_level_actions[:, 1:]
                            [np.arange(batch_indices.shape[1]),
                             batch_indices].permute(1, 0, 2).reshape(
                                 -1, action_preds.shape[-1]),
                        )
                        total_primitive_loss += primitive_loss.item()
                        total_world_model_loss += world_model_loss.item()
                        total_div_loss += div.item()
                        total_image_pred_loss += image_pred_loss.item()
                        total_transition_loss += transition_loss.item()
                        total_entropy_loss += entropy_loss.item()
                        total_pred_discount_loss += pred_discount_loss.item()
                        total_reward_pred_loss += reward_pred_loss.item()
                        total_loss += world_model_loss.item(
                        ) + primitive_loss.item()
                        total_test_steps += 1
                eval_statistics["test/primitive_loss"] = (
                    total_primitive_loss / total_test_steps)
                eval_statistics["test/world_model_loss"] = (
                    total_world_model_loss / total_test_steps)
                eval_statistics["test/image_pred_loss"] = (
                    total_image_pred_loss / total_test_steps)
                eval_statistics["test/transition_loss"] = (
                    total_transition_loss / total_test_steps)
                eval_statistics["test/entropy_loss"] = (total_entropy_loss /
                                                        total_test_steps)
                eval_statistics["test/pred_discount_loss"] = (
                    total_pred_discount_loss / total_test_steps)
                eval_statistics["test/reward_pred_loss"] = (
                    total_reward_pred_loss / total_test_steps)
                if (total_loss / total_test_steps) <= best_test_loss:
                    best_test_loss = total_loss / total_test_steps
                    os.makedirs(logdir + "/models/", exist_ok=True)
                    best_wm_state_dict = world_model.state_dict().copy()
                    torch.save(
                        best_wm_state_dict,
                        logdir + "/models/world_model.pt",
                    )
                if i % variant["plotting_period"] == 0:
                    print("Best test loss", best_test_loss)
                    world_model.load_state_dict(best_wm_state_dict)
                    visualize_wm(
                        env,
                        world_model,
                        train_dataset.outputs,
                        train_dataset.inputs[1],
                        test_dataset.outputs,
                        test_dataset.inputs[1],
                        logdir,
                        max_path_length,
                        low_level_primitives,
                        num_low_level_actions_per_primitive,
                        primitive_model=primitive_model,
                    )
                    world_model.load_state_dict(latest_state_dict)
                logger.record_dict(eval_statistics, prefix="")
                logger.dump_tabular(with_prefix=False, with_timestamp=False)

    elif clone_primitives_separately:
        world_model.load_state_dict(torch.load(variant["world_model_path"]))
        criterion = nn.MSELoss()
        primitives = []
        for p in range(env.num_primitives):
            arguments_size = train_datasets[p].inputs[0].shape[-1]
            m = Mlp(
                hidden_sizes=variant["mlp_hidden_sizes"],
                output_size=low_level_action_dim,
                input_size=world_model.feature_size + arguments_size,
                hidden_activation=torch.nn.functional.relu,
            ).to(ptu.device)
            if variant.get("primitives_path", None):
                m.load_state_dict(
                    torch.load(variant["primitives_path"] +
                               "primitive_model_{}.pt".format(p)))
            primitives.append(m)

        optimizers = [
            optim.Adam(p.parameters(), **optimizer_kwargs) for p in primitives
        ]
        for i in tqdm(range(num_epochs)):
            if i % variant["plotting_period"] == 0:
                visualize_rollout(
                    env,
                    None,
                    None,
                    world_model,
                    logdir,
                    max_path_length,
                    use_env=True,
                    forcing="none",
                    tag="none",
                    low_level_primitives=low_level_primitives,
                    num_low_level_actions_per_primitive=
                    num_low_level_actions_per_primitive,
                    primitive_model=primitives,
                    use_separate_primitives=True,
                )
                visualize_rollout(
                    env,
                    None,
                    None,
                    world_model,
                    logdir,
                    max_path_length,
                    use_env=True,
                    forcing="teacher",
                    tag="none",
                    low_level_primitives=low_level_primitives,
                    num_low_level_actions_per_primitive=
                    num_low_level_actions_per_primitive,
                    primitive_model=primitives,
                    use_separate_primitives=True,
                )
                visualize_rollout(
                    env,
                    None,
                    None,
                    world_model,
                    logdir,
                    max_path_length,
                    use_env=True,
                    forcing="self",
                    tag="none",
                    low_level_primitives=low_level_primitives,
                    num_low_level_actions_per_primitive=
                    num_low_level_actions_per_primitive,
                    primitive_model=primitives,
                    use_separate_primitives=True,
                )
            eval_statistics = OrderedDict()
            print("Epoch: ", i)
            for p, (
                    train_dataloader,
                    test_dataloader,
                    primitive_model,
                    optimizer,
            ) in enumerate(
                    zip(train_dataloaders, test_dataloaders, primitives,
                        optimizers)):
                total_loss = 0
                total_train_steps = 0
                for data in train_dataloader:
                    with torch.cuda.amp.autocast():
                        (arguments, obs), actions = data
                        obs = obs.to(ptu.device).float()
                        actions = actions.to(ptu.device).float()
                        arguments = arguments.to(ptu.device).float()
                        action_preds = world_model(
                            obs,
                            (arguments, actions),
                            primitive_model,
                            use_network_action=False,
                        )[-1]
                        loss = criterion(action_preds, actions)
                        total_loss += loss.item()
                        total_train_steps += 1

                    update_network(primitive_model, optimizer, loss,
                                   gradient_clip, scaler)
                    scaler.update()
                eval_statistics["train/primitive_loss {}".format(p)] = (
                    total_loss / total_train_steps)
                best_test_loss = np.inf
                with torch.no_grad():
                    total_loss = 0
                    total_test_steps = 0
                    for data in test_dataloader:
                        with torch.cuda.amp.autocast():
                            (high_level_actions, obs), actions = data
                            obs = obs.to(ptu.device).float()
                            actions = actions.to(ptu.device).float()
                            high_level_actions = high_level_actions.to(
                                ptu.device).float()
                            action_preds = world_model(
                                obs,
                                (high_level_actions, actions),
                                primitive_model,
                                use_network_action=False,
                            )[-1]
                            loss = criterion(action_preds, actions)
                            total_loss += loss.item()
                            total_test_steps += 1
                    eval_statistics["test/primitive_loss {}".format(p)] = (
                        total_loss / total_test_steps)
                    if (total_loss / total_test_steps) <= best_test_loss:
                        best_test_loss = total_loss / total_test_steps
                        os.makedirs(logdir + "/models/", exist_ok=True)
                        torch.save(
                            primitive_model.state_dict(),
                            logdir + "/models/primitive_model_{}.pt".format(p),
                        )
            logger.record_dict(eval_statistics, prefix="")
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        visualize_rollout(
            env,
            None,
            None,
            world_model,
            logdir,
            max_path_length,
            use_env=True,
            forcing="none",
            tag="none",
            low_level_primitives=low_level_primitives,
            num_low_level_actions_per_primitive=
            num_low_level_actions_per_primitive,
            primitive_model=primitives,
            use_separate_primitives=True,
        )

    elif clone_primitives:
        world_model.load_state_dict(torch.load(variant["world_model_path"]))
        criterion = nn.MSELoss()
        primitive_model = Mlp(
            hidden_sizes=variant["mlp_hidden_sizes"],
            output_size=low_level_action_dim,
            input_size=world_model.feature_size +
            env.action_space.low.shape[0] + 1,
            hidden_activation=torch.nn.functional.relu,
        ).to(ptu.device)
        optimizer = optim.Adam(
            primitive_model.parameters(),
            **optimizer_kwargs,
        )
        for i in tqdm(range(num_epochs)):
            if i % variant["plotting_period"] == 0:
                visualize_rollout(
                    env,
                    None,
                    None,
                    world_model,
                    logdir,
                    max_path_length,
                    use_env=True,
                    forcing="none",
                    tag="none",
                    low_level_primitives=low_level_primitives,
                    num_low_level_actions_per_primitive=
                    num_low_level_actions_per_primitive,
                    primitive_model=primitive_model,
                )
                visualize_rollout(
                    env,
                    train_dataset.outputs,
                    train_dataset.inputs[1],
                    world_model,
                    logdir,
                    max_path_length,
                    use_env=False,
                    forcing="teacher",
                    tag="train",
                    low_level_primitives=low_level_primitives,
                    num_low_level_actions_per_primitive=
                    num_low_level_actions_per_primitive - 1,
                )
                visualize_rollout(
                    env,
                    test_dataset.outputs,
                    test_dataset.inputs[1],
                    world_model,
                    logdir,
                    max_path_length,
                    use_env=False,
                    forcing="teacher",
                    tag="test",
                    low_level_primitives=low_level_primitives,
                    num_low_level_actions_per_primitive=
                    num_low_level_actions_per_primitive - 1,
                )
            eval_statistics = OrderedDict()
            print("Epoch: ", i)
            total_loss = 0
            total_train_steps = 0
            for data in train_dataloader:
                with torch.cuda.amp.autocast():
                    (high_level_actions, obs), actions = data
                    obs = obs.to(ptu.device).float()
                    actions = actions.to(ptu.device).float()
                    high_level_actions = high_level_actions.to(
                        ptu.device).float()
                    action_preds = world_model(
                        obs,
                        (high_level_actions, actions),
                        primitive_model,
                        use_network_action=False,
                    )[-1]
                    loss = criterion(action_preds, actions)
                    total_loss += loss.item()
                    total_train_steps += 1

                update_network(primitive_model, optimizer, loss, gradient_clip,
                               scaler)
                scaler.update()
            eval_statistics[
                "train/primitive_loss"] = total_loss / total_train_steps
            best_test_loss = np.inf
            with torch.no_grad():
                total_loss = 0
                total_test_steps = 0
                for data in test_dataloader:
                    with torch.cuda.amp.autocast():
                        (high_level_actions, obs), actions = data
                        obs = obs.to(ptu.device).float()
                        actions = actions.to(ptu.device).float()
                        high_level_actions = high_level_actions.to(
                            ptu.device).float()
                        action_preds = world_model(
                            obs,
                            (high_level_actions, actions),
                            primitive_model,
                            use_network_action=False,
                        )[-1]
                        loss = criterion(action_preds, actions)
                        total_loss += loss.item()
                        total_test_steps += 1
                eval_statistics[
                    "test/primitive_loss"] = total_loss / total_test_steps
                if (total_loss / total_test_steps) <= best_test_loss:
                    best_test_loss = total_loss / total_test_steps
                    os.makedirs(logdir + "/models/", exist_ok=True)
                    torch.save(
                        primitive_model.state_dict(),
                        logdir + "/models/primitive_model.pt",
                    )
                logger.record_dict(eval_statistics, prefix="")
                logger.dump_tabular(with_prefix=False, with_timestamp=False)
    else:
        world_model = WorldModel(**world_model_kwargs).to(ptu.device)
        optimizer = optim.Adam(
            world_model.parameters(),
            **optimizer_kwargs,
        )
        for i in tqdm(range(num_epochs)):
            if i % variant["plotting_period"] == 0:
                visualize_wm(
                    env,
                    world_model,
                    train_dataset.inputs,
                    train_dataset.outputs,
                    test_dataset.inputs,
                    test_dataset.outputs,
                    logdir,
                    max_path_length,
                    low_level_primitives,
                    num_low_level_actions_per_primitive,
                )
            eval_statistics = OrderedDict()
            print("Epoch: ", i)
            total_wm_loss = 0
            total_div_loss = 0
            total_image_pred_loss = 0
            total_transition_loss = 0
            total_entropy_loss = 0
            total_train_steps = 0
            for data in train_dataloader:
                with torch.cuda.amp.autocast():
                    actions, obs = data
                    obs = obs.to(ptu.device).float()
                    actions = actions.to(ptu.device).float()
                    post, prior, post_dist, prior_dist, image_dist = world_model(
                        obs, actions)[:5]
                    obs = world_model.flatten_obs(obs.permute(
                        1, 0, 2), (int(np.prod(image_shape)), ))
                    (
                        world_model_loss,
                        div,
                        image_pred_loss,
                        transition_loss,
                        entropy_loss,
                    ) = compute_world_model_loss(
                        world_model,
                        image_shape,
                        image_dist,
                        prior,
                        post,
                        prior_dist,
                        post_dist,
                        obs,
                        **world_model_loss_kwargs,
                    )
                    total_wm_loss += world_model_loss.item()
                    total_div_loss += div.item()
                    total_image_pred_loss += image_pred_loss.item()
                    total_transition_loss += transition_loss.item()
                    total_entropy_loss += entropy_loss.item()
                    total_train_steps += 1

                update_network(world_model, optimizer, world_model_loss,
                               gradient_clip, scaler)
                scaler.update()
            eval_statistics[
                "train/wm_loss"] = total_wm_loss / total_train_steps
            eval_statistics[
                "train/div_loss"] = total_div_loss / total_train_steps
            eval_statistics["train/image_pred_loss"] = (total_image_pred_loss /
                                                        total_train_steps)
            eval_statistics["train/transition_loss"] = (total_transition_loss /
                                                        total_train_steps)
            eval_statistics["train/entropy_loss"] = (total_entropy_loss /
                                                     total_train_steps)
            best_test_loss = np.inf
            with torch.no_grad():
                total_wm_loss = 0
                total_div_loss = 0
                total_image_pred_loss = 0
                total_transition_loss = 0
                total_entropy_loss = 0
                total_train_steps = 0
                total_test_steps = 0
                for data in test_dataloader:
                    with torch.cuda.amp.autocast():
                        actions, obs = data
                        obs = obs.to(ptu.device).float()
                        actions = actions.to(ptu.device).float()
                        post, prior, post_dist, prior_dist, image_dist = world_model(
                            obs, actions)[:5]
                        obs = world_model.flatten_obs(obs.permute(
                            1, 0, 2), (int(np.prod(image_shape)), ))
                        (
                            world_model_loss,
                            div,
                            image_pred_loss,
                            transition_loss,
                            entropy_loss,
                        ) = compute_world_model_loss(
                            world_model,
                            image_shape,
                            image_dist,
                            prior,
                            post,
                            prior_dist,
                            post_dist,
                            obs,
                            **world_model_loss_kwargs,
                        )
                        total_wm_loss += world_model_loss.item()
                        total_div_loss += div.item()
                        total_image_pred_loss += image_pred_loss.item()
                        total_transition_loss += transition_loss.item()
                        total_entropy_loss += entropy_loss.item()
                        total_test_steps += 1
                eval_statistics[
                    "test/wm_loss"] = total_wm_loss / total_test_steps
                eval_statistics[
                    "test/div_loss"] = total_div_loss / total_test_steps
                eval_statistics["test/image_pred_loss"] = (
                    total_image_pred_loss / total_test_steps)
                eval_statistics["test/transition_loss"] = (
                    total_transition_loss / total_test_steps)
                eval_statistics["test/entropy_loss"] = (total_entropy_loss /
                                                        total_test_steps)
                if (total_wm_loss / total_test_steps) <= best_test_loss:
                    best_test_loss = total_wm_loss / total_test_steps
                    os.makedirs(logdir + "/models/", exist_ok=True)
                    torch.save(
                        world_model.state_dict(),
                        logdir + "/models/world_model.pt",
                    )
                logger.record_dict(eval_statistics, prefix="")
                logger.dump_tabular(with_prefix=False, with_timestamp=False)

        world_model.load_state_dict(
            torch.load(logdir + "/models/world_model.pt"))
        visualize_wm(
            env,
            world_model,
            train_dataset,
            test_dataset,
            logdir,
            max_path_length,
            low_level_primitives,
            num_low_level_actions_per_primitive,
        )
コード例 #23
0
 def _log_vae_stats(self):
     logger.record_dict(
         self.vae_trainer.get_diagnostics(),
         prefix='vae_trainer/',
     )
コード例 #24
0
    def _log_stats(self, epoch):
        logger.log("Epoch {} finished".format(epoch), with_timestamp=True)
        """
        dump video of policy
        """
        if epoch % self.dump_video_interval == 0:
            imsize = self.expl_env.imsize
            env = self.eval_env

            dump_path = logger.get_snapshot_dir()

            rollout_func = rollout

            video_name = "{}_sample_goal.gif".format(epoch)
            latent_distance_name = "{}_latent_sample_goal.png".format(epoch)
            dump_video(env,
                       self.eval_data_collector._policy,
                       osp.join(dump_path, video_name),
                       rollout_function=rollout_func,
                       imsize=imsize,
                       horizon=self.max_path_length,
                       rows=1,
                       columns=8)
            plot_latent_dist(env,
                             self.eval_data_collector._policy,
                             save_name=osp.join(dump_path,
                                                latent_distance_name),
                             rollout_function=rollout_func,
                             horizon=self.max_path_length)

            old_goal_sampling_mode, old_decode_goals = env._goal_sampling_mode, env.decode_goals
            env._goal_sampling_mode = 'reset_of_env'
            env.decode_goals = False
            video_name = "{}_reset.gif".format(epoch)
            latent_distance_name = "{}_latent_reset.png".format(epoch)
            dump_video(env,
                       self.eval_data_collector._policy,
                       osp.join(dump_path, video_name),
                       rollout_function=rollout_func,
                       imsize=imsize,
                       horizon=self.max_path_length,
                       rows=1,
                       columns=8)
            plot_latent_dist(env,
                             self.eval_data_collector._policy,
                             save_name=osp.join(dump_path,
                                                latent_distance_name),
                             rollout_function=rollout_func,
                             horizon=self.max_path_length)
            self.eval_env._goal_sampling_mode = old_goal_sampling_mode
            self.eval_env.decode_goals = old_decode_goals
        """
        Replay Buffer
        """
        logger.record_dict(self.replay_buffer.get_diagnostics(),
                           prefix='replay_buffer/')
        """
        Trainer
        """
        logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/')
        """
        Exploration
        """
        logger.record_dict(self.expl_data_collector.get_diagnostics(),
                           prefix='exploration/')
        expl_paths = self.expl_data_collector.get_epoch_paths()
        if hasattr(self.expl_env, 'get_diagnostics'):
            logger.record_dict(
                self.expl_env.get_diagnostics(expl_paths),
                prefix='exploration/',
            )
        logger.record_dict(
            eval_util.get_generic_path_information(expl_paths),
            prefix="exploration/",
        )
        """
        Evaluation
        """
        logger.record_dict(
            self.eval_data_collector.get_diagnostics(),
            prefix='evaluation/',
        )
        eval_paths = self.eval_data_collector.get_epoch_paths()
        if hasattr(self.eval_env, 'get_diagnostics'):
            logger.record_dict(
                self.eval_env.get_diagnostics(eval_paths),
                prefix='evaluation/',
            )
        logger.record_dict(
            eval_util.get_generic_path_information(eval_paths),
            prefix="evaluation/",
        )
        """
        Misc
        """
        gt.stamp('logging')
        logger.record_dict(_get_epoch_timings())
        logger.record_tabular('Epoch', epoch)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)