def __init__(self, seed, state_dim, action_dim, lr=3e-4, gamma=0.99, tau=5e-3, batchsize=256, hidden_size=256, update_interval=1, buffer_size=int(1e6), target_entropy=None): self.gamma = gamma self.tau = tau self.target_entropy = target_entropy if target_entropy else -action_dim self.batchsize = batchsize self.update_interval = update_interval torch.manual_seed(seed) # aka critic self.q_funcs = DoubleQFunc(state_dim, action_dim, hidden_size=hidden_size).to(device) self.target_q_funcs = copy.deepcopy(self.q_funcs) self.target_q_funcs.eval() for p in self.target_q_funcs.parameters(): p.requires_grad = False # aka actor self.policy = Policy(state_dim, action_dim, hidden_size=hidden_size).to(device) # aka temperature self.log_alpha = torch.zeros(1, requires_grad=True, device=device) self.q_optimizer = torch.optim.Adam(self.q_funcs.parameters(), lr=lr) self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr) self.temp_optimizer = torch.optim.Adam([self.log_alpha], lr=lr) self.replay_pool = ReplayPool(action_dim=action_dim, state_dim=state_dim, capacity=int(1e6))
def __init__(self, seed, state_dim, action_dim, action_lim=1, lr=3e-4, gamma=0.99, tau=5e-3, batch_size=256, hidden_size=256, update_interval=2, buffer_size=1e6): self.gamma = gamma self.tau = tau self.batch_size = batch_size self.update_interval = update_interval self.action_lim = action_lim torch.manual_seed(seed) # aka critic self.q_funcs = DoubleQFunc(state_dim, action_dim, hidden_size=hidden_size).to(device) self.target_q_funcs = copy.deepcopy(self.q_funcs) self.target_q_funcs.eval() for p in self.target_q_funcs.parameters(): p.requires_grad = False # aka actor self.policy = Policy(state_dim, action_dim, hidden_size=hidden_size).to(device) self.q_optimizer = torch.optim.Adam(self.q_funcs.parameters(), lr=lr) self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr) self.replay_pool = ReplayPool(action_dim=action_dim, state_dim=state_dim, capacity=int(buffer_size)) self._seed = seed self._update_counter = 0
def reallocate_replay_pool(self, new_size: int) -> None: """Reset buffer Args: new_size (int): new maximum buffer size. """ assert new_size != self.replay_pool.capacity, "Error, you've tried to allocate a new pool which has the same length" new_replay_pool = ReplayPool(capacity=new_size) new_replay_pool.initialise(self.replay_pool) self.replay_pool = new_replay_pool
def train(self): pool = ReplayPool(max_pool_size=self.replay_pool_size, observation_dim=self._observation_dim, action_dim=self._action_dim) terminal = False observation = self.env.reset() path_length = 0 path_return = 0 itr = 0 for epoch in range(self.n_epoch): print('Starting epoch #%d' % epoch) for epoch_itr in range(self.epoch_length): if terminal: print(path_return, path_length) observation = self.env.reset() path_length = 0 path_return = 0 # if self.render: # self.env.render() action = self.policy.get_action(observation) next_observation, reward, terminal, _ = self.env.step(action) path_length += 1 path_return += reward if not terminal and path_length >= self.epoch_length: terminal = True pool.add_sample(observation, action, reward, terminal) observation = next_observation if pool.size >= self.min_pool_size: batch = pool.random_batch(self.batch_size) self._do_training(itr, batch) itr += 0
class SAC_Agent: def __init__(self, seed, state_dim, action_dim, lr=3e-4, gamma=0.99, tau=5e-3, batchsize=256, hidden_size=256, update_interval=1, buffer_size=int(1e6), target_entropy=None): self.gamma = gamma self.tau = tau self.target_entropy = target_entropy if target_entropy else -action_dim self.batchsize = batchsize self.update_interval = update_interval torch.manual_seed(seed) # aka critic self.q_funcs = DoubleQFunc(state_dim, action_dim, hidden_size=hidden_size).to(device) self.target_q_funcs = copy.deepcopy(self.q_funcs) self.target_q_funcs.eval() for p in self.target_q_funcs.parameters(): p.requires_grad = False # aka actor self.policy = Policy(state_dim, action_dim, hidden_size=hidden_size).to(device) # aka temperature self.log_alpha = torch.zeros(1, requires_grad=True, device=device) self.q_optimizer = torch.optim.Adam(self.q_funcs.parameters(), lr=lr) self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr) self.temp_optimizer = torch.optim.Adam([self.log_alpha], lr=lr) self.replay_pool = ReplayPool(action_dim=action_dim, state_dim=state_dim, capacity=int(1e6)) def get_action(self, state, state_filter=None, deterministic=False): if state_filter: state = state_filter(state) with torch.no_grad(): action, _, mean = self.policy(torch.Tensor(state).view(1,-1).to(device)) if deterministic: return mean.squeeze().cpu().numpy() return np.atleast_1d(action.squeeze().cpu().numpy()) def update_target(self): """moving average update of target networks""" with torch.no_grad(): for target_q_param, q_param in zip(self.target_q_funcs.parameters(), self.q_funcs.parameters()): target_q_param.data.copy_(self.tau * q_param.data + (1.0 - self.tau) * target_q_param.data) def update_q_functions(self, state_batch, action_batch, reward_batch, nextstate_batch, done_batch): with torch.no_grad(): nextaction_batch, logprobs_batch, _ = self.policy(nextstate_batch, get_logprob=True) q_t1, q_t2 = self.target_q_funcs(nextstate_batch, nextaction_batch) # take min to mitigate positive bias in q-function training q_target = torch.min(q_t1, q_t2) value_target = reward_batch + (1.0 - done_batch) * self.gamma * (q_target - self.alpha * logprobs_batch) q_1, q_2 = self.q_funcs(state_batch, action_batch) loss_1 = F.mse_loss(q_1, value_target) loss_2 = F.mse_loss(q_2, value_target) return loss_1, loss_2 def update_policy_and_temp(self, state_batch): action_batch, logprobs_batch, _ = self.policy(state_batch, get_logprob=True) q_b1, q_b2 = self.q_funcs(state_batch, action_batch) qval_batch = torch.min(q_b1, q_b2) policy_loss = (self.alpha * logprobs_batch - qval_batch).mean() temp_loss = -self.alpha * (logprobs_batch.detach() + self.target_entropy).mean() return policy_loss, temp_loss def optimize(self, n_updates, state_filter=None): q1_loss, q2_loss, pi_loss, a_loss = 0, 0, 0, 0 for i in range(n_updates): samples = self.replay_pool.sample(self.batchsize) if state_filter: state_batch = torch.FloatTensor(state_filter(samples.state)).to(device) nextstate_batch = torch.FloatTensor(state_filter(samples.nextstate)).to(device) else: state_batch = torch.FloatTensor(samples.state).to(device) nextstate_batch = torch.FloatTensor(samples.nextstate).to(device) action_batch = torch.FloatTensor(samples.action).to(device) reward_batch = torch.FloatTensor(samples.reward).to(device).unsqueeze(1) done_batch = torch.FloatTensor(samples.real_done).to(device).unsqueeze(1) # update q-funcs q1_loss_step, q2_loss_step = self.update_q_functions(state_batch, action_batch, reward_batch, nextstate_batch, done_batch) q_loss_step = q1_loss_step + q2_loss_step self.q_optimizer.zero_grad() q_loss_step.backward() self.q_optimizer.step() # update policy and temperature parameter for p in self.q_funcs.parameters(): p.requires_grad = False pi_loss_step, a_loss_step = self.update_policy_and_temp(state_batch) self.policy_optimizer.zero_grad() pi_loss_step.backward() self.policy_optimizer.step() self.temp_optimizer.zero_grad() a_loss_step.backward() self.temp_optimizer.step() for p in self.q_funcs.parameters(): p.requires_grad = True q1_loss += q1_loss_step.detach().item() q2_loss += q2_loss_step.detach().item() pi_loss += pi_loss_step.detach().item() a_loss += a_loss_step.detach().item() if i % self.update_interval == 0: self.update_target() return q1_loss, q2_loss, pi_loss, a_loss @property def alpha(self): return self.log_alpha.exp()
def __init__(self, seed: int, state_dim: int, action_dim: int, action_lim: int = 1, lr: float = 3e-4, gamma: float = 0.99, tau: float = 5e-3, batchsize: int = 256, hidden_size: int = 256, update_interval: int = 2, buffer_size: int = int(1e6), target_noise: float = 0.2, target_noise_clip: float = 0.5, explore_noise: float = 0.1, n_quantiles: int = 100, kappa: float = 1.0, beta: float = 0.0, bandit_lr: float = 0.1) -> None: """ Initialize DOPE agent. Args: seed (int): random seed state_dim (int): state dimension action_dim (int): action dimension action_lim (int, optional): max action value. Defaults to 1. lr (float, optional): learning rate. Defaults to 3e-4. gamma (float, optional): discount factor. Defaults to 0.99. tau (float, optional): mixing rate for target nets. Defaults to 5e-3. batchsize (int, optional): batch size. Defaults to 256. hidden_size (int, optional): hidden layer size for policy. Defaults to 256. update_interval (int, optional): delay for actor, target updates. Defaults to 2. buffer_size (int, optional): size of replay buffer. Defaults to int(1e6). target_noise (float, optional): smoothing noise for target action. Defaults to 0.2. target_noise_clip (float, optional): limit for target. Defaults to 0.5. explore_noise (float, optional): noise for exploration. Defaults to 0.1. n_quantiles (int, optional): number of quantiles. Defaults to 100. kappa (float, optional): constant for Huber loss. Defaults to 1.0. bandit_lr (float, optional): bandit learning rate. Defaults to 0.1. """ self.gamma = gamma self.tau = tau self.batchsize = batchsize self.update_interval = update_interval self.action_lim = action_lim self.target_noise = target_noise self.target_noise_clip = target_noise_clip self.explore_noise = explore_noise torch.manual_seed(seed) # init critic(s) self.q_funcs = QuantileDoubleQFunc(state_dim, action_dim, n_quantiles=n_quantiles, hidden_size=hidden_size).to(device) self.target_q_funcs = copy.deepcopy(self.q_funcs) self.target_q_funcs.eval() for p in self.target_q_funcs.parameters(): p.requires_grad = False # init actor self.policy = Policy(state_dim, action_dim, hidden_size=hidden_size).to(device) self.target_policy = copy.deepcopy(self.policy) for p in self.target_policy.parameters(): p.requires_grad = False # set distributional parameters taus = torch.arange( 0, n_quantiles + 1, device=device, dtype=torch.float32) / n_quantiles self.tau_hats = ((taus[1:] + taus[:-1]) / 2.0).view(1, n_quantiles) self.n_quantiles = n_quantiles self.kappa = kappa # bandit top-down controller self.TDC = ExpWeights(arms=[-1, 0], lr=bandit_lr, init=0.0, use_std=True) # init optimizers self.q_optimizer = torch.optim.Adam(self.q_funcs.parameters(), lr=lr) self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr) self.replay_pool = ReplayPool(capacity=int(buffer_size)) self._update_counter = 0
class DOPE_Agent: def __init__(self, seed: int, state_dim: int, action_dim: int, action_lim: int = 1, lr: float = 3e-4, gamma: float = 0.99, tau: float = 5e-3, batchsize: int = 256, hidden_size: int = 256, update_interval: int = 2, buffer_size: int = int(1e6), target_noise: float = 0.2, target_noise_clip: float = 0.5, explore_noise: float = 0.1, n_quantiles: int = 100, kappa: float = 1.0, beta: float = 0.0, bandit_lr: float = 0.1) -> None: """ Initialize DOPE agent. Args: seed (int): random seed state_dim (int): state dimension action_dim (int): action dimension action_lim (int, optional): max action value. Defaults to 1. lr (float, optional): learning rate. Defaults to 3e-4. gamma (float, optional): discount factor. Defaults to 0.99. tau (float, optional): mixing rate for target nets. Defaults to 5e-3. batchsize (int, optional): batch size. Defaults to 256. hidden_size (int, optional): hidden layer size for policy. Defaults to 256. update_interval (int, optional): delay for actor, target updates. Defaults to 2. buffer_size (int, optional): size of replay buffer. Defaults to int(1e6). target_noise (float, optional): smoothing noise for target action. Defaults to 0.2. target_noise_clip (float, optional): limit for target. Defaults to 0.5. explore_noise (float, optional): noise for exploration. Defaults to 0.1. n_quantiles (int, optional): number of quantiles. Defaults to 100. kappa (float, optional): constant for Huber loss. Defaults to 1.0. bandit_lr (float, optional): bandit learning rate. Defaults to 0.1. """ self.gamma = gamma self.tau = tau self.batchsize = batchsize self.update_interval = update_interval self.action_lim = action_lim self.target_noise = target_noise self.target_noise_clip = target_noise_clip self.explore_noise = explore_noise torch.manual_seed(seed) # init critic(s) self.q_funcs = QuantileDoubleQFunc(state_dim, action_dim, n_quantiles=n_quantiles, hidden_size=hidden_size).to(device) self.target_q_funcs = copy.deepcopy(self.q_funcs) self.target_q_funcs.eval() for p in self.target_q_funcs.parameters(): p.requires_grad = False # init actor self.policy = Policy(state_dim, action_dim, hidden_size=hidden_size).to(device) self.target_policy = copy.deepcopy(self.policy) for p in self.target_policy.parameters(): p.requires_grad = False # set distributional parameters taus = torch.arange( 0, n_quantiles + 1, device=device, dtype=torch.float32) / n_quantiles self.tau_hats = ((taus[1:] + taus[:-1]) / 2.0).view(1, n_quantiles) self.n_quantiles = n_quantiles self.kappa = kappa # bandit top-down controller self.TDC = ExpWeights(arms=[-1, 0], lr=bandit_lr, init=0.0, use_std=True) # init optimizers self.q_optimizer = torch.optim.Adam(self.q_funcs.parameters(), lr=lr) self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr) self.replay_pool = ReplayPool(capacity=int(buffer_size)) self._update_counter = 0 def reallocate_replay_pool(self, new_size: int) -> None: """Reset buffer Args: new_size (int): new maximum buffer size. """ assert new_size != self.replay_pool.capacity, "Error, you've tried to allocate a new pool which has the same length" new_replay_pool = ReplayPool(capacity=new_size) new_replay_pool.initialise(self.replay_pool) self.replay_pool = new_replay_pool def get_action(self, state: np.ndarray, state_filter: Callable = None, deterministic: bool = False) -> np.ndarray: """given the current state, produce an action Args: state (np.ndarray): state input. state_filter (Callable): pre-processing function for state input. Defaults to None. deterministic (bool, optional): whether the action is deterministic or stochastic. Defaults to False. Returns: np.ndarray: the action. """ if state_filter: state = state_filter(state) state = torch.Tensor(state).view(1, -1).to(device) with torch.no_grad(): action = self.policy(state) if not deterministic: action += self.explore_noise * torch.randn_like(action) action.clamp_(-self.action_lim, self.action_lim) return np.atleast_1d(action.squeeze().cpu().numpy()) def update_target(self) -> None: """moving average update of target networks""" with torch.no_grad(): for target_q_param, q_param in zip( self.target_q_funcs.parameters(), self.q_funcs.parameters()): target_q_param.data.copy_(self.tau * q_param.data + (1.0 - self.tau) * target_q_param.data) for target_pi_param, pi_param in zip( self.target_policy.parameters(), self.policy.parameters()): target_pi_param.data.copy_(self.tau * pi_param.data + (1.0 - self.tau) * target_pi_param.data) def update_q_functions( self, state_batch: torch.Tensor, action_batch: torch.Tensor, reward_batch: torch.Tensor, nextstate_batch: torch.Tensor, done_batch: torch.Tensor, beta: float ) -> [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """compute quantile losses for critics Args: state_batch (torch.Tensor): batch of states action_batch (torch.Tensor): batch of actions reward_batch (torch.Tensor): batch of rewards nextstate_batch (torch.Tensor): batch of next states done_batch (torch.Tensor): batch of booleans describing whether episode ended. beta (float): optimism parameter Returns: [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: critic 1 loss, critic 2 loss, critic 1 quantiles, critic 2 quantiles """ with torch.no_grad(): # get next action from target network nextaction_batch = self.target_policy(nextstate_batch) # add noise target_noise = self.target_noise * torch.randn_like( nextaction_batch) target_noise.clamp_(-self.target_noise_clip, self.target_noise_clip) nextaction_batch += target_noise nextaction_batch.clamp_(-self.action_lim, self.action_lim) # get quantiles at (s', \tilde a) quantiles_t1, quantiles_t2 = self.target_q_funcs( nextstate_batch, nextaction_batch) # compute mean and std quantiles_all = torch.stack([quantiles_t1, quantiles_t2], dim=-1) # [batch_size, n_quantiles, 2] mu = torch.mean(quantiles_all, axis=-1) # [batch_size, n_quantiles] # compute std by hand for stability sigma = torch.sqrt((torch.pow(quantiles_t1 - mu, 2) + torch.pow(quantiles_t2 - mu, 2)) + 1e-4) # construct belief distribution belief_dist = mu + beta * sigma # [batch_size, n_quantiles] # compute the targets as batch_size x 1 x n_quantiles n_quantiles = belief_dist.shape[-1] quantile_target = reward_batch[..., None] + (1.0 - done_batch[..., None]) \ * self.gamma * belief_dist[:, None, :] # [batch_size, 1, n_quantiles] # get quantiles at (s, a) quantiles_1, quantiles_2 = self.q_funcs(state_batch, action_batch) # compute pairwise td errors td_errors_1 = quantile_target - quantiles_1[ ..., None] # [batch_size, n_quantiles, n_quantiles] td_errors_2 = quantile_target - quantiles_2[ ..., None] # [batch_size, n_quantiles, n_quantiles] # compute quantile losses loss_1 = calculate_quantile_huber_loss(td_errors_1, self.tau_hats, weights=None, kappa=self.kappa) loss_2 = calculate_quantile_huber_loss(td_errors_2, self.tau_hats, weights=None, kappa=self.kappa) return loss_1, loss_2, quantiles_1, quantiles_2 def update_policy(self, state_batch: torch.Tensor, beta: float) -> torch.Tensor: """update the actor. Args: state_batch (torch.Tensor): batch of states. beta (float): optimism parameter. Returns: torch.Tensor: DPG loss. """ # get actions a action_batch = self.policy(state_batch) # compute quantiles (s,a) quantiles_b1, quantiles_b2 = self.q_funcs(state_batch, action_batch) # construct belief distribution quantiles_all = torch.stack([quantiles_b1, quantiles_b2], dim=-1) # [batch_size, n_quantiles, 2] mu = torch.mean(quantiles_all, axis=-1) # [batch_size, n_quantiles] eps1, eps2 = 1e-4, 1.1e-4 # small constants for stability sigma = torch.sqrt((torch.pow(quantiles_b1 + eps1 - mu, 2) + torch.pow(quantiles_b2 + eps2 - mu, 2)) + eps1) belief_dist = mu + beta * sigma # [batch_size, n_quantiles] # DPG loss qval_batch = torch.mean(belief_dist, axis=-1) policy_loss = (-qval_batch).mean() return policy_loss def optimize( self, n_updates: int, beta: float, state_filter: Callable = None ) -> [float, float, float, float, torch.Tensor, torch.Tensor]: """sample transitions from the buffer and update parameters Args: n_updates (int): number of updates to perform. beta (float): optimism parameter. state_filter (Callable, optional): state pre-processing function. Defaults to None. Returns: [float, float, float, float, torch.Tensor, torch.Tensor]: critic 1 loss, critic 2 loss, actor loss, WD, critic 1 quantiles, critic 2 quantiles """ q1_loss, q2_loss, wd, pi_loss = 0, 0, 0, None for i in range(n_updates): samples = self.replay_pool.sample(self.batchsize) if state_filter: state_batch = torch.FloatTensor(state_filter( samples.state)).to(device) nextstate_batch = torch.FloatTensor( state_filter(samples.nextstate)).to(device) else: state_batch = torch.FloatTensor(samples.state).to(device) nextstate_batch = torch.FloatTensor( samples.nextstate).to(device) action_batch = torch.FloatTensor(samples.action).to(device) reward_batch = torch.FloatTensor( samples.reward).to(device).unsqueeze(1) done_batch = torch.FloatTensor( samples.real_done).to(device).unsqueeze(1) # update q-funcs q1_loss_step, q2_loss_step, quantiles1_step, quantiles2_step = self.update_q_functions( state_batch, action_batch, reward_batch, nextstate_batch, done_batch, beta) q_loss_step = q1_loss_step + q2_loss_step # measure wasserstein distance wd_step = compute_wd_quantile(quantiles1_step, quantiles2_step) wd += wd_step.detach().item() # take gradient step for critics self.q_optimizer.zero_grad() q_loss_step.backward() self.q_optimizer.step() self._update_counter += 1 q1_loss += q1_loss_step.detach().item() q2_loss += q2_loss_step.detach().item() # every update_interval steps update actor, target nets if self._update_counter % self.update_interval == 0: if not pi_loss: pi_loss = 0 # update policy for p in self.q_funcs.parameters(): p.requires_grad = False pi_loss_step = self.update_policy(state_batch, beta) self.policy_optimizer.zero_grad() pi_loss_step.backward() self.policy_optimizer.step() for p in self.q_funcs.parameters(): p.requires_grad = True # update target policy and q-functions using Polyak averaging self.update_target() pi_loss += pi_loss_step.detach().item() return q1_loss, q2_loss, pi_loss, wd / n_updates, quantiles1_step, quantiles2_step
def reallocate_replay_pool(self, new_size: int): assert new_size != self.replay_pool.capacity, "Error, you've tried to allocate a new pool which has the same length" new_replay_pool = ReplayPool(capacity=new_size) new_replay_pool.initialise(self.replay_pool) self.replay_pool = new_replay_pool
class OffPolicyAgent: def __init__(self, seed, state_dim, action_dim, action_lim=1, lr=3e-4, gamma=0.99, tau=5e-3, batch_size=256, hidden_size=256, update_interval=2, buffer_size=1e6): self.gamma = gamma self.tau = tau self.batch_size = batch_size self.update_interval = update_interval self.action_lim = action_lim torch.manual_seed(seed) # aka critic self.q_funcs = DoubleQFunc(state_dim, action_dim, hidden_size=hidden_size).to(device) self.target_q_funcs = copy.deepcopy(self.q_funcs) self.target_q_funcs.eval() for p in self.target_q_funcs.parameters(): p.requires_grad = False # aka actor self.policy = Policy(state_dim, action_dim, hidden_size=hidden_size).to(device) self.q_optimizer = torch.optim.Adam(self.q_funcs.parameters(), lr=lr) self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr) self.replay_pool = ReplayPool(action_dim=action_dim, state_dim=state_dim, capacity=int(buffer_size)) self._seed = seed self._update_counter = 0 def reallocate_replay_pool(self, new_size: int): assert new_size != self.replay_pool.capacity, "Error, you've tried to allocate a new pool which has the same length" new_replay_pool = ReplayPool(capacity=new_size) new_replay_pool.initialise(self.replay_pool) self.replay_pool = new_replay_pool @property def is_soft(self): raise NotImplementedError @property def alg_name(self): raise NotImplementedError def get_action(self, state, state_filter=None, deterministic=False): raise NotImplementedError def update_target(self): raise NotImplementedError def update_q_functions(self, state_batch, action_batch, reward_batch, nextstate_batch, done_batch): raise NotImplementedError def update_policy(self, state_batch): raise NotImplementedError def optimize(self, n_updates, state_filter=None): q1_loss, q2_loss, pi_loss, a_loss = 0, 0, None, None for i in range(n_updates): samples = self.replay_pool.sample(self.batch_size) if state_filter: state_batch = torch.FloatTensor(state_filter(samples.state)).to(device) nextstate_batch = torch.FloatTensor(state_filter(samples.nextstate)).to(device) else: state_batch = torch.FloatTensor(samples.state).to(device) nextstate_batch = torch.FloatTensor(samples.nextstate).to(device) action_batch = torch.FloatTensor(samples.action).to(device) reward_batch = torch.FloatTensor(samples.reward).to(device).unsqueeze(1) done_batch = torch.FloatTensor(samples.real_done).to(device).unsqueeze(1) # update q-funcs q1_loss_step, q2_loss_step = self.update_q_functions(state_batch, action_batch, reward_batch, nextstate_batch, done_batch) q_loss_step = q1_loss_step + q2_loss_step self.q_optimizer.zero_grad() q_loss_step.backward() self.q_optimizer.step() self._update_counter += 1 q1_loss += q1_loss_step.detach().item() q2_loss += q2_loss_step.detach().item() if self._update_counter % self.update_interval == 0: if not pi_loss: pi_loss = 0 # update policy for p in self.q_funcs.parameters(): p.requires_grad = False pi_loss_step = self.update_policy(state_batch) # if there's a soft policy (i.e., max-ent), then we need to update target entropy if self.is_soft: if not a_loss: a_loss = 0 pi_loss_step, a_loss_step = pi_loss_step self.temp_optimizer.zero_grad() a_loss_step.backward() self.temp_optimizer.step() a_loss += a_loss_step.detach().item() self.policy_optimizer.zero_grad() pi_loss_step.backward() self.policy_optimizer.step() for p in self.q_funcs.parameters(): p.requires_grad = True # update target policy and q-functions using Polyak averaging self.update_target() pi_loss += pi_loss_step.detach().item() return q1_loss, q2_loss, pi_loss, a_loss def load_checkpoint(self, checkpoint_path, env_name): load_dict = torch.load(checkpoint_path) assert load_dict['alg_name'] == self.alg_name, "Incorrect checkpoint, this is a {} policy, but you're loading a {} policy.".format(self.alg_name, load_dict['alg_name']) assert load_dict['env_name'] == env_name, "Incorrect checkpoint, this env is {}, but the policy was trained on {}.".format(env_name, load_dict['env_name']) self.q_funcs.load_state_dict(load_dict['double_q_state_dict']) self.target_q_funcs.load_state_dict(load_dict['target_double_q_state_dict']) self.policy.load_state_dict(load_dict['policy_state_dict']) if self.is_soft: self._log_alpha = load_dict['log_alpha'] if hasattr(self, "target_policy"): self.target_policy.load_state_dict(load_dict['target_policy_state_dict']) num_steps = int(load_dict['num_steps']) self._update_counter = load_dict['num_updates'] self.replay_pool = load_dict['replay_pool'] if load_dict['replay_pool'] else self.replay_pool return num_steps
class TD3_Agent: def __init__(self, seed, state_dim, action_dim, action_lim=1, lr=3e-4, gamma=0.99, tau=5e-3, batchsize=256, hidden_size=256, update_interval=2, buffer_size=1e6, target_noise=0.2, target_noise_clip=0.5, explore_noise=0.1): self.gamma = gamma self.tau = tau self.batchsize = batchsize self.update_interval = update_interval self.action_lim = action_lim self.target_noise = target_noise self.target_noise_clip = target_noise_clip self.explore_noise = explore_noise torch.manual_seed(seed) # aka critic self.q_funcs = DoubleQFunc(state_dim, action_dim, hidden_size=hidden_size).to(device) self.target_q_funcs = copy.deepcopy(self.q_funcs) self.target_q_funcs.eval() for p in self.target_q_funcs.parameters(): p.requires_grad = False # aka actor self.policy = Policy(state_dim, action_dim, hidden_size=hidden_size).to(device) self.target_policy = copy.deepcopy(self.policy) for p in self.target_policy.parameters(): p.requires_grad = False self.q_optimizer = torch.optim.Adam(self.q_funcs.parameters(), lr=lr) self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr) self.replay_pool = ReplayPool(action_dim=action_dim, state_dim=state_dim, capacity=int(buffer_size)) self._update_counter = 0 def reallocate_replay_pool(self, new_size: int): assert new_size != self.replay_pool.capacity, "Error, you've tried to allocate a new pool which has the same length" new_replay_pool = ReplayPool(capacity=new_size) new_replay_pool.initialise(self.replay_pool) self.replay_pool = new_replay_pool def get_action(self, state, state_filter=None, deterministic=False): if state_filter: state = state_filter(state) state = torch.Tensor(state).view(1, -1).to(device) with torch.no_grad(): action = self.policy(state) if not deterministic: action += self.explore_noise * torch.randn_like(action) action.clamp_(-self.action_lim, self.action_lim) return np.atleast_1d(action.squeeze().cpu().numpy()) def update_target(self): """moving average update of target networks""" with torch.no_grad(): for target_q_param, q_param in zip( self.target_q_funcs.parameters(), self.q_funcs.parameters()): target_q_param.data.copy_(self.tau * q_param.data + (1.0 - self.tau) * target_q_param.data) for target_pi_param, pi_param in zip( self.target_policy.parameters(), self.policy.parameters()): target_pi_param.data.copy_(self.tau * pi_param.data + (1.0 - self.tau) * target_pi_param.data) def update_q_functions(self, state_batch, action_batch, reward_batch, nextstate_batch, done_batch): with torch.no_grad(): nextaction_batch = self.target_policy(nextstate_batch) target_noise = self.target_noise * torch.randn_like( nextaction_batch) target_noise.clamp_(-self.target_noise_clip, self.target_noise_clip) nextaction_batch += target_noise nextaction_batch.clamp_(-self.action_lim, self.action_lim) q_t1, q_t2 = self.target_q_funcs(nextstate_batch, nextaction_batch) # take min to mitigate positive bias in q-function training q_target = torch.min(q_t1, q_t2) value_target = reward_batch + (1.0 - done_batch) * self.gamma * q_target q_1, q_2 = self.q_funcs(state_batch, action_batch) loss_1 = F.mse_loss(q_1, value_target) loss_2 = F.mse_loss(q_2, value_target) return loss_1, loss_2 def update_policy(self, state_batch): action_batch = self.policy(state_batch) q_b1, q_b2 = self.q_funcs(state_batch, action_batch) qval_batch = torch.min(q_b1, q_b2) policy_loss = (-qval_batch).mean() return policy_loss def optimize(self, n_updates, state_filter=None): q1_loss, q2_loss, pi_loss = 0, 0, None for i in range(n_updates): samples = self.replay_pool.sample(self.batchsize) if state_filter: state_batch = torch.FloatTensor(state_filter( samples.state)).to(device) nextstate_batch = torch.FloatTensor( state_filter(samples.nextstate)).to(device) else: state_batch = torch.FloatTensor(samples.state).to(device) nextstate_batch = torch.FloatTensor( samples.nextstate).to(device) action_batch = torch.FloatTensor(samples.action).to(device) reward_batch = torch.FloatTensor( samples.reward).to(device).unsqueeze(1) done_batch = torch.FloatTensor( samples.real_done).to(device).unsqueeze(1) # update q-funcs q1_loss_step, q2_loss_step = self.update_q_functions( state_batch, action_batch, reward_batch, nextstate_batch, done_batch) q_loss_step = q1_loss_step + q2_loss_step self.q_optimizer.zero_grad() q_loss_step.backward() self.q_optimizer.step() self._update_counter += 1 q1_loss += q1_loss_step.detach().item() q2_loss += q2_loss_step.detach().item() if self._update_counter % self.update_interval == 0: if not pi_loss: pi_loss = 0 # update policy for p in self.q_funcs.parameters(): p.requires_grad = False pi_loss_step = self.update_policy(state_batch) self.policy_optimizer.zero_grad() pi_loss_step.backward() self.policy_optimizer.step() for p in self.q_funcs.parameters(): p.requires_grad = True # update target policy and q-functions using Polyak averaging self.update_target() pi_loss += pi_loss_step.detach().item() return q1_loss, q2_loss, pi_loss