def __init__(self, is_master=None):
     if is_master is None:
         self.is_master = heyhi.is_master()
     else:
         self.is_master = is_master
     if self.is_master:
         self.writer = torch.utils.tensorboard.SummaryWriter(log_dir="tb")
         self.jsonl_writer = open("metrics.jsonl", "a")
Esempio n. 2
0
def main(task, cfg):
    heyhi.setup_logging()
    logging.info("Cwd: %s", os.getcwd())
    logging.info("Task: %s", task)
    logging.info("Cfg:\n%s", cfg)
    heyhi.log_git_status()
    logging.info("Is on slurm: %s", heyhi.is_on_slurm())
    if heyhi.is_on_slurm():
        logging.info("Slurm job id: %s", heyhi.get_slurm_job_id())
    logging.info("Is master: %s", heyhi.is_master())

    if task not in TASKS:
        raise ValueError("Unknown task: %s. Known tasks: %s" % (task, sorted(TASKS)))
    return TASKS[task](cfg)
Esempio n. 3
0
def main(cfg):
    heyhi.setup_logging()
    logging.info("CWD: %s", os.getcwd())
    logging.info("cfg:\n%s", cfg.pretty())
    resource.setrlimit(resource.RLIMIT_CORE,
                       (resource.RLIM_INFINITY, resource.RLIM_INFINITY))
    logging.info("resource.RLIMIT_CORE: %s", resource.RLIMIT_CORE)
    heyhi.log_git_status()
    logging.info("Is AWS: %s", heyhi.is_aws())
    logging.info("is on slurm:%s", heyhi.is_on_slurm())
    if heyhi.is_on_slurm():
        logging.info("Slurm job id: %s", heyhi.get_slurm_job_id())
    logging.info("Is master: %s", heyhi.is_master())

    task = getattr(cfvpy.tasks, cfg.task)
    return task(cfg)
Esempio n. 4
0
def main(cfg):
    print("HERHEHRHEHREHRHERHEHR")
    heyhi.setup_logging()
    logging.info("CWD: %s", os.getcwd())
    logging.info("cfg:\n%s", cfg.pretty())
    #resource.setrlimit(
    #    resource.RLIMIT_CORE, (16, 16)
    #)
    logging.info("resource.RLIMIT_CORE: %s", resource.RLIMIT_CORE)
    heyhi.log_git_status()
    logging.info("Is AWS: %s", heyhi.is_aws())
    logging.info("is on slurm:%s", heyhi.is_on_slurm())
    if heyhi.is_on_slurm():
        logging.info("Slurm job id: %s", heyhi.get_slurm_job_id())
    logging.info("Is master: %s", heyhi.is_master())

    #task = getattr(cfvpy.tasks, cfg.task)
    #return task(cfg)
    return cfvpy.tasks.selfplay(cfg)
    def __call__(self):
        cfg = self.cfg
        if not torch.cuda.is_available():
            logging.warning("No CUDA found!")
            device = "cpu"
        else:
            device = "cuda"  # Training device.
        self._init_state(device)

        if heyhi.is_master():
            CKPT_DIR.mkdir(exist_ok=True, parents=True)
        if REQUEUE_CKPT.exists():
            logging.info("Found requeue checkpoint: %s",
                         REQUEUE_CKPT.resolve())
            self.state = TrainerState.load(REQUEUE_CKPT, self.state, device)
        elif CKPT_DIR.exists() and list(CKPT_DIR.iterdir()):
            last_ckpt = max(CKPT_DIR.iterdir(), key=str)
            logging.info(
                "Found existing checkpoint folder. Will load last one: %s",
                last_ckpt)
            self.state = TrainerState.load(last_ckpt, self.state, device)
        else:
            logging.info("No checkpoints to restart from found")

        # TODO(akhti): wipe the folder.
        # TODO(akhti): handle multi machine.
        ckpt_syncer = CkptSyncer(
            fairdiplomacy.selfplay.data_loader.get_ckpt_sync_dir(),
            create_dir=True)
        ckpt_syncer.save_state_dict(self.state.model)

        logger = self.logger = fairdiplomacy.selfplay.metrics.Logger()
        data_loader = self.data_loader = fairdiplomacy.selfplay.data_loader.DataLoader(
            cfg.model_path, cfg.rollout)
        self.state.model.train()
        if self.cfg.trainer.train_as_eval:
            self.state.model.eval()
            # Cast cuDNN RNN back to train mode.
            self.state.model.apply(_lstm_to_train)
        elif self.cfg.trainer.train_encoder_as_eval:
            self.state.model.encoder.eval()
        elif self.cfg.trainer.train_decoder_as_eval:
            self.state.model.policy_decoder.eval()
            self.state.model.policy_decoder.apply(_lstm_to_train)
        elif self.cfg.trainer.train_as_eval_but_batchnorm:
            self.state.model.eval()
            self.state.model.apply(_lstm_to_train)
            self.state.model.apply(_bn_to_train)

        max_epochs = self.cfg.trainer.max_epochs or 10**9
        self.state.model.to(device)

        for self.state.epoch_id in range(self.state.epoch_id, max_epochs):
            # Clone state each epoch in case we'll need to requeue.
            self.state.save(REQUEUE_CKPT)
            if (self.cfg.trainer.save_checkpoint_every
                    and self.state.epoch_id %
                    self.cfg.trainer.save_checkpoint_every == 0):
                self.state.save(CKPT_DIR / (CKPT_TPL % self.state.epoch_id))
            # Coutner accumulate different statistic over the epoch. Default
            # accumulation strategy is averaging.
            counters = collections.defaultdict(
                fairdiplomacy.selfplay.metrics.FractionCounter)
            use_grad_clip = self.cfg.optimizer.grad_clip > 1e-10
            if use_grad_clip:
                counters[
                    "optim/grad_max"] = grad_max_counter = fairdiplomacy.selfplay.metrics.MaxCounter(
                    )
            counters[
                "score/num_games"] = fairdiplomacy.selfplay.metrics.SumCounter(
                )
            for p in POWERS:
                counters[
                    f"score_{p}/num_games"] = fairdiplomacy.selfplay.metrics.SumCounter(
                    )
            # For LR just record it's value at the start of the epoch.
            counters["optim/lr"].update(
                next(iter(self.state.optimizer.param_groups))["lr"])
            epoch_start_time = time.time()
            for _ in range(self.cfg.trainer.epoch_size):
                timings = TimingCtx()
                with timings("data_gen"):
                    (
                        (power_ids, obs, rewards, actions,
                         behavior_action_logprobs, done),
                        rollout_scores_per_power,
                    ) = data_loader.get_batch()

                with timings("to_cuda"):
                    actions = actions.to(device)
                    rewards = rewards.to(device)
                    power_ids = power_ids.to(device)
                    obs = {k: v.to(device) for k, v in obs.items()}
                    cand_actions = obs.pop("cand_indices")
                    behavior_action_logprobs = behavior_action_logprobs.to(
                        device)
                    done = done.to(device)

                with timings("net"):
                    # Shape: _, [B, 17], [B, S, 469], [B, 7].
                    # policy_cand_actions has the same information as actions,
                    # but uses local indices to match policy logits.
                    assert EOS_IDX == -1, "Rewrite the code to remove the assumption"
                    _, _, policy_logits, sc_values = self.state.model(
                        **obs,
                        temperature=1.0,
                        teacher_force_orders=actions.clamp(
                            0),  # EOS_IDX = -1 -> 0
                        x_power=to_onehot(power_ids, len(POWERS)),
                    )
                    cand_actions = cand_actions[:, :policy_logits.shape[1]]

                    # Shape: [B].
                    sc_values = sc_values.gather(
                        1, power_ids.unsqueeze(1)).squeeze(1)

                    # Removing absolute order ids to not use them by accident.
                    # Will use relative order ids (cand_actions) from now on.
                    del actions

                    if self.cfg.rollout.do_not_split_rollouts:
                        # Asssumes that episode actually ends.
                        bootstrap_value = torch.zeros_like(sc_values[-1])
                    else:
                        # Reducing batch size by one. Deleting things that are
                        # too lazy to adjsut to avoid artifacts.
                        bootstrap_value = sc_values[-1].detach()
                        sc_values = sc_values[:-1]
                        cand_actions = cand_actions[:-1]
                        policy_logits = policy_logits[:-1]
                        rewards = rewards[:-1]
                        power_ids = power_ids[:-1]
                        del obs
                        behavior_action_logprobs = behavior_action_logprobs[:
                                                                            -1]
                        done = done[:-1]

                    # Shape: [B].
                    discounts = (~done).float() * self.cfg.discounting

                    # Shape: [B, 17].
                    mask = (cand_actions != EOS_IDX).float()

                    # Shape: [B].
                    policy_action_logprobs = order_logits_to_action_logprobs(
                        policy_logits, cand_actions, mask)

                    vtrace_returns = vtrace_from_logprobs_no_batch(
                        log_rhos=policy_action_logprobs -
                        behavior_action_logprobs,
                        discounts=discounts,
                        rewards=rewards,
                        values=sc_values,
                        bootstrap_value=bootstrap_value,
                    )

                    critic_mses = 0.5 * (
                        (vtrace_returns.vs.detach() - sc_values)**2)

                    losses = dict(
                        actor=compute_policy_gradient_loss(
                            policy_action_logprobs,
                            vtrace_returns.pg_advantages),
                        critic=critic_mses.mean(),
                        # TODO(akhti): it's incorrect to apply this to
                        # per-position order distribution instead of action
                        # distribution.
                        entropy=compute_entropy_loss(policy_logits, mask),
                    )

                    loss = (losses["actor"] +
                            cfg.critic_weight * losses["critic"] +
                            cfg.entropy_weight * losses["entropy"])
                    if cfg.sampled_entropy_weight:
                        loss = loss + cfg.sampled_entropy_weight * compute_sampled_entropy_loss(
                            policy_action_logprobs)

                    self.state.optimizer.zero_grad()
                    loss.backward()

                    if use_grad_clip:
                        g_norm_tensor = clip_grad_norm_(
                            self.state.model.parameters(),
                            self.cfg.optimizer.grad_clip)

                    if (not self.cfg.trainer.max_updates
                            or self.state.global_step <
                            self.cfg.trainer.max_updates):
                        self.state.optimizer.step()
                    # Sync to make sure timing is correct.
                    loss.item()

                with timings("metrics"), torch.no_grad():
                    last_count = done.long().sum()
                    critic_end_mses = critic_mses[done].sum()

                    if use_grad_clip:
                        g_norm = g_norm_tensor.item()
                        grad_max_counter.update(g_norm)
                        counters["optim/grad_mean"].update(g_norm)
                        counters["optim/grad_clip_ratio"].update(
                            int(g_norm >= self.cfg.optimizer.grad_clip - 1e-5))
                    for key, value in losses.items():
                        counters[f"loss/{key}"].update(value)
                    counters[f"loss/total"].update(loss.item())
                    for power_id, rollout_scores in rollout_scores_per_power.items(
                    ):
                        prefix = f"score_{POWERS[power_id]}" if power_id is not None else "score"
                        for key, value in rollout_scores.items():
                            if key != "num_games":
                                counters[f"{prefix}/{key}"].update(
                                    value, rollout_scores["num_games"])
                            else:
                                counters[f"{prefix}/{key}"].update(value)

                    counters["loss/critic_last"].update(
                        critic_end_mses, last_count)

                    counters["reward/mean"].update(rewards.sum(), len(rewards))
                    # Rewards at the end of episodes. We precompute everything
                    # before adding to counters to pipeline things when
                    # possible.
                    last_rewards = rewards[done]
                    last_sum = last_rewards.sum()
                    # tensor [num_powers, num_dones].
                    last_power_masks = (
                        power_ids[done].unsqueeze(0) == torch.arange(
                            len(POWERS),
                            device=power_ids.device).unsqueeze(1)).float()
                    last_power_rewards = (last_power_masks *
                                          last_rewards.unsqueeze(0)).sum(1)
                    last_power_counts = last_power_masks.sum(1)
                    counters["reward/last"].update(last_sum, last_count)
                    for power, reward, counts in zip(POWERS,
                                                     last_power_rewards.cpu(),
                                                     last_power_counts.cpu()):
                        counters[f"reward/last_{power}"].update(reward, counts)
                    # To match entropy loss we don't negate logprobs. So this
                    # is an estimate of the negative entropy.
                    counters["loss/entropy_sampled"].update(
                        policy_action_logprobs.mean())

                    # Measure off-policiness.
                    counters["loss/rho"].update(vtrace_returns.rhos.sum(),
                                                vtrace_returns.rhos.numel())
                    counters["loss/rhos_clipped"].update(
                        vtrace_returns.clipped_rhos.sum(),
                        vtrace_returns.clipped_rhos.numel())

                    bsz = len(rewards)
                    counters["size/batch"].update(bsz)
                    counters["size/episode"].update(bsz, last_count)

                with timings("sync"), torch.no_grad():
                    if self.state.global_step % self.cfg.trainer.save_sync_checkpoint_every == 0:
                        ckpt_syncer.save_state_dict(self.state.model)

                # Doing outside of the context to capture the context's timing.
                for key, value in timings.items():
                    counters[f"time/{key}"].update(value)

                if (self.state.global_step < 128
                        or (self.state.global_step
                            & self.state.global_step + 1) == 0):
                    logging.info(
                        "Metrics (global_step=%d): %s",
                        self.state.global_step,
                        {k: v.value()
                         for k, v in sorted(counters.items())},
                    )
                self.state.global_step += 1

            epoch_scalars = {k: v.value() for k, v in sorted(counters.items())}
            average_batch_size = epoch_scalars["size/batch"]
            epoch_scalars["speed/loop_bps"] = self.cfg.trainer.epoch_size / (
                time.time() - epoch_start_time + 1e-5)
            epoch_scalars["speed/loop_eps"] = epoch_scalars[
                "speed/loop_bps"] * average_batch_size
            # Speed for to_cuda + forward + backward.
            torch_time = epoch_scalars["time/net"] + epoch_scalars[
                "time/to_cuda"]
            epoch_scalars["speed/train_bps"] = 1.0 / torch_time
            epoch_scalars["speed/train_eps"] = average_batch_size / torch_time

            eval_scores = data_loader.extract_eval_scores()
            if eval_scores is not None:
                for k, v in eval_scores.items():
                    epoch_scalars[f"score_eval/{k}"] = v

            logging.info(
                "Finished epoch %d. Metrics:\n%s",
                self.state.epoch_id,
                format_metrics_for_print(epoch_scalars),
            )
            logger.log_metrics(epoch_scalars, self.state.epoch_id)
        logging.info("End of training")
        data_loader.terminate()
        logging.info("Exiting main funcion")