Example #1
0
    def forward(self,
                observation,
                reparameterize=True,
                deterministic=False,
                return_log_prob=False):
        """
        Forward pass.
        Assumes input is a torch tensor.

        :type observation: torch.Tensor
        """
        layer_input = observation
        for fc in self.fcs:
            layer_input = self.hidden_activation(fc(layer_input))
        network_output = self.output_activation(self.last_fc(layer_input))

        alpha = network_output[:, 0].unsqueeze(1) + EPSILON
        beta = network_output[:, 1].unsqueeze(1) + EPSILON
        distribution = Beta(alpha, beta)
        distribution_mean = distribution.mean
        if deterministic:
            sample = distribution.rsample()
        else:
            sample = distribution_mean
        # transform to range (min, max)
        action = self.min + self.max_min_difference * sample
        mean = self.min + self.max_min_difference * distribution_mean
        variance = self.max_min_difference_squared * distribution.variance
        std = torch.sqrt(variance)
        log_std = torch.log(std)
        log_prob = distribution.log_prob(sample)
        entropy = distribution.entropy()
        mean_action_log_prob = None
        pre_tanh_value = None
        return action, mean, log_std, log_prob, entropy, std, mean_action_log_prob, pre_tanh_value
Example #2
0
def _adapted_beta(shape: Union[Tuple, torch.Size],
                  a: Union[float, int, torch.Tensor],
                  b: Union[float, int, torch.Tensor],
                  same_on_batch=False) -> torch.Tensor:
    r""" The beta sampling function that accepts 'same_on_batch'.
    If same_on_batch is True, all values generated will be exactly same given a batch_size (shape[0]).
    By default, same_on_batch is set to False.
    """
    if not isinstance(a, torch.Tensor):
        a = torch.tensor(a, dtype=torch.float32)
    if not isinstance(b, torch.Tensor):
        b = torch.tensor(b, dtype=torch.float32)
    dist = Beta(a, b)
    if same_on_batch:
        return dist.rsample((1, *shape[1:])).repeat(shape[0])
    else:
        return dist.rsample(shape)
def construct_priors(alpha, lambda_0, lambda_1, T):
    p_beta = Beta(1, alpha)
    p_lambda = Gamma(lambda_0, lambda_1)
    p_zeta = Categorical(torch.tensor(mix_weights(p_beta.rsample([T - 1]))))
    return p_beta, p_lambda, p_zeta
Example #4
0
class TD3(object):
    def __init__(self, state_dim, action_dim, max_action, use_target_q,
                 target_distance_weight):
        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_static = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_static.load_state_dict(self.actor.state_dict())
        # self.actor_optimizer = RAdam(self.actor.parameters())
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters())
        # self.actor_optimizer = torch.optim.SGD(self.actor.parameters(), lr=0.0001, momentum=0.1)

        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_target = Critic(state_dim, action_dim).to(device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        # self.critic_optimizer = RAdam(self.critic.parameters())
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters())
        # self.critic_optimizer = torch.optim.SGD(self.critic.parameters(), lr=0.01, momentum=0.1)

        self.max_action = max_action
        self.use_target_q = use_target_q
        self.target_distance_weight = target_distance_weight

        self.noise_sampler = Beta(torch.FloatTensor([4.0]),
                                  torch.FloatTensor([4.0]))

    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(device)
        return self.actor(state).cpu().data.numpy().flatten()

    def select_action_target(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(device)
        return self.actor_target(state).cpu().data.numpy().flatten()

    def train(self,
              replay_buffer,
              iterations,
              batch_size=100,
              discount=0.99,
              tau=0.005,
              policy_noise=0.2,
              noise_clip=0.5,
              policy_freq=2,
              update_target_actor=True,
              update_target_q=True):

        abs_actor_loss = 0
        abs_critic_loss = 0

        for it in range(iterations):

            # Sample replay buffer
            x, y, u, r, d, _ = replay_buffer.sample(batch_size)
            state = torch.FloatTensor(x).to(device)
            action = torch.FloatTensor(u).to(device)
            next_state = torch.FloatTensor(y).to(device)
            done = torch.FloatTensor(1 - d).to(device)
            reward = torch.FloatTensor(r).to(device)

            # Select action according to policy and add clipped noise
            # noise = torch.FloatTensor(u).data.normal_(0, policy_noise).to(device)
            # noise = noise.clamp(-noise_clip, noise_clip)

            with torch.no_grad():
                target_action = self.actor_target(state)
                noise = (self.noise_sampler.rsample(
                    (action.shape[0], action.shape[1])).view(
                        action.shape[0], action.shape[1]) * 2 -
                         1).to(device) * noise_clip

                target_action = (self.actor_target(next_state) + noise).clamp(
                    -1, 1) * self.max_action

                # Compute the target Q value
                if self.use_target_q:
                    target_Q1, target_Q2 = self.critic_target(
                        next_state, target_action)
                else:
                    target_Q1, target_Q2 = self.critic(next_state,
                                                       target_action)
                target_Q = torch.min(target_Q1, target_Q2)
                target_Q = reward + (done * discount * target_Q)

            # Get current Q estimates
            current_Q1, current_Q2 = self.critic(state, action)

            # Compute critic loss
            critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
                current_Q2, target_Q)

            # Optimize the critic
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_optimizer.step()

            abs_critic_loss += abs(critic_loss.item())

            # Delayed policy updates
            if it % policy_freq == 0:

                # Compute actor loss
                action = self.actor(state) * self.max_action
                actor_loss = -self.critic.Q1(state, action).mean(
                ) + F.mse_loss(action, target_action,
                               reduce=True) * self.target_distance_weight

                # Optimize the actor
                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                self.actor_optimizer.step()

                abs_actor_loss += abs(actor_loss.item())

                # Update the frozen target models
                if update_target_q:
                    for param, target_param in zip(
                            self.critic.parameters(),
                            self.critic_target.parameters()):
                        target_param.data.copy_(tau * param.data +
                                                (1 - tau) * target_param.data)

                if update_target_actor:
                    for param, target_param in zip(
                            self.actor.parameters(),
                            self.actor_target.parameters()):
                        target_param.data.copy_(tau * param.data +
                                                (1 - tau) * target_param.data)

        return abs_critic_loss / iterations, abs_actor_loss / iterations * policy_freq

    def save(self, filename, directory):
        torch.save(self.actor.state_dict(),
                   '%s/%s_actor.pth' % (directory, filename))
        torch.save(self.critic.state_dict(),
                   '%s/%s_critic.pth' % (directory, filename))

    def load(self, filename, directory):
        self.actor.load_state_dict(
            torch.load('%s/%s_actor.pth' % (directory, filename)))
        self.critic.load_state_dict(
            torch.load('%s/%s_critic.pth' % (directory, filename)))