Esempio n. 1
0
class TanhNormal(Distribution):
    """Copied from Kaixhi"""
    def __init__(self, loc, scale):
        super().__init__()
        self.normal = Independent(Normal(loc, scale), 1)

    def sample(self):
        return torch.tanh(self.normal.sample())

    # samples with re-parametrization trick (differentiable)
    def rsample(self):
        return torch.tanh(self.normal.rsample())

    # Calculates log probability of value using the change-of-variables technique
    # (uses log1p = log(1 + x) for extra numerical stability)
    def log_prob(self, value):
        inv_value = (torch.log1p(value) - torch.log1p(-value)) / 2  # artanh(y)
        # log p(f^-1(y)) + log |det(J(f^-1(y)))|
        return self.normal.log_prob(inv_value) - torch.log1p(-value.pow(2) +
                                                             1e-6).sum(dim=1)

    @property
    def mean(self):
        return torch.tanh(self.normal.mean)

    def get_std(self):
        return self.normal.stddev
Esempio n. 2
0
 def get_loss(self, output, target, ignore_label=None):
     loss = super().get_loss(output, target, ignore_label)
     eps = 1e-12
     # construct N(0,1) diagonal covariance of size y (output)
     # construct N(0,1) diagonal covariance of size y (output)
     normal = Independent(
         Normal(loc=torch.FloatTensor(
             output['logits'].size()).fill_(0).to(device),
                scale=torch.FloatTensor(output['logits'].size()).fill_(
                    self.sigma).to(device)), 1)
     # sum ( softmax (distorted softmax probs)) using predicted voxel variances (scale)
     # we then take the log of these
     sum_distorted_softmax = torch.sum(torch.stack([
         self.softmax(output['logits'] +
                      (output['sigma'] * normal.sample()))
         for _ in torch.arange(self.samples)
     ]),
                                       dim=0)
     # sum_distorted_softmax should have shape [batch, nclasses, x, y]
     one_hot = torch.zeros(output['logits'].shape).scatter_(
         1, target.unsqueeze(1), 1)
     # mask sum_distorted_softmax in order to obtain only the softmax probs for the gt class and take max
     # of the result, which will just select the prob of the gt class (reduce dim 1=nclasses)
     sum_distorted_softmax, _ = torch.max(sum_distorted_softmax * one_hot,
                                          1)
     # sum_distorted_softmax should now have shape [batch, x, y]
     # finally compute the categorical aleatoric loss
     aleatoric_loss = -0.0001 * torch.mean(torch.sum(
         torch.log(sum_distorted_softmax + eps) - np.log(self.samples),
         dim=(1, 2)),
                                           dim=0)
     output['sigma'] = output['sigma'].cpu().detach().numpy()
     output['logits'] = None
     self.current_aleatoric_loss = aleatoric_loss.detach().cpu().numpy()
     return loss + aleatoric_loss
Esempio n. 3
0
class IndependentNormal(Distribution):
    arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
    support = constraints.positive
    has_rsample = True

    def __init__(self, loc, scale, validate_args=None):
        self.base_dist = Independent(Normal(loc=loc,
                                            scale=scale,
                                            validate_args=validate_args),
                                     len(loc.shape) - 1,
                                     validate_args=validate_args)
        super(IndependentNormal, self).__init__(self.base_dist.batch_shape,
                                                self.base_dist.event_shape,
                                                validate_args=validate_args)

    def log_prob(self, value):
        return self.base_dist.log_prob(value)

    @property
    def mean(self):
        return self.base_dist.mean

    @property
    def variance(self):
        return self.base_dist.variance

    def sample(self, sample_shape=torch.Size()):
        return self.base_dist.sample(sample_shape)

    def rsample(self, sample_shape=torch.Size()):
        return self.base_dist.rsample(sample_shape)

    def entropy(self):
        entropy = self.base_dist.entropy()
        return entropy
Esempio n. 4
0
class MeanField(BaseApproximation):
    def __init__(self):
        """
        Implements a mean field approximation of the state space
        """

        self._mean = None
        self._logstd = None

        self._sampledist = None     # type: Independent

    def entropy(self):
        return Independent(Normal(self._mean, self._logstd.exp()), 2).entropy()

    def initialize(self, data, ndim):
        self._mean = torch.zeros((data.shape[0] + 1, ndim), requires_grad=True)
        self._logstd = torch.zeros_like(self._mean, requires_grad=True)

        # ===== Start optimization ===== #
        self._sampledist = Independent(Normal(torch.zeros_like(self._mean), torch.ones_like(self._logstd)), 2)

        return self

    def get_parameters(self):
        return [self._mean, self._logstd]

    def sample(self, num_samples):
        samples = (num_samples,) if isinstance(num_samples, int) else num_samples
        return self._mean + self._logstd.exp() * self._sampledist.sample(samples)
Esempio n. 5
0
    def sample_trajectories(self, init_std=1.0, min_std=1e-6, output_size=2):
        # 基于当前策略,采样 batch_size 个完整的轨迹
        observations = self.envs.reset()
        with torch.no_grad():
            while not self.envs.dones.all():
                observations_tensor = torch.from_numpy(observations)
                """
                ******************************************************************
                """
                output = self.policy(observations_tensor)

                min_log_std = math.log(min_std)
                sigma = nn.Parameter(torch.Tensor(output_size))
                sigma.data.fill_(math.log(init_std))

                scale = torch.exp(torch.clamp(sigma, min=min_log_std))
                # loc 是高斯分布均值
                # scale 是高斯分布方差
                p_normal = Independent(Normal(loc=output, scale=scale), 1)

                actions_tensor = p_normal.sample()
                actions = actions_tensor.cpu().numpy()

                # pi = policy(observations_tensor)
                # actions_tensor = pi.sample()
                # actions = actions_tensor.cpu().numpy()

                new_observations, rewards, _, infos = self.envs.step(actions)
                batch_ids = infos['batch_ids']
                yield (observations, actions, rewards, batch_ids)
                observations = new_observations
Esempio n. 6
0
 def choose_action(self, observation):
     mu, sigma  = self.actor.forward(observation)#.to(self.actor.device)
     sigma = T.exp(sigma)
     action_probs = Independent(Normal(mu, sigma),1)
     probs = action_probs.sample()
     self.log_probs = action_probs.log_prob(probs).to(self.actor.device)
     return probs
Esempio n. 7
0
def prop_state(x, f, g):
    bins = Independent(Binomial(x[:-1], f), 1)
    samp = bins.sample()

    s = x[0] - samp[..., 0]
    i = x[1] + samp[..., 0] - samp[..., 1]
    r = x[2] + samp[..., 1]

    return concater(s, i, r)
Esempio n. 8
0
def prop_state(x, beta, gamma, eta, dt):
    f = _f(x, beta, gamma, eta, dt)

    bins = Independent(Binomial(x[..., :-1], f), 1)
    samp = bins.sample()

    s = x[..., 0] - samp[..., 0]
    i = x[..., 1] + samp[..., 0] - samp[..., 1]
    r = x[..., 2] + samp[..., 1]

    return concater(s, i, r)
Esempio n. 9
0
def sample_xy_diag(mix_probs, means, scales):
    # get MVN means and scales
    mixtures = Categorical(mix_probs).sample() # (n,)
    means_sel = torch.stack([elt[i] for elt, i in zip(means, mixtures)]) # (n,d)
    scales_sel = torch.stack([elt[i] for elt, i in zip(scales, mixtures)]) # (n,d)

    # sample from MVNs
    norm = Normal(means_sel, scales_sel)
    mvn = Independent(norm, 1)
    samples = mvn.sample() # (n,d)

    return samples
Esempio n. 10
0
 def test_independent_shape(self):
     for Dist, params in EXAMPLES:
         for param in params:
             base_dist = Dist(**param)
             x = base_dist.sample()
             base_log_prob_shape = base_dist.log_prob(x).shape
             for reinterpreted_batch_ndims in range(
                     len(base_dist.batch_shape) + 1):
                 indep_dist = Independent(base_dist,
                                          reinterpreted_batch_ndims)
                 indep_log_prob_shape = base_log_prob_shape[:len(
                     base_log_prob_shape) - reinterpreted_batch_ndims]
                 self.assertEqual(
                     indep_dist.log_prob(x).shape, indep_log_prob_shape)
                 self.assertEqual(indep_dist.sample().shape,
                                  base_dist.sample().shape)
                 self.assertEqual(indep_dist.has_rsample,
                                  base_dist.has_rsample)
                 if indep_dist.has_rsample:
                     self.assertEqual(indep_dist.sample().shape,
                                      base_dist.sample().shape)
                 try:
                     self.assertEqual(
                         indep_dist.enumerate_support().shape,
                         base_dist.enumerate_support().shape,
                     )
                     self.assertEqual(indep_dist.mean.shape,
                                      base_dist.mean.shape)
                 except NotImplementedError:
                     pass
                 try:
                     self.assertEqual(indep_dist.variance.shape,
                                      base_dist.variance.shape)
                 except NotImplementedError:
                     pass
                 try:
                     self.assertEqual(indep_dist.entropy().shape,
                                      indep_log_prob_shape)
                 except NotImplementedError:
                     pass
Esempio n. 11
0
    def get_action(self, obs):
        obs = torch.tensor(obs, dtype=torch.float).to(self.device)
        with torch.no_grad():
            mu, sigma = self.pi(obs)
            act_distribution = Independent(Normal(mu, sigma), 1)
            action = act_distribution.sample()

            log_prob = act_distribution.log_prob(action)
            val = self.V(obs)

        action = action.cpu().numpy()
        log_prob = log_prob.cpu().numpy()
        val = val.cpu().numpy()

        return action, log_prob, val
Esempio n. 12
0
class IndependentRescaledBeta(Distribution):
    arg_constraints = {
        'concentration1': constraints.positive,
        'concentration0': constraints.positive
    }
    support = constraints.interval(-1., 1.)
    has_rsample = True

    def __init__(self, concentration1, concentration0, validate_args=None):
        self.base_dist = Independent(RescaledBeta(concentration1,
                                                  concentration0,
                                                  validate_args),
                                     len(concentration1.shape) - 1,
                                     validate_args=validate_args)
        super(IndependentRescaledBeta,
              self).__init__(self.base_dist.batch_shape,
                             self.base_dist.event_shape,
                             validate_args=validate_args)

    def log_prob(self, value):
        return self.base_dist.log_prob(value)

    @property
    def mean(self):
        return self.base_dist.mean

    @property
    def variance(self):
        return self.base_dist.variance

    def sample(self, sample_shape=torch.Size()):
        return self.base_dist.sample(sample_shape)

    def rsample(self, sample_shape=torch.Size()):
        return self.base_dist.rsample(sample_shape)

    def entropy(self):
        entropy = self.base_dist.entropy()
        return entropy
Esempio n. 13
0
class NormalApproximation(KernelDensityEstimate):
    def __init__(self, independent=True):
        super().__init__()
        self._dist = None  # type: torch.distributions.Distribution
        self._indep = independent
        self._shape = None

    def fit(self, x, w):
        self._shape = (x.shape[0], )

        if not self._indep:
            self._dist = _construct_mvn(x, w)
            return self

        mean = (w.unsqueeze(-1) * x).sum(0)
        var = robust_var(x, w, mean)

        self._dist = Independent(Normal(mean, var.sqrt()), 1)

        return self

    def sample(self, inds=None):
        return self._dist.sample(self._shape)
def test_independent_normal() -> None:
    num_samples = 2000
    dim = 4

    loc = np.arange(0, dim) / float(dim)
    diag = np.arange(dim) / dim + 0.5
    Sigma = diag**2

    distr = Independent(
        Normal(loc=torch.Tensor(loc), scale=torch.Tensor(diag)), 1)

    assert np.allclose(
        distr.variance.numpy(), Sigma, atol=0.1, rtol=0.1
    ), f"did not match: sigma = {Sigma}, sigma_hat = {distr.variance.numpy()}"

    samples = distr.sample((num_samples, ))

    loc_hat, diag_hat = maximum_likelihood_estimate_sgd(
        NormalOutput(dim=dim),
        samples,
        learning_rate=0.01,
        num_epochs=10,
    )

    distr = Independent(
        Normal(loc=torch.Tensor(loc_hat), scale=torch.Tensor(diag_hat)), 1)

    Sigma_hat = distr.variance.numpy()

    assert np.allclose(
        loc_hat, loc, atol=0.2,
        rtol=0.1), f"mu did not match: loc = {loc}, loc_hat = {loc_hat}"

    assert np.allclose(
        Sigma_hat, Sigma, atol=0.1, rtol=0.1
    ), f"sigma did not match: sigma = {Sigma}, sigma_hat = {Sigma_hat}"
    def forward(self, x, mean=False, z_q=None):
        blocks = []
        used_latents = []
        distributions = []
        if isinstance(mean, bool):
            mean = [mean] * self.num_latent_levels

        features = x
        for i, block in enumerate(self.res_layers):
            #print("Block",i,block)
            features = block(features)
            blocks.append(features)
            if i != self.num_levels - 1:
                features = self.Pool_layers[i](features)

        decoder_features = blocks[-1]
        #print(decoder_features.shape,1)

        for proba_level in range(self.num_latent_levels):
            #print(proba_level)
            latent_dim = self._latent_dims[proba_level]
            mu_log_sigma = self.probabilistic_block(decoder_features)
            #print(mu_log_sigma.shape,"mu logsigma shape")

            # mu_log_sigma = torch.squeeze(mu_log_sigma,dim=1)
            # print(mu_log_sigma.shape,"mu logsigma shape squeeze")
            # print(mu_log_sigma[Ellipsis,:latent_dim].shape,"mu  shape Ellipsis")
            # print(mu_log_sigma[Ellipsis,latent_dim:].shape,"logsigma shape Ellipsis")
            mu = mu_log_sigma[:, :latent_dim]
            #print("mu shape:",mu.shape)
            log_sigma = mu_log_sigma[:, latent_dim:]
            #print("Logsigma shape:",log_sigma.shape)

            # mu = mu_log_sigma[:,:latent_dim,...]
            # print("mu shape:",mu.shape)
            # log_sigma = mu_log_sigma[:,latent_dim:,...]
            # print("Logsigma shape:",log_sigma.shape)
            dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1)
            #dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)),0)

            distributions.append(dist)

            if z_q is not None:
                z = z_q[proba_level]
                #print(z.shape,"z_q")
            elif mean[proba_level]:
                z = dist.base_dist.loc
                #print(z.shape,"Proba level")
            else:
                z = dist.sample()
                #print(z.shape,"z shape")
            used_latents.append(z)
            # print(z.shape,"sample shape")
            decoder_output_lo = torch.cat([z, decoder_features], axis=1)
            # print(decoder_output_lo.shape,"decoder_lo")

            decoder_output_hi = self.interpolate(decoder_output_lo)
            # print(decoder_output_hi.shape,"decoder_hi")
            # print(blocks[::-1][proba_level + 1].shape,"block")
            decoder_features = torch.cat(
                [decoder_output_hi, blocks[::-1][proba_level + 1]], axis=1)
            # print(decoder_features.shape)

            decoder_features = self.decoder_layers[proba_level](
                decoder_features)

        #print('decoder features {}'.format(decoder_features.shape))

        return {
            'decoder_features': decoder_features,
            'encoder_features': blocks,
            'distributions': distributions,
            'used_latents': used_latents
        }
Esempio n. 16
0
class UNetVAEGenerator(GeneralVAE):
    def __init__(self,
                 imsize,
                 n_channels_in,
                 n_channels_out,
                 n_hidden,
                 z_dim,
                 device="cpu",
                 **kwargs):
        super(UNetVAEGenerator, self).__init__(n_channels_in, n_channels_out,
                                               device, **kwargs)

        self.z_dim = z_dim

        hidden_dims = [n_hidden, n_hidden * 2, n_hidden * 4, n_hidden * 8]

        # embedder
        self.enc1 = nn.Sequential(
            nn.Conv2d(n_channels_in,
                      hidden_dims[0],
                      kernel_size=3,
                      stride=2,
                      padding=1), nn.BatchNorm2d(hidden_dims[0]),
            nn.LeakyReLU(0.2), nn.Dropout(0.1))
        self.enc2 = nn.Sequential(
            nn.Conv2d(hidden_dims[0],
                      hidden_dims[1],
                      kernel_size=3,
                      stride=2,
                      padding=1), nn.BatchNorm2d(hidden_dims[1]),
            nn.LeakyReLU(0.2), nn.Dropout(0.1))
        self.enc3 = nn.Sequential(
            nn.Conv2d(hidden_dims[1],
                      hidden_dims[2],
                      kernel_size=3,
                      stride=2,
                      padding=1), nn.BatchNorm2d(hidden_dims[2]),
            nn.LeakyReLU(0.2), nn.Dropout(0.1))
        self.enc4 = nn.Sequential(
            nn.Conv2d(hidden_dims[2],
                      hidden_dims[3],
                      kernel_size=3,
                      stride=2,
                      padding=1), nn.BatchNorm2d(hidden_dims[3]),
            nn.LeakyReLU(0.2), nn.Dropout(0.1))

        enc_imsize = (1 + (imsize[0] - 1) // (2**4),
                      1 + (imsize[1] - 1) // (2**4))

        self.mu = nn.Sequential(
            Flatten(),
            nn.Linear(hidden_dims[3] * enc_imsize[0] * enc_imsize[1],
                      z_dim))  # n_channels depends on img resolution
        self.logvar = nn.Sequential(
            Flatten(),
            nn.Linear(hidden_dims[3] * enc_imsize[0] * enc_imsize[1], z_dim))

        self.project_z = nn.Sequential(
            nn.Linear(z_dim, hidden_dims[3] * enc_imsize[0] * enc_imsize[1]),
            UnFlatten(n_channels=hidden_dims[3], im_size=enc_imsize),
            nn.Conv2d(hidden_dims[3], hidden_dims[3], kernel_size=2,
                      padding=2), nn.BatchNorm2d(hidden_dims[3]),
            nn.LeakyReLU(0.2))

        self.dec0 = nn.Sequential(
            nn.ConvTranspose2d(hidden_dims[3],
                               hidden_dims[2],
                               kernel_size=2,
                               stride=2), nn.BatchNorm2d(hidden_dims[2]),
            nn.LeakyReLU(0.2), nn.Dropout(0.1))
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(hidden_dims[3],
                               hidden_dims[1],
                               kernel_size=2,
                               stride=2), nn.BatchNorm2d(hidden_dims[1]),
            nn.LeakyReLU(0.2), nn.Dropout(0.1))
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(hidden_dims[2],
                               hidden_dims[0],
                               kernel_size=2,
                               stride=2), nn.BatchNorm2d(hidden_dims[0]),
            nn.LeakyReLU(0.2), nn.Dropout(0.1))
        self.dec3 = nn.Sequential(
            nn.ConvTranspose2d(hidden_dims[1],
                               hidden_dims[0],
                               kernel_size=2,
                               stride=2), nn.BatchNorm2d(hidden_dims[0]),
            nn.LeakyReLU(0.2), nn.Dropout(0.1))

        self.zres1 = Noise_injector(hidden_dims[1],
                                    z_dim,
                                    n_channels_in,
                                    hidden_dims[1],
                                    device=device).to(device)
        self.zres2 = Noise_injector(hidden_dims[0],
                                    z_dim,
                                    n_channels_in,
                                    hidden_dims[0],
                                    device=device).to(device)
        self.out = Noise_injector(hidden_dims[0],
                                  z_dim,
                                  n_channels_in,
                                  n_channels_out,
                                  device=device).to(device)

        initialize_weights(self.dec3, self.dec2, self.dec1, self.dec0)
        self.mu.apply(weights_init)
        self.logvar.apply(weights_init)
        self.project_z.apply(weights_init)

    def forward(self, x, return_mu_logvar=False):

        mu, logvar = self.encode(x)

        if return_mu_logvar:
            return mu, logvar
        else:
            z = self.latent_dist.sample()
            return self.decode(z)

    def encode(self, x):
        self.down1 = self.enc1(x)
        self.down2 = self.enc2(self.down1)
        self.down3 = self.enc3(self.down2)
        self.down4 = self.enc4(self.down3)

        mu = self.mu(self.down4)
        logvar = self.logvar(self.down4).clamp(min=np.log(1e-7))

        std = logvar.mul(0.5).exp_()

        self.latent_dist = Independent(Normal(loc=mu, scale=std), 1)

        return mu, logvar

    def decode(self, z, ign_idxs=None):
        up1 = self.dec0(self.down4)
        up2 = self.dec1(torch.cat((up1, self.down3), dim=1))  #skip connection

        up2b = nn.functional.leaky_relu(self.zres1(up2, z))  # noise injection

        up3 = self.dec2(torch.cat((up2b, self.down2), dim=1))

        up3b = nn.functional.leaky_relu(self.zres2(up3, z))

        up4 = self.dec3(torch.cat((up3b, self.down1), dim=1))

        logits = self.out(up4, z)

        out = F.softmax(logits, dim=1)

        if ign_idxs is None:
            return out
        else:
            # set unlabelled pixels to class unlabelled for Cityscapes
            # masks the adv loss by preventing gradients from being formed in unlabelled pixs
            w = torch.ones(out.shape)
            w[ign_idxs[0], :, ign_idxs[1], ign_idxs[2]] = 0.

            r = torch.zeros(out.shape)
            r[ign_idxs[0], 24, ign_idxs[1], ign_idxs[2]] = 1.

            out = out * w.to(DEVICE) + r.to(DEVICE)

            return out

    def sample(self, x, n_samples=1, ign_idxs=None):

        self.encode(x)

        # sample z
        z = self.latent_dist.sample((n_samples, ))

        # serial decoding
        if ign_idxs is None:
            pred_dist = torch_comp_along_dim(self.decode, z, dim=0)
        else:
            pred_dist = torch_comp_along_dim(self.decode, z, ign_idxs, dim=0)

        avg_pred = pred_dist.mean(0)

        return pred_dist, None, avg_pred
Esempio n. 17
0
class TorchGaussianMixtureDistribution(TorchDistributionWrapper):
    @staticmethod
    def required_model_output_shape(action_space, model_config):
        return prod((2, model_config['custom_model_config']['num_gaussians']) +
                    action_space.shape)

    def __init__(self, inputs: List[torch.Tensor], model: TorchModelV2):
        super(TorchDistributionWrapper, self).__init__(inputs, model)
        assert len(inputs.shape) == 2
        self.batch_size = self.inputs.shape[0]
        self.num_gaussians = model.model_config['custom_model_config'][
            'num_gaussians']
        self.monte_samples = model.model_config['custom_model_config'][
            'monte_samples']
        inputs = torch.reshape(self.inputs,
                               (self.batch_size, 2, self.num_gaussians, -1))
        self.action_dim = inputs.shape[-1]
        assert not torch.isnan(inputs).any(), "Input nan aborting"
        self.means = inputs[:,
                            0, :, :]  # batch_size x num_gaussians x action_dim
        self.sigmas = torch.exp(
            inputs[:, 1, :, :])  # batch_size x num_gaussians x action_dim

        self.cat = Categorical(
            torch.ones(self.batch_size,
                       self.num_gaussians,
                       device=inputs.device,
                       requires_grad=False))
        self.normals = Independent(Normal(self.means, self.sigmas), 1)

    def logp(self, actions: torch.Tensor):
        actions = actions.view(
            self.batch_size, 1,
            -1)  # batch_size x 1 (broadcast to num gaussians) x action_dim
        mix_lps = self.cat.logits  # batch_size x num_gaussians x action_dim
        normal_lps = self.normals.log_prob(
            actions)  # batch_size x num_gaussians x action_dim
        assert not torch.isnan(mix_lps).any(), "output nan aborting"
        assert not torch.isnan(normal_lps).any(), "output nan aborting"
        return torch.logsumexp(mix_lps + normal_lps,
                               dim=1)  # reduce along num gaussians

    def deterministic_sample(self) -> torch.Tensor:
        self.last_sample = self.means[:,
                                      0, :]  # select the mode of the first gaussian
        return self.last_sample

    def __rsamples(self):
        """ Compute samples that can be differentiated through
        """
        # Using reparameterization trick i.e. rsample
        normal_samples = self.normals.rsample(
            (self.monte_samples,
             ))  # monte_samples x batch_size x num_gaussians x action_dim
        cat_samples = self.cat.sample(
            (self.monte_samples, ))  # monte_samples x batch_size
        # First we need to expand cat so that it has the same dimension as normal samples
        cat_samples = cat_samples.reshape(self.monte_samples, -1, 1,
                                          1).expand(-1, -1, -1,
                                                    self.action_dim)
        # We select the normal distribution based on the outputs of
        # the categorical distribution
        return torch.gather(normal_samples, 2, cat_samples).squeeze(
            dim=2)  # monte_samples x batch_size x action_dim

    def kl(self, q: ActionDistribution) -> torch.Tensor:
        """ KL(self || q) estimated with monte carlo sampling
        """
        rsamples = self.__rsamples().unbind(0)
        log_ratios = torch.stack(
            [self.logp(rsample) - q.logp(rsample) for rsample in rsamples])
        assert not torch.isnan(log_ratios).any(), "output nan aborting"
        return log_ratios.mean(0)

    def entropy(self) -> torch.Tensor:
        """ H(self) estimated with monte carlo sampling
        """
        rsamples = self.__rsamples().unbind(0)
        log_ps = torch.stack([-self.logp(rsample) for rsample in rsamples])
        assert not torch.isnan(log_ps).any(), "output nan aborting"
        return log_ps.mean(0)

    def sample(self):
        normal_samples = self.normals.sample(
        )  # batch_size x num_gaussians x action_dim
        cat_samples = self.cat.sample()  # batch_size
        # First we need to expand cat so that it has the same dimension as normal samples
        cat_samples = cat_samples.view(-1, 1,
                                       1).expand(-1, -1, self.action_dim)
        # We select the normal distribution based on the outputs of
        # the categorical distribution
        self.last_sample = torch.gather(normal_samples, 1,
                                        cat_samples).squeeze(
                                            dim=1)  # batch_size x action_dim
        assert len(
            self.last_sample.shape) == 2, f"shape, {self.last_sample.shape}"
        return self.last_sample
Esempio n. 18
0
    def __call__(self, x, out_keys=['action'], info={}, **kwargs):
        # Output dictionary
        out_policy = {}

        # Forward pass of feature networks to obtain features
        if self.recurrent:
            out_network = self.network(x=x,
                                       hidden_states=self.rnn_states,
                                       mask=info.get('mask', None))
            features = out_network['output']
            # Update the tracking of current RNN hidden states
            self.rnn_states = out_network['hidden_states']
        else:
            features = self.network(x)

        # Forward pass through mean head to obtain mean values for Gaussian distribution
        mean = self.network.mean_head(features)
        # Obtain logvar based on the options
        if isinstance(self.network.logvar_head,
                      nn.Linear):  # linear layer, then do forward pass
            logvar = self.network.logvar_head(features)
        else:  # either Tensor or nn.Parameter
            logvar = self.network.logvar_head
            # Expand as same shape as mean
            logvar = logvar.expand_as(mean)

        # Forward pass of value head to obtain value function if required
        if 'state_value' in out_keys:
            out_policy['state_value'] = self.network.value_head(
                features).squeeze(-1)  # squeeze final single dim

        # Get std from logvar
        if self.std_style == 'exp':
            std = torch.exp(0.5 * logvar)
        elif self.std_style == 'softplus':
            std = F.softplus(logvar)

        # Lower bound threshould for std
        min_std = torch.full(std.size(),
                             self.min_std).type_as(std).to(self.device)
        std = torch.max(std, min_std)

        # Create independent Gaussian distributions i.e. Diagonal Gaussian
        action_dist = Independent(Normal(loc=mean, scale=std), 1)

        # Sample action from the distribution (no gradient)
        # Do not use `rsample()`, it leads to zero gradient of mean head !
        action = action_dist.sample()
        out_policy['action'] = action

        # Calculate log-probability of the sampled action
        if 'action_logprob' in out_keys:
            out_policy['action_logprob'] = action_dist.log_prob(action)

        # Calculate policy entropy conditioned on state
        if 'entropy' in out_keys:
            out_policy['entropy'] = action_dist.entropy()

        # Calculate policy perplexity i.e. exp(entropy)
        if 'perplexity' in out_keys:
            out_policy['perplexity'] = action_dist.perplexity()

        # sanity check for NaN
        if torch.any(torch.isnan(action)):
            while True:
                msg = 'NaN ! A workaround is to learn state-independent std or use tanh rather than relu'
                msg2 = f'check: \n\t mean: {mean}, logvar: {logvar}'
                print(msg + msg2)

        # Constraint action in valid range
        out_policy['action'] = self.constraint_action(action)

        return out_policy
Esempio n. 19
0
class UNetGenerator(GeneralVAE):
    def __init__(self, imsize, n_channels_in,n_channels_out, n_hidden, z_dim, device = "cpu", **kwargs):
        super(UNetGenerator, self).__init__(n_channels_in, n_channels_out, device, **kwargs)

        self.z_dim = z_dim

        hidden_dims = [n_hidden, n_hidden*2, n_hidden*4, n_hidden*8]

        # embedder
        self.enc1 = nn.Sequential(nn.Conv2d(n_channels_in, hidden_dims[0], kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(hidden_dims[0]), nn.LeakyReLU(0.2), nn.Dropout(0.1))
        self.enc2 = nn.Sequential(nn.Conv2d(hidden_dims[0], hidden_dims[1], kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(hidden_dims[1]), nn.LeakyReLU(0.2), nn.Dropout(0.1))
        self.enc3 = nn.Sequential(nn.Conv2d(hidden_dims[1], hidden_dims[2], kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(hidden_dims[2]), nn.LeakyReLU(0.2), nn.Dropout(0.1))
        self.enc4 = nn.Sequential(nn.Conv2d(hidden_dims[2], hidden_dims[3], kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(hidden_dims[3]), nn.LeakyReLU(0.2), nn.Dropout(0.1))

        self.dec0 = nn.Sequential(nn.ConvTranspose2d(hidden_dims[3], hidden_dims[2], kernel_size=2, stride=2), nn.BatchNorm2d(hidden_dims[2]), nn.LeakyReLU(0.2), nn.Dropout(0.1))
        self.dec1 = nn.Sequential(nn.ConvTranspose2d(hidden_dims[3], hidden_dims[1], kernel_size=2, stride=2), nn.BatchNorm2d(hidden_dims[1]), nn.LeakyReLU(0.2), nn.Dropout(0.1))
        self.dec2 = nn.Sequential(nn.ConvTranspose2d(hidden_dims[2], hidden_dims[0], kernel_size=2, stride=2), nn.BatchNorm2d(hidden_dims[0]), nn.LeakyReLU(0.2), nn.Dropout(0.1))
        self.dec3 = nn.Sequential(nn.ConvTranspose2d(hidden_dims[1],hidden_dims[0], kernel_size=2, stride=2), nn.BatchNorm2d(hidden_dims[0]), nn.LeakyReLU(0.2), nn.Dropout(0.1))

        self.zres1 = Noise_injector(hidden_dims[1], z_dim, n_channels_in, hidden_dims[1], device=device).to(device)
        self.zres2 = Noise_injector(hidden_dims[0], z_dim, n_channels_in, hidden_dims[0], device=device).to(device)
        self.out = Noise_injector(hidden_dims[0], z_dim, n_channels_in, n_channels_out, device=device).to(device)

        initialize_weights(self.dec3, self.dec2, self.dec1, self.dec0)

    def forward(self, x):

        self.encode(x)

        self.get_gauss(x)

        z = self.gauss.sample()

        return self.decode(z)


    def encode(self, x):
        self.down1 = self.enc1(x)
        self.down2 = self.enc2(self.down1)
        self.down3 = self.enc3(self.down2)
        self.down4 = self.enc4(self.down3)

    def decode(self, z, ign_idxs=None):
        up1 = self.dec0(self.down4)
        up2 = self.dec1(torch.cat((up1, self.down3),dim=1)) #skip connection

        up2b = nn.functional.leaky_relu(self.zres1(up2, z)) # noise injection

        up3 = self.dec2(torch.cat((up2b, self.down2), dim=1))

        up3b = nn.functional.leaky_relu(self.zres2(up3, z))

        up4 = self.dec3(torch.cat((up3b, self.down1),dim=1))

        logits = self.out(up4,z)

        out =  F.softmax(logits, dim=1)

        if ign_idxs is None:
            return out
        else:
            # set unlabelled pixels to class unlabelled for Cityscapes
            # masks the adv loss by preventing gradients from being formed in unlabelled pixs
            w = torch.ones(out.shape)
            w[ign_idxs[0], :, ign_idxs[1], ign_idxs[2]] = 0.

            r = torch.zeros(out.shape)
            r[ign_idxs[0], 24, ign_idxs[1], ign_idxs[2]] = 1.

            out = out * w.to(DEVICE) + r.to(DEVICE)

            return out

    def get_gauss(self, x):
        b_size = len(x)
        self.gauss = Independent(Normal(loc=torch.zeros((b_size, self.z_dim)).float().to(DEVICE),
                                        scale=torch.ones((b_size, self.z_dim)).float().to(DEVICE)), 1)

    def sample(self, x, ign_idxs = None, n_samples=1):

        self.get_gauss(x)

        # sample z
        z = self.gauss.sample((n_samples,))

        # encode z
        self.encode(x)

        # serial decoding
        if ign_idxs is None:
            pred_dist = torch_comp_along_dim(self.decode, z, dim=0)
        else:
            pred_dist = torch_comp_along_dim(self.decode, z, ign_idxs, dim=0)

        # compute the average prediction
        avg_pred = pred_dist.mean(0)

        return pred_dist, None, avg_pred