def test_base_dist(): for dims in [2, 3, 4, 5]: base_dists = [ TransformedDistribution( Uniform(torch.zeros(dims), torch.ones(dims)), SigmoidTransform().inv), MultivariateNormal(torch.zeros(dims), torch.eye(dims)), GeneralisedNormal(torch.zeros(dims), torch.ones(dims), torch.tensor(8.0)) ] for base_dist in base_dists: t = Trainer(dims, flow='choleksy', base_dist=base_dist) test_data = np.random.normal(size=(10, dims)) test_data = torch.from_numpy(test_data).float() z, z_log_det = t.forward(test_data) assert z.shape == torch.Size([10, dims]) assert z_log_det.shape == torch.Size([10]) x, x_log_det = t.inverse(z) diff = torch.max(x - test_data).detach().cpu().numpy() assert np.abs(diff) <= max_forward_backward_diff diff = torch.max(x_log_det + z_log_det).detach().cpu().numpy() assert np.abs(diff) <= max_forward_backward_diff samples = t.get_synthetic_samples(10) assert samples.shape == torch.Size([10, dims]) log_probs = t.log_probs(test_data) assert log_probs.shape == torch.Size([10])
def distribution( self, distr_args, loc: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None, ) -> Distribution: r""" Construct the associated distribution, given the collection of constructor arguments and, optionally, a scale tensor. Parameters ---------- distr_args Constructor arguments for the underlying Distribution type. loc Optional tensor, of the same shape as the batch_shape+event_shape of the resulting distribution. scale Optional tensor, of the same shape as the batch_shape+event_shape of the resulting distribution. """ distr = self._base_distribution(distr_args) if loc is None and scale is None: return distr else: transform = AffineTransform( loc=0.0 if loc is None else loc, scale=1.0 if scale is None else scale, ) return TransformedDistribution(distr, [transform])
def __init__( self, model: "Model", elbo_hats: List[float], y: torch.Tensor, q: Distribution ): self.elbo_hats, self.model, self.y, self.q = elbo_hats, model, y, q self.input_length = len(y) sds = torch.sqrt(self.q.variance) for p in model.params: setattr(self, p.prior_name, getattr(model, p.prior_name)) if p.dimension > 1: continue # construct marginals in optimization space if isinstance(p, TransformedModelParameter): tfm_post_marg = Normal(self.q.loc[p.index], sds[p.index]) setattr(self, p.tfm_post_marg_name, tfm_post_marg) tfm_prior = getattr(model, p.tfm_prior_name) setattr(self, p.tfm_prior_name, tfm_prior) tfm = getattr(model, p.tfm_name) setattr(self, p.tfm_name, tfm) post_marg = TransformedDistribution(tfm_post_marg, tfm.inv) setattr(self, p.post_marg_name, post_marg) else: post_marg = Normal(self.q.loc[p.index], sds[p.index]) setattr(self, p.post_marg_name, post_marg)
def test_compose_affine(event_dims): transforms = [AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims] transform = ComposeTransform(transforms) assert transform.codomain.event_dim == max(event_dims) assert transform.domain.event_dim == max(event_dims) base_dist = Normal(0, 1) if transform.domain.event_dim: base_dist = base_dist.expand((1,) * transform.domain.event_dim) dist = TransformedDistribution(base_dist, transform.parts) assert dist.support.event_dim == max(event_dims) base_dist = Dirichlet(torch.ones(5)) if transform.domain.event_dim > 1: base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1)) dist = TransformedDistribution(base_dist, transforms) assert dist.support.event_dim == max(1, max(event_dims))
def _define_transdist(loc: torch.Tensor, scale: torch.Tensor, inc_dist: Distribution, ndim: int): loc, scale = torch.broadcast_tensors(loc, scale) shape = loc.shape[:-ndim] if ndim > 0 else loc.shape return TransformedDistribution(inc_dist.expand(shape), AffineTransform(loc, scale, event_dim=ndim))
def get_transformed_dists(self) -> Tuple[Distribution, ...]: res = tuple() for bij, msk in zip(self._bijections, self._mask): dist = TransformedDistribution( Normal(self._mean[msk], self._log_std[msk].exp()), bij) res += (dist, ) return res
def _init_and_train_flow(data, nh, l, prior_dist, epochs, device, opt_method='adam', verbose=False): # init and save 2 normalizing flows, 1 for each direction d = data.shape[1] if d > 2: print('using higher D implementation') affine_flow = AffineFullFlowGeneral else: affine_flow = AffineFullFlow if prior_dist == 'laplace': prior = Laplace(torch.zeros(d), torch.ones(d)) else: prior = TransformedDistribution( Uniform(torch.zeros(d), torch.ones(d)), SigmoidTransform().inv) flows = [ affine_flow(dim=d, nh=nh, parity=False, net_class=MLP1layer) for _ in range(l) ] flow = NormalizingFlowModel(prior, flows).to(device) dset = CustomSyntheticDatasetDensity(data.astype(np.float32), device=device) train_loader = DataLoader(dset, shuffle=True, batch_size=128) optimizer = optim.Adam(flow.parameters(), lr=1e-4, weight_decay=1e-5) if opt_method == 'scheduler': scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=3, verbose=verbose) flow.train() loss_vals = [] for e in range(epochs): loss_val = 0 for _, x in enumerate(train_loader): x.to(device) # compute loss _, prior_logprob, log_det = flow(x) loss = -torch.sum(prior_logprob + log_det) loss_val += loss.item() # optimize optimizer.zero_grad() loss.backward() optimizer.step() if opt_method == 'scheduler': scheduler.step(loss_val / len(train_loader)) if verbose: print('epoch {}/{} \tloss: {}'.format(e, epochs, loss_val)) loss_vals.append(loss_val) return flow, loss_vals
def forward(self, x): mean = self.mean_head(x) logstd = torch.tanh(self.logstd_head(x)) logstd = self.logstd_min + 0.5 * (1 + logstd) * (self.logstd_max - self.logstd_min) std = torch.exp(logstd) action_dist = TransformedDistribution( Independent(Normal(loc=mean, scale=std), 1), [TanhTransform()]) return action_dist
def distribution( self, distr_args, scale: Optional[torch.Tensor] = None ) -> Distribution: distr = Independent(Normal(*distr_args), 1) if scale is None: return distr else: return TransformedDistribution(distr, [AffineTransform(loc=0, scale=scale)])
def transformed_dist(self): """ Returns the unconstrained distribution. """ if not self.trainable: raise ValueError('Is not of `Distribution` instance!') return TransformedDistribution(self._prior, [self.bijection.inv])
def distribution( self, distr_args, scale: Optional[torch.Tensor] = None ) -> Distribution: if scale is None: return self.distr_cls(*distr_args) else: distr = self.distr_cls(*distr_args) return TransformedDistribution(distr, [AffineTransform(loc=0, scale=scale)])
def distribution( self, distr_args, scale: Optional[torch.Tensor] = None ) -> Distribution: loc, scale_tri = distr_args distr = MultivariateNormal(loc=loc, scale_tril=scale_tri) if scale is None: return distr else: return TransformedDistribution(distr, [AffineTransform(loc=0, scale=scale)])
def forward(self, x): for layer in self.feature_layers: x = F.relu(layer(x)) mean = self.mean_head(x) logstd = self.logstd_head(x) logstd = torch.tanh(logstd) logstd = self.LOGSTD_MIN + 0.5*(self.LOGSTD_MAX - self.LOGSTD_MIN)*(1 + logstd) std = torch.exp(logstd) dist = TransformedDistribution(Independent(Normal(mean, std), 1), [TanhTransform(cache_size=1)]) return dist
def test_save_load_transform(): # Evaluating `log_prob` will create a weakref `_inv` which cannot be pickled. Here, we check # that `__getstate__` correctly handles the weakref, and that we can evaluate the density after. dist = TransformedDistribution(Normal(0, 1), [AffineTransform(2, 3)]) x = torch.linspace(0, 1, 10) log_prob = dist.log_prob(x) stream = io.BytesIO() torch.save(dist, stream) stream.seek(0) other = torch.load(stream) assert torch.allclose(log_prob, other.log_prob(x))
def test_log_prob_d2(eta): dist = LKJCorrCholesky(2, torch.tensor([eta])) test_dist = TransformedDistribution(Beta(eta, eta), AffineTransform(loc=-1., scale=2.0)) samples = dist.sample(torch.Size([100])) lp = dist.log_prob(samples) x = samples[..., 1, 0] tst = test_dist.log_prob(x) assert_tensors_equal(lp, tst, prec=1e-6)
def distribution(self, distr_args, scale: Optional[torch.Tensor] = None) -> Distribution: mix_logits, loc, dist_scale = distr_args distr = MixtureSameFamily(Categorical(logits=mix_logits), Normal(loc, dist_scale)) if scale is None: return distr else: return TransformedDistribution( distr, [AffineTransform(loc=0, scale=scale)])
def forward(self, state, mean_action=False): mu, log_std = self.network(state).chunk(2, dim=-1) log_std = torch.clamp( log_std, LOG_MIN, LOG_MAX) # to make it not too random/deterministic normal = TransformedDistribution( Independent(Normal(mu, log_std.exp()), 1), [TanhTransform(), AffineTransform(loc=self.loc, scale=self.scale)]) if mean_action: return self.loc * torch.tanh(mu) + self.scale return normal
def test_log_prob_d2(concentration): dist = LKJCholesky(2, torch.tensor([concentration])) test_dist = TransformedDistribution(Beta(concentration, concentration), AffineTransform(loc=-1., scale=2.0)) samples = dist.sample(torch.Size([100])) lp = dist.log_prob(samples) x = samples[..., 1, 0] tst = test_dist.log_prob(x) # LKJ prevents inf values in log_prob lp[tst == math.inf] = math.inf # substitute inf for comparison assert_tensors_equal(lp, tst, prec=1e-3)
def i_sample(self, shape=None, as_dist=False): shape = size_getter(shape) dist = TransformedDistribution( self.noise0.expand(shape), AffineTransform(self.i_mean(), self.i_scale(), event_dim=self._event_dim)) if as_dist: return dist return dist.sample()
def test_transformed_distribution(base_batch_dim, base_event_dim, transform_dim, num_transforms, sample_shape): shape = torch.Size([2, 3, 4, 5]) base_dist = Normal(0, 1) base_dist = base_dist.expand(shape[4 - base_batch_dim - base_event_dim:]) if base_event_dim: base_dist = Independent(base_dist, base_event_dim) transforms = [ AffineTransform(torch.zeros(shape[4 - transform_dim:]), 1), ReshapeTransform((4, 5), (20, )), ReshapeTransform((3, 20), (6, 10)) ] transforms = transforms[:num_transforms] transform = ComposeTransform(transforms) # Check validation in .__init__(). if base_batch_dim + base_event_dim < transform.domain.event_dim: with pytest.raises(ValueError): TransformedDistribution(base_dist, transforms) return d = TransformedDistribution(base_dist, transforms) # Check sampling is sufficiently expanded. x = d.sample(sample_shape) assert x.shape == sample_shape + d.batch_shape + d.event_shape num_unique = len(set(x.reshape(-1).tolist())) assert num_unique >= 0.9 * x.numel() # Check log_prob shape on full samples. log_prob = d.log_prob(x) assert log_prob.shape == sample_shape + d.batch_shape # Check log_prob shape on partial samples. y = x while y.dim() > len(d.event_shape): y = y[0] log_prob = d.log_prob(y) assert log_prob.shape == d.batch_shape
def distribution(self, distr_args, scale: Optional[torch.Tensor] = None) -> Distribution: mix_logits, loc, scale = distr_args comp_distr = Normal(loc, scale) if scale is None: return MixtureSameFamily(Categorical(logits=mix_logits), comp_distr) else: scaled_comp_distr = TransformedDistribution( comp_distr, [AffineTransform(loc=0, scale=scale)]) return MixtureSameFamily(Categorical(logits=mix_logits), scaled_comp_distr)
def _get_flow_arch(self, parity=False): """ Returns a normalizing flow according to the config file. Parameters: ---------- parity: bool If True, the flow follows the (1, 2) permutations, otherwise it follows the (2, 1) permutation. """ # this method only gets called by _train, which in turn is only called after self.dim has been initialized dim = self.dim # prior if self.config.flow.prior_dist == 'laplace': prior = Laplace(torch.zeros(dim).to(self.device), torch.ones(dim).to(self.device)) else: prior = TransformedDistribution(Uniform(torch.zeros(dim).to(self.device), torch.ones(dim).to(self.device)), SigmoidTransform().inv) # net type for flow parameters if self.config.flow.net_class.lower() == 'mlp': net_class = MLP1layer elif self.config.flow.net_class.lower() == 'mlp4': net_class = MLP4 elif self.config.flow.net_class.lower() == 'armlp': net_class = ARMLP else: raise NotImplementedError('net_class {} not understood.'.format(self.config.flow.net_class)) # flow type def ar_flow(hidden_dim): if self.config.flow.architecture.lower() in ['cl', 'realnvp']: return AffineCL(dim=dim, nh=hidden_dim, scale_base=self.config.flow.scale_base, shift_base=self.config.flow.shift_base, net_class=net_class, parity=parity, scale=self.config.flow.scale) elif self.config.flow.architecture.lower() == 'maf': return MAF(dim=dim, nh=hidden_dim, net_class=net_class, parity=parity) elif self.config.flow.architecture.lower() == 'spline': return NSF_AR(dim=dim, hidden_dim=hidden_dim, base_network=net_class) else: raise NotImplementedError('Architecture {} not understood.'.format(self.config.flow.architecture)) # support training multiple flows for varying depth and width, and keep only best self.n_layers = self.n_layers if type(self.n_layers) is list else [self.n_layers] self.n_hidden = self.n_hidden if type(self.n_hidden) is list else [self.n_hidden] normalizing_flows = [] for nl in self.n_layers: # only 1 item in list self.n_layer= [5] for nh in self.n_hidden: # only 1 item in list self.n_layer= [10] # construct normalizing flows flow_list = [ar_flow(nh) for _ in range(nl)] normalizing_flows.append(NormalizingFlowModel(prior, flow_list).to(self.device)) return normalizing_flows
def forward(self, state): policy_mean, policy_log_std = self.policy(state).chunk(2, dim=1) policy_log_std = torch.clamp(policy_log_std, min=self.log_std_min, max=self.log_std_max) policy = TransformedDistribution( Independent(Normal(policy_mean, policy_log_std.exp()), 1), [ TanhTransform(), AffineTransform(loc=self.action_loc, scale=self.action_scale) ]) policy.mean_ = self.action_scale * torch.tanh( policy.base_dist.mean ) + self.action_loc # TODO: See if mean attr can be overwritten return policy
def get_continuous_dist(self, mean, std): if self._dist == "tanh_normal_dreamer_v1": mean = self._mean_scale * torch.tanh(mean / self._mean_scale) std = F.softplus(std + self.raw_init_std) + self._min_std dist = Normal(mean, std) dist = TransformedDistribution(dist, TanhBijector()) dist = Independent(dist, 1) dist = SampleDist(dist) elif self._dist == "trunc_normal": mean = torch.tanh(mean) std = 2 * torch.sigmoid(std / 2) + self._min_std dist = SafeTruncatedNormal(mean, std, -1, 1) dist = Independent(dist, 1) return dist
def forward(self, x, get_logprob=False): mu_logstd = self.network(x) mu, logstd = mu_logstd.chunk(2, dim=1) logstd = torch.clamp(logstd, -20, 2) std = logstd.exp() dist = Normal(mu, std) transforms = [TanhTransform(cache_size=1)] dist = TransformedDistribution(dist, transforms) action = dist.rsample() if get_logprob: logprob = dist.log_prob(action).sum(axis=-1, keepdim=True) else: logprob = None mean = torch.tanh(mu) return action, logprob, mean
def test_compose_reshape(batch_shape): transforms = [ReshapeTransform((), ()), ReshapeTransform((2,), (1, 2)), ReshapeTransform((3, 1, 2), (6,)), ReshapeTransform((6,), (2, 3))] transform = ComposeTransform(transforms) assert transform.codomain.event_dim == 2 assert transform.domain.event_dim == 2 data = torch.randn(batch_shape + (3, 2)) assert transform(data).shape == batch_shape + (2, 3) dist = TransformedDistribution(Normal(data, 1), transforms) assert dist.batch_shape == batch_shape assert dist.event_shape == (2, 3) assert dist.support.event_dim == 2
def forward(self, x, compute_pi=True, compute_log_pi=True): for layer in self.feature_layers: x = F.relu(layer(x)) mu = self.mean_head(x) logstd = self.logstd_head(x) logstd = torch.tanh(logstd) logstd = LOGSTD_MIN + 0.5 * (LOGSTD_MAX - LOGSTD_MIN) * ( logstd + 1) dist = TransformedDistribution(Independent(Normal(mu, logstd.exp()), 1), [TanhTransform(cache_size=1)]) if compute_pi: #std = logstd.exp() #noise = torch.randn_like(mu) #pi = mu + noise * std pi = dist.rsample() else: pi = None if compute_log_pi: #log_pi = Independent(Normal(mu, logstd.exp()), 1).log_prob(pi).unsqueeze(-1) #log_pi = gaussian_likelihood(noise, logstd) log_pi = dist.log_prob(pi).unsqueeze(-1) else: log_pi = None mu = torch.tanh(mu) #if compute_pi: # pi = torch.tanh(pi) #if compute_log_pi: # log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True) #print(mu.shape, pi.shape, log_pi.shape) #print(log_pi) #breakpoint() #mu, pi, log_pi = apply_squashing_func(mu, pi, log_pi) return mu, pi, log_pi
def forward(self, state, get_logprob=False): data = state for i in range(len(self.net_layers)): data = self.net_layers[i](data) data = self.activation(data) mu = self.mu_layer(data) logsigma = self.logsigma_layer(data) logstd = torch.clamp(logsigma, -20, 2) std = logstd.exp() dist = Normal(mu, std) transforms = [TanhTransform(cache_size=1)] dist = TransformedDistribution(dist, transforms) action = dist.rsample() logprob = dist.log_prob(action).sum(axis=-1, keepdim=True) mean = torch.tanh(mu) return action, logprob, mean
def _define_transdist(self, loc, scale): """ Helper method for defining the transition density :param loc: The mean :type loc: torch.Tensor :param scale: The scale :type scale: torch.Tensor :return: Distribution :rtype: Distribution """ loc, scale = torch.broadcast_tensors(loc, scale) shape = _get_shape(loc, self.ndim) return TransformedDistribution( self.noise.expand(shape), AffineTransform(loc, scale, event_dim=self._event_dim))
def make_conditional_flow(dim=2, hidden_dims=[16, 16], condition_dims={'x': 2}, B=10, K=16, num_layers=3, delta=False): ## prior uniform = Uniform(torch.zeros(dim), torch.ones(dim)) logistic = TransformedDistribution(uniform, SigmoidTransform().inv) # conditional flow layers coupling_layer = ConditionalNSF flows = [coupling_layer( dim=dim, K=K, B=B, hidden_dims=hidden_dims, condition_dims=condition_dims) for _ in range(num_layers)] convs = [Invertible1x1Conv(dim=dim) for _ in flows] norms = [ActNorm(dim=dim) for _ in flows] flows = list(itertools.chain(*zip(norms, convs, flows))) # compose the layers into a conditional flow model model = ConditionalNormalizingFlowModel(logistic, flows, delta=delta) return model