def __init__(self, is_master=None): if is_master is None: self.is_master = heyhi.is_master() else: self.is_master = is_master if self.is_master: self.writer = torch.utils.tensorboard.SummaryWriter(log_dir="tb") self.jsonl_writer = open("metrics.jsonl", "a")
def main(task, cfg): heyhi.setup_logging() logging.info("Cwd: %s", os.getcwd()) logging.info("Task: %s", task) logging.info("Cfg:\n%s", cfg) heyhi.log_git_status() logging.info("Is on slurm: %s", heyhi.is_on_slurm()) if heyhi.is_on_slurm(): logging.info("Slurm job id: %s", heyhi.get_slurm_job_id()) logging.info("Is master: %s", heyhi.is_master()) if task not in TASKS: raise ValueError("Unknown task: %s. Known tasks: %s" % (task, sorted(TASKS))) return TASKS[task](cfg)
def main(cfg): heyhi.setup_logging() logging.info("CWD: %s", os.getcwd()) logging.info("cfg:\n%s", cfg.pretty()) resource.setrlimit(resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY)) logging.info("resource.RLIMIT_CORE: %s", resource.RLIMIT_CORE) heyhi.log_git_status() logging.info("Is AWS: %s", heyhi.is_aws()) logging.info("is on slurm:%s", heyhi.is_on_slurm()) if heyhi.is_on_slurm(): logging.info("Slurm job id: %s", heyhi.get_slurm_job_id()) logging.info("Is master: %s", heyhi.is_master()) task = getattr(cfvpy.tasks, cfg.task) return task(cfg)
def main(cfg): print("HERHEHRHEHREHRHERHEHR") heyhi.setup_logging() logging.info("CWD: %s", os.getcwd()) logging.info("cfg:\n%s", cfg.pretty()) #resource.setrlimit( # resource.RLIMIT_CORE, (16, 16) #) logging.info("resource.RLIMIT_CORE: %s", resource.RLIMIT_CORE) heyhi.log_git_status() logging.info("Is AWS: %s", heyhi.is_aws()) logging.info("is on slurm:%s", heyhi.is_on_slurm()) if heyhi.is_on_slurm(): logging.info("Slurm job id: %s", heyhi.get_slurm_job_id()) logging.info("Is master: %s", heyhi.is_master()) #task = getattr(cfvpy.tasks, cfg.task) #return task(cfg) return cfvpy.tasks.selfplay(cfg)
def __call__(self): cfg = self.cfg if not torch.cuda.is_available(): logging.warning("No CUDA found!") device = "cpu" else: device = "cuda" # Training device. self._init_state(device) if heyhi.is_master(): CKPT_DIR.mkdir(exist_ok=True, parents=True) if REQUEUE_CKPT.exists(): logging.info("Found requeue checkpoint: %s", REQUEUE_CKPT.resolve()) self.state = TrainerState.load(REQUEUE_CKPT, self.state, device) elif CKPT_DIR.exists() and list(CKPT_DIR.iterdir()): last_ckpt = max(CKPT_DIR.iterdir(), key=str) logging.info( "Found existing checkpoint folder. Will load last one: %s", last_ckpt) self.state = TrainerState.load(last_ckpt, self.state, device) else: logging.info("No checkpoints to restart from found") # TODO(akhti): wipe the folder. # TODO(akhti): handle multi machine. ckpt_syncer = CkptSyncer( fairdiplomacy.selfplay.data_loader.get_ckpt_sync_dir(), create_dir=True) ckpt_syncer.save_state_dict(self.state.model) logger = self.logger = fairdiplomacy.selfplay.metrics.Logger() data_loader = self.data_loader = fairdiplomacy.selfplay.data_loader.DataLoader( cfg.model_path, cfg.rollout) self.state.model.train() if self.cfg.trainer.train_as_eval: self.state.model.eval() # Cast cuDNN RNN back to train mode. self.state.model.apply(_lstm_to_train) elif self.cfg.trainer.train_encoder_as_eval: self.state.model.encoder.eval() elif self.cfg.trainer.train_decoder_as_eval: self.state.model.policy_decoder.eval() self.state.model.policy_decoder.apply(_lstm_to_train) elif self.cfg.trainer.train_as_eval_but_batchnorm: self.state.model.eval() self.state.model.apply(_lstm_to_train) self.state.model.apply(_bn_to_train) max_epochs = self.cfg.trainer.max_epochs or 10**9 self.state.model.to(device) for self.state.epoch_id in range(self.state.epoch_id, max_epochs): # Clone state each epoch in case we'll need to requeue. self.state.save(REQUEUE_CKPT) if (self.cfg.trainer.save_checkpoint_every and self.state.epoch_id % self.cfg.trainer.save_checkpoint_every == 0): self.state.save(CKPT_DIR / (CKPT_TPL % self.state.epoch_id)) # Coutner accumulate different statistic over the epoch. Default # accumulation strategy is averaging. counters = collections.defaultdict( fairdiplomacy.selfplay.metrics.FractionCounter) use_grad_clip = self.cfg.optimizer.grad_clip > 1e-10 if use_grad_clip: counters[ "optim/grad_max"] = grad_max_counter = fairdiplomacy.selfplay.metrics.MaxCounter( ) counters[ "score/num_games"] = fairdiplomacy.selfplay.metrics.SumCounter( ) for p in POWERS: counters[ f"score_{p}/num_games"] = fairdiplomacy.selfplay.metrics.SumCounter( ) # For LR just record it's value at the start of the epoch. counters["optim/lr"].update( next(iter(self.state.optimizer.param_groups))["lr"]) epoch_start_time = time.time() for _ in range(self.cfg.trainer.epoch_size): timings = TimingCtx() with timings("data_gen"): ( (power_ids, obs, rewards, actions, behavior_action_logprobs, done), rollout_scores_per_power, ) = data_loader.get_batch() with timings("to_cuda"): actions = actions.to(device) rewards = rewards.to(device) power_ids = power_ids.to(device) obs = {k: v.to(device) for k, v in obs.items()} cand_actions = obs.pop("cand_indices") behavior_action_logprobs = behavior_action_logprobs.to( device) done = done.to(device) with timings("net"): # Shape: _, [B, 17], [B, S, 469], [B, 7]. # policy_cand_actions has the same information as actions, # but uses local indices to match policy logits. assert EOS_IDX == -1, "Rewrite the code to remove the assumption" _, _, policy_logits, sc_values = self.state.model( **obs, temperature=1.0, teacher_force_orders=actions.clamp( 0), # EOS_IDX = -1 -> 0 x_power=to_onehot(power_ids, len(POWERS)), ) cand_actions = cand_actions[:, :policy_logits.shape[1]] # Shape: [B]. sc_values = sc_values.gather( 1, power_ids.unsqueeze(1)).squeeze(1) # Removing absolute order ids to not use them by accident. # Will use relative order ids (cand_actions) from now on. del actions if self.cfg.rollout.do_not_split_rollouts: # Asssumes that episode actually ends. bootstrap_value = torch.zeros_like(sc_values[-1]) else: # Reducing batch size by one. Deleting things that are # too lazy to adjsut to avoid artifacts. bootstrap_value = sc_values[-1].detach() sc_values = sc_values[:-1] cand_actions = cand_actions[:-1] policy_logits = policy_logits[:-1] rewards = rewards[:-1] power_ids = power_ids[:-1] del obs behavior_action_logprobs = behavior_action_logprobs[: -1] done = done[:-1] # Shape: [B]. discounts = (~done).float() * self.cfg.discounting # Shape: [B, 17]. mask = (cand_actions != EOS_IDX).float() # Shape: [B]. policy_action_logprobs = order_logits_to_action_logprobs( policy_logits, cand_actions, mask) vtrace_returns = vtrace_from_logprobs_no_batch( log_rhos=policy_action_logprobs - behavior_action_logprobs, discounts=discounts, rewards=rewards, values=sc_values, bootstrap_value=bootstrap_value, ) critic_mses = 0.5 * ( (vtrace_returns.vs.detach() - sc_values)**2) losses = dict( actor=compute_policy_gradient_loss( policy_action_logprobs, vtrace_returns.pg_advantages), critic=critic_mses.mean(), # TODO(akhti): it's incorrect to apply this to # per-position order distribution instead of action # distribution. entropy=compute_entropy_loss(policy_logits, mask), ) loss = (losses["actor"] + cfg.critic_weight * losses["critic"] + cfg.entropy_weight * losses["entropy"]) if cfg.sampled_entropy_weight: loss = loss + cfg.sampled_entropy_weight * compute_sampled_entropy_loss( policy_action_logprobs) self.state.optimizer.zero_grad() loss.backward() if use_grad_clip: g_norm_tensor = clip_grad_norm_( self.state.model.parameters(), self.cfg.optimizer.grad_clip) if (not self.cfg.trainer.max_updates or self.state.global_step < self.cfg.trainer.max_updates): self.state.optimizer.step() # Sync to make sure timing is correct. loss.item() with timings("metrics"), torch.no_grad(): last_count = done.long().sum() critic_end_mses = critic_mses[done].sum() if use_grad_clip: g_norm = g_norm_tensor.item() grad_max_counter.update(g_norm) counters["optim/grad_mean"].update(g_norm) counters["optim/grad_clip_ratio"].update( int(g_norm >= self.cfg.optimizer.grad_clip - 1e-5)) for key, value in losses.items(): counters[f"loss/{key}"].update(value) counters[f"loss/total"].update(loss.item()) for power_id, rollout_scores in rollout_scores_per_power.items( ): prefix = f"score_{POWERS[power_id]}" if power_id is not None else "score" for key, value in rollout_scores.items(): if key != "num_games": counters[f"{prefix}/{key}"].update( value, rollout_scores["num_games"]) else: counters[f"{prefix}/{key}"].update(value) counters["loss/critic_last"].update( critic_end_mses, last_count) counters["reward/mean"].update(rewards.sum(), len(rewards)) # Rewards at the end of episodes. We precompute everything # before adding to counters to pipeline things when # possible. last_rewards = rewards[done] last_sum = last_rewards.sum() # tensor [num_powers, num_dones]. last_power_masks = ( power_ids[done].unsqueeze(0) == torch.arange( len(POWERS), device=power_ids.device).unsqueeze(1)).float() last_power_rewards = (last_power_masks * last_rewards.unsqueeze(0)).sum(1) last_power_counts = last_power_masks.sum(1) counters["reward/last"].update(last_sum, last_count) for power, reward, counts in zip(POWERS, last_power_rewards.cpu(), last_power_counts.cpu()): counters[f"reward/last_{power}"].update(reward, counts) # To match entropy loss we don't negate logprobs. So this # is an estimate of the negative entropy. counters["loss/entropy_sampled"].update( policy_action_logprobs.mean()) # Measure off-policiness. counters["loss/rho"].update(vtrace_returns.rhos.sum(), vtrace_returns.rhos.numel()) counters["loss/rhos_clipped"].update( vtrace_returns.clipped_rhos.sum(), vtrace_returns.clipped_rhos.numel()) bsz = len(rewards) counters["size/batch"].update(bsz) counters["size/episode"].update(bsz, last_count) with timings("sync"), torch.no_grad(): if self.state.global_step % self.cfg.trainer.save_sync_checkpoint_every == 0: ckpt_syncer.save_state_dict(self.state.model) # Doing outside of the context to capture the context's timing. for key, value in timings.items(): counters[f"time/{key}"].update(value) if (self.state.global_step < 128 or (self.state.global_step & self.state.global_step + 1) == 0): logging.info( "Metrics (global_step=%d): %s", self.state.global_step, {k: v.value() for k, v in sorted(counters.items())}, ) self.state.global_step += 1 epoch_scalars = {k: v.value() for k, v in sorted(counters.items())} average_batch_size = epoch_scalars["size/batch"] epoch_scalars["speed/loop_bps"] = self.cfg.trainer.epoch_size / ( time.time() - epoch_start_time + 1e-5) epoch_scalars["speed/loop_eps"] = epoch_scalars[ "speed/loop_bps"] * average_batch_size # Speed for to_cuda + forward + backward. torch_time = epoch_scalars["time/net"] + epoch_scalars[ "time/to_cuda"] epoch_scalars["speed/train_bps"] = 1.0 / torch_time epoch_scalars["speed/train_eps"] = average_batch_size / torch_time eval_scores = data_loader.extract_eval_scores() if eval_scores is not None: for k, v in eval_scores.items(): epoch_scalars[f"score_eval/{k}"] = v logging.info( "Finished epoch %d. Metrics:\n%s", self.state.epoch_id, format_metrics_for_print(epoch_scalars), ) logger.log_metrics(epoch_scalars, self.state.epoch_id) logging.info("End of training") data_loader.terminate() logging.info("Exiting main funcion")