Ejemplo n.º 1
0
 def initialize_logging(self):
     self._traj_infos = deque(maxlen=self.log_traj_window)
     self._cum_completed_trajs = 0
     self._new_completed_trajs = 0
     logger.log(f"Optimizing over {self.log_interval_itrs} iterations.")
     super().initialize_logging()
     self.pbar = ProgBarCounter(self.log_interval_itrs)
Ejemplo n.º 2
0
 def initialize_logging(self):
     self._opt_infos = {k: list() for k in self.algo.opt_info_fields}
     self._start_time = self._last_time = time.time()
     self._cum_time = 0.0
     if self.snapshot_gap_intervals is not None:
         logger.set_snapshot_gap(self.snapshot_gap_intervals *
                                 self.log_interval_updates)
     self.pbar = ProgBarCounter(self.log_interval_updates)
Ejemplo n.º 3
0
    def log_diagnostics(self, itr, sample_itr, throttle_time):
        self.pbar.stop()
        self.save_itr_snapshot(itr, sample_itr)
        new_time = time.time()
        time_elapsed = new_time - self._last_time
        samples_per_second = (float('nan') if itr == 0 else
            self.log_interval_itrs * self.itr_batch_size / time_elapsed)
        updates_per_second = (float('nan') if itr == 0 else
            self.algo.updates_per_optimize * (itr - self._last_itr) / time_elapsed)
        logger.record_tabular('Iteration', itr)
        logger.record_tabular('SamplerIteration', sample_itr)
        logger.record_tabular('CumTime (s)', new_time - self._start_time)
        logger.record_tabular('CumSteps', sample_itr * self.itr_batch_size)
        logger.record_tabular('CumUpdates', itr * self.algo.updates_per_optimize)
        logger.record_tabular('SamplesPerSecond', samples_per_second)
        logger.record_tabular('UpdatesPerSecond', updates_per_second)
        logger.record_tabular('OptThrottle', (time_elapsed - throttle_time) /
            time_elapsed)

        self._log_infos()
        self._last_time = new_time
        self._last_itr = itr
        logger.dump_tabular(with_prefix=False)
        logger.log(f"Optimizing over {self.log_interval_itrs} sampler "
            "iterations.")
        self.pbar = ProgBarCounter(self.log_interval_itrs)
Ejemplo n.º 4
0
    def log_diagnostics(self, itr, eval_traj_infos, eval_time):
        """
        Write diagnostics (including stored ones) to csv via the logger.
        ONE NEW LINE VS REGULAR RUNNER, TO LOG ENV STEPS = steps*frame_skip
        """
        if not eval_traj_infos:
            logger.log("WARNING: had no complete trajectories in eval.")
        steps_in_eval = sum([info["Length"] for info in eval_traj_infos])
        logger.record_tabular("StepsInEval", steps_in_eval)
        logger.record_tabular("TrajsInEval", len(eval_traj_infos))
        self._cum_eval_time += eval_time
        logger.record_tabular("CumEvalTime", self._cum_eval_time)

        if itr > 0:
            self.pbar.stop()
        if itr >= self.min_itr_learn - 1:
            self.save_itr_snapshot(itr)
        new_time = time.time()
        self._cum_time = new_time - self._start_time
        train_time_elapsed = new_time - self._last_time - eval_time
        new_updates = self.algo.update_counter - self._last_update_counter
        new_samples = self.sampler.batch_size * self.world_size * self.log_interval_itrs
        updates_per_second = (
            float("nan") if itr == 0 else new_updates / train_time_elapsed
        )
        samples_per_second = (
            float("nan") if itr == 0 else new_samples / train_time_elapsed
        )
        replay_ratio = (
            new_updates * self.algo.batch_size * self.world_size / new_samples
        )
        cum_replay_ratio = (
            self.algo.batch_size
            * self.algo.update_counter
            / ((itr + 1) * self.sampler.batch_size)
        )  # world_size cancels.
        cum_steps = (itr + 1) * self.sampler.batch_size * self.world_size

        if self._eval:
            logger.record_tabular(
                "CumTrainTime", self._cum_time - self._cum_eval_time
            )  # Already added new eval_time.
        logger.record_tabular("Iteration", itr)
        logger.record_tabular("CumTime (s)", self._cum_time)
        logger.record_tabular("CumSteps", cum_steps)
        logger.record_tabular("EnvSteps", cum_steps * self._frame_skip)  # NEW LINE
        logger.record_tabular("CumCompletedTrajs", self._cum_completed_trajs)
        logger.record_tabular("CumUpdates", self.algo.update_counter)
        logger.record_tabular("StepsPerSecond", samples_per_second)
        logger.record_tabular("UpdatesPerSecond", updates_per_second)
        logger.record_tabular("ReplayRatio", replay_ratio)
        logger.record_tabular("CumReplayRatio", cum_replay_ratio)
        self._log_infos(eval_traj_infos)
        logger.dump_tabular(with_prefix=False)

        self._last_time = new_time
        self._last_update_counter = self.algo.update_counter
        if itr < self.n_itr - 1:
            logger.log(f"Optimizing over {self.log_interval_itrs} iterations.")
            self.pbar = ProgBarCounter(self.log_interval_itrs)
Ejemplo n.º 5
0
    def log_diagnostics(self, itr, eval_traj_infos, eval_time):
        self.save_itr_snapshot(itr)
        if not eval_traj_infos:
            logger.log("WARNING: had no complete trajectories in eval.")
        steps_in_eval = sum([info["Length"] for info in eval_traj_infos])
        logger.record_tabular('Iteration', itr)
        logger.record_tabular('CumSteps', itr * self.itr_batch_size)
        logger.record_tabular('StepsInEval', steps_in_eval)
        logger.record_tabular('TrajsInEval', len(eval_traj_infos))

        self._log_infos(eval_traj_infos)

        new_time = time.time()
        log_interval_time = new_time - self._last_time
        new_train_time = log_interval_time - eval_time
        self.cum_train_time += new_train_time
        self.cum_eval_time += eval_time
        self.cum_total_time += log_interval_time
        self._last_time = new_time
        samples_per_second = (float('nan')
                              if itr == 0 else self.log_interval_itrs *
                              self.itr_batch_size / new_train_time)

        logger.record_tabular('CumTrainTime', self.cum_train_time)
        logger.record_tabular('CumEvalTime', self.cum_eval_time)
        logger.record_tabular('CumTotalTime', self.cum_total_time)
        logger.record_tabular('SamplesPerSecond', samples_per_second)

        logger.dump_tabular(with_prefix=False)

        logger.log(f"optimizing over {self.log_interval_itrs} iterations")
        self.pbar = ProgBarCounter(self.log_interval_itrs)
Ejemplo n.º 6
0
    def log_diagnostics(self,
                        itr,
                        sampler_itr,
                        throttle_time,
                        prefix='Diagnostics/'):
        self.pbar.stop()
        self.save_itr_snapshot(itr, sampler_itr)
        new_time = time.time()
        time_elapsed = new_time - self._last_time
        new_updates = self.algo.update_counter - self._last_update_counter
        new_samples = self.sampler.batch_size * (sampler_itr -
                                                 self._last_sampler_itr)
        updates_per_second = (float('nan') if itr == 0 else new_updates /
                              time_elapsed)
        samples_per_second = (float('nan') if itr == 0 else new_samples /
                              time_elapsed)
        if self._eval:
            new_eval_time = self.ctrl.eval_time.value
            eval_time_elapsed = new_eval_time - self._last_eval_time
            non_eval_time_elapsed = time_elapsed - eval_time_elapsed
            non_eval_samples_per_second = (float('nan') if itr == 0 else
                                           new_samples / non_eval_time_elapsed)
            self._last_eval_time = new_eval_time
        cum_steps = sampler_itr * self.sampler.batch_size  # No * world_size.
        replay_ratio = (new_updates * self.algo.batch_size * self.world_size /
                        max(1, new_samples))
        cum_replay_ratio = (self.algo.update_counter * self.algo.batch_size *
                            self.world_size / max(1, cum_steps))

        with logger.tabular_prefix(prefix):
            logger.record_tabular('Iteration', itr)
            logger.record_tabular('SamplerIteration', sampler_itr)
            logger.record_tabular('CumTime (s)', new_time - self._start_time)
            logger.record_tabular('CumSteps', cum_steps)
            logger.record_tabular('CumUpdates', self.algo.update_counter)
            logger.record_tabular('ReplayRatio', replay_ratio)
            logger.record_tabular('CumReplayRatio', cum_replay_ratio)
            logger.record_tabular('StepsPerSecond', samples_per_second)
            if self._eval:
                logger.record_tabular('NonEvalSamplesPerSecond',
                                      non_eval_samples_per_second)
            logger.record_tabular('UpdatesPerSecond', updates_per_second)
            logger.record_tabular(
                'OptThrottle', (time_elapsed - throttle_time) / time_elapsed)

        self._log_infos()
        self._last_time = new_time
        self._last_itr = itr
        self._last_sampler_itr = sampler_itr
        self._last_update_counter = self.algo.update_counter
        logger.dump_tabular(with_prefix=False)
        logger.log(f"Optimizing over {self.log_interval_itrs} sampler "
                   "iterations.")
        self.pbar = ProgBarCounter(self.log_interval_itrs)
    def log_diagnostics(self,
                        itr,
                        traj_infos=None,
                        eval_time=0,
                        prefix='Diagnostics/'):
        """
        Write diagnostics (including stored ones) to csv via the logger.
        """
        if itr > 0:
            self.pbar.stop()
        if itr >= self.min_itr_learn - 1:
            self.save_itr_snapshot(itr)
        new_time = time.time()
        self._cum_time = new_time - self._start_time
        train_time_elapsed = new_time - self._last_time - eval_time
        new_updates = self.algo.update_counter - self._last_update_counter
        new_samples = (self.sampler.batch_size * self.world_size *
                       self.log_interval_itrs)
        updates_per_second = (float('nan') if itr == 0 else new_updates /
                              train_time_elapsed)
        samples_per_second = (float('nan') if itr == 0 else new_samples /
                              train_time_elapsed)
        replay_ratio = (new_updates * self.algo.batch_size * self.world_size /
                        new_samples)
        cum_replay_ratio = (self.algo.batch_size * self.algo.update_counter /
                            ((itr + 1) * self.sampler.batch_size)
                            )  # world_size cancels.
        cum_steps = (itr + 1) * self.sampler.batch_size * self.world_size

        # Add summaries etc
        with logger.tabular_prefix(prefix):
            if self._eval:
                logger.record_tabular(
                    'CumTrainTime', self._cum_time -
                    self._cum_eval_time)  # Already added new eval_time.
            logger.record_tabular('Iteration', itr)
            logger.record_tabular('CumTime (s)', self._cum_time)
            logger.record_tabular('CumSteps', cum_steps)
            logger.record_tabular('CumCompletedTrajs',
                                  self._cum_completed_trajs)
            logger.record_tabular('CumUpdates', self.algo.update_counter)
            logger.record_tabular('StepsPerSecond', samples_per_second)
            logger.record_tabular('UpdatesPerSecond', updates_per_second)
            logger.record_tabular('ReplayRatio', replay_ratio)
            logger.record_tabular('CumReplayRatio', cum_replay_ratio)
        self._log_trajectory_stats(traj_infos)
        logger.dump_tabular(with_prefix=False)

        self._last_time = new_time
        self._last_update_counter = self.algo.update_counter
        if itr < self.n_itr - 1:
            logger.log(f"Optimizing over {self.log_interval_itrs} iterations.")
            self.pbar = ProgBarCounter(self.log_interval_itrs)
Ejemplo n.º 8
0
 def log_diagnostics(self, itr, val_info, *args, **kwargs):
     self.save_itr_snapshot(itr)
     new_time = time.time()
     self._cum_time = new_time - self._start_time
     epochs = itr * self.algo.batch_size / (
         self.algo.replay_buffer.size * (1 - self.algo.validation_split))
     logger.record_tabular("Iteration", itr)
     logger.record_tabular("Epochs", epochs)
     logger.record_tabular("CumTime (s)", self._cum_time)
     logger.record_tabular("UpdatesPerSecond", itr / self._cum_time)
     if self._opt_infos:
         for k, v in self._opt_infos.items():
             logger.record_tabular_misc_stat(k, v)
     for k, v in zip(val_info._fields, val_info):
         logger.record_tabular_misc_stat("val_" + k, v)
     self._opt_infos = {k: list() for k in self._opt_infos}  # (reset)
     logger.dump_tabular(with_prefix=False)
     if itr < self.n_updates - 1:
         logger.log(
             f"Optimizing over {self.log_interval_updates} iterations.")
         self.pbar = ProgBarCounter(self.log_interval_updates)
Ejemplo n.º 9
0
    def log_diagnostics(self, itr, traj_infos=None, eval_time=0):
        """
        记录诊断信息(写日志),会把模型参数等也保存下来。

        :param itr: 第几次迭代。
        """
        if itr > 0:
            self.pbar.stop()  # 停止更新进度条
        if itr >= self.min_itr_learn - 1:
            self.save_itr_snapshot(itr)
        new_time = time.time()
        self._cum_time = new_time - self._start_time
        train_time_elapsed = new_time - self._last_time - eval_time
        new_updates = self.algo.update_counter - self._last_update_counter
        new_samples = (self.sampler.batch_size * self.world_size *
                       self.log_interval_itrs)
        updates_per_second = (float('nan') if itr == 0 else new_updates /
                              train_time_elapsed)
        samples_per_second = (float('nan') if itr == 0 else new_samples /
                              train_time_elapsed)
        replay_ratio = (new_updates * self.algo.batch_size * self.world_size /
                        new_samples)
        cum_replay_ratio = (self.algo.batch_size * self.algo.update_counter /
                            ((itr + 1) * self.sampler.batch_size)
                            )  # world_size cancels.
        cum_steps = (itr + 1) * self.sampler.batch_size * self.world_size

        # 写一些额外的统计信息到日志里
        if self._eval:
            logger.record_tabular(
                'CumTrainTime', self._cum_time -
                self._cum_eval_time)  # Already added new eval_time.
        logger.record_tabular('Iteration', itr)
        logger.record_tabular('CumTime (s)', self._cum_time)
        logger.record_tabular('CumSteps', cum_steps)
        logger.record_tabular(
            'CumCompletedTrajs',
            self._cum_completed_trajs)  # 只对那些标识了"traj_done"的有效
        logger.record_tabular('CumUpdates', self.algo.update_counter)
        logger.record_tabular('StepsPerSecond', samples_per_second)
        logger.record_tabular('UpdatesPerSecond', updates_per_second)
        logger.record_tabular('ReplayRatio', replay_ratio)
        logger.record_tabular('CumReplayRatio', cum_replay_ratio)
        self._log_infos(traj_infos)
        logger.dump_tabular(with_prefix=False)  # 写日志文件

        self._last_time = new_time
        self._last_update_counter = self.algo.update_counter
        if itr < self.n_itr - 1:
            logger.log(f"Optimizing over {self.log_interval_itrs} iterations.")
            self.pbar = ProgBarCounter(self.log_interval_itrs)  # 进度条
Ejemplo n.º 10
0
    def log_diagnostics(self, itr):
        self.pbar.stop()
        self.save_itr_snapshot(itr)
        new_time = time.time()
        samples_per_second = (self.log_interval_itrs *
            self.itr_batch_size) / (new_time - self._last_time)
        logger.record_tabular('Iteration', itr)
        logger.record_tabular('CumSteps', (itr + 1) * self.itr_batch_size)
        logger.record_tabular('CumTime (s)', new_time - self._start_time)
        logger.record_tabular('SamplesPerSecond', samples_per_second)
        logger.record_tabular('CumCompletedTrajs', self._cum_completed_trajs)
        logger.record_tabular('NewCompletedTrajs', self._new_completed_trajs)
        logger.record_tabular('StepsInTrajWindow',
            sum(info["Length"] for info in self._traj_infos))
        self._log_infos()

        self._last_time = new_time
        logger.dump_tabular(with_prefix=False)

        self._new_completed_trajs = 0
        if itr < self.n_itr - 1:
            logger.log(f"Optimizing over {self.log_interval_itrs} iterations.")
            self.pbar = ProgBarCounter(self.log_interval_itrs)
Ejemplo n.º 11
0
    def log_diagnostics(self, itr, traj_infos=None, eval_time=0):
        if itr > 0:
            self.pbar.stop()
        self.save_itr_snapshot(itr)
        new_time = time.time()
        self._cum_time = new_time - self._start_time
        train_time_elapsed = new_time - self._last_time - eval_time
        new_updates = self.algo.update_counter - self._last_update_counter
        new_samples = (self.sampler.batch_size * self.world_size *
                       self.log_interval_itrs)
        updates_per_second = (float('nan') if itr == 0 else new_updates /
                              train_time_elapsed)
        samples_per_second = (float('nan') if itr == 0 else new_samples /
                              train_time_elapsed)
        replay_ratio = (new_updates * self.algo.batch_size * self.world_size /
                        new_samples)
        cum_replay_ratio = (self.algo.batch_size * self.algo.update_counter /
                            ((itr + 1) * self.sampler.batch_size)
                            )  # world_size cancels.
        cum_steps = (itr + 1) * self.sampler.batch_size * self.world_size

        if self._eval:
            logger.record_tabular(
                'CumTrainTime', self._cum_time -
                self._cum_eval_time)  # Already added new eval_time.
        logger.record_tabular('Iteration', itr)
        logger.record_tabular('CumTime (s)', self._cum_time)
        logger.record_tabular('CumSteps', cum_steps)
        logger.record_tabular('CumCompletedTrajs', self._cum_completed_trajs)
        logger.record_tabular('CumUpdates', self.algo.update_counter)
        logger.record_tabular('StepsPerSecond', samples_per_second)
        logger.record_tabular('UpdatesPerSecond', updates_per_second)
        logger.record_tabular('ReplayRatio', replay_ratio)
        logger.record_tabular('CumReplayRatio', cum_replay_ratio)
        self._cum_pyflex_steps += sum(
            getattr(info['env_infos'][-1], 'total_steps')
            for info in traj_infos)
        logger.record_tabular('CumPyflexSteps', self._cum_pyflex_steps)
        self._log_infos(traj_infos)
        logger.dump_tabular(with_prefix=False)

        self._last_time = new_time
        self._last_update_counter = self.algo.update_counter
        if itr < self.n_itr - 1:
            logger.log(f"Optimizing over {self.log_interval_itrs} iterations.")
            self.pbar = ProgBarCounter(self.log_interval_itrs)
Ejemplo n.º 12
0
def do_training_mt(loader, model, opt, dev, aug_model, min_bc_module,
                   n_batches):
    # @torch.jit.script
    def do_loss_forward_back(obs_batch_obs, obs_batch_task, obs_batch_var,
                             obs_batch_source, acts_batch):
        # we don't use the value output
        logits_flat, _ = model(obs_batch_obs, task_ids=obs_batch_task)
        losses = F.cross_entropy(logits_flat,
                                 acts_batch.long(),
                                 reduction='none')
        if min_bc_module is not None:
            # weight using a model-dependent strategy
            mbc_weights = min_bc_module(obs_batch_task, obs_batch_var,
                                        obs_batch_source)
            assert mbc_weights.shape == losses.shape, (mbc_weights.shape,
                                                       losses.shape)
            loss = (losses * mbc_weights).sum()
        else:
            # no weighting
            loss = losses.mean()
        loss.backward()
        return losses.detach().cpu().numpy()

    # make sure we're in train mode
    model.train()

    # for logging
    loss_ewma = None
    losses = []
    per_task_losses = collections.defaultdict(lambda: [])
    progress = ProgBarCounter(n_batches)
    inf_batch_iter = repeat_dataset(loader)
    ctr_batch_iter = zip(range(1, n_batches), inf_batch_iter)
    for batches_done, loader_batch in ctr_batch_iter:
        # (task_ids_batch, obs_batch, acts_batch)
        # copy to GPU
        obs_batch = loader_batch['obs']
        acts_batch = loader_batch['acts']
        # reminder: attributes are .observation, .task_id, .variant_id
        obs_batch = tree_map(lambda t: t.to(dev), obs_batch)
        acts_batch = acts_batch.to(dev)

        if aug_model is not None:
            # apply augmentations
            obs_batch = obs_batch._replace(
                observation=aug_model(obs_batch.observation))

        # compute loss & take opt step
        opt.zero_grad()
        batch_losses = do_loss_forward_back(obs_batch.observation,
                                            obs_batch.task_id,
                                            obs_batch.variant_id,
                                            obs_batch.source_id, acts_batch)
        opt.step()

        # for logging
        progress.update(batches_done)
        f_loss = np.mean(batch_losses)
        loss_ewma = f_loss if loss_ewma is None \
            else 0.9 * loss_ewma + 0.1 * f_loss
        losses.append(f_loss)

        # also track separately for each task
        tv_ids = torch.stack((obs_batch.task_id, obs_batch.variant_id), axis=1)
        np_tv_ids = tv_ids.cpu().numpy()
        assert len(np_tv_ids.shape) == 2 and np_tv_ids.shape[1] == 2, \
            np_tv_ids.shape
        for tv_id in np.unique(np_tv_ids, axis=0):
            tv_mask = np.all(np_tv_ids == tv_id[None], axis=-1)
            rel_losses = batch_losses[tv_mask]
            if len(rel_losses) > 0:
                task_id, variant_id = tv_id
                per_task_losses[(task_id, variant_id)] \
                    .append(np.mean(rel_losses))

    progress.stop()

    return loss_ewma, losses, per_task_losses
Ejemplo n.º 13
0
 def initialize_logging(self):
     self._traj_infos = list()
     self._last_eval_time = 0.
     super().initialize_logging()
     self.pbar = ProgBarCounter(self.log_interval_itrs)
Ejemplo n.º 14
0
class UnsupervisedLearning(BaseRunner):
    def __init__(
            self,
            algo,
            n_updates,
            seed=None,
            affinity=None,
            log_interval_updates=1e3,
            snapshot_gap_intervals=None,  # units: log_intervals
    ):
        n_updates = int(n_updates)
        affinity = dict() if affinity is None else affinity
        save__init__args(locals())

    def startup(self):
        p = psutil.Process()
        try:
            if self.affinity.get("master_cpus",
                                 None) is not None and self.affinity.get(
                                     "set_affinity", True):
                p.cpu_affinity(self.affinity["master_cpus"])
            cpu_affin = p.cpu_affinity()
        except AttributeError:
            cpu_affin = "UNAVAILABLE MacOS"
        logger.log(f"Runner {getattr(self, 'rank', '')} master CPU affinity: "
                   f"{cpu_affin}.")
        if self.affinity.get("master_torch_threads", None) is not None:
            torch.set_num_threads(self.affinity["master_torch_threads"])
        logger.log(f"Runner {getattr(self, 'rank', '')} master Torch threads: "
                   f"{torch.get_num_threads()}.")
        if self.seed is None:
            self.seed = make_seed()
        set_seed(self.seed)
        # self.rank = rank = getattr(self, "rank", 0)
        # self.world_size = world_size = getattr(self, "world_size", 1)
        self.algo.initialize(
            n_updates=self.n_updates,
            cuda_idx=self.affinity.get("cuda_idx", None),
        )
        self.initialize_logging()

    def initialize_logging(self):
        self._opt_infos = {k: list() for k in self.algo.opt_info_fields}
        self._start_time = self._last_time = time.time()
        self._cum_time = 0.0
        if self.snapshot_gap_intervals is not None:
            logger.set_snapshot_gap(self.snapshot_gap_intervals *
                                    self.log_interval_updates)
        self.pbar = ProgBarCounter(self.log_interval_updates)

    def shutdown(self):
        logger.log("Pretraining complete.")
        self.pbar.stop()

    def get_itr_snapshot(self, itr):
        return dict(
            itr=itr,
            algo_state_dict=self.algo.state_dict(),
        )

    def save_itr_snapshot(self, itr):
        """
        Calls the logger to save training checkpoint/snapshot (logger itself
        may or may not save, depending on mode selected).
        """
        logger.log("saving snapshot...")
        params = self.get_itr_snapshot(itr)
        logger.save_itr_params(itr, params)
        logger.log("saved")

    def store_diagnostics(self, itr, opt_info):
        for k, v in self._opt_infos.items():
            new_v = getattr(opt_info, k, [])
            v.extend(new_v if isinstance(new_v, list) else [new_v])
        self.pbar.update((itr + 1) % self.log_interval_updates)

    def log_diagnostics(self, itr, val_info, *args, **kwargs):
        self.save_itr_snapshot(itr)
        new_time = time.time()
        self._cum_time = new_time - self._start_time
        epochs = (itr * self.algo.batch_size /
                  (self.algo.replay_buffer.size *
                   (1 - self.algo.validation_split)))
        logger.record_tabular("Iteration", itr)
        logger.record_tabular("Epochs", epochs)
        logger.record_tabular("CumTime (s)", self._cum_time)
        logger.record_tabular("UpdatesPerSecond", itr / self._cum_time)
        if self._opt_infos:
            for k, v in self._opt_infos.items():
                logger.record_tabular_misc_stat(k, v)
        for k, v in zip(val_info._fields, val_info):
            logger.record_tabular_misc_stat("val_" + k, v)
        self._opt_infos = {k: list() for k in self._opt_infos}  # (reset)
        logger.dump_tabular(with_prefix=False)
        if itr < self.n_updates - 1:
            logger.log(
                f"Optimizing over {self.log_interval_updates} iterations.")
            self.pbar = ProgBarCounter(self.log_interval_updates)

    def train(self):
        self.startup()
        self.algo.train()
        for itr in range(self.n_updates):
            logger.set_iteration(itr)
            with logger.prefix(f"itr #{itr} "):
                opt_info = self.algo.optimize(itr)  # perform one update
                self.store_diagnostics(itr, opt_info)
                if (itr + 1) % self.log_interval_updates == 0:
                    self.algo.eval()
                    val_info = self.algo.validation(itr)
                    self.log_diagnostics(itr, val_info)
                    self.algo.train()
        self.shutdown()
Ejemplo n.º 15
0
class MinibatchRl(MinibatchRlBase):
    """Runs RL on minibatches; tracks performance online using learning
    trajectories."""

    def __init__(self, log_traj_window=100, **kwargs):
        super().__init__(**kwargs)
        self.log_traj_window = int(log_traj_window)

    def train(self):
        n_itr = self.startup()
        for itr in range(n_itr):
            with logger.prefix(f"itr #{itr} "):
                self.agent.sample_mode(itr)  # Might not be this agent sampling.
                samples, traj_infos = self.sampler.obtain_samples(itr)
                self.agent.train_mode(itr)
                opt_info = self.algo.optimize_agent(itr, samples)
                self.store_diagnostics(itr, traj_infos, opt_info)
                if (itr + 1) % self.log_interval_itrs == 0:
                    self.log_diagnostics(itr)
        self.shutdown()

    def initialize_logging(self):
        self._traj_infos = deque(maxlen=self.log_traj_window)
        self._cum_completed_trajs = 0
        self._new_completed_trajs = 0
        logger.log(f"Optimizing over {self.log_interval_itrs} iterations.")
        super().initialize_logging()
        self.pbar = ProgBarCounter(self.log_interval_itrs)

    def store_diagnostics(self, itr, traj_infos, opt_info):
        self._cum_completed_trajs += len(traj_infos)
        self._new_completed_trajs += len(traj_infos)
        self._traj_infos.extend(traj_infos)
        for k, v in self._opt_infos.items():
            new_v = getattr(opt_info, k, [])
            v.extend(new_v if isinstance(new_v, list) else [new_v])
        self.pbar.update((itr + 1) % self.log_interval_itrs)

    def log_diagnostics(self, itr):
        self.pbar.stop()
        self.save_itr_snapshot(itr)
        new_time = time.time()
        samples_per_second = (self.log_interval_itrs *
            self.itr_batch_size) / (new_time - self._last_time)
        logger.record_tabular('Iteration', itr)
        logger.record_tabular('CumSteps', (itr + 1) * self.itr_batch_size)
        logger.record_tabular('CumTime (s)', new_time - self._start_time)
        logger.record_tabular('SamplesPerSecond', samples_per_second)
        logger.record_tabular('CumCompletedTrajs', self._cum_completed_trajs)
        logger.record_tabular('NewCompletedTrajs', self._new_completed_trajs)
        logger.record_tabular('StepsInTrajWindow',
            sum(info["Length"] for info in self._traj_infos))
        self._log_infos()

        self._last_time = new_time
        logger.dump_tabular(with_prefix=False)

        self._new_completed_trajs = 0
        if itr < self.n_itr - 1:
            logger.log(f"Optimizing over {self.log_interval_itrs} iterations.")
            self.pbar = ProgBarCounter(self.log_interval_itrs)
Ejemplo n.º 16
0
    def log_diagnostics(self,
                        itr,
                        player_traj_infos=None,
                        observer_traj_infos=None,
                        eval_time=0,
                        prefix='Diagnostics/'):
        """
        Write diagnostics (including stored ones) to csv via the logger.
        """
        if itr > 0:
            self.pbar.stop()
        if itr >= self.min_itr_learn - 1:
            self.save_itr_snapshot(itr)
        new_time = time.time()
        self._cum_time = new_time - self._start_time
        train_time_elapsed = new_time - self._last_time - eval_time
        player_new_updates = self.player_algo.update_counter - self._player_last_update_counter
        observer_new_updates = self.observer_algo.update_counter - self._observer_last_update_counter
        player_new_samples = (self.sampler.batch_size * self.world_size *
                              self.log_interval_itrs)
        if self.agent.serial:
            observer_new_samples = self.agent.n_serial * player_new_samples
        else:
            observer_new_samples = player_new_samples
        player_updates_per_second = (float('nan') if itr == 0 else
                                     player_new_updates / train_time_elapsed)
        observer_updates_per_second = (float('nan')
                                       if itr == 0 else observer_new_updates /
                                       train_time_elapsed)
        player_samples_per_second = (float('nan') if itr == 0 else
                                     player_new_samples / train_time_elapsed)
        observer_samples_per_second = (float('nan')
                                       if itr == 0 else observer_new_samples /
                                       train_time_elapsed)
        player_replay_ratio = (player_new_updates *
                               self.player_algo.batch_size * self.world_size /
                               player_new_samples)
        observer_replay_ratio = (observer_new_updates *
                                 self.observer_algo.batch_size *
                                 self.world_size / observer_new_samples)
        player_cum_replay_ratio = (
            self.player_algo.batch_size * self.player_algo.update_counter /
            ((itr + 1) * self.sampler.batch_size))  # world_size cancels.
        player_cum_steps = (itr +
                            1) * self.sampler.batch_size * self.world_size
        if self.agent.serial:
            observer_cum_replay_ratio = (
                self.observer_algo.batch_size *
                self.observer_algo.update_counter /
                ((itr + 1) * (self.agent.n_serial * self.sampler.batch_size))
            )  # world_size cancels.
            observer_cum_steps = self.agent.n_serial * player_cum_steps
        else:
            observer_cum_replay_ratio = (self.observer_algo.batch_size *
                                         self.observer_algo.update_counter /
                                         ((itr + 1) * self.sampler.batch_size)
                                         )  # world_size cancels.
            observer_cum_steps = player_cum_steps

        with logger.tabular_prefix(prefix):
            if self._eval:
                logger.record_tabular(
                    'CumTrainTime', self._cum_time -
                    self._cum_eval_time)  # Already added new eval_time.
            logger.record_tabular('Iteration', itr)
            logger.record_tabular('CumTime (s)', self._cum_time)
            logger.record_tabular('PlayerCumSteps', player_cum_steps)
            logger.record_tabular('ObserverCumSteps', observer_cum_steps)
            logger.record_tabular('PlayerCumCompletedTrajs',
                                  self._player_cum_completed_trajs)
            logger.record_tabular('ObserverCumCompletedTrajs',
                                  self._observer_cum_completed_trajs)
            logger.record_tabular('PlayerCumUpdates',
                                  self.player_algo.update_counter)
            logger.record_tabular('ObserverCumUpdates',
                                  self.observer_algo.update_counter)
            logger.record_tabular('PlayerStepsPerSecond',
                                  player_samples_per_second)
            logger.record_tabular('ObserverStepsPerSecond',
                                  observer_samples_per_second)
            logger.record_tabular('PlayerUpdatesPerSecond',
                                  player_updates_per_second)
            logger.record_tabular('ObserverUpdatesPerSecond',
                                  observer_updates_per_second)
            logger.record_tabular('PlayerReplayRatio', player_replay_ratio)
            logger.record_tabular('ObserverReplayRatio', observer_replay_ratio)
            logger.record_tabular('PlayerCumReplayRatio',
                                  player_cum_replay_ratio)
            logger.record_tabular('ObserverCumReplayRatio',
                                  observer_cum_replay_ratio)
        self._log_infos(player_traj_infos, observer_traj_infos)
        logger.dump_tabular(with_prefix=False)

        self._last_time = new_time
        self._player_last_update_counter = self.player_algo.update_counter
        self._observer_last_update_counter = self.observer_algo.update_counter
        if itr < self.n_itr - 1:
            logger.log(f"Optimizing over {self.log_interval_itrs} iterations.")
            self.pbar = ProgBarCounter(self.log_interval_itrs)