def train(batch=128, epochs=10): # denoise_fn = Unet(dim=16, channels=1, dim_mults=(1, 2, 4)).cuda() denoise_fn = Cond_Unet(dim=16, channels=1, dim_mults=(1, 2, 4)).cuda() cond_fn = cond_Net(in_dim=28).cuda() # cond_fn = cond_Net1(in_dim=28).cuda() betas = np.linspace(1e-4, 5*1e-2, 200) model = PriorGrad(denoise_fn, betas, cond_fn, loss_type='l1').cuda() train_loader, test_loader, val_loader = get_mnist_loaders(batch_size=batch) optim = torch.optim.Adam(model.parameters(), lr=1e-3) for epoch in range(epochs): with tqdm(train_loader) as it: for x, label in it: optim.zero_grad() model.get_cond_par(label.to(device)) loss = model(x.to(device), cond=label.to(device)) loss.backward() it.set_postfix( ordered_dict={ "epoch": epoch, 'loss': loss.item() } ) optim.step() # 生成过程 shape = (100, 1, 28, 28) label = torch.tensor(list(range(10))*10).long().to(device) model.eval() model.get_cond_par(label.to(device)) img1 = model.mu img2 = torch.exp(model.logstd) img = model.sample(shape, cond=label.to(device)) save_image(img, '../figures/priorgrad.jpg', nrow=10) save_image(img1, '../figures/priorgrad_mu.jpg', nrow=10) save_image(img2, '../figures/priorgrad_std.jpg', nrow=10)
[vpsde_beta_t(t, timesteps, 0.1, 20) for t in range(1, timesteps + 1)]) print(betas) denoise_fn = c_Gen(dim=16, dim_mults=(1, 2, 4), channels=1, latent_dim=z_dim).to(device) dis = c_Dis(dim=16, channels=2, dim_mults=(1, 2, 4)).to(device) gan = diffusion_gan(dis, denoise_fn, betas=betas, z_dim=z_dim, device=device) gen_optim = torch.optim.Adam(denoise_fn.parameters(), lr=2 * 1e-4, betas=(0.5, 0.9)) dis_optim = torch.optim.Adam(dis.parameters(), lr=1e-4, betas=(0.5, 0.9)) sdlG = torch.optim.lr_scheduler.ExponentialLR(gen_optim, gamma) sdlD = torch.optim.lr_scheduler.ExponentialLR(dis_optim, gamma) # load data train_loader, test_loader, val_loader = get_mnist_loaders( batch_size=batch, num=0, test_batch_size=batch) step = 0 gloss, dloss = 0, 0 for i in range(epochs): with tqdm(train_loader) as it: for data, label in it: step += 1 z = torch.randn(batch, z_dim).to(device) dloss = gan.train_dis_step(data.to(device), dis_optim, z, grad_penal=True, label=label.to(device)) z = torch.randn(batch, z_dim).to(device) gloss = gan.train_gen_step(data.to(device),
# num += 1 epochs = 5 md = logistic_Unet(dim=16, channels=1, dim_mults=(1, 2, 4)).cuda() betas = np.linspace(1e-4, 1e-2, 1000) diffusion = categorical_diffusion(betas, transition_bands=None, method='gaussian', num_bits=8, loss_type='hybrid', hybrid_coeff=0.001, model_prediction='x_start', model_output='logistic_pars').cuda() optim = torch.optim.Adam(md.parameters(), lr=1e-3) train_loader, test_loader, val_loader = get_mnist_loaders() num = 0 for epoch in range(epochs): with tqdm(train_loader) as it: for x, label in it: num += 1 optim.zero_grad() x *= 255 x = x.long().cuda() loss = torch.sum(diffusion.training_losses(md, x)) loss.backward() optim.step() it.set_postfix(ordered_dict={ 'train_loss': loss.item(), 'epoch': epoch
samples, n = sampling_fn(model) samples = samples.detach().cpu().numpy() plot_fig(samples) # plt.show() return sde train1 = False #True train2 = False #True load_res = False schedule_epochs = 10 schedule_path = father_path + '/sc_vpsde.pkl' path1 = father_path + '/vpsde.pkl' beta_path = father_path + '/beta_candi.txt' config = config_mnist() train_loader, _, _ = get_mnist_loaders() model = cont_Unet(dim=16, channels=1, dim_mults=(1, 2, 4)).cuda() s_phi = sigma_phi().cuda() sde = VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) if train1: train(config, model, sde, train_loader) torch.save(model.state_dict(), path1) else: model.load_state_dict(torch.load(path1)) schedule = bddm_schedule(s_phi, denoise_fn=model, T=1000, tao=200,