def main(task, cfg): heyhi.setup_logging() logging.info("Cwd: %s", os.getcwd()) logging.info("Task: %s", task) logging.info("Cfg:\n%s", cfg) heyhi.log_git_status() logging.info("Is on slurm: %s", heyhi.is_on_slurm()) if heyhi.is_on_slurm(): logging.info("Slurm job id: %s", heyhi.get_slurm_job_id()) logging.info("Is master: %s", heyhi.is_master()) if task not in TASKS: raise ValueError("Unknown task: %s. Known tasks: %s" % (task, sorted(TASKS))) return TASKS[task](cfg)
def compute_exploitability(model_path, cfg_cfr_eval, cfr_binary="build/cfr"): root_dir = pathlib.Path(__file__).parent.parent.resolve() if cfg_cfr_eval.args.num_threads: num_threads = cfg_cfr_eval.args.num_threads else: num_threads = 10 if heyhi.is_on_slurm() else 40 cmd = ("%s -t %d -linear -alternate -decompose -cfr_ave -cfvnet ") % ( cfr_binary, num_threads, ) for k, v in cfg_cfr_eval.args.items(): if k == "num_threads": continue if v is True: cmd += f" -{k}" elif v is False: pass else: cmd += f" -{k} {v}" logging.debug("Going to run: %s", cmd) output = subprocess.check_output( cmd.split() + ["-model", str(model_path.resolve())], cwd=root_dir) values = [] for line in output.decode("utf8").split("\n"): line = line.strip() if line.startswith("Summed Exploitability:"): values.append(float(line.split()[-1])) return values
def main(cfg): heyhi.setup_logging() logging.info("CWD: %s", os.getcwd()) logging.info("cfg:\n%s", cfg.pretty()) resource.setrlimit(resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY)) logging.info("resource.RLIMIT_CORE: %s", resource.RLIMIT_CORE) heyhi.log_git_status() logging.info("Is AWS: %s", heyhi.is_aws()) logging.info("is on slurm:%s", heyhi.is_on_slurm()) if heyhi.is_on_slurm(): logging.info("Slurm job id: %s", heyhi.get_slurm_job_id()) logging.info("Is master: %s", heyhi.is_master()) task = getattr(cfvpy.tasks, cfg.task) return task(cfg)
def get_ckpt_sync_dir(): # TODO(akhti): check for multinode. if heyhi.is_on_slurm(): return f"/scratch/slurm_tmpdir/{heyhi.get_slurm_job_id()}/ckpt_syncer/ckpt" else: # Use NFS. Slow, but at least don't have to clean or resolve conflicts. return "ckpt_syncer/ckpt"
def main(cfg): print("HERHEHRHEHREHRHERHEHR") heyhi.setup_logging() logging.info("CWD: %s", os.getcwd()) logging.info("cfg:\n%s", cfg.pretty()) #resource.setrlimit( # resource.RLIMIT_CORE, (16, 16) #) logging.info("resource.RLIMIT_CORE: %s", resource.RLIMIT_CORE) heyhi.log_git_status() logging.info("Is AWS: %s", heyhi.is_aws()) logging.info("is on slurm:%s", heyhi.is_on_slurm()) if heyhi.is_on_slurm(): logging.info("Slurm job id: %s", heyhi.get_slurm_job_id()) logging.info("Is master: %s", heyhi.is_master()) #task = getattr(cfvpy.tasks, cfg.task) #return task(cfg) return cfvpy.tasks.selfplay(cfg)
def __init__(self, cfg): self.cfg = cfg #self.device = cfg.device or "cuda" self.device = "cuda" ckpt_path = "." if heyhi.is_on_slurm(): self.rank = int(os.environ["SLURM_PROCID"]) self.is_master = self.rank == 0 n_nodes = int(os.environ["SLURM_JOB_NUM_NODES"]) else: self.rank = 0 self.is_master = True n_nodes = 1 logging.info( "Setup: is_master=%s n_nodes=%s rank=%s ckpt_path=%s", self.is_master, n_nodes, self.rank, ckpt_path, ) self.num_actions = cfg.env.num_dice * cfg.env.num_faces * 2 + 1 self.net = _build_model(self.device, self.cfg.env, self.cfg.model) if self.is_master: if self.cfg.load_checkpoint: logging.info("Loading checkpoint: %s", self.cfg.load_checkpoint) self.net.load_state_dict( torch.load(self.cfg.load_checkpoint), strict=not self.cfg.load_checkpoint_loose, ) if self.cfg.selfplay.data_parallel: logging.info("data parallel") assert self.cfg.selfplay.num_master_threads == 0 self.net = torch.nn.DataParallel(self.net) else: logging.info("Single machine mode") self.train_timer = cfvpy.utils.MultiStopWatchTimer() if cfg.seed: logging.info("Setting pytorch random seed to %s", cfg.seed) torch.manual_seed(cfg.seed)
def run_trainer(self): # Fix version so that training always continues. if self.is_master: logger = pl_logging.TestTubeLogger(save_dir=os.getcwd(), version=0) # Storing the whole dict to preserve ref_models. datagen = self.initialize_datagen() context = datagen["context"] replay = datagen["replay"] policy_replay = datagen["policy_replay"] if self.cfg.data.train_preload: # Must preload data before starting generators to avoid deadlocks. _preload_data(self.cfg.data.train_preload, replay) preloaded_size = replay.size() else: preloaded_size = 0 self.opt, self.policy_opt = self.configure_optimizers() self.scheduler = self.configure_scheduler(self.opt) context.start() if self.cfg.benchmark_data_gen: # Benchmark generation speed and exit. time.sleep(self.cfg.benchmark_data_gen) context.terminate() size = replay.num_add() logging.info("BENCHMARK size %s speed %.2f", size, size / context.running_time) return train_size = self.cfg.data.train_epoch_size or 128 * 1000 logging.info("Train set size (forced): %s", train_size) assert self.cfg.data.train_batch_size batch_size = self.cfg.data.train_batch_size epoch_size = train_size // batch_size if self.is_master: val_datasets = [] logging.info( "model size is %s", sum(p.numel() for p in self.net.parameters() if p.requires_grad), ) save_dir = pathlib.Path("ckpt") if self.is_master and not save_dir.exists(): logging.info(f"Creating savedir: {save_dir}") save_dir.mkdir(parents=True) burn_in_frames = batch_size * 2 while replay.size() < burn_in_frames or ( policy_replay is not None and policy_replay.size() < burn_in_frames): logging.info("warming up replay buffer: %d/%d", replay.size(), burn_in_frames) if policy_replay is not None: logging.info( "warming up POLICY replay buffer: %d/%d", policy_replay.size(), burn_in_frames, ) time.sleep(30) def compute_gen_bps(): return ((replay.num_add() - preloaded_size) / context.running_time / batch_size) def compute_gen_bps_policy(): return policy_replay.num_add() / context.running_time / batch_size metrics = None num_decays = 0 for epoch in range(self.cfg.max_epochs): self.train_timer.start("start") if (epoch % self.cfg.decrease_lr_every == self.cfg.decrease_lr_every - 1 and self.scheduler is None): if (not self.cfg.decrease_lr_times or num_decays < self.cfg.decrease_lr_times): for param_group in self.opt.param_groups: param_group["lr"] /= 2 num_decays += 1 if (self.cfg.create_validation_set_every and self.is_master and epoch % self.cfg.create_validation_set_every == 0): logging.info("Adding new validation set") val_batches = [ replay.sample(batch_size, "cpu")[0] for _ in range(512 * 100 // batch_size) ] val_datasets.append( (f"valid_snapshot_{epoch:04d}", val_batches)) if (self.cfg.selfplay.dump_dataset_every_epochs and epoch % self.cfg.selfplay.dump_dataset_every_epochs == 0 and (not self.cfg.data.train_preload or epoch > 0)): dataset_folder = pathlib.Path("dumped_data").resolve() dataset_folder.mkdir(exist_ok=True, parents=True) dataset_path = dataset_folder / f"data_{epoch:03d}.dat" logging.info( "Saving replay buffer as supervised dataset to %s", dataset_path) replay.save(str(dataset_path)) metrics = {} metrics["optim/lr"] = next(iter(self.opt.param_groups))["lr"] metrics["epoch"] = epoch counters = collections.defaultdict(cfvpy.utils.FractionCounter) if self.cfg.grad_clip: counters["optim/grad_max"] = cfvpy.utils.MaxCounter() if self.cfg.train_policy: counters["optim_policy/grad_max"] = cfvpy.utils.MaxCounter( ) use_progress_bar = not heyhi.is_on_slurm( ) or self.cfg.show_progress_bar train_loader = range(epoch_size) train_device = self.device train_iter = tqdm.tqdm( train_loader) if use_progress_bar else train_loader training_start = time.time() if self.cfg.train_gen_ratio: while True: if replay.num_add( ) * self.cfg.train_gen_ratio >= train_size * (epoch + 1): break logging.info( "Throttling to satisfy |replay| * ratio >= train_size * epochs:" " %s * %s >= %s %s", replay.num_add(), self.cfg.train_gen_ratio, train_size, epoch + 1, ) time.sleep(60) assert self.cfg.replay.use_priority is False, "Not supported" value_loss = policy_loss = 0 # For progress bar. for iter_id in train_iter: self.train_timer.start("train-get_batch") use_policy_net = iter_id % 2 and policy_replay is not None if use_policy_net: batch, _ = policy_replay.sample(batch_size, train_device) suffix = "_policy" else: batch, _ = replay.sample(batch_size, train_device) suffix = "" self.train_timer.start("train-forward") self.net.train() loss_dict = self._compute_loss_dict(batch, train_device, use_policy_net, timer_prefix="train-") self.train_timer.start("train-backward") loss = loss_dict["loss"] opt = self.policy_opt if use_policy_net else self.opt params = (self.get_policy_params() if use_policy_net else self.get_value_params()) opt.zero_grad() loss.backward() if self.cfg.grad_clip: g_norm = clip_grad_norm_(params, self.cfg.grad_clip) else: g_norm = None opt.step() loss.item() # Force sync. self.train_timer.start("train-rest") if g_norm is not None: g_norm = g_norm.item() counters[f"optim{suffix}/grad_max"].update(g_norm) counters[f"optim{suffix}/grad_mean"].update(g_norm) counters[f"optim{suffix}/grad_clip_ratio"].update( int(g_norm >= self.cfg.grad_clip - 1e-5)) counters[f"loss{suffix}/train"].update(loss) for num_cards, partial_data in loss_dict["partials"].items(): counters[f"loss{suffix}/train_{num_cards}"].update( partial_data["loss_sum"], partial_data["count"], ) counters[f"val{suffix}/train_{num_cards}"].update( partial_data["val_sum"], partial_data["count"], ) counters[f"shares{suffix}/train_{num_cards}"].update( partial_data["count"], batch_size) if use_progress_bar: if use_policy_net: policy_loss = loss.detach().item() else: value_loss = loss.detach().item() pbar_fields = dict( policy_loss=policy_loss, value_loss=value_loss, buffer_size=replay.size(), gen_bps=compute_gen_bps(), ) if policy_replay is not None: pbar_fields["pol_buffer_size"] = policy_replay.size() train_iter.set_postfix(**pbar_fields) if self.cfg.fake_training: # Generation benchmarking mode in which training is # skipped. The goal is to measure generation speed withot # sample() calls.. break if self.cfg.fake_training: # Fake training epoch takes a minute. time.sleep(60) if len(train_loader) > 0: metrics["bps/train"] = len(train_loader) / (time.time() - training_start) metrics[ "bps/train_examples"] = metrics["bps/train"] * batch_size logging.info( "[Train] epoch %d complete, avg error is %f", epoch, counters["loss/train"].value(), ) if self.scheduler is not None: self.scheduler.step() for name, counter in counters.items(): metrics[name] = counter.value() metrics["buffer/size"] = replay.size() metrics["buffer/added"] = replay.num_add() metrics["bps/gen"] = compute_gen_bps() metrics["bps/gen_examples"] = metrics["bps/gen"] * batch_size if policy_replay is not None: metrics["buffer/policy_size"] = policy_replay.size() metrics["buffer/policy_added"] = policy_replay.num_add() metrics["bps/gen_policy"] = compute_gen_bps_policy() metrics["bps/gen_policy_examples"] = ( metrics["bps/gen_policy"] * batch_size) if (epoch + 1 ) % self.cfg.selfplay.network_sync_epochs == 0 or epoch < 15: logging.info("Copying current network to the eval network") for model_locker in datagen["model_lockers"]: model_locker.update_model(self.get_model()) if self.cfg.purging_epochs and (epoch + 1) in self.cfg.purging_epochs: new_size = max( burn_in_frames, int((self.cfg.purging_share_keep or 0.0) * replay.size()), ) logging.info( "Going to purge everything but %d elements in the buffer", new_size, ) replay.pop_until(new_size) if self.is_master and epoch % 10 == 0: with torch.no_grad(): for i, (name, val_loader) in enumerate(val_datasets): self.train_timer.start("valid-acc-extra") eval_errors = [] val_iter = (tqdm.tqdm(val_loader, desc="Eval") if use_progress_bar else val_loader) for data in val_iter: self.net.eval() loss = self._compute_loss_dict( data, train_device, use_policy_net=False)["loss"] eval_errors.append(loss.detach().item()) current_error = sum(eval_errors) / len(eval_errors) logging.info( "[Eval] epoch %d complete, data is %s, avg error is %f", epoch, name, current_error, ) metrics[f"loss/{name}"] = current_error self.train_timer.start("valid-trace") ckpt_path = save_dir / f"epoch{epoch}.ckpt" torch.save(self.get_model().state_dict(), ckpt_path) bin_path = ckpt_path.with_suffix(".torchscript") torch.jit.save(torch.jit.script(self.get_model()), str(bin_path)) self.train_timer.start("valid-exploit") if self.cfg.exploit and epoch % 20 == 0: bin_path = pathlib.Path("tmp.torchscript") torch.jit.save(torch.jit.script(self.get_model()), str(bin_path)) ( exploitability, mse_net_traverse, mse_fp_traverse, ) = cfvpy.rela.compute_stats_with_net( create_mdp_config(self.cfg.env), str(bin_path)) logging.info("Exploitability to leaf (epoch=%d): %.2f", epoch, exploitability) metrics["exploitability_last"] = exploitability metrics["eval_mse/net_reach"] = mse_net_traverse metrics["eval_mse/fp_reach"] = mse_fp_traverse if len(train_loader) > 0: metrics["bps/loop"] = len(train_loader) / (time.time() - training_start) total = 1e-5 for k, v in self.train_timer.timings.items(): metrics[f"timing/{k}"] = v / (epoch + 1) total += v for k, v in self.train_timer.timings.items(): metrics[f"timing_pct/{k}"] = v * 100 / total logging.info("Metrics: %s", metrics) if self.is_master: logger.log_metrics(metrics) logger.save() return metrics