# ----------# for epoch in range(1, args.n_epochs + 1): encoder.train() decoder.train() train_loss = 0 for batch_idx, (data, ) in enumerate(train_loader): data = data.to(device) optimizer.zero_grad() z, mu, logvar = encoder(data) recon_batch = decoder(z) loss, rec_loss, kl_loss = loss_function(recon_batch, data, mu, logvar) loss.backward() train_loss += loss.item() optimizer.step() print( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} recon_loss:{:.6f} kl_loss:{:.6f}' .format(epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item() / len(data), rec_loss.item() / len(data), kl_loss.item() / len(data))) batches_done = epoch * len(train_loader) + batch_idx if (epoch) % args.sample_interval == 0: torch.save( decoder.state_dict(), output_dir + f'/decoder_64_VAE_{args.beta_vae}_epoch{epoch}.pth') torch.save( encoder.state_dict(), output_dir + f'/encoder_64_VAE_{args.beta_vae}_epoch{epoch}.pth')
x16_data = x16_data.to(device) optimizer.zero_grad() z, mu, logvar = encoder(x64_data) _, z16, _ = X_l_encoder(encoder_model, x16_data) z16 = upsample(z16, 16) z_h = torch.cat((z16, z), 1) recon_batch = decoder(z_h) loss, rec_loss, kl_loss = loss_function(recon_batch, x64_data, mu, logvar) loss.backward() train_loss += loss.item() optimizer.step() print( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} recon_loss:{:.6f} kl_loss:{:.6f}' .format(epoch, batch_idx * len(x64_data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item() / len(x64_data), rec_loss.item() / len(x64_data), kl_loss.item() / len(x64_data))) batches_done = epoch * len(train_loader) + batch_idx if (epoch) % args.sample_interval == 0: torch.save( decoder.state_dict(), output_dir + f'/decoder_16_64_VAE_0.5_{args.beta_vae}_epoch{epoch}.pth') torch.save( encoder.state_dict(), output_dir + f'/encoder_16_64_VAE_0.5_{args.beta_vae}_epoch{epoch}.pth')
optimizer.zero_grad() z, mu, logvar = encoder(x64_data) _, z16, _ = X_l_encoder(encoder_model, x16_data) z16 = upsample(z16, 16) z_h = torch.cat((z16, z), 1) recon_batch = decoder(z_h) loss, rec_loss, kl_loss = loss_function(recon_batch, x64_data, mu, logvar) loss.backward() train_loss += loss.item() optimizer.step() print( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} recon_loss:{:.6f} kl_loss:{:.6f}' .format(epoch, batch_idx * len(x64_data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item() / len(x64_data), rec_loss.item() / len(x64_data), kl_loss.item() / len(x64_data))) batches_done = epoch * len(train_loader) + batch_idx if (epoch) % args.sample_interval == 0: test(encoder_model, epoch, x64_test, x16_test) torch.save( decoder.state_dict(), model_dir + f'/decoder_16_64_VAE_1_{args.beta_vae}_epoch{epoch}.pth') torch.save( encoder.state_dict(), model_dir + f'/encoder_16_64_VAE_1_{args.beta_vae}_epoch{epoch}.pth')
logvar=logvar.reshape(-1,16) KLD = torch.sum(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim = 1), dim = 0) return BCE + args.beta_vae*KLD, BCE , KLD # ----------# # Training # # ----------# for epoch in range(1,args.n_epochs+1): encoder.train() decoder.train() train_loss = 0 for batch_idx, (data, ) in enumerate(train_loader): data = data.to(device) optimizer.zero_grad() z, mu, logvar = encoder(data) recon_batch = decoder(z) loss,rec_loss, kl_loss= loss_function(recon_batch, data, mu, logvar) loss.backward() train_loss += loss.item() optimizer.step() print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} recon_loss:{:.6f} kl_loss:{:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item() / len(data),rec_loss.item() / len(data),kl_loss.item() / len(data))) batches_done = epoch * len(train_loader) + batch_idx if (epoch) % args.sample_interval == 0: test(epoch,x_test) torch.save(decoder.state_dict(), model_dir + f'/decoder_16_VAE_{args.beta_vae}_epoch{epoch}.pth') torch.save(encoder.state_dict(), model_dir + f'/encoder_16_VAE_{args.beta_vae}_epoch{epoch}.pth')