예제 #1
0
    def _try_to_eval(self, epoch=0):
        if epoch % self.save_extra_data_interval == 0:
            logger.save_extra_data(self.get_extra_data_to_save(epoch), epoch)
        if self._can_evaluate():
            self.evaluate(epoch)

            params = self.get_epoch_snapshot(epoch)
            logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration.")
            self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
예제 #2
0
    def _try_to_eval(self, epoch):
        #logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            self.evaluate(epoch)
            #params = self.get_epoch_snapshot(epoch)
            #logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration.")
            self._old_table_keys = table_keys
            logger.record_tabular("Number of train steps total",
                                  self._n_train_steps_total)
            logger.record_tabular("Number of env steps total",
                                  self._n_env_steps_total)
            logger.record_tabular("Number of rollouts total",
                                  self._n_rollouts_total)

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
예제 #3
0
    def _try_to_eval(self, epoch):
        logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            if self.environment_farming:
                # Create new new eval_sampler each evaluation time in order to avoid relesed environment problem
                env_for_eval_sampler = self.farmer.force_acq_env()
                print(env_for_eval_sampler)
                self.eval_sampler = InPlacePathSampler(
                    env=env_for_eval_sampler,
                    policy=self.eval_policy,
                    max_samples=self.num_steps_per_eval + self.max_path_length,
                    max_path_length=self.max_path_length,
                )

            self.evaluate(epoch)

            # Adding env back to free_env list
            self.farmer.add_free_env(env_for_eval_sampler)

            params = self.get_epoch_snapshot(epoch)
            logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration.")
            self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
예제 #4
0
    def _try_to_eval(self, epoch):
        if epoch % self.freq_saving == 0:
            logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            self.evaluate(epoch)

            if epoch % self.freq_saving == 0:
                params = self.get_epoch_snapshot(epoch)
                logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()

            # logger.record_tabular(
            #     "Number of train steps total",
            #     self._n_policy_train_steps_total,
            # )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
예제 #5
0
    def _try_to_eval(self, epoch):
        logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            self.evaluate(epoch)

            params = self.get_epoch_snapshot(epoch)
            logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            #print("TABLE KEYS")
            #print(table_keys)
            #if self._old_table_keys is not None:
            #    assert table_keys == self._old_table_keys, (
            #        "Table keys cannot change from iteration to iteration."
            #    )
            self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)

            # tensorboard stuff
            _writer = self._writer
            for k, v_str in logger._tabular:

                if k == 'Epoch': continue

                v = float(v_str)
                if k.endswith('Loss'):
                    _writer.add_scalar('Loss/{}'.format(k), v, epoch)
                elif k.endswith('Max'):
                    prefix = k[:-4]
                    _writer.add_scalar('{}/{}'.format(prefix, k), v, epoch)
                elif k.endswith('Min'):
                    prefix = k[:-4]
                    _writer.add_scalar('{}/{}'.format(prefix, k), v, epoch)
                elif k.endswith('Std'):
                    prefix = k[:-4]
                    _writer.add_scalar('{}/{}'.format(prefix, k), v, epoch)
                elif k.endswith('Mean'):
                    prefix = k[:-5]
                    _writer.add_scalar('{}/{}'.format(prefix, k), v, epoch)
                elif 'Time' in k:
                    _writer.add_scalar('Time/{}'.format(k), v, epoch)
                elif k.startswith('Num'):
                    _writer.add_scalar('Number/{}'.format(k), v, epoch)
                elif k.startswith('Exploration'):
                    _writer.add_scalar('Exploration/{}'.format(k), v, epoch)
                elif k.startswith('Test'):
                    _writer.add_scalar('Test/{}'.format(k), v, epoch)
                else:
                    _writer.add_scalar(k, v, epoch)

            _writer.file_writer.flush()

            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
예제 #6
0
    def _try_to_eval(self, epoch, eval_paths=None):
        if MPI and MPI.COMM_WORLD.Get_rank() == 0:
            if epoch % self.save_extra_data_interval == 0:
                logger.save_extra_data(self.get_extra_data_to_save(epoch))

            if epoch % self.num_epochs_per_param_save == 0:
                print("Attemping itr param save...")
                params = self.get_epoch_snapshot(epoch)
                logger.save_itr_params(epoch, params)
                print(F"Itr{epoch} param saved!")

        if self._can_evaluate():
            self.evaluate(epoch, eval_paths=eval_paths)

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            # train_time = times_itrs['train'][-1]
            training_loops = ['get_batch', 'update_normalizer', 'forward', 'compute_losses', 'qf1_loop', "policy_loss_forward", 'policy_loop', 'vf_loop']
            train_time = sum(times_itrs[loop][-1] for loop in times_itrs.keys())

            sample_time = times_itrs['sample'][-1]

            if epoch > 0:
                eval_time = times_itrs['eval'][-1]
            else:
                times_itrs['eval'] = [0] # Need to do this so we can do line 343, the list comprehension
                eval_time = 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            # logger.record_tabular('Get Batch (s)', times_itrs['get_batch'][-1])
            # logger.record_tabular('Update Normalizer (s)', times_itrs['update_normalizer'][-1])
            # logger.record_tabular('Forward (s)', times_itrs['forward'][-1])
            # logger.record_tabular('Compute Losses (s)', times_itrs['compute_losses'][-1])
            # logger.record_tabular('QF1 Loop (s)', times_itrs['qf1_loop'][-1])
            # logger.record_tabular('QF2 Loop (s)', times_itrs['qf2_loop'][-1])
            # logger.record_tabular("Policy Forward (s)", times_itrs['policy_loss_forward'][-1])
            # logger.record_tabular('Policy Loop (s)', times_itrs['policy_loop'][-1])
            # logger.record_tabular('VF Loop (s)', times_itrs['vf_loop'][-1])

            [logger.record_tabular(key.title(), times_itrs[key][-1]) for key in times_itrs.keys()]

            logger.record_tabular('Train Time (s) ---', train_time)
            logger.record_tabular('(Previous) Eval Time (s) ---', eval_time)
            logger.record_tabular('Sample Time (s) ---', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)
            logger.record_tabular("Epoch", epoch)

            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None and table_keys != self._old_table_keys:
                # assert table_keys == self._old_table_keys, (
                #     "Table keys cannot change from iteration to iteration."
                # )
                print("Table keys have changed. Rewriting header and filling with 0s")
                logger.update_header()
                raise NotImplementedError
            self._old_table_keys = table_keys

            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
예제 #7
0
    def _try_to_eval(self, epoch):
        if epoch % self.logging_period != 0:
            return
        if epoch in self.save_extra_manual_epoch_set:
            logger.save_extra_data(
                self.get_extra_data_to_save(epoch),
                file_name='extra_snapshot_itr{}'.format(epoch),
                mode='cloudpickle',
            )
        if self._save_extra_every_epoch:
            logger.save_extra_data(self.get_extra_data_to_save(epoch))
        gt.stamp('save-extra')
        if self._can_evaluate():
            self.evaluate(epoch)
            gt.stamp('eval')

            params = self.get_epoch_snapshot(epoch)
            logger.save_itr_params(epoch, params)
            gt.stamp('save-snapshot')
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration.")
            self._old_table_keys = table_keys

            logger.record_dict(
                self.trainer.get_diagnostics(),
                prefix='trainer/',
            )

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            save_extra_time = times_itrs['save-extra'][-1]
            save_snapshot_time = times_itrs['save-snapshot'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + save_extra_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('in_unsupervised_model',
                                  float(self.in_unsupervised_phase))
            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Save Extra Time (s)', save_extra_time)
            logger.record_tabular('Save Snapshot Time (s)', save_snapshot_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")