Esempio n. 1
0
    def train(self):
        """
        Trains policy on env using algo
        """
        worker_data_queue, worker_model_queue, worker_policy_queue = self.queues
        worker_data_remote, worker_model_remote, worker_policy_remote = self.remotes

        for p in self.ps:
            p.start()
        ''' --------------- worker warm-up --------------- '''

        logger.log('Prepare start...')

        worker_data_remote.send('prepare start')
        worker_data_queue.put(self.initial_random_samples)
        assert worker_data_remote.recv() == 'loop ready'

        worker_model_remote.send('prepare start')
        assert worker_model_remote.recv() == 'loop ready'

        worker_policy_remote.send('prepare start')
        assert worker_policy_remote.recv() == 'loop ready'

        time_total = time.time()
        ''' --------------- worker looping --------------- '''

        logger.log('Start looping...')
        for remote in self.remotes:
            remote.send('start loop')
        ''' --------------- collect info --------------- '''

        for remote in self.remotes:
            assert remote.recv() == 'loop done'
        logger.log('\n------------all workers exit loops -------------')
        for remote in self.remotes:
            assert remote.recv() == 'worker closed'

        for p in self.ps:
            p.terminate()

        logger.logkv('Trainer-TimeTotal', time.time() - time_total)
        logger.dumpkvs()
        logger.log("*****Training finished")
    def train(self):
        """
        Trains policy on env using algo

        Pseudocode:
            for itr in n_itr:
                for step in num_inner_grad_steps:
                    sampler.sample()
                    algo.compute_updated_dists()
                algo.optimize_policy()
                sampler.update_goals()
        """
        with self.sess.as_default() as sess:

            # initialize uninitialized vars  (only initialize vars that were not loaded)
            # uninit_vars = [var for var in tf.global_variables() if not sess.run(tf.is_variable_initialized(var))]
            # sess.run(tf.variables_initializer(uninit_vars))
            sess.run(tf.global_variables_initializer())

            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                logger.log(
                    "\n ---------------- Iteration %d ----------------" % itr)

                time_env_sampling_start = time.time()

                if itr == 0:
                    logger.log(
                        "Obtaining random samples from the environment...")
                    self.env_sampler.total_samples *= self.num_rollouts_per_iter
                    env_paths = self.env_sampler.obtain_samples(
                        log=True,
                        random=self.initial_random_samples,
                        log_prefix='Data-EnvSampler-',
                        verbose=True)
                    self.env_sampler.total_samples /= self.num_rollouts_per_iter

                time_env_samp_proc = time.time()
                samples_data = self.dynamics_sample_processor.process_samples(
                    env_paths, log=True, log_prefix='Data-EnvTrajs-')
                self.env.log_diagnostics(env_paths, prefix='Data-EnvTrajs-')
                logger.record_tabular('Data-TimeEnvSampleProc',
                                      time.time() - time_env_samp_proc)

                buffer = samples_data if self.sample_from_buffer else None
                ''' --------------- fit dynamics model --------------- '''
                logger.log("Training dynamics model for %i epochs ..." %
                           self.dynamics_model_max_epochs)
                time_fit_start = time.time()
                self.dynamics_model.fit(samples_data['observations'],
                                        samples_data['actions'],
                                        samples_data['next_observations'],
                                        epochs=self.dynamics_model_max_epochs,
                                        verbose=False,
                                        log_tabular=True,
                                        prefix='Model-')

                logger.record_tabular('Model-TimeModelFit',
                                      time.time() - time_fit_start)

                env_paths = []
                for id_rollout in range(self.num_rollouts_per_iter):
                    times_dyn_sampling = []
                    times_dyn_sample_processing = []
                    times_optimization = []
                    times_step = []

                    grad_steps_per_rollout = self.grad_steps_per_rollout
                    for step in range(grad_steps_per_rollout):

                        # logger.log("\n ---------------- Grad-Step %d ----------------" % int(grad_steps_per_rollout*itr*self.num_rollouts_per_iter,
                        #                                                             + id_rollout * grad_steps_per_rollout + step))
                        step_start_time = time.time()
                        """ -------------------- Sampling --------------------------"""

                        logger.log("Obtaining samples from the model...")
                        time_env_sampling_start = time.time()
                        paths = self.model_sampler.obtain_samples(
                            log=True, log_prefix='Policy-', buffer=buffer)
                        sampling_time = time.time() - time_env_sampling_start
                        """ ----------------- Processing Samples ---------------------"""

                        logger.log("Processing samples from the model...")
                        time_proc_samples_start = time.time()
                        samples_data = self.model_sample_processor.process_samples(
                            paths, log='all', log_prefix='Policy-')
                        proc_samples_time = time.time(
                        ) - time_proc_samples_start

                        if type(paths) is list:
                            self.log_diagnostics(paths, prefix='Policy-')
                        else:
                            self.log_diagnostics(sum(paths.values(), []),
                                                 prefix='Policy-')
                        """ ------------------ Policy Update ---------------------"""

                        logger.log("Optimizing policy...")
                        time_optimization_step_start = time.time()
                        self.algo.optimize_policy(samples_data)
                        optimization_time = time.time(
                        ) - time_optimization_step_start

                        times_dyn_sampling.append(sampling_time)
                        times_dyn_sample_processing.append(proc_samples_time)
                        times_optimization.append(optimization_time)
                        times_step.append(time.time() - step_start_time)

                    logger.log(
                        "Obtaining random samples from the environment...")
                    env_paths.extend(
                        self.env_sampler.obtain_samples(
                            log=True,
                            log_prefix='Data-EnvSampler-',
                            verbose=True))

                logger.record_tabular('Data-TimeEnvSampling',
                                      time.time() - time_env_sampling_start)
                logger.log("Processing environment samples...")
                """ ------------------- Logging Stuff --------------------------"""
                logger.logkv('Iteration', itr)
                logger.logkv('n_timesteps',
                             self.env_sampler.total_timesteps_sampled)
                logger.logkv('Policy-TimeSampleProc',
                             np.sum(times_dyn_sample_processing))
                logger.logkv('Policy-TimeSampling', np.sum(times_dyn_sampling))
                logger.logkv('Policy-TimeAlgoOpt', np.sum(times_optimization))
                logger.logkv('Policy-TimeStep', np.sum(times_step))

                logger.logkv('Time', time.time() - start_time)
                logger.logkv('ItrTime', time.time() - itr_start_time)

                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr)
                logger.save_itr_params(itr, params)
                logger.log("Saved")

                logger.dumpkvs()
                if itr == 0:
                    sess.graph.finalize()

            logger.logkv('Trainer-TimeTotal', time.time() - start_time)

        logger.log("Training finished")
        self.sess.close()
Esempio n. 3
0
    def train(self):
        """
        Trains policy on env using algo

        Pseudocode:
            for itr in n_itr:
                for step in num_inner_grad_steps:
                    sampler.sample()
                    algo.compute_updated_dists()
                algo.optimize_policy()
                sampler.update_goals()
        """
        start_time = time.time()
        rollout_time = 0
        success_rate = 0
        itrs_on_level = 0

        for itr in range(self.start_itr, self.n_itr):

            # high_entropy = False
            # # At the very beginning of a new level, have high entropy
            # if itrs_on_level < 3:
            #     high_entropy = True
            # # At the very beginning, have high entropy
            # if self.curriculum_step == 0 and itrs_on_level < 20:
            #     high_entropy = True
            # # If the model never succeeds
            # if itrs_on_level > 20 and success_rate < .2:
            #     high_entropy = True
            # logger.logkv("HighE", high_entropy)
            logger.logkv("ItrsOnLEvel", itrs_on_level)
            itrs_on_level += 1

            itr_start_time = time.time()
            logger.log("\n ---------------- Iteration %d ----------------" % itr)
            logger.log("Sampling set of tasks/goals for this meta-batch...")

            """ -------------------- Sampling --------------------------"""

            logger.log("Obtaining samples...")
            time_env_sampling_start = time.time()
            samples_data, episode_logs = self.algo.collect_experiences(self.teacher_train_dict)
            assert len(samples_data.action.shape) == 1, (samples_data.action.shape)
            time_collection = time.time() - time_env_sampling_start
            time_training_start = time.time()
            # if high_entropy:
            #     entropy = self.algo.entropy_coef# * 10
            # else:
            #     entropy = self.algo.entropy_coef
            summary_logs = self.algo.optimize_policy(samples_data, teacher_dict=self.teacher_train_dict)
            time_training = time.time() - time_training_start
            self._log(episode_logs, summary_logs, tag="Train")
            logger.logkv('Curriculum Step', self.curriculum_step)
            advance_curriculum = self.check_advance_curriculum(episode_logs, summary_logs)
            logger.logkv('Train/Advance', int(advance_curriculum))
            time_env_sampling = time.time() - time_env_sampling_start
            #
            # """ ------------------ Reward Predictor Splicing ---------------------"""
            rp_start_time = time.time()
            # samples_data = self.use_reward_predictor(samples_data)  # TODO: update
            rp_splice_time = time.time() - rp_start_time

            """ ------------------ End Reward Predictor Splicing ---------------------"""

            """ ------------------ Policy Update ---------------------"""

            logger.log("Optimizing policy...")
            # # This needs to take all samples_data so that it can construct graph for meta-optimization.
            time_rp_train_start = time.time()
            # self.train_rp(samples_data)
            time_rp_train = time.time() - time_rp_train_start

            """ ------------------ Distillation ---------------------"""
            if False:  # TODO: remove False once I've checked things work with dictionaries
            # if self.supervised_model is not None and advance_curriculum:
                time_distill_start = time.time()
                for _ in range(3):  # TODO: tune this!
                    distill_log = self.distill(samples_data, is_training=True)
                for k, v in distill_log.items():
                    logger.logkv(f"Distill/{k}_Train", v)
                distill_time = time.time() - time_distill_start
                advance_curriculum = distill_log['Accuracy'] >= self.accuracy_threshold
                logger.logkv('Distill/Advance', int(advance_curriculum))
            else:
                distill_time = 0

            """ ------------------ Policy rollouts ---------------------"""
            run_policy_time = 0
            if False:  # TODO: remove False once I've checkeo things work with dictionaries
            # if advance_curriculum or (itr % self.eval_every == 0) or (itr == self.n_itr - 1):  # TODO: collect rollouts with and without the teacher
                train_advance_curriculum = advance_curriculum
                with torch.no_grad():
                    if self.supervised_model is not None:
                        # Distilled model
                        time_run_supervised_start = time.time()
                        self.sampler.supervised_model.reset(dones=[True] * len(samples_data.obs))
                        logger.log("Running supervised model")
                        advance_curriculum_sup = self.run_supervised(self.il_trainer.acmodel, self.no_teacher_dict, "DRollout/")
                        run_supervised_time = time.time() - time_run_supervised_start
                    else:
                        run_supervised_time = 0
                        advance_curriculum_sup = True
                    # Original Policy
                    time_run_policy_start = time.time()
                    self.algo.acmodel.reset(dones=[True] * len(samples_data.obs))
                    logger.log("Running model with teacher")
                    advance_curriculum_policy = self.run_supervised(self.algo.acmodel, self.teacher_train_dict, "Rollout/")
                    run_policy_time = time.time() - time_run_policy_start

                    advance_curriculum = advance_curriculum_policy and advance_curriculum_sup and train_advance_curriculum
                    print("ADvancing curriculum???", advance_curriculum)

                    logger.logkv('Advance', int(advance_curriculum))
            else:
                run_supervised_time = 0
                # advance_curriculum = False

            """ ------------------- Logging Stuff --------------------------"""
            logger.logkv('Itr', itr)
            logger.logkv('n_timesteps', self.sampler.total_timesteps_sampled)

            logger.logkv('Time/Total', time.time() - start_time)
            logger.logkv('Time/Itr', time.time() - itr_start_time)

            logger.logkv('Curriculum Percent', self.curriculum_step / len(self.env.levels_list))

            process = psutil.Process(os.getpid())
            memory_use = process.memory_info().rss / float(2 ** 20)
            logger.logkv('Memory MiB', memory_use)

            logger.log(self.exp_name)

            logger.logkv('Time/Sampling', time_env_sampling)
            logger.logkv('Time/Training', time_training)
            logger.logkv('Time/Collection', time_collection)
            logger.logkv('Time/RPUse', rp_splice_time)
            logger.logkv('Time/RPTrain', time_rp_train)
            logger.logkv('Time/RunwTeacher', run_policy_time)
            logger.logkv('Time/Distillation', distill_time)
            logger.logkv('Time/RunDistilled', run_supervised_time)
            logger.logkv('Time/VidRollout', rollout_time)
            logger.dumpkvs()

            """ ------------------ Video Saving ---------------------"""

            should_save_video = (itr % self.save_videos_every == 0) or (itr == self.n_itr - 1) or advance_curriculum
            if should_save_video:
            # TODO: consider generating videos with every element in the powerset of feedback types
                time_rollout_start = time.time()
                if self.supervised_model is not None:
                    self.il_trainer.acmodel.reset(dones=[True])
                    self.save_videos(itr, self.il_trainer.acmodel, save_name='distilled_video_stoch',
                                   num_rollouts=5,
                                   teacher_dict=self.teacher_train_dict,
                                   save_video=should_save_video, log_prefix="DVidRollout/Stoch", stochastic=True)
                    self.il_trainer.acmodel.reset(dones=[True])
                    self.save_videos(itr, self.il_trainer.acmodel, save_name='distilled_video_det',
                                     num_rollouts=5,
                                   teacher_dict=self.no_teacher_dict,
                                     save_video=should_save_video, log_prefix="DVidRollout/Det", stochastic=False)

                self.algo.acmodel.reset(dones=[True])
                self.save_videos(self.curriculum_step, self.algo.acmodel, save_name='withTeacher_video_stoch',
                               num_rollouts=5,
                               teacher_dict=self.teacher_train_dict,
                               save_video=should_save_video, log_prefix="VidRollout/Stoch", stochastic=True)
                self.algo.acmodel.reset(dones=[True])
                self.save_videos(self.curriculum_step, self.algo.acmodel, save_name='withTeacher_video_det',
                                 num_rollouts=5,
                                   teacher_dict=self.no_teacher_dict,
                                 save_video=should_save_video, log_prefix="VidRollout/Det", stochastic=False)
                rollout_time = time.time() - time_rollout_start
            else:
                rollout_time = 0

            params = self.get_itr_snapshot(itr)
            step = self.curriculum_step

            if self.log_and_save:
                if (itr % self.save_every == 0) or (itr == self.n_itr - 1) or advance_curriculum:
                    logger.log("Saving snapshot...")
                    logger.save_itr_params(itr, step, params)
                    logger.log("Saved")

            if advance_curriculum and self.advance_levels:
                self.curriculum_step += 1
                self.sampler.advance_curriculum()
                self.algo.advance_curriculum()
                # self.algo.set_optimizer()
                itrs_on_level = 0


        logger.log("Training finished")
Esempio n. 4
0
    def train(self):
        """
        Trains policy on env using algo

        Pseudocode:
            for itr in n_itr:
                for step in num_inner_grad_steps:
                    sampler.sample()
                    algo.compute_updated_dists()
                algo.optimize_policy()
                sampler.update_goals()
        """
        with self.sess.as_default() as sess:

            # initialize uninitialized vars  (only initialize vars that were not loaded)
            # uninit_vars = [var for var in tf.global_variables() if not sess.run(tf.is_variable_initialized(var))]
            # sess.run(tf.variables_initializer(uninit_vars))
            sess.run(tf.global_variables_initializer())

            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                logger.log(
                    "\n ---------------- Iteration %d ----------------" % itr)

                time_env_sampling_start = time.time()

                if self.initial_random_samples and itr == 0:
                    logger.log(
                        "Obtaining random samples from the environment...")
                    env_paths = self.env_sampler.obtain_samples(
                        log=True, random=True, log_prefix='Data-EnvSampler-')

                else:
                    logger.log(
                        "Obtaining samples from the environment using the policy..."
                    )
                    env_paths = self.env_sampler.obtain_samples(
                        log=True, log_prefix='Data-EnvSampler-')

                # Add sleeping time to match parallel experiment
                # time.sleep(10)

                logger.record_tabular('Data-TimeEnvSampling',
                                      time.time() - time_env_sampling_start)
                logger.log("Processing environment samples...")

                # first processing just for logging purposes
                time_env_samp_proc = time.time()

                samples_data = self.dynamics_sample_processor.process_samples(
                    env_paths, log=True, log_prefix='Data-EnvTrajs-')

                self.env.log_diagnostics(env_paths, prefix='Data-EnvTrajs-')

                logger.record_tabular('Data-TimeEnvSampleProc',
                                      time.time() - time_env_samp_proc)
                ''' --------------- fit dynamics model --------------- '''

                time_fit_start = time.time()

                self.dynamics_model.update_buffer(
                    samples_data['observations'],
                    samples_data['actions'],
                    samples_data['next_observations'],
                    check_init=True)

                buffer = None if not self.sample_from_buffer else samples_data

                logger.record_tabular('Model-TimeModelFit',
                                      time.time() - time_fit_start)
                ''' --------------- MAML steps --------------- '''
                times_dyn_sampling = []
                times_dyn_sample_processing = []
                times_optimization = []
                times_step = []
                remaining_model_idx = list(
                    range(self.dynamics_model.num_models))
                valid_loss_rolling_average_prev = None

                with_new_data = True
                for id_step in range(self.repeat_steps):

                    for epoch in range(self.num_epochs_per_step):
                        logger.log(
                            "Training dynamics model for %i epochs ..." % 1)
                        remaining_model_idx, valid_loss_rolling_average = self.dynamics_model.fit_one_epoch(
                            remaining_model_idx,
                            valid_loss_rolling_average_prev,
                            with_new_data,
                            log_tabular=True,
                            prefix='Model-')
                        with_new_data = False

                    for step in range(self.num_grad_policy_per_step):

                        logger.log(
                            "\n ---------------- Grad-Step %d ----------------"
                            % int(itr * self.repeat_steps *
                                  self.num_grad_policy_per_step + id_step *
                                  self.num_grad_policy_per_step + step))
                        step_start_time = time.time()
                        """ -------------------- Sampling --------------------------"""

                        logger.log("Obtaining samples from the model...")
                        time_env_sampling_start = time.time()
                        paths = self.model_sampler.obtain_samples(
                            log=True, log_prefix='Policy-', buffer=buffer)
                        sampling_time = time.time() - time_env_sampling_start
                        """ ----------------- Processing Samples ---------------------"""

                        logger.log("Processing samples from the model...")
                        time_proc_samples_start = time.time()
                        samples_data = self.model_sample_processor.process_samples(
                            paths, log='all', log_prefix='Policy-')
                        proc_samples_time = time.time(
                        ) - time_proc_samples_start

                        if type(paths) is list:
                            self.log_diagnostics(paths, prefix='Policy-')
                        else:
                            self.log_diagnostics(sum(paths.values(), []),
                                                 prefix='Policy-')
                        """ ------------------ Policy Update ---------------------"""

                        logger.log("Optimizing policy...")
                        # This needs to take all samples_data so that it can construct graph for meta-optimization.
                        time_optimization_step_start = time.time()
                        self.algo.optimize_policy(samples_data)
                        optimization_time = time.time(
                        ) - time_optimization_step_start

                        times_dyn_sampling.append(sampling_time)
                        times_dyn_sample_processing.append(proc_samples_time)
                        times_optimization.append(optimization_time)
                        times_step.append(time.time() - step_start_time)
                """ ------------------- Logging Stuff --------------------------"""
                logger.logkv('Iteration', itr)
                logger.logkv('n_timesteps',
                             self.env_sampler.total_timesteps_sampled)
                logger.logkv('Policy-TimeSampleProc',
                             np.sum(times_dyn_sample_processing))
                logger.logkv('Policy-TimeSampling', np.sum(times_dyn_sampling))
                logger.logkv('Policy-TimeAlgoOpt', np.sum(times_optimization))
                logger.logkv('Policy-TimeStep', np.sum(times_step))

                logger.logkv('Time', time.time() - start_time)
                logger.logkv('ItrTime', time.time() - itr_start_time)

                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr)
                logger.save_itr_params(itr, params)
                logger.log("Saved")

                logger.dumpkvs()
                if itr == 0:
                    sess.graph.finalize()

            logger.logkv('Trainer-TimeTotal', time.time() - start_time)

        logger.log("Training finished")
        self.sess.close()
Esempio n. 5
0
    def train(self):
        """
        Trains policy on env using algo

        Pseudocode:
            for itr in n_itr:
                for step in num_inner_grad_steps:
                    sampler.sample()
                    algo.compute_updated_dists()
                algo.optimize_policy()
                sampler.update_goals()
        """
        with self.sess.as_default() as sess:

            # initialize uninitialized vars  (only initialize vars that were not loaded)
            uninit_vars = [
                var for var in tf.global_variables()
                if not sess.run(tf.is_variable_initialized(var))
            ]
            sess.run(tf.variables_initializer(uninit_vars))

            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                logger.log(
                    "\n ---------------- Iteration %d ----------------" % itr)

                time_env_sampling_start = time.time()

                logger.log(
                    "Obtaining samples from the environment using the policy..."
                )
                env_paths = self.sampler.obtain_samples(log=True,
                                                        log_prefix='')

                logger.record_tabular('Time-EnvSampling',
                                      time.time() - time_env_sampling_start)
                logger.log("Processing environment samples...")

                # first processing just for logging purposes
                time_env_samp_proc = time.time()
                samples_data = self.sample_processor.process_samples(
                    env_paths, log=True, log_prefix='EnvTrajs-')

                logger.record_tabular('Time-EnvSampleProc',
                                      time.time() - time_env_samp_proc)
                ''' --------------- fit dynamics model --------------- '''

                time_fit_start = time.time()

                logger.log("Training dynamics model for %i epochs ..." %
                           self.dynamics_model_max_epochs)
                self.dynamics_model.fit(samples_data['observations'],
                                        samples_data['actions'],
                                        samples_data['next_observations'],
                                        epochs=self.dynamics_model_max_epochs,
                                        verbose=False,
                                        log_tabular=True,
                                        early_stopping=True,
                                        compute_normalization=False)

                logger.log("Training the value function for %i epochs ..." %
                           self.vfun_max_epochs)
                self.value_function.fit(samples_data['observations'],
                                        samples_data['returns'],
                                        epochs=self.vfun_max_epochs,
                                        verbose=False,
                                        log_tabular=True,
                                        compute_normalization=False)

                logger.log("Training the policy ...")
                self.algo.optimize_policy(samples_data)

                logger.record_tabular('Time-ModelFit',
                                      time.time() - time_fit_start)
                """ ------------------- Logging Stuff --------------------------"""
                logger.logkv('Itr', itr)
                logger.logkv('n_timesteps',
                             self.sampler.total_timesteps_sampled)

                logger.logkv('Time', time.time() - start_time)
                logger.logkv('ItrTime', time.time() - itr_start_time)

                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr)
                self.log_diagnostics(env_paths, '')
                logger.save_itr_params(itr, params)
                logger.log("Saved")

                logger.dumpkvs()
                if itr == 0:
                    sess.graph.finalize()

        logger.log("Training finished")
        self.sess.close()
Esempio n. 6
0
    def train(self):
        """
        Trains policy on env using algo

        Pseudocode:
            for itr in n_itr:
                for step in num_inner_grad_steps:
                    sampler.sample()
                    algo.compute_updated_dists()
                algo.optimize_policy()
                sampler.update_goals()
        """
        with self.sess.as_default() as sess:

            # initialize uninitialized vars  (only initialize vars that were not loaded)
            # uninit_vars = [var for var in tf.global_variables() if not sess.run(tf.is_variable_initialized(var))]
            # sess.run(tf.variables_initializer(uninit_vars))
            sess.run(tf.global_variables_initializer())

            if type(self.meta_steps_per_iter) is tuple:
                meta_steps_per_iter = np.linspace(self.meta_steps_per_iter[0],
                                                  self.meta_steps_per_iter[1],
                                                  self.n_itr).astype(np.int)
            else:
                meta_steps_per_iter = [self.meta_steps_per_iter] * self.n_itr
            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                logger.log(
                    "\n ---------------- Iteration %d ----------------" % itr)

                time_env_sampling_start = time.time()

                if self.initial_random_samples and itr == 0:
                    logger.log(
                        "Obtaining random samples from the environment...")
                    env_paths = self.env_sampler.obtain_samples(
                        log=True, random=True, log_prefix='EnvSampler-')

                else:
                    logger.log(
                        "Obtaining samples from the environment using the policy..."
                    )
                    env_paths = self.env_sampler.obtain_samples(
                        log=True, log_prefix='EnvSampler-')

                logger.record_tabular('Time-EnvSampling',
                                      time.time() - time_env_sampling_start)
                logger.log("Processing environment samples...")

                # first processing just for logging purposes
                time_env_samp_proc = time.time()
                if type(env_paths) is dict or type(
                        env_paths) is collections.OrderedDict:
                    env_paths = list(env_paths.values())
                    idxs = np.random.choice(range(len(env_paths)),
                                            size=self.num_rollouts_per_iter,
                                            replace=False)
                    env_paths = sum([env_paths[idx] for idx in idxs], [])

                elif type(env_paths) is list:
                    idxs = np.random.choice(range(len(env_paths)),
                                            size=self.num_rollouts_per_iter,
                                            replace=False)
                    env_paths = [env_paths[idx] for idx in idxs]

                else:
                    raise TypeError
                samples_data = self.dynamics_sample_processor.process_samples(
                    env_paths, log=True, log_prefix='EnvTrajs-')

                self.env.log_diagnostics(env_paths, prefix='EnvTrajs-')

                logger.record_tabular('Time-EnvSampleProc',
                                      time.time() - time_env_samp_proc)

                logger.record_tabular('Time-Data',
                                      time.time() - time_env_sampling_start)
                ''' --------------- fit dynamics model --------------- '''

                time_fit_start = time.time()

                logger.log("Training dynamics model for %i epochs ..." %
                           (self.dynamics_model_max_epochs))
                self.dynamics_model.fit(samples_data['observations'],
                                        samples_data['actions'],
                                        samples_data['next_observations'],
                                        epochs=self.dynamics_model_max_epochs,
                                        verbose=True,
                                        log_tabular=True)

                buffer = None if not self.sample_from_buffer else samples_data

                logger.record_tabular('Time-ModelFit',
                                      time.time() - time_fit_start)
                ''' ------------ log real performance --------------- '''

                if self.log_real_performance:
                    logger.log("Evaluating the performance of the real policy")
                    self.policy.switch_to_pre_update()
                    env_paths = self.env_sampler.obtain_samples(
                        log=True, log_prefix='PrePolicy-')
                    samples_data = self.model_sample_processor.process_samples(
                        env_paths, log='all', log_prefix='PrePolicy-')
                    self.algo._adapt(samples_data)
                    env_paths = self.env_sampler.obtain_samples(
                        log=True, log_prefix='PostPolicy-')
                    self.model_sample_processor.process_samples(
                        env_paths, log='all', log_prefix='PostPolicy-')
                ''' --------------- MAML steps --------------- '''

                times_dyn_sampling = []
                times_dyn_sample_processing = []
                times_meta_sampling = []
                times_inner_step = []
                times_total_inner_step = []
                times_outer_step = []
                times_maml_steps = []

                for meta_itr in range(meta_steps_per_iter[itr]):

                    logger.log(
                        "\n ---------------- Meta-Step %d ----------------" %
                        int(sum(meta_steps_per_iter[:itr]) + meta_itr))
                    self.policy.switch_to_pre_update(
                    )  # Switch to pre-update policy

                    all_samples_data, all_paths = [], []
                    list_sampling_time, list_inner_step_time, list_outer_step_time, list_proc_samples_time = [], [], [], []
                    time_maml_steps_start = time.time()
                    start_total_inner_time = time.time()
                    for step in range(self.num_inner_grad_steps + 1):
                        logger.log("\n ** Adaptation-Step %d **" % step)
                        """ -------------------- Sampling --------------------------"""

                        logger.log("Obtaining samples...")
                        time_env_sampling_start = time.time()
                        paths = self.model_sampler.obtain_samples(
                            log=True,
                            log_prefix='Step_%d-' % step,
                            buffer=buffer)
                        list_sampling_time.append(time.time() -
                                                  time_env_sampling_start)
                        all_paths.append(paths)
                        """ ----------------- Processing Samples ---------------------"""

                        logger.log("Processing samples...")
                        time_proc_samples_start = time.time()
                        samples_data = self.model_sample_processor.process_samples(
                            paths, log='all', log_prefix='Step_%d-' % step)
                        all_samples_data.append(samples_data)
                        list_proc_samples_time.append(time.time() -
                                                      time_proc_samples_start)

                        self.log_diagnostics(sum(list(paths.values()), []),
                                             prefix='Step_%d-' % step)
                        """ ------------------- Inner Policy Update --------------------"""

                        time_inner_step_start = time.time()
                        if step < self.num_inner_grad_steps:
                            logger.log("Computing inner policy updates...")
                            self.algo._adapt(samples_data)

                        list_inner_step_time.append(time.time() -
                                                    time_inner_step_start)
                    total_inner_time = time.time() - start_total_inner_time

                    time_maml_opt_start = time.time()
                    """ ------------------ Outer Policy Update ---------------------"""

                    logger.log("Optimizing policy...")
                    # This needs to take all samples_data so that it can construct graph for meta-optimization.
                    time_outer_step_start = time.time()
                    self.algo.optimize_policy(all_samples_data)

                    times_inner_step.append(list_inner_step_time)
                    times_total_inner_step.append(total_inner_time)
                    times_outer_step.append(time.time() -
                                            time_outer_step_start)
                    times_meta_sampling.append(np.sum(list_sampling_time))
                    times_dyn_sampling.append(list_sampling_time)
                    times_dyn_sample_processing.append(list_proc_samples_time)
                    times_maml_steps.append(time.time() -
                                            time_maml_steps_start)
                """ ------------------- Logging Stuff --------------------------"""
                logger.logkv('Itr', itr)
                if self.log_real_performance:
                    logger.logkv(
                        'n_timesteps',
                        self.env_sampler.total_timesteps_sampled /
                        (3 * self.policy.meta_batch_size) *
                        self.num_rollouts_per_iter)
                else:
                    logger.logkv(
                        'n_timesteps',
                        self.env_sampler.total_timesteps_sampled /
                        self.policy.meta_batch_size *
                        self.num_rollouts_per_iter)
                logger.logkv('AvgTime-OuterStep', np.mean(times_outer_step))
                logger.logkv('AvgTime-InnerStep', np.mean(times_inner_step))
                logger.logkv('AvgTime-TotalInner',
                             np.mean(times_total_inner_step))
                logger.logkv('AvgTime-InnerStep', np.mean(times_inner_step))
                logger.logkv('AvgTime-SampleProc',
                             np.mean(times_dyn_sample_processing))
                logger.logkv('AvgTime-Sampling', np.mean(times_dyn_sampling))
                logger.logkv('AvgTime-MAMLSteps', np.mean(times_maml_steps))

                logger.logkv('Time', time.time() - start_time)
                logger.logkv('ItrTime', time.time() - itr_start_time)

                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr)
                logger.save_itr_params(itr, params)
                logger.log("Saved")

                logger.dumpkvs()
                if itr == 0:
                    sess.graph.finalize()

        logger.log("Training finished")
        self.sess.close()
Esempio n. 7
0
    def train(self):
        """
        Trains policy on env using algo

        Pseudocode:
            for itr in n_itr:
                for step in num_inner_grad_steps:
                    sampler.sample()
                    algo.compute_updated_dists()
                algo.optimize_policy()
                sampler.update_goals()
        """
        with self.sess.as_default() as sess:

            # initialize uninitialized vars  (only initialize vars that were not loaded)
            uninit_vars = [
                var for var in tf.global_variables()
                if not sess.run(tf.is_variable_initialized(var))
            ]
            sess.run(tf.variables_initializer(uninit_vars))

            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                self.sampler.update_tasks()
                itr_start_time = time.time()
                logger.log(
                    "\n ---------------- Iteration %d ----------------" % itr)
                logger.log(
                    "Sampling set of tasks/goals for this meta-batch...")
                """ -------------------- Sampling --------------------------"""

                logger.log("Obtaining samples...")
                time_env_sampling_start = time.time()
                paths = self.sampler.obtain_samples(log=True,
                                                    log_prefix='train-')
                sampling_time = time.time() - time_env_sampling_start
                """ ----------------- Processing Samples ---------------------"""

                logger.log("Processing samples...")
                time_proc_samples_start = time.time()
                samples_data = self.sample_processor.process_samples(
                    paths, log='all', log_prefix='train-')
                proc_samples_time = time.time() - time_proc_samples_start

                if type(paths) is list:
                    self.log_diagnostics(paths, prefix='train-')
                else:
                    self.log_diagnostics(sum(paths.values(), []),
                                         prefix='train-')
                """ ------------------ Policy Update ---------------------"""

                logger.log("Optimizing policy...")
                # This needs to take all samples_data so that it can construct graph for meta-optimization.
                time_optimization_step_start = time.time()
                self.algo.optimize_policy(samples_data)
                """ ------------------- Logging Stuff --------------------------"""
                logger.logkv('Itr', itr)
                logger.logkv('n_timesteps',
                             self.sampler.total_timesteps_sampled)

                logger.logkv('Time-Optimization',
                             time.time() - time_optimization_step_start)
                logger.logkv('Time-SampleProc', np.sum(proc_samples_time))
                logger.logkv('Time-Sampling', sampling_time)

                logger.logkv('Time', time.time() - start_time)
                logger.logkv('ItrTime', time.time() - itr_start_time)

                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr)
                logger.save_itr_params(itr, params)
                logger.log("Saved")

                logger.dumpkvs()
                if itr == 0:
                    sess.graph.finalize()

        logger.log("Training finished")
        self.sess.close()
Esempio n. 8
0
    def train(self):
        """
        Trains policy on env using algo

        Pseudocode:
            for itr in n_itr:
                for step in num_inner_grad_steps:
                    sampler.sample()
                    algo.compute_updated_dists()
                algo.optimize_policy()
                sampler.update_goals()
        """
        with self.sess.as_default() as sess:

            # initialize uninitialized vars  (only initialize vars that were not loaded)
            uninit_vars = [
                var for var in tf.global_variables()
                if not sess.run(tf.is_variable_initialized(var))
            ]
            sess.run(tf.variables_initializer(uninit_vars))

            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                logger.log(
                    "\n ---------------- Iteration %d ----------------" % itr)
                logger.log(
                    "Sampling set of tasks/goals for this meta-batch...")

                self.sampler.update_tasks()
                self.policy.switch_to_pre_update(
                )  # Switch to pre-update policy

                all_samples_data, all_paths = [], []
                list_sampling_time, list_inner_step_time, list_outer_step_time, list_proc_samples_time = [], [], [], []
                start_total_inner_time = time.time()
                for step in range(self.num_inner_grad_steps + 1):
                    logger.log('** Step ' + str(step) + ' **')
                    """ -------------------- Sampling --------------------------"""

                    logger.log("Obtaining samples...")
                    time_env_sampling_start = time.time()
                    paths = self.sampler.obtain_samples(log=True,
                                                        log_prefix='Step_%d-' %
                                                        step)
                    list_sampling_time.append(time.time() -
                                              time_env_sampling_start)
                    all_paths.append(paths)
                    """ ----------------- Processing Samples ---------------------"""

                    logger.log("Processing samples...")
                    time_proc_samples_start = time.time()
                    samples_data = self.sample_processor.process_samples(
                        paths, log='all', log_prefix='Step_%d-' % step)
                    all_samples_data.append(samples_data)
                    list_proc_samples_time.append(time.time() -
                                                  time_proc_samples_start)

                    self.log_diagnostics(sum(list(paths.values()), []),
                                         prefix='Step_%d-' % step)
                    """ ------------------- Inner Policy Update --------------------"""

                    time_inner_step_start = time.time()
                    if step < self.num_inner_grad_steps:
                        logger.log("Computing inner policy updates...")
                        self.algo._adapt(samples_data)
                    # train_writer = tf.summary.FileWriter('/home/ignasi/Desktop/maml_zoo_graph',
                    #                                      sess.graph)
                    list_inner_step_time.append(time.time() -
                                                time_inner_step_start)
                total_inner_time = time.time() - start_total_inner_time

                time_maml_opt_start = time.time()
                """ ------------------ Outer Policy Update ---------------------"""

                logger.log("Optimizing policy...")
                # This needs to take all samples_data so that it can construct graph for meta-optimization.
                time_outer_step_start = time.time()
                self.algo.optimize_policy(all_samples_data)
                """ ------------------- Logging Stuff --------------------------"""
                logger.logkv('Itr', itr)
                logger.logkv('n_timesteps',
                             self.sampler.total_timesteps_sampled)

                logger.logkv('Time-OuterStep',
                             time.time() - time_outer_step_start)
                logger.logkv('Time-TotalInner', total_inner_time)
                logger.logkv('Time-InnerStep', np.sum(list_inner_step_time))
                logger.logkv('Time-SampleProc', np.sum(list_proc_samples_time))
                logger.logkv('Time-Sampling', np.sum(list_sampling_time))

                logger.logkv('Time', time.time() - start_time)
                logger.logkv('ItrTime', time.time() - itr_start_time)
                logger.logkv('Time-MAMLSteps',
                             time.time() - time_maml_opt_start)

                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr)
                logger.save_itr_params(itr, params)
                logger.log("Saved")

                logger.dumpkvs()
                if itr == 0:
                    sess.graph.finalize()

        logger.log("Training finished")
        self.sess.close()
Esempio n. 9
0
    def train(self):
        """
        Trains policy on env using algo

        Pseudocode:
            for itr in n_itr:
                for step in num_inner_grad_steps:
                    sampler.sample()
                    algo.compute_updated_dists()
                algo.optimize_policy()
                sampler.update_goals()
        """
        with self.sess.as_default() as sess:

            # initialize uninitialized vars  (only initialize vars that were not loaded)
            uninit_vars = [
                var for var in tf.global_variables()
                if not sess.run(tf.is_variable_initialized(var))
            ]
            sess.run(tf.variables_initializer(uninit_vars))

            if type(self.steps_per_iter) is tuple:
                steps_per_iter = np.linspace(self.steps_per_iter[0],
                                             self.steps_per_iter[1],
                                             self.n_itr).astype(np.int)
            else:
                steps_per_iter = [self.steps_per_iter] * self.n_itr

            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                logger.log(
                    "\n ---------------- Iteration %d ----------------" % itr)

                time_env_sampling_start = time.time()

                if self.initial_random_samples and itr == 0:
                    logger.log(
                        "Obtaining random samples from the environment...")
                    env_paths = self.env_sampler.obtain_samples(
                        log=True, random=True, log_prefix='EnvSampler-')

                else:
                    logger.log(
                        "Obtaining samples from the environment using the policy..."
                    )
                    env_paths = self.env_sampler.obtain_samples(
                        log=True, log_prefix='EnvSampler-')
                    self.policy.obs_filter.stats_increment()

                logger.record_tabular('Time-EnvSampling',
                                      time.time() - time_env_sampling_start)
                logger.log("Processing environment samples...")

                # first processing just for logging purposes
                time_env_samp_proc = time.time()
                samples_data = self.dynamics_sample_processor.process_samples(
                    env_paths, log=True, log_prefix='EnvTrajs-')

                self.env.log_diagnostics(env_paths, prefix='EnvTrajs-')

                logger.record_tabular('Time-EnvSampleProc',
                                      time.time() - time_env_samp_proc)
                ''' --------------- fit dynamics model --------------- '''

                time_fit_start = time.time()

                logger.log("Training dynamics model for %i epochs ..." %
                           self.dynamics_model_max_epochs)
                if self.dynamics_model is not None:
                    self.dynamics_model.fit(
                        samples_data['observations'],
                        samples_data['actions'],
                        samples_data['next_observations'],
                        epochs=self.dynamics_model_max_epochs,
                        verbose=False,
                        log_tabular=True,
                        compute_normalization=True)

                buffer = None if not self.sample_from_buffer else samples_data

                logger.record_tabular('Time-ModelFit',
                                      time.time() - time_fit_start)

                # returns = np.mean(samples_data['returns'])
                # if returns < self._last_returns:
                #     self.policy.set_params(self._prev_policy)
                #     self._last_returns = returns
                # self._prev_policy = self.policy.get_params()
                ''' ------------ log real performance --------------- '''

                # if self.log_real_performance:
                #     logger.log("Evaluating the performance of the real policy")
                #     env_paths = self.env_sampler.obtain_samples(log=True, log_prefix='RealPolicy-')
                #     _ = self.model_sample_processor.process_samples(env_paths, log='all', log_prefix='PrePolicy-')
                ''' --------------- RS steps --------------- '''

                times_dyn_sampling = []
                times_dyn_sample_processing = []
                times_itr = []
                times_rs_steps = []
                list_sampling_time = []
                list_proc_samples_time = []
                for rs_itr in range(steps_per_iter[itr]):
                    time_itr_start = time.time()
                    logger.log("\n -------------- RS-Step %d --------------" %
                               int(sum(steps_per_iter[:itr]) + rs_itr))
                    deltas = self.policy.get_deltas(self.num_deltas)
                    self.policy.set_deltas(deltas,
                                           delta_std=self.delta_std,
                                           symmetrical=True)
                    """ -------------------- Sampling --------------------------"""
                    logger.log("Obtaining samples...")
                    time_env_sampling_start = time.time()
                    samples_data = self.model_sampler.obtain_samples(
                        log=True, log_prefix='Models-', buffer=buffer)
                    list_sampling_time.append(time.time() -
                                              time_env_sampling_start)
                    """ ---------------------- Processing --------------------- """
                    # TODO: Add preprocessing of the state to see what sort of update rule between the models we want
                    samples_data = self.ars_sample_processor.process_samples(
                        samples_data, log=True, log_prefix='step%d-' % rs_itr)

                    if self.dynamics_model is None:
                        self.policy.stats_increment()
                    """ ------------------ Outer Policy Update ---------------------"""
                    logger.log("Optimizing policy...")
                    # This needs to take all samples_data so that it can construct graph for meta-optimization.
                    time_rs_start = time.time()
                    self.algo.optimize_policy(samples_data['returns'], deltas)

                    times_dyn_sampling.append(list_sampling_time)
                    times_dyn_sample_processing.append(list_proc_samples_time)
                    times_rs_steps.append(time.time() - time_rs_start)
                    times_itr.append(time.time() - time_itr_start)
                """ ------------------- Logging Stuff --------------------------"""
                logger.logkv('Itr', itr)
                if self.dynamics_model is None:
                    logger.logkv('n_timesteps',
                                 self.model_sampler.total_timesteps_sampled)
                else:
                    logger.logkv('n_timesteps',
                                 self.env_sampler.total_timesteps_sampled)

                logger.logkv('AvgTime-RS', np.mean(times_rs_steps))
                logger.logkv('AvgTime-SampleProc',
                             np.mean(times_dyn_sample_processing))
                logger.logkv('AvgTime-Sampling', np.mean(times_dyn_sampling))
                logger.logkv('AvgTime-ModelItr', np.mean(times_itr))

                logger.logkv('Time', time.time() - start_time)
                logger.logkv('ItrTime', time.time() - itr_start_time)

                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr)
                logger.save_itr_params(itr, params)
                logger.log("Saved")

                logger.dumpkvs()
                if itr == 0:
                    sess.graph.finalize()

        logger.log("Training finished")
        self.sess.close()
Esempio n. 10
0
    # Preupdate:
    for i in range(NUM_INNER_GRAD_STEPS):
        paths = sampler.obtain_samples(log=False)
        samples_data = sample_processor.process_samples(paths,
                                                        log=True,
                                                        log_prefix='%i_' % i)
        env.log_diagnostics(sum(list(paths.values()), []), prefix='%i_' % i)
        algo._adapt(samples_data)

    paths = sampler.obtain_samples(log=False)
    samples_data = sample_processor.process_samples(paths,
                                                    log=True,
                                                    log_prefix='%i_' %
                                                    NUM_INNER_GRAD_STEPS)
    env.log_diagnostics(sum(list(paths.values()), []),
                        prefix='%i_' % NUM_INNER_GRAD_STEPS)
    logger.dumpkvs()

    # Postupdate:
    while True:
        task_i = np.random.choice(range(META_BATCH_SIZE))
        env.set_task(tasks[task_i])
        print(tasks[task_i])
        obs = env.reset()
        for _ in range(PATH_LENGTH):
            env.render()
            action, _ = policy.get_action(obs, task_i)
            obs, reward, done, _ = env.step(action)
            time.sleep(0.001)
            if done:
                break
Esempio n. 11
0
    def train(self):
        """
        Trains policy on env using algo
        Pseudocode:
            for itr in n_itr:
                for step in num_inner_grad_steps:
                    sampler.sample()
                    algo.compute_updated_dists()
                algo.optimize_policy()
                sampler.update_goals()
        """

        with self.sess.as_default() as sess:
            # initialize uninitialized vars  (only initialize vars that were not loaded)
            uninit_vars = [
                var for var in tf.global_variables()
                if not sess.run(tf.is_variable_initialized(var))
            ]
            sess.run(tf.variables_initializer(uninit_vars))
            start_time = time.time()

            if self.start_itr == 0:
                self.algo._update_target(tau=1.0)
                if self.n_initial_exploration_steps > 0:
                    while self.replay_buffer._size < self.n_initial_exploration_steps:
                        paths = self.sampler.obtain_samples(
                            log=True, log_prefix='train-', random=True)
                        samples_data = self.sample_processor.process_samples(
                            paths, log='all', log_prefix='train-')
                        sample_num = samples_data['observations'].shape[0]
                        for i in range(sample_num):
                            self.replay_buffer.add_sample(
                                samples_data['observations'][i],
                                samples_data['actions'][i],
                                samples_data['rewards'][i],
                                samples_data['dones'][i],
                                samples_data['next_observations'][i],
                            )

            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                logger.log(
                    "\n ---------------- Iteration %d ----------------" % itr)
                logger.log(
                    "Sampling set of tasks/goals for this meta-batch...")
                """ -------------------- Sampling --------------------------"""

                logger.log("Obtaining samples...")
                time_env_sampling_start = time.time()
                paths = self.sampler.obtain_samples(log=True,
                                                    log_prefix='train-')
                sampling_time = time.time() - time_env_sampling_start
                """ ----------------- Processing Samples ---------------------"""
                # check how the samples are processed
                logger.log("Processing samples...")
                time_proc_samples_start = time.time()
                samples_data = self.sample_processor.process_samples(
                    paths, log='all', log_prefix='train-')
                sample_num = samples_data['observations'].shape[0]
                for i in range(sample_num):
                    self.replay_buffer.add_sample(
                        samples_data['observations'][i],
                        samples_data['actions'][i], samples_data['rewards'][i],
                        samples_data['dones'][i],
                        samples_data['next_observations'][i])
                proc_samples_time = time.time() - time_proc_samples_start

                paths = self.sampler.obtain_samples(log=True,
                                                    log_prefix='eval-',
                                                    deterministic=True)
                _ = self.sample_processor.process_samples(paths,
                                                          log='all',
                                                          log_prefix='eval-')

                self.log_diagnostics(paths, prefix='train-')
                """ ------------------ Policy Update ---------------------"""

                logger.log("Optimizing policy...")

                time_optimization_step_start = time.time()
                self.algo.optimize_policy(self.replay_buffer,
                                          itr - self.start_itr,
                                          self.num_grad_steps)
                """ ------------------- Logging Stuff --------------------------"""
                logger.logkv('Itr', itr)
                logger.logkv('n_timesteps',
                             self.sampler.total_timesteps_sampled)

                logger.logkv('Time-Optimization',
                             time.time() - time_optimization_step_start)
                logger.logkv('Time-SampleProc', np.sum(proc_samples_time))
                logger.logkv('Time-Sampling', sampling_time)

                logger.logkv('Time', time.time() - start_time)
                logger.logkv('ItrTime', time.time() - itr_start_time)

                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr)
                logger.save_itr_params(itr, params)
                logger.log("Saved")

                logger.dumpkvs()
                if itr == 0:
                    sess.graph.finalize()

        logger.log("Training finished")
        self.sess.close()
Esempio n. 12
0
    def __call__(
        self,
        exp_dir,
        policy_pickle,
        env_pickle,
        baseline_pickle,
        dynamics_model_pickle,
        feed_dict,
        queue_prev,
        queue,
        queue_next,
        remote,
        start_itr,
        n_itr,
        stop_cond,
        need_query,
        auto_push,
        config,
    ):
        time_start = time.time()

        self.name = current_process().name
        logger.configure(dir=exp_dir + '/' + self.name,
                         format_strs=['csv', 'stdout', 'log'],
                         snapshot_mode='last')

        self.n_itr = n_itr
        self.queue_prev = queue_prev
        self.queue = queue
        self.queue_next = queue_next
        self.stop_cond = stop_cond

        # FIXME: specify CPU/GPU usage here

        import tensorflow as tf

        def _init_vars():
            sess = tf.get_default_session()
            sess.run(tf.initializers.global_variables())

        with tf.Session(config=config).as_default():

            self.construct_from_feed_dict(
                policy_pickle,
                env_pickle,
                baseline_pickle,
                dynamics_model_pickle,
                feed_dict,
            )

            _init_vars()

            # warm up
            self.itr_counter = start_itr
            if self.verbose:
                print('{} waiting for starting msg from trainer...'.format(
                    self.name))
            assert remote.recv() == 'prepare start'
            self.prepare_start()
            remote.send('loop ready')
            logger.dumpkvs()
            logger.log("\n============== {} is ready =============".format(
                self.name))

            assert remote.recv() == 'start loop'
            total_push, total_synch, total_step = 0, 0, 0
            while not self.stop_cond.is_set():
                if self.verbose:
                    logger.log(
                        "\n------------------------- {} starting new loop ------------------"
                        .format(self.name))
                if need_query:  # poll
                    time_poll = time.time()
                    queue_prev.put('push')
                    time_poll = time.time() - time_poll
                    logger.logkv('{}-TimePoll'.format(self.name), time_poll)
                do_push, do_synch, do_step = self.process_queue()
                # step
                if do_step:
                    self.itr_counter += 1
                    self.step()
                    if auto_push:
                        do_push += 1
                        self.push()
                    # Assuming doing autopush for all
                    assert do_push == 1
                    assert do_step == 1

                total_push += do_push
                total_synch += do_synch
                total_step += do_step
                logger.logkv(self.name + '-TimeSoFar',
                             time.time() - time_start)
                logger.logkv(self.name + '-TotalPush', total_push)
                logger.logkv(self.name + '-TotalSynch', total_synch)
                logger.logkv(self.name + '-TotalStep', total_step)
                if total_synch > 0:
                    logger.logkv(self.name + '-StepPerSynch',
                                 total_step / total_synch)
                logger.dumpkvs()
                logger.log(
                    "\n========================== {} {}, total {} ==================="
                    .format(
                        self.name,
                        (do_push, do_synch, do_step),
                        (total_push, total_synch, total_step),
                    ))
                self.set_stop_cond()

            remote.send('loop done')

        logger.log("\n================== {} closed ===================".format(
            self.name))

        remote.send('worker closed')