Пример #1
0
def main(task, cfg):
    heyhi.setup_logging()
    logging.info("Cwd: %s", os.getcwd())
    logging.info("Task: %s", task)
    logging.info("Cfg:\n%s", cfg)
    heyhi.log_git_status()
    logging.info("Is on slurm: %s", heyhi.is_on_slurm())
    if heyhi.is_on_slurm():
        logging.info("Slurm job id: %s", heyhi.get_slurm_job_id())
    logging.info("Is master: %s", heyhi.is_master())

    if task not in TASKS:
        raise ValueError("Unknown task: %s. Known tasks: %s" % (task, sorted(TASKS)))
    return TASKS[task](cfg)
Пример #2
0
def compute_exploitability(model_path, cfg_cfr_eval, cfr_binary="build/cfr"):
    root_dir = pathlib.Path(__file__).parent.parent.resolve()
    if cfg_cfr_eval.args.num_threads:
        num_threads = cfg_cfr_eval.args.num_threads
    else:
        num_threads = 10 if heyhi.is_on_slurm() else 40
    cmd = ("%s -t %d -linear -alternate -decompose -cfr_ave -cfvnet ") % (
        cfr_binary,
        num_threads,
    )
    for k, v in cfg_cfr_eval.args.items():
        if k == "num_threads":
            continue
        if v is True:
            cmd += f" -{k}"
        elif v is False:
            pass
        else:
            cmd += f" -{k} {v}"
    logging.debug("Going to run: %s", cmd)
    output = subprocess.check_output(
        cmd.split() + ["-model", str(model_path.resolve())], cwd=root_dir)
    values = []
    for line in output.decode("utf8").split("\n"):
        line = line.strip()
        if line.startswith("Summed Exploitability:"):
            values.append(float(line.split()[-1]))
    return values
Пример #3
0
def main(cfg):
    heyhi.setup_logging()
    logging.info("CWD: %s", os.getcwd())
    logging.info("cfg:\n%s", cfg.pretty())
    resource.setrlimit(resource.RLIMIT_CORE,
                       (resource.RLIM_INFINITY, resource.RLIM_INFINITY))
    logging.info("resource.RLIMIT_CORE: %s", resource.RLIMIT_CORE)
    heyhi.log_git_status()
    logging.info("Is AWS: %s", heyhi.is_aws())
    logging.info("is on slurm:%s", heyhi.is_on_slurm())
    if heyhi.is_on_slurm():
        logging.info("Slurm job id: %s", heyhi.get_slurm_job_id())
    logging.info("Is master: %s", heyhi.is_master())

    task = getattr(cfvpy.tasks, cfg.task)
    return task(cfg)
Пример #4
0
def get_ckpt_sync_dir():
    # TODO(akhti): check for multinode.
    if heyhi.is_on_slurm():
        return f"/scratch/slurm_tmpdir/{heyhi.get_slurm_job_id()}/ckpt_syncer/ckpt"
    else:
        # Use NFS. Slow, but at least don't have to clean or resolve conflicts.
        return "ckpt_syncer/ckpt"
Пример #5
0
def main(cfg):
    print("HERHEHRHEHREHRHERHEHR")
    heyhi.setup_logging()
    logging.info("CWD: %s", os.getcwd())
    logging.info("cfg:\n%s", cfg.pretty())
    #resource.setrlimit(
    #    resource.RLIMIT_CORE, (16, 16)
    #)
    logging.info("resource.RLIMIT_CORE: %s", resource.RLIMIT_CORE)
    heyhi.log_git_status()
    logging.info("Is AWS: %s", heyhi.is_aws())
    logging.info("is on slurm:%s", heyhi.is_on_slurm())
    if heyhi.is_on_slurm():
        logging.info("Slurm job id: %s", heyhi.get_slurm_job_id())
    logging.info("Is master: %s", heyhi.is_master())

    #task = getattr(cfvpy.tasks, cfg.task)
    #return task(cfg)
    return cfvpy.tasks.selfplay(cfg)
Пример #6
0
    def __init__(self, cfg):

        self.cfg = cfg
        #self.device = cfg.device or "cuda"
        self.device = "cuda"
        ckpt_path = "."
        if heyhi.is_on_slurm():
            self.rank = int(os.environ["SLURM_PROCID"])
            self.is_master = self.rank == 0
            n_nodes = int(os.environ["SLURM_JOB_NUM_NODES"])
        else:
            self.rank = 0
            self.is_master = True
            n_nodes = 1
        logging.info(
            "Setup: is_master=%s n_nodes=%s rank=%s ckpt_path=%s",
            self.is_master,
            n_nodes,
            self.rank,
            ckpt_path,
        )

        self.num_actions = cfg.env.num_dice * cfg.env.num_faces * 2 + 1

        self.net = _build_model(self.device, self.cfg.env, self.cfg.model)
        if self.is_master:
            if self.cfg.load_checkpoint:
                logging.info("Loading checkpoint: %s",
                             self.cfg.load_checkpoint)
                self.net.load_state_dict(
                    torch.load(self.cfg.load_checkpoint),
                    strict=not self.cfg.load_checkpoint_loose,
                )
        if self.cfg.selfplay.data_parallel:
            logging.info("data parallel")
            assert self.cfg.selfplay.num_master_threads == 0
            self.net = torch.nn.DataParallel(self.net)
        else:
            logging.info("Single machine mode")

        self.train_timer = cfvpy.utils.MultiStopWatchTimer()

        if cfg.seed:
            logging.info("Setting pytorch random seed to %s", cfg.seed)
            torch.manual_seed(cfg.seed)
Пример #7
0
    def run_trainer(self):
        # Fix version so that training always continues.
        if self.is_master:
            logger = pl_logging.TestTubeLogger(save_dir=os.getcwd(), version=0)

        # Storing the whole dict to preserve ref_models.
        datagen = self.initialize_datagen()
        context = datagen["context"]
        replay = datagen["replay"]
        policy_replay = datagen["policy_replay"]

        if self.cfg.data.train_preload:
            # Must preload data before starting generators to avoid deadlocks.
            _preload_data(self.cfg.data.train_preload, replay)
            preloaded_size = replay.size()
        else:
            preloaded_size = 0

        self.opt, self.policy_opt = self.configure_optimizers()
        self.scheduler = self.configure_scheduler(self.opt)

        context.start()

        if self.cfg.benchmark_data_gen:
            # Benchmark generation speed and exit.
            time.sleep(self.cfg.benchmark_data_gen)
            context.terminate()
            size = replay.num_add()
            logging.info("BENCHMARK size %s speed %.2f", size,
                         size / context.running_time)
            return

        train_size = self.cfg.data.train_epoch_size or 128 * 1000
        logging.info("Train set size (forced): %s", train_size)

        assert self.cfg.data.train_batch_size
        batch_size = self.cfg.data.train_batch_size
        epoch_size = train_size // batch_size

        if self.is_master:
            val_datasets = []

        logging.info(
            "model size is %s",
            sum(p.numel() for p in self.net.parameters() if p.requires_grad),
        )
        save_dir = pathlib.Path("ckpt")
        if self.is_master and not save_dir.exists():
            logging.info(f"Creating savedir: {save_dir}")
            save_dir.mkdir(parents=True)

        burn_in_frames = batch_size * 2
        while replay.size() < burn_in_frames or (
                policy_replay is not None
                and policy_replay.size() < burn_in_frames):
            logging.info("warming up replay buffer: %d/%d", replay.size(),
                         burn_in_frames)
            if policy_replay is not None:
                logging.info(
                    "warming up POLICY replay buffer: %d/%d",
                    policy_replay.size(),
                    burn_in_frames,
                )
            time.sleep(30)

        def compute_gen_bps():
            return ((replay.num_add() - preloaded_size) /
                    context.running_time / batch_size)

        def compute_gen_bps_policy():
            return policy_replay.num_add() / context.running_time / batch_size

        metrics = None
        num_decays = 0
        for epoch in range(self.cfg.max_epochs):
            self.train_timer.start("start")
            if (epoch % self.cfg.decrease_lr_every
                    == self.cfg.decrease_lr_every - 1
                    and self.scheduler is None):
                if (not self.cfg.decrease_lr_times
                        or num_decays < self.cfg.decrease_lr_times):
                    for param_group in self.opt.param_groups:
                        param_group["lr"] /= 2
                    num_decays += 1
            if (self.cfg.create_validation_set_every and self.is_master
                    and epoch % self.cfg.create_validation_set_every == 0):
                logging.info("Adding new validation set")
                val_batches = [
                    replay.sample(batch_size, "cpu")[0]
                    for _ in range(512 * 100 // batch_size)
                ]
                val_datasets.append(
                    (f"valid_snapshot_{epoch:04d}", val_batches))

            if (self.cfg.selfplay.dump_dataset_every_epochs and
                    epoch % self.cfg.selfplay.dump_dataset_every_epochs == 0
                    and (not self.cfg.data.train_preload or epoch > 0)):
                dataset_folder = pathlib.Path("dumped_data").resolve()
                dataset_folder.mkdir(exist_ok=True, parents=True)
                dataset_path = dataset_folder / f"data_{epoch:03d}.dat"
                logging.info(
                    "Saving replay buffer as supervised dataset to %s",
                    dataset_path)
                replay.save(str(dataset_path))

            metrics = {}
            metrics["optim/lr"] = next(iter(self.opt.param_groups))["lr"]
            metrics["epoch"] = epoch
            counters = collections.defaultdict(cfvpy.utils.FractionCounter)
            if self.cfg.grad_clip:
                counters["optim/grad_max"] = cfvpy.utils.MaxCounter()
                if self.cfg.train_policy:
                    counters["optim_policy/grad_max"] = cfvpy.utils.MaxCounter(
                    )
            use_progress_bar = not heyhi.is_on_slurm(
            ) or self.cfg.show_progress_bar
            train_loader = range(epoch_size)
            train_device = self.device
            train_iter = tqdm.tqdm(
                train_loader) if use_progress_bar else train_loader
            training_start = time.time()

            if self.cfg.train_gen_ratio:
                while True:
                    if replay.num_add(
                    ) * self.cfg.train_gen_ratio >= train_size * (epoch + 1):
                        break
                    logging.info(
                        "Throttling to satisfy |replay| * ratio >= train_size * epochs:"
                        " %s * %s >= %s %s",
                        replay.num_add(),
                        self.cfg.train_gen_ratio,
                        train_size,
                        epoch + 1,
                    )
                    time.sleep(60)
            assert self.cfg.replay.use_priority is False, "Not supported"

            value_loss = policy_loss = 0  # For progress bar.
            for iter_id in train_iter:
                self.train_timer.start("train-get_batch")
                use_policy_net = iter_id % 2 and policy_replay is not None
                if use_policy_net:
                    batch, _ = policy_replay.sample(batch_size, train_device)
                    suffix = "_policy"
                else:
                    batch, _ = replay.sample(batch_size, train_device)
                    suffix = ""
                self.train_timer.start("train-forward")
                self.net.train()
                loss_dict = self._compute_loss_dict(batch,
                                                    train_device,
                                                    use_policy_net,
                                                    timer_prefix="train-")
                self.train_timer.start("train-backward")
                loss = loss_dict["loss"]
                opt = self.policy_opt if use_policy_net else self.opt
                params = (self.get_policy_params()
                          if use_policy_net else self.get_value_params())
                opt.zero_grad()
                loss.backward()

                if self.cfg.grad_clip:
                    g_norm = clip_grad_norm_(params, self.cfg.grad_clip)
                else:
                    g_norm = None
                opt.step()
                loss.item()  # Force sync.
                self.train_timer.start("train-rest")
                if g_norm is not None:
                    g_norm = g_norm.item()
                    counters[f"optim{suffix}/grad_max"].update(g_norm)
                    counters[f"optim{suffix}/grad_mean"].update(g_norm)
                    counters[f"optim{suffix}/grad_clip_ratio"].update(
                        int(g_norm >= self.cfg.grad_clip - 1e-5))
                counters[f"loss{suffix}/train"].update(loss)
                for num_cards, partial_data in loss_dict["partials"].items():
                    counters[f"loss{suffix}/train_{num_cards}"].update(
                        partial_data["loss_sum"],
                        partial_data["count"],
                    )
                    counters[f"val{suffix}/train_{num_cards}"].update(
                        partial_data["val_sum"],
                        partial_data["count"],
                    )
                    counters[f"shares{suffix}/train_{num_cards}"].update(
                        partial_data["count"], batch_size)

                if use_progress_bar:
                    if use_policy_net:
                        policy_loss = loss.detach().item()
                    else:
                        value_loss = loss.detach().item()
                    pbar_fields = dict(
                        policy_loss=policy_loss,
                        value_loss=value_loss,
                        buffer_size=replay.size(),
                        gen_bps=compute_gen_bps(),
                    )
                    if policy_replay is not None:
                        pbar_fields["pol_buffer_size"] = policy_replay.size()
                    train_iter.set_postfix(**pbar_fields)
                if self.cfg.fake_training:
                    # Generation benchmarking mode in which training is
                    # skipped. The goal is to measure generation speed withot
                    # sample() calls..
                    break
            if self.cfg.fake_training:
                # Fake training epoch takes a minute.
                time.sleep(60)

            if len(train_loader) > 0:
                metrics["bps/train"] = len(train_loader) / (time.time() -
                                                            training_start)
                metrics[
                    "bps/train_examples"] = metrics["bps/train"] * batch_size
            logging.info(
                "[Train] epoch %d complete, avg error is %f",
                epoch,
                counters["loss/train"].value(),
            )
            if self.scheduler is not None:
                self.scheduler.step()
            for name, counter in counters.items():
                metrics[name] = counter.value()
            metrics["buffer/size"] = replay.size()
            metrics["buffer/added"] = replay.num_add()
            metrics["bps/gen"] = compute_gen_bps()
            metrics["bps/gen_examples"] = metrics["bps/gen"] * batch_size
            if policy_replay is not None:
                metrics["buffer/policy_size"] = policy_replay.size()
                metrics["buffer/policy_added"] = policy_replay.num_add()
                metrics["bps/gen_policy"] = compute_gen_bps_policy()
                metrics["bps/gen_policy_examples"] = (
                    metrics["bps/gen_policy"] * batch_size)

            if (epoch + 1
                ) % self.cfg.selfplay.network_sync_epochs == 0 or epoch < 15:
                logging.info("Copying current network to the eval network")
                for model_locker in datagen["model_lockers"]:
                    model_locker.update_model(self.get_model())
            if self.cfg.purging_epochs and (epoch +
                                            1) in self.cfg.purging_epochs:
                new_size = max(
                    burn_in_frames,
                    int((self.cfg.purging_share_keep or 0.0) * replay.size()),
                )
                logging.info(
                    "Going to purge everything but %d elements in the buffer",
                    new_size,
                )
                replay.pop_until(new_size)

            if self.is_master and epoch % 10 == 0:
                with torch.no_grad():
                    for i, (name, val_loader) in enumerate(val_datasets):
                        self.train_timer.start("valid-acc-extra")
                        eval_errors = []
                        val_iter = (tqdm.tqdm(val_loader, desc="Eval")
                                    if use_progress_bar else val_loader)
                        for data in val_iter:
                            self.net.eval()
                            loss = self._compute_loss_dict(
                                data, train_device,
                                use_policy_net=False)["loss"]
                            eval_errors.append(loss.detach().item())
                        current_error = sum(eval_errors) / len(eval_errors)
                        logging.info(
                            "[Eval] epoch %d complete, data is %s, avg error is %f",
                            epoch,
                            name,
                            current_error,
                        )
                        metrics[f"loss/{name}"] = current_error

                self.train_timer.start("valid-trace")
                ckpt_path = save_dir / f"epoch{epoch}.ckpt"
                torch.save(self.get_model().state_dict(), ckpt_path)
                bin_path = ckpt_path.with_suffix(".torchscript")
                torch.jit.save(torch.jit.script(self.get_model()),
                               str(bin_path))

                self.train_timer.start("valid-exploit")
                if self.cfg.exploit and epoch % 20 == 0:
                    bin_path = pathlib.Path("tmp.torchscript")
                    torch.jit.save(torch.jit.script(self.get_model()),
                                   str(bin_path))
                    (
                        exploitability,
                        mse_net_traverse,
                        mse_fp_traverse,
                    ) = cfvpy.rela.compute_stats_with_net(
                        create_mdp_config(self.cfg.env), str(bin_path))
                    logging.info("Exploitability to leaf (epoch=%d): %.2f",
                                 epoch, exploitability)
                    metrics["exploitability_last"] = exploitability
                    metrics["eval_mse/net_reach"] = mse_net_traverse
                    metrics["eval_mse/fp_reach"] = mse_fp_traverse

            if len(train_loader) > 0:
                metrics["bps/loop"] = len(train_loader) / (time.time() -
                                                           training_start)
            total = 1e-5
            for k, v in self.train_timer.timings.items():
                metrics[f"timing/{k}"] = v / (epoch + 1)
                total += v
            for k, v in self.train_timer.timings.items():
                metrics[f"timing_pct/{k}"] = v * 100 / total
            logging.info("Metrics: %s", metrics)
            if self.is_master:
                logger.log_metrics(metrics)
                logger.save()
        return metrics