예제 #1
0
 def load_model(self, load_path):
     ckpt = torch.load(load_path, map_location="cpu")
     self.ac.pi.load_state_dict(ckpt["actor"])
     self.pi_optimizer = Adam(trainable_parameters(self.ac.pi),
                              lr=self.pi_lr,
                              eps=1e-8)
     self.pi_optimizer.load_state_dict(ckpt["pi_optimizer"])
     if "entropy_coeff" in ckpt:
         self.entropy_coeff = ckpt["entropy_coeff"]
     if ckpt["nagents"] == self.nagents:
         self.ac.v.load_state_dict(ckpt["critic"])
         self.vf_optimizer = Adam(trainable_parameters(self.ac.v),
                                  lr=self.vf_lr,
                                  eps=1e-8)
         self.vf_optimizer.load_state_dict(ckpt["vf_optimizer"])
     else:
         self.vf_optimizer = Adam(trainable_parameters(self.ac.v),
                                  lr=self.vf_lr,
                                  eps=1e-8)
         self.logger.log("The agent was trained with a different nagents")
         if ("permutation_invariant" in self.ac_params
                 and self.ac_params["permutation_invariant"]):
             self.ac.v.load_state_dict(ckpt["critic"])
             self.vf_optimizer.load_state_dict(ckpt["vf_optimizer"])
             self.logger.log(
                 "Agent doesn't depend on nagents. So continuing finetuning"
             )
예제 #2
0
 def load_model(self, load_path):
     ckpt = torch.load(load_path, map_location="cpu")
     self.actor.load_state_dict(ckpt["actor"])
     self.pi_optimizer = Adam(trainable_parameters(self.actor),
                              lr=self.pi_lr,
                              eps=1e-8)
     self.pi_optimizer.load_state_dict(ckpt["pi_optimizer"])
예제 #3
0
    def __init__(
        self,
        env,
        env_params: dict,
        log_dir: str,
        ac_kwargs: dict = {},
        seed: int = 0,
        steps_per_epoch: int = 4000,
        epochs: int = 50,
        gamma: float = 0.99,
        clip_ratio: float = 0.2,
        pi_lr: float = 3e-4,
        vf_lr: float = 1e-3,
        train_iters: int = 100,
        entropy_coeff: float = 1e-2,
        lam: float = 0.97,
        target_kl: float = 0.01,
        save_freq: int = 10,
        load_path=None,
        render_train: bool = False,
        wandb_id: Optional[str] = None,
        **kwargs,
    ):
        self.log_dir = log_dir
        self.render_dir = os.path.join(log_dir, "renders")
        self.ckpt_dir = os.path.join(log_dir, "checkpoints")
        if hvd.rank() == 0:
            os.makedirs(self.log_dir, exist_ok=True)
            os.makedirs(self.render_dir, exist_ok=True)
            os.makedirs(self.ckpt_dir, exist_ok=True)
        self.softlink = os.path.abspath(
            os.path.join(self.ckpt_dir, f"ckpt_latest.pth"))
        self.ac_params_file = os.path.join(log_dir, "ac_params.json")
        hparams = convert_json(locals())
        self.logger = EpochLogger(output_dir=self.log_dir, exp_name=wandb_id)

        if torch.cuda.is_available():
            # Horovod: pin GPU to local rank.
            dev_id = int(torch.cuda.device_count() * hvd.local_rank() /
                         hvd.local_size())
            torch.cuda.set_device(dev_id)
            device = torch.device(f"cuda:{dev_id}")
            torch.cuda.manual_seed(seed)
        else:
            device = torch.device("cpu")

        #         env_params.update({"device": device})
        self.env = env(**env_params)
        self.ac_params = {k: v for k, v in ac_kwargs.items()}
        self.ac_params.update({
            "observation_space": self.env.observation_space,
            "action_space": self.env.action_space,
            "nagents": self.env.nagents,
        })

        self.entropy_coeff = entropy_coeff
        self.entropy_coeff_decay = entropy_coeff / epochs

        # Horovod: limit # of CPU threads to be used per worker.
        torch.set_num_threads(1)

        torch.save(self.ac_params, self.ac_params_file)

        if os.path.isfile(self.softlink):
            self.logger.log("Restarting from latest checkpoint", color="red")
            load_path = self.softlink

        # Random seed
        seed += 10000 * hvd.rank()
        torch.manual_seed(seed)
        np.random.seed(seed)

        self.nagents = self.env.nagents
        self.ac = PPOLidarActorCritic(
            self.env.observation_space,
            self.env.action_space,
            nagents=self.nagents,
            centralized=True,
            **ac_kwargs,
        )

        self.device = device

        self.pi_lr = pi_lr
        self.vf_lr = vf_lr

        self.load_path = load_path
        if load_path is not None:
            self.load_model(load_path)
        else:
            self.pi_optimizer = Adam(trainable_parameters(self.ac.pi),
                                     lr=self.pi_lr,
                                     eps=1e-8)
            self.vf_optimizer = Adam(trainable_parameters(self.ac.v),
                                     lr=self.vf_lr,
                                     eps=1e-8)

        # Sync params across processes
        hvd.broadcast_parameters(self.ac.state_dict(), root_rank=0)
        hvd.broadcast_optimizer_state(self.pi_optimizer, root_rank=0)
        hvd.broadcast_optimizer_state(self.vf_optimizer, root_rank=0)
        self.ac = self.ac.to(device)
        self.move_optimizer_to_device(self.pi_optimizer)
        self.move_optimizer_to_device(self.vf_optimizer)

        if hvd.rank() == 0:
            if wandb_id is None:
                eid = (log_dir.split("/")[-2]
                       if load_path is None else load_path.split("/")[-4])
            else:
                eid = wandb_id
            wandb.init(
                name=eid,
                id=eid,
                project="Social Driving",
                resume=load_path is not None,
            )
            wandb.watch_called = False

            if "self" in hparams:
                del hparams["self"]
            wandb.config.update(hparams, allow_val_change=True)

            wandb.watch(self.ac.pi, log="all")
            wandb.watch(self.ac.v, log="all")

        # Count variables
        var_counts = tuple(
            count_vars(module) for module in [self.ac.pi, self.ac.v])
        self.logger.log(
            "\nNumber of parameters: \t pi: %d, \t v: %d\n" % var_counts,
            color="green",
        )

        # Set up experience buffer
        self.steps_per_epoch = steps_per_epoch
        self.local_steps_per_epoch = int(steps_per_epoch / hvd.size())
        self.buf = CentralizedPPOBuffer(
            self.env.observation_space[0].shape,
            self.env.observation_space[1].shape,
            self.env.action_space.shape,
            self.local_steps_per_epoch,
            gamma,
            lam,
            self.env.nagents,
            device=self.device,
        )

        self.gamma = gamma
        self.clip_ratio = clip_ratio
        self.train_iters = train_iters
        self.target_kl = target_kl
        self.epochs = epochs
        self.save_freq = save_freq