示例#1
0
 def log_diagnostics(self,
                     itr,
                     player_eval_traj_infos,
                     observer_eval_traj_infos,
                     eval_time,
                     prefix='Diagnostics/'):
     if not player_eval_traj_infos:
         logger.log("WARNING: player had no complete trajectories in eval.")
     if not observer_eval_traj_infos:
         logger.log(
             "WARNING: observer had no complete trajectories in eval.")
     player_steps_in_eval = sum(
         [info["Length"] for info in player_eval_traj_infos])
     observer_steps_in_eval = sum(
         [info["Length"] for info in observer_eval_traj_infos])
     with logger.tabular_prefix(prefix):
         logger.record_tabular('PlayerStepsInEval', player_steps_in_eval)
         logger.record_tabular('ObserverStepsInEval',
                               observer_steps_in_eval)
         logger.record_tabular('PlayerTrajsInEval',
                               len(player_eval_traj_infos))
         logger.record_tabular('ObserverTrajsInEval',
                               len(observer_eval_traj_infos))
         self._cum_eval_time += eval_time
         logger.record_tabular('CumEvalTime', self._cum_eval_time)
     super().log_diagnostics(itr,
                             player_eval_traj_infos,
                             observer_eval_traj_infos,
                             eval_time,
                             prefix=prefix)
示例#2
0
 def log_diagnostics(self, itr, prefix='Diagnostics/'):
     with logger.tabular_prefix(prefix):
         logger.record_tabular('NewCompletedTrajs', self._new_completed_trajs)
         logger.record_tabular('StepsInTrajWindow',
             sum(info["Length"] for info in self._traj_infos))
     super().log_diagnostics(itr, prefix=prefix)
     self._new_completed_trajs = 0
示例#3
0
    def _log_infos(self, traj_infos=None, prefix=''):
        """
        Writes trajectory info and optimizer info into csv via the logger.
        Resets stored optimizer info.
        """
        if traj_infos is None:
            traj_infos = self._traj_infos
        if traj_infos:
            with logger.tabular_prefix(prefix):
                for k in traj_infos[0]:
                    if not k.startswith("_"):
                        logger.record_tabular_misc_stat(
                            k, [info[k] for info in traj_infos])

        if self._opt_infos:
            with logger.tabular_prefix('Train_'):
                for k, v in self._opt_infos.items():
                    logger.record_tabular_misc_stat(k, v)
        self._opt_infos = {k: list() for k in self._opt_infos}  # (reset)
示例#4
0
 def log_diagnostics(self, itr, sampler_itr, throttle_time, prefix='Diagnostics/'):
     if not self._traj_infos:
         logger.log("WARNING: had no complete trajectories in eval.")
     steps_in_eval = sum([info["Length"] for info in self._traj_infos])
     with logger.tabular_prefix(prefix):
         logger.record_tabular('StepsInEval', steps_in_eval)
         logger.record_tabular('TrajsInEval', len(self._traj_infos))
         logger.record_tabular('CumEvalTime', self.ctrl.eval_time.value)
     super().log_diagnostics(itr, sampler_itr, throttle_time, prefix=prefix)
     self._traj_infos = list()  # Clear after each eval.
示例#5
0
 def log_diagnostics(self, itr, eval_traj_infos, eval_time, prefix='Diagnostics/'):
     if not eval_traj_infos:
         logger.log("WARNING: had no complete trajectories in eval.")
     steps_in_eval = sum([info["Length"] for info in eval_traj_infos])
     with logger.tabular_prefix(prefix):
         logger.record_tabular('StepsInEval', steps_in_eval)
         logger.record_tabular('TrajsInEval', len(eval_traj_infos))
         self._cum_eval_time += eval_time
         logger.record_tabular('CumEvalTime', self._cum_eval_time)
     super().log_diagnostics(itr, eval_traj_infos, eval_time, prefix=prefix)
示例#6
0
    def log_diagnostics(self,
                        itr,
                        sampler_itr,
                        throttle_time,
                        prefix='Diagnostics/'):
        self.pbar.stop()
        self.save_itr_snapshot(itr, sampler_itr)
        new_time = time.time()
        time_elapsed = new_time - self._last_time
        new_updates = self.algo.update_counter - self._last_update_counter
        new_samples = self.sampler.batch_size * (sampler_itr -
                                                 self._last_sampler_itr)
        updates_per_second = (float('nan') if itr == 0 else new_updates /
                              time_elapsed)
        samples_per_second = (float('nan') if itr == 0 else new_samples /
                              time_elapsed)
        if self._eval:
            new_eval_time = self.ctrl.eval_time.value
            eval_time_elapsed = new_eval_time - self._last_eval_time
            non_eval_time_elapsed = time_elapsed - eval_time_elapsed
            non_eval_samples_per_second = (float('nan') if itr == 0 else
                                           new_samples / non_eval_time_elapsed)
            self._last_eval_time = new_eval_time
        cum_steps = sampler_itr * self.sampler.batch_size  # No * world_size.
        replay_ratio = (new_updates * self.algo.batch_size * self.world_size /
                        max(1, new_samples))
        cum_replay_ratio = (self.algo.update_counter * self.algo.batch_size *
                            self.world_size / max(1, cum_steps))

        with logger.tabular_prefix(prefix):
            logger.record_tabular('Iteration', itr)
            logger.record_tabular('SamplerIteration', sampler_itr)
            logger.record_tabular('CumTime (s)', new_time - self._start_time)
            logger.record_tabular('CumSteps', cum_steps)
            logger.record_tabular('CumUpdates', self.algo.update_counter)
            logger.record_tabular('ReplayRatio', replay_ratio)
            logger.record_tabular('CumReplayRatio', cum_replay_ratio)
            logger.record_tabular('StepsPerSecond', samples_per_second)
            if self._eval:
                logger.record_tabular('NonEvalSamplesPerSecond',
                                      non_eval_samples_per_second)
            logger.record_tabular('UpdatesPerSecond', updates_per_second)
            logger.record_tabular(
                'OptThrottle', (time_elapsed - throttle_time) / time_elapsed)

        self._log_infos()
        self._last_time = new_time
        self._last_itr = itr
        self._last_sampler_itr = sampler_itr
        self._last_update_counter = self.algo.update_counter
        logger.dump_tabular(with_prefix=False)
        logger.log(f"Optimizing over {self.log_interval_itrs} sampler "
                   "iterations.")
        self.pbar = ProgBarCounter(self.log_interval_itrs)
    def log_diagnostics(self,
                        itr,
                        traj_infos=None,
                        eval_time=0,
                        prefix='Diagnostics/'):
        """
        Write diagnostics (including stored ones) to csv via the logger.
        """
        if itr > 0:
            self.pbar.stop()
        if itr >= self.min_itr_learn - 1:
            self.save_itr_snapshot(itr)
        new_time = time.time()
        self._cum_time = new_time - self._start_time
        train_time_elapsed = new_time - self._last_time - eval_time
        new_updates = self.algo.update_counter - self._last_update_counter
        new_samples = (self.sampler.batch_size * self.world_size *
                       self.log_interval_itrs)
        updates_per_second = (float('nan') if itr == 0 else new_updates /
                              train_time_elapsed)
        samples_per_second = (float('nan') if itr == 0 else new_samples /
                              train_time_elapsed)
        replay_ratio = (new_updates * self.algo.batch_size * self.world_size /
                        new_samples)
        cum_replay_ratio = (self.algo.batch_size * self.algo.update_counter /
                            ((itr + 1) * self.sampler.batch_size)
                            )  # world_size cancels.
        cum_steps = (itr + 1) * self.sampler.batch_size * self.world_size

        # Add summaries etc
        with logger.tabular_prefix(prefix):
            if self._eval:
                logger.record_tabular(
                    'CumTrainTime', self._cum_time -
                    self._cum_eval_time)  # Already added new eval_time.
            logger.record_tabular('Iteration', itr)
            logger.record_tabular('CumTime (s)', self._cum_time)
            logger.record_tabular('CumSteps', cum_steps)
            logger.record_tabular('CumCompletedTrajs',
                                  self._cum_completed_trajs)
            logger.record_tabular('CumUpdates', self.algo.update_counter)
            logger.record_tabular('StepsPerSecond', samples_per_second)
            logger.record_tabular('UpdatesPerSecond', updates_per_second)
            logger.record_tabular('ReplayRatio', replay_ratio)
            logger.record_tabular('CumReplayRatio', cum_replay_ratio)
        self._log_trajectory_stats(traj_infos)
        logger.dump_tabular(with_prefix=False)

        self._last_time = new_time
        self._last_update_counter = self.algo.update_counter
        if itr < self.n_itr - 1:
            logger.log(f"Optimizing over {self.log_interval_itrs} iterations.")
            self.pbar = ProgBarCounter(self.log_interval_itrs)
示例#8
0
 def log_diagnostics(self,
                     itr,
                     sampler_itr,
                     throttle_time,
                     prefix="Diagnostics/"):
     with logger.tabular_prefix(prefix):
         logger.record_tabular("CumCompletedTrajs",
                               self._cum_completed_trajs)
         logger.record_tabular("NewCompletedTrajs",
                               self._new_completed_trajs)
         logger.record_tabular(
             "StepsInTrajWindow",
             sum(info["Length"] for info in self._traj_infos))
     super().log_diagnostics(itr, sampler_itr, throttle_time, prefix=prefix)
     self._new_completed_trajs = 0
    def shaping(samples):
        # TODO eval/train mode here and in other places
        if logger._iteration <= 1e3:  # FIXME(1e3)
            return 0

        with torch.no_grad():
            obs = (samples.agent_inputs.observation.to(device)
                   )  # TODO check if maybe better to keep it on cpu
            obsprim = (samples.target_inputs.observation.to(device))
            qs = weak_agent(obs, samples.agent_inputs.prev_action.to(device),
                            samples.agent_inputs.prev_reward.to(device))
            qsprim = weak_agent(obsprim,
                                samples.target_inputs.prev_action.to(device),
                                samples.target_inputs.prev_reward.to(device))
            vals = 0.995 * torch.max(qsprim, dim=1).values - torch.max(
                qs, dim=1).values

            if logger._iteration % 1e1 == 0:  # FIXME(1e1)
                with logger.tabular_prefix("Shaping"):
                    logger.record_tabular_misc_stat(
                        'ShapedReward',
                        vals.detach().cpu().numpy())
            return vals
示例#10
0
    def log_diagnostics(self,
                        itr,
                        player_traj_infos=None,
                        observer_traj_infos=None,
                        eval_time=0,
                        prefix='Diagnostics/'):
        """
        Write diagnostics (including stored ones) to csv via the logger.
        """
        if itr > 0:
            self.pbar.stop()
        if itr >= self.min_itr_learn - 1:
            self.save_itr_snapshot(itr)
        new_time = time.time()
        self._cum_time = new_time - self._start_time
        train_time_elapsed = new_time - self._last_time - eval_time
        player_new_updates = self.player_algo.update_counter - self._player_last_update_counter
        observer_new_updates = self.observer_algo.update_counter - self._observer_last_update_counter
        player_new_samples = (self.sampler.batch_size * self.world_size *
                              self.log_interval_itrs)
        if self.agent.serial:
            observer_new_samples = self.agent.n_serial * player_new_samples
        else:
            observer_new_samples = player_new_samples
        player_updates_per_second = (float('nan') if itr == 0 else
                                     player_new_updates / train_time_elapsed)
        observer_updates_per_second = (float('nan')
                                       if itr == 0 else observer_new_updates /
                                       train_time_elapsed)
        player_samples_per_second = (float('nan') if itr == 0 else
                                     player_new_samples / train_time_elapsed)
        observer_samples_per_second = (float('nan')
                                       if itr == 0 else observer_new_samples /
                                       train_time_elapsed)
        player_replay_ratio = (player_new_updates *
                               self.player_algo.batch_size * self.world_size /
                               player_new_samples)
        observer_replay_ratio = (observer_new_updates *
                                 self.observer_algo.batch_size *
                                 self.world_size / observer_new_samples)
        player_cum_replay_ratio = (
            self.player_algo.batch_size * self.player_algo.update_counter /
            ((itr + 1) * self.sampler.batch_size))  # world_size cancels.
        player_cum_steps = (itr +
                            1) * self.sampler.batch_size * self.world_size
        if self.agent.serial:
            observer_cum_replay_ratio = (
                self.observer_algo.batch_size *
                self.observer_algo.update_counter /
                ((itr + 1) * (self.agent.n_serial * self.sampler.batch_size))
            )  # world_size cancels.
            observer_cum_steps = self.agent.n_serial * player_cum_steps
        else:
            observer_cum_replay_ratio = (self.observer_algo.batch_size *
                                         self.observer_algo.update_counter /
                                         ((itr + 1) * self.sampler.batch_size)
                                         )  # world_size cancels.
            observer_cum_steps = player_cum_steps

        with logger.tabular_prefix(prefix):
            if self._eval:
                logger.record_tabular(
                    'CumTrainTime', self._cum_time -
                    self._cum_eval_time)  # Already added new eval_time.
            logger.record_tabular('Iteration', itr)
            logger.record_tabular('CumTime (s)', self._cum_time)
            logger.record_tabular('PlayerCumSteps', player_cum_steps)
            logger.record_tabular('ObserverCumSteps', observer_cum_steps)
            logger.record_tabular('PlayerCumCompletedTrajs',
                                  self._player_cum_completed_trajs)
            logger.record_tabular('ObserverCumCompletedTrajs',
                                  self._observer_cum_completed_trajs)
            logger.record_tabular('PlayerCumUpdates',
                                  self.player_algo.update_counter)
            logger.record_tabular('ObserverCumUpdates',
                                  self.observer_algo.update_counter)
            logger.record_tabular('PlayerStepsPerSecond',
                                  player_samples_per_second)
            logger.record_tabular('ObserverStepsPerSecond',
                                  observer_samples_per_second)
            logger.record_tabular('PlayerUpdatesPerSecond',
                                  player_updates_per_second)
            logger.record_tabular('ObserverUpdatesPerSecond',
                                  observer_updates_per_second)
            logger.record_tabular('PlayerReplayRatio', player_replay_ratio)
            logger.record_tabular('ObserverReplayRatio', observer_replay_ratio)
            logger.record_tabular('PlayerCumReplayRatio',
                                  player_cum_replay_ratio)
            logger.record_tabular('ObserverCumReplayRatio',
                                  observer_cum_replay_ratio)
        self._log_infos(player_traj_infos, observer_traj_infos)
        logger.dump_tabular(with_prefix=False)

        self._last_time = new_time
        self._player_last_update_counter = self.player_algo.update_counter
        self._observer_last_update_counter = self.observer_algo.update_counter
        if itr < self.n_itr - 1:
            logger.log(f"Optimizing over {self.log_interval_itrs} iterations.")
            self.pbar = ProgBarCounter(self.log_interval_itrs)