batch_size, device, dtype) # Train C2ST-L np.random.seed(seed=819*kk + 1 + i + n) y = (torch.cat((torch.zeros(N1, 1), torch.ones(N2, 1)), 0)).squeeze(1).to(device, dtype).long() pred, STAT_C2ST_L, model_C2ST_L, w_C2ST_L, b_C2ST_L = C2ST_NN_fit(S, y, N1, x_in, H, x_out, 0.001, N_epoch_C, batch_size, device, dtype) # Train MMD-O np.random.seed(seed=1102) torch.manual_seed(1102) torch.cuda.manual_seed(1102) sigma0 = 2*d * torch.rand([1]).to(device, dtype) sigma0.requires_grad = True optimizer_sigma0 = torch.optim.Adam([sigma0], lr=learning_ratea) for t in range(N_epoch): TEMPa = MMDu(S, N1, S, 0, sigma0, is_smooth=False) mmd_value_tempa = -1 * (TEMPa[0]+10**(-8)) mmd_std_tempa = torch.sqrt(TEMPa[1]+10**(-8)) if mmd_std_tempa.item() == 0: print('error!!') if np.isnan(mmd_std_tempa.item()): print('error!!') STAT_adaptive = torch.div(mmd_value_tempa, mmd_std_tempa) J_star_adp[t] = STAT_adaptive.item() optimizer_sigma0.zero_grad() STAT_adaptive.backward(retain_graph=True) # Update sigma0 using gradient descent optimizer_sigma0.step() if t % 100 == 0: print("mmd_value: ", -1 * mmd_value_tempa.item(), "mmd_std: ", mmd_std_tempa.item(), "Statistic: ", -1 * STAT_adaptive.item())
X = torch.cat([real_imgs, Fake_imgs], 0) Y = torch.cat([valid, fake], 0).squeeze().long() # ------------------------------ # Train deep network for MMD-D # ------------------------------ # Initialize optimizer optimizer_F.zero_grad() # Compute output of deep network modelu_output = featurizer(X) # Compute epsilon, sigma and sigma_0 ep = torch.exp(epsilonOPT) / (1 + torch.exp(epsilonOPT)) sigma = sigmaOPT**2 sigma0_u = sigma0OPT**2 # Compute Compute J (STAT_u) TEMP = MMDu(modelu_output, imgs.shape[0], X.view(X.shape[0], -1), sigma, sigma0_u, ep) mmd_value_temp = -1 * (TEMP[0]) mmd_std_temp = torch.sqrt(TEMP[1] + 10**(-8)) STAT_u = torch.div(mmd_value_temp, mmd_std_temp) # Compute gradient STAT_u.backward() # Update weights using gradient descent optimizer_F.step() # ------------------------------------------ # Train deep network for C2ST-S and C2ST-L # ------------------------------------------ # Initialize optimizer optimizer_D.zero_grad() # Compute Cross-Entropy (loss_C) loss between two samples loss_C = adversarial_loss(discriminator(X), Y)
Dxy = Pdist2(Sv[:N1, :], Sv[N1:, :]) Dxy_org = Pdist2(S_FEA[:N1, :], S_FEA[N1:, :]) epsilonOPT = torch.log( MatConvert(np.random.rand(1) * 10**(-10), device, dtype)) epsilonOPT.requires_grad = True sigma0 = Dxy.median() sigma0.requires_grad = True sigmaOPT = MatConvert(np.ones(1) * np.sqrt(2 * 32 * 32), device, dtype) sigmaOPT.requires_grad = True optimizer_sigma0 = torch.optim.Adam([sigma0] + [sigmaOPT] + [epsilonOPT], lr=0.0002) for t in range(opt.n_epochs): ep = torch.exp(epsilonOPT) / (1 + torch.exp(epsilonOPT)) sigma = sigmaOPT**2 TEMPa = MMDu(Sv, N1, S_FEA, sigma, sigma0, ep, is_smooth=True) mmd_value_tempa = -1 * (TEMPa[0] + 10**(-8)) mmd_std_tempa = torch.sqrt(TEMPa[1] + 10**(-8)) STAT_adaptive = torch.div(mmd_value_tempa, mmd_std_tempa) optimizer_sigma0.zero_grad() STAT_adaptive.backward(retain_graph=True) optimizer_sigma0.step() if t % 100 == 0: print("mmd: ", -1 * mmd_value_tempa.item(), "mmd_std: ", mmd_std_tempa.item(), "Statistic: ", -1 * STAT_adaptive.item()) Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor # Compute test power of MMD-D and baselines H_adaptive = np.zeros(N)
X = torch.cat([real_imgs, Fake_imgs], 0) Y = torch.cat([valid, fake], 0).squeeze().long() # ------------------------------ # Train deep network for MMD-D # ------------------------------ # Initialize optimizer optimizer_F.zero_grad() # Compute output of deep network modelu_output = featurizer(X) # Compute epsilon, sigma and sigma_0 ep = torch.exp(epsilonOPT) / (1 + torch.exp(epsilonOPT)) sigma = sigmaOPT**2 sigma0_u = sigma0OPT**2 # Compute Compute J (STAT_u) TEMP = MMDu(modelu_output, imgs.shape[0], X.view(X.shape[0], -1), sigma, sigma0_u, ep) mmd_value_temp = -1 * (TEMP[0]) mmd_std_temp = torch.sqrt(TEMP[1] + 10**(-8)) STAT_u_F = torch.div(mmd_value_temp, mmd_std_temp) # Compute gradient STAT_u_F.backward() # Update weights using gradient descent optimizer_F.step() # ------------------------------------------------------ # Train deep network to distinguish two sets of samples # ------------------------------------------------------ # Initialize optimizer optimizer_D.zero_grad() # Compute Cross-Entropy (loss_C) loss between two samples loss_C = adversarial_loss(discriminator(X), Y)
S = np.concatenate((s1, s2), axis=0) S = MatConvert(S, device, dtype) # Train deep networks for G+C and D+C y = (torch.cat((torch.zeros(N1, 1), torch.ones(N2, 1)), 0)).squeeze(1).to(device, dtype).long() pred, STAT_C2ST, model_C2ST, w_C2ST, b_C2ST = C2ST_NN_fit( S, y, N1, x_in, H, x_out, learning_rate_C2ST, N_epoch, batch_size, device, dtype) # Train G+J np.random.seed(seed=1102) torch.manual_seed(1102) torch.cuda.manual_seed(1102) for t in range(N_epoch): modelu1_output = model_u1(S) TEMP1 = MMDu(modelu1_output, N1, S, sigma, sigma0_u, is_smooth=False) mmd_value_temp = -1 * (TEMP1[0] + 10**(-8)) mmd_std_temp = torch.sqrt(TEMP1[1] + 10**(-8)) if mmd_std_temp.item() == 0: print('error!!') if np.isnan(mmd_std_temp.item()): print('error!!') STAT_u1 = torch.div(mmd_value_temp, mmd_std_temp) J_star_u[t] = STAT_u1.item() optimizer_u1.zero_grad() STAT_u1.backward(retain_graph=True) # Update weights using gradient descent optimizer_u1.step() if t % 100 == 0: print("mmd: ", -1 * mmd_value_temp.item(), "mmd_std: ", mmd_std_temp.item(), "Statistic: ", -1 * STAT_u1.item())
N1 = Num_clusters * n N2 = Num_clusters * n # Train deep kernel to maximize test power np.random.seed(seed=1102) torch.manual_seed(1102) torch.cuda.manual_seed(1102) for t in range(N_epoch): # Compute epsilon, sigma and sigma_0 ep = torch.exp(epsilonOPT) / (1 + torch.exp(epsilonOPT)) sigma = sigmaOPT**2 sigma0_u = sigma0OPT**2 # Compute output of the deep network modelu_output = model_u(S) # Compute J (STAT_u) TEMP = MMDu(modelu_output, N1, S, sigma, sigma0_u, ep) mmd_value_temp = -1 * (TEMP[0] + 10**(-8)) mmd_std_temp = torch.sqrt(TEMP[1] + 10**(-8)) if mmd_std_temp.item() == 0: print('error!!') if np.isnan(mmd_std_temp.item()): print('error!!') STAT_u = torch.div(mmd_value_temp, mmd_std_temp) J_star_u[t] = STAT_u.item() # Initialize optimizer and Compute gradient optimizer_u.zero_grad() STAT_u.backward(retain_graph=True) # Update weights using gradient descent optimizer_u.step() # Print MMD, std of MMD and J if t % 100 == 0:
Fake_imgs = Variable(Fake_imgs.type(Tensor)) X = torch.cat([real_imgs, Fake_imgs], 0) Y = torch.cat([valid, fake], 0).squeeze().long() # ----------- # Train G+J # ----------- optimizer_F.zero_grad() modelu_output = featurizer(X) ep = torch.exp(epsilonOPT) / (1 + torch.exp(epsilonOPT)) sigma = sigmaOPT**2 sigma0_u = sigma0OPT**2 TEMP = MMDu(modelu_output, imgs.shape[0], X.view(X.shape[0], -1), sigma, sigma0_u, ep, is_smooth=False) mmd_value_temp = -1 * (TEMP[0]) mmd_std_temp = torch.sqrt(TEMP[1] + 10**(-8)) if mmd_std_temp.item() == 0: print('error std!!') if np.isnan(mmd_std_temp.item()): print('error mmd!!') STAT_u = torch.div(mmd_value_temp, mmd_std_temp) STAT_u.backward() optimizer_F.step() # ----------- # Train L+J