def forward(self, x: torch.Tensor) -> dist.Normal: # type: ignore """Output a Gaussian distribution parameterized by ``N(x, self.scale^2)``.""" _validate_input(x) return dist.Normal(x, self.scale)
def prior(self): return distributions.Normal(self.prior_mu, torch.exp(self.prior_logstd))
def __init__( self, n_in, n_latent, n_cat_list, n_hidden, n_layers, t, dropout_rate=0.05, use_batch_norm=True, do_h=True, ): """ Encoder using h representation as described in IAF paper :param n_in: :param n_latent: :param n_cat_list: :param n_hidden: :param n_layers: :param t: :param dropout_rate: :param use_batch_norm: """ super().__init__() self.do_h = do_h msg = "" if do_h else "Not " logger.info(msg="{}Using Hidden State".format(msg)) self.n_latent = n_latent self.encoders = torch.nn.ModuleList() self.encoders.append( EncoderH( n_in=n_in, n_out=n_latent, n_cat_list=n_cat_list, n_layers=n_layers, n_hidden=n_hidden, do_h=do_h, dropout_rate=dropout_rate, use_batch_norm=use_batch_norm, do_sigmoid=False, ) ) n_in = 2 * n_latent if do_h else n_latent for _ in range(t - 1): self.encoders.append( EncoderH( n_in=n_in, n_out=n_latent, n_cat_list=None, n_layers=n_layers, n_hidden=n_hidden, do_h=False, dropout_rate=dropout_rate, use_batch_norm=use_batch_norm, do_sigmoid=True, ) ) self.dist0 = db.Normal( loc=torch.zeros(n_latent, device="cuda"), scale=torch.ones(n_latent, device="cuda"), )
def base_dist(self, major_label): # input_size return D.Normal(self.base_dist_mean[major_label, :], self.base_dist_var[major_label, :])
data[:, 0:2] = data[:, 0:2]/20 return data[:, 0:2], data[:, 2].astype(np.int) # %% X, c = make_pinwheel_data(0.3, 0.05, 1, 512, 0.25) X = torch.Tensor(X) c = torch.Tensor(c) plt.figure(figsize=(10, 10)) plt.scatter(X[:,0].numpy(), X[:,1].numpy(), c=c.numpy(), s=5) plt.show() # %% base_dist = distrib.Normal(loc=torch.zeros(2), scale=torch.ones(2)) # %% X0 = base_dist.sample((1000,)).numpy() # %% colors = np.zeros(len(X0)) idx_0 = np.logical_and(X0[:, 0] < 0, X0[:, 1] < 0) colors[idx_0] = 0 idx_1 = np.logical_and(X0[:, 0] >= 0, X0[:, 1] < 0) colors[idx_1] = 1 idx_2 = np.logical_and(X0[:, 0] >= 0, X0[:, 1] >= 0) colors[idx_2] = 2 idx_3 = np.logical_and(X0[:, 0] < 0, X0[:, 1] >= 0) colors[idx_3] = 3
def __init__(self, args): super().__init__() C, H, W = args.image_dims x_dim = C * H * W # -------------------- # p model -- SSL paper generative semi supervised model M2 # -------------------- self.p_y = D.OneHotCategorical(probs=1 / args.y_dim * torch.ones(1,args.y_dim, device=args.device)) self.p_z = D.Normal(torch.tensor(0., device=args.device), torch.tensor(1., device=args.device)) # parametrized data likelihood p(x|y,z) self.decoder = nn.Sequential(#nn.Dropout(0.5), nn.Linear(args.z_dim + args.y_dim, args.hidden_dim), nn.BatchNorm1d(args.hidden_dim), nn.Softplus(), nn.Linear(args.hidden_dim, args.hidden_dim), nn.BatchNorm1d(args.hidden_dim), nn.Softplus(), #nn.Dropout(0.5), nn.Linear(args.hidden_dim, x_dim), nn.Softplus()) #self.decoder_cnn = nn.Sequential( # #nn.Dropout(0.5), # nn.Conv2d(in_channels=args.image_dims[0], out_channels=10, kernel_size=3, stride=1, padding=1), ### <----------- EVT TILFØJ FLERE CHANNELS # nn.BatchNorm2d(10), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015 # nn.Softplus(), # nn.Conv2d(in_channels=10, out_channels=20, kernel_size=3, stride=1, padding=1), # nn.BatchNorm2d(20), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015 # nn.Softplus(), # # nn.MaxPool2d(kernel_size=2, stride=2, padding=0), # # nn.Dropout(0.4), # nn.Conv2d(in_channels=20, out_channels=args.image_dims[0], kernel_size=3, stride=1, padding=1), # #nn.BatchNorm2d(args.image_dims[0]), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015 # #nn.Softplus() # ) # Transposed Conv test # Before: 1 -> 10 -> 20 -> 1 self.decoder_tcnn = nn.Sequential( #nn.Dropout(0.5), nn.ConvTranspose2d(in_channels=args.image_dims[0], out_channels=10, kernel_size=3, stride=1, padding=1), ### <----------- EVT TILFØJ FLERE CHANNELS nn.BatchNorm2d(10), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015 nn.Softplus(), nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5, stride=1, padding=2), nn.Softplus(), nn.ConvTranspose2d(in_channels=20, out_channels=20, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(20), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015 nn.Softplus(), # nn.MaxPool2d(kernel_size=2, stride=2, padding=0), # nn.Dropout(0.4), nn.ConvTranspose2d(in_channels=20, out_channels=args.image_dims[0], kernel_size=3, stride=1, padding=1), #nn.BatchNorm2d(args.image_dims[0]), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015 #nn.Softplus() ) # -------------------- # q model -- SSL paper eq 4 # -------------------- # parametrized q(y|x) = Cat(y|pi_phi(x)) -- outputs parametrization of categorical distribution #before: 1 -> 10 -> 20 self.encoder_y_cnn = nn.Sequential( #nn.Dropout(0.5), nn.Conv2d(in_channels=args.image_dims[0], out_channels=10, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d(10), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015 nn.Softplus(), nn.MaxPool2d(kernel_size=2, stride=2, padding=0), # nn.Dropout(0.4), nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5, stride=1, padding=2), #nn.BatchNorm2d(20), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015 nn.Softplus(), nn.Conv2d(in_channels=20, out_channels=20, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(20), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015 nn.Softplus(), nn.MaxPool2d(kernel_size=2, stride=2, padding=0) ) self.encoder_y = nn.Sequential(#nn.Dropout(0.5), nn.Linear(20*7*7, args.hidden_dim), # x_dim i stedet for # nn.BatchNorm1d(args.hidden_dim), nn.Softplus(), #nn.Linear(args.hidden_dim, args.hidden_dim), #nn.Softplus(), nn.Linear(args.hidden_dim, args.hidden_dim), #nn.BatchNorm1d(args.hidden_dim), nn.Softplus(), #nn.Dropout(0.5), nn.Linear(args.hidden_dim, args.y_dim)) # parametrized q(z|x,y) = Normal(z|mu_phi(x,y), diag(sigma2_phi(x))) -- output parametrizations for mean and diagonal variance of a Normal distribution #before: 1 -> 10 -> 20 self.encoder_z_cnn = nn.Sequential( #nn.Dropout(0.5), nn.Conv2d(in_channels=args.image_dims[0], out_channels=10, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d(10), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015 nn.Softplus(), nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5, stride=1, padding=2), #nn.BatchNorm2d(20), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015 nn.Softplus(), nn.MaxPool2d(kernel_size=2, stride=2, padding=0), # nn.Dropout(0.4), nn.Conv2d(in_channels=20, out_channels=20, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(20), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015 nn.Softplus(), #nn.Conv2d(in_channels=20, out_channels=20, kernel_size=3, stride=1, padding=1), #nn.BatchNorm2d(20), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015 #nn.Softplus(), #nn.MaxPool2d(kernel_size=2, stride=2, padding=0), #nn.Conv2d(in_channels=20, out_channels=20, kernel_size=3, stride=1, padding=1), ##nn.BatchNorm2d(20), # batch normalization before activation function as suggested in Ioffe and Szegedy 2015 #nn.Softplus(), nn.MaxPool2d(kernel_size=2, stride=2, padding=0)) self.encoder_z = nn.Sequential(#nn.Dropout(0.5), nn.Linear(20*7*7 + args.y_dim, args.hidden_dim), # x_dim + args.y_dim # nn.BatchNorm1d(args.hidden_dim), nn.Softplus(), nn.Linear(args.hidden_dim, args.hidden_dim), #nn.BatchNorm1d(args.hidden_dim), nn.Softplus(), #nn.Dropout(0.5), nn.Linear(args.hidden_dim, 2*args.z_dim)) # initialize weights to N(0, 0.001) and biases to 0 (cf SSL section 4.4) for p in self.parameters(): p.data.normal_(0, 0.001) if p.ndimension() == 1: p.data.fill_(0.)
def __init__(self, loc, scale): self.loc = loc self.scale = scale self.base_dist = pyd.Normal(loc, scale) transforms = [TanhTransform()] super().__init__(self.base_dist, transforms)
def draw_message_distributions_tracker1_2d(mu, sigma, save_dir, azim): # sigma *= torch.where(sigma<0.01, torch.ones(sigma.shape).cuda(), 10*torch.ones(sigma.shape).cuda()) mu = mu.view(-1, 2) sigma = sigma.view(-1, 2) s_mu = torch.Tensor([0.0, 0.0]) s_sigma = torch.Tensor([1.0, 1.0]) x = y = np.arange(100) t = np.meshgrid(x, y) d = D.Normal(s_mu, s_sigma) d1 = D.Normal(mu[0], sigma[0]) d21 = D.Normal(mu[2], sigma[2]) d22 = D.Normal(mu[3], sigma[3]) d31 = D.Normal(mu[4], sigma[4]) d32 = D.Normal(mu[5], sigma[5]) print('Entropy') print(d1.entropy().detach().cpu().numpy()) print(d21.entropy().detach().cpu().numpy()) print(d22.entropy().detach().cpu().numpy()) print(d31.entropy().detach().cpu().numpy()) print(d32.entropy().detach().cpu().numpy()) print('KL Divergence') for tt_i in range(3): d1 = D.Normal(mu[tt_i * 2 + 0], sigma[tt_i * 2 + 0]) d2 = D.Normal(mu[tt_i * 2 + 1], sigma[tt_i * 2 + 1]) print(tt_i, D.kl_divergence(d1, d2).sum().detach().cpu().numpy(), D.kl_divergence(d1, d).sum().detach().cpu().numpy(), D.kl_divergence(d2, d).sum().detach().cpu().numpy(), sigma[tt_i * 2 + 0].mean().detach().cpu().numpy(), sigma[tt_i * 2 + 1].mean().detach().cpu().numpy()) # Numpy array of mu and sigma s_mu_ = s_mu.detach().cpu().numpy() mu_0 = mu[0].detach().cpu().numpy() mu_2 = mu[2].detach().cpu().numpy() mu_3 = mu[3].detach().cpu().numpy() mu_4 = mu[4].detach().cpu().numpy() mu_5 = mu[5].detach().cpu().numpy() s_sigma_ = s_sigma.detach().cpu().numpy() sigma_0 = sigma[0].detach().cpu().numpy() sigma_2 = sigma[2].detach().cpu().numpy() sigma_3 = sigma[3].detach().cpu().numpy() sigma_4 = sigma[4].detach().cpu().numpy() sigma_5 = sigma[5].detach().cpu().numpy() # Print print('mu and sigma') print(mu_0, sigma_0) print(mu_2, sigma_2) print(mu_3, sigma_3) print(mu_4, sigma_4) print(mu_5, sigma_5) # Create grid x = np.linspace(-5, 5, 5000) y = np.linspace(-5, 5, 5000) X, Y = np.meshgrid(x, y) pos = np.empty(X.shape + (2, )) pos[:, :, 0] = X pos[:, :, 1] = Y rv = multivariate_normal(s_mu_, [[s_sigma[0], 0], [0, s_sigma[1]]]) # Agent 1 # Create multivariate normal rv1 = multivariate_normal(mu_0, [[sigma_0[0], 0], [0, sigma_0[1]]]) # Make a 3D plot fig = plt.figure() ax = fig.gca(projection='3d') ax.plot_surface(X, Y, rv.pdf(pos) + rv1.pdf(pos), cmap='viridis', linewidth=0) ax.set_xlabel('X axis') ax.set_ylabel('Y axis') ax.set_zlabel('message') plt.tight_layout() plt.savefig(save_dir + 'agent1.png') ax.view_init(elev=0., azim=azim) plt.savefig(save_dir + ("agent1_0_%i.png" % int(azim))) ax.view_init(elev=90., azim=0.) plt.savefig(save_dir + "agent1_90_0.png") # plt.show() plt.close() # Agent 2 # Create multivariate normal rv21 = multivariate_normal(mu_2, [[sigma_2[0], 0], [0, sigma_2[1]]]) rv22 = multivariate_normal(mu_3, [[sigma_3[0], 0], [0, sigma_3[1]]]) # Make a 3D plot fig = plt.figure() ax = fig.gca(projection='3d') ax.plot_surface(X, Y, rv.pdf(pos) + rv21.pdf(pos) + rv22.pdf(pos), cmap='viridis', linewidth=0) ax.set_xlabel('X axis') ax.set_ylabel('Y axis') ax.set_zlabel('message') plt.tight_layout() plt.savefig(save_dir + 'agent2.png') ax.view_init(elev=0., azim=azim) plt.savefig(save_dir + ("agent2_0_%i.png" % int(azim))) ax.view_init(elev=90., azim=0.) plt.savefig(save_dir + "agent2_90_0.png") # plt.show() plt.close() # Agent 3 rv31 = multivariate_normal(mu_4, [[sigma_4[0], 0], [0, sigma_4[1]]]) rv32 = multivariate_normal(mu_5, [[sigma_5[0], 0], [0, sigma_5[1]]]) # Make a 3D plot fig = plt.figure() ax = fig.gca(projection='3d') ax.plot_surface(X, Y, rv.pdf(pos) + rv31.pdf(pos) + rv32.pdf(pos), cmap='viridis', linewidth=0) ax.set_xlabel('X axis') ax.set_ylabel('Y axis') ax.set_zlabel('message') plt.tight_layout() plt.savefig(save_dir + 'agent3.png') ax.view_init(elev=0., azim=azim) plt.savefig(save_dir + ("agent3_0_%i.png" % int(azim))) ax.view_init(elev=90., azim=0.) plt.savefig(save_dir + "agent3_90_0.png") # plt.show() plt.close() # Overall # Make a 3D plot fig = plt.figure() ax = fig.gca(projection='3d') ax.plot_surface(X, Y, rv.pdf(pos) + rv1.pdf(pos) + rv21.pdf(pos) + rv22.pdf(pos) + rv31.pdf(pos) + rv32.pdf(pos), cmap='viridis', linewidth=0) ax.set_xlabel('X axis') ax.set_ylabel('Y axis') ax.set_zlabel('message') plt.tight_layout() plt.savefig(save_dir + 'overall.png') ax.view_init(elev=0., azim=azim) plt.savefig(save_dir + ("overall_0_%i.png" % int(azim))) ax.view_init(elev=90., azim=0.) plt.savefig(save_dir + "overall_90_0.png") # plt.show() plt.close()
def base_dist(self): return D.Normal(self.base_dist_mean, self.base_dist_var)
def __init__( self, x_dim, y_dim, z_dim, w_dim, predict_x_var=False, labels_mutually_exclusive=True, use_bernoulli_y=False, z_prior_var=1, likelihood_partition=None, likelihood_params={ 'lik_var': 0.1**2, 'lik_var_lognormal': 0.1**2 }, pw_y_means=torch.tensor([[-1.], [1.]]), pw_y_vars=torch.tensor([[0.5**2], [0.5**2]]), show_val_loss=False, beta1=20, # reconstruction_error, beta2=1, # how strongly we want the identification network for q(w|x,y) to mimic the prior p(w|y) beta3=0.2, # how strongly we want the identification network q(z|x) to mimic the prior p(z) beta4=10, # beta4 and beta5 control how strongly we want to enforce the mutual information minimization between beta5=1, # w and z. If too small, z will encode information about w. should be tweaked. optimizer1=None, optimizer2=None, device='cpu'): super().__init__() self.x_dim = x_dim self.y_dim = y_dim self.z_dim = z_dim self.w_dim = w_dim print(f"x_dim: {x_dim}") print(f"y_dim: {y_dim}") print(f"z_dim: {z_dim}") print(f"w_dim: {w_dim}") print(f"beta1: {beta1}") print(f"beta2: {beta2}") print(f"beta3: {beta3}") print(f"beta4: {beta4}") print(f"beta5: {beta5}") if likelihood_partition is None: # range for feature types self.likelihood_partition = {(0, x_dim - 1): 'real'} else: self.likelihood_partition = likelihood_partition self.likelihood_params = likelihood_params self.z_prior_var = z_prior_var self.device = device self.predict_x_var = predict_x_var self.labels_mutually_exclusive = labels_mutually_exclusive # prior on z p_z_loc = torch.zeros(z_dim).to(device) cov_diag = z_prior_var * torch.ones(z_dim).to(device) self.p_z = D.Normal(loc=p_z_loc, scale=cov_diag.sqrt()) # prior on y if labels_mutually_exclusive: if not use_bernoulli_y: self.p_y = D.Categorical(probs=1 / y_dim * torch.ones(1, y_dim)) else: if y_dim != 2: raise ValueError("using bernoulli y with a y_dim != 2") self.p_y = D.Bernoulli(probs=0.5) else: self.p_y = D.Bernoulli(probs=0.5 * torch.ones(y_dim)) self.decoder_x = nn.Sequential(nn.Linear(z_dim + w_dim, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, x_dim)) self.encoder_w = nn.Sequential( nn.Linear(x_dim + ((y_dim - 1) if y_dim == 2 else y_dim), 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 2 * w_dim)) self.encoder_z = nn.Sequential(nn.Linear(x_dim, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 2 * z_dim)) self.optim_1_params = chain(self.decoder_x.parameters(), self.encoder_w.parameters(), self.encoder_z.parameters()) self.decoder_y = nn.Sequential(nn.Linear(z_dim, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, y_dim)) self.optim_2_params = chain(self.decoder_y.parameters()) if optimizer1 is None: self.optimizer1 = torch.optim.Adam(self.optim_1_params, lr=1e-3) else: self.optimizer1 = optimizer1 if optimizer2 is None: self.optimizer2 = torch.optim.Adam(self.optim_2_params, lr=1e-3) else: self.optimizer2 = optimizer2 self.pw_y_means = pw_y_means self.pw_y_vars = pw_y_vars self.show_val_loss = show_val_loss self.beta1 = beta1 self.beta2 = beta2 self.beta3 = beta3 self.beta4 = beta4 self.beta5 = beta5 self.pw_y0 = self.get_pw_y(torch.tensor(0)) self.pw_y1 = self.get_pw_y(torch.tensor(1))
def kde_pig_dl( dm: pl.LightningDataModule, batch_size: int, N_hat_multiplier: float = 1, ) -> DataLoader: # % gd_n_steps, gd_lr, gd_threshold = 5, 4e-1, 0.005 # Spherical = each component has single variance. bgm = BayesianGaussianMixture( n_components=batch_size, covariance_type="spherical", warm_start=True, ) x_hat = torch.Tensor() for idx, batch in enumerate(iter(dm.train_dataloader())): x, _ = batch device = x.device x = x.detach().cpu().numpy() # Last batch might have less elements than origin n_components if x.shape[0] < bgm.n_components: bgm = BayesianGaussianMixture( n_components=x.shape[0], covariance_type="spherical", ) # Estimate KDE bgm.fit(x) # [N_components, 1], [N_components, N_features], [N_components, 1] weights, means, variances = ( torch.Tensor(bgm.weights_).to(device), torch.Tensor(bgm.means_).to(device), torch.Tensor(bgm.covariances_).to(device), ) filter_weights_idx = weights >= 1e-5 weights, means, variances = ( weights[filter_weights_idx], means[filter_weights_idx], variances[filter_weights_idx][:, None], ) n_selected_components = weights.shape[0] p_x = D.Independent(D.Normal(means, torch.sqrt(variances)), 1) mix = D.Categorical(weights) p_x = D.MixtureSameFamily(mix, p_x) # Sample according to multiplier x_start = p_x.sample( ( n_selected_components * ((batch_size // n_selected_components) + 1) * N_hat_multiplier, ) ).reshape(-1, x.shape[1]) # Use GD _x_hat = density_gradient_descent( p_x, x_start, {"N_steps": gd_n_steps, "lr": gd_lr, "threshold": gd_threshold}, ) # Ensure same device if x_hat.device != device: x_hat = x_hat.to(device) x_hat = torch.cat((x_hat, _x_hat.detach())) dl = DataLoader(TensorDataset(x_hat), batch_size=batch_size, shuffle=True) return dl
def forward(self, query, key, value): batch_size = query.shape[0] maxlen = query.shape[1] Q = self.fc_q(query) K = self.fc_k(key) V = self.fc_v(value) # Q = [batch size, query len, hid dim] # K = [batch size, key len, hid dim] # V = [batch size, value len, hid dim] Q = Q.view(batch_size, maxlen, self.n_heads, self.head_dim).permute(0, 2, 1, 3) K = K.view(batch_size, maxlen, self.n_heads, self.head_dim).permute(0, 2, 1, 3) V = V.view(batch_size, maxlen, self.n_heads, self.head_dim).permute(0, 2, 1, 3) # Q = [batch size, n heads, query len, head dim] # K = [batch size, n heads, key len, head dim] # V = [batch size, n heads, value len, head dim] KLD = torch.tensor(0.0) if self.args.att_type == 'dot': energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale elif self.args.att_type == 'ikandirect': w1_proj = self.attsharedw.w1 w2_proj = self.attsharedw.w2 scores, norm = ika_ns(Q, K, self.args, self.scale, w1_proj, w2_proj, 2 * np.pi, self.training) energy = torch.log(scores + (1e-5)) + norm elif self.args.att_type == 'mikan': ''' copula augmented estimation ''' mu, logvar, L = self.attsharedw.copulanet(Q, K) mu = mu.squeeze(-1) logvar = logvar.squeeze(-1) var = torch.exp(logvar) dim_batch_size, num_head, num_head = L.size() dim = int(dim_batch_size / batch_size) pos_eps = torch.randn([dim, num_head, self.args.M // 2]).cuda() # [64,8,128(M/2)] X_pos = torch.einsum('ijk,ijl->ijl', L, pos_eps) # [64,8,128(M/2)] X_pos = torch.clamp(X_pos, min=-2.0, max=2.0) U_pos = self.standard_normal_dist.cdf( X_pos) # [64,num_head,128(M/2)] neg_eps = torch.randn([dim, num_head, self.args.M // 2]).cuda() # [64,8,128(M/2)] X_neg = torch.einsum('ijk,ijl->ijl', L, neg_eps) # [64,8,128(M/2)] X_neg = torch.clamp(X_neg, min=-2.0, max=2.0) U_neg = self.standard_normal_dist.cdf( X_neg) # [64,num_head,128(M/2)] marginal_pos = Normal( mu.unsqueeze(-1), var.unsqueeze(-1)) # mu : [64,num_head] / var : [64,num_head] marginal_neg = Normal( -1 * mu.unsqueeze(-1), var.unsqueeze(-1)) # mu : [64,num_head] / var : [64,num_head] Y_pos = marginal_pos.icdf(U_pos) # [32,4,64] Y_neg = marginal_neg.icdf(U_neg) U = torch.cat([U_pos, U_neg]) ent_copula = -1 * torch.sum(torch.mul(U, torch.log(U + (1e-5)))) ''' kernel and norm calculation ''' z = torch.cat([Y_pos, Y_neg], -1) # torch.Size([1, 64, 4, 256]) w1_proj = self.attsharedw.wnet1(z) w2_proj = self.attsharedw.wnet2(z) scores, norm = ika_ns(Q, K, self.args, self.scale, w1_proj, w2_proj, 2 * np.pi, self.training) energy = torch.log(scores + (1e-5)) + norm # energy = [batch size, n heads, query len, key len] q_dist = tdist.Normal(mu, logvar.exp()) KLD = torch.distributions.kl_divergence(q_dist, self.p_dist) KLD = self.args.kl_lambda * torch.sum( KLD) + self.args.copula_lambda * ent_copula attention = torch.softmax(energy, dim=-1) # attention = [batch size, n heads, query len, key len] x = torch.matmul(self.dropout(attention), V) # x = [batch size, n heads, query len, head dim] x = x.permute(0, 2, 1, 3).contiguous() # x = [batch size, query len, n heads, head dim] x = x.view(batch_size, -1, self.args.KEY_DIM) # x = [batch size, query len, hid dim] x = self.fc_o(x) # x = [batch size, query len, hid dim] return x, attention, KLD
def likelihood(self, z): loc, scale = self.decoder(z) return dist.Normal(loc, scale)
def _train(replay_buffer, net, teacher_net, criterion, coord_converter, logger, config, episode): import torch from phase2_utils import _log_visuals, get_weight, repeat import torch.distributions as tdist noiser = tdist.Normal(torch.tensor(0.0), torch.tensor(config['speed_noise'])) teacher_net.eval() for epoch in range(config['epoch_per_episode']): optimizer = torch.optim.Adam(net.parameters(), lr=1e-4) net.train() replay_buffer.init_new_weights() loader = torch.utils.data.DataLoader(replay_buffer, batch_size=config['batch_size'], num_workers=4, shuffle=True, drop_last=True) for i, (idxes, rgb_image, command, speed, target, birdview) in enumerate(loader): if i % 100 == 0: print("ITER: %d" % i) rgb_image = rgb_image.to(config['device']).float() birdview = birdview.to(config['device']).float() command = one_hot(command).to(config['device']).float() speed = speed.to(config['device']).float() if config['speed_noise'] > 0: speed += noiser.sample(speed.size()).to(speed.device) speed = torch.clamp(speed, 0, 10) if len(rgb_image.size()) > 4: B, batch_aug, c, h, w = rgb_image.size() rgb_image = rgb_image.view(B * batch_aug, c, h, w) birdview = repeat(birdview, batch_aug) command = repeat(command, batch_aug) speed = repeat(speed, batch_aug) else: B = rgb_image.size(0) batch_aug = 1 with torch.no_grad(): _teac_location, _teac_locations = teacher_net( birdview, speed, command) _pred_location, _pred_locations = net(rgb_image, speed, command) pred_location = coord_converter(_pred_location) pred_locations = coord_converter(_pred_locations) optimizer.zero_grad() loss = criterion(pred_locations, _teac_locations) # Compute resample weights pred_location_normed = pred_location / (0.5 * CROP_SIZE) - 1. weights = get_weight(pred_location_normed, _teac_location) weights = torch.mean(torch.stack(torch.chunk(weights, B)), dim=0) replay_buffer.update_weights(idxes, weights) loss_mean = loss.mean() loss_mean.backward() optimizer.step() should_log = False should_log |= i % config['log_iterations'] == 0 if should_log: metrics = dict() metrics['loss'] = loss_mean.item() images = _log_visuals( rgb_image, birdview, speed, command, loss, pred_location, (_pred_location + 1) * coord_converter._img_size / 2, _teac_location) logger.scalar(loss_mean=loss_mean.item()) logger.image(birdview=images) replay_buffer.normalize_weights() rgb_image, birdview, command, speed, target = replay_buffer.get_highest_k( 32) rgb_image = rgb_image.to(config['device']).float() birdview = birdview.to(config['device']).float() command = one_hot(command).to(config['device']).float() speed = speed.to(config['device']).float() with torch.no_grad(): _teac_location, _teac_locations = teacher_net( birdview, speed, command) net.eval() _pred_location, _pred_locations = net(rgb_image, speed, command) pred_location = coord_converter(_pred_location) pred_locations = coord_converter(_pred_locations) pred_location_normed = pred_location / (0.5 * CROP_SIZE) - 1. weights = get_weight(pred_location_normed, _teac_location) # TODO: Plot highest images = _log_visuals(rgb_image, birdview, speed, command, weights, pred_location, (_pred_location + 1) * coord_converter._img_size / 2, _teac_location) logger.image(topk=images) logger.end_epoch() if episode in SAVE_EPISODES: torch.save(net.state_dict(), str(Path(config['log_dir']) / ('model-%d.th' % episode)))
def forward(self, x): return td.Normal(x, self.log_scale.exp())
def __call__(self, wav): noiser = distributions.Normal(0, self.var) if np.random.uniform() < 0.5: wav += noiser.sample(wav.size()) return wav.clamp(-1, 1)
retain_graph=True, create_graph=True)[0] eps = torch.randn_like(dfdx) epsH = torch.autograd.grad(dfdx, x, grad_outputs=eps, create_graph=True, retain_graph=True)[0] trH = (epsH * eps).sum(1) norm_s = (dfdx * dfdx).sum(1) loss = (trH + .5 * norm_s).mean() elif args.mode == "nce": noise_dist = distributions.Normal(init_mu, init_std) x_fake = noise_dist.sample_n(x.size(0)) pos_logits = modelICA( x) + approx_normalizing_const - noise_dist.log_prob(x).sum( 1) neg_logits = modelICA( x_fake) + approx_normalizing_const - noise_dist.log_prob( x_fake).sum(1) pos_loss = nn.BCEWithLogitsLoss()(pos_logits, torch.ones_like(pos_logits)) neg_loss = nn.BCEWithLogitsLoss()(neg_logits, torch.zeros_like(neg_logits)) loss = pos_loss + neg_loss
def MIWAE(X_miss, h=128, d=1, K=1000, L=20, bs=64, n_epochs=201): mask = (1 - np.isnan(X_miss).numpy()).astype(bool) xhat_0 = np.where(np.isnan(X_miss), 0, X_miss) n = np.shape(X_miss)[0] p = np.shape(X_miss)[1] p_z = td.Independent( td.Normal(loc=torch.zeros(d).cuda(), scale=torch.ones(d).cuda()), 1) decoder = nn.Sequential( torch.nn.Linear(d, h), torch.nn.ReLU(), torch.nn.Linear(h, h), torch.nn.ReLU(), torch.nn.Linear( h, 3 * p ), # the decoder will output both the mean, the scale, and the number of degrees of freedoms (hence the 3*p) ) encoder = nn.Sequential( torch.nn.Linear(p, h), torch.nn.ReLU(), torch.nn.Linear(h, h), torch.nn.ReLU(), torch.nn.Linear( h, 2 * d ), # the encoder will output both the mean and the diagonal covariance ) encoder.cuda() # we'll use the GPU decoder.cuda() optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3) xhat = np.copy(xhat_0) # This will be out imputed data matrix encoder.apply(weights_init) decoder.apply(weights_init) for ep in range(1, n_epochs): perm = np.random.permutation( n) # We use the "random reshuffling" version of SGD batches_data = np.array_split(xhat_0[perm, ], n / bs) batches_mask = np.array_split(mask[perm, ], n / bs) for it in range(len(batches_data)): optimizer.zero_grad() encoder.zero_grad() decoder.zero_grad() b_data = torch.from_numpy(batches_data[it]).float().cuda() b_mask = torch.from_numpy(batches_mask[it]).float().cuda() loss = miwae_loss(iota_x=b_data, mask=b_mask, d=d, K=K, p_z=p_z, encoder=encoder, decoder=decoder) loss.backward() optimizer.step() ### Imputation xhat[~mask] = miwae_impute( iota_x=torch.from_numpy(xhat_0).float().cuda(), mask=torch.from_numpy(mask).float().cuda(), L=L, d=d, p_z=p_z, encoder=encoder, decoder=decoder).cpu().data.numpy()[~mask] x_miwae = np.where(np.isnan(X_miss), xhat, X_miss) return (x_miwae)
screens, obs, acts, rewards, dones, next_screens, next_obs = batch batch = [] k = 0 # Double-Q(s,a) q1, q2 = crt_net(screens, obs, acts) q_min = torch.min(q1,q2).squeeze() tracker.track("Q(s,a)_batch_mean", q_min.mean(), exp_count) # Double-Q(s',a') next_mu, next_std = act_net(next_screens, next_obs) # no action cliping here next_q1, next_q2 = tgt_crt_net(next_screens, next_obs, next_mu) # take expected action, what else? # Q_ref next_acts_distr = distr.Normal(next_mu, next_std) v_next = torch.min(next_q1, next_q2).squeeze() - \ ENTROPY_ALPHA*next_acts_distr.log_prob(next_mu).sum(dim=1) # eqn (3) in 2nd SAC paper v_next[dones] = 0.0 # if terminated, local reward makes up the value already q_ref = (rewards + v_next*(GAMMA**UNROLL)).unsqueeze(dim=-1) tracker.track("Q_ref(s,a)_batch_mean", q_ref.mean(), exp_count) # Critic loss (on both Double-Q heads to update all parameters) q1_loss = F.mse_loss(q1, q_ref.detach()) # eqn (5) in 2nd SAC paper q2_loss = F.mse_loss(q2, q_ref.detach()) q_loss = q1_loss + q2_loss q_loss.backward(retain_graph=True) crt_opt.step() tracker.track("Loss_crt", q_loss, exp_count)
def forward(self, x: torch.Tensor): loss = -tdist.Normal(self.mu, self.sigma).log_prob(x) return torch.sum(loss) / (loss.size(0) * loss.size(1))
def sample_normal(means, **kwargs): return sample_(D.Normal(means, **kwargs))
def test_non_empty_bit_vector(batch_shape=tuple(), K=3): assert K<= 10, "I test against explicit enumeration, K>10 might be too slow for that" # Uniform F = NonEmptyBitVector(torch.zeros(batch_shape + (K,))) # Shapes assert F.batch_shape == batch_shape, "NonEmptyBitVector has the wrong batch_shape" assert F.dim == K, "NonEmptyBitVector has the wrong dim" assert F.event_shape == (K,), "NonEmptyBitVector has the wrong event_shape" assert F.scores.shape == batch_shape + (K,), "NonEmptyBitVector.score has the wrong shape" assert F.arc_weight.shape == batch_shape + (K+1,3,3), "NonEmptyBitVector.arc_weight has the wrong shape" assert F.state_value.shape == batch_shape + (K+2,3), "NonEmptyBitVector.state_value has the wrong shape" assert F.state_rvalue.shape == batch_shape + (K+2,3), "NonEmptyBitVector.state_rvalue has the wrong shape" # shape: [num_faces] + batch_shape + [K] support = F.enumerate_support() # test shape of support assert support.shape == (2**K-1,) + batch_shape + (K,), "The support has the wrong shape" assert F.expand((2,3) + batch_shape).batch_shape == (2,3) + batch_shape, "Bad expand batch_shape" assert F.expand((2,3) + batch_shape).event_shape == (K,), "Bad expand event_shape" assert F.expand((2,3) + batch_shape).sample().shape == (2,3) + batch_shape + (K,), "Bad expand single sample" assert F.expand((2,3) + batch_shape).sample((13,)).shape == (13,2,3) + batch_shape + (K,), "Bad expand multiple samples" # Constraints assert (support.sum(-1) > 0).all(), "The support has an empty bit vector" for _ in range(100): # testing one sample at a time assert F.sample().sum(-1).all(), "I found an empty vector" # testing a batch of samples assert F.sample((100,)).sum(-1).all(), "I found an empty vector" # testing a complex batch of samples assert F.sample((2, 100,)).sum(-1).all(), "I found an empty vector" # Distribution # check for uniform probabilities assert torch.isclose(F.log_prob(support).exp(), torch.tensor(1./F.support_size)).all(), "Non-uniform" # check for uniform marginal probabilities assert torch.isclose(F.sample((10000,)).float().mean(0), support.mean(0), atol=1e-1).all(), "Bad MC marginals" assert torch.isclose(F.marginals(), support.mean(0)).all(), "Bad exact marginals" # Entropy # [num_faces, B] log_prob = F.log_prob(support) assert torch.isclose(F.entropy(), (-(log_prob.exp() * log_prob).sum(0)), atol=1e-2).all(), "Problem in the entropy DP" # Non-Uniform # Entropy P = NonEmptyBitVector(td.Normal(torch.zeros(batch_shape + (K,)), torch.ones(batch_shape + (K,))).sample()) log_p = P.log_prob(support) assert torch.isclose(P.entropy(), (-(log_p.exp() * log_p).sum(0)), atol=1e-2).all(), "Problem in the entropy DP" # Cross-Entropy Q = NonEmptyBitVector(td.Normal(torch.zeros(batch_shape + (K,)), torch.ones(batch_shape + (K,))).sample()) log_q = Q.log_prob(support) assert torch.isclose(P.cross_entropy(Q), -(log_p.exp() * log_q).sum(0), atol=1e-2).all(), "Problem in the cross-entropy DP" # KL assert torch.isclose(td.kl_divergence(P, Q), (log_p.exp() * (log_p - log_q)).sum(0), atol=1e-2).all(), "Problem in KL" # Constraints for _ in range(100): # testing one sample at a time assert P.sample().sum(-1).all(), "I found an empty vector" assert Q.sample().sum(-1).all(), "I found an empty vector" # testing a batch of samples assert P.sample((100,)).sum(-1).all(), "I found an empty vector" assert Q.sample((100,)).sum(-1).all(), "I found an empty vector" # testing a complex batch of samples assert P.sample((2, 100,)).sum(-1).all(), "I found an empty vector" assert Q.sample((2, 100,)).sum(-1).all(), "I found an empty vector"
def simulate_Y(H, Params): mu, sigma = torch.matmul(H, Params.W_mu_y), torch.matmul(H, Params.W_sigma_y) Y = dist.Normal(mu, sigma).sample() return Y
continue #generate labels for the real batch of data...the (k+1)th element is 1...rest are zero D_label_real = utils.get_labels(num_generators, -1, real_b_size, device) #forward pass for the real batch of data and then resize gen_input_noise = utils.generate_noise_for_generator( real_b_size // num_generators, n_z, device) gen_output = generator( gen_input_noise) #, real_b_size//num_generators) gen_out_d_in = gen_output.detach() ############################################################## norm = dist.Normal(torch.tensor([NOISE_MEAN]), torch.tensor([NOISE_DEV])) if add_noise == 1: x_noise = norm.sample(gen_out_d_in.size()).view( gen_out_d_in.size()).to(device) gen_out_d_in = gen_out_d_in + x_noise ################################################################# D_Label_Fake = [] for g in range(num_generators): D_Label_Fake.append( utils.get_labels(num_generators, g, real_b_size // num_generators, device)) D_Label_Fake = torch.cat(D_Label_Fake) D_Labels = torch.cat([D_label_real, D_Label_Fake])
def _train_continuous(self, BATCH): q1 = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q2 = self.critic2(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] if self.is_continuous: target_mu, target_log_std = self.actor( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] dist = td.Independent(td.Normal(target_mu, target_log_std.exp()), 1) target_pi = dist.sample() # [T, B, A] target_pi, target_log_pi = squash_action( target_pi, dist.log_prob(target_pi).unsqueeze(-1)) # [T, B, A], [T, B, 1] else: target_logits = self.actor( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] target_cate_dist = td.Categorical(logits=target_logits) target_pi = target_cate_dist.sample() # [T, B] target_log_pi = target_cate_dist.log_prob(target_pi).unsqueeze( -1) # [T, B, 1] target_pi = F.one_hot(target_pi, self.a_dim).float() # [T, B, A] q1_target = self.critic.t(BATCH.obs_, target_pi, begin_mask=BATCH.begin_mask) # [T, B, 1] q2_target = self.critic2.t(BATCH.obs_, target_pi, begin_mask=BATCH.begin_mask) # [T, B, 1] q_target = th.minimum(q1_target, q2_target) # [T, B, 1] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, (q_target - self.alpha * target_log_pi), BATCH.begin_mask).detach() # [T, B, 1] td_error1 = q1 - dc_r # [T, B, 1] td_error2 = q2 - dc_r # [T, B, 1] q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean() # 1 q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean() # 1 critic_loss = 0.5 * q1_loss + 0.5 * q2_loss self.critic_oplr.optimize(critic_loss) if self.is_continuous: mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) pi = dist.rsample() # [T, B, A] pi, log_pi = squash_action( pi, dist.log_prob(pi).unsqueeze(-1)) # [T, B, A], [T, B, 1] entropy = dist.entropy().mean() # 1 else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] pi = _pi_diff + _pi # [T, B, A] log_pi = (logp_all * pi).sum(-1, keepdim=True) # [T, B, 1] entropy = -(logp_all.exp() * logp_all).sum(-1).mean() # 1 q_s_pi = th.minimum( self.critic(BATCH.obs, pi, begin_mask=BATCH.begin_mask), self.critic2(BATCH.obs, pi, begin_mask=BATCH.begin_mask)) # [T, B, 1] actor_loss = -(q_s_pi - self.alpha * log_pi).mean() # 1 self.actor_oplr.optimize(actor_loss) summaries = { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/q1_loss': q1_loss, 'LOSS/q2_loss': q2_loss, 'LOSS/critic_loss': critic_loss, 'Statistics/log_alpha': self.log_alpha, 'Statistics/alpha': self.alpha, 'Statistics/entropy': entropy, 'Statistics/q_min': th.minimum(q1, q2).min(), 'Statistics/q_mean': th.minimum(q1, q2).mean(), 'Statistics/q_max': th.maximum(q1, q2).max() } if self.auto_adaption: alpha_loss = -(self.alpha * (log_pi + self.target_entropy).detach()).mean() # 1 self.alpha_oplr.optimize(alpha_loss) summaries.update({ 'LOSS/alpha_loss': alpha_loss, 'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr }) return (td_error1 + td_error2) / 2, summaries
def __init__(self, loc, scale_diag, validate_args=None): base_distribution = dists.Normal(loc=loc, scale=scale_diag) super().__init__(base_distribution, reinterpreted_batch_ndims=1, validate_args=validate_args)
def _prior(self, batch_size, sequence_len): mu = torch.zeros(batch_size, sequence_len, self._z_dim) std = torch.ones_like(mu) return td.Independent(td.Normal(mu, std), 1)
def __init__(self, loc, scale, reinterpreted_batch_ndims=None, **kwargs): if reinterpreted_batch_ndims is None: # Handle 1D input, e.g. Normal(0, 1) reinterpreted_batch_ndims = 1 if isinstance(loc, torch.Tensor) else 0 super().__init__(td.Normal(loc, scale, **kwargs), reinterpreted_batch_ndims)
def forward(self, x=None): return dist.Normal(self.mu, self.sigma)
def _component_sample(self, idxs): mean, std = self.mean[idxs], self.log_std[idxs].exp() return distributions.Normal(mean, std).sample()