Ejemplo n.º 1
0
  def sample(self, obs_with_hidden):

    mean, log_std, hidden = self.actor.forward(obs_with_hidden)

    log_std = torch.tanh(log_std)
    log_std = self.hyperps['log_std_min'] + 0.5 * (self.hyperps['log_std_max'] - self.hyperps['log_std_min']) * (log_std + 1)

    std = log_std.exp()
    normal = Normal(mean, std)
    
  
    x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
    y_t = torch.tanh(x_t)

    action = y_t * self.hyperps['action_scale'] + self.hyperps['action_bias']
    log_prob = normal.log_prob(x_t)
    
    # Enforcing Action Bound
    
    log_prob = log_prob - torch.log(self.hyperps['action_scale'] * (1 - y_t.pow(2)) + self.hyperps['epsilon'])
    log_prob = log_prob.sum(1, keepdim=True)
    
    mean = torch.tanh(mean) * self.hyperps['action_scale'] + self.hyperps['action_bias']

    return action, log_prob, mean, hidden
Ejemplo n.º 2
0
 def forward(self, batch_size, temp=1.0):
     first_dist = Normal(self.prior.first_mean,
                         self.prior.first_logvar.exp())
     results = []
     sample = temp * first_dist.sample(sample_shape=batch_size)[:, 0, :]
     out = self.bottom_up.first(sample)
     out = out.view(out.size(0), out.size(1), 1, 1)
     out = func.interpolate(out, scale_factor=self.scale)
     for idx, (block, mean, logvar, mf, lf, mod, zero) in enumerate(
             zip(self.bottom_up.blocks, self.prior.mean, self.prior.logvar,
                 self.prior.mean_factor, self.prior.logvar_factor,
                 self.bottom_up.modifiers, self.bottom_up.zeros)):
         mf = 1
         lf = 1
         results.append(out)
         pos = self.prior.position_embedding(out)
         dpos = torch.cat((out, pos), dim=1)
         dist = Normal(mean(dpos) * mf, (logvar(dpos) * lf).exp())
         sample = temp * dist.rsample()
         out = block(out + 0.1 * mod(sample))
         if (idx + 1) % self.bottom_up.level_repeat == 0 and idx < len(
                 self.bottom_up.blocks) - 1:
             out = func.interpolate(out, scale_factor=2)
     res = DiscretizedMixtureLogits(10, self.decoder.block(
         results[-1])).sample()
     res = ((res + 1) / 2).clamp(0, 1)
     return res
Ejemplo n.º 3
0
    def sample(self, obs, msg):

        mean, log_std, hidden, msg = self.actor.forward(obs, msg)

        std = log_std.exp()
        normal = Normal(mean, std)
        x_t = normal.rsample(
        )  # for reparameterization trick (mean + std * N(0,1))
        y_t = torch.tanh(x_t)

        action = y_t * self.hyperps[
            'action_scale']  #+ self.hyperps['action_bias']
        action[:, 0] += self.hyperps['action_bias']
        log_prob = normal.log_prob(x_t)

        # Enforcing Action Bound
        log_prob -= torch.log(self.hyperps['action_scale'] * (1 - y_t.pow(2)) +
                              self.hyperps['epsilon'])
        log_prob = log_prob.sum(1, keepdim=True)

        mean = torch.tanh(
            mean) * self.hyperps['action_scale'] + self.hyperps['action_bias']

        entropy = normal.entropy()
        entropy1, entropy2 = entropy[0][0].item(), entropy[0][1].item()

        #print('Std: {:2.3f}, {:2.3f}, log_std: {:2.3f},{:2.3f}, entropy:{:2.3f}, {:2.3f}'.format(std[0][0].item(),std[0][1].item(), log_std[0][0].item(), log_std[0][1].item(), entropy1, entropy2))
        return action, log_prob, mean, std, hidden, msg
Ejemplo n.º 4
0
 def get_action(self, s):
     s = torch.tensor(data=s, dtype=torch.float)
     mean, std = self.actor(s)
     normal = Normal(mean, std)
     z = normal.rsample()
     a = torch.tanh(z)
     return a.detach().numpy().tolist()
Ejemplo n.º 5
0
    def get_action (self,state, epsilon=1e-6, reparam=True):
        mean, log_std = self.forward(state)
        std = log_std.exp()

        normal = Normal(mean, std)
        if reparam=True:
            z = normal.rsample()    # reparameterization trick
Ejemplo n.º 6
0
 def predict(self, x) -> dict:
     """
     :param x: tensor of shape [batch_size, num_features]
     :return: A dictionary containing prediction i.e.
     - latent_dist = torch.distributions.Normal instance of latent space
     - latent_mu = torch.Tensor mu (mean) parameter of latent Normal distribution
     - latent_sigma = torch.Tensor sigma (std) parameter of latent Normal distribution
     - recon_mu = torch.Tensor mu (mean) parameter of reconstructed Normal distribution
     - recon_sigma = torch.Tensor sigma (std) parameter of reconstructed Normal distribution
     - z = torch.Tensor sampled latent space from latent distribution
     """
     batch_size = len(x)
     latent_mu, latent_sigma = self.encoder(x).chunk(
         2, dim=1)  #both with size [batch_size, latent_size]
     latent_sigma = softplus(latent_sigma)
     dist = Normal(latent_mu, latent_sigma)
     z = dist.rsample([self.L])  # shape: [L, batch_size, latent_size]
     z = z.view(self.L * batch_size, self.latent_size)
     recon_mu, recon_sigma = self.decoder(z).chunk(2, dim=1)
     recon_sigma = softplus(recon_sigma)
     recon_mu = recon_mu.view(self.L, *x.shape)
     recon_sigma = recon_sigma.view(self.L, *x.shape)
     return dict(latent_dist=dist,
                 latent_mu=latent_mu,
                 latent_sigma=latent_sigma,
                 recon_mu=recon_mu,
                 recon_sigma=recon_sigma,
                 z=z)
Ejemplo n.º 7
0
    def produce_action_and_action_info(self,
                                       state,
                                       return_stats: bool = False):
        """Given the state, produces an action, the log probability of the action, and the tanh of the mean action"""
        if return_stats:
            actor_output, actor_stats = self.actor_local(state, return_stats)
        else:
            actor_output = self.actor_local(state)

        mean, log_std = actor_output[:, :self.
                                     action_size], actor_output[:, self.
                                                                action_size:]
        std = log_std.exp()
        normal = Normal(mean, std)
        x_t = normal.rsample(
        )  # rsample means it is sampled using reparameterisation trick
        action = torch.tanh(x_t)
        log_prob = normal.log_prob(x_t)
        log_prob -= torch.log(1 - action.pow(2) + EPSILON)
        log_prob = log_prob.sum(1, keepdim=True)

        if return_stats:
            actor_stats['action'] = {
                'mean': mean.detach().cpu().numpy(),
                'std': std.detach().cpu().numpy(),
            }
            return action, log_prob, torch.tanh(mean), actor_stats
        else:
            return action, log_prob, torch.tanh(mean)
Ejemplo n.º 8
0
    def forward(
        self, x: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, ...]]:
        n = int(torch.randint(1, 5, (1, )).item())

        # perm = torch.argsort(torch.rand(x.size(0), x.size(0), device=x.device), dim=1)
        # subsets = perm[:, :n]

        d = ((x.unsqueeze(0) - x.unsqueeze(1))**2).sum(dim=2)
        subsets = torch.argsort(d, dim=1)[:, :n]

        # create the subset index which will be used to get the subsets distance from the pairwise distance
        # matrix. linspace is needed because it needs to be a tuple of (x, y) coordinates
        # fmt: off
        sub_idx = (
            torch.linspace(0,
                           x.size(0) - 1, x.size(0), device=x.device).repeat(
                               n, 1).T.flatten().long(),  # type: ignore
            subsets.flatten().long())
        # fmt: on

        # print(f"subsets: {subsets.size()} sub idx: {sub_idx.size()}")

        z = x[subsets]

        mu = z.mean(dim=1)
        mx, _ = z.max(dim=1)
        z = torch.cat((mu, mx), dim=1)

        z = self.decoder(z)
        dist = Normal(z[:, :self.in_dim], torch.exp(z[:, self.in_dim:] / 2))
        x = x + dist.rsample()

        return x, dist.entropy().mean(), sub_idx
Ejemplo n.º 9
0
def normal_tanh_reparameterised_sample(dis: Normal,
                                       epsilon=1e-6
                                       ) -> Tuple[torch.tensor, torch.tensor]:
    """
    The log-likelihood here is for the TanhNorm distribution instead of only Gaussian distribution.
    The TanhNorm forces the Gaussian with infinite action range to be finite.

    For the three terms in this log-likelihood estimation:
     (1). the first term is the log probability of action as in common stochastic Gaussian action policy
     (without Tanh); \
    (2). the second term is the caused by the Tanh(), as shown in appendix C. Enforcing Action Bounds of
    https://arxiv.org/pdf/1801.01290.pdf, the epsilon is for preventing the negative cases in log


@param dis:
@param epsilon:
@return:
"""

    z = dis.rsample()  # for reparameterisation trick (mean + std * N(0,1))
    action = torch.tanh(z)
    log_prob = torch.sum(dis.log_prob(z) -
                         torch.log(1 - action.pow(2) + epsilon),
                         dim=-1,
                         keepdim=True)
    return action, log_prob
Ejemplo n.º 10
0
class TanhNormal(Distribution):
  """Distribution of X ~ tanh(Z) where Z ~ N(mean, std)
  Adapted from https://github.com/vitchyr/rlkit
  """
  def __init__(self, normal_mean, normal_std, epsilon=1e-6):
    self.normal_mean = normal_mean
    self.normal_std = normal_std
    self.normal = Normal(normal_mean, normal_std)
    self.epsilon = epsilon
    super().__init__(self.normal.batch_shape, self.normal.event_shape)

  def log_prob(self, x):
    assert hasattr(x, "pre_tanh_value")
    assert x.dim() == 2 and x.pre_tanh_value.dim() == 2
    return self.normal.log_prob(x.pre_tanh_value) - torch.log(
      1 - x * x + self.epsilon
    )

  def sample(self, sample_shape=torch.Size()):
    z = self.normal.sample(sample_shape)
    out = torch.tanh(z)
    out.pre_tanh_value = z
    return out

  def rsample(self, sample_shape=torch.Size()):
    z = self.normal.rsample(sample_shape)
    out = torch.tanh(z)
    out.pre_tanh_value = z
    return out
Ejemplo n.º 11
0
    def forward(self, x):
        # Reshape data for net
        if len(x.shape)==4:
            batch, chan, h, w = x.shape
            x = x.view(batch,chan,h*w).squeeze(1)

        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        mean = self.mean_linear(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, min=self.LOG_SIG_MIN, max=self.LOG_SIG_MAX)

        std = log_std.exp()
        normal = Normal(mean, std)
        delta = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
        log_prob = normal.log_prob(delta)

        if self.decode:
            delta = self.linear(delta)
            # Problem here is log_prob can't be for delta unless same size...

        # Enforcing Action Bound
        # log_prob -= torch.log(1 - action.pow(2) + self.epsilon)
        # log_prob = log_prob.sum(-1, keepdim=True)


        # Shape noise to match original data
        delta = delta.unsqueeze(1)
        delta = delta.view(batch,chan,h,w)

        return delta, mean, log_std, log_prob
Ejemplo n.º 12
0
    def forward(self, x, reparam=True):
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        action = self.max_action * torch.tanh(self.l3(x))

        mean = self.mean_linear(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
        std = log_std.exp()

        normal = Normal(mean, std)

        if reparam == True:
            x_t = normal.rsample()
        else:
            x_t = normal.sample()

        log_prob = normal.log_prob(x_t)

        log_prob -= torch.log(1 - action.pow(2) + epsilon)
        log_prob = log_prob.sum(-1, keepdim=True)

        entropy = normal.entropy()
        dist_entropy = entropy.sum(-1).mean()

        return action, dist_entropy, mean, log_std, log_prob
Ejemplo n.º 13
0
    def _sample_z(self, mean, var, name, cell_coord):
        '''
        Performs the sampling step in VAE and stores the distribution for KL computation
        :param mean:
        :param var:
        :param name: name of the distribution
        :return: sampled value
        '''
        dist = Normal(loc=mean, scale=var)

        if name not in self.dist_param.keys():
            _, H, W = self.feature_space_dim
            self.dist_param[name] = {}
            self.dist_param[name]['mean'] = torch.empty(
                self.batch_size,
                mean.shape[-1],
                H,
                W,
            ).to(self.device)
            self.dist_param[name]['sigma'] = torch.empty(
                self.batch_size,
                mean.shape[-1],
                H,
                W,
            ).to(self.device)
        x, y = cell_coord
        self.dist_param[name]['mean'][:, :, x, y] = mean
        self.dist_param[name]['sigma'][:, :, x, y] = var

        return dist.rsample()
Ejemplo n.º 14
0
    def forward(self, x):
        # Sample the weights and forward it
        # perform all operations in the forward rather than in __init__ (including log[1+exp(rho)])
        variational_posterior = Normal(self.mu,
                                       torch.log1p(torch.exp(self.rho)))
        variational_posterior_bias = Normal(
            self.mu_bias, torch.log1p(torch.exp(self.rho_bias)))
        w = variational_posterior.rsample()
        b = variational_posterior_bias.rsample()

        # Get the log prob
        self.log_variational_posterior = (variational_posterior.log_prob(
            w)).sum() + (variational_posterior_bias.log_prob(b)).sum()
        self.log_prior = self.prior_weights.log_prob(
            w).sum() + self.prior_bias.log_prob(b).sum()
        return F.linear(x, w, b)
Ejemplo n.º 15
0
    def sample(self, state):
        """ 
        Samples actions and log actions from the distribution

        Arguments:
        state : State vector containing state variables
        """
        epsilon = 1e-6
        mu, log_std = self.forward(state)
        std = log_std.exp()

        #Gaussian distribution with mu and std from the network
        normal = Normal(mu, std)

        #Action is sampled from the distribution
        z = normal.rsample()
        #Tanh sqeezes the action between (-1,1)
        action = torch.tanh(z)
        """
        Log probability for the action is calculated using log-likelihood formula
        Refer equation 21 : https://arxiv.org/pdf/1801.01290.pdf
        """
        log_prob = (normal.log_prob(z) -
                    torch.log(1 - (torch.tanh(z)).pow(2) + epsilon))
        log_prob = log_prob.sum(1, keepdims=True)

        return action, log_prob
Ejemplo n.º 16
0
 def forward(self, x):
     logN = torch.log(x.sum(axis=-1)).view(-1, 1)
     varz = torch.stack([self.variational_logvars] * len(logN))
     varz = torch.cat((varz, logN), dim=1)
     z_var = self.sigma_net(varz)
     z_mean = self.encode(x)
     qz = Normal(z_mean, torch.exp(0.5 * z_var))
     ql = Normal(0, torch.exp(0.5 * self.log_sigma_sq))
     z_sample = qz.rsample()
     l_sample = ql.rsample()
     x_out = self.decoder(z_sample) + l_sample
     kl_div = kl_divergence(qz, Normal(0, 1)).mean(0).sum()
     recon_loss = self.recon_model_loglik(x, x_out).mean(0).sum()
     elbo = recon_loss - kl_div
     loss = - elbo
     return loss
Ejemplo n.º 17
0
    def get_z_sup_sample(self, zp_mean, zp_std):
        """Get z sample and log_lik of sample from state code.

        Args:
            zp_mean, zp_std (torch.Tensor), 2 * (nTo, 4): State dist. parameters.

        Returns:
            z_obj (torch.Tensor), (nTo, 4): Sampled states.
            log_q_xz (torch.Tensor), (nTo): Likelihood of samples for ELBO.

        """
        # get z from sampling, each gaussian has dim (4)
        # we need n4o samples per gaussian. dim of sampling is again n4o, 4
        z_dist = Normal(zp_mean, zp_std)

        # rsample can propagate gradients, no explicit reparametrization
        z_tmp = z_dist.rsample().to(self.c.device)
        # Get log lik of sample
        # approximated E_q(z|x) [log q(z|x)] with single sample
        # Sum (in log-domain) the probabilities for the z's [per image]
        # sum (n4o, 4) to n4o
        log_q_xz = z_dist.log_prob(z_tmp).sum(-1)

        # Get sy from sy/sx sy.
        z_obj = self.sy_from_quotient(z_tmp)

        return z_obj, log_q_xz
    def sample_conditional_a(self, resid_image, var_so_far, pixel_1d):

        is_on = (pixel_1d < (self.n_discrete_latent - 1)).float()

        # pass through galaxy encoder
        pixel_2d = self.one_galaxy_vae.pixel_1d_to_2d(pixel_1d)
        z_mean, z_var = self.one_galaxy_vae.enc(resid_image, pixel_2d)

        # sample z
        q_z = Normal(z_mean, z_var.sqrt())
        z_sample = q_z.rsample()

        # kl term for continuous latent vars
        log_q_z = q_z.log_prob(z_sample).sum(1)
        p_z = Normal(torch.zeros_like(z_sample), torch.ones_like(z_sample))
        log_p_z = p_z.log_prob(z_sample).sum(1)
        kl_z = is_on * (log_q_z - log_p_z)

        # run through decoder
        recon_mean, recon_var = self.one_galaxy_vae.dec(is_on, pixel_2d, z_sample)

        # NOTE: we will have to the recon means once we do more detections
        # recon_means = recon_mean + image_so_far
        # recon_vars = recon_var + var_so_far

        return recon_mean, recon_var, is_on, kl_z
Ejemplo n.º 19
0
    def forward(self, state: torch.Tensor) -> torch.Tensor:
        """Forward method implementation."""
        x = self.relu(self.bn1(self.conv1(state)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))

        x = self.lrelu(self.fc1(x.view(x.size(0), -1)))
        x = self.lrelu(self.fc2(x))
        x = self.lrelu(self.fc3(x))

        # get mean
        mu = self.mu_layer(x).tanh()

        # get std
        log_std = self.log_std_layer(x).tanh()
        log_std = self.log_std_min + 0.5 * (self.log_std_max -
                                            self.log_std_min) * (log_std + 1)
        std = torch.exp(log_std)

        # sample actions
        dist = Normal(mu, std)
        z = dist.rsample()

        # normalize action and log_prob
        # see appendix C of [2]
        action = z.tanh()
        log_prob = dist.log_prob(z) - torch.log(1 - action.pow(2) + 1e-7)
        log_prob = log_prob.sum(-1, keepdim=True)

        return action, log_prob
Ejemplo n.º 20
0
    def sample_action(self,
                      state: np.ndarray,
                      deterministic: bool = False) -> np.ndarray:
        """
        sample action normal distribution parameterized by policy network

        :param state: Observation state
        :param deterministic: Is the greedy action being chosen?
        :type state: int, float, ...
        :type deterministic: bool
        :returns: action
        :returns: log likelihood of policy
        :returns: scaled mean of normal distribution
        :rtype: int, float, ...
        :rtype: float
        :rtype: float
        """
        mean, log_std = self.policy.forward(state)
        std = log_std.exp()

        # reparameterization trick
        distribution = Normal(mean, std)
        xi = distribution.rsample()
        yi = torch.tanh(xi)
        action = yi * self.action_scale + self.action_bias
        log_pi = distribution.log_prob(xi)

        # enforcing action bound (appendix of paper)
        log_pi -= torch.log(self.action_scale * (1 - yi.pow(2)) +
                            np.finfo(np.float32).eps)
        log_pi = log_pi.sum(1, keepdim=True)
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return action.float(), log_pi, mean
Ejemplo n.º 21
0
 def reparameters(self, means, stds):
     distribution = Normal(means, stds)
     actions = distribution.rsample()
     news_actions = torch.tanh(actions)
     log_probs = distribution.log_prob(actions.detach()).sum(dim=1,
                                                             keepdims=True)
     return news_actions, log_probs
Ejemplo n.º 22
0
    def forward(self, x, n_samples, squeeze=True, reparam=True):
        q = self.encoder(x)
        q_m = self.mean_encoder(q)
        q_v = self.var_encoder(q)

        # q_v = 16.0 * self.tanh(q_v)
        # q_v = torch.clamp(q_v, min=-17., max=14.)

        # PREVIOUS TO KEEP
        # q_m = torch.clamp(q_m, min=-1000, max=1000)

        # q_v = torch.clamp(q_v, min=-17.0, max=8.0)
        q_v = torch.clamp(q_v, min=-17.0, max=10.0)
        q_v = q_v.exp()
        # q_v = 1e-16 + q_v.exp()

        variational_dist = Normal(loc=q_m, scale=q_v.sqrt())

        if n_samples == 1 and squeeze:
            sample_shape = []
        else:
            sample_shape = (n_samples, )
        if reparam:
            latent = variational_dist.rsample(sample_shape=sample_shape)
        else:
            latent = variational_dist.sample(sample_shape=sample_shape)
        return dict(q_m=q_m, q_v=q_v, latent=latent)
Ejemplo n.º 23
0
 def forward(self, input: Tensor) -> Tensor:
     params_ = self.conv(input)
     mu  = params_[..., :128]
     std = params_[..., 128:]
     n = Normal(mu, std)
     z = n.rsample()  # latent variable
     return z
Ejemplo n.º 24
0
    def forward(self, state):
        """
        Given states input [batch, state_dim],
        """
        state = state.to(self.device)
        # x = F.relu(self.linear1(state))
        # x = F.relu(self.linear2(x))
        x = self.linear(state)

        mean = self.mean_linear(x)
        # return torch.tanh(mean), 0

        log_std = self.log_std_linear(x)
        # log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        std = torch.exp(log_std)
        std = torch.clamp(std, self.std_min, self.std_max)

        normal = Normal(mean, std)
        z = normal.rsample()
        a = torch.tanh(z)

        # compute log probability
        log_pi = normal.log_prob(z) - torch.log(1 - a.pow(2) + 1e-6)

        return a, log_pi
Ejemplo n.º 25
0
    def get_action(self, state: torch.Tensor, deterministic: bool = False):
        state = torch.as_tensor(state).float()

        if self.actor.sac:
            mean, log_std = self.actor(state)
            std = log_std.exp()
            distribution = Normal(mean, std)

            action_probs = distribution.rsample()
            log_probs = distribution.log_prob(action_probs)
            action_probs = torch.tanh(action_probs)

            action = action_probs * self.action_scale + self.action_bias

            # enforcing action bound (appendix of SAC paper)
            log_probs -= torch.log(
                self.action_scale * (1 - action_probs.pow(2)) + np.finfo(np.float32).eps
            )
            log_probs = log_probs.sum(1, keepdim=True)
            mean = torch.tanh(mean) * self.action_scale + self.action_bias

            action = (action.float(), log_probs, mean)
        else:
            action = self.actor.get_action(state, deterministic=deterministic)

        return action
Ejemplo n.º 26
0
    def sample(self, state):
        '''
        :param state: (batch_num, state_dim)
        :return: action: (batch_num, action_dim, option_num)
        log_prob: (batch_num, option_num)
        mean_mat: (batch_num, action_dim, option_num)
        '''
        mean_mat, log_std_mat = self.forward(state)
        std_mat = log_std_mat.exp()
        normal = Normal(mean_mat, std_mat)
        x_t = normal.rsample(
        )  # for reparameterization trick (mean + std * N(0,1))

        # print('x_t', x_t.shape)

        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)  # log(pi(at|st))

        # Enforcing Action Bound, because the Gaussian distribution changes from (-inf, inf) to (-1, 1)
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)

        # print('log_prob', log_prob.shape)

        log_prob = log_prob.sum(1, keepdim=True)
        # print('log_prob_sum', log_prob.shape)

        mean_mat = torch.tanh(mean_mat) * self.action_scale + self.action_bias

        return action, log_prob, mean_mat
Ejemplo n.º 27
0
    def forward(self, x, a = None):
        x = self.net(x)
        mu = self.mu(x)
        log_sigma = self.log_sigma(x)
        """ Note from Josh Achiam @ OpenAI
        Because algorithm maximizes trade-off of reward and entropy,
        entropy must be unique to state---and therefore log_stds need
        to be a neural network output instead of a shared-across-states
        learnable parameter vector. But for deep Relu and other nets,
        simply sticking an activationless dense layer at the end would
        be quite bad---at the beginning of training, a randomly initialized
        net could produce extremely large values for the log_stds, which
        would result in some actions being either entirely deterministic
        or too random to come back to earth. Either of these introduces
        numerical instability which could break the algorithm. To
        protect against that, we'll constrain the output range of the
        log_stds, to lie within [LOG_STD_MIN, LOG_STD_MAX]. This is
        slightly different from the trick used by the original authors of
        SAC---they used tf.clip_by_value instead of squashing and rescaling.
        I prefer this approach because it allows gradient propagation
        through log_std where clipping wouldn't, but I don't know if
        it makes much of a difference.
        """
        log_sigma = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_sigma + 1)
        sigma = torch.exp(log_sigma)
        dist = Normal(mu, sigma)
        # rsample() - https://pytorch.org/docs/stable/distributions.html#pathwise-derivative
        pi = dist.rsample() # reparametrization
        logp_pi = dist.log_prob(pi).sum(dim=1)

        mu *= self.act_limit
        pi *= self.act_limit
        mu, pi, logp_pi = apply_squashing_func(mu, pi, logp_pi)
        return mu, pi, logp_pi
Ejemplo n.º 28
0
    def noisy_action(self, state, return_only_action=True):

        if self.policy_type == 'GaussianPolicy':
            mean, log_std = self.clean_action(state, return_only_action=False)
            std = log_std.exp()
            normal = Normal(mean, std)
            x_t = normal.rsample(
            )  # for reparameterization trick (mean + std * N(0,1))
            action = torch.tanh(x_t)

            if return_only_action: return action

            log_prob = normal.log_prob(x_t)
            # Enforcing Action Bound
            log_prob -= torch.log(1 - action.pow(2) + epsilon)
            log_prob = log_prob.sum(-1, keepdim=True)

            #log_prob.clamp(-10, 0)

            return action, log_prob, x_t, mean, log_std

        elif self.policy_type == 'DeterministicPolicy':
            mean = self.clean_action(state)
            action = mean + self.noise.normal_(0., std=0.4)

            if return_only_action: return action
            else:
                return action, torch.tensor(0.), torch.tensor(
                    0.), mean, torch.tensor(0.)
Ejemplo n.º 29
0
    def forward(self, x):
        x = super(ReparamGaussianPolicy, self).forward(x)

        mu = self.mu_layer(x)
        log_std = torch.tanh(self.log_std_layer(x))
        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std +
                                                                     1)
        std = torch.exp(log_std)

        # https://pytorch.org/docs/stable/distributions.html#normal
        dist = Normal(mu, std)
        pi = dist.rsample()  # reparameterization trick (mean + std * N(0,1))
        log_pi = dist.log_prob(pi).sum(dim=-1)
        mu, pi, log_pi = self.apply_squashing_func(mu, pi, log_pi)

        if self.log_type == 'log':
            # make sure actions are in correct range
            mu = mu * self.action_scale
            pi = pi * self.action_scale
            return mu, pi, log_pi
        elif self.log_type == 'log-q':
            if self.q == 1.:
                log_q_pi = log_pi
            else:
                exp_log_pi = torch.exp(log_pi)
                log_q_pi = self.tsallis_entropy_log_q(exp_log_pi, self.q)
            # make sure actions are in correct range
            mu = mu * self.action_scale
            pi = pi * self.action_scale
            return mu, pi, log_q_pi
    def select_action(self, state, deterministic=False):
        """
        Compute an action or vector of actions given a state or vector of states
        :param state: the input state(s)
        :param deterministic: whether the policy should be considered deterministic or not
        :return: the resulting action(s)
        """
        with torch.no_grad():
            # Forward pass
            mu, std = self.forward(state)
            pi_distribution = Normal(mu, std)

            if deterministic:
                # Only used for evaluating policy at test time.
                pi_action = mu
            else:
                pi_action = pi_distribution.rsample()

            # Finally applies tanh for squashing
            #If env is Pendulum:
            pi_action = 2 * torch.tanh(pi_action)
            # pi_action = torch.tanh(pi_action)
            if len(pi_action) == 1:
                pi_action = pi_action[0]
            return pi_action.data.numpy().astype(float)
Ejemplo n.º 31
0
    def forward(self, sample_shape=()):
        if isinstance(self.prior, FactorisedPrior):
            sqrt_prec = 1. / math.sqrt(self.in_features)
            post_mean = self.post_mean * sqrt_prec
            post_log_var = self.post_log_var_scaled * self.log_var_lr + 2. * math.log(
                sqrt_prec)

            prior_prec = self.prior(1)

            KL_term = 0.5*((post_mean**2).sum() + post_log_var.exp().sum())*prior_prec.scale -\
                0.5*post_mean.numel() - 0.5*post_mean.numel()*t.log(prior_prec.scale) - 0.5*post_log_var.sum()

            self.logpq = -KL_term * t.ones(*sample_shape,
                                           device=KL_term.device)
            return post_mean, post_log_var.exp()
        else:
            post_log_var = self.post_log_var_scaled * self.log_var_lr
            sqrt_prec = 1. / math.sqrt(self.in_features)
            post_mean = self.post_mean * sqrt_prec
            Qw = Normal(post_mean, sqrt_prec * t.exp(0.5 * post_log_var))

            w = Qw.rsample(sample_shape=t.Size([sample_shape[0]]))
            prior_prec = self.prior(sample_shape[0])
            logP = mvnormal_log_prob(prior_prec, w.transpose(-1, -2))
            logQ = Qw.log_prob(w).sum((-1, -2))
            self.logpq = logP - logQ
            return post_mean, post_log_var.exp()
Ejemplo n.º 32
0
 def forward(self, x_src):
     # Example variational parameters lambda
     mu, logvar = self.encoder(x_src)
     q_normal = Normal(loc=mu, scale=logvar.mul(0.5).exp())
     # Reparameterized sample.
     z_sample = q_normal.rsample()
     # z_sample = mu (no sampling)
     return self.decoder(z_sample), q_normal
Ejemplo n.º 33
0
# TODO: to make this stochastic, shuffle and make smaller batches.
start = time.time()
theta.train()
for epoch in range(args.num_epochs*2):
    # Keep track of reconstruction loss and total kl
    total_recon_loss = 0
    total_kl = 0
    total = 0
    for img, _ in loader:
        # no need to Variable(img).cuda()
        optim1.zero_grad()
        optim2.zero_grad()
        q = Normal(loc=mu, scale=logvar.mul(0.5).exp())
        # Reparameterized sample.
        qsamp = q.rsample()
        kl = kl_divergence(q, p).sum() # KL term
        out = theta(qsamp)
        recon_loss = criterion(out, img) # reconstruction term
        loss = (recon_loss + args.alpha * kl) / args.batch_size
        total_recon_loss += recon_loss.item() / args.batch_size
        total_kl += kl.item() / args.batch_size
        total += 1
        loss.backward()
        if args.clip:
            torch.nn.utils.clip_grad_norm(theta.parameters(), args.clip)
            torch.nn.utils.clip_grad_norm(mu, args.clip)
            torch.nn.utils.clip_grad_norm(theta.parameters(), args.clip)
        if epoch % 2:
            optim1.step()
            wv = 'Theta'