def forward(self, observation, reparameterize=True, deterministic=False, return_log_prob=False): """ Forward pass. Assumes input is a torch tensor. :type observation: torch.Tensor """ layer_input = observation for fc in self.fcs: layer_input = self.hidden_activation(fc(layer_input)) network_output = self.output_activation(self.last_fc(layer_input)) alpha = network_output[:, 0].unsqueeze(1) + EPSILON beta = network_output[:, 1].unsqueeze(1) + EPSILON distribution = Beta(alpha, beta) distribution_mean = distribution.mean if deterministic: sample = distribution.rsample() else: sample = distribution_mean # transform to range (min, max) action = self.min + self.max_min_difference * sample mean = self.min + self.max_min_difference * distribution_mean variance = self.max_min_difference_squared * distribution.variance std = torch.sqrt(variance) log_std = torch.log(std) log_prob = distribution.log_prob(sample) entropy = distribution.entropy() mean_action_log_prob = None pre_tanh_value = None return action, mean, log_std, log_prob, entropy, std, mean_action_log_prob, pre_tanh_value
def _adapted_beta(shape: Union[Tuple, torch.Size], a: Union[float, int, torch.Tensor], b: Union[float, int, torch.Tensor], same_on_batch=False) -> torch.Tensor: r""" The beta sampling function that accepts 'same_on_batch'. If same_on_batch is True, all values generated will be exactly same given a batch_size (shape[0]). By default, same_on_batch is set to False. """ if not isinstance(a, torch.Tensor): a = torch.tensor(a, dtype=torch.float32) if not isinstance(b, torch.Tensor): b = torch.tensor(b, dtype=torch.float32) dist = Beta(a, b) if same_on_batch: return dist.rsample((1, *shape[1:])).repeat(shape[0]) else: return dist.rsample(shape)
def construct_priors(alpha, lambda_0, lambda_1, T): p_beta = Beta(1, alpha) p_lambda = Gamma(lambda_0, lambda_1) p_zeta = Categorical(torch.tensor(mix_weights(p_beta.rsample([T - 1])))) return p_beta, p_lambda, p_zeta
class TD3(object): def __init__(self, state_dim, action_dim, max_action, use_target_q, target_distance_weight): self.actor = Actor(state_dim, action_dim, max_action).to(device) self.actor_target = Actor(state_dim, action_dim, max_action).to(device) self.actor_target.load_state_dict(self.actor.state_dict()) self.actor_static = Actor(state_dim, action_dim, max_action).to(device) self.actor_static.load_state_dict(self.actor.state_dict()) # self.actor_optimizer = RAdam(self.actor.parameters()) self.actor_optimizer = torch.optim.Adam(self.actor.parameters()) # self.actor_optimizer = torch.optim.SGD(self.actor.parameters(), lr=0.0001, momentum=0.1) self.critic = Critic(state_dim, action_dim).to(device) self.critic_target = Critic(state_dim, action_dim).to(device) self.critic_target.load_state_dict(self.critic.state_dict()) # self.critic_optimizer = RAdam(self.critic.parameters()) self.critic_optimizer = torch.optim.Adam(self.critic.parameters()) # self.critic_optimizer = torch.optim.SGD(self.critic.parameters(), lr=0.01, momentum=0.1) self.max_action = max_action self.use_target_q = use_target_q self.target_distance_weight = target_distance_weight self.noise_sampler = Beta(torch.FloatTensor([4.0]), torch.FloatTensor([4.0])) def select_action(self, state): state = torch.FloatTensor(state.reshape(1, -1)).to(device) return self.actor(state).cpu().data.numpy().flatten() def select_action_target(self, state): state = torch.FloatTensor(state.reshape(1, -1)).to(device) return self.actor_target(state).cpu().data.numpy().flatten() def train(self, replay_buffer, iterations, batch_size=100, discount=0.99, tau=0.005, policy_noise=0.2, noise_clip=0.5, policy_freq=2, update_target_actor=True, update_target_q=True): abs_actor_loss = 0 abs_critic_loss = 0 for it in range(iterations): # Sample replay buffer x, y, u, r, d, _ = replay_buffer.sample(batch_size) state = torch.FloatTensor(x).to(device) action = torch.FloatTensor(u).to(device) next_state = torch.FloatTensor(y).to(device) done = torch.FloatTensor(1 - d).to(device) reward = torch.FloatTensor(r).to(device) # Select action according to policy and add clipped noise # noise = torch.FloatTensor(u).data.normal_(0, policy_noise).to(device) # noise = noise.clamp(-noise_clip, noise_clip) with torch.no_grad(): target_action = self.actor_target(state) noise = (self.noise_sampler.rsample( (action.shape[0], action.shape[1])).view( action.shape[0], action.shape[1]) * 2 - 1).to(device) * noise_clip target_action = (self.actor_target(next_state) + noise).clamp( -1, 1) * self.max_action # Compute the target Q value if self.use_target_q: target_Q1, target_Q2 = self.critic_target( next_state, target_action) else: target_Q1, target_Q2 = self.critic(next_state, target_action) target_Q = torch.min(target_Q1, target_Q2) target_Q = reward + (done * discount * target_Q) # Get current Q estimates current_Q1, current_Q2 = self.critic(state, action) # Compute critic loss critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss( current_Q2, target_Q) # Optimize the critic self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() abs_critic_loss += abs(critic_loss.item()) # Delayed policy updates if it % policy_freq == 0: # Compute actor loss action = self.actor(state) * self.max_action actor_loss = -self.critic.Q1(state, action).mean( ) + F.mse_loss(action, target_action, reduce=True) * self.target_distance_weight # Optimize the actor self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() abs_actor_loss += abs(actor_loss.item()) # Update the frozen target models if update_target_q: for param, target_param in zip( self.critic.parameters(), self.critic_target.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) if update_target_actor: for param, target_param in zip( self.actor.parameters(), self.actor_target.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) return abs_critic_loss / iterations, abs_actor_loss / iterations * policy_freq def save(self, filename, directory): torch.save(self.actor.state_dict(), '%s/%s_actor.pth' % (directory, filename)) torch.save(self.critic.state_dict(), '%s/%s_critic.pth' % (directory, filename)) def load(self, filename, directory): self.actor.load_state_dict( torch.load('%s/%s_actor.pth' % (directory, filename))) self.critic.load_state_dict( torch.load('%s/%s_critic.pth' % (directory, filename)))