Beispiel #1
0
    def train(self):
        self.start_worker()
        self.init_opt()
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                paths = self.sampler.obtain_samples(itr)
                samples_data = self.sampler.process_samples(itr, paths)
                self.log_diagnostics(paths)
                self.optimize_policy(itr, samples_data)
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params["algo"] = self
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                                  "continue...")

        self.shutdown_worker()
Beispiel #2
0
    def train(self):
        self.start_worker()
        self.init_opt()
        rets = []
        for itr in range(self.start_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                paths = self.obtain_samples(itr)
                print(("BatchPolopt:train len(paths)", len(paths)))
                samples_data, total_returns_per_episode = self.process_samples(itr, paths)
                rets.append(total_returns_per_episode)
                self.log_diagnostics(paths)
                self.optimize_policy(itr, samples_data)
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)  # , **kwargs)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        eval(input("Plotting evaluation run: Press Enter to "
                                  "continue..."))

        self.shutdown_worker()
        return rets
Beispiel #3
0
 def train(self):
     with tf.Session() as sess:
         sess.run(tf.initialize_all_variables())
         self.start_worker()
         start_time = time.time()
         for itr in range(self.start_itr, self.n_itr):
             itr_start_time = time.time()
             with logger.prefix('itr #%d | ' % itr):
                 logger.log("Obtaining samples...")
                 paths = self.obtain_samples(itr)
                 logger.log("Processing samples...")
                 samples_data = self.process_samples(itr, paths)
                 logger.log("Logging diagnostics...")
                 self.log_diagnostics(paths)
                 logger.log("Optimizing policy...")
                 self.optimize_policy(itr, samples_data)
                 logger.log("Saving snapshot...")
                 params = self.get_itr_snapshot(itr, samples_data)  # , **kwargs)
                 if self.store_paths:
                     params["paths"] = samples_data["paths"]
                 logger.save_itr_params(itr, params)
                 logger.log("Saved")
                 logger.record_tabular('Time', time.time() - start_time)
                 logger.record_tabular('ItrTime', time.time() - itr_start_time)
                 logger.dump_tabular(with_prefix=False)
                 if self.plot:
                     self.update_plot()
                     if self.pause_for_plot:
                         input("Plotting evaluation run: Press Enter to "
                               "continue...")
     self.shutdown_worker()
Beispiel #4
0
def reset():
    sh.agent.reset()

    if not sh.observations is None:
        path = dict(
            observations=tensor_utils.stack_tensor_list(sh.observations),
            actions=tensor_utils.stack_tensor_list(sh.actions),
            rewards=tensor_utils.stack_tensor_list(sh.rewards),
            agent_infos=tensor_utils.stack_tensor_dict_list(sh.agent_infos),
            env_infos=tensor_utils.stack_tensor_dict_list(sh.env_infos),
        )

        sh.paths.append(path)
        sh.count += len(sh.observations)

        # check if it is time to update
        if sh.count > sh.algor.batch_size:
            itr = sh.itera
            with logger.prefix('itr #%d | ' % itr):
                paths = sh.paths
                samples_data = sh.algor.sampler.process_samples(itr, paths)
                sh.algor.log_diagnostics(paths)
                sh.algor.optimize_policy(itr, samples_data)
                logger.log("saving snapshot...")
                params = sh.algor.get_itr_snapshot(itr, samples_data)
                sh.algor.current_itr = itr + 1
                params["algo"] = sh.algor
                if sh.algor.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)

                sh.paths = []

                if sh.algor.plot:
                    sh.algor.update_plot()
                    if sh.algor.pause_for_plot:
                        raw_input("Plotting evaluation run: Press Enter to "
                                  "continue...")

            sh.itera += 1
            sh.count = 0

    # reset arrays
    sh.observations, sh.actions, sh.rewards, sh.agent_infos, sh.env_infos = [], [], [], [], []
    def process_samples(self, itr, paths):
        # count visitations or whatever the bonus wants to do. This should not modify the paths
        for b_eval in self.bonus_evaluator:
            logger.log("fitting bonus evaluator before processing...")
            b_eval.fit_before_process_samples(paths)
            logger.log("fitted")
        # save real undiscounted reward before changing them
        undiscounted_returns = [sum(path["rewards"]) for path in paths]
        logger.record_tabular('TrueAverageReturn',
                              np.mean(undiscounted_returns))
        for path in paths:
            path['true_rewards'] = list(path['rewards'])

        # If using a latent regressor (and possibly adding MI to the reward):
        if isinstance(self.latent_regressor, Latent_regressor):
            with logger.prefix(' Latent_regressor '):
                self.latent_regressor.fit(paths)

                if self.reward_regressor_mi:
                    for i, path in enumerate(paths):
                        path[
                            'logli_latent_regressor'] = self.latent_regressor.predict_log_likelihood(
                                [path], [path['agent_infos']['latents']
                                         ])[0]  # this is for paths usually..

                        path['rewards'] += self.reward_regressor_mi * path[
                            'logli_latent_regressor']  # the logli of the latent is the variable of the mutual information

        # for the extra bonus
        for b, b_eval in enumerate(self.bonus_evaluator):
            for i, path in enumerate(paths):
                bonuses = b_eval.predict(path)
                path['rewards'] += self.reward_coef_bonus[b] * bonuses

        real_samples = ext.extract_dict(
            BatchSampler.process_samples(self, itr, paths),
            # I don't need to process the hallucinated samples: the R, A,.. same!
            "observations",
            "actions",
            "advantages",
            "env_infos",
            "agent_infos")
        real_samples["importance_weights"] = np.ones_like(
            real_samples["advantages"])

        return real_samples
Beispiel #6
0
    def train(self):
        sess = tf.get_default_session()
        sess.run(tf.global_variables_initializer())
        if self.init_pol_params is not None:
            self.policy.set_param_values(self.init_pol_params)
        if self.init_irl_params is not None:
            self.irl_model.set_params(self.init_irl_params)
        self.start_worker()
        start_time = time.time()

        returns = []
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)
                logger.log("path rewards shape: {}".format(
                    paths[0]['rewards'].shape))
                logger.log("Processing samples...")
                paths = self.compute_irl(paths, itr=itr)

                returns.append(self.log_avg_returns(paths))
                samples_data = self.process_samples(itr, paths)

                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr,
                                               samples_data)  # , **kwargs)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
        self.shutdown_worker()
        return
Beispiel #7
0
    def train(self, continue_learning=False):
        self.start_worker()
        if not continue_learning:
            self.init_opt()
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                #test = np.concatenate([np.arange(12), [1, 0]])
                #test = test[0:-2]
                #print('BF opt: ', self.policy.get_action(test))
                paths = self.sampler.obtain_samples(itr)
                samples_data = self.sampler.process_samples(itr, paths)
                import joblib, copy
                #joblib.dump(samples_data, 'data/local/experiment/hopper_torso0110_sd3_additionaldim_twotask_zerobaseline/training_data'+str(itr), compress=True)
                '''saved_data = joblib.load('data/local/experiment/hopper_torso0110_sd3_additionaldim_twotask_zerobaseline/training_data'+str(itr))
                samples_data = copy.deepcopy(saved_data)
                samples_data["observations"] = []
                for obs in saved_data['observations']:
                    if obs[-1] < 0.5:
                        obs = np.concatenate([obs, [1,0]])
                    else:
                        obs = np.concatenate([obs, [0, 1]])
                    samples_data["observations"].append(obs)
                gd=self.get_grad(samples_data)
                print(gd)'''

                self.log_diagnostics(paths)
                self.optimize_policy(itr, samples_data)
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params["algo"] = self
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
                #print(self.policy.get_param_values())
                #print('AFT opt: ', self.policy.get_action(test))
                #abc
        self.shutdown_worker()
    def train(self):
        self.start_worker()
        self.init_opt()

        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                timeit.reset()
                timeit.start('Total')
                logger.log("Collecting rollouts")
                timeit.start("CollectingRollouts")
                paths = self.sampler.obtain_samples(itr)
                samples_data = self.sampler.process_samples(itr, paths)
                timeit.stop("CollectingRollouts")
                logger.log("Logging diagnostics")
                timeit.start("LoggingDiagnostics")
                self.log_diagnostics(paths)
                timeit.stop("LoggingDiagnostics")
                logger.log('Optimizing policy')
                timeit.start('OptimizingPolicy')
                self.optimize_policy(itr, samples_data)
                timeit.stop('OptimizingPolicy')
                logger.log("Miscellaneous")
                timeit.start('Miscellaneous')
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params["algo"] = self
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                timeit.stop('Miscellaneous')
                logger.log("saved")
                timeit.stop('Total')
                for line in str(timeit).split('\n'):
                    logger.log(line)
                logger.dump_tabular(with_prefix=False)

                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")

        self.shutdown_worker()
Beispiel #9
0
    def train(self):
        self.start_worker()
        self.init_opt()
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                '''
                paths = list of 20 dicts, where each has --- observations,actions,agent_infos,rewards,env_infos
                '''
                if (itr == -1):
                    print("I'm bringing in my own data.")
                    from six.moves import cPickle
                    f = open(
                        '/home/anagabandi/learn_dyn_nips/all_rollouts_to_save_good.save',
                        'rb')
                    paths = cPickle.load(f)
                    f.close()
                else:
                    paths = self.sampler.obtain_samples(itr)

                samples_data = self.sampler.process_samples(itr, paths)
                '''
                NOW:
                paths = list of 20 dicts, where each has --- actions, advantages, agent_infos, env_infos, observations, returns, rewards
                samples_data = dict of --- paths, actions, returns, rewards, env_infos, advantages, observations, agent_infos
                '''

                self.log_diagnostics(paths)
                self.optimize_policy(itr, samples_data)
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params["algo"] = self
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")

        self.shutdown_worker()
Beispiel #10
0
    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session(config=get_session_config())
            sess.__enter__()

        sess.run(tf.global_variables_initializer())
        self.start_worker()
        start_time = time.time()
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)
                logger.log("Processing samples...")
                samples_data = self.process_samples(itr, paths)
                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr,
                                               samples_data)  # , **kwargs)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                save_itr_params_pickle(itr, params)
                prune_old_snapshots(itr,
                                    keep_every=self.snap_keep_every,
                                    keep_latest=self.snap_keep_latest)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    rollout(self.env,
                            self.policy,
                            animated=True,
                            max_path_length=self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
        self.shutdown_worker()
        if created_session:
            sess.close()
    def _evaluate(self, epoch):
        """Perform evaluation for the current policy.

        We always use the most recent policy, but for computational efficiency
        we sometimes use a stale version of the metapolicy.
        During evaluation, our policy expects an un-augmented observation.

        :param epoch: The epoch number.
        :return: None
        """

        if self._eval_n_episodes < 1:
            return

        if epoch % self._find_best_skill_interval == 0:
            self._single_option_policy = self._get_best_single_option_policy()
        for (policy, policy_name) in [(self._single_option_policy, 'best_single_option_policy')]:
            with logger.tabular_prefix(policy_name + '/'), logger.prefix(policy_name + '/'):
                with self._policy.deterministic(self._eval_deterministic):
                    if self._eval_render:
                        paths = rollouts(self._eval_env, policy,
                                         self._max_path_length, self._eval_n_episodes,
                                         render=True, render_mode='rgb_array')
                    else:
                        paths = rollouts(self._eval_env, policy,
                                         self._max_path_length, self._eval_n_episodes)

                total_returns = [path['rewards'].sum() for path in paths]
                episode_lengths = [len(p['rewards']) for p in paths]

                logger.record_tabular('return-average', np.mean(total_returns))
                logger.record_tabular('return-min', np.min(total_returns))
                logger.record_tabular('return-max', np.max(total_returns))
                logger.record_tabular('return-std', np.std(total_returns))
                logger.record_tabular('episode-length-avg', np.mean(episode_lengths))
                logger.record_tabular('episode-length-min', np.min(episode_lengths))
                logger.record_tabular('episode-length-max', np.max(episode_lengths))
                logger.record_tabular('episode-length-std', np.std(episode_lengths))

                self._eval_env.log_diagnostics(paths)

        batch = self._pool.random_batch(self._batch_size)
        self.log_diagnostics(batch)
    def train(self):
        sess = tf.get_default_session()
        sess.run(tf.global_variables_initializer())
        if self.init_pol_params is not None:
            self.policy.set_param_values(self.init_pol_params)
        if self.init_irl_params is not None:
            self.irl_model.set_params(self.init_irl_params)
        self.start_worker()
        start_time = time.time()

        returns = []
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)

                logger.log("Processing samples...")
                paths = self.compute_irl(paths, itr=itr)
                returns.append(self.log_avg_returns(paths))
                samples_data = self.process_samples(itr, paths)

                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)  # , **kwargs)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
        self.shutdown_worker()
        return 
Beispiel #13
0
    def train(self):
        if self.init_pol_params is not None:
            self.policy.set_param_values(self.init_pol_params)
        if self.init_irl_params is not None:
            self.irl_model.set_params(self.init_irl_params)
        self.start_worker()
        start_time = time.time()

        returns = []
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log('Obtaining samples...')
                paths = self.sampler.obtain_samples(itr)
                logger.log('Processing samples...')
                # Update the Reward function
                paths = self.compute_irl(paths, itr=itr)
                # returns.append(self.log_avg_returns(paths))
                samples_data = self.sampler.process_samples(itr, paths)

                logger.log('Logging diagnostics...')
                self.log_diagnostics(paths)
                logger.log('Optimizing policy...')
                self.optimize_policy(itr, samples_data)
                logger.log('Saving snapshot...')
                params = self.get_itr_snapshot(itr,
                                               samples_data)  # , **kwargs)
                if self.store_paths:
                    params['paths'] = samples_data['paths']
                logger.save_itr_params(itr, params)
                logger.log('Saved')
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input('Plotting evaluation run: Press Enter to '
                              'continue...')
        self.shutdown_worker()
        return
    def train(self):
        with tf.Session() as sess:
            sess.run(tf.initialize_all_variables())
            self.start_worker()
            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr + 1):
                itr_start_time = time.time()
                with logger.prefix('itr #%d | ' % itr):
                    logger.log("Obtaining samples...")
                    paths = self.obtain_samples(itr)
                    logger.log("Processing samples...")
                    samples_data = self.process_samples(itr, paths)
                    logger.log("Logging diagnostics...")
                    self.log_diagnostics(paths)
                    logger.log("Optimizing actor policy...")
                    self.optimize_policy(itr, samples_data)

                    if itr % self.save_param_update == 0:
                        logger.log("Saving snapshot...")
                        params = self.get_itr_snapshot(
                            itr, samples_data)  # , **kwargs)
                        if self.store_paths:
                            if isinstance(samples_data, list):
                                params["paths"] = [
                                    sd["paths"] for sd in samples_data
                                ]
                            else:
                                params["paths"] = samples_data["paths"]
                        logger.save_itr_params(itr, params)
                        logger.log("Saved")

                    logger.record_tabular('Time', time.time() - start_time)
                    logger.record_tabular('ItrTime',
                                          time.time() - itr_start_time)
                    logger.dump_tabular(with_prefix=False)
                    if self.plot:
                        self.update_plot()
                        if self.pause_for_plot:
                            input("Plotting evaluation run: Press Enter to "
                                  "continue...")
        self.shutdown_worker()
Beispiel #15
0
    def train(self):
        with tf.Session() as sess:
            sess.run(tf.initialize_all_variables())
            if self.load_params_args is not None:
                self.policy.load_params(*self.load_params_args)
            self.start_worker()
            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                self.policy.save_params(itr)
                itr_start_time = time.time()
                if itr >= self.temporal_noise_thresh:
                    self.env._noise_indicies = None

                with logger.prefix('itr #%d | ' % itr):
                    logger.log("Obtaining samples...")
                    paths = self.obtain_samples(itr)
                    logger.log("Processing samples...")
                    samples_data = self.process_samples(itr, paths)
                    logger.log("Logging diagnostics...")
                    # env, policy, baseline have individual log_diagnos methods
                    # for overriding
                    self.log_diagnostics(paths)
                    logger.log("Optimizing policy...")
                    self.optimize_policy(itr, samples_data)
                    logger.log("Saving snapshot...")
                    params = self.get_itr_snapshot(
                        itr, samples_data)  # , **kwargs)
                    if self.store_paths:
                        params["paths"] = samples_data["paths"]
                    logger.save_itr_params(itr, params)
                    logger.log("Saved")
                    logger.record_tabular('Time', time.time() - start_time)
                    logger.record_tabular(
                        'ItrTime', time.time() - itr_start_time)
                    logger.dump_tabular(with_prefix=False)
                    if self.plot:
                        self.update_plot()
                        if self.pause_for_plot:
                            input("Plotting evaluation run: Press Enter to "
                                  "continue...")
        self.shutdown_worker()
Beispiel #16
0
    def fit_with_samples(self, paths, samples_data):
        inputs = [
            samples_data["observations"], samples_data["returns"],
            samples_data["valids"]
        ]

        self.f_update_stats(samples_data["returns"], samples_data["valids"])

        with logger.prefix("Vf | "), logger.tabular_prefix("Vf."):
            if self.log_loss_before:
                logger.log("Computing loss before training")
                loss_before, _ = self.optimizer.loss_diagnostics(inputs)
                logger.log("Computed")

            epoch_losses = []

            def record_data(loss, diagnostics, *args, **kwargs):
                epoch_losses.append(loss)
                return True

            self.optimizer.optimize(inputs, callback=record_data)

            if self.log_loss_after:
                logger.log("Computing loss after training")
                loss_after, _ = self.optimizer.loss_diagnostics(inputs)
                logger.log("Computed")

            # perform minibatch gradient descent on the surrogate loss, while monitoring the KL divergence

            if self.log_loss_before:
                logger.record_tabular('LossBefore', loss_before)
            else:
                # Log approximately
                logger.record_tabular('FirstEpoch.Loss', epoch_losses[0])
            if self.log_loss_after:
                logger.record_tabular('LossAfter', loss_after)
            else:
                logger.record_tabular('LastEpoch.Loss', epoch_losses[-1])
            if self.log_loss_before and self.log_loss_after:
                logger.record_tabular('dLoss', loss_before - loss_after)
Beispiel #17
0
def oracle_train(algo, sess=None):
    """
    This is necessary so that we don't wipe away already initialized policy params.
    Ideally, we should pull request this in as an option to RLlab and remove it from here once done
    """
    created_session = True if (sess is None) else False
    if sess is None:
        sess = tf.Session()
        sess.__enter__()

    rollout_cache = []
    initialize_uninitialized(sess)
    algo.start_worker()
    start_time = time.time()
    for itr in range(algo.start_itr, algo.n_itr):
        itr_start_time = time.time()
        with logger.prefix('itr #%d | ' % itr):

            logger.log("Obtaining samples...")
            paths = algo.obtain_samples(itr)
            logger.log("Processing samples...")
            samples_data = algo.process_samples(itr, paths)
            logger.log("Logging diagnostics...")
            algo.log_diagnostics(paths)
            logger.log("Optimizing policy...")
            algo.optimize_policy(itr, samples_data)
            logger.log("Saving snapshot...")
            params = algo.get_itr_snapshot(itr, samples_data)  # , **kwargs)
            if algo.store_paths:
                params["paths"] = samples_data["paths"]
            logger.save_itr_params(itr, params)
            logger.log("Saved")
            logger.record_tabular('Time', time.time() - start_time)
            logger.record_tabular('ItrTime', time.time() - itr_start_time)
            logger.dump_tabular(with_prefix=False)

    algo.shutdown_worker()
    if created_session:
        sess.close()
    def optimize_policy(self, itr,
                        samples_data):  # make that samples_data comes with latents: see train in batch_polopt
        print("##############################")
        print("optimize policy in npo_snn_rewards")


        all_input_values = tuple(ext.extract(  # it will be in agent_infos!!! under key "latents"
            samples_data,
            "observations", "actions", "advantages"
        ))
        print("advantage function:", samples_data["advantages"])
        sum_adv = 0
        for x in samples_data["advantages"]:
            # print(x)
            sum_adv += x
        print("adv_sum", sum_adv)
        agent_infos = samples_data["agent_infos"]
        all_input_values += (agent_infos[
                                 "latents"],)  # latents has already been processed and is the concat of all latents, but keeps key "latents"
        info_list = [agent_infos[k] for k in
                     self.policy.distribution.dist_info_keys]  # these are the mean and var used at rollout, corresponding to
        all_input_values += tuple(info_list)  # old_dist_info_vars_list as symbolic var
        if self.policy.recurrent:
            all_input_values += (samples_data["valids"],)

        loss_before = self.optimizer.loss(all_input_values)
        # this should always be 0. If it's not there is a problem.
        mean_kl_before = self.optimizer.constraint_val(all_input_values)
        logger.record_tabular('MeanKL_Before', mean_kl_before)

        with logger.prefix(' PolicyOptimize | '):
            self.optimizer.optimize(all_input_values)

        mean_kl = self.optimizer.constraint_val(all_input_values)
        loss_after = self.optimizer.loss(all_input_values)
        logger.record_tabular('LossAfter', loss_after)
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('dLoss', loss_before - loss_after)
        return dict()
Beispiel #19
0
def custom_train(algo, sess=None):
    """
    This is necessary so that we don't wipe away already initialized policy params.
    Ideally, we should pull request this in as an option to RLlab and remove it from here once done
    """
    created_session = True if (sess is None) else False
    if sess is None:
        sess = tf.Session()
        sess.__enter__()

    rollout_cache = []
    initialize_uninitialized(sess)
    algo.start_worker()
    start_time = time.time()
    for itr in range(algo.start_itr, algo.n_itr):
        itr_start_time = time.time()
        with logger.prefix('itr #%d | ' % itr):
            logger.log("Obtaining samples...")
            paths = algo.obtain_samples(itr)
            logger.log("Processing samples...")
            samples_data = algo.process_samples(itr, paths)
            logger.log("Logging diagnostics...")
            algo.log_diagnostics(paths)
            logger.log("Optimizing policy...")
            algo.optimize_policy(itr, samples_data)
            logger.log("Saving snapshot...")
            params = algo.get_itr_snapshot(itr, samples_data)  # , **kwargs)
            if algo.store_paths:
                params["paths"] = samples_data["paths"]
            logger.save_itr_params(itr, params)
            logger.log("Saved")
            logger.record_tabular('Time', time.time() - start_time)
            logger.record_tabular('ItrTime', time.time() - itr_start_time)
            logger.dump_tabular(with_prefix=False)

    algo.shutdown_worker()
    if created_session:
        sess.close()
Beispiel #20
0
    def train(self):
        self.start_worker()
        self.init_opt()
        for itr in xrange(self.start_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                paths = self.obtain_samples(itr)
                samples_data = self.process_samples(itr, paths)
                self.log_diagnostics(paths)
                self.optimize_policy(itr, samples_data)
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)  # , **kwargs)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        raw_input("Plotting evaluation run: Press Enter to "
                                  "continue...")

        self.shutdown_worker()
Beispiel #21
0
 def train(self):
     with tf.Session() as sess:
         sess.run(tf.initialize_all_variables())
         self.start_worker()
         start_time = time.time()
         num_samples = 0
         for itr in range(self.start_itr, self.n_itr):
             itr_start_time = time.time()
             with logger.prefix('itr #%d | ' % itr):
                 logger.log("Obtaining new samples...")
                 paths = self.obtain_samples(itr)
                 for path in paths:
                     num_samples += len(path["rewards"])
                 logger.log("total num samples..." + str(num_samples))
                 logger.log("Processing samples...")
                 samples_data = self.process_samples(itr, paths)
                 logger.log("Logging diagnostics...")
                 self.log_diagnostics(paths)
                 logger.log("Optimizing policy...")
                 self.optimize_policy(samples_data)
                 logger.log("Saving snapshot...")
                 params = self.get_itr_snapshot(itr,
                                                samples_data)  # , **kwargs)
                 if self.store_paths:
                     params["paths"] = samples_data["paths"]
                 logger.save_itr_params(itr, params)
                 logger.log("Saved")
                 logger.record_tabular('Time', time.time() - start_time)
                 logger.record_tabular('ItrTime',
                                       time.time() - itr_start_time)
                 logger.dump_tabular(with_prefix=False)
                 if self.plot:
                     self.update_plot()
                     if self.pause_for_plot:
                         input("Plotting evaluation run: Press Enter to "
                               "continue...")
     self.shutdown_worker()
Beispiel #22
0
 def train(self, sess=None):
     created_session = True if (sess is None) else False
     if sess is None:
         sess = tf.Session()
         sess.__enter__()
         
     sess.run(tf.global_variables_initializer())
     self.start_worker()
     start_time = time.time()
     for itr in range(self.start_itr, self.n_itr):
         itr_start_time = time.time()
         with logger.prefix('itr #%d | ' % itr):
             logger.log("Obtaining samples...")
             paths = self.obtain_samples(itr)
             logger.log("Processing samples...")
             samples_data = self.process_samples(itr, paths)
             logger.log("Logging diagnostics...")
             self.log_diagnostics(paths)
             logger.log("Optimizing policy...")
             self.optimize_policy(itr, samples_data)
             logger.log("Saving snapshot...")
             params = self.get_itr_snapshot(itr, samples_data)  # , **kwargs)
             if self.store_paths:
                 params["paths"] = samples_data["paths"]
             logger.save_itr_params(itr, params)
             logger.log("Saved")
             logger.record_tabular('Time', time.time() - start_time)
             logger.record_tabular('ItrTime', time.time() - itr_start_time)
             logger.dump_tabular(with_prefix=False)
             if self.plot:
                 rollout(self.env, self.policy, animated=True, max_path_length=self.max_path_length)
                 if self.pause_for_plot:
                     input("Plotting evaluation run: Press Enter to "
                           "continue...")
     self.shutdown_worker()
     if created_session:
         sess.close()
Beispiel #23
0
    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        if not self.transfer:
            sess.run(tf.global_variables_initializer())
        self.start_worker()
        start_time = time.time()
        for itr in range(self.start_itr, self.n_itr):
            if itr == 0 or itr % self.policy_save_interval == 0:
                params = dict(
                    params1=self.policy.get_param_values(),
                    params2=self.policy2.get_param_values(),
                )
                joblib.dump(params,
                            self.policy_path + '/params' + str(itr) + '.pkl',
                            compress=3)

            itr_start_time = time.time()

            for n1 in range(self.N1):
                with logger.prefix('itr #%d ' % itr + 'n1 #%d |' % n1):
                    logger.log("training policy 1...")
                    logger.log("Obtaining samples...")
                    paths = self.obtain_samples(itr, 1)
                    logger.log("Processing samples...")
                    samples_data = self.process_samples(itr, paths, 1)

                    if self.record_rewards:
                        undiscounted_returns = [
                            sum(path["rewards"]) for path in paths
                        ]
                        average_discounted_return = np.mean(
                            [path["returns"][0] for path in paths])
                        AverageReturn = np.mean(undiscounted_returns)
                        StdReturn = np.std(undiscounted_returns)
                        MaxReturn = np.max(undiscounted_returns)
                        MinReturn = np.min(undiscounted_returns)
                        self.rewards['average_discounted_return1'].append(
                            average_discounted_return)
                        self.rewards['AverageReturn1'].append(AverageReturn)
                        self.rewards['StdReturn1'].append(StdReturn)
                        self.rewards['MaxReturn1'].append(MaxReturn)
                        self.rewards['MinReturn1'].append(MinReturn)

                    logger.log("Logging diagnostics...")
                    self.log_diagnostics(paths, 1)
                    logger.log("Optimizing policy...")
                    self.optimize_policy(itr, samples_data, 1)

                    logger.record_tabular('Time', time.time() - start_time)
                    logger.record_tabular('ItrTime',
                                          time.time() - itr_start_time)
                    logger.dump_tabular(with_prefix=False)

            for n2 in range(self.N2):
                if itr != self.n_itr - 1:  #don't train adversary at last time
                    with logger.prefix('itr #%d ' % itr + 'n2 #%d |' % n2):
                        logger.log("training policy 2...")
                        logger.log("Obtaining samples...")
                        paths = self.obtain_samples(itr, 2)
                        logger.log("Processing samples...")
                        samples_data = self.process_samples(itr, paths, 2)

                        if self.record_rewards:
                            undiscounted_returns = [
                                sum(path["rewards"]) for path in paths
                            ]
                            average_discounted_return = np.mean(
                                [path["returns"][0] for path in paths])
                            AverageReturn = np.mean(undiscounted_returns)
                            StdReturn = np.std(undiscounted_returns)
                            MaxReturn = np.max(undiscounted_returns)
                            MinReturn = np.min(undiscounted_returns)
                            self.rewards['average_discounted_return2'].append(
                                average_discounted_return)
                            self.rewards['AverageReturn2'].append(
                                AverageReturn)
                            self.rewards['StdReturn2'].append(StdReturn)
                            self.rewards['MaxReturn2'].append(MaxReturn)
                            self.rewards['MinReturn2'].append(MinReturn)

                        logger.log("Logging diagnostics...")
                        self.log_diagnostics(paths, 2)
                        logger.log("Optimizing policy...")
                        self.optimize_policy(itr, samples_data, 2)

                        logger.record_tabular('Time', time.time() - start_time)
                        logger.record_tabular('ItrTime',
                                              time.time() - itr_start_time)
                        logger.dump_tabular(with_prefix=False)

            logger.log("Saving snapshot...")
            params = self.get_itr_snapshot(itr)  # , **kwargs)
            logger.save_itr_params(itr, params)
            logger.log("Saved")
            # logger.record_tabular('Time', time.time() - start_time)
            # logger.record_tabular('ItrTime', time.time() - itr_start_time)
            # logger.dump_tabular(with_prefix=False)

        self.shutdown_worker()
        if created_session:
            sess.close()
Beispiel #24
0
    def train(self, sess=None):
        if sess is None:
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            sess = tf.Session(config=config)
            sess.__enter__()
            sess.run(tf.initialize_all_variables())
        else:
            sess.run(
                tf.initialize_variables(
                    list(
                        tf.get_variable(name) for name in sess.run(
                            tf.report_uninitialized_variables()))))

        self.start_worker()
        start_time = time.time()

        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()

            with logger.prefix('itr #%d | ' % itr):
                all_paths = []
                logger.log("Obtaining samples...")
                for sampler in self.local_samplers:
                    all_paths.append(sampler.obtain_samples(itr))

                logger.log("Processing samples...")
                all_samples_data = []
                for n, (sampler,
                        paths) in enumerate(zip(self.local_samplers,
                                                all_paths)):
                    with logger.tabular_prefix(str(n)):
                        all_samples_data.append(
                            sampler.process_samples(itr, paths))

                logger.log("Logging diagnostics...")
                self.log_diagnostics(all_paths, )

                if self.should_optimize_policy:
                    logger.log("Optimizing policy...")
                    self.optimize_policy(itr, all_samples_data)

                if not self.test_env is None:
                    logger.log("Obtaining test samples...")
                    test_paths = self.test_sampler.obtain_samples(itr)
                    with logger.tabular_prefix("Test"):
                        test_samples = self.test_sampler.process_samples(
                            itr, test_paths)
                        logger.record_tabular(
                            "TestSuccessRate",
                            np.mean(test_samples["env_infos"]["success"]))

                successes = 0.0
                trials = 0.0
                for i, samples_data in enumerate(all_samples_data):
                    success = samples_data["env_infos"]["success"]
                    logger.record_tabular("SuccessRate{}".format(i),
                                          np.mean(success))
                    successes += np.sum(success)
                    trials += success.shape[0]

                success_rate = successes / trials
                logger.record_tabular("SuccessRate", success_rate)

                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr,
                                               all_samples_data)  # , **kwargs)
                logger.save_itr_params(itr, params)

                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)

        self.shutdown_worker()
    def train(self):
        # TODO - make this a util
        flatten_list = lambda l: [item for sublist in l for item in sublist]

        config = tf.ConfigProto()
        config.gpu_options.per_process_gpu_memory_fraction = self.frac_gpu

        with tf.Session(config=config) as sess:
            # Code for loading a previous policy. Somewhat hacky because needs to be in sess.
            if self.load_policy is not None:
                import joblib
                self.policy = joblib.load(self.load_policy)['policy']
            self.init_opt()
            self.initialize_uninitialized_variables(sess)

            self.all_paths = []

            self.start_worker()
            start_time = time.time()
            n_env_timesteps = 0

            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                with logger.prefix('itr #%d | ' % itr):

                    logger.record_tabular("mean_inner_stepsize", self.policy.get_mean_step_size())

                    ''' sample environment configuration '''
                    env = self.env
                    while not ('sample_env_params' in dir(env) or 'sample_goals' in dir(env)):
                        env = env._wrapped_env
                    if 'sample_goals' in dir(env):
                        learner_env_params = env.sample_goals(self.meta_batch_size)
                    elif 'sample_env_params':
                        learner_env_params = env.sample_env_params(self.meta_batch_size)

                    ''' get rollouts from the environment'''

                    time_env_sampling_start = time.time()

                    if self.initial_random_samples and itr == 0:
                        logger.log("Obtaining random samples from the environment...")
                        new_env_paths = self.obtain_random_samples(itr, log=True)

                        n_env_timesteps += self.initial_random_samples
                        logger.record_tabular("n_timesteps", n_env_timesteps)

                        self.all_paths.extend(new_env_paths)
                        samples_data_dynamics = self.random_sampler.process_samples(itr, self.all_paths,
                                                                                    log=True,
                                                                                    log_prefix='EnvTrajs-')  # must log in the same way as the model sampler below

                    else:
                        if self.reset_policy_std:
                            logger.log("Resetting policy std")
                            self.policy.set_std()

                        if not self.tailored_exploration:
                            logger.log("Disabling tailored exploration. Using pre-update policy to collect samples.")
                            self.policy.switch_to_init_dist()

                        logger.log("Obtaining samples from the environment using the policy...")
                        new_env_paths = self.obtain_env_samples(itr, reset_args=learner_env_params,
                                                                log_prefix='EnvSampler-')
                        n_env_timesteps += self.batch_size
                        logger.record_tabular("n_timesteps", n_env_timesteps)

                        # flatten dict of paths per task/mode --> list of paths
                        new_env_paths = [path for task_paths in new_env_paths.values() for path in task_paths]
                        # self.all_paths.extend(new_env_paths)
                        logger.log("Processing environment samples...")
                        # first processing just for logging purposes
                        self.model_sampler.process_samples(itr, new_env_paths, log=True, log_prefix='EnvTrajs-')

                        new_samples_data_dynamics = self.process_samples_for_dynamics(itr, new_env_paths)
                        for k, v in samples_data_dynamics.items():
                            samples_data_dynamics[k] = np.concatenate([v, new_samples_data_dynamics[k]], axis=0)[-int(self.dynamics_data_buffer_size):]

                    logger.record_tabular('Time-EnvSampling', time.time() - time_env_sampling_start)

                    if self.log_real_performance:
                        logger.log("Evaluating the performance of the real policy")
                        self.policy.switch_to_init_dist()
                        new_env_paths = self.obtain_env_samples(itr, reset_args=learner_env_params,
                                                                log_prefix='PrePolicy-')
                        samples_data = {}
                        for key in new_env_paths.keys():
                            samples_data[key] = self.process_samples_for_policy(itr, new_env_paths[key], log=False)
                        _ = self.process_samples_for_policy(itr, flatten_list(new_env_paths.values()), log_prefix='PrePolicy-')
                        self.policy.compute_updated_dists(samples_data)
                        new_env_paths = self.obtain_env_samples(itr, reset_args=learner_env_params, log_prefix='PostPolicy-',)
                        _ = self.process_samples_for_policy(itr,  flatten_list(new_env_paths.values()), log_prefix='PostPolicy-')

                    ''' --------------- fit dynamics model --------------- '''

                    time_fit_start = time.time()

                    epochs = self.dynamic_model_epochs[min(itr, len(self.dynamic_model_epochs) - 1)]
                    if self.reinit_model and itr % self.reinit_model == 0:
                        self.dynamics_model.reinit_model()
                        epochs = self.dynamic_model_epochs[0]
                    logger.log("Training dynamics model for %i epochs ..." % (epochs))
                    self.dynamics_model.fit(samples_data_dynamics['observations_dynamics'],
                                            samples_data_dynamics['actions_dynamics'],
                                            samples_data_dynamics['next_observations_dynamics'],
                                            epochs=epochs, verbose=True, log_tabular=True)

                    logger.record_tabular('Time-ModelFit', time.time() - time_fit_start)

                    ''' --------------- MAML steps --------------- '''

                    times_dyn_sampling = []
                    times_dyn_sample_processing = []
                    times_inner_step = []
                    times_outer_step = []

                    time_maml_steps_start = time.time()

                    kl_pre_post = []
                    model_std = []

                    for maml_itr in range(self.num_maml_steps_per_iter):

                        self.policy.switch_to_init_dist()  # Switch to pre-update policy

                        all_samples_data_maml_iter, all_paths_maml_iter = [], []
                        for step in range(self.num_grad_updates + 1):

                            ''' --------------- Sampling from Dynamics Models --------------- '''

                            logger.log("MAML Step %i%s of %i - Obtaining samples from the dynamics model..." % (
                                maml_itr + 1, chr(97 + step), self.num_maml_steps_per_iter))

                            time_dyn_sampling_start = time.time()

                            if self.reset_from_env_traj:
                                new_model_paths = self.obtain_model_samples(itr, traj_starting_obs=samples_data_dynamics['observations_dynamics'],
                                                                            traj_starting_ts=samples_data_dynamics['timesteps_dynamics'])
                            else:
                                new_model_paths = self.obtain_model_samples(itr)

                            assert type(new_model_paths) == dict and len(new_model_paths) == self.meta_batch_size
                            all_paths_maml_iter.append(new_model_paths)

                            times_dyn_sampling.append(time.time() - time_dyn_sampling_start)

                            ''' --------------- Processing Dynamics Samples --------------- '''

                            logger.log("Processing samples...")
                            time_dyn_sample_processing_start = time.time()
                            samples_data = {}

                            for key in new_model_paths.keys():  # the keys are the tasks
                                # don't log because this will spam the consol with every task.
                                samples_data[key] = self.process_samples_for_policy(itr, new_model_paths[key], log=False)
                            all_samples_data_maml_iter.append(samples_data)

                            # for logging purposes
                            _, mean_reward = self.process_samples_for_policy(itr,
                                                                             flatten_list(new_model_paths.values()),
                                                                             log='reward',
                                                                             log_prefix="DynTrajs%i%s-" % (
                                                                                 maml_itr + 1, chr(97 + step)),
                                                                             return_reward=True)

                            times_dyn_sample_processing.append(time.time() - time_dyn_sample_processing_start)

                            ''' --------------- Inner Policy Update --------------- '''

                            time_inner_step_start = time.time()

                            if step < self.num_grad_updates:
                                logger.log("Computing policy updates...")
                                self.policy.compute_updated_dists(samples_data)

                            times_inner_step.append(time.time() - time_inner_step_start)

                        '''---------- Computing KL divergence ot the policies and variance of the model --------'''
                        # self.policy.switch_to_init_dist()
                        last_samples = all_samples_data_maml_iter[-1]
                        for idx in range(self.meta_batch_size):
                            _, agent_infos_pre = self.policy.get_actions(last_samples[idx]['observations'])
                            # compute KL divergence between pre and post update policy
                            kl_pre_post.append(
                                self.policy.distribution.kl(agent_infos_pre, last_samples[idx]['agent_infos']).mean())
                            model_std.append(self.dynamics_model.predict_std(last_samples[idx]['observations'],
                                                                             last_samples[idx]['actions']).mean())

                        '''------------------------------------------------------------------------------------------'''

                        if maml_itr == 0:
                            prev_rolling_reward_mean = mean_reward
                            rolling_reward_mean = mean_reward
                        else:
                            prev_rolling_reward_mean = rolling_reward_mean
                            rolling_reward_mean = 0.8 * rolling_reward_mean + 0.2 * mean_reward


                        # stop gradient steps when mean_reward decreases
                        if self.retrain_model_when_reward_decreases and rolling_reward_mean < prev_rolling_reward_mean:
                            logger.log(
                                "Stopping policy gradients steps since rolling mean reward decreased from %.2f to %.2f" % (
                                    prev_rolling_reward_mean, rolling_reward_mean))
                            # complete some logging stuff
                            for i in range(maml_itr + 1, self.num_maml_steps_per_iter):
                                logger.record_tabular('DynTrajs%ia-AverageReturn' % (i+1), 0.0)
                                logger.record_tabular('DynTrajs%ib-AverageReturn' % (i+1), 0.0)
                            break

                        ''' --------------- Meta Policy Update --------------- '''

                        logger.log("MAML Step %i of %i - Optimizing policy..." % (maml_itr + 1, self.num_maml_steps_per_iter))
                        time_outer_step_start = time.time()

                        # This needs to take all samples_data so that it can construct graph for meta-optimization.
                        self.optimize_policy(itr, all_samples_data_maml_iter, log=False)
                        if itr == 0: sess.graph.finalize()

                        times_outer_step.append(time.time() - time_outer_step_start)



                    ''' --------------- Logging Stuff --------------- '''
                    logger.record_tabular("KL-pre-post", np.mean(kl_pre_post))
                    logger.record_tabular("Variance Models", np.mean(model_std))

                    logger.record_tabular('Time-MAMLSteps', time.time() - time_maml_steps_start)
                    logger.record_tabular('Time-DynSampling', np.mean(times_dyn_sampling))
                    logger.record_tabular('Time-DynSampleProc', np.mean(times_dyn_sample_processing))
                    logger.record_tabular('Time-InnerStep', np.mean(times_inner_step))
                    logger.record_tabular('Time-OuterStep', np.mean(times_outer_step))


                    logger.log("Saving snapshot...")
                    params = self.get_itr_snapshot(itr, all_samples_data_maml_iter[-1])  # , **kwargs)
                    if self.store_paths:
                        params["paths"] = all_samples_data_maml_iter[-1]["paths"]
                    logger.save_itr_params(itr, params)
                    logger.log("Saved")
                    logger.record_tabular('Time-Overall', time.time() - start_time)
                    logger.record_tabular('Time-Itr', time.time() - itr_start_time)

                    logger.dump_tabular(with_prefix=False)


            self.shutdown_worker()
    def train(self):
        # TODO - make this a util
        flatten_list = lambda l: [item for sublist in l for item in sublist]

        with tf.Session() as sess:
            # Code for loading a previous policy. Somewhat hacky because needs to be in sess.
            if self.load_policy is not None:
                import joblib
                self.policy = joblib.load(self.load_policy)['policy']
            self.init_opt()
            # initialize uninitialized vars  (only initialize vars that were not loaded)
            uninit_vars = []
            for var in tf.global_variables():
                # note - this is hacky, may be better way to do this in newer TF.
                try:
                    sess.run(var)
                except tf.errors.FailedPreconditionError:
                    uninit_vars.append(var)
            sess.run(tf.variables_initializer(uninit_vars))

            self.start_worker()
            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                with logger.prefix('itr #%d | ' % itr):
                    logger.log("Sampling set of tasks/goals for this meta-batch...")

                    env = self.env
                    while 'sample_goals' not in dir(env):
                        env = env.wrapped_env
                    learner_env_goals = env.sample_goals(self.meta_batch_size)

                    self.policy.switch_to_init_dist()  # Switch to pre-update policy

                    all_samples_data, all_paths = [], []
                    for step in range(self.num_grad_updates+1):
                        #if step > 0:
                        #    import pdb; pdb.set_trace() # test param_vals functions.
                        logger.log('** Step ' + str(step) + ' **')
                        logger.log("Obtaining samples...")
                        paths = self.obtain_samples(itr, reset_args=learner_env_goals, log_prefix=str(step))
                        all_paths.append(paths)
                        logger.log("Processing samples...")
                        samples_data = {}
                        for key in paths.keys():  # the keys are the tasks
                            # don't log because this will spam the consol with every task.
                            samples_data[key] = self.process_samples(itr, paths[key], log=False)
                        all_samples_data.append(samples_data)
                        # for logging purposes only
                        self.process_samples(itr, flatten_list(paths.values()), prefix=str(step), log=True)
                        logger.log("Logging diagnostics...")
                        self.log_diagnostics(flatten_list(paths.values()), prefix=str(step))
                        if step < self.num_grad_updates:
                            logger.log("Computing policy updates...")
                            self.policy.compute_updated_dists(samples_data)


                    logger.log("Optimizing policy...")
                    # This needs to take all samples_data so that it can construct graph for meta-optimization.
                    self.optimize_policy(itr, all_samples_data)
                    logger.log("Saving snapshot...")
                    params = self.get_itr_snapshot(itr, all_samples_data[-1])  # , **kwargs)
                    if self.store_paths:
                        params["paths"] = all_samples_data[-1]["paths"]
                    logger.save_itr_params(itr, params)
                    logger.log("Saved")
                    logger.record_tabular('Time', time.time() - start_time)
                    logger.record_tabular('ItrTime', time.time() - itr_start_time)

                    logger.dump_tabular(with_prefix=False)

                    # The rest is some example plotting code.
                    # Plotting code is useful for visualizing trajectories across a few different tasks.
                    if False and itr % 2 == 0 and self.env.observation_space.shape[0] <= 4: # point-mass
                        logger.log("Saving visualization of paths")
                        for ind in range(min(5, self.meta_batch_size)):
                            plt.clf()
                            plt.plot(learner_env_goals[ind][0], learner_env_goals[ind][1], 'k*', markersize=10)
                            plt.hold(True)

                            preupdate_paths = all_paths[0]
                            postupdate_paths = all_paths[-1]

                            pre_points = preupdate_paths[ind][0]['observations']
                            post_points = postupdate_paths[ind][0]['observations']
                            plt.plot(pre_points[:,0], pre_points[:,1], '-r', linewidth=2)
                            plt.plot(post_points[:,0], post_points[:,1], '-b', linewidth=1)

                            pre_points = preupdate_paths[ind][1]['observations']
                            post_points = postupdate_paths[ind][1]['observations']
                            plt.plot(pre_points[:,0], pre_points[:,1], '--r', linewidth=2)
                            plt.plot(post_points[:,0], post_points[:,1], '--b', linewidth=1)

                            pre_points = preupdate_paths[ind][2]['observations']
                            post_points = postupdate_paths[ind][2]['observations']
                            plt.plot(pre_points[:,0], pre_points[:,1], '-.r', linewidth=2)
                            plt.plot(post_points[:,0], post_points[:,1], '-.b', linewidth=1)

                            plt.plot(0,0, 'k.', markersize=5)
                            plt.xlim([-0.8, 0.8])
                            plt.ylim([-0.8, 0.8])
                            plt.legend(['goal', 'preupdate path', 'postupdate path'])
                            plt.savefig(osp.join(logger.get_snapshot_dir(), 'prepost_path'+str(ind)+'.png'))
                    elif False and itr % 2 == 0:  # swimmer or cheetah
                        logger.log("Saving visualization of paths")
                        for ind in range(min(5, self.meta_batch_size)):
                            plt.clf()
                            goal_vel = learner_env_goals[ind]
                            plt.title('Swimmer paths, goal vel='+str(goal_vel))
                            plt.hold(True)

                            prepathobs = all_paths[0][ind][0]['observations']
                            postpathobs = all_paths[-1][ind][0]['observations']
                            plt.plot(prepathobs[:,0], prepathobs[:,1], '-r', linewidth=2)
                            plt.plot(postpathobs[:,0], postpathobs[:,1], '--b', linewidth=1)
                            plt.plot(prepathobs[-1,0], prepathobs[-1,1], 'r*', markersize=10)
                            plt.plot(postpathobs[-1,0], postpathobs[-1,1], 'b*', markersize=10)
                            plt.xlim([-1.0, 5.0])
                            plt.ylim([-1.0, 1.0])

                            plt.legend(['preupdate path', 'postupdate path'], loc=2)
                            plt.savefig(osp.join(logger.get_snapshot_dir(), 'swim1d_prepost_itr'+str(itr)+'_id'+str(ind)+'.pdf'))
        self.shutdown_worker()
Beispiel #27
0
    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        self.initialize_unitialized_variables(sess)

        self.start_worker()
        start_time = time.time()
        n_env_timesteps = 0

        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()

            with logger.prefix('itr #%d | ' % itr):
                if self.initial_random_samples and itr == 0:
                    logger.log(
                        "Obtaining random samples from the environment...")
                    new_env_paths = self.obtain_random_samples(itr, log=True)

                    n_env_timesteps += self.initial_random_samples
                    logger.record_tabular("n_timesteps", n_env_timesteps)

                    samples_data_dynamics = self.random_sampler.process_samples(
                        itr, new_env_paths, log=True, log_prefix='EnvTrajs-'
                    )  # must log in the same way as the model sampler below
                else:
                    logger.log(
                        "Obtaining samples from the environment using the policy..."
                    )
                    new_env_paths = self.obtain_env_samples(itr)

                    n_env_timesteps += self.batch_size
                    logger.record_tabular("n_timesteps", n_env_timesteps)

                    logger.log("Processing environment samples...")
                    # first processing just for logging purposes
                    self.process_samples_for_dynamics(itr,
                                                      new_env_paths,
                                                      log=True,
                                                      log_prefix='EnvTrajs-')

                    new_samples_data_dynamics = self.process_samples_for_dynamics(
                        itr, new_env_paths)
                    for k, v in samples_data_dynamics.items():
                        samples_data_dynamics[k] = np.concatenate(
                            [v, new_samples_data_dynamics[k]],
                            axis=0)[-MAX_BUFFER:]

                epochs = self.dynamic_model_max_epochs[min(
                    itr,
                    len(self.dynamic_model_max_epochs) - 1)]
                # fit dynamics model
                if self.reinit_model and itr % self.reinit_model == 0:
                    self.dynamics_model.reinit_model()
                    epochs = self.dynamic_model_max_epochs[0]
                logger.log("Training dynamics model for %i epochs ..." %
                           (epochs))
                self.dynamics_model.fit(
                    samples_data_dynamics['observations_dynamics'],
                    samples_data_dynamics['actions_dynamics'],
                    samples_data_dynamics['next_observations_dynamics'],
                    epochs=epochs,
                    verbose=True)

                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr, None)  # , **kwargs)
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    rollout(self.env,
                            self.policy,
                            animated=True,
                            max_path_length=self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
        self.shutdown_worker()
        if created_session:
            sess.close()
Beispiel #28
0
    def optimize_policy(self):

        iteration = self.policy_opt_params["max_iters"]
        cost_np_vec = self.cost_np_vec
        algo = self.algo_policy
        real_env = self.env
        """ Re-initialize Policy std parameters. """
        if self.non_increase_counter == self.reset_non_increasing:
            self.tf_sess.run(
                tf.variables_initializer(tf.global_variables(self.scope_name)))
            self.non_increase_counter = 0
            self.min_validation_cost = np.inf

        logging.debug("Before reset policy std %s " % np.array2string(
            np.exp(self.training_policy._l_std_param.param.eval()),
            formatter={'float_kind': '{0:.5f}'.format}))
        self.tf_sess.run([self.reset_op])
        """ Optimize policy via rllab. """

        min_iter = self.min_iters
        min_validation_cost = np.inf  # self.min_validation_cost
        min_idx = 0
        mean_validation_costs, real_validation_costs = [], []
        reset_idx = np.arange(len(self.policy_validation_reset_init))

        for j in range(iteration):
            np.random.shuffle(reset_idx)
            reset_val = reset_idx[:len(self.policy_validation_reset_init) //
                                  20]
            algo.start_worker()

            with rllab_logger.prefix('itr #%d | ' % int(j + 1)):
                paths = algo.obtain_samples(j)
                samples_data = algo.process_samples(j, paths)
                algo.optimize_policy(j, samples_data)
            """ Do validation cost """

            if (j + 1) % self.policy_opt_params["log_every"] == 0:

                if self.bnn_model:
                    estimate_validation_cost = self._evaluate_cost_bnn_env(
                        self.bnn_model, self.n_timestep,
                        self.policy_validation_init)
                else:
                    estimate_validation_cost = evaluate_fixed_init_trajectories_2(
                        real_env,
                        self.session_policy_out,
                        self.policy_validation_reset_init[reset_val],
                        cost_np_vec,
                        self.tf_sess,
                        max_timestep=self.n_timestep,
                        gamma=1.00,
                    )

                mean_validation_cost = np.mean(estimate_validation_cost)
                validation_cost = mean_validation_cost

                np.random.shuffle(reset_idx)
                real_validation_cost = evaluate_fixed_init_trajectories_2(
                    real_env,
                    self.session_policy_out,
                    self.policy_validation_reset_init[reset_val],
                    cost_np_vec,
                    self.tf_sess,
                    max_timestep=self.n_timestep,
                    gamma=1.00)
                real_validation_costs.append(real_validation_cost)
                mean_validation_costs.append(mean_validation_cost)

                logging.info('iter %d' % j)
                logging.info(
                    "%s\n"
                    "\tVal cost:\t%.3f\n"
                    "\tReal cost:\t%.3f\n" % (np.array2string(
                        estimate_validation_cost,
                        formatter={'float_kind': '{0:.5f}'.format
                                   }), validation_cost, real_validation_cost))
                """ Store current best policy """
                if validation_cost < min_validation_cost:
                    min_idx = j
                    min_validation_cost = validation_cost

                    # Save
                    logging.info('\tSaving policy')
                    self.policy_saver.save(self.tf_sess,
                                           os.path.join(
                                               self.log_dir, 'policy.ckpt'),
                                           write_meta_graph=False)

                if j - min_idx > min_iter and mean_validation_cost - min_validation_cost > 1.0:  # tolerance
                    break
        """ Log and restore """
        logging.info(
            "Stop at iteration %d and restore the current best at %d: %.3f" %
            (j + 1, min_idx + 1, min_validation_cost))
        self.policy_saver.restore(self.tf_sess,
                                  os.path.join(self.log_dir, 'policy.ckpt'))

        min_real_cost = min(real_validation_costs)
        if min_real_cost < self.min_validation_cost:
            self.min_validation_cost = min_real_cost
            self.non_increase_counter = 0
        else:
            self.non_increase_counter += 1

        real_final_cost = evaluate_fixed_init_trajectories_2(
            real_env,
            self.session_policy_out,
            self.policy_validation_reset_init,
            cost_np_vec,
            self.tf_sess,
            max_timestep=self.n_timestep,
            gamma=1.00)
        real_validation_costs.append(real_final_cost)

        logging.info("Final Real cost: %.3f" % real_final_cost)

        logging.info("Best in all iters %.3f, non increasing in %d" %
                     (self.min_validation_cost, self.non_increase_counter))

        return mean_validation_costs, real_validation_costs
Beispiel #29
0
    def train(self):

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:
            if self.load_policy is not None:
                self.policy = joblib.load(self.load_policy)['policy']
            self.init_opt()
            # initialize uninitialized vars (I know, it's ugly)

            uninit_vars = []
            for var in tf.all_variables():
                try:
                    sess.run(var)
                except tf.errors.FailedPreconditionError:
                    uninit_vars.append(var)
            sess.run(tf.initialize_variables(uninit_vars))
            #sess.run(tf.initialize_all_variables())
            self.start_worker()
            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                if self.comet_logger:
                    self.comet_logger.set_step(itr + self.outer_iteration)
                if itr == self.n_itr - 1:
                    self.policy.std_modifier = 0.00001
                    #self.policy.std_modifier = 1
                    self.policy.recompute_dist_for_adjusted_std()
                if itr in self.goals_for_ET_dict.keys():
                    # self.policy.std_modifier = 0.0001
                    # self.policy.recompute_dist_for_adjusted_std()
                    goals = self.goals_for_ET_dict[itr]
                    noise = self.action_noise_test
                    self.batch_size = self.batch_size_expert_traj
                else:
                    if self.reset_arg is None:
                        goals = [None]
                    else:
                        goals = [self.reset_arg]
                    noise = self.action_noise_train
                    self.batch_size = self.batch_size_train
                paths_to_save = {}
                itr_start_time = time.time()
                with logger.prefix('itr #%d | ' % itr):

                    logger.log("Obtaining samples...")

                    preupdate = True if itr < self.n_itr - 1 else False
                    # paths_for_goal = self.obtain_samples(itr=itr, reset_args=[{'goal': goal, 'noise': noise}])  # when using oracle environments with changing noise, use this line!
                    paths = self.obtain_samples(itr=itr,
                                                reset_args=[self.reset_arg],
                                                preupdate=preupdate)

                    logger.log("Processing samples...")
                    samples_data = self.process_samples(itr, paths)
                    logger.log("Logging diagnostics...")
                    self.log_diagnostics(paths)

                    #new_param_values = self.policy.get_variable_values(self.policy.all_params)
                    logger.log("Saving snapshot...")
                    params = self.get_itr_snapshot(itr,
                                                   samples_data)  # , **kwargs)
                    if self.store_paths:
                        params["paths"] = samples_data["paths"]

                    logger.save_itr_params(itr,
                                           samples_data["paths"],
                                           file_name=str(self.reset_arg) +
                                           '.pkl')
                    # print(samples_data['paths'])
                    logger.log("Saved")
                    logger.log("Optimizing policy...")
                    self.optimize_policy(itr, samples_data)

                    logger.record_tabular('Time', time.time() - start_time)
                    logger.record_tabular('ItrTime',
                                          time.time() - itr_start_time)

                    logger.dump_tabular(with_prefix=False)

        if self.log_dir is not None:
            logger.remove_tabular_output(self.log_dir + '/progress.csv')
        self.shutdown_worker()
Beispiel #30
0
    def train(self, sess=None):
        global_step = tf.Variable(0,
                                  name='global_step',
                                  trainable=False,
                                  dtype=tf.int32)
        increment_global_step_op = tf.assign(global_step, global_step + 1)
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()
        sess.run(tf.global_variables_initializer())
        if self.qf is not None:
            self.pool = SimpleReplayPool(
                max_pool_size=self.replay_pool_size,
                observation_dim=self.env.observation_space.flat_dim,
                action_dim=self.env.action_space.flat_dim,
                replacement_prob=self.replacement_prob,
                env=self.env,
            )
        self.start_worker()
        self.init_opt()
        # This initializes the optimizer parameters
        sess.run(tf.global_variables_initializer())
        if self.restore_auto: self.restore()
        itr = sess.run(global_step)
        start_time = time.time()
        t0 = time.time()
        while itr < self.n_itr:
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Mem: %f" % memory_usage_resource())
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)
                logger.log("Processing samples...")
                samples_data = self.process_samples(itr, paths)
                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                if self.qf is not None:
                    logger.log("Adding samples to replay pool...")
                    self.add_pool(itr, paths, self.pool)
                    logger.log("Optimizing critic before policy...")
                    self.optimize_critic(itr, self.pool, samples_data)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                self.log_critic_training()
                params = self.get_itr_snapshot(itr,
                                               samples_data)  # , **kwargs)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    rollout(self.env,
                            self.policy,
                            animated=True,
                            max_path_length=self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
                if time.time() - t0 > 10:
                    gc.collect()
                    t0 = time.time()
                itr = sess.run(increment_global_step_op)
                if self.save_freq > 0 and (itr - 1) % self.save_freq == 0:
                    self.save()

        self.shutdown_worker()
        if created_session:
            sess.close()
Beispiel #31
0
    def train(self):
        with tf.Session() as sess:
            if self.load_policy is not None:
                import joblib
                self.policy = joblib.load(self.load_policy)['policy']
            self.init_opt()
            # initialize uninitialized vars (I know, it's ugly)
            uninit_vars = []
            for var in tf.all_variables():
                try:
                    sess.run(var)
                except tf.errors.FailedPreconditionError:
                    uninit_vars.append(var)
            sess.run(tf.initialize_variables(uninit_vars))
            #sess.run(tf.initialize_all_variables())
            self.start_worker()
            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                with logger.prefix('itr #%d | ' % itr):

                    logger.log("Obtaining samples...")
                    paths = self.obtain_samples(itr)
                    logger.log("Processing samples...")
                    samples_data = self.process_samples(itr, paths)
                    logger.log("Logging diagnostics...")
                    self.log_diagnostics(paths)
                    logger.log("Optimizing policy...")
                    self.optimize_policy(itr, samples_data)
                    #new_param_values = self.policy.get_variable_values(self.policy.all_params)

                    logger.log("Saving snapshot...")
                    params = self.get_itr_snapshot(itr, samples_data)  # , **kwargs)
                    if self.store_paths:
                        params["paths"] = samples_data["paths"]
                    logger.save_itr_params(itr, params)
                    logger.log("Saved")
                    logger.record_tabular('Time', time.time() - start_time)
                    logger.record_tabular('ItrTime', time.time() - itr_start_time)

                    #import pickle
                    #with open('paths_itr'+str(itr)+'.pkl', 'wb') as f:
                    #    pickle.dump(paths, f)

                    # debugging
                    """
                    if itr % 1 == 0:
                        logger.log("Saving visualization of paths")
                        import matplotlib.pyplot as plt;
                        for ind in range(5):
                            plt.clf(); plt.hold(True)
                            points = paths[ind]['observations']
                            plt.plot(points[:,0], points[:,1], '-r', linewidth=2)
                            plt.xlim([-1.0, 1.0])
                            plt.ylim([-1.0, 1.0])
                            plt.legend(['path'])
                            plt.savefig('/home/cfinn/path'+str(ind)+'.png')
                    """
                    # end debugging

                    logger.dump_tabular(with_prefix=False)
                    if self.plot:
                        self.update_plot()
                        if self.pause_for_plot:
                            input("Plotting evaluation run: Press Enter to "
                                  "continue...")
        self.shutdown_worker()
Beispiel #32
0
    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()
        if not self.transfer:
            sess.run(tf.global_variables_initializer())
        self.start_worker()
        start_time = time.time()
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                # self.env._wrapped_env.generate_grid=True
                # self.env._wrapped_env.generate_b0_start_goal=True
                # self.env.reset()
                # self.env._wrapped_env.generate_grid=False
                # self.env._wrapped_env.generate_b0_start_goal=False
                paths = self.obtain_samples(itr)
                logger.log("Processing samples...")
                samples_data = self.process_samples(itr, paths)

                if self.record_rewards:
                    logger.log("recording rewards...")
                    undiscounted_returns = [
                        sum(path["rewards"]) for path in paths
                    ]
                    average_discounted_return = np.mean(
                        [path["returns"][0] for path in paths])
                    AverageReturn = np.mean(undiscounted_returns)
                    StdReturn = np.std(undiscounted_returns)
                    MaxReturn = np.max(undiscounted_returns)
                    MinReturn = np.min(undiscounted_returns)
                    self.rewards['average_discounted_return'].append(
                        average_discounted_return)
                    self.rewards['AverageReturn'].append(AverageReturn)
                    self.rewards['StdReturn'].append(StdReturn)
                    self.rewards['MaxReturn'].append(MaxReturn)
                    self.rewards['MinReturn'].append(MinReturn)
                    print("AverageReturn: ", AverageReturn)
                    print("MaxReturn: ", MaxReturn)
                    print("MinReturn: ", MinReturn)
                    # print("returns: ",samples_data["returns"])
                    # print("valids: ",samples_data["valids"])

                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr)  # , **kwargs)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    rollout(self.env,
                            self.policy,
                            animated=True,
                            max_path_length=self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
        self.shutdown_worker()
        if created_session:
            sess.close()
Beispiel #33
0
    def _train(self, rank, shared_dict):
        # Initialize separate exemplar per process
        if self.exemplar_cls is not None:
            self.exemplar = self.exemplar_cls(**self.exemplar_args)

        self.init_rank(rank)
        self.init_shared_dict(shared_dict)
        if self.rank == 0:
            start_time = time.time()

        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                self.update_algo_params(itr)
                if rank == 0:
                    logger.log("Collecting samples ...")
                paths = self.sampler.obtain_samples()
                if rank == 0:
                    logger.log("Processing paths...")
                self.process_paths(paths)

                if rank == 0:
                    logger.log("processing samples...")
                    if self.plot_exemplar:
                        if 'bonus_rewards' in paths[0]:
                            log_paths(paths[:20], 'traj_rewards', itr=itr)

                samples_data, dgnstc_data = self.sampler.process_samples(paths)

                self.log_diagnostics(itr, samples_data,
                                     dgnstc_data)  # (parallel)
                if rank == 0:
                    logger.log("optimizing policy...")

                if self.path_replayer is not None:
                    replayed_paths = self.path_replayer.replay_paths()
                    if len(replayed_paths) > 0:
                        self.process_paths(replayed_paths)
                        replayed_samples_data, _ = self.sampler.process_samples(
                            replayed_paths)
                        samples_data = self.sampler.combine_samples(
                            [samples_data, replayed_samples_data])
                    self.path_replayer.record_paths(paths)
                self.optimize_policy(itr, samples_data)  # (parallel)
                if rank == 0:
                    logger.log("fitting baseline...")
                # self.baseline.fit_by_samples_data(samples_data)  # (parallel)
                self.baseline.fit(paths)
                if rank == 0:
                    logger.log("fitted")
                    logger.log("saving snapshot...")
                    params = self.get_itr_snapshot(itr, samples_data)
                    params["algo"] = self
                    if self.store_paths:
                        # NOTE: Only paths from rank==0 worker will be saved.
                        params["paths"] = samples_data["paths"]
                    logger.save_itr_params(itr, params)
                    logger.log("saved")

                    logger.record_tabular("ElapsedTime",
                                          time.time() - start_time)
                    logger.dump_tabular(with_prefix=False)
                    if self.plot:
                        self.update_plot()
                        if self.pause_for_plot:
                            input("Plotting evaluation run: Press Enter to "
                                  "continue...")
                if self.log_memory_usage:
                    process = psutil.Process(os.getpid())
                    print("Process %d memory usage: %.4f GB" %
                          (rank, process.memory_info().rss / (1024**3)))
                    if self.rank == 0 and sys.platform == "linux":
                        print("Shared memory usage: %.4f GB" %
                              (process.memory_info().shared / (1024**3)))
                self.current_itr = itr + 1
    def train(self):
        # TODO - make this a util
        flatten_list = lambda l: [item for sublist in l for item in sublist]
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        # with tf.Session(config=tf.ConfigProto(device_count={'GPU': 0})) as sess:
        with tf.Session(config=config) as sess:
            tf.set_random_seed(1)
            # Code for loading a previous policy. Somewhat hacky because needs to be in sess.
            if self.load_policy is not None:
                self.policy = joblib.load(self.load_policy)['policy']

            self.init_opt()
            self.init_experts_opt()
            # initialize uninitialized vars  (only initialize vars that were not loaded)
            uninit_vars = []
            # sess.run(tf.global_variables_initializer())
            for var in tf.global_variables():
                # note - this is hacky, may be better way to do this in newer TF.
                try:
                    sess.run(var)
                except tf.errors.FailedPreconditionError:
                    uninit_vars.append(var)
            sess.run(tf.variables_initializer(uninit_vars))
            self.start_worker()
            start_time = time.time()
            self.metaitr = 0

            self.expertLearning_itrs = [30 * i for i in range(100)]

            expertPaths = []
            for itr in range(self.start_itr, self.n_itr):

                if itr in self.expertLearning_itrs:
                    expertPathsDict = self.trainExperts(self.expert_num_itrs)

                # trainIndices = np.random.choice(np.arange(0, len(self.trainGoals)), self.meta_batch_size, replace = False)
                # curr_trainGoals = self.trainGoals[trainIndices]
                # curr_expertPaths = {i : expertPathsDict[key] for i, key in enumerate(trainIndices)}
                curr_trainGoals = self.trainGoals
                curr_expertPaths = expertPathsDict

                itr_start_time = time.time()
                np.random.seed(self.seed + itr)
                tf.set_random_seed(self.seed + itr)
                rd.seed(self.seed + itr)
                with logger.prefix('itr #%d | ' % itr):
                    all_paths_for_plotting = []
                    all_postupdate_paths = []
                    self.beta_steps = min(
                        self.beta_steps,
                        self.beta_curve[min(itr,
                                            len(self.beta_curve) - 1)])
                    beta_steps_range = range(
                        self.beta_steps
                    ) if itr not in self.testing_itrs else range(
                        self.test_goals_mult)
                    beta0_step0_paths = None
                    num_inner_updates = self.num_grad_updates_for_testing if itr in self.testing_itrs else self.num_grad_updates

                    for beta_step in beta_steps_range:
                        all_samples_data_for_betastep = []
                        print("debug, pre-update std modifier")
                        self.policy.std_modifier = self.pre_std_modifier

                        self.policy.switch_to_init_dist()
                        self.policy.perTask_switch_to_init_dist(
                        )  # Switch to pre-update policy

                        if itr in self.testing_itrs:

                            # env = self.env
                            # while 'sample_goals' not in dir(env):

                            #     env = env.wrapped_env
                            #if self.test_on_training_goals:

                            goals_to_use = curr_trainGoals
                            # else:
                            #     goals_to_use = env.sample_goals(self.meta_batch_size)

                        for step in range(num_inner_updates + 1):  # inner loop
                            logger.log('** Betastep %s ** Step %s **' %
                                       (str(beta_step), str(step)))
                            logger.log("Obtaining samples...")

                            if itr in self.testing_itrs:
                                if step < num_inner_updates:
                                    print(
                                        'debug12.0.0, test-time sampling step=',
                                        step)  #, goals_to_use)
                                    paths = self.obtain_samples(
                                        itr=itr,
                                        reset_args=goals_to_use,
                                        log_prefix=str(beta_step) + "_" +
                                        str(step),
                                        testitr=True,
                                        preupdate=True,
                                        mode='vec')

                                    paths = store_agent_infos(
                                        paths
                                    )  # agent_infos_orig is _taskd here

                                elif step == num_inner_updates:
                                    print(
                                        'debug12.0.1, test-time sampling step=',
                                        step)  #, goals_to_use)

                                    paths = self.obtain_samples(
                                        itr=itr,
                                        reset_args=goals_to_use,
                                        log_prefix=str(beta_step) + "_" +
                                        str(step),
                                        testitr=True,
                                        preupdate=False,
                                        mode=self.updateMode)

                                    all_postupdate_paths.extend(paths.values())

                            elif self.expert_trajs_dir is None or (
                                    beta_step == 0
                                    and step < num_inner_updates):
                                print("debug12.1, regular sampling"
                                      )  #, self.goals_to_use_dict[itr])

                                paths = self.obtain_samples(
                                    itr=itr,
                                    reset_args=curr_trainGoals,
                                    log_prefix=str(beta_step) + "_" +
                                    str(step),
                                    preupdate=True,
                                    mode='vec')

                                if beta_step == 0 and step == 0:
                                    paths = store_agent_infos(
                                        paths
                                    )  # agent_infos_orig is populated here
                                    beta0_step0_paths = deepcopy(paths)
                            elif step == num_inner_updates:
                                print("debug12.2, expert traj")
                                paths = curr_expertPaths

                            else:
                                assert False, "we shouldn't be able to get here"

                            all_paths_for_plotting.append(paths)
                            logger.log("Processing samples...")
                            samples_data = {}

                            for tasknum in paths.keys(
                            ):  # the keys are the tasks
                                # don't log because this will spam the console with every task.

                                if self.use_maml_il and step == num_inner_updates:
                                    fast_process = True
                                else:
                                    fast_process = False
                                if itr in self.testing_itrs:
                                    testitr = True
                                else:
                                    testitr = False
                                samples_data[tasknum] = self.process_samples(
                                    itr,
                                    paths[tasknum],
                                    log=False,
                                    fast_process=fast_process,
                                    testitr=testitr,
                                    metalearn_baseline=self.metalearn_baseline)

                            all_samples_data_for_betastep.append(samples_data)

                            # for logging purposes
                            self.process_samples(
                                itr,
                                flatten_list(paths.values()),
                                prefix=str(step),
                                log=True,
                                fast_process=True,
                                testitr=testitr,
                                metalearn_baseline=self.metalearn_baseline)
                            if itr in self.testing_itrs:
                                self.log_diagnostics(flatten_list(
                                    paths.values()),
                                                     prefix=str(step))

                            if step == num_inner_updates:
                                #ogger.record_tabular("AverageReturnLastTest", self.parallel_sampler.memory["AverageReturnLastTest"],front=True)  #TODO: add functionality for multiple grad steps
                                logger.record_tabular(
                                    "TestItr", ("1" if testitr else "0"),
                                    front=True)
                                logger.record_tabular("MetaItr",
                                                      self.metaitr,
                                                      front=True)

                            if step == num_inner_updates - 1:
                                if itr not in self.testing_itrs:
                                    print(
                                        "debug, post update train std modifier"
                                    )
                                    self.policy.std_modifier = self.post_std_modifier_train * self.policy.std_modifier
                                else:
                                    print(
                                        "debug, post update test std modifier")
                                    self.policy.std_modifier = self.post_std_modifier_test * self.policy.std_modifier
                                if (itr in self.testing_itrs
                                        or not self.use_maml_il
                                        or step < num_inner_updates - 1
                                    ) and step < num_inner_updates:
                                    # do not update on last grad step, and do not update on second to last step when training MAMLIL
                                    logger.log("Computing policy updates...")
                                    self.policy.compute_updated_dists(
                                        samples=samples_data)

                        logger.log("Optimizing policy...")
                        # This needs to take all samples_data so that it can construct graph for meta-optimization.
                        start_loss = self.optimize_policy(
                            itr, all_samples_data_for_betastep)

                    if itr not in self.testing_itrs:
                        self.metaitr += 1
                    logger.log("Saving snapshot...")
                    params = self.get_itr_snapshot(
                        itr, all_samples_data_for_betastep[-1])  # , **kwargs)
                    print("debug123, params", params)
                    if self.store_paths:
                        params["paths"] = all_samples_data_for_betastep[-1][
                            "paths"]
                    logger.save_itr_params(itr, params)
                    logger.log("Saved")
                    logger.record_tabular('Time', time.time() - start_time)
                    logger.record_tabular('ItrTime',
                                          time.time() - itr_start_time)
                    logger.dump_tabular(with_prefix=False)

                    #self.plotTrajs(itr, all_paths_for_plotting)
        self.shutdown_worker()
    def train(self):
        # TODO - make this a util
        flatten_list = lambda l: [item for sublist in l for item in sublist]

        with tf.Session() as sess:
            # Code for loading a previous policy. Somewhat hacky because needs to be in sess.
            if self.load_policy is not None:
                import joblib
                self.policy = joblib.load(self.load_policy)['policy']
            self.init_opt()
            # initialize uninitialized vars  (only initialize vars that were not loaded)
            uninit_vars = []
            for var in tf.global_variables():
                # note - this is hacky, may be better way to do this in newer TF.
                try:
                    sess.run(var)
                except tf.errors.FailedPreconditionError:
                    uninit_vars.append(var)
            sess.run(tf.variables_initializer(uninit_vars))

            self.start_worker()
            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                with logger.prefix('itr #%d | ' % itr):
                    logger.log(
                        "Sampling set of tasks/goals for this meta-batch...")

                    env = self.env
                    while 'sample_goals' not in dir(env):
                        env = env.wrapped_env
                    learner_env_goals = env.sample_goals(self.meta_batch_size)

                    self.policy.switch_to_init_dist(
                    )  # Switch to pre-update policy

                    all_samples_data, all_paths = [], []
                    for step in range(self.num_grad_updates + 1):
                        #if step > 0:
                        #    import pdb; pdb.set_trace() # test param_vals functions.
                        logger.log('** Step ' + str(step) + ' **')
                        logger.log("Obtaining samples...")
                        paths = self.obtain_samples(
                            itr,
                            reset_args=learner_env_goals,
                            log_prefix=str(step))
                        all_paths.append(paths)
                        logger.log("Processing samples...")
                        samples_data = {}
                        for key in paths.keys():  # the keys are the tasks
                            # don't log because this will spam the consol with every task.
                            samples_data[key] = self.process_samples(
                                itr, paths[key], log=False)
                        all_samples_data.append(samples_data)
                        # for logging purposes only
                        self.process_samples(itr,
                                             flatten_list(paths.values()),
                                             prefix=str(step),
                                             log=True)
                        logger.log("Logging diagnostics...")
                        self.log_diagnostics(flatten_list(paths.values()),
                                             prefix=str(step))
                        if step < self.num_grad_updates:
                            logger.log("Computing policy updates...")
                            self.policy.compute_updated_dists(samples_data)

                    logger.log("Optimizing policy...")
                    # This needs to take all samples_data so that it can construct graph for meta-optimization.
                    self.optimize_policy(itr, all_samples_data)
                    logger.log("Saving snapshot...")
                    params = self.get_itr_snapshot(
                        itr, all_samples_data[-1])  # , **kwargs)
                    if self.store_paths:
                        params["paths"] = all_samples_data[-1]["paths"]
                    #logger.save_itr_params(itr, params)
                    logger.log("Saved")
                    logger.record_tabular('Time', time.time() - start_time)
                    logger.record_tabular('ItrTime',
                                          time.time() - itr_start_time)

                    logger.dump_tabular(with_prefix=False)

                    # The rest is some example plotting code.
                    # Plotting code is useful for visualizing trajectories across a few different tasks.
                    if False and itr % 2 == 0 and self.env.observation_space.shape[
                            0] <= 4:  # point-mass
                        logger.log("Saving visualization of paths")
                        for ind in range(min(5, self.meta_batch_size)):
                            plt.clf()
                            plt.plot(learner_env_goals[ind][0],
                                     learner_env_goals[ind][1],
                                     'k*',
                                     markersize=10)
                            plt.hold(True)

                            preupdate_paths = all_paths[0]
                            postupdate_paths = all_paths[-1]

                            pre_points = preupdate_paths[ind][0][
                                'observations']
                            post_points = postupdate_paths[ind][0][
                                'observations']
                            plt.plot(pre_points[:, 0],
                                     pre_points[:, 1],
                                     '-r',
                                     linewidth=2)
                            plt.plot(post_points[:, 0],
                                     post_points[:, 1],
                                     '-b',
                                     linewidth=1)

                            pre_points = preupdate_paths[ind][1][
                                'observations']
                            post_points = postupdate_paths[ind][1][
                                'observations']
                            plt.plot(pre_points[:, 0],
                                     pre_points[:, 1],
                                     '--r',
                                     linewidth=2)
                            plt.plot(post_points[:, 0],
                                     post_points[:, 1],
                                     '--b',
                                     linewidth=1)

                            pre_points = preupdate_paths[ind][2][
                                'observations']
                            post_points = postupdate_paths[ind][2][
                                'observations']
                            plt.plot(pre_points[:, 0],
                                     pre_points[:, 1],
                                     '-.r',
                                     linewidth=2)
                            plt.plot(post_points[:, 0],
                                     post_points[:, 1],
                                     '-.b',
                                     linewidth=1)

                            plt.plot(0, 0, 'k.', markersize=5)
                            plt.xlim([-0.8, 0.8])
                            plt.ylim([-0.8, 0.8])
                            plt.legend(
                                ['goal', 'preupdate path', 'postupdate path'])
                            plt.savefig(
                                osp.join(logger.get_snapshot_dir(),
                                         'prepost_path' + str(ind) + '.png'))
                    elif False and itr % 2 == 0:  # swimmer or cheetah
                        logger.log("Saving visualization of paths")
                        for ind in range(min(5, self.meta_batch_size)):
                            plt.clf()
                            goal_vel = learner_env_goals[ind]
                            plt.title('Swimmer paths, goal vel=' +
                                      str(goal_vel))
                            plt.hold(True)

                            prepathobs = all_paths[0][ind][0]['observations']
                            postpathobs = all_paths[-1][ind][0]['observations']
                            plt.plot(prepathobs[:, 0],
                                     prepathobs[:, 1],
                                     '-r',
                                     linewidth=2)
                            plt.plot(postpathobs[:, 0],
                                     postpathobs[:, 1],
                                     '--b',
                                     linewidth=1)
                            plt.plot(prepathobs[-1, 0],
                                     prepathobs[-1, 1],
                                     'r*',
                                     markersize=10)
                            plt.plot(postpathobs[-1, 0],
                                     postpathobs[-1, 1],
                                     'b*',
                                     markersize=10)
                            plt.xlim([-1.0, 5.0])
                            plt.ylim([-1.0, 1.0])

                            plt.legend(['preupdate path', 'postupdate path'],
                                       loc=2)
                            plt.savefig(
                                osp.join(
                                    logger.get_snapshot_dir(),
                                    'swim1d_prepost_itr' + str(itr) + '_id' +
                                    str(ind) + '.pdf'))
        self.shutdown_worker()
Beispiel #36
0
    def curriculum_train(self, curriculum):
        from collections import defaultdict
        from rllab.misc.evaluate import evaluate
        import numpy as np

        task_dist = np.ones(len(curriculum.tasks))
        task_dist[0] = len(curriculum.tasks)
        min_reward = np.inf
        task_eval_reward = defaultdict(float)
        task_counts = defaultdict(int)
        with tf.Session() as sess:
            sess.run(tf.initialize_all_variables())
            start_time = time.time()
            start_itr = self.start_itr
            end_itr = self.n_itr
            # while True:
            schedule = {
                'i_curr': [
                    0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3,
                    4, 5, 6, 7
                ],
                'iter': [
                    100, 100, 100, 100, 200, 200, 300, 300, 100, 100, 100, 100,
                    200, 200, 300, 300, 100, 100, 100, 100, 200, 200, 300, 300
                ]
            }  # curriculum
            # schedule = {'i_curr' : [5],
            #             'iter' : [sum([100, 100, 100, 100, 200, 200, 300, 300, 100, 100, 100, 100, 200, 200, 300, 300])]} # direct
            # schedule = {'i_curr' : [0, 1, 2],
            #             'iter' : [300, 600, 600]}
            i_counter = 0
            i_schedule = 0
            while i_schedule < len(schedule['i_curr']):
                task = curriculum.tasks[schedule['i_curr'][i_schedule]]
                logger.log("Lesson: number of agents = {}".format(task.prop))
                # for ctrial in range(curriculum.n_trials):
                # for ctrial in range(schedule['iter'][i_schedule]):
                # task_prob = np.random.dirichlet(task_dist)
                # task = np.random.choice(curriculum.tasks, p=task_prob)
                # logger.log("Lesson: number of agents = {}".format(task.prop))
                self.env.set_param_values(task.prop)
                self.start_worker()
                # for itr in range(start_itr, end_itr):
                for itr in range(schedule['iter'][i_schedule]):
                    itr_start_time = time.time()
                    # with logger.prefix('curr_trial: #%d itr #%d |' % (ctrial, itr)):
                    with logger.prefix(
                            'curr: #%d itr #%d |' %
                        (schedule['i_curr'][i_schedule], i_counter)):
                        i_counter += 1
                        logger.log("Obtaining samples...")
                        paths = self.obtain_samples(itr)
                        # print('number of ma_paths = {}'.format(len(paths)))
                        logger.log("Processing samples...")
                        # TODO Process appropriately for concurrent or decentralized
                        samples_data = self.process_samples(itr, paths)
                        logger.log("Logging diagnostics...")
                        self.log_diagnostics(paths)
                        logger.log("Optimizing policy...")
                        self.optimize_policy(itr, samples_data)
                        logger.log("Saving snapshot...")
                        params = self.get_itr_snapshot(
                            itr, samples_data)  # , **kwargs)
                        if self.store_paths:
                            if isinstance(samples_data, list):
                                params["paths"] = [
                                    sd["paths"] for sd in samples_data
                                ]
                            else:
                                params["paths"] = samples_data["paths"]
                        logger.save_itr_params(itr, params)
                        logger.log("Saved")
                        logger.record_tabular('Time', time.time() - start_time)
                        logger.record_tabular('ItrTime',
                                              time.time() - itr_start_time)
                        logger.dump_tabular(with_prefix=False)
                        if self.plot:
                            self.update_plot()
                            if self.pause_for_plot:
                                input(
                                    "Plotting evaluation run: Press Enter to "
                                    "continue...")

                start_itr = end_itr
                end_itr += self.n_itr
                logger.log("Evaluating...")
                evres = evaluate(self.env,
                                 self.policy,
                                 max_path_length=self.max_path_length,
                                 n_paths=curriculum.eval_trials,
                                 ma_mode=self.ma_mode,
                                 disc=self.discount)
                task_eval_reward[task] += np.mean(evres[curriculum.metric])
                task_counts[task] += 1

                i_schedule += 1

                # Check how we have progressed
                scores = []
                for i, task in enumerate(curriculum.tasks):
                    if task_counts[task] > 0:
                        score = 1.0 * task_eval_reward[task] / task_counts[task]
                        logger.log("task #{} {}".format(i, score))
                        scores.append(score)
                    else:
                        # scores.append(min(scores))
                        scores.append(-np.inf)

                logger.log('Eval scores = {}'.format(scores))
                min_reward = min(min_reward, min(scores))
                rel_reward = scores[np.argmax(task_dist)]
                # print('min_reward = {}'.format(min_reward))
                # print('curriculum.stop_threshold = {}'.format(curriculum.stop_threshold))
                if rel_reward > curriculum.lesson_threshold:
                    logger.log("task: {} breached, reward: {}!".format(
                        np.argmax(task_dist), rel_reward))
                    task_dist = np.roll(task_dist, 1)  # update distribution
                # if min_reward > curriculum.stop_threshold or i_counter > 4500:
                #     break

        self.shutdown_worker()
Beispiel #37
0
    def train(self):
        with tf.Session() as sess:
            if self.load_policy is not None:
                import joblib
                self.policy = joblib.load(self.load_policy)['policy']
            self.init_opt()
            # initialize uninitialized vars (I know, it's ugly)
            uninit_vars = []
            for var in tf.all_variables():
                try:
                    sess.run(var)
                except tf.errors.FailedPreconditionError:
                    uninit_vars.append(var)
            sess.run(tf.initialize_variables(uninit_vars))
            #sess.run(tf.initialize_all_variables())
            self.start_worker()
            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                with logger.prefix('itr #%d | ' % itr):

                    logger.log("Obtaining samples...")
                    paths = self.obtain_samples(itr)
                    logger.log("Processing samples...")
                    samples_data = self.process_samples(itr, paths)
                    logger.log("Logging diagnostics...")
                    self.log_diagnostics(paths)
                    logger.log("Optimizing policy...")
                    self.optimize_policy(itr, samples_data)
                    #new_param_values = self.policy.get_variable_values(self.policy.all_params)

                    logger.log("Saving snapshot...")
                    params = self.get_itr_snapshot(itr,
                                                   samples_data)  # , **kwargs)
                    if self.store_paths:
                        params["paths"] = samples_data["paths"]
                    logger.save_itr_params(itr, params)
                    logger.log("Saved")
                    logger.record_tabular('Time', time.time() - start_time)
                    logger.record_tabular('ItrTime',
                                          time.time() - itr_start_time)

                    import pickle
                    with open('paths_itr' + str(itr) + '.pkl', 'wb') as f:
                        pickle.dump(paths, f)

                    # debugging
                    """
                    if itr % 1 == 0:
                        logger.log("Saving visualization of paths")
                        import matplotlib.pyplot as plt;
                        for ind in range(5):
                            plt.clf(); plt.hold(True)
                            points = paths[ind]['observations']
                            plt.plot(points[:,0], points[:,1], '-r', linewidth=2)
                            plt.xlim([-1.0, 1.0])
                            plt.ylim([-1.0, 1.0])
                            plt.legend(['path'])
                            plt.savefig('/home/cfinn/path'+str(ind)+'.png')
                    """
                    # end debugging

                    logger.dump_tabular(with_prefix=False)
                    if self.plot:
                        self.update_plot()
                        if self.pause_for_plot:
                            input("Plotting evaluation run: Press Enter to "
                                  "continue...")
        self.shutdown_worker()
Beispiel #38
0
    def train(self):
        with tf.Session() as sess:
            if self.load_policy is not None:
                self.policy = joblib.load(self.load_policy)['policy']
            self.init_opt()
            # initialize uninitialized vars (I know, it's ugly)
            uninit_vars = []
            for var in tf.all_variables():
                try:
                    sess.run(var)
                except tf.errors.FailedPreconditionError:
                    uninit_vars.append(var)
            sess.run(tf.initialize_variables(uninit_vars))
            #sess.run(tf.initialize_all_variables())
            self.start_worker()
            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                if itr == self.n_itr - 1:
                    self.policy.std_modifier = 0.0001
                    self.policy.recompute_dist_for_adjusted_std()
                if itr in self.goals_for_ET_dict.keys():
                    # self.policy.std_modifier = 0.0001
                    # self.policy.recompute_dist_for_adjusted_std()
                    goals = self.goals_for_ET_dict[itr]
                    noise = self.action_noise_test
                    self.batch_size = self.batch_size_expert_traj
                else:
                    if self.reset_arg is None:
                        goals = [None]
                    else:
                        goals = [self.reset_arg]
                    noise = self.action_noise_train
                    self.batch_size = self.batch_size_train
                paths_to_save = {}
                itr_start_time = time.time()
                with logger.prefix('itr #%d | ' % itr):

                    logger.log("Obtaining samples...")
                    paths = []
                    for goalnum, goal in enumerate(goals):
                        preupdate = True if itr < self.n_itr - 1 else False
                        # paths_for_goal = self.obtain_samples(itr=itr, reset_args=[{'goal': goal, 'noise': noise}])  # when using oracle environments with changing noise, use this line!
                        paths_for_goal = self.obtain_samples(
                            itr=itr,
                            reset_args=[{
                                'goal': goal,
                                'noise': noise
                            }],
                            preupdate=preupdate)
                        print("debug, goal 1", goal)
                        paths.extend(
                            paths_for_goal
                        )  # we need this to be flat because we process all of them together
                        # TODO: there's a bunch of sample processing happening below that we should abstract away
                        if itr in self.expert_traj_itrs_to_pickle:
                            logger.log("Saving trajectories...")
                            paths_no_goalobs = self.clip_goal_from_obs(
                                paths_for_goal)
                            [
                                path.pop('agent_infos')
                                for path in paths_no_goalobs
                            ]
                            paths_to_save[goalnum] = paths_no_goalobs
                    if itr in self.expert_traj_itrs_to_pickle:
                        logger.log("Pickling trajectories...")
                        assert len(
                            paths_to_save.keys()
                        ) == 1, "we're going through ET goals one at a time now 10/24/17"
                        joblib_dump_safe(
                            paths_to_save[0],
                            self.save_expert_traj_dir + str(itr) + ".pkl")
                        logger.log("Fast-processing returns...")
                        undiscounted_returns = [
                            sum(path['rewards']) for path in paths
                        ]
                        print("debug", undiscounted_returns)
                        logger.record_tabular('AverageReturn',
                                              np.mean(undiscounted_returns))

                    else:
                        logger.log("Processing samples...")
                        samples_data = self.process_samples(itr, paths)
                        logger.log("Logging diagnostics...")
                        self.log_diagnostics(paths)
                        logger.log("Optimizing policy...")
                        self.optimize_policy(itr, samples_data)
                        #new_param_values = self.policy.get_variable_values(self.policy.all_params)
                        logger.log("Saving snapshot...")
                        params = self.get_itr_snapshot(
                            itr, samples_data)  # , **kwargs)
                        if self.store_paths:
                            params["paths"] = samples_data["paths"]

                    logger.save_itr_params(itr, params)
                    logger.log("Saved")
                    logger.record_tabular('Time', time.time() - start_time)
                    logger.record_tabular('ItrTime',
                                          time.time() - itr_start_time)

                    if True and (itr % 16 == 0
                                 ) and 7 < self.env.observation_space.shape[
                                     0] < 12:  # ReacherEnvOracleNoise
                        logger.log("Saving visualization of paths")
                        plt.clf()
                        plt.hold(True)

                        goal = paths[0]['observations'][0][-2:]
                        plt.plot(goal[0], goal[1], 'k*', markersize=10)

                        goal = paths[1]['observations'][0][-2:]
                        plt.plot(goal[0], goal[1], 'k*', markersize=10)

                        goal = paths[2]['observations'][0][-2:]
                        plt.plot(goal[0], goal[1], 'k*', markersize=10)

                        points = np.array(
                            [obs[6:8] for obs in paths[0]['observations']])
                        plt.plot(points[:, 0], points[:, 1], '-r', linewidth=2)

                        points = np.array(
                            [obs[6:8] for obs in paths[1]['observations']])
                        plt.plot(points[:, 0],
                                 points[:, 1],
                                 '--r',
                                 linewidth=2)

                        points = np.array(
                            [obs[6:8] for obs in paths[2]['observations']])
                        plt.plot(points[:, 0],
                                 points[:, 1],
                                 '-.r',
                                 linewidth=2)

                        plt.plot(0, 0, 'k.', markersize=5)
                        plt.xlim([-0.25, 0.25])
                        plt.ylim([-0.25, 0.25])
                        plt.legend(['path'])
                        plt.savefig(
                            osp.join(logger.get_snapshot_dir(), 'path' +
                                     str(0) + '_' + str(itr) + '.png'))
                        print(
                            osp.join(logger.get_snapshot_dir(), 'path' +
                                     str(0) + '_' + str(itr) + '.png'))

                    # if self.make_video and itr % 2 == 0 or itr in [0,1,2,3,4,5,6,7,8]: # and itr in self.goals_for_ET_dict.keys() == 0:
                    if self.make_video and (
                            itr >= 0 and itr <= self.n_itr - 1
                    ):  # and itr in self.goals_for_ET_dict.keys() == 0:
                        logger.log("Saving videos...")
                        self.env.reset(reset_args=goals[0])
                        video_filename = osp.join(
                            logger.get_snapshot_dir(),
                            'post_path_%s_0_%s.gif' %
                            (itr, time.strftime("%H%M%S")))
                        rollout(
                            env=self.env,
                            agent=self.policy,
                            max_path_length=self.max_path_length,
                            animated=True,
                            speedup=2,
                            save_video=True,
                            video_filename=video_filename,
                            reset_arg=goals[0],
                            use_maml=False,
                        )
                        # self.env.reset(reset_args=goals[0])
                        # video_filename = osp.join(logger.get_snapshot_dir(), 'post_path_%s_1_%s.gif' % (itr,time.strftime("%H%M%S")))
                        # rollout(env=self.env, agent=self.policy, max_path_length=self.max_path_length,
                        #         animated=True, speedup=2, save_video=True, video_filename=video_filename,
                        #         reset_arg=goals[0],
                        #         use_maml=False, )
                        # self.env.reset(reset_args=goals[0])
                        # video_filename = osp.join(logger.get_snapshot_dir(), 'post_path_%s_2_%s.gif' % (itr,time.strftime("%H%M%S")))
                        # rollout(env=self.env, agent=self.policy, max_path_length=self.max_path_length,
                        #         animated=True, speedup=2, save_video=True, video_filename=video_filename,
                        #         reset_arg=goals[0],
                        #         use_maml=False, )

                    # debugging
                    """
                    if itr % 1 == 0:
                        logger.log("Saving visualization of paths")
                        import matplotlib.pyplot as plt;
                        for ind in range(5):
                            plt.clf(); plt.hold(True)
                            points = paths[ind]['observations']
                            plt.plot(points[:,0], points[:,1], '-r', linewidth=2)
                            plt.xlim([-1.0, 1.0])
                            plt.ylim([-1.0, 1.0])
                            plt.legend(['path'])
                            plt.savefig('/home/cfinn/path'+str(ind)+'.png')
                    """
                    # end debugging

                    logger.dump_tabular(with_prefix=False)
                    if self.plot:
                        self.update_plot()
                        if self.pause_for_plot:
                            input("Plotting evaluation run: Press Enter to "
                                  "continue...")

        import ipdb
        ipdb.set_trace()
        brace = 'a'
        self.shutdown_worker()