def load_model(self, load_path): ckpt = torch.load(load_path, map_location="cpu") self.ac.pi.load_state_dict(ckpt["actor"]) self.pi_optimizer = Adam(trainable_parameters(self.ac.pi), lr=self.pi_lr, eps=1e-8) self.pi_optimizer.load_state_dict(ckpt["pi_optimizer"]) if "entropy_coeff" in ckpt: self.entropy_coeff = ckpt["entropy_coeff"] if ckpt["nagents"] == self.nagents: self.ac.v.load_state_dict(ckpt["critic"]) self.vf_optimizer = Adam(trainable_parameters(self.ac.v), lr=self.vf_lr, eps=1e-8) self.vf_optimizer.load_state_dict(ckpt["vf_optimizer"]) else: self.vf_optimizer = Adam(trainable_parameters(self.ac.v), lr=self.vf_lr, eps=1e-8) self.logger.log("The agent was trained with a different nagents") if ("permutation_invariant" in self.ac_params and self.ac_params["permutation_invariant"]): self.ac.v.load_state_dict(ckpt["critic"]) self.vf_optimizer.load_state_dict(ckpt["vf_optimizer"]) self.logger.log( "Agent doesn't depend on nagents. So continuing finetuning" )
def load_model(self, load_path): ckpt = torch.load(load_path, map_location="cpu") self.actor.load_state_dict(ckpt["actor"]) self.pi_optimizer = Adam(trainable_parameters(self.actor), lr=self.pi_lr, eps=1e-8) self.pi_optimizer.load_state_dict(ckpt["pi_optimizer"])
def __init__( self, env, env_params: dict, log_dir: str, ac_kwargs: dict = {}, seed: int = 0, steps_per_epoch: int = 4000, epochs: int = 50, gamma: float = 0.99, clip_ratio: float = 0.2, pi_lr: float = 3e-4, vf_lr: float = 1e-3, train_iters: int = 100, entropy_coeff: float = 1e-2, lam: float = 0.97, target_kl: float = 0.01, save_freq: int = 10, load_path=None, render_train: bool = False, wandb_id: Optional[str] = None, **kwargs, ): self.log_dir = log_dir self.render_dir = os.path.join(log_dir, "renders") self.ckpt_dir = os.path.join(log_dir, "checkpoints") if hvd.rank() == 0: os.makedirs(self.log_dir, exist_ok=True) os.makedirs(self.render_dir, exist_ok=True) os.makedirs(self.ckpt_dir, exist_ok=True) self.softlink = os.path.abspath( os.path.join(self.ckpt_dir, f"ckpt_latest.pth")) self.ac_params_file = os.path.join(log_dir, "ac_params.json") hparams = convert_json(locals()) self.logger = EpochLogger(output_dir=self.log_dir, exp_name=wandb_id) if torch.cuda.is_available(): # Horovod: pin GPU to local rank. dev_id = int(torch.cuda.device_count() * hvd.local_rank() / hvd.local_size()) torch.cuda.set_device(dev_id) device = torch.device(f"cuda:{dev_id}") torch.cuda.manual_seed(seed) else: device = torch.device("cpu") # env_params.update({"device": device}) self.env = env(**env_params) self.ac_params = {k: v for k, v in ac_kwargs.items()} self.ac_params.update({ "observation_space": self.env.observation_space, "action_space": self.env.action_space, "nagents": self.env.nagents, }) self.entropy_coeff = entropy_coeff self.entropy_coeff_decay = entropy_coeff / epochs # Horovod: limit # of CPU threads to be used per worker. torch.set_num_threads(1) torch.save(self.ac_params, self.ac_params_file) if os.path.isfile(self.softlink): self.logger.log("Restarting from latest checkpoint", color="red") load_path = self.softlink # Random seed seed += 10000 * hvd.rank() torch.manual_seed(seed) np.random.seed(seed) self.nagents = self.env.nagents self.ac = PPOLidarActorCritic( self.env.observation_space, self.env.action_space, nagents=self.nagents, centralized=True, **ac_kwargs, ) self.device = device self.pi_lr = pi_lr self.vf_lr = vf_lr self.load_path = load_path if load_path is not None: self.load_model(load_path) else: self.pi_optimizer = Adam(trainable_parameters(self.ac.pi), lr=self.pi_lr, eps=1e-8) self.vf_optimizer = Adam(trainable_parameters(self.ac.v), lr=self.vf_lr, eps=1e-8) # Sync params across processes hvd.broadcast_parameters(self.ac.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(self.pi_optimizer, root_rank=0) hvd.broadcast_optimizer_state(self.vf_optimizer, root_rank=0) self.ac = self.ac.to(device) self.move_optimizer_to_device(self.pi_optimizer) self.move_optimizer_to_device(self.vf_optimizer) if hvd.rank() == 0: if wandb_id is None: eid = (log_dir.split("/")[-2] if load_path is None else load_path.split("/")[-4]) else: eid = wandb_id wandb.init( name=eid, id=eid, project="Social Driving", resume=load_path is not None, ) wandb.watch_called = False if "self" in hparams: del hparams["self"] wandb.config.update(hparams, allow_val_change=True) wandb.watch(self.ac.pi, log="all") wandb.watch(self.ac.v, log="all") # Count variables var_counts = tuple( count_vars(module) for module in [self.ac.pi, self.ac.v]) self.logger.log( "\nNumber of parameters: \t pi: %d, \t v: %d\n" % var_counts, color="green", ) # Set up experience buffer self.steps_per_epoch = steps_per_epoch self.local_steps_per_epoch = int(steps_per_epoch / hvd.size()) self.buf = CentralizedPPOBuffer( self.env.observation_space[0].shape, self.env.observation_space[1].shape, self.env.action_space.shape, self.local_steps_per_epoch, gamma, lam, self.env.nagents, device=self.device, ) self.gamma = gamma self.clip_ratio = clip_ratio self.train_iters = train_iters self.target_kl = target_kl self.epochs = epochs self.save_freq = save_freq