def log_diagnostics(self, itr, prefix='Diagnostics/'): with logger.tabular_prefix(prefix): logger.record_tabular('NewCompletedTrajs', self._new_completed_trajs) logger.record_tabular('StepsInTrajWindow', sum(info["Length"] for info in self._traj_infos)) super().log_diagnostics(itr, prefix=prefix) self._new_completed_trajs = 0
def log_diagnostics(self, itr, player_eval_traj_infos, observer_eval_traj_infos, eval_time, prefix='Diagnostics/'): if not player_eval_traj_infos: logger.log("WARNING: player had no complete trajectories in eval.") if not observer_eval_traj_infos: logger.log( "WARNING: observer had no complete trajectories in eval.") player_steps_in_eval = sum( [info["Length"] for info in player_eval_traj_infos]) observer_steps_in_eval = sum( [info["Length"] for info in observer_eval_traj_infos]) with logger.tabular_prefix(prefix): logger.record_tabular('PlayerStepsInEval', player_steps_in_eval) logger.record_tabular('ObserverStepsInEval', observer_steps_in_eval) logger.record_tabular('PlayerTrajsInEval', len(player_eval_traj_infos)) logger.record_tabular('ObserverTrajsInEval', len(observer_eval_traj_infos)) self._cum_eval_time += eval_time logger.record_tabular('CumEvalTime', self._cum_eval_time) super().log_diagnostics(itr, player_eval_traj_infos, observer_eval_traj_infos, eval_time, prefix=prefix)
def log_diagnostics(self, itr): """ 记录诊断信息。 :param itr: 第几次迭代。 """ logger.record_tabular('NewCompletedTrajs', self._new_completed_trajs) logger.record_tabular('StepsInTrajWindow', sum(info["Length"] for info in self._traj_infos)) super().log_diagnostics(itr) self._new_completed_trajs = 0
def log_diagnostics(self, itr, sampler_itr, throttle_time): logger.record_tabular('CumCompletedTrajs', self._cum_completed_trajs) logger.record_tabular('NewCompletedTrajs', self._new_completed_trajs) logger.record_tabular('StepsInTrajWindow', sum(info["Length"] for info in self._traj_infos)) super().log_diagnostics(itr, sampler_itr, throttle_time) self._new_completed_trajs = 0
def log_diagnostics(self, itr, sampler_itr, throttle_time): if not self._traj_infos: logger.log("WARNING: had no complete trajectories in eval.") steps_in_eval = sum([info["Length"] for info in self._traj_infos]) logger.record_tabular('StepsInEval', steps_in_eval) logger.record_tabular('TrajsInEval', len(self._traj_infos)) logger.record_tabular('CumEvalTime', self.ctrl.eval_time.value) super().log_diagnostics(itr, sampler_itr, throttle_time) self._traj_infos = list() # Clear after each eval.
def log_diagnostics(self, itr, eval_traj_infos, eval_time): if not eval_traj_infos: logger.log("WARNING: had no complete trajectories in eval.") steps_in_eval = sum([info["Length"] for info in eval_traj_infos]) logger.record_tabular('StepsInEval', steps_in_eval) logger.record_tabular('TrajsInEval', len(eval_traj_infos)) self._cum_eval_time += eval_time logger.record_tabular('CumEvalTime', self._cum_eval_time) super().log_diagnostics(itr, eval_traj_infos, eval_time)
def log_diagnostics(self, itr, val_info, *args, **kwargs): self.save_itr_snapshot(itr) new_time = time.time() self._cum_time = new_time - self._start_time epochs = itr * self.algo.batch_size / ( self.algo.replay_buffer.size * (1 - self.algo.validation_split)) logger.record_tabular("Iteration", itr) logger.record_tabular("Epochs", epochs) logger.record_tabular("CumTime (s)", self._cum_time) logger.record_tabular("UpdatesPerSecond", itr / self._cum_time) if self._opt_infos: for k, v in self._opt_infos.items(): logger.record_tabular_misc_stat(k, v) for k, v in zip(val_info._fields, val_info): logger.record_tabular_misc_stat("val_" + k, v) self._opt_infos = {k: list() for k in self._opt_infos} # (reset) logger.dump_tabular(with_prefix=False) if itr < self.n_updates - 1: logger.log( f"Optimizing over {self.log_interval_updates} iterations.") self.pbar = ProgBarCounter(self.log_interval_updates)
def log_diagnostics(self, itr, sampler_itr, throttle_time, prefix="Diagnostics/"): if not self._traj_infos: logger.log("WARNING: had no complete trajectories in eval.") steps_in_eval = sum([info["Length"] for info in self._traj_infos]) with logger.tabular_prefix(prefix): logger.record_tabular("StepsInEval", steps_in_eval) logger.record_tabular("TrajsInEval", len(self._traj_infos)) logger.record_tabular("CumEvalTime", self.ctrl.eval_time.value) super().log_diagnostics(itr, sampler_itr, throttle_time, prefix=prefix) self._traj_infos = list() # Clear after each eval.
def log_diagnostics(self, itr, eval_traj_infos, eval_time, prefix="Diagnostics/"): if not eval_traj_infos: logger.log("WARNING: had no complete trajectories in eval.") steps_in_eval = sum([info["Length"] for info in eval_traj_infos]) with logger.tabular_prefix(prefix): logger.record_tabular("StepsInEval", steps_in_eval) logger.record_tabular("TrajsInEval", len(eval_traj_infos)) self._cum_eval_time += eval_time logger.record_tabular("CumEvalTime", self._cum_eval_time) super().log_diagnostics(itr, eval_traj_infos, eval_time, prefix=prefix)
def log_diagnostics(self, itr, sampler_itr, throttle_time, prefix="Diagnostics/"): with logger.tabular_prefix(prefix): logger.record_tabular("CumCompletedTrajs", self._cum_completed_trajs) logger.record_tabular("NewCompletedTrajs", self._new_completed_trajs) logger.record_tabular( "StepsInTrajWindow", sum(info["Length"] for info in self._traj_infos)) super().log_diagnostics(itr, sampler_itr, throttle_time, prefix=prefix) self._new_completed_trajs = 0
def log_diagnostics(self, itr, eval_traj_infos, eval_time: float): """ [此函数设计得不好,和父类的同名函数签名不一致] 记录诊断信息。此函数会写日志。 :param itr: 第几次迭代。 :param eval_traj_infos: 做evaluation的时候在trajectory中step产生的统计信息,例如非0的reward值等。 :param eval_time: 完成一次evaluation消耗的时间(秒)。 :return: """ if not eval_traj_infos: logger.log("WARNING: had no complete trajectories in eval.") steps_in_eval = sum([info["Length"] for info in eval_traj_infos]) logger.record_tabular('StepsInEval', steps_in_eval) logger.record_tabular('TrajsInEval', len(eval_traj_infos)) self._cum_eval_time += eval_time # 累积的evaluation消耗时间 logger.record_tabular('CumEvalTime', self._cum_eval_time) super().log_diagnostics(itr, eval_traj_infos, eval_time)
def log_diagnostics(itr, algo, agent, sampler): if itr > 0: shared_sampler_dict['angle_bound_scale'] = np.minimum( options.angle_bound_scale + 0.004 * (itr - 0), 1.) if itr > 500: if not options.fixed_hpos_std: shared_sampler_dict['hpos_std'] = np.minimum( options.hpos_std + 0.0005 * (itr - 500), 0.25) record_tabular('agent/hpos_std', shared_sampler_dict['hpos_std']) record_tabular('agent/angle_bound_scale', shared_sampler_dict['angle_bound_scale']) std = agent.model.log_std.exp().data.cpu().numpy() for i in range(std.shape[0]): record_tabular('agent/std{}'.format(i), std[i]) record_tabular_misc_stat( 'final_obj_position_x', sampler.samples_np.env.observation[sampler.samples_np.env.done, 8])
def log_diagnostics(self, itr, eval_traj_infos, eval_time): self.save_itr_snapshot(itr) if not eval_traj_infos: logger.log("WARNING: had no complete trajectories in eval.") steps_in_eval = sum([info["Length"] for info in eval_traj_infos]) logger.record_tabular('Iteration', itr) logger.record_tabular('CumSteps', itr * self.itr_batch_size) logger.record_tabular('StepsInEval', steps_in_eval) logger.record_tabular('TrajsInEval', len(eval_traj_infos)) self._log_infos(eval_traj_infos) new_time = time.time() log_interval_time = new_time - self._last_time new_train_time = log_interval_time - eval_time self.cum_train_time += new_train_time self.cum_eval_time += eval_time self.cum_total_time += log_interval_time self._last_time = new_time samples_per_second = (float('nan') if itr == 0 else self.log_interval_itrs * self.itr_batch_size / new_train_time) logger.record_tabular('CumTrainTime', self.cum_train_time) logger.record_tabular('CumEvalTime', self.cum_eval_time) logger.record_tabular('CumTotalTime', self.cum_total_time) logger.record_tabular('SamplesPerSecond', samples_per_second) logger.dump_tabular(with_prefix=False) logger.log(f"optimizing over {self.log_interval_itrs} iterations") self.pbar = ProgBarCounter(self.log_interval_itrs)
def log_diagnostics(self, itr, sampler_itr, throttle_time): self.pbar.stop() self.save_itr_snapshot(itr, sampler_itr) new_time = time.time() time_elapsed = new_time - self._last_time new_updates = self.algo.update_counter - self._last_update_counter new_samples = self.sampler.batch_size * (sampler_itr - self._last_sampler_itr) updates_per_second = (float('nan') if itr == 0 else new_updates / time_elapsed) samples_per_second = (float('nan') if itr == 0 else new_samples / time_elapsed) if self._eval: new_eval_time = self.ctrl.eval_time.value eval_time_elapsed = new_eval_time - self._last_eval_time non_eval_time_elapsed = time_elapsed - eval_time_elapsed non_eval_samples_per_second = (float('nan') if itr == 0 else new_samples / non_eval_time_elapsed) self._last_eval_time = new_eval_time cum_steps = sampler_itr * self.sampler.batch_size # No * world_size. replay_ratio = (new_updates * self.algo.batch_size * self.world_size / max(1, new_samples)) cum_replay_ratio = (self.algo.update_counter * self.algo.batch_size * self.world_size / max(1, cum_steps)) logger.record_tabular('Iteration', itr) logger.record_tabular('SamplerIteration', sampler_itr) logger.record_tabular('CumTime (s)', new_time - self._start_time) logger.record_tabular('CumSteps', cum_steps) logger.record_tabular('CumUpdates', self.algo.update_counter) logger.record_tabular('ReplayRatio', replay_ratio) logger.record_tabular('CumReplayRatio', cum_replay_ratio) logger.record_tabular('StepsPerSecond', samples_per_second) if self._eval: logger.record_tabular('NonEvalSamplesPerSecond', non_eval_samples_per_second) logger.record_tabular('UpdatesPerSecond', updates_per_second) logger.record_tabular('OptThrottle', (time_elapsed - throttle_time) / time_elapsed) self._log_infos() self._last_time = new_time self._last_itr = itr self._last_sampler_itr = sampler_itr self._last_update_counter = self.algo.update_counter logger.dump_tabular(with_prefix=False) logger.log(f"Optimizing over {self.log_interval_itrs} sampler " "iterations.") self.pbar = ProgBarCounter(self.log_interval_itrs)
def log_diagnostics(self, itr, traj_infos=None, eval_time=0, prefix='Diagnostics/'): """ Write diagnostics (including stored ones) to csv via the logger. """ if itr > 0: self.pbar.stop() if itr >= self.min_itr_learn - 1: self.save_itr_snapshot(itr) new_time = time.time() self._cum_time = new_time - self._start_time train_time_elapsed = new_time - self._last_time - eval_time new_updates = self.algo.update_counter - self._last_update_counter new_samples = (self.sampler.batch_size * self.world_size * self.log_interval_itrs) updates_per_second = (float('nan') if itr == 0 else new_updates / train_time_elapsed) samples_per_second = (float('nan') if itr == 0 else new_samples / train_time_elapsed) replay_ratio = (new_updates * self.algo.batch_size * self.world_size / new_samples) cum_replay_ratio = (self.algo.batch_size * self.algo.update_counter / ((itr + 1) * self.sampler.batch_size) ) # world_size cancels. cum_steps = (itr + 1) * self.sampler.batch_size * self.world_size with logger.tabular_prefix(prefix): if self._eval: logger.record_tabular( 'CumTrainTime', self._cum_time - self._cum_eval_time) # Already added new eval_time. logger.record_tabular('Iteration', itr) logger.record_tabular('CumTime (s)', self._cum_time) logger.record_tabular('CumSteps', cum_steps) logger.record_tabular('CumCompletedTrajs', self._cum_completed_trajs) logger.record_tabular('CumUpdates', self.algo.update_counter) logger.record_tabular('StepsPerSecond', samples_per_second) logger.record_tabular('UpdatesPerSecond', updates_per_second) logger.record_tabular('ReplayRatio', replay_ratio) logger.record_tabular('CumReplayRatio', cum_replay_ratio) self._log_infos(traj_infos) logger.dump_tabular(with_prefix=False) self._last_time = new_time self._last_update_counter = self.algo.update_counter if itr < self.n_itr - 1: logger.log(f"Optimizing over {self.log_interval_itrs} iterations.") self.pbar = ProgBarCounter(self.log_interval_itrs)
def log_diagnostics(self, itr, sample_itr, throttle_time): self.pbar.stop() self.save_itr_snapshot(itr, sample_itr) new_time = time.time() time_elapsed = new_time - self._last_time samples_per_second = (float('nan') if itr == 0 else self.log_interval_itrs * self.itr_batch_size / time_elapsed) updates_per_second = (float('nan') if itr == 0 else self.algo.updates_per_optimize * (itr - self._last_itr) / time_elapsed) logger.record_tabular('Iteration', itr) logger.record_tabular('SamplerIteration', sample_itr) logger.record_tabular('CumTime (s)', new_time - self._start_time) logger.record_tabular('CumSteps', sample_itr * self.itr_batch_size) logger.record_tabular('CumUpdates', itr * self.algo.updates_per_optimize) logger.record_tabular('SamplesPerSecond', samples_per_second) logger.record_tabular('UpdatesPerSecond', updates_per_second) logger.record_tabular('OptThrottle', (time_elapsed - throttle_time) / time_elapsed) self._log_infos() self._last_time = new_time self._last_itr = itr logger.dump_tabular(with_prefix=False) logger.log(f"Optimizing over {self.log_interval_itrs} sampler " "iterations.") self.pbar = ProgBarCounter(self.log_interval_itrs)
def log_diagnostics(self, itr, traj_infos=None, eval_time=0): if itr > 0: self.pbar.stop() self.save_itr_snapshot(itr) new_time = time.time() self._cum_time = new_time - self._start_time train_time_elapsed = new_time - self._last_time - eval_time new_updates = self.algo.update_counter - self._last_update_counter new_samples = (self.sampler.batch_size * self.world_size * self.log_interval_itrs) updates_per_second = (float('nan') if itr == 0 else new_updates / train_time_elapsed) samples_per_second = (float('nan') if itr == 0 else new_samples / train_time_elapsed) replay_ratio = (new_updates * self.algo.batch_size * self.world_size / new_samples) cum_replay_ratio = (self.algo.batch_size * self.algo.update_counter / ((itr + 1) * self.sampler.batch_size) ) # world_size cancels. cum_steps = (itr + 1) * self.sampler.batch_size * self.world_size if self._eval: logger.record_tabular( 'CumTrainTime', self._cum_time - self._cum_eval_time) # Already added new eval_time. logger.record_tabular('Iteration', itr) logger.record_tabular('CumTime (s)', self._cum_time) logger.record_tabular('CumSteps', cum_steps) logger.record_tabular('CumCompletedTrajs', self._cum_completed_trajs) logger.record_tabular('CumUpdates', self.algo.update_counter) logger.record_tabular('StepsPerSecond', samples_per_second) logger.record_tabular('UpdatesPerSecond', updates_per_second) logger.record_tabular('ReplayRatio', replay_ratio) logger.record_tabular('CumReplayRatio', cum_replay_ratio) self._cum_pyflex_steps += sum( getattr(info['env_infos'][-1], 'total_steps') for info in traj_infos) logger.record_tabular('CumPyflexSteps', self._cum_pyflex_steps) self._log_infos(traj_infos) logger.dump_tabular(with_prefix=False) self._last_time = new_time self._last_update_counter = self.algo.update_counter if itr < self.n_itr - 1: logger.log(f"Optimizing over {self.log_interval_itrs} iterations.") self.pbar = ProgBarCounter(self.log_interval_itrs)
def train(demos, add_preproc, seed, batch_size, total_n_batches, eval_every_n_batches, out_dir, run_name, gpu_idx, cpu_list, eval_n_traj, snapshot_gap, omit_noop, net_width_mul, net_use_bn, net_dropout, net_coord_conv, net_attention, net_task_spec_layers, load_policy, aug_mode, min_bc): # TODO: abstract setup code. Seeds & GPUs should go in one function. Env # setup should go in another function (or maybe the same function). Dataset # loading should be simplified by having a single class that can provide # whatever form of data the current IL method needs, without having to do # unnecessary copies in memory. Maybe also just use Sacred, because YOLO. with contextlib.ExitStack() as exit_stack: # set up seeds & devices set_seeds(seed) mp.set_start_method('spawn') use_gpu = gpu_idx is not None and torch.cuda.is_available() dev = torch.device(["cpu", f"cuda:{gpu_idx}"][use_gpu]) print(f"Using device {dev}, seed {seed}") if cpu_list is None: cpu_list = sample_cpu_list() affinity = dict( cuda_idx=gpu_idx if use_gpu else None, workers_cpus=cpu_list, ) # register original envs import magical magical.register_envs() # TODO: split out part of the dataset for validation. demos_metas_dict = get_demos_meta(demo_paths=demos, omit_noop=omit_noop, transfer_variants=[], preproc_name=add_preproc) dataset_mt = demos_metas_dict['dataset_mt'] loader_mt = make_loader_mt(dataset_mt, batch_size) variant_groups = demos_metas_dict['variant_groups'] env_metas = demos_metas_dict['env_metas'] num_demo_sources = demos_metas_dict['num_demo_sources'] task_ids_and_demo_env_names = demos_metas_dict[ 'task_ids_and_demo_env_names'] sampler_batch_B = batch_size # this doesn't really matter sampler_batch_T = 5 sampler, sampler_batch_B = make_mux_sampler( variant_groups=variant_groups, num_demo_sources=num_demo_sources, env_metas=env_metas, use_gpu=use_gpu, batch_B=sampler_batch_B, batch_T=sampler_batch_T, # TODO: instead of doing this, try sampling in proportion to length # of horizon; that should get more samples from harder envs task_var_weights=None) if load_policy is not None: try: pol_path = get_latest_path(load_policy) except ValueError: pol_path = load_policy policy_ctor = functools.partial( adapt_pol_loader, pol_path=pol_path, task_ids_and_demo_env_names=task_ids_and_demo_env_names) policy_kwargs = {} else: policy_kwargs = { 'env_ids_and_names': task_ids_and_demo_env_names, 'width': net_width_mul, 'use_bn': net_use_bn, 'dropout': net_dropout, 'coord_conv': net_coord_conv, 'attention': net_attention, 'n_task_spec_layers': net_task_spec_layers, **get_policy_spec_magical(env_metas), } policy_ctor = MultiHeadPolicyNet agent = CategoricalPgAgent(ModelCls=MuxTaskModelWrapper, model_kwargs=dict( model_ctor=policy_ctor, model_kwargs=policy_kwargs)) sampler.initialize(agent=agent, seed=np.random.randint(1 << 31), affinity=affinity) exit_stack.callback(lambda: sampler.shutdown()) model_mt = policy_ctor(**policy_kwargs).to(dev) if min_bc: num_tasks = len(task_ids_and_demo_env_names) weight_mod = MinBCWeightingModule(num_tasks, num_demo_sources) \ .to(dev) all_params = it.chain(model_mt.parameters(), weight_mod.parameters()) else: weight_mod = None all_params = model_mt.parameters() # Adam mostly works fine, but in very loose informal tests it seems # like SGD had fewer weird failures where mean loss would jump up by a # factor of 2x for a period (?). (I don't think that was solely due to # high LR; probably an architectural issue.) opt_mt = # torch.optim.Adam(model_mt.parameters(), lr=3e-4) opt_mt = torch.optim.SGD(all_params, lr=1e-3, momentum=0.1) try: aug_opts = MILBenchAugmentations.PRESETS[aug_mode] except KeyError: raise ValueError(f"unsupported mode '{aug_mode}'") if aug_opts: print("Augmentations:", ", ".join(aug_opts)) aug_model = MILBenchAugmentations(**{k: True for k in aug_opts}) \ .to(dev) else: print("No augmentations") aug_model = None n_uniq_envs = len(task_ids_and_demo_env_names) log_params = { 'n_uniq_envs': n_uniq_envs, 'n_demos': len(demos), 'net_use_bn': net_use_bn, 'net_width_mul': net_width_mul, 'net_dropout': net_dropout, 'net_coord_conv': net_coord_conv, 'net_attention': net_attention, 'aug_mode': aug_mode, 'seed': seed, 'omit_noop': omit_noop, 'batch_size': batch_size, 'eval_n_traj': eval_n_traj, 'eval_every_n_batches': eval_every_n_batches, 'total_n_batches': total_n_batches, 'snapshot_gap': snapshot_gap, 'add_preproc': add_preproc, 'net_task_spec_layers': net_task_spec_layers, } with make_logger_ctx(out_dir, "mtbc", f"mt{n_uniq_envs}", run_name, snapshot_gap=snapshot_gap, log_params=log_params): # initial save torch.save( model_mt, os.path.join(logger.get_snapshot_dir(), 'full_model.pt')) # train for a while n_batches_done = 0 n_rounds = int(np.ceil(total_n_batches / eval_every_n_batches)) rnd = 1 assert eval_every_n_batches > 0 while n_batches_done < total_n_batches: batches_left_now = min(total_n_batches - n_batches_done, eval_every_n_batches) print(f"Done {n_batches_done}/{total_n_batches} " f"({n_batches_done/total_n_batches*100:.2f}%, " f"{rnd}/{n_rounds} rounds) batches; doing another " f"{batches_left_now}") model_mt.train() loss_ewma, losses, per_task_losses = do_training_mt( loader=loader_mt, model=model_mt, opt=opt_mt, dev=dev, aug_model=aug_model, min_bc_module=weight_mod, n_batches=batches_left_now) # TODO: record accuracy on a random subset of the train and # validation sets (both in eval mode, not train mode) print(f"Evaluating {eval_n_traj} trajectories on " f"{variant_groups.num_tasks} tasks") record_misc_calls = [] model_mt.eval() copy_model_into_agent_eval(model_mt, sampler.agent) scores_by_tv = eval_model( sampler, # shouldn't be any exploration itr=0, n_traj=eval_n_traj) for (task_id, variant_id), scores in scores_by_tv.items(): tv_id = (task_id, variant_id) env_name = variant_groups.env_name_by_task_variant[tv_id] tag = make_env_tag(strip_mb_preproc_name(env_name)) logger.record_tabular_misc_stat("Score%s" % tag, scores) env_losses = per_task_losses.get(tv_id, []) record_misc_calls.append((f"Loss{tag}", env_losses)) # we record score AFTER loss so that losses are all in one # place, and scores are all in another for args in record_misc_calls: logger.record_tabular_misc_stat(*args) # finish logging for this epoch logger.record_tabular("Round", rnd) logger.record_tabular("LossEWMA", loss_ewma) logger.record_tabular_misc_stat("Loss", losses) logger.dump_tabular() logger.save_itr_params( rnd, { 'model_state': model_mt.state_dict(), 'opt_state': opt_mt.state_dict(), }) # advance ctrs rnd += 1 n_batches_done += batches_left_now
def log_diagnostics(self, itr, eval_traj_infos, eval_time): """ Write diagnostics (including stored ones) to csv via the logger. ONE NEW LINE VS REGULAR RUNNER, TO LOG ENV STEPS = steps*frame_skip """ if not eval_traj_infos: logger.log("WARNING: had no complete trajectories in eval.") steps_in_eval = sum([info["Length"] for info in eval_traj_infos]) logger.record_tabular('StepsInEval', steps_in_eval) logger.record_tabular('TrajsInEval', len(eval_traj_infos)) self._cum_eval_time += eval_time logger.record_tabular('CumEvalTime', self._cum_eval_time) if itr > 0: self.pbar.stop() if itr >= self.min_itr_learn - 1: self.save_itr_snapshot(itr) new_time = time.time() self._cum_time = new_time - self._start_time train_time_elapsed = new_time - self._last_time - eval_time new_updates = self.algo.update_counter - self._last_update_counter new_samples = (self.sampler.batch_size * self.world_size * self.log_interval_itrs) updates_per_second = (float('nan') if itr == 0 else new_updates / train_time_elapsed) samples_per_second = (float('nan') if itr == 0 else new_samples / train_time_elapsed) replay_ratio = (new_updates * self.algo.batch_size * self.world_size / new_samples) cum_replay_ratio = (self.algo.batch_size * self.algo.update_counter / ((itr + 1) * self.sampler.batch_size)) # world_size cancels. cum_steps = (itr + 1) * self.sampler.batch_size * self.world_size if self._eval: logger.record_tabular('CumTrainTime', self._cum_time - self._cum_eval_time) # Already added new eval_time. logger.record_tabular('Iteration', itr) logger.record_tabular('CumTime (s)', self._cum_time) logger.record_tabular('CumSteps', cum_steps) logger.record_tabular('EnvSteps', cum_steps * self._frame_skip) # NEW LINE logger.record_tabular('CumCompletedTrajs', self._cum_completed_trajs) logger.record_tabular('CumUpdates', self.algo.update_counter) logger.record_tabular('StepsPerSecond', samples_per_second) logger.record_tabular('UpdatesPerSecond', updates_per_second) logger.record_tabular('ReplayRatio', replay_ratio) logger.record_tabular('CumReplayRatio', cum_replay_ratio) self._log_infos(eval_traj_infos) logger.dump_tabular(with_prefix=False) self._last_time = new_time self._last_update_counter = self.algo.update_counter if itr < self.n_itr - 1: logger.log(f"Optimizing over {self.log_interval_itrs} iterations.") self.pbar = ProgBarCounter(self.log_interval_itrs)
def log_diagnostics(self, itr, traj_infos=None, eval_time=0): """ 记录诊断信息(写日志),会把模型参数等也保存下来。 :param itr: 第几次迭代。 """ if itr > 0: self.pbar.stop() # 停止更新进度条 if itr >= self.min_itr_learn - 1: self.save_itr_snapshot(itr) new_time = time.time() self._cum_time = new_time - self._start_time train_time_elapsed = new_time - self._last_time - eval_time new_updates = self.algo.update_counter - self._last_update_counter new_samples = (self.sampler.batch_size * self.world_size * self.log_interval_itrs) updates_per_second = (float('nan') if itr == 0 else new_updates / train_time_elapsed) samples_per_second = (float('nan') if itr == 0 else new_samples / train_time_elapsed) replay_ratio = (new_updates * self.algo.batch_size * self.world_size / new_samples) cum_replay_ratio = (self.algo.batch_size * self.algo.update_counter / ((itr + 1) * self.sampler.batch_size) ) # world_size cancels. cum_steps = (itr + 1) * self.sampler.batch_size * self.world_size # 写一些额外的统计信息到日志里 if self._eval: logger.record_tabular( 'CumTrainTime', self._cum_time - self._cum_eval_time) # Already added new eval_time. logger.record_tabular('Iteration', itr) logger.record_tabular('CumTime (s)', self._cum_time) logger.record_tabular('CumSteps', cum_steps) logger.record_tabular( 'CumCompletedTrajs', self._cum_completed_trajs) # 只对那些标识了"traj_done"的有效 logger.record_tabular('CumUpdates', self.algo.update_counter) logger.record_tabular('StepsPerSecond', samples_per_second) logger.record_tabular('UpdatesPerSecond', updates_per_second) logger.record_tabular('ReplayRatio', replay_ratio) logger.record_tabular('CumReplayRatio', cum_replay_ratio) self._log_infos(traj_infos) logger.dump_tabular(with_prefix=False) # 写日志文件 self._last_time = new_time self._last_update_counter = self.algo.update_counter if itr < self.n_itr - 1: logger.log(f"Optimizing over {self.log_interval_itrs} iterations.") self.pbar = ProgBarCounter(self.log_interval_itrs) # 进度条
def log_diagnostics(self, itr): self.pbar.stop() self.save_itr_snapshot(itr) new_time = time.time() samples_per_second = (self.log_interval_itrs * self.itr_batch_size) / (new_time - self._last_time) logger.record_tabular('Iteration', itr) logger.record_tabular('CumSteps', (itr + 1) * self.itr_batch_size) logger.record_tabular('CumTime (s)', new_time - self._start_time) logger.record_tabular('SamplesPerSecond', samples_per_second) logger.record_tabular('CumCompletedTrajs', self._cum_completed_trajs) logger.record_tabular('NewCompletedTrajs', self._new_completed_trajs) logger.record_tabular('StepsInTrajWindow', sum(info["Length"] for info in self._traj_infos)) self._log_infos() self._last_time = new_time logger.dump_tabular(with_prefix=False) self._new_completed_trajs = 0 if itr < self.n_itr - 1: logger.log(f"Optimizing over {self.log_interval_itrs} iterations.") self.pbar = ProgBarCounter(self.log_interval_itrs)