class TanhNormal(Distribution): """Copied from Kaixhi""" def __init__(self, loc, scale): super().__init__() self.normal = Independent(Normal(loc, scale), 1) def sample(self): return torch.tanh(self.normal.sample()) # samples with re-parametrization trick (differentiable) def rsample(self): return torch.tanh(self.normal.rsample()) # Calculates log probability of value using the change-of-variables technique # (uses log1p = log(1 + x) for extra numerical stability) def log_prob(self, value): inv_value = (torch.log1p(value) - torch.log1p(-value)) / 2 # artanh(y) # log p(f^-1(y)) + log |det(J(f^-1(y)))| return self.normal.log_prob(inv_value) - torch.log1p(-value.pow(2) + 1e-6).sum(dim=1) @property def mean(self): return torch.tanh(self.normal.mean) def get_std(self): return self.normal.stddev
def get_loss(self, output, target, ignore_label=None): loss = super().get_loss(output, target, ignore_label) eps = 1e-12 # construct N(0,1) diagonal covariance of size y (output) # construct N(0,1) diagonal covariance of size y (output) normal = Independent( Normal(loc=torch.FloatTensor( output['logits'].size()).fill_(0).to(device), scale=torch.FloatTensor(output['logits'].size()).fill_( self.sigma).to(device)), 1) # sum ( softmax (distorted softmax probs)) using predicted voxel variances (scale) # we then take the log of these sum_distorted_softmax = torch.sum(torch.stack([ self.softmax(output['logits'] + (output['sigma'] * normal.sample())) for _ in torch.arange(self.samples) ]), dim=0) # sum_distorted_softmax should have shape [batch, nclasses, x, y] one_hot = torch.zeros(output['logits'].shape).scatter_( 1, target.unsqueeze(1), 1) # mask sum_distorted_softmax in order to obtain only the softmax probs for the gt class and take max # of the result, which will just select the prob of the gt class (reduce dim 1=nclasses) sum_distorted_softmax, _ = torch.max(sum_distorted_softmax * one_hot, 1) # sum_distorted_softmax should now have shape [batch, x, y] # finally compute the categorical aleatoric loss aleatoric_loss = -0.0001 * torch.mean(torch.sum( torch.log(sum_distorted_softmax + eps) - np.log(self.samples), dim=(1, 2)), dim=0) output['sigma'] = output['sigma'].cpu().detach().numpy() output['logits'] = None self.current_aleatoric_loss = aleatoric_loss.detach().cpu().numpy() return loss + aleatoric_loss
class IndependentNormal(Distribution): arg_constraints = {'loc': constraints.real, 'scale': constraints.positive} support = constraints.positive has_rsample = True def __init__(self, loc, scale, validate_args=None): self.base_dist = Independent(Normal(loc=loc, scale=scale, validate_args=validate_args), len(loc.shape) - 1, validate_args=validate_args) super(IndependentNormal, self).__init__(self.base_dist.batch_shape, self.base_dist.event_shape, validate_args=validate_args) def log_prob(self, value): return self.base_dist.log_prob(value) @property def mean(self): return self.base_dist.mean @property def variance(self): return self.base_dist.variance def sample(self, sample_shape=torch.Size()): return self.base_dist.sample(sample_shape) def rsample(self, sample_shape=torch.Size()): return self.base_dist.rsample(sample_shape) def entropy(self): entropy = self.base_dist.entropy() return entropy
class MeanField(BaseApproximation): def __init__(self): """ Implements a mean field approximation of the state space """ self._mean = None self._logstd = None self._sampledist = None # type: Independent def entropy(self): return Independent(Normal(self._mean, self._logstd.exp()), 2).entropy() def initialize(self, data, ndim): self._mean = torch.zeros((data.shape[0] + 1, ndim), requires_grad=True) self._logstd = torch.zeros_like(self._mean, requires_grad=True) # ===== Start optimization ===== # self._sampledist = Independent(Normal(torch.zeros_like(self._mean), torch.ones_like(self._logstd)), 2) return self def get_parameters(self): return [self._mean, self._logstd] def sample(self, num_samples): samples = (num_samples,) if isinstance(num_samples, int) else num_samples return self._mean + self._logstd.exp() * self._sampledist.sample(samples)
def sample_trajectories(self, init_std=1.0, min_std=1e-6, output_size=2): # 基于当前策略,采样 batch_size 个完整的轨迹 observations = self.envs.reset() with torch.no_grad(): while not self.envs.dones.all(): observations_tensor = torch.from_numpy(observations) """ ****************************************************************** """ output = self.policy(observations_tensor) min_log_std = math.log(min_std) sigma = nn.Parameter(torch.Tensor(output_size)) sigma.data.fill_(math.log(init_std)) scale = torch.exp(torch.clamp(sigma, min=min_log_std)) # loc 是高斯分布均值 # scale 是高斯分布方差 p_normal = Independent(Normal(loc=output, scale=scale), 1) actions_tensor = p_normal.sample() actions = actions_tensor.cpu().numpy() # pi = policy(observations_tensor) # actions_tensor = pi.sample() # actions = actions_tensor.cpu().numpy() new_observations, rewards, _, infos = self.envs.step(actions) batch_ids = infos['batch_ids'] yield (observations, actions, rewards, batch_ids) observations = new_observations
def choose_action(self, observation): mu, sigma = self.actor.forward(observation)#.to(self.actor.device) sigma = T.exp(sigma) action_probs = Independent(Normal(mu, sigma),1) probs = action_probs.sample() self.log_probs = action_probs.log_prob(probs).to(self.actor.device) return probs
def prop_state(x, f, g): bins = Independent(Binomial(x[:-1], f), 1) samp = bins.sample() s = x[0] - samp[..., 0] i = x[1] + samp[..., 0] - samp[..., 1] r = x[2] + samp[..., 1] return concater(s, i, r)
def prop_state(x, beta, gamma, eta, dt): f = _f(x, beta, gamma, eta, dt) bins = Independent(Binomial(x[..., :-1], f), 1) samp = bins.sample() s = x[..., 0] - samp[..., 0] i = x[..., 1] + samp[..., 0] - samp[..., 1] r = x[..., 2] + samp[..., 1] return concater(s, i, r)
def sample_xy_diag(mix_probs, means, scales): # get MVN means and scales mixtures = Categorical(mix_probs).sample() # (n,) means_sel = torch.stack([elt[i] for elt, i in zip(means, mixtures)]) # (n,d) scales_sel = torch.stack([elt[i] for elt, i in zip(scales, mixtures)]) # (n,d) # sample from MVNs norm = Normal(means_sel, scales_sel) mvn = Independent(norm, 1) samples = mvn.sample() # (n,d) return samples
def test_independent_shape(self): for Dist, params in EXAMPLES: for param in params: base_dist = Dist(**param) x = base_dist.sample() base_log_prob_shape = base_dist.log_prob(x).shape for reinterpreted_batch_ndims in range( len(base_dist.batch_shape) + 1): indep_dist = Independent(base_dist, reinterpreted_batch_ndims) indep_log_prob_shape = base_log_prob_shape[:len( base_log_prob_shape) - reinterpreted_batch_ndims] self.assertEqual( indep_dist.log_prob(x).shape, indep_log_prob_shape) self.assertEqual(indep_dist.sample().shape, base_dist.sample().shape) self.assertEqual(indep_dist.has_rsample, base_dist.has_rsample) if indep_dist.has_rsample: self.assertEqual(indep_dist.sample().shape, base_dist.sample().shape) try: self.assertEqual( indep_dist.enumerate_support().shape, base_dist.enumerate_support().shape, ) self.assertEqual(indep_dist.mean.shape, base_dist.mean.shape) except NotImplementedError: pass try: self.assertEqual(indep_dist.variance.shape, base_dist.variance.shape) except NotImplementedError: pass try: self.assertEqual(indep_dist.entropy().shape, indep_log_prob_shape) except NotImplementedError: pass
def get_action(self, obs): obs = torch.tensor(obs, dtype=torch.float).to(self.device) with torch.no_grad(): mu, sigma = self.pi(obs) act_distribution = Independent(Normal(mu, sigma), 1) action = act_distribution.sample() log_prob = act_distribution.log_prob(action) val = self.V(obs) action = action.cpu().numpy() log_prob = log_prob.cpu().numpy() val = val.cpu().numpy() return action, log_prob, val
class IndependentRescaledBeta(Distribution): arg_constraints = { 'concentration1': constraints.positive, 'concentration0': constraints.positive } support = constraints.interval(-1., 1.) has_rsample = True def __init__(self, concentration1, concentration0, validate_args=None): self.base_dist = Independent(RescaledBeta(concentration1, concentration0, validate_args), len(concentration1.shape) - 1, validate_args=validate_args) super(IndependentRescaledBeta, self).__init__(self.base_dist.batch_shape, self.base_dist.event_shape, validate_args=validate_args) def log_prob(self, value): return self.base_dist.log_prob(value) @property def mean(self): return self.base_dist.mean @property def variance(self): return self.base_dist.variance def sample(self, sample_shape=torch.Size()): return self.base_dist.sample(sample_shape) def rsample(self, sample_shape=torch.Size()): return self.base_dist.rsample(sample_shape) def entropy(self): entropy = self.base_dist.entropy() return entropy
class NormalApproximation(KernelDensityEstimate): def __init__(self, independent=True): super().__init__() self._dist = None # type: torch.distributions.Distribution self._indep = independent self._shape = None def fit(self, x, w): self._shape = (x.shape[0], ) if not self._indep: self._dist = _construct_mvn(x, w) return self mean = (w.unsqueeze(-1) * x).sum(0) var = robust_var(x, w, mean) self._dist = Independent(Normal(mean, var.sqrt()), 1) return self def sample(self, inds=None): return self._dist.sample(self._shape)
def test_independent_normal() -> None: num_samples = 2000 dim = 4 loc = np.arange(0, dim) / float(dim) diag = np.arange(dim) / dim + 0.5 Sigma = diag**2 distr = Independent( Normal(loc=torch.Tensor(loc), scale=torch.Tensor(diag)), 1) assert np.allclose( distr.variance.numpy(), Sigma, atol=0.1, rtol=0.1 ), f"did not match: sigma = {Sigma}, sigma_hat = {distr.variance.numpy()}" samples = distr.sample((num_samples, )) loc_hat, diag_hat = maximum_likelihood_estimate_sgd( NormalOutput(dim=dim), samples, learning_rate=0.01, num_epochs=10, ) distr = Independent( Normal(loc=torch.Tensor(loc_hat), scale=torch.Tensor(diag_hat)), 1) Sigma_hat = distr.variance.numpy() assert np.allclose( loc_hat, loc, atol=0.2, rtol=0.1), f"mu did not match: loc = {loc}, loc_hat = {loc_hat}" assert np.allclose( Sigma_hat, Sigma, atol=0.1, rtol=0.1 ), f"sigma did not match: sigma = {Sigma}, sigma_hat = {Sigma_hat}"
def forward(self, x, mean=False, z_q=None): blocks = [] used_latents = [] distributions = [] if isinstance(mean, bool): mean = [mean] * self.num_latent_levels features = x for i, block in enumerate(self.res_layers): #print("Block",i,block) features = block(features) blocks.append(features) if i != self.num_levels - 1: features = self.Pool_layers[i](features) decoder_features = blocks[-1] #print(decoder_features.shape,1) for proba_level in range(self.num_latent_levels): #print(proba_level) latent_dim = self._latent_dims[proba_level] mu_log_sigma = self.probabilistic_block(decoder_features) #print(mu_log_sigma.shape,"mu logsigma shape") # mu_log_sigma = torch.squeeze(mu_log_sigma,dim=1) # print(mu_log_sigma.shape,"mu logsigma shape squeeze") # print(mu_log_sigma[Ellipsis,:latent_dim].shape,"mu shape Ellipsis") # print(mu_log_sigma[Ellipsis,latent_dim:].shape,"logsigma shape Ellipsis") mu = mu_log_sigma[:, :latent_dim] #print("mu shape:",mu.shape) log_sigma = mu_log_sigma[:, latent_dim:] #print("Logsigma shape:",log_sigma.shape) # mu = mu_log_sigma[:,:latent_dim,...] # print("mu shape:",mu.shape) # log_sigma = mu_log_sigma[:,latent_dim:,...] # print("Logsigma shape:",log_sigma.shape) dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1) #dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)),0) distributions.append(dist) if z_q is not None: z = z_q[proba_level] #print(z.shape,"z_q") elif mean[proba_level]: z = dist.base_dist.loc #print(z.shape,"Proba level") else: z = dist.sample() #print(z.shape,"z shape") used_latents.append(z) # print(z.shape,"sample shape") decoder_output_lo = torch.cat([z, decoder_features], axis=1) # print(decoder_output_lo.shape,"decoder_lo") decoder_output_hi = self.interpolate(decoder_output_lo) # print(decoder_output_hi.shape,"decoder_hi") # print(blocks[::-1][proba_level + 1].shape,"block") decoder_features = torch.cat( [decoder_output_hi, blocks[::-1][proba_level + 1]], axis=1) # print(decoder_features.shape) decoder_features = self.decoder_layers[proba_level]( decoder_features) #print('decoder features {}'.format(decoder_features.shape)) return { 'decoder_features': decoder_features, 'encoder_features': blocks, 'distributions': distributions, 'used_latents': used_latents }
class UNetVAEGenerator(GeneralVAE): def __init__(self, imsize, n_channels_in, n_channels_out, n_hidden, z_dim, device="cpu", **kwargs): super(UNetVAEGenerator, self).__init__(n_channels_in, n_channels_out, device, **kwargs) self.z_dim = z_dim hidden_dims = [n_hidden, n_hidden * 2, n_hidden * 4, n_hidden * 8] # embedder self.enc1 = nn.Sequential( nn.Conv2d(n_channels_in, hidden_dims[0], kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(hidden_dims[0]), nn.LeakyReLU(0.2), nn.Dropout(0.1)) self.enc2 = nn.Sequential( nn.Conv2d(hidden_dims[0], hidden_dims[1], kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(hidden_dims[1]), nn.LeakyReLU(0.2), nn.Dropout(0.1)) self.enc3 = nn.Sequential( nn.Conv2d(hidden_dims[1], hidden_dims[2], kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(hidden_dims[2]), nn.LeakyReLU(0.2), nn.Dropout(0.1)) self.enc4 = nn.Sequential( nn.Conv2d(hidden_dims[2], hidden_dims[3], kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(hidden_dims[3]), nn.LeakyReLU(0.2), nn.Dropout(0.1)) enc_imsize = (1 + (imsize[0] - 1) // (2**4), 1 + (imsize[1] - 1) // (2**4)) self.mu = nn.Sequential( Flatten(), nn.Linear(hidden_dims[3] * enc_imsize[0] * enc_imsize[1], z_dim)) # n_channels depends on img resolution self.logvar = nn.Sequential( Flatten(), nn.Linear(hidden_dims[3] * enc_imsize[0] * enc_imsize[1], z_dim)) self.project_z = nn.Sequential( nn.Linear(z_dim, hidden_dims[3] * enc_imsize[0] * enc_imsize[1]), UnFlatten(n_channels=hidden_dims[3], im_size=enc_imsize), nn.Conv2d(hidden_dims[3], hidden_dims[3], kernel_size=2, padding=2), nn.BatchNorm2d(hidden_dims[3]), nn.LeakyReLU(0.2)) self.dec0 = nn.Sequential( nn.ConvTranspose2d(hidden_dims[3], hidden_dims[2], kernel_size=2, stride=2), nn.BatchNorm2d(hidden_dims[2]), nn.LeakyReLU(0.2), nn.Dropout(0.1)) self.dec1 = nn.Sequential( nn.ConvTranspose2d(hidden_dims[3], hidden_dims[1], kernel_size=2, stride=2), nn.BatchNorm2d(hidden_dims[1]), nn.LeakyReLU(0.2), nn.Dropout(0.1)) self.dec2 = nn.Sequential( nn.ConvTranspose2d(hidden_dims[2], hidden_dims[0], kernel_size=2, stride=2), nn.BatchNorm2d(hidden_dims[0]), nn.LeakyReLU(0.2), nn.Dropout(0.1)) self.dec3 = nn.Sequential( nn.ConvTranspose2d(hidden_dims[1], hidden_dims[0], kernel_size=2, stride=2), nn.BatchNorm2d(hidden_dims[0]), nn.LeakyReLU(0.2), nn.Dropout(0.1)) self.zres1 = Noise_injector(hidden_dims[1], z_dim, n_channels_in, hidden_dims[1], device=device).to(device) self.zres2 = Noise_injector(hidden_dims[0], z_dim, n_channels_in, hidden_dims[0], device=device).to(device) self.out = Noise_injector(hidden_dims[0], z_dim, n_channels_in, n_channels_out, device=device).to(device) initialize_weights(self.dec3, self.dec2, self.dec1, self.dec0) self.mu.apply(weights_init) self.logvar.apply(weights_init) self.project_z.apply(weights_init) def forward(self, x, return_mu_logvar=False): mu, logvar = self.encode(x) if return_mu_logvar: return mu, logvar else: z = self.latent_dist.sample() return self.decode(z) def encode(self, x): self.down1 = self.enc1(x) self.down2 = self.enc2(self.down1) self.down3 = self.enc3(self.down2) self.down4 = self.enc4(self.down3) mu = self.mu(self.down4) logvar = self.logvar(self.down4).clamp(min=np.log(1e-7)) std = logvar.mul(0.5).exp_() self.latent_dist = Independent(Normal(loc=mu, scale=std), 1) return mu, logvar def decode(self, z, ign_idxs=None): up1 = self.dec0(self.down4) up2 = self.dec1(torch.cat((up1, self.down3), dim=1)) #skip connection up2b = nn.functional.leaky_relu(self.zres1(up2, z)) # noise injection up3 = self.dec2(torch.cat((up2b, self.down2), dim=1)) up3b = nn.functional.leaky_relu(self.zres2(up3, z)) up4 = self.dec3(torch.cat((up3b, self.down1), dim=1)) logits = self.out(up4, z) out = F.softmax(logits, dim=1) if ign_idxs is None: return out else: # set unlabelled pixels to class unlabelled for Cityscapes # masks the adv loss by preventing gradients from being formed in unlabelled pixs w = torch.ones(out.shape) w[ign_idxs[0], :, ign_idxs[1], ign_idxs[2]] = 0. r = torch.zeros(out.shape) r[ign_idxs[0], 24, ign_idxs[1], ign_idxs[2]] = 1. out = out * w.to(DEVICE) + r.to(DEVICE) return out def sample(self, x, n_samples=1, ign_idxs=None): self.encode(x) # sample z z = self.latent_dist.sample((n_samples, )) # serial decoding if ign_idxs is None: pred_dist = torch_comp_along_dim(self.decode, z, dim=0) else: pred_dist = torch_comp_along_dim(self.decode, z, ign_idxs, dim=0) avg_pred = pred_dist.mean(0) return pred_dist, None, avg_pred
class TorchGaussianMixtureDistribution(TorchDistributionWrapper): @staticmethod def required_model_output_shape(action_space, model_config): return prod((2, model_config['custom_model_config']['num_gaussians']) + action_space.shape) def __init__(self, inputs: List[torch.Tensor], model: TorchModelV2): super(TorchDistributionWrapper, self).__init__(inputs, model) assert len(inputs.shape) == 2 self.batch_size = self.inputs.shape[0] self.num_gaussians = model.model_config['custom_model_config'][ 'num_gaussians'] self.monte_samples = model.model_config['custom_model_config'][ 'monte_samples'] inputs = torch.reshape(self.inputs, (self.batch_size, 2, self.num_gaussians, -1)) self.action_dim = inputs.shape[-1] assert not torch.isnan(inputs).any(), "Input nan aborting" self.means = inputs[:, 0, :, :] # batch_size x num_gaussians x action_dim self.sigmas = torch.exp( inputs[:, 1, :, :]) # batch_size x num_gaussians x action_dim self.cat = Categorical( torch.ones(self.batch_size, self.num_gaussians, device=inputs.device, requires_grad=False)) self.normals = Independent(Normal(self.means, self.sigmas), 1) def logp(self, actions: torch.Tensor): actions = actions.view( self.batch_size, 1, -1) # batch_size x 1 (broadcast to num gaussians) x action_dim mix_lps = self.cat.logits # batch_size x num_gaussians x action_dim normal_lps = self.normals.log_prob( actions) # batch_size x num_gaussians x action_dim assert not torch.isnan(mix_lps).any(), "output nan aborting" assert not torch.isnan(normal_lps).any(), "output nan aborting" return torch.logsumexp(mix_lps + normal_lps, dim=1) # reduce along num gaussians def deterministic_sample(self) -> torch.Tensor: self.last_sample = self.means[:, 0, :] # select the mode of the first gaussian return self.last_sample def __rsamples(self): """ Compute samples that can be differentiated through """ # Using reparameterization trick i.e. rsample normal_samples = self.normals.rsample( (self.monte_samples, )) # monte_samples x batch_size x num_gaussians x action_dim cat_samples = self.cat.sample( (self.monte_samples, )) # monte_samples x batch_size # First we need to expand cat so that it has the same dimension as normal samples cat_samples = cat_samples.reshape(self.monte_samples, -1, 1, 1).expand(-1, -1, -1, self.action_dim) # We select the normal distribution based on the outputs of # the categorical distribution return torch.gather(normal_samples, 2, cat_samples).squeeze( dim=2) # monte_samples x batch_size x action_dim def kl(self, q: ActionDistribution) -> torch.Tensor: """ KL(self || q) estimated with monte carlo sampling """ rsamples = self.__rsamples().unbind(0) log_ratios = torch.stack( [self.logp(rsample) - q.logp(rsample) for rsample in rsamples]) assert not torch.isnan(log_ratios).any(), "output nan aborting" return log_ratios.mean(0) def entropy(self) -> torch.Tensor: """ H(self) estimated with monte carlo sampling """ rsamples = self.__rsamples().unbind(0) log_ps = torch.stack([-self.logp(rsample) for rsample in rsamples]) assert not torch.isnan(log_ps).any(), "output nan aborting" return log_ps.mean(0) def sample(self): normal_samples = self.normals.sample( ) # batch_size x num_gaussians x action_dim cat_samples = self.cat.sample() # batch_size # First we need to expand cat so that it has the same dimension as normal samples cat_samples = cat_samples.view(-1, 1, 1).expand(-1, -1, self.action_dim) # We select the normal distribution based on the outputs of # the categorical distribution self.last_sample = torch.gather(normal_samples, 1, cat_samples).squeeze( dim=1) # batch_size x action_dim assert len( self.last_sample.shape) == 2, f"shape, {self.last_sample.shape}" return self.last_sample
def __call__(self, x, out_keys=['action'], info={}, **kwargs): # Output dictionary out_policy = {} # Forward pass of feature networks to obtain features if self.recurrent: out_network = self.network(x=x, hidden_states=self.rnn_states, mask=info.get('mask', None)) features = out_network['output'] # Update the tracking of current RNN hidden states self.rnn_states = out_network['hidden_states'] else: features = self.network(x) # Forward pass through mean head to obtain mean values for Gaussian distribution mean = self.network.mean_head(features) # Obtain logvar based on the options if isinstance(self.network.logvar_head, nn.Linear): # linear layer, then do forward pass logvar = self.network.logvar_head(features) else: # either Tensor or nn.Parameter logvar = self.network.logvar_head # Expand as same shape as mean logvar = logvar.expand_as(mean) # Forward pass of value head to obtain value function if required if 'state_value' in out_keys: out_policy['state_value'] = self.network.value_head( features).squeeze(-1) # squeeze final single dim # Get std from logvar if self.std_style == 'exp': std = torch.exp(0.5 * logvar) elif self.std_style == 'softplus': std = F.softplus(logvar) # Lower bound threshould for std min_std = torch.full(std.size(), self.min_std).type_as(std).to(self.device) std = torch.max(std, min_std) # Create independent Gaussian distributions i.e. Diagonal Gaussian action_dist = Independent(Normal(loc=mean, scale=std), 1) # Sample action from the distribution (no gradient) # Do not use `rsample()`, it leads to zero gradient of mean head ! action = action_dist.sample() out_policy['action'] = action # Calculate log-probability of the sampled action if 'action_logprob' in out_keys: out_policy['action_logprob'] = action_dist.log_prob(action) # Calculate policy entropy conditioned on state if 'entropy' in out_keys: out_policy['entropy'] = action_dist.entropy() # Calculate policy perplexity i.e. exp(entropy) if 'perplexity' in out_keys: out_policy['perplexity'] = action_dist.perplexity() # sanity check for NaN if torch.any(torch.isnan(action)): while True: msg = 'NaN ! A workaround is to learn state-independent std or use tanh rather than relu' msg2 = f'check: \n\t mean: {mean}, logvar: {logvar}' print(msg + msg2) # Constraint action in valid range out_policy['action'] = self.constraint_action(action) return out_policy
class UNetGenerator(GeneralVAE): def __init__(self, imsize, n_channels_in,n_channels_out, n_hidden, z_dim, device = "cpu", **kwargs): super(UNetGenerator, self).__init__(n_channels_in, n_channels_out, device, **kwargs) self.z_dim = z_dim hidden_dims = [n_hidden, n_hidden*2, n_hidden*4, n_hidden*8] # embedder self.enc1 = nn.Sequential(nn.Conv2d(n_channels_in, hidden_dims[0], kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(hidden_dims[0]), nn.LeakyReLU(0.2), nn.Dropout(0.1)) self.enc2 = nn.Sequential(nn.Conv2d(hidden_dims[0], hidden_dims[1], kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(hidden_dims[1]), nn.LeakyReLU(0.2), nn.Dropout(0.1)) self.enc3 = nn.Sequential(nn.Conv2d(hidden_dims[1], hidden_dims[2], kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(hidden_dims[2]), nn.LeakyReLU(0.2), nn.Dropout(0.1)) self.enc4 = nn.Sequential(nn.Conv2d(hidden_dims[2], hidden_dims[3], kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(hidden_dims[3]), nn.LeakyReLU(0.2), nn.Dropout(0.1)) self.dec0 = nn.Sequential(nn.ConvTranspose2d(hidden_dims[3], hidden_dims[2], kernel_size=2, stride=2), nn.BatchNorm2d(hidden_dims[2]), nn.LeakyReLU(0.2), nn.Dropout(0.1)) self.dec1 = nn.Sequential(nn.ConvTranspose2d(hidden_dims[3], hidden_dims[1], kernel_size=2, stride=2), nn.BatchNorm2d(hidden_dims[1]), nn.LeakyReLU(0.2), nn.Dropout(0.1)) self.dec2 = nn.Sequential(nn.ConvTranspose2d(hidden_dims[2], hidden_dims[0], kernel_size=2, stride=2), nn.BatchNorm2d(hidden_dims[0]), nn.LeakyReLU(0.2), nn.Dropout(0.1)) self.dec3 = nn.Sequential(nn.ConvTranspose2d(hidden_dims[1],hidden_dims[0], kernel_size=2, stride=2), nn.BatchNorm2d(hidden_dims[0]), nn.LeakyReLU(0.2), nn.Dropout(0.1)) self.zres1 = Noise_injector(hidden_dims[1], z_dim, n_channels_in, hidden_dims[1], device=device).to(device) self.zres2 = Noise_injector(hidden_dims[0], z_dim, n_channels_in, hidden_dims[0], device=device).to(device) self.out = Noise_injector(hidden_dims[0], z_dim, n_channels_in, n_channels_out, device=device).to(device) initialize_weights(self.dec3, self.dec2, self.dec1, self.dec0) def forward(self, x): self.encode(x) self.get_gauss(x) z = self.gauss.sample() return self.decode(z) def encode(self, x): self.down1 = self.enc1(x) self.down2 = self.enc2(self.down1) self.down3 = self.enc3(self.down2) self.down4 = self.enc4(self.down3) def decode(self, z, ign_idxs=None): up1 = self.dec0(self.down4) up2 = self.dec1(torch.cat((up1, self.down3),dim=1)) #skip connection up2b = nn.functional.leaky_relu(self.zres1(up2, z)) # noise injection up3 = self.dec2(torch.cat((up2b, self.down2), dim=1)) up3b = nn.functional.leaky_relu(self.zres2(up3, z)) up4 = self.dec3(torch.cat((up3b, self.down1),dim=1)) logits = self.out(up4,z) out = F.softmax(logits, dim=1) if ign_idxs is None: return out else: # set unlabelled pixels to class unlabelled for Cityscapes # masks the adv loss by preventing gradients from being formed in unlabelled pixs w = torch.ones(out.shape) w[ign_idxs[0], :, ign_idxs[1], ign_idxs[2]] = 0. r = torch.zeros(out.shape) r[ign_idxs[0], 24, ign_idxs[1], ign_idxs[2]] = 1. out = out * w.to(DEVICE) + r.to(DEVICE) return out def get_gauss(self, x): b_size = len(x) self.gauss = Independent(Normal(loc=torch.zeros((b_size, self.z_dim)).float().to(DEVICE), scale=torch.ones((b_size, self.z_dim)).float().to(DEVICE)), 1) def sample(self, x, ign_idxs = None, n_samples=1): self.get_gauss(x) # sample z z = self.gauss.sample((n_samples,)) # encode z self.encode(x) # serial decoding if ign_idxs is None: pred_dist = torch_comp_along_dim(self.decode, z, dim=0) else: pred_dist = torch_comp_along_dim(self.decode, z, ign_idxs, dim=0) # compute the average prediction avg_pred = pred_dist.mean(0) return pred_dist, None, avg_pred