Exemple #1
0
    def step(self, random=False):
        time_step = time.time()
        '''------------- Obtaining samples from the environment -----------'''

        if self.verbose:
            logger.log("Data is obtaining samples...")
        env_paths = self.env_sampler.obtain_samples(
            log=True,
            random=random,
            log_prefix='Data-EnvSampler-',
        )
        '''-------------- Processing environment samples -------------------'''

        if self.verbose:
            logger.log("Data is processing environment samples...")
        samples_data = self.dynamics_sample_processor.process_samples(
            env_paths,
            log=True,
            log_prefix='Data-EnvTrajs-',
        )

        self.samples_data_arr.append(samples_data)
        time_step = time.time() - time_step

        time_sleep = max(self.simulation_sleep - time_step, 0)
        time.sleep(time_sleep)

        logger.logkv('Data-TimeStep', time_step)
        logger.logkv('Data-TimeSleep', time_sleep)
Exemple #2
0
    def step(self, random=False):
        """
        Uses self.env_sampler which samples data under policy.
        Outcome: generate samples_data.
        """

        time_env_sampling_start = time.time()

        logger.log(
            "Obtaining samples from the environment using the policy...")
        env_sampler = self.server.pull.remote(['env_sampler'])[0]
        env_paths = env_sampler.obtain_samples(log=True,
                                               random=random,
                                               log_prefix='EnvSampler-')

        #        logger.record_tabular('Time-EnvSampling', time.time() - time_env_sampling_start)
        logger.log("Processing environment samples...")

        # first processing just for logging purposes
        time_env_samp_proc = time.time()
        samples_data = self.dynamics_sample_processor.process_samples(
            env_paths, log=True, log_prefix='EnvTrajs-')
        #        self.env.log_diagnostics(env_paths, prefix='EnvTrajs-')
        #        logger.record_tabular('Time-EnvSampleProc', time.time() - time_env_samp_proc)

        self.server.push.remote('samples_data', ray.put(samples_data))
    def optimize(self, input_val_dict):
        """
        Carries out the optimization step

        Args:
            input_val_dict (dict): dict containing the values to be fed into the computation graph

        Returns:
            (float) loss before optimization

        """

        sess = tf.get_default_session()
        feed_dict = self.create_feed_dict(input_val_dict)

        # Todo: reimplement minibatches

        loss_before_opt = None
        for epoch in range(self._max_epochs):
            if self._verbose:
                logger.log("Epoch %d" % epoch)

            loss, _ = sess.run([self._loss, self._train_op], feed_dict)

            if not loss_before_opt: loss_before_opt = loss

        return loss_before_opt
Exemple #4
0
    def optimize_policy(self, all_samples_data, log=True):
        """
        Performs MAML outer step

        Args:
            all_samples_data (list) : list of lists of lists of samples (each is a dict) split by gradient update and
             meta task
            log (bool) : whether to log statistics

        Returns:
            None
        """
        meta_op_input_dict = self._extract_input_dict_meta_op(
            all_samples_data, self._optimization_keys)

        if log: logger.log("Optimizing")
        loss_before = self.optimizer.optimize(
            input_val_dict=meta_op_input_dict)

        if log: logger.log("Computing statistics")
        loss_after = self.optimizer.loss(input_val_dict=meta_op_input_dict)

        if log:
            logger.logkv('LossBefore', loss_before)
            logger.logkv('LossAfter', loss_after)
Exemple #5
0
    def process_queue(self):
        do_push, do_synch = 0, 0
        data = None

        while True:
            try:
                if self.verbose:
                    logger.log('{} try'.format(self.name))
                new_data = self.queue.get_nowait()
                if new_data == 'push':  # only happens when next worker has need_query = True
                    if do_push == 0:  # only push once
                        do_push += 1
                        self.push()
                else:
                    do_synch = 1
                    data = new_data
            except Empty:
                break

        if do_synch:
            self._synch(data)

        do_step = 1  # - do_synch

        if self.verbose:
            logger.log(
                '{} finishes processing queue with {}, {}, {}......'.format(
                    self.name, do_push, do_synch, do_step))

        return do_push, do_synch, do_step
Exemple #6
0
    def optimize_policy(self, all_samples_data, log=True):
        """
        Performs MAML outer step

        Args:
            all_samples_data (list) : list of lists of lists of samples (each is a dict) split by gradient update and
             meta task
            log (bool) : whether to log statistics

        Returns:
            None
        """
        meta_op_input_dict = self._extract_input_dict_meta_op(
            all_samples_data, self._optimization_keys)
        logger.log("Computing KL before")
        mean_kl_before = self.optimizer.constraint_val(meta_op_input_dict)

        logger.log("Computing loss before")
        loss_before = self.optimizer.loss(meta_op_input_dict)
        logger.log("Optimizing")
        self.optimizer.optimize(meta_op_input_dict)
        logger.log("Computing loss after")
        loss_after = self.optimizer.loss(meta_op_input_dict)

        logger.log("Computing KL after")
        mean_kl = self.optimizer.constraint_val(meta_op_input_dict)
        if log:
            logger.logkv('MeanKLBefore', mean_kl_before)
            logger.logkv('MeanKL', mean_kl)

            logger.logkv('LossBefore', loss_before)
            logger.logkv('LossAfter', loss_after)
            logger.logkv('dLoss', loss_before - loss_after)
Exemple #7
0
    def _synch(self, samples_data_arr, check_init=False):
        time_synch = time.time()
        if self.verbose:
            logger.log('Model at {} is synchronizing...'.format(
                self.itr_counter))
        obs = np.concatenate([
            samples_data['observations'] for samples_data in samples_data_arr
        ])
        act = np.concatenate(
            [samples_data['actions'] for samples_data in samples_data_arr])
        obs_next = np.concatenate([
            samples_data['next_observations']
            for samples_data in samples_data_arr
        ])
        self.dynamics_model.update_buffer(
            obs=obs,
            act=act,
            obs_next=obs_next,
            check_init=check_init,
        )

        # Reset variables for early stopping condition
        logger.logkv('Model-AvgEpochs',
                     self.sum_model_itr / self.dynamics_model.num_models)
        self.sum_model_itr = 0
        self.with_new_data = True
        self.remaining_model_idx = list(range(self.dynamics_model.num_models))
        self.valid_loss_rolling_average = None
        time_synch = time.time() - time_synch

        logger.logkv('Model-TimeSynch', time_synch)
Exemple #8
0
    def optimize_reward(self,
                        samples_data,
                        log=True,
                        prefix='',
                        verbose=False):
        """
        Performs MAML outer step

        Args:
            samples_data (list) : list of lists of lists of samples (each is a dict) split by gradient update and
             meta task
            log (bool) : whether to log statistics

        Returns:
            None
        """
        input_dict = self._extract_input_dict(samples_data,
                                              self._optimization_r_keys,
                                              prefix='train')

        if verbose: logger.log("Optimizing")
        loss_before = self.optimizer_r.optimize(input_val_dict=input_dict)

        if verbose: logger.log("Computing statistics")
        loss_after = self.optimizer_r.loss(input_val_dict=input_dict)

        if log:
            logger.logkv(prefix + 'RewardLossBefore', loss_before)
            logger.logkv(prefix + 'RewardLossAfter', loss_after)
Exemple #9
0
    def optimize_policy(self,
                        samples_data,
                        log=True,
                        prefix='',
                        verbose=False):
        """
        Performs MAML outer step

        Args:
            samples_data (list) : list of lists of lists of samples (each is a dict) split by gradient update and
             meta task
            log (bool) : whether to log statistics

        Returns:
            None
        """
        input_dict = self._extract_input_dict(samples_data,
                                              self._optimization_keys,
                                              prefix='train')
        entropy_loss, reward_loss = self.optimizer.compute_loss_variations(
            input_dict, self.entropy_loss, self.reward_loss, self.log_values)

        if verbose: logger.log("Optimizing")

        # Update model
        loss_before = self.optimizer.optimize(input_val_dict=input_dict)

        if verbose: logger.log("Computing statistics")
        loss_after = self.optimizer.loss(input_val_dict=input_dict)

        if log:
            logger.logkv(prefix + 'Loss/LossBefore', loss_before)
            logger.logkv(prefix + 'Loss/LossAfter', loss_after)
            logger.logkv(prefix + 'Loss/PartialLossEntropy', entropy_loss)
            logger.logkv(prefix + 'Loss/PartialLossReward', reward_loss)
Exemple #10
0
    def _synch(self, dynamics_model_state_pickle):
        time_synch = time.time()
        if self.verbose:
            logger.log('Policy is synchronizing...')
        dynamics_model_state = pickle.loads(dynamics_model_state_pickle)
        assert isinstance(dynamics_model_state, dict)
        self.model_sampler.dynamics_model.set_shared_params(
            dynamics_model_state)
        if hasattr(self.model_sampler, 'vec_env'):
            self.model_sampler.vec_env.dynamics_model.set_shared_params(
                dynamics_model_state)
        time_synch = time.time() - time_synch

        logger.logkv('Policy-TimeSynch', time_synch)
Exemple #11
0
    def push(self):
        time_push = time.time()
        state_pickle = pickle.dumps(
            self.dynamics_model.get_shared_param_values())
        assert state_pickle is not None
        while self.queue_next.qsize() > 5:
            try:
                logger.log('Model is off loading data from queue_next...')
                _ = self.queue_next.get_nowait()
            except Empty:
                break
        self.queue_next.put(state_pickle)
        time_push = time.time() - time_push

        logger.logkv('Model-TimePush', time_push)
Exemple #12
0
    def push(self):
        time_push = time.time()
        policy_state_pickle = pickle.dumps(
            self.policy.get_shared_param_values())
        assert policy_state_pickle is not None
        while self.queue_next.qsize() > 5:
            try:
                logger.log('Policy is off loading data from queue_next...')
                _ = self.queue_next.get_nowait()
            except Empty:
                # very rare chance to reach here
                break
        self.queue_next.put(policy_state_pickle)
        time_push = time.time() - time_push

        logger.logkv('Policy-TimePush', time_push)
Exemple #13
0
    def step(self):
        time_step = time.time()
        ''' --------------- MAML steps --------------- '''

        self.policy.switch_to_pre_update()  # Switch to pre-update policy
        all_samples_data = []

        for step in range(self.num_inner_grad_steps + 1):
            if self.verbose:
                logger.log("Policy Adaptation-Step %d **" % step)
            """ -------------------- Sampling --------------------------"""

            #time_sampling = time.time()
            paths = self.model_sampler.obtain_samples(log=True,
                                                      log_prefix='Policy-',
                                                      buffer=None)
            #time_sampling = time.time() - time_sampling
            """ ----------------- Processing Samples ---------------------"""

            #time_sample_proc = time.time()
            samples_data = self.model_sample_processor.process_samples(
                paths, log='all', log_prefix='Policy-')
            all_samples_data.append(samples_data)
            #time_sample_proc = time.time() - time_sample_proc

            self.log_diagnostics(sum(list(paths.values()), []),
                                 prefix='Policy-')
            """ ------------------- Inner Policy Update --------------------"""

            #time_algo_adapt = time.time()
            if step < self.num_inner_grad_steps:
                self.algo._adapt(samples_data)
            #time_algo_adapt = time.time() - time_algo_adapt
        """ ------------------ Outer Policy Update ---------------------"""

        if self.verbose:
            logger.log("Policy is optimizing...")
        # This needs to take all samples_data so that it can construct graph for meta-optimization.
        #time_algo_opt = time.time()
        self.algo.optimize_policy(all_samples_data, prefix='Policy-')
        #time_algo_opt = time.time() - time_algo_opt

        time_step = time.time() - time_step
        self.policy = self.model_sampler.policy

        logger.logkv('Policy-TimeStep', time_step)
Exemple #14
0
    def optimize_policy(self,
                        samples_data,
                        log=True,
                        prefix='',
                        verbose=False):
        """
        Performs MAML outer step

        Args:
            samples_data (list) : list of lists of lists of samples (each is a dict) split by gradient update and
             meta task
            log (bool) : whether to log statistics

        Returns:
            None
        """
        input_dict = self._extract_input_dict(samples_data,
                                              self._optimization_keys,
                                              prefix='train')

        if verbose:
            logger.log("Computing KL before")
        mean_kl_before = self.optimizer.constraint_val(
            input_val_dict=input_dict)

        if verbose:
            logger.log("Computing loss before")
        loss_before = self.optimizer.loss(input_val_dict=input_dict)
        if verbose:
            logger.log("Optimizing")
        self.optimizer.optimize(input_val_dict=input_dict)
        if verbose:
            logger.log("Computing loss after")
        loss_after = self.optimizer.loss(input_val_dict=input_dict)

        if verbose:
            logger.log("Computing KL after")
        mean_kl = self.optimizer.constraint_val(input_val_dict=input_dict)
        if log:
            logger.logkv(prefix + 'MeanKLBefore', mean_kl_before)
            logger.logkv(prefix + 'MeanKL', mean_kl)

            logger.logkv(prefix + 'LossBefore', loss_before)
            logger.logkv(prefix + 'LossAfter', loss_after)
            logger.logkv(prefix + 'dLoss', loss_before - loss_after)
Exemple #15
0
    def step(self, random=False):
        time_step = time.time()
        '''------------- Obtaining samples from the environment -----------'''

        if self.verbose:
            logger.log("Data is obtaining samples...")
        env_paths = self.env_sampler.obtain_samples(
            log=True,
            random=random,
            log_prefix='Data-EnvSampler-',
        )
        '''-------------- Processing environment samples -------------------'''

        if self.verbose:
            logger.log("Data is processing samples...")
        if type(env_paths) is dict or type(env_paths) is OrderedDict:
            env_paths = list(env_paths.values())
            idxs = np.random.choice(range(len(env_paths)),
                                    size=self.num_rollouts_per_iter,
                                    replace=False)
            env_paths = sum([env_paths[idx] for idx in idxs], [])

        elif type(env_paths) is list:
            idxs = np.random.choice(range(len(env_paths)),
                                    size=self.num_rollouts_per_iter,
                                    replace=False)
            env_paths = [env_paths[idx] for idx in idxs]

        else:
            raise TypeError
        samples_data = self.dynamics_sample_processor.process_samples(
            env_paths,
            log=True,
            log_prefix='Data-EnvTrajs-',
        )

        self.samples_data_arr.append(samples_data)
        time_step = time.time() - time_step

        time_sleep = max(self.simulation_sleep - time_step, 0)
        time.sleep(time_sleep)

        logger.logkv('Data-TimeStep', time_step)
        logger.logkv('Data-TimeSleep', time_sleep)
Exemple #16
0
    def step(self):
        '''
        In sequential order, is "samples_data" accumulated???
        Outcome: dynamics model is updated with self.samples_data.?
        '''
        time_fit_start = time.time()
        ''' --------------- fit dynamics model --------------- '''

        logger.log("Training dynamics model for %i epochs ..." %
                   (self.dynamics_model_max_epochs))
        dynamics_model, samples_data = self.server.pull.remote(
            ['dynamics_model', 'samples_data'])
        self.dynamics_model.fit(samples_data['observations'],
                                samples_data['actions'],
                                samples_data['next_observations'],
                                epochs=self.dynamics_model_max_epochs,
                                verbose=False,
                                log_tabular=True)

        self.server.push.remote('dynamics_data', ray.put(dynamics_model))
Exemple #17
0
    def step(self, obs=None, act=None, obs_next=None):
        time_model_fit = time.time()
        """ --------------- fit dynamics model --------------- """

        if self.verbose:
            logger.log(
                'Model at iteration {} is training for one epoch...'.format(
                    self.itr_counter))
        self.remaining_model_idx, self.valid_loss_rolling_average = self.dynamics_model.fit_one_epoch(
            remaining_model_idx=self.remaining_model_idx,
            valid_loss_rolling_average_prev=self.valid_loss_rolling_average,
            with_new_data=self.with_new_data,
            verbose=self.verbose,
            log_tabular=True,
            prefix='Model-',
        )
        self.with_new_data = False
        time_model_fit = time.time() - time_model_fit

        logger.logkv('Model-TimeStep', time_model_fit)
Exemple #18
0
    def optimize_supervised(self,
                            samples_data,
                            log=True,
                            prefix='',
                            verbose=False):
        input_dict = self._extract_input_dict(samples_data,
                                              self._optimization_keys,
                                              prefix='train')
        self.optimizer_s.compute_loss_variations(input_dict, None, None,
                                                 self.log_values_sup)

        if verbose: logger.log("Optimizing Supervised Model")
        loss_before = self.optimizer_s.optimize(input_val_dict=input_dict)

        if verbose: logger.log("Computing statistics")
        loss_after = self.optimizer_s.loss(input_val_dict=input_dict)

        if log:
            logger.logkv(prefix + 'SupervisedLossBefore', loss_before)
            logger.logkv(prefix + 'SupervisedLossAfter', loss_after)
Exemple #19
0
    def step(self):
        time_step = time.time()
        """ -------------------- Sampling --------------------------"""

        if self.verbose:
            logger.log("Policy is obtaining samples ...")
        paths = self.model_sampler.obtain_samples(log=True,
                                                  log_prefix='Policy-')
        """ ----------------- Processing Samples ---------------------"""

        if self.verbose:
            logger.log("Policy is processing samples ...")
        samples_data = self.model_sample_processor.process_samples(
            paths, log='all', log_prefix='Policy-')

        if type(paths) is list:
            self.log_diagnostics(paths, prefix='Policy-')
        else:
            self.log_diagnostics(sum(paths.values(), []), prefix='Policy-')
        """ ------------------ Policy Update ---------------------"""

        if self.verbose:
            logger.log("Policy optimization...")
        # This needs to take all samples_data so that it can construct graph for meta-optimization.
        self.algo.optimize_policy(samples_data,
                                  log=True,
                                  verbose=self.verbose,
                                  prefix='Policy-')

        self.policy = self.model_sampler.policy
        time_step = time.time() - time_step

        logger.logkv('Policy-TimeStep', time_step)
Exemple #20
0
    def step(self):
        """
        Uses self.model_sampler which is asynchrounously updated by worker_model.
        Outcome: policy is updated by PPO on one fictitious trajectory. 
        """

        itr_start_time = time.time()
        """ -------------------- Sampling --------------------------"""

        logger.log("Obtaining samples from the model...")
        time_env_sampling_start = time.time()
        model_sampler = self.server.pull.remote(['model_sampler'])[0]
        paths = model_sampler.obtain_samples(log=True, log_prefix='train-')
        sampling_time = time.time() - time_env_sampling_start
        """ ----------------- Processing Samples ---------------------"""

        logger.log("Processing samples from the model...")
        time_proc_samples_start = time.time()
        samples_data = self.model_sample_processor.process_samples(
            paths, log='all', log_prefix='train-')
        proc_samples_time = time.time() - time_proc_samples_start
        """ ------------------ Policy Update ---------------------"""

        logger.log("Optimizing policy...")
        # This needs to take all samples_data so that it can construct graph for meta-optimization.
        time_optimization_step_start = time.time()
        self.algo.optimize_policy(samples_data)
        optimization_time = time.time() - time_optimization_step_start

        #        times_dyn_sampling.append(sampling_time)
        #        times_dyn_sample_processing.append(proc_samples_time)
        #        times_optimization.append(optimization_time)
        #        times_step.append(time.time() - itr_start_time)

        return None
Exemple #21
0
    def optimize_policy(self, all_samples_data, log=True):
        """
        Performs MAML outer step

        Args:
            all_samples_data (list) : list of lists of lists of samples (each is a dict) split by gradient update and
             meta task
            log (bool) : whether to log statistics

        Returns:
            None
        """
        meta_op_input_dict = self._extract_input_dict_meta_op(
            all_samples_data, self._optimization_keys)

        # add kl_coeffs / clip_eps to meta_op_input_dict
        meta_op_input_dict['inner_kl_coeff'] = self.inner_kl_coeff
        if self.clip_outer:
            meta_op_input_dict['clip_eps'] = self.clip_eps
        else:
            meta_op_input_dict['outer_kl_coeff'] = self.outer_kl_coeff

        if log: logger.log("Optimizing")
        loss_before = self.optimizer.optimize(
            input_val_dict=meta_op_input_dict)

        if log: logger.log("Computing statistics")
        loss_after, inner_kls, outer_kl = self.optimizer.compute_stats(
            input_val_dict=meta_op_input_dict)

        if self.adaptive_inner_kl_penalty:
            if log: logger.log("Updating inner KL loss coefficients")
            self.inner_kl_coeff = self.adapt_kl_coeff(self.inner_kl_coeff,
                                                      inner_kls,
                                                      self.target_inner_step)

        if self.adaptive_outer_kl_penalty:
            if log: logger.log("Updating outer KL loss coefficients")
            self.outer_kl_coeff = self.adapt_kl_coeff(self.outer_kl_coeff,
                                                      outer_kl,
                                                      self.target_outer_step)

        if log:
            logger.logkv('LossBefore', loss_before)
            logger.logkv('LossAfter', loss_after)
            logger.logkv('KLInner', np.mean(inner_kls))
            logger.logkv('KLCoeffInner', np.mean(self.inner_kl_coeff))
            if not self.clip_outer: logger.logkv('KLOuter', outer_kl)
Exemple #22
0
    def train(self):
        """
        Trains policy on env using algo
        """
        worker_data_queue, worker_model_queue, worker_policy_queue = self.queues
        worker_data_remote, worker_model_remote, worker_policy_remote = self.remotes

        for p in self.ps:
            p.start()
        ''' --------------- worker warm-up --------------- '''

        logger.log('Prepare start...')

        worker_data_remote.send('prepare start')
        worker_data_queue.put(self.initial_random_samples)
        assert worker_data_remote.recv() == 'loop ready'

        worker_model_remote.send('prepare start')
        assert worker_model_remote.recv() == 'loop ready'

        worker_policy_remote.send('prepare start')
        assert worker_policy_remote.recv() == 'loop ready'

        time_total = time.time()
        ''' --------------- worker looping --------------- '''

        logger.log('Start looping...')
        for remote in self.remotes:
            remote.send('start loop')
        ''' --------------- collect info --------------- '''

        for remote in self.remotes:
            assert remote.recv() == 'loop done'
        logger.log('\n------------all workers exit loops -------------')
        for remote in self.remotes:
            assert remote.recv() == 'worker closed'

        for p in self.ps:
            p.terminate()

        logger.logkv('Trainer-TimeTotal', time.time() - time_total)
        logger.dumpkvs()
        logger.log("*****Training finished")
Exemple #23
0
    def step(self, random_sinusoid=(False, False)):
        time_step = time.time()

        if self.itr_counter == 1 and self.env_sampler.policy.dynamics_model.normalization is None:
            if self.verbose:
                logger.log('Data starts first step...')
            self.env_sampler.policy.dynamics_model = pickle.loads(
                self.queue.get())
            if self.verbose:
                logger.log('Data first step done...')
        '''------------- Obtaining samples from the environment -----------'''

        if self.verbose:
            logger.log("Data is obtaining samples...")
        env_paths = self.env_sampler.obtain_samples(
            log=True,
            random=random_sinusoid[0],
            sinusoid=random_sinusoid[1],
            log_prefix='Data-EnvSampler-',
        )
        '''-------------- Processing environment samples -------------------'''

        if self.verbose:
            logger.log("Data is processing samples...")
        samples_data = self.dynamics_sample_processor.process_samples(
            env_paths,
            log=True,
            log_prefix='Data-EnvTrajs-',
        )

        self.samples_data_arr.append(samples_data)
        time_step = time.time() - time_step

        time_sleep = max(self.simulation_sleep - time_step, 0)
        time.sleep(time_sleep)

        logger.logkv('Data-TimeStep', time_step)
        logger.logkv('Data-TimeSleep', time_sleep)
Exemple #24
0
    def process_queue(self):
        do_push = 0
        samples_data_arr = []
        while True:
            try:
                if not self.remaining_model_idx:
                    logger.log(
                        'Model at iteration {} is block waiting for data'.
                        format(self.itr_counter))
                    # FIXME: check stop_cond
                    time_wait = time.time()
                    samples_data_arr_pickle = self.queue.get()
                    time_wait = time.time() - time_wait
                    logger.logkv('Model-TimeBlockWait', time_wait)
                    self.remaining_model_idx = list(
                        range(self.dynamics_model.num_models))
                else:
                    if self.verbose:
                        logger.log('Model try get_nowait.........')
                    samples_data_arr_pickle = self.queue.get_nowait()
                if samples_data_arr_pickle == 'push':
                    # Only push once before executing another step
                    if do_push == 0:
                        do_push = 1
                        self.push()
                else:
                    samples_data_arr.extend(
                        pickle.loads(samples_data_arr_pickle))
            except Empty:
                break

        do_synch = len(samples_data_arr)
        if do_synch:
            self._synch(samples_data_arr)

        do_step = 1

        if self.verbose:
            logger.log(
                'Model finishes processing queue with {}, {}, {}......'.format(
                    do_push, do_synch, do_step))

        return do_push, do_synch, do_step
Exemple #25
0
    def fit(self,
            obs,
            ret,
            epochs=1000,
            compute_normalization=True,
            verbose=False,
            valid_split_ratio=None,
            rolling_average_persitency=None,
            log_tabular=False):
        """
        Fits the NN dynamics model
        :param obs: observations - numpy array of shape (n_samples, ndim_obs)
        :param act: actions - numpy array of shape (n_samples, ndim_act)
        :param obs_next: observations after taking action - numpy array of shape (n_samples, ndim_obs)
        :param epochs: number of training epochs
        :param compute_normalization: boolean indicating whether normalization shall be (re-)computed given the data
        :param valid_split_ratio: relative size of validation split (float between 0.0 and 1.0)
        :param verbose: logging verbosity
        """
        assert obs.ndim == 2 and obs.shape[1] == self.obs_space_dims

        if valid_split_ratio is None:
            valid_split_ratio = self.valid_split_ratio
        if rolling_average_persitency is None:
            rolling_average_persitency = self.rolling_average_persitency

        assert 1 > valid_split_ratio >= 0

        sess = tf.get_default_session()

        if (self.normalization is None
                or compute_normalization) and self.normalize_input:
            self.compute_normalization(obs, ret)

        if self.normalize_input:
            # normalize data
            obs, ret = self._normalize_data(obs, ret)
            assert obs.ndim == 2 and ret.ndim == 1

        # split into valid and test set

        obs_train, ret_train, obs_test, ret_test = train_test_split(
            obs, ret, test_split_ratio=valid_split_ratio)

        if self._dataset_test is None:
            self._dataset_test = dict(obs=obs_test, ret=ret_test)
            self._dataset_train = dict(obs=obs_train, ret=ret_train)

        else:
            # n_test_new_samples = len(obs_test)
            # n_max_test = self.buffer_size - n_test_new_samples
            # n_train_new_samples = len(obs_train)
            # n_max_train = self.buffer_size - n_train_new_samples
            # self._dataset_test['obs'] = np.concatenate([self._dataset_test['obs'][-n_max_test:], obs_test])
            # self._dataset_test['ret'] = np.concatenate([self._dataset_test['ret'][-n_max_test:], ret_test])
            #
            # self._dataset_train['obs'] = np.concatenate([self._dataset_train['obs'][-n_max_train:], obs_train])
            # self._dataset_train['ret'] = np.concatenate([self._dataset_train['ret'][-n_max_train:], ret_train])
            # FIXME: Hack so it always has on-policy samples

            self._dataset_test = dict(obs=obs_test, ret=ret_test)
            self._dataset_train = dict(obs=obs_train, ret=ret_train)

        # Create data queue
        if self.next_batch is None:
            self.next_batch, self.iterator = self._data_input_fn(
                self._dataset_train['obs'],
                self._dataset_train['ret'],
                batch_size=self.batch_size,
                buffer_size=self.buffer_size)

        if (self.normalization is None
                or compute_normalization) and self.normalize_input:
            self.compute_normalization(self._dataset_train['obs'],
                                       self._dataset_train['ret'])

        if self.normalize_input:
            # Normalize data
            obs_train, ret_train = self._normalize_data(
                self._dataset_train['obs'], self._dataset_train['ret'])
            assert obs.ndim == 2 and ret.ndim == 1
        else:
            obs_train, ret_train = self._dataset_train[
                'obs'], self._dataset_train['ret']

        valid_loss_rolling_average = None

        # Training loop
        for epoch in range(epochs):

            # initialize data queue
            sess.run(self.iterator.initializer,
                     feed_dict={
                         self.obs_dataset_ph: obs_train,
                         self.ret_dataset_ph: ret_train
                     })

            batch_losses = []
            while True:
                try:
                    obs_batch, rest_batch = sess.run(self.next_batch)

                    # run train op
                    batch_loss, _ = sess.run([self.loss, self.train_op],
                                             feed_dict={
                                                 self.obs_ph: obs_batch,
                                                 self.ret_ph: rest_batch
                                             })
                    batch_losses.append(batch_loss)

                except tf.errors.OutOfRangeError:

                    if self.normalize_input:
                        # normalize data
                        obs_test, ret_test = self._normalize_data(
                            self._dataset_test['obs'],
                            self._dataset_test['ret'])
                    else:
                        obs_test, ret_test = self._dataset_test[
                            'obs'], self._dataset_test['ret']

                        # compute validation loss
                    valid_loss = sess.run(self.loss,
                                          feed_dict={
                                              self.obs_ph: obs_test,
                                              self.ret_ph: ret_test
                                          })

                    if valid_loss_rolling_average is None:
                        valid_loss_rolling_average = 1.5 * valid_loss  # set initial rolling to a higher value avoid too early stopping
                        valid_loss_rolling_average_prev = 2.0 * valid_loss

                    valid_loss_rolling_average = rolling_average_persitency * valid_loss_rolling_average + (
                        1.0 - rolling_average_persitency) * valid_loss

                    if verbose:
                        logger.log(
                            "Training NNDynamicsModel - finished epoch %i -- train loss: %.4f  valid loss: %.4f  valid_loss_mov_avg: %.4f"
                            % (epoch, float(np.mean(batch_losses)), valid_loss,
                               valid_loss_rolling_average))
                    break

            if valid_loss_rolling_average_prev < valid_loss_rolling_average:
                logger.log(
                    'Stopping DynamicsEnsemble Training since valid_loss_rolling_average decreased'
                )
                break
            valid_loss_rolling_average_prev = valid_loss_rolling_average
Exemple #26
0
    def train(self):
        """
        Trains policy on env using algo

        Pseudocode:
            for itr in n_itr:
                for step in num_inner_grad_steps:
                    sampler.sample()
                    algo.compute_updated_dists()
                algo.optimize_policy()
                sampler.update_goals()
        """
        with self.sess.as_default() as sess:

            # initialize uninitialized vars  (only initialize vars that were not loaded)
            uninit_vars = [
                var for var in tf.global_variables()
                if not sess.run(tf.is_variable_initialized(var))
            ]
            sess.run(tf.variables_initializer(uninit_vars))

            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                logger.log(
                    "\n ---------------- Iteration %d ----------------" % itr)

                time_env_sampling_start = time.time()

                logger.log(
                    "Obtaining samples from the environment using the policy..."
                )
                env_paths = self.sampler.obtain_samples(log=True,
                                                        log_prefix='')

                logger.record_tabular('Time-EnvSampling',
                                      time.time() - time_env_sampling_start)
                logger.log("Processing environment samples...")

                # first processing just for logging purposes
                time_env_samp_proc = time.time()
                samples_data = self.sample_processor.process_samples(
                    env_paths, log=True, log_prefix='EnvTrajs-')

                logger.record_tabular('Time-EnvSampleProc',
                                      time.time() - time_env_samp_proc)
                ''' --------------- fit dynamics model --------------- '''

                time_fit_start = time.time()

                logger.log("Training dynamics model for %i epochs ..." %
                           self.dynamics_model_max_epochs)
                self.dynamics_model.fit(samples_data['observations'],
                                        samples_data['actions'],
                                        samples_data['next_observations'],
                                        epochs=self.dynamics_model_max_epochs,
                                        verbose=False,
                                        log_tabular=True,
                                        early_stopping=True,
                                        compute_normalization=False)

                logger.log("Training the value function for %i epochs ..." %
                           self.vfun_max_epochs)
                self.value_function.fit(samples_data['observations'],
                                        samples_data['returns'],
                                        epochs=self.vfun_max_epochs,
                                        verbose=False,
                                        log_tabular=True,
                                        compute_normalization=False)

                logger.log("Training the policy ...")
                self.algo.optimize_policy(samples_data)

                logger.record_tabular('Time-ModelFit',
                                      time.time() - time_fit_start)
                """ ------------------- Logging Stuff --------------------------"""
                logger.logkv('Itr', itr)
                logger.logkv('n_timesteps',
                             self.sampler.total_timesteps_sampled)

                logger.logkv('Time', time.time() - start_time)
                logger.logkv('ItrTime', time.time() - itr_start_time)

                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr)
                self.log_diagnostics(env_paths, '')
                logger.save_itr_params(itr, params)
                logger.log("Saved")

                logger.dumpkvs()
                if itr == 0:
                    sess.graph.finalize()

        logger.log("Training finished")
        self.sess.close()
Exemple #27
0
    def train(self):
        """
        Trains policy on env using algo

        Pseudocode:
            for itr in n_itr:
                for step in num_inner_grad_steps:
                    sampler.sample()
                    algo.compute_updated_dists()
                algo.optimize_policy()
                sampler.update_goals()
        """
        with self.sess.as_default() as sess:

            # initialize uninitialized vars  (only initialize vars that were not loaded)
            # uninit_vars = [var for var in tf.global_variables() if not sess.run(tf.is_variable_initialized(var))]
            # sess.run(tf.variables_initializer(uninit_vars))
            sess.run(tf.global_variables_initializer())

            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                logger.log(
                    "\n ---------------- Iteration %d ----------------" % itr)

                time_env_sampling_start = time.time()

                if self.initial_random_samples and itr == 0:
                    logger.log(
                        "Obtaining random samples from the environment...")
                    env_paths = self.env_sampler.obtain_samples(
                        log=True, random=True, log_prefix='Data-EnvSampler-')

                else:
                    logger.log(
                        "Obtaining samples from the environment using the policy..."
                    )
                    env_paths = self.env_sampler.obtain_samples(
                        log=True, log_prefix='Data-EnvSampler-')

                # Add sleeping time to match parallel experiment
                # time.sleep(10)

                logger.record_tabular('Data-TimeEnvSampling',
                                      time.time() - time_env_sampling_start)
                logger.log("Processing environment samples...")

                # first processing just for logging purposes
                time_env_samp_proc = time.time()

                samples_data = self.dynamics_sample_processor.process_samples(
                    env_paths, log=True, log_prefix='Data-EnvTrajs-')

                self.env.log_diagnostics(env_paths, prefix='Data-EnvTrajs-')

                logger.record_tabular('Data-TimeEnvSampleProc',
                                      time.time() - time_env_samp_proc)
                ''' --------------- fit dynamics model --------------- '''

                time_fit_start = time.time()

                self.dynamics_model.update_buffer(
                    samples_data['observations'],
                    samples_data['actions'],
                    samples_data['next_observations'],
                    check_init=True)

                buffer = None if not self.sample_from_buffer else samples_data

                logger.record_tabular('Model-TimeModelFit',
                                      time.time() - time_fit_start)
                ''' --------------- MAML steps --------------- '''
                times_dyn_sampling = []
                times_dyn_sample_processing = []
                times_optimization = []
                times_step = []
                remaining_model_idx = list(
                    range(self.dynamics_model.num_models))
                valid_loss_rolling_average_prev = None

                with_new_data = True
                for id_step in range(self.repeat_steps):

                    for epoch in range(self.num_epochs_per_step):
                        logger.log(
                            "Training dynamics model for %i epochs ..." % 1)
                        remaining_model_idx, valid_loss_rolling_average = self.dynamics_model.fit_one_epoch(
                            remaining_model_idx,
                            valid_loss_rolling_average_prev,
                            with_new_data,
                            log_tabular=True,
                            prefix='Model-')
                        with_new_data = False

                    for step in range(self.num_grad_policy_per_step):

                        logger.log(
                            "\n ---------------- Grad-Step %d ----------------"
                            % int(itr * self.repeat_steps *
                                  self.num_grad_policy_per_step + id_step *
                                  self.num_grad_policy_per_step + step))
                        step_start_time = time.time()
                        """ -------------------- Sampling --------------------------"""

                        logger.log("Obtaining samples from the model...")
                        time_env_sampling_start = time.time()
                        paths = self.model_sampler.obtain_samples(
                            log=True, log_prefix='Policy-', buffer=buffer)
                        sampling_time = time.time() - time_env_sampling_start
                        """ ----------------- Processing Samples ---------------------"""

                        logger.log("Processing samples from the model...")
                        time_proc_samples_start = time.time()
                        samples_data = self.model_sample_processor.process_samples(
                            paths, log='all', log_prefix='Policy-')
                        proc_samples_time = time.time(
                        ) - time_proc_samples_start

                        if type(paths) is list:
                            self.log_diagnostics(paths, prefix='Policy-')
                        else:
                            self.log_diagnostics(sum(paths.values(), []),
                                                 prefix='Policy-')
                        """ ------------------ Policy Update ---------------------"""

                        logger.log("Optimizing policy...")
                        # This needs to take all samples_data so that it can construct graph for meta-optimization.
                        time_optimization_step_start = time.time()
                        self.algo.optimize_policy(samples_data)
                        optimization_time = time.time(
                        ) - time_optimization_step_start

                        times_dyn_sampling.append(sampling_time)
                        times_dyn_sample_processing.append(proc_samples_time)
                        times_optimization.append(optimization_time)
                        times_step.append(time.time() - step_start_time)
                """ ------------------- Logging Stuff --------------------------"""
                logger.logkv('Iteration', itr)
                logger.logkv('n_timesteps',
                             self.env_sampler.total_timesteps_sampled)
                logger.logkv('Policy-TimeSampleProc',
                             np.sum(times_dyn_sample_processing))
                logger.logkv('Policy-TimeSampling', np.sum(times_dyn_sampling))
                logger.logkv('Policy-TimeAlgoOpt', np.sum(times_optimization))
                logger.logkv('Policy-TimeStep', np.sum(times_step))

                logger.logkv('Time', time.time() - start_time)
                logger.logkv('ItrTime', time.time() - itr_start_time)

                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr)
                logger.save_itr_params(itr, params)
                logger.log("Saved")

                logger.dumpkvs()
                if itr == 0:
                    sess.graph.finalize()

            logger.logkv('Trainer-TimeTotal', time.time() - start_time)

        logger.log("Training finished")
        self.sess.close()
Exemple #28
0
from meta_mb.logger import logger
Exemple #29
0
    def train(self):
        """
        Trains policy on env using algo

        Pseudocode:
            for itr in n_itr:
                for step in num_inner_grad_steps:
                    sampler.sample()
                    algo.compute_updated_dists()
                algo.optimize_policy()
                sampler.update_goals()
        """
        start_time = time.time()
        rollout_time = 0
        success_rate = 0
        itrs_on_level = 0

        for itr in range(self.start_itr, self.n_itr):

            # high_entropy = False
            # # At the very beginning of a new level, have high entropy
            # if itrs_on_level < 3:
            #     high_entropy = True
            # # At the very beginning, have high entropy
            # if self.curriculum_step == 0 and itrs_on_level < 20:
            #     high_entropy = True
            # # If the model never succeeds
            # if itrs_on_level > 20 and success_rate < .2:
            #     high_entropy = True
            # logger.logkv("HighE", high_entropy)
            logger.logkv("ItrsOnLEvel", itrs_on_level)
            itrs_on_level += 1

            itr_start_time = time.time()
            logger.log("\n ---------------- Iteration %d ----------------" % itr)
            logger.log("Sampling set of tasks/goals for this meta-batch...")

            """ -------------------- Sampling --------------------------"""

            logger.log("Obtaining samples...")
            time_env_sampling_start = time.time()
            samples_data, episode_logs = self.algo.collect_experiences(self.teacher_train_dict)
            assert len(samples_data.action.shape) == 1, (samples_data.action.shape)
            time_collection = time.time() - time_env_sampling_start
            time_training_start = time.time()
            # if high_entropy:
            #     entropy = self.algo.entropy_coef# * 10
            # else:
            #     entropy = self.algo.entropy_coef
            summary_logs = self.algo.optimize_policy(samples_data, teacher_dict=self.teacher_train_dict)
            time_training = time.time() - time_training_start
            self._log(episode_logs, summary_logs, tag="Train")
            logger.logkv('Curriculum Step', self.curriculum_step)
            advance_curriculum = self.check_advance_curriculum(episode_logs, summary_logs)
            logger.logkv('Train/Advance', int(advance_curriculum))
            time_env_sampling = time.time() - time_env_sampling_start
            #
            # """ ------------------ Reward Predictor Splicing ---------------------"""
            rp_start_time = time.time()
            # samples_data = self.use_reward_predictor(samples_data)  # TODO: update
            rp_splice_time = time.time() - rp_start_time

            """ ------------------ End Reward Predictor Splicing ---------------------"""

            """ ------------------ Policy Update ---------------------"""

            logger.log("Optimizing policy...")
            # # This needs to take all samples_data so that it can construct graph for meta-optimization.
            time_rp_train_start = time.time()
            # self.train_rp(samples_data)
            time_rp_train = time.time() - time_rp_train_start

            """ ------------------ Distillation ---------------------"""
            if False:  # TODO: remove False once I've checked things work with dictionaries
            # if self.supervised_model is not None and advance_curriculum:
                time_distill_start = time.time()
                for _ in range(3):  # TODO: tune this!
                    distill_log = self.distill(samples_data, is_training=True)
                for k, v in distill_log.items():
                    logger.logkv(f"Distill/{k}_Train", v)
                distill_time = time.time() - time_distill_start
                advance_curriculum = distill_log['Accuracy'] >= self.accuracy_threshold
                logger.logkv('Distill/Advance', int(advance_curriculum))
            else:
                distill_time = 0

            """ ------------------ Policy rollouts ---------------------"""
            run_policy_time = 0
            if False:  # TODO: remove False once I've checkeo things work with dictionaries
            # if advance_curriculum or (itr % self.eval_every == 0) or (itr == self.n_itr - 1):  # TODO: collect rollouts with and without the teacher
                train_advance_curriculum = advance_curriculum
                with torch.no_grad():
                    if self.supervised_model is not None:
                        # Distilled model
                        time_run_supervised_start = time.time()
                        self.sampler.supervised_model.reset(dones=[True] * len(samples_data.obs))
                        logger.log("Running supervised model")
                        advance_curriculum_sup = self.run_supervised(self.il_trainer.acmodel, self.no_teacher_dict, "DRollout/")
                        run_supervised_time = time.time() - time_run_supervised_start
                    else:
                        run_supervised_time = 0
                        advance_curriculum_sup = True
                    # Original Policy
                    time_run_policy_start = time.time()
                    self.algo.acmodel.reset(dones=[True] * len(samples_data.obs))
                    logger.log("Running model with teacher")
                    advance_curriculum_policy = self.run_supervised(self.algo.acmodel, self.teacher_train_dict, "Rollout/")
                    run_policy_time = time.time() - time_run_policy_start

                    advance_curriculum = advance_curriculum_policy and advance_curriculum_sup and train_advance_curriculum
                    print("ADvancing curriculum???", advance_curriculum)

                    logger.logkv('Advance', int(advance_curriculum))
            else:
                run_supervised_time = 0
                # advance_curriculum = False

            """ ------------------- Logging Stuff --------------------------"""
            logger.logkv('Itr', itr)
            logger.logkv('n_timesteps', self.sampler.total_timesteps_sampled)

            logger.logkv('Time/Total', time.time() - start_time)
            logger.logkv('Time/Itr', time.time() - itr_start_time)

            logger.logkv('Curriculum Percent', self.curriculum_step / len(self.env.levels_list))

            process = psutil.Process(os.getpid())
            memory_use = process.memory_info().rss / float(2 ** 20)
            logger.logkv('Memory MiB', memory_use)

            logger.log(self.exp_name)

            logger.logkv('Time/Sampling', time_env_sampling)
            logger.logkv('Time/Training', time_training)
            logger.logkv('Time/Collection', time_collection)
            logger.logkv('Time/RPUse', rp_splice_time)
            logger.logkv('Time/RPTrain', time_rp_train)
            logger.logkv('Time/RunwTeacher', run_policy_time)
            logger.logkv('Time/Distillation', distill_time)
            logger.logkv('Time/RunDistilled', run_supervised_time)
            logger.logkv('Time/VidRollout', rollout_time)
            logger.dumpkvs()

            """ ------------------ Video Saving ---------------------"""

            should_save_video = (itr % self.save_videos_every == 0) or (itr == self.n_itr - 1) or advance_curriculum
            if should_save_video:
            # TODO: consider generating videos with every element in the powerset of feedback types
                time_rollout_start = time.time()
                if self.supervised_model is not None:
                    self.il_trainer.acmodel.reset(dones=[True])
                    self.save_videos(itr, self.il_trainer.acmodel, save_name='distilled_video_stoch',
                                   num_rollouts=5,
                                   teacher_dict=self.teacher_train_dict,
                                   save_video=should_save_video, log_prefix="DVidRollout/Stoch", stochastic=True)
                    self.il_trainer.acmodel.reset(dones=[True])
                    self.save_videos(itr, self.il_trainer.acmodel, save_name='distilled_video_det',
                                     num_rollouts=5,
                                   teacher_dict=self.no_teacher_dict,
                                     save_video=should_save_video, log_prefix="DVidRollout/Det", stochastic=False)

                self.algo.acmodel.reset(dones=[True])
                self.save_videos(self.curriculum_step, self.algo.acmodel, save_name='withTeacher_video_stoch',
                               num_rollouts=5,
                               teacher_dict=self.teacher_train_dict,
                               save_video=should_save_video, log_prefix="VidRollout/Stoch", stochastic=True)
                self.algo.acmodel.reset(dones=[True])
                self.save_videos(self.curriculum_step, self.algo.acmodel, save_name='withTeacher_video_det',
                                 num_rollouts=5,
                                   teacher_dict=self.no_teacher_dict,
                                 save_video=should_save_video, log_prefix="VidRollout/Det", stochastic=False)
                rollout_time = time.time() - time_rollout_start
            else:
                rollout_time = 0

            params = self.get_itr_snapshot(itr)
            step = self.curriculum_step

            if self.log_and_save:
                if (itr % self.save_every == 0) or (itr == self.n_itr - 1) or advance_curriculum:
                    logger.log("Saving snapshot...")
                    logger.save_itr_params(itr, step, params)
                    logger.log("Saved")

            if advance_curriculum and self.advance_levels:
                self.curriculum_step += 1
                self.sampler.advance_curriculum()
                self.algo.advance_curriculum()
                # self.algo.set_optimizer()
                itrs_on_level = 0


        logger.log("Training finished")
    def train(self):
        """
        Trains policy on env using algo

        Pseudocode:
            for itr in n_itr:
                for step in num_inner_grad_steps:
                    sampler.sample()
                    algo.compute_updated_dists()
                algo.optimize_policy()
                sampler.update_goals()
        """
        with self.sess.as_default() as sess:

            # initialize uninitialized vars  (only initialize vars that were not loaded)
            # uninit_vars = [var for var in tf.global_variables() if not sess.run(tf.is_variable_initialized(var))]
            # sess.run(tf.variables_initializer(uninit_vars))
            sess.run(tf.global_variables_initializer())

            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                logger.log(
                    "\n ---------------- Iteration %d ----------------" % itr)

                time_env_sampling_start = time.time()

                if itr == 0:
                    logger.log(
                        "Obtaining random samples from the environment...")
                    self.env_sampler.total_samples *= self.num_rollouts_per_iter
                    env_paths = self.env_sampler.obtain_samples(
                        log=True,
                        random=self.initial_random_samples,
                        log_prefix='Data-EnvSampler-',
                        verbose=True)
                    self.env_sampler.total_samples /= self.num_rollouts_per_iter

                time_env_samp_proc = time.time()
                samples_data = self.dynamics_sample_processor.process_samples(
                    env_paths, log=True, log_prefix='Data-EnvTrajs-')
                self.env.log_diagnostics(env_paths, prefix='Data-EnvTrajs-')
                logger.record_tabular('Data-TimeEnvSampleProc',
                                      time.time() - time_env_samp_proc)

                buffer = samples_data if self.sample_from_buffer else None
                ''' --------------- fit dynamics model --------------- '''
                logger.log("Training dynamics model for %i epochs ..." %
                           self.dynamics_model_max_epochs)
                time_fit_start = time.time()
                self.dynamics_model.fit(samples_data['observations'],
                                        samples_data['actions'],
                                        samples_data['next_observations'],
                                        epochs=self.dynamics_model_max_epochs,
                                        verbose=False,
                                        log_tabular=True,
                                        prefix='Model-')

                logger.record_tabular('Model-TimeModelFit',
                                      time.time() - time_fit_start)

                env_paths = []
                for id_rollout in range(self.num_rollouts_per_iter):
                    times_dyn_sampling = []
                    times_dyn_sample_processing = []
                    times_optimization = []
                    times_step = []

                    grad_steps_per_rollout = self.grad_steps_per_rollout
                    for step in range(grad_steps_per_rollout):

                        # logger.log("\n ---------------- Grad-Step %d ----------------" % int(grad_steps_per_rollout*itr*self.num_rollouts_per_iter,
                        #                                                             + id_rollout * grad_steps_per_rollout + step))
                        step_start_time = time.time()
                        """ -------------------- Sampling --------------------------"""

                        logger.log("Obtaining samples from the model...")
                        time_env_sampling_start = time.time()
                        paths = self.model_sampler.obtain_samples(
                            log=True, log_prefix='Policy-', buffer=buffer)
                        sampling_time = time.time() - time_env_sampling_start
                        """ ----------------- Processing Samples ---------------------"""

                        logger.log("Processing samples from the model...")
                        time_proc_samples_start = time.time()
                        samples_data = self.model_sample_processor.process_samples(
                            paths, log='all', log_prefix='Policy-')
                        proc_samples_time = time.time(
                        ) - time_proc_samples_start

                        if type(paths) is list:
                            self.log_diagnostics(paths, prefix='Policy-')
                        else:
                            self.log_diagnostics(sum(paths.values(), []),
                                                 prefix='Policy-')
                        """ ------------------ Policy Update ---------------------"""

                        logger.log("Optimizing policy...")
                        time_optimization_step_start = time.time()
                        self.algo.optimize_policy(samples_data)
                        optimization_time = time.time(
                        ) - time_optimization_step_start

                        times_dyn_sampling.append(sampling_time)
                        times_dyn_sample_processing.append(proc_samples_time)
                        times_optimization.append(optimization_time)
                        times_step.append(time.time() - step_start_time)

                    logger.log(
                        "Obtaining random samples from the environment...")
                    env_paths.extend(
                        self.env_sampler.obtain_samples(
                            log=True,
                            log_prefix='Data-EnvSampler-',
                            verbose=True))

                logger.record_tabular('Data-TimeEnvSampling',
                                      time.time() - time_env_sampling_start)
                logger.log("Processing environment samples...")
                """ ------------------- Logging Stuff --------------------------"""
                logger.logkv('Iteration', itr)
                logger.logkv('n_timesteps',
                             self.env_sampler.total_timesteps_sampled)
                logger.logkv('Policy-TimeSampleProc',
                             np.sum(times_dyn_sample_processing))
                logger.logkv('Policy-TimeSampling', np.sum(times_dyn_sampling))
                logger.logkv('Policy-TimeAlgoOpt', np.sum(times_optimization))
                logger.logkv('Policy-TimeStep', np.sum(times_step))

                logger.logkv('Time', time.time() - start_time)
                logger.logkv('ItrTime', time.time() - itr_start_time)

                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr)
                logger.save_itr_params(itr, params)
                logger.log("Saved")

                logger.dumpkvs()
                if itr == 0:
                    sess.graph.finalize()

            logger.logkv('Trainer-TimeTotal', time.time() - start_time)

        logger.log("Training finished")
        self.sess.close()