示例#1
0
 def log_diagnostics(self, itr, prefix='Diagnostics/'):
     with logger.tabular_prefix(prefix):
         logger.record_tabular('NewCompletedTrajs', self._new_completed_trajs)
         logger.record_tabular('StepsInTrajWindow',
             sum(info["Length"] for info in self._traj_infos))
     super().log_diagnostics(itr, prefix=prefix)
     self._new_completed_trajs = 0
示例#2
0
 def log_diagnostics(self,
                     itr,
                     player_eval_traj_infos,
                     observer_eval_traj_infos,
                     eval_time,
                     prefix='Diagnostics/'):
     if not player_eval_traj_infos:
         logger.log("WARNING: player had no complete trajectories in eval.")
     if not observer_eval_traj_infos:
         logger.log(
             "WARNING: observer had no complete trajectories in eval.")
     player_steps_in_eval = sum(
         [info["Length"] for info in player_eval_traj_infos])
     observer_steps_in_eval = sum(
         [info["Length"] for info in observer_eval_traj_infos])
     with logger.tabular_prefix(prefix):
         logger.record_tabular('PlayerStepsInEval', player_steps_in_eval)
         logger.record_tabular('ObserverStepsInEval',
                               observer_steps_in_eval)
         logger.record_tabular('PlayerTrajsInEval',
                               len(player_eval_traj_infos))
         logger.record_tabular('ObserverTrajsInEval',
                               len(observer_eval_traj_infos))
         self._cum_eval_time += eval_time
         logger.record_tabular('CumEvalTime', self._cum_eval_time)
     super().log_diagnostics(itr,
                             player_eval_traj_infos,
                             observer_eval_traj_infos,
                             eval_time,
                             prefix=prefix)
示例#3
0
    def log_diagnostics(self, itr):
        """
        记录诊断信息。

        :param itr: 第几次迭代。
        """
        logger.record_tabular('NewCompletedTrajs', self._new_completed_trajs)
        logger.record_tabular('StepsInTrajWindow',
                              sum(info["Length"] for info in self._traj_infos))
        super().log_diagnostics(itr)
        self._new_completed_trajs = 0
示例#4
0
 def log_diagnostics(self, itr, sampler_itr, throttle_time):
     logger.record_tabular('CumCompletedTrajs', self._cum_completed_trajs)
     logger.record_tabular('NewCompletedTrajs', self._new_completed_trajs)
     logger.record_tabular('StepsInTrajWindow',
         sum(info["Length"] for info in self._traj_infos))
     super().log_diagnostics(itr, sampler_itr, throttle_time)
     self._new_completed_trajs = 0
示例#5
0
 def log_diagnostics(self, itr, sampler_itr, throttle_time):
     if not self._traj_infos:
         logger.log("WARNING: had no complete trajectories in eval.")
     steps_in_eval = sum([info["Length"] for info in self._traj_infos])
     logger.record_tabular('StepsInEval', steps_in_eval)
     logger.record_tabular('TrajsInEval', len(self._traj_infos))
     logger.record_tabular('CumEvalTime', self.ctrl.eval_time.value)
     super().log_diagnostics(itr, sampler_itr, throttle_time)
     self._traj_infos = list()  # Clear after each eval.
示例#6
0
 def log_diagnostics(self, itr, eval_traj_infos, eval_time):
     if not eval_traj_infos:
         logger.log("WARNING: had no complete trajectories in eval.")
     steps_in_eval = sum([info["Length"] for info in eval_traj_infos])
     logger.record_tabular('StepsInEval', steps_in_eval)
     logger.record_tabular('TrajsInEval', len(eval_traj_infos))
     self._cum_eval_time += eval_time
     logger.record_tabular('CumEvalTime', self._cum_eval_time)
     super().log_diagnostics(itr, eval_traj_infos, eval_time)
示例#7
0
 def log_diagnostics(self, itr, val_info, *args, **kwargs):
     self.save_itr_snapshot(itr)
     new_time = time.time()
     self._cum_time = new_time - self._start_time
     epochs = itr * self.algo.batch_size / (
         self.algo.replay_buffer.size * (1 - self.algo.validation_split))
     logger.record_tabular("Iteration", itr)
     logger.record_tabular("Epochs", epochs)
     logger.record_tabular("CumTime (s)", self._cum_time)
     logger.record_tabular("UpdatesPerSecond", itr / self._cum_time)
     if self._opt_infos:
         for k, v in self._opt_infos.items():
             logger.record_tabular_misc_stat(k, v)
     for k, v in zip(val_info._fields, val_info):
         logger.record_tabular_misc_stat("val_" + k, v)
     self._opt_infos = {k: list() for k in self._opt_infos}  # (reset)
     logger.dump_tabular(with_prefix=False)
     if itr < self.n_updates - 1:
         logger.log(
             f"Optimizing over {self.log_interval_updates} iterations.")
         self.pbar = ProgBarCounter(self.log_interval_updates)
示例#8
0
 def log_diagnostics(self,
                     itr,
                     sampler_itr,
                     throttle_time,
                     prefix="Diagnostics/"):
     if not self._traj_infos:
         logger.log("WARNING: had no complete trajectories in eval.")
     steps_in_eval = sum([info["Length"] for info in self._traj_infos])
     with logger.tabular_prefix(prefix):
         logger.record_tabular("StepsInEval", steps_in_eval)
         logger.record_tabular("TrajsInEval", len(self._traj_infos))
         logger.record_tabular("CumEvalTime", self.ctrl.eval_time.value)
     super().log_diagnostics(itr, sampler_itr, throttle_time, prefix=prefix)
     self._traj_infos = list()  # Clear after each eval.
示例#9
0
 def log_diagnostics(self,
                     itr,
                     eval_traj_infos,
                     eval_time,
                     prefix="Diagnostics/"):
     if not eval_traj_infos:
         logger.log("WARNING: had no complete trajectories in eval.")
     steps_in_eval = sum([info["Length"] for info in eval_traj_infos])
     with logger.tabular_prefix(prefix):
         logger.record_tabular("StepsInEval", steps_in_eval)
         logger.record_tabular("TrajsInEval", len(eval_traj_infos))
         self._cum_eval_time += eval_time
         logger.record_tabular("CumEvalTime", self._cum_eval_time)
     super().log_diagnostics(itr, eval_traj_infos, eval_time, prefix=prefix)
示例#10
0
 def log_diagnostics(self,
                     itr,
                     sampler_itr,
                     throttle_time,
                     prefix="Diagnostics/"):
     with logger.tabular_prefix(prefix):
         logger.record_tabular("CumCompletedTrajs",
                               self._cum_completed_trajs)
         logger.record_tabular("NewCompletedTrajs",
                               self._new_completed_trajs)
         logger.record_tabular(
             "StepsInTrajWindow",
             sum(info["Length"] for info in self._traj_infos))
     super().log_diagnostics(itr, sampler_itr, throttle_time, prefix=prefix)
     self._new_completed_trajs = 0
示例#11
0
    def log_diagnostics(self, itr, eval_traj_infos, eval_time: float):
        """
        [此函数设计得不好,和父类的同名函数签名不一致]
        记录诊断信息。此函数会写日志。

        :param itr: 第几次迭代。
        :param eval_traj_infos: 做evaluation的时候在trajectory中step产生的统计信息,例如非0的reward值等。
        :param eval_time: 完成一次evaluation消耗的时间(秒)。
        :return:
        """
        if not eval_traj_infos:
            logger.log("WARNING: had no complete trajectories in eval.")
        steps_in_eval = sum([info["Length"] for info in eval_traj_infos])
        logger.record_tabular('StepsInEval', steps_in_eval)
        logger.record_tabular('TrajsInEval', len(eval_traj_infos))
        self._cum_eval_time += eval_time  # 累积的evaluation消耗时间
        logger.record_tabular('CumEvalTime', self._cum_eval_time)
        super().log_diagnostics(itr, eval_traj_infos, eval_time)
def log_diagnostics(itr, algo, agent, sampler):
    if itr > 0:
        shared_sampler_dict['angle_bound_scale'] = np.minimum(
            options.angle_bound_scale + 0.004 * (itr - 0), 1.)

    if itr > 500:
        if not options.fixed_hpos_std:
            shared_sampler_dict['hpos_std'] = np.minimum(
                options.hpos_std + 0.0005 * (itr - 500), 0.25)

    record_tabular('agent/hpos_std', shared_sampler_dict['hpos_std'])
    record_tabular('agent/angle_bound_scale',
                   shared_sampler_dict['angle_bound_scale'])

    std = agent.model.log_std.exp().data.cpu().numpy()
    for i in range(std.shape[0]):
        record_tabular('agent/std{}'.format(i), std[i])

    record_tabular_misc_stat(
        'final_obj_position_x',
        sampler.samples_np.env.observation[sampler.samples_np.env.done, 8])
示例#13
0
    def log_diagnostics(self, itr, eval_traj_infos, eval_time):
        self.save_itr_snapshot(itr)
        if not eval_traj_infos:
            logger.log("WARNING: had no complete trajectories in eval.")
        steps_in_eval = sum([info["Length"] for info in eval_traj_infos])
        logger.record_tabular('Iteration', itr)
        logger.record_tabular('CumSteps', itr * self.itr_batch_size)
        logger.record_tabular('StepsInEval', steps_in_eval)
        logger.record_tabular('TrajsInEval', len(eval_traj_infos))

        self._log_infos(eval_traj_infos)

        new_time = time.time()
        log_interval_time = new_time - self._last_time
        new_train_time = log_interval_time - eval_time
        self.cum_train_time += new_train_time
        self.cum_eval_time += eval_time
        self.cum_total_time += log_interval_time
        self._last_time = new_time
        samples_per_second = (float('nan')
                              if itr == 0 else self.log_interval_itrs *
                              self.itr_batch_size / new_train_time)

        logger.record_tabular('CumTrainTime', self.cum_train_time)
        logger.record_tabular('CumEvalTime', self.cum_eval_time)
        logger.record_tabular('CumTotalTime', self.cum_total_time)
        logger.record_tabular('SamplesPerSecond', samples_per_second)

        logger.dump_tabular(with_prefix=False)

        logger.log(f"optimizing over {self.log_interval_itrs} iterations")
        self.pbar = ProgBarCounter(self.log_interval_itrs)
示例#14
0
    def log_diagnostics(self, itr, sampler_itr, throttle_time):
        self.pbar.stop()
        self.save_itr_snapshot(itr, sampler_itr)
        new_time = time.time()
        time_elapsed = new_time - self._last_time
        new_updates = self.algo.update_counter - self._last_update_counter
        new_samples = self.sampler.batch_size * (sampler_itr - self._last_sampler_itr)
        updates_per_second = (float('nan') if itr == 0 else
            new_updates / time_elapsed)
        samples_per_second = (float('nan') if itr == 0 else
            new_samples / time_elapsed)
        if self._eval:
            new_eval_time = self.ctrl.eval_time.value
            eval_time_elapsed = new_eval_time - self._last_eval_time
            non_eval_time_elapsed = time_elapsed - eval_time_elapsed
            non_eval_samples_per_second = (float('nan') if itr == 0 else
                new_samples / non_eval_time_elapsed)
            self._last_eval_time = new_eval_time
        cum_steps = sampler_itr * self.sampler.batch_size  # No * world_size.
        replay_ratio = (new_updates * self.algo.batch_size * self.world_size /
            max(1, new_samples))
        cum_replay_ratio = (self.algo.update_counter * self.algo.batch_size *
            self.world_size / max(1, cum_steps))

        logger.record_tabular('Iteration', itr)
        logger.record_tabular('SamplerIteration', sampler_itr)
        logger.record_tabular('CumTime (s)', new_time - self._start_time)
        logger.record_tabular('CumSteps', cum_steps)
        logger.record_tabular('CumUpdates', self.algo.update_counter)
        logger.record_tabular('ReplayRatio', replay_ratio)
        logger.record_tabular('CumReplayRatio', cum_replay_ratio)
        logger.record_tabular('StepsPerSecond', samples_per_second)
        if self._eval:
            logger.record_tabular('NonEvalSamplesPerSecond', non_eval_samples_per_second)
        logger.record_tabular('UpdatesPerSecond', updates_per_second)
        logger.record_tabular('OptThrottle', (time_elapsed - throttle_time) /
            time_elapsed)

        self._log_infos()
        self._last_time = new_time
        self._last_itr = itr
        self._last_sampler_itr = sampler_itr
        self._last_update_counter = self.algo.update_counter
        logger.dump_tabular(with_prefix=False)
        logger.log(f"Optimizing over {self.log_interval_itrs} sampler "
            "iterations.")
        self.pbar = ProgBarCounter(self.log_interval_itrs)
示例#15
0
    def log_diagnostics(self,
                        itr,
                        traj_infos=None,
                        eval_time=0,
                        prefix='Diagnostics/'):
        """
        Write diagnostics (including stored ones) to csv via the logger.
        """
        if itr > 0:
            self.pbar.stop()
        if itr >= self.min_itr_learn - 1:
            self.save_itr_snapshot(itr)
        new_time = time.time()
        self._cum_time = new_time - self._start_time
        train_time_elapsed = new_time - self._last_time - eval_time
        new_updates = self.algo.update_counter - self._last_update_counter
        new_samples = (self.sampler.batch_size * self.world_size *
                       self.log_interval_itrs)
        updates_per_second = (float('nan') if itr == 0 else new_updates /
                              train_time_elapsed)
        samples_per_second = (float('nan') if itr == 0 else new_samples /
                              train_time_elapsed)
        replay_ratio = (new_updates * self.algo.batch_size * self.world_size /
                        new_samples)
        cum_replay_ratio = (self.algo.batch_size * self.algo.update_counter /
                            ((itr + 1) * self.sampler.batch_size)
                            )  # world_size cancels.
        cum_steps = (itr + 1) * self.sampler.batch_size * self.world_size

        with logger.tabular_prefix(prefix):
            if self._eval:
                logger.record_tabular(
                    'CumTrainTime', self._cum_time -
                    self._cum_eval_time)  # Already added new eval_time.
            logger.record_tabular('Iteration', itr)
            logger.record_tabular('CumTime (s)', self._cum_time)
            logger.record_tabular('CumSteps', cum_steps)
            logger.record_tabular('CumCompletedTrajs',
                                  self._cum_completed_trajs)
            logger.record_tabular('CumUpdates', self.algo.update_counter)
            logger.record_tabular('StepsPerSecond', samples_per_second)
            logger.record_tabular('UpdatesPerSecond', updates_per_second)
            logger.record_tabular('ReplayRatio', replay_ratio)
            logger.record_tabular('CumReplayRatio', cum_replay_ratio)
        self._log_infos(traj_infos)
        logger.dump_tabular(with_prefix=False)

        self._last_time = new_time
        self._last_update_counter = self.algo.update_counter
        if itr < self.n_itr - 1:
            logger.log(f"Optimizing over {self.log_interval_itrs} iterations.")
            self.pbar = ProgBarCounter(self.log_interval_itrs)
示例#16
0
    def log_diagnostics(self, itr, sample_itr, throttle_time):
        self.pbar.stop()
        self.save_itr_snapshot(itr, sample_itr)
        new_time = time.time()
        time_elapsed = new_time - self._last_time
        samples_per_second = (float('nan') if itr == 0 else
            self.log_interval_itrs * self.itr_batch_size / time_elapsed)
        updates_per_second = (float('nan') if itr == 0 else
            self.algo.updates_per_optimize * (itr - self._last_itr) / time_elapsed)
        logger.record_tabular('Iteration', itr)
        logger.record_tabular('SamplerIteration', sample_itr)
        logger.record_tabular('CumTime (s)', new_time - self._start_time)
        logger.record_tabular('CumSteps', sample_itr * self.itr_batch_size)
        logger.record_tabular('CumUpdates', itr * self.algo.updates_per_optimize)
        logger.record_tabular('SamplesPerSecond', samples_per_second)
        logger.record_tabular('UpdatesPerSecond', updates_per_second)
        logger.record_tabular('OptThrottle', (time_elapsed - throttle_time) /
            time_elapsed)

        self._log_infos()
        self._last_time = new_time
        self._last_itr = itr
        logger.dump_tabular(with_prefix=False)
        logger.log(f"Optimizing over {self.log_interval_itrs} sampler "
            "iterations.")
        self.pbar = ProgBarCounter(self.log_interval_itrs)
示例#17
0
    def log_diagnostics(self, itr, traj_infos=None, eval_time=0):
        if itr > 0:
            self.pbar.stop()
        self.save_itr_snapshot(itr)
        new_time = time.time()
        self._cum_time = new_time - self._start_time
        train_time_elapsed = new_time - self._last_time - eval_time
        new_updates = self.algo.update_counter - self._last_update_counter
        new_samples = (self.sampler.batch_size * self.world_size *
                       self.log_interval_itrs)
        updates_per_second = (float('nan') if itr == 0 else new_updates /
                              train_time_elapsed)
        samples_per_second = (float('nan') if itr == 0 else new_samples /
                              train_time_elapsed)
        replay_ratio = (new_updates * self.algo.batch_size * self.world_size /
                        new_samples)
        cum_replay_ratio = (self.algo.batch_size * self.algo.update_counter /
                            ((itr + 1) * self.sampler.batch_size)
                            )  # world_size cancels.
        cum_steps = (itr + 1) * self.sampler.batch_size * self.world_size

        if self._eval:
            logger.record_tabular(
                'CumTrainTime', self._cum_time -
                self._cum_eval_time)  # Already added new eval_time.
        logger.record_tabular('Iteration', itr)
        logger.record_tabular('CumTime (s)', self._cum_time)
        logger.record_tabular('CumSteps', cum_steps)
        logger.record_tabular('CumCompletedTrajs', self._cum_completed_trajs)
        logger.record_tabular('CumUpdates', self.algo.update_counter)
        logger.record_tabular('StepsPerSecond', samples_per_second)
        logger.record_tabular('UpdatesPerSecond', updates_per_second)
        logger.record_tabular('ReplayRatio', replay_ratio)
        logger.record_tabular('CumReplayRatio', cum_replay_ratio)
        self._cum_pyflex_steps += sum(
            getattr(info['env_infos'][-1], 'total_steps')
            for info in traj_infos)
        logger.record_tabular('CumPyflexSteps', self._cum_pyflex_steps)
        self._log_infos(traj_infos)
        logger.dump_tabular(with_prefix=False)

        self._last_time = new_time
        self._last_update_counter = self.algo.update_counter
        if itr < self.n_itr - 1:
            logger.log(f"Optimizing over {self.log_interval_itrs} iterations.")
            self.pbar = ProgBarCounter(self.log_interval_itrs)
示例#18
0
文件: __main__.py 项目: qxcv/mtil
def train(demos, add_preproc, seed, batch_size, total_n_batches,
          eval_every_n_batches, out_dir, run_name, gpu_idx, cpu_list,
          eval_n_traj, snapshot_gap, omit_noop, net_width_mul, net_use_bn,
          net_dropout, net_coord_conv, net_attention, net_task_spec_layers,
          load_policy, aug_mode, min_bc):

    # TODO: abstract setup code. Seeds & GPUs should go in one function. Env
    # setup should go in another function (or maybe the same function). Dataset
    # loading should be simplified by having a single class that can provide
    # whatever form of data the current IL method needs, without having to do
    # unnecessary copies in memory. Maybe also just use Sacred, because YOLO.

    with contextlib.ExitStack() as exit_stack:
        # set up seeds & devices
        set_seeds(seed)
        mp.set_start_method('spawn')
        use_gpu = gpu_idx is not None and torch.cuda.is_available()
        dev = torch.device(["cpu", f"cuda:{gpu_idx}"][use_gpu])
        print(f"Using device {dev}, seed {seed}")
        if cpu_list is None:
            cpu_list = sample_cpu_list()
        affinity = dict(
            cuda_idx=gpu_idx if use_gpu else None,
            workers_cpus=cpu_list,
        )

        # register original envs
        import magical
        magical.register_envs()

        # TODO: split out part of the dataset for validation.
        demos_metas_dict = get_demos_meta(demo_paths=demos,
                                          omit_noop=omit_noop,
                                          transfer_variants=[],
                                          preproc_name=add_preproc)
        dataset_mt = demos_metas_dict['dataset_mt']
        loader_mt = make_loader_mt(dataset_mt, batch_size)
        variant_groups = demos_metas_dict['variant_groups']
        env_metas = demos_metas_dict['env_metas']
        num_demo_sources = demos_metas_dict['num_demo_sources']
        task_ids_and_demo_env_names = demos_metas_dict[
            'task_ids_and_demo_env_names']
        sampler_batch_B = batch_size
        # this doesn't really matter
        sampler_batch_T = 5
        sampler, sampler_batch_B = make_mux_sampler(
            variant_groups=variant_groups,
            num_demo_sources=num_demo_sources,
            env_metas=env_metas,
            use_gpu=use_gpu,
            batch_B=sampler_batch_B,
            batch_T=sampler_batch_T,
            # TODO: instead of doing this, try sampling in proportion to length
            # of horizon; that should get more samples from harder envs
            task_var_weights=None)
        if load_policy is not None:
            try:
                pol_path = get_latest_path(load_policy)
            except ValueError:
                pol_path = load_policy
            policy_ctor = functools.partial(
                adapt_pol_loader,
                pol_path=pol_path,
                task_ids_and_demo_env_names=task_ids_and_demo_env_names)
            policy_kwargs = {}
        else:
            policy_kwargs = {
                'env_ids_and_names': task_ids_and_demo_env_names,
                'width': net_width_mul,
                'use_bn': net_use_bn,
                'dropout': net_dropout,
                'coord_conv': net_coord_conv,
                'attention': net_attention,
                'n_task_spec_layers': net_task_spec_layers,
                **get_policy_spec_magical(env_metas),
            }
            policy_ctor = MultiHeadPolicyNet
        agent = CategoricalPgAgent(ModelCls=MuxTaskModelWrapper,
                                   model_kwargs=dict(
                                       model_ctor=policy_ctor,
                                       model_kwargs=policy_kwargs))
        sampler.initialize(agent=agent,
                           seed=np.random.randint(1 << 31),
                           affinity=affinity)
        exit_stack.callback(lambda: sampler.shutdown())

        model_mt = policy_ctor(**policy_kwargs).to(dev)
        if min_bc:
            num_tasks = len(task_ids_and_demo_env_names)
            weight_mod = MinBCWeightingModule(num_tasks, num_demo_sources) \
                .to(dev)
            all_params = it.chain(model_mt.parameters(),
                                  weight_mod.parameters())
        else:
            weight_mod = None
            all_params = model_mt.parameters()
        # Adam mostly works fine, but in very loose informal tests it seems
        # like SGD had fewer weird failures where mean loss would jump up by a
        # factor of 2x for a period (?). (I don't think that was solely due to
        # high LR; probably an architectural issue.) opt_mt =
        # torch.optim.Adam(model_mt.parameters(), lr=3e-4)
        opt_mt = torch.optim.SGD(all_params, lr=1e-3, momentum=0.1)

        try:
            aug_opts = MILBenchAugmentations.PRESETS[aug_mode]
        except KeyError:
            raise ValueError(f"unsupported mode '{aug_mode}'")
        if aug_opts:
            print("Augmentations:", ", ".join(aug_opts))
            aug_model = MILBenchAugmentations(**{k: True for k in aug_opts}) \
                .to(dev)
        else:
            print("No augmentations")
            aug_model = None

        n_uniq_envs = len(task_ids_and_demo_env_names)
        log_params = {
            'n_uniq_envs': n_uniq_envs,
            'n_demos': len(demos),
            'net_use_bn': net_use_bn,
            'net_width_mul': net_width_mul,
            'net_dropout': net_dropout,
            'net_coord_conv': net_coord_conv,
            'net_attention': net_attention,
            'aug_mode': aug_mode,
            'seed': seed,
            'omit_noop': omit_noop,
            'batch_size': batch_size,
            'eval_n_traj': eval_n_traj,
            'eval_every_n_batches': eval_every_n_batches,
            'total_n_batches': total_n_batches,
            'snapshot_gap': snapshot_gap,
            'add_preproc': add_preproc,
            'net_task_spec_layers': net_task_spec_layers,
        }
        with make_logger_ctx(out_dir,
                             "mtbc",
                             f"mt{n_uniq_envs}",
                             run_name,
                             snapshot_gap=snapshot_gap,
                             log_params=log_params):
            # initial save
            torch.save(
                model_mt,
                os.path.join(logger.get_snapshot_dir(), 'full_model.pt'))

            # train for a while
            n_batches_done = 0
            n_rounds = int(np.ceil(total_n_batches / eval_every_n_batches))
            rnd = 1
            assert eval_every_n_batches > 0
            while n_batches_done < total_n_batches:
                batches_left_now = min(total_n_batches - n_batches_done,
                                       eval_every_n_batches)
                print(f"Done {n_batches_done}/{total_n_batches} "
                      f"({n_batches_done/total_n_batches*100:.2f}%, "
                      f"{rnd}/{n_rounds} rounds) batches; doing another "
                      f"{batches_left_now}")
                model_mt.train()
                loss_ewma, losses, per_task_losses = do_training_mt(
                    loader=loader_mt,
                    model=model_mt,
                    opt=opt_mt,
                    dev=dev,
                    aug_model=aug_model,
                    min_bc_module=weight_mod,
                    n_batches=batches_left_now)

                # TODO: record accuracy on a random subset of the train and
                # validation sets (both in eval mode, not train mode)

                print(f"Evaluating {eval_n_traj} trajectories on "
                      f"{variant_groups.num_tasks} tasks")
                record_misc_calls = []
                model_mt.eval()

                copy_model_into_agent_eval(model_mt, sampler.agent)
                scores_by_tv = eval_model(
                    sampler,
                    # shouldn't be any exploration
                    itr=0,
                    n_traj=eval_n_traj)
                for (task_id, variant_id), scores in scores_by_tv.items():
                    tv_id = (task_id, variant_id)
                    env_name = variant_groups.env_name_by_task_variant[tv_id]
                    tag = make_env_tag(strip_mb_preproc_name(env_name))
                    logger.record_tabular_misc_stat("Score%s" % tag, scores)
                    env_losses = per_task_losses.get(tv_id, [])
                    record_misc_calls.append((f"Loss{tag}", env_losses))

                # we record score AFTER loss so that losses are all in one
                # place, and scores are all in another
                for args in record_misc_calls:
                    logger.record_tabular_misc_stat(*args)

                # finish logging for this epoch
                logger.record_tabular("Round", rnd)
                logger.record_tabular("LossEWMA", loss_ewma)
                logger.record_tabular_misc_stat("Loss", losses)
                logger.dump_tabular()
                logger.save_itr_params(
                    rnd, {
                        'model_state': model_mt.state_dict(),
                        'opt_state': opt_mt.state_dict(),
                    })

                # advance ctrs
                rnd += 1
                n_batches_done += batches_left_now
示例#19
0
    def log_diagnostics(self, itr, eval_traj_infos, eval_time):
        """
        Write diagnostics (including stored ones) to csv via the logger.
        ONE NEW LINE VS REGULAR RUNNER, TO LOG ENV STEPS = steps*frame_skip
        """
        if not eval_traj_infos:
            logger.log("WARNING: had no complete trajectories in eval.")
        steps_in_eval = sum([info["Length"] for info in eval_traj_infos])
        logger.record_tabular('StepsInEval', steps_in_eval)
        logger.record_tabular('TrajsInEval', len(eval_traj_infos))
        self._cum_eval_time += eval_time
        logger.record_tabular('CumEvalTime', self._cum_eval_time)

        if itr > 0:
            self.pbar.stop()
        if itr >= self.min_itr_learn - 1:
            self.save_itr_snapshot(itr)
        new_time = time.time()
        self._cum_time = new_time - self._start_time
        train_time_elapsed = new_time - self._last_time - eval_time
        new_updates = self.algo.update_counter - self._last_update_counter
        new_samples = (self.sampler.batch_size * self.world_size *
            self.log_interval_itrs)
        updates_per_second = (float('nan') if itr == 0 else
            new_updates / train_time_elapsed)
        samples_per_second = (float('nan') if itr == 0 else
            new_samples / train_time_elapsed)
        replay_ratio = (new_updates * self.algo.batch_size * self.world_size /
            new_samples)
        cum_replay_ratio = (self.algo.batch_size * self.algo.update_counter /
            ((itr + 1) * self.sampler.batch_size))  # world_size cancels.
        cum_steps = (itr + 1) * self.sampler.batch_size * self.world_size

        if self._eval:
            logger.record_tabular('CumTrainTime',
                self._cum_time - self._cum_eval_time)  # Already added new eval_time.
        logger.record_tabular('Iteration', itr)
        logger.record_tabular('CumTime (s)', self._cum_time)
        logger.record_tabular('CumSteps', cum_steps)
        logger.record_tabular('EnvSteps', cum_steps * self._frame_skip)  # NEW LINE
        logger.record_tabular('CumCompletedTrajs', self._cum_completed_trajs)
        logger.record_tabular('CumUpdates', self.algo.update_counter)
        logger.record_tabular('StepsPerSecond', samples_per_second)
        logger.record_tabular('UpdatesPerSecond', updates_per_second)
        logger.record_tabular('ReplayRatio', replay_ratio)
        logger.record_tabular('CumReplayRatio', cum_replay_ratio)
        self._log_infos(eval_traj_infos)
        logger.dump_tabular(with_prefix=False)

        self._last_time = new_time
        self._last_update_counter = self.algo.update_counter
        if itr < self.n_itr - 1:
            logger.log(f"Optimizing over {self.log_interval_itrs} iterations.")
            self.pbar = ProgBarCounter(self.log_interval_itrs)
示例#20
0
    def log_diagnostics(self, itr, traj_infos=None, eval_time=0):
        """
        记录诊断信息(写日志),会把模型参数等也保存下来。

        :param itr: 第几次迭代。
        """
        if itr > 0:
            self.pbar.stop()  # 停止更新进度条
        if itr >= self.min_itr_learn - 1:
            self.save_itr_snapshot(itr)
        new_time = time.time()
        self._cum_time = new_time - self._start_time
        train_time_elapsed = new_time - self._last_time - eval_time
        new_updates = self.algo.update_counter - self._last_update_counter
        new_samples = (self.sampler.batch_size * self.world_size *
                       self.log_interval_itrs)
        updates_per_second = (float('nan') if itr == 0 else new_updates /
                              train_time_elapsed)
        samples_per_second = (float('nan') if itr == 0 else new_samples /
                              train_time_elapsed)
        replay_ratio = (new_updates * self.algo.batch_size * self.world_size /
                        new_samples)
        cum_replay_ratio = (self.algo.batch_size * self.algo.update_counter /
                            ((itr + 1) * self.sampler.batch_size)
                            )  # world_size cancels.
        cum_steps = (itr + 1) * self.sampler.batch_size * self.world_size

        # 写一些额外的统计信息到日志里
        if self._eval:
            logger.record_tabular(
                'CumTrainTime', self._cum_time -
                self._cum_eval_time)  # Already added new eval_time.
        logger.record_tabular('Iteration', itr)
        logger.record_tabular('CumTime (s)', self._cum_time)
        logger.record_tabular('CumSteps', cum_steps)
        logger.record_tabular(
            'CumCompletedTrajs',
            self._cum_completed_trajs)  # 只对那些标识了"traj_done"的有效
        logger.record_tabular('CumUpdates', self.algo.update_counter)
        logger.record_tabular('StepsPerSecond', samples_per_second)
        logger.record_tabular('UpdatesPerSecond', updates_per_second)
        logger.record_tabular('ReplayRatio', replay_ratio)
        logger.record_tabular('CumReplayRatio', cum_replay_ratio)
        self._log_infos(traj_infos)
        logger.dump_tabular(with_prefix=False)  # 写日志文件

        self._last_time = new_time
        self._last_update_counter = self.algo.update_counter
        if itr < self.n_itr - 1:
            logger.log(f"Optimizing over {self.log_interval_itrs} iterations.")
            self.pbar = ProgBarCounter(self.log_interval_itrs)  # 进度条
示例#21
0
    def log_diagnostics(self, itr):
        self.pbar.stop()
        self.save_itr_snapshot(itr)
        new_time = time.time()
        samples_per_second = (self.log_interval_itrs *
            self.itr_batch_size) / (new_time - self._last_time)
        logger.record_tabular('Iteration', itr)
        logger.record_tabular('CumSteps', (itr + 1) * self.itr_batch_size)
        logger.record_tabular('CumTime (s)', new_time - self._start_time)
        logger.record_tabular('SamplesPerSecond', samples_per_second)
        logger.record_tabular('CumCompletedTrajs', self._cum_completed_trajs)
        logger.record_tabular('NewCompletedTrajs', self._new_completed_trajs)
        logger.record_tabular('StepsInTrajWindow',
            sum(info["Length"] for info in self._traj_infos))
        self._log_infos()

        self._last_time = new_time
        logger.dump_tabular(with_prefix=False)

        self._new_completed_trajs = 0
        if itr < self.n_itr - 1:
            logger.log(f"Optimizing over {self.log_interval_itrs} iterations.")
            self.pbar = ProgBarCounter(self.log_interval_itrs)