def __init__(self, dataType='BSD', batch_size=10, args=None): self.dataType = dataType if dataType == 'BSD': self.dataPath = './dataset/' self.imgList = os.listdir(self.dataPath) self.batchSize = batch_size self.len = len(self.imgList) self.loimgs = torch.zeros((300, 3, 32, 32)) self.midImgs = torch.zeros((300, 3, 64, 64)) self.HDImgs = torch.zeros((300, 3, 128, 128)) self.iter = 0 preprocess = torchTrans.Compose([ torchTrans.ToTensor(), torchTrans.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) for i in range(self.len): imgH = cv2.resize( mpimg.imread(self.dataPath + self.imgList[i + self.iter]), (128, 128))[:, :, 0:-1] imgM = cv2.resize(imgH, (64, 64)) img = cv2.resize(imgM, (32, 32)) imgH = preprocess(imgH) imgM = preprocess(imgM) img = preprocess(img) self.loimgs[i, :, :, :] = img self.midImgs[i, :, :, :] = imgM self.HDImgs[i, :, :, :] = imgH elif dataType == 'CIFAR': train_gen, dev_gen, test_gen = utils.dataset_iterator(args) self.batchSize = batch_size self.gen = utils.inf_train_gen(train_gen) elif dataType == 'PASCAL': self.dataPath = './VOCdevkit/VOC2012/' self.imgList = [] for line in open(self.dataPath + 'ImageSets/Main/trainval.txt'): self.imgList.append(line[0:-1]) self.batchSize = batch_size self.len = len(self.imgList) self.loimgs = torch.zeros((self.len, 3, 32, 32)) self.midImgs = torch.zeros((self.len, 3, 64, 64)) self.HDImgs = torch.zeros((self.len, 3, 128, 128)) self.iter = 0 preprocess = torchTrans.Compose([ torchTrans.ToTensor(), torchTrans.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) for i in range(self.len): imgH = cv2.resize( mpimg.imread(self.dataPath + 'JPEGImages/' + self.imgList[i + self.iter] + '.jpg'), (128, 128)) imgM = cv2.resize(imgH, (64, 64)) img = cv2.resize(imgH, (32, 32)) imgH = preprocess(imgH) imgM = preprocess(imgM) img = preprocess(img) self.loimgs[i, :, :, :] = img self.midImgs[i, :, :, :] = imgM self.HDImgs[i, :, :, :] = imgH
def train(): args = load_args() train_gen, dev_gen, test_gen = utils.dataset_iterator(args) torch.manual_seed(1) netG, netD, netE = load_models(args) # optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9)) optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9)) vgg_scale = 0.0784 # 1/12.75 mse_criterion = nn.MSELoss() one = torch.FloatTensor([1]).cuda(0) mone = (one * -1).cuda(0) gen = utils.inf_train_gen(train_gen) """ train SRResNet with MSE """ for iteration in range(1, 20000): start_time = time.time() # for p in netD.parameters(): # p.requires_grad = False for p in netE.parameters(): p.requires_grad = False netG.zero_grad() _data = next(gen) real_data, vgg_data = stack_data(args, _data) real_data_v = autograd.Variable(real_data) #Perceptual loss #vgg_data_v = autograd.Variable(vgg_data) #vgg_features_real = netE(vgg_data_v) fake = netG(real_data_v) #vgg_features_fake = netE(fake) #diff = vgg_features_fake - vgg_features_real.cuda(0) #perceptual_loss = vgg_scale * ((diff.pow(2)).sum(3).mean()) # mean(sum(square(diff))) #perceptual_loss.backward(one) mse_loss = mse_criterion(fake, real_data_v) mse_loss.backward(one) optimizerG.step() save_dir = './plots/' + args.dataset plot.plot(save_dir, '/mse cost SRResNet', np.round(mse_loss.data.cpu().numpy(), 4)) if iteration % 50 == 49: utils.generate_sr_image(iteration, netG, save_dir, args, real_data_v) if (iteration < 5) or (iteration % 50 == 49): plot.flush() plot.tick() if iteration % 5000 == 0: torch.save(netG.state_dict(), './SRResNet_PL.pt') for iteration in range(args.epochs): start_time = time.time() """ Update AutoEncoder """ for p in netD.parameters(): p.requires_grad = False netG.zero_grad() netE.zero_grad() _data = next(gen) real_data = stack_data(args, _data) real_data_v = autograd.Variable(real_data) encoding = netE(real_data_v) fake = netG(encoding) ae_loss = ae_criterion(fake, real_data_v) ae_loss.backward(one) optimizerE.step() optimizerG.step() """ Update D network """ for p in netD.parameters(): # reset requires_grad p.requires_grad = True # they are set to False below in netG update for i in range(5): _data = next(gen) real_data = stack_data(args, _data) real_data_v = autograd.Variable(real_data) # train with real data netD.zero_grad() D_real = netD(real_data_v) D_real = D_real.mean() D_real.backward(mone) # train with fake data noise = torch.randn(args.batch_size, args.dim).cuda() noisev = autograd.Variable(noise, volatile=True) # totally freeze netG # instead of noise, use image fake = autograd.Variable(netG(real_data_v).data) inputv = fake D_fake = netD(inputv) D_fake = D_fake.mean() D_fake.backward(one) # train with gradient penalty gradient_penalty = ops.calc_gradient_penalty( args, netD, real_data_v.data, fake.data) gradient_penalty.backward() D_cost = D_fake - D_real + gradient_penalty Wasserstein_D = D_real - D_fake optimizerD.step() # Update generator network (GAN) # noise = torch.randn(args.batch_size, args.dim).cuda() # noisev = autograd.Variable(noise) _data = next(gen) real_data = stack_data(args, _data) real_data_v = autograd.Variable(real_data) # again use real data instead of noise fake = netG(real_data_v) G = netD(fake) G = G.mean() G.backward(mone) G_cost = -G optimizerG.step() # Write logs and save samples save_dir = './plots/' + args.dataset plot.plot(save_dir, '/disc cost', np.round(D_cost.cpu().data.numpy(), 4)) plot.plot(save_dir, '/gen cost', np.round(G_cost.cpu().data.numpy(), 4)) plot.plot(save_dir, '/w1 distance', np.round(Wasserstein_D.cpu().data.numpy(), 4)) # plot.plot(save_dir, '/ae cost', np.round(ae_loss.data.cpu().numpy(), 4)) # Calculate dev loss and generate samples every 100 iters if iteration % 100 == 99: dev_disc_costs = [] for images, _ in dev_gen(): imgs = stack_data(args, images) imgs_v = autograd.Variable(imgs, volatile=True) D = netD(imgs_v) _dev_disc_cost = -D.mean().cpu().data.numpy() dev_disc_costs.append(_dev_disc_cost) plot.plot(save_dir, '/dev disc cost', np.round(np.mean(dev_disc_costs), 4)) # utils.generate_image(iteration, netG, save_dir, args) # utils.generate_ae_image(iteration, netE, netG, save_dir, args, real_data_v) utils.generate_sr_image(iteration, netG, save_dir, args, real_data_v) # Save logs every 100 iters if (iteration < 5) or (iteration % 100 == 99): plot.flush() plot.tick()
def train(): args = load_args() train_gen, dev_gen, test_gen = utils.dataset_iterator(args) torch.manual_seed(1) np.set_printoptions(precision=4) netG, netD, netE = load_models(args) optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9)) optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9)) optimizerE = optim.Adam(netE.parameters(), lr=1e-4, betas=(0.5, 0.9)) ae_criterion = nn.MSELoss() one = torch.FloatTensor([1]).cuda() mone = (one * -1).cuda() gen = utils.inf_train_gen(train_gen) preprocess = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) for iteration in range(args.epochs): start_time = time.time() """ Update AutoEncoder """ for p in netD.parameters(): p.requires_grad = False netG.zero_grad() netE.zero_grad() _data = next(gen) real_data = stack_data(args, _data) real_data_v = autograd.Variable(real_data) encoding = netE(real_data_v) fake = netG(encoding) ae_loss = ae_criterion(fake, real_data_v) ae_loss.backward(one) optimizerE.step() optimizerG.step() """ Update D network """ for p in netD.parameters(): # reset requires_grad p.requires_grad = True # they are set to False below in netG update for i in range(5): _data = next(gen) real_data = stack_data(args, _data) real_data_v = autograd.Variable(real_data) # train with real data netD.zero_grad() D_real = netD(real_data_v) D_real = D_real.mean() D_real.backward(mone) # train with fake data noise = torch.randn(args.batch_size, args.dim).cuda() noisev = autograd.Variable(noise, volatile=True) # totally freeze netG fake = autograd.Variable(netG(noisev).data) inputv = fake D_fake = netD(inputv) D_fake = D_fake.mean() D_fake.backward(one) # train with gradient penalty gradient_penalty = ops.calc_gradient_penalty( args, netD, real_data_v.data, fake.data) gradient_penalty.backward() D_cost = D_fake - D_real + gradient_penalty Wasserstein_D = D_real - D_fake optimizerD.step() # Update generator network (GAN) noise = torch.randn(args.batch_size, args.dim).cuda() noisev = autograd.Variable(noise) fake = netG(noisev) G = netD(fake) G = G.mean() G.backward(mone) G_cost = -G optimizerG.step() # Write logs and save samples save_dir = './plots/' + args.dataset plot.plot(save_dir, '/disc cost', np.round(D_cost.cpu().data.numpy(), 4)) plot.plot(save_dir, '/gen cost', np.round(G_cost.cpu().data.numpy(), 4)) plot.plot(save_dir, '/w1 distance', np.round(Wasserstein_D.cpu().data.numpy(), 4)) plot.plot(save_dir, '/ae cost', np.round(ae_loss.data.cpu().numpy(), 4)) # Calculate dev loss and generate samples every 100 iters if iteration % 100 == 99: dev_disc_costs = [] for images, _ in dev_gen(): imgs = stack_data(args, images) imgs_v = autograd.Variable(imgs, volatile=True) D = netD(imgs_v) _dev_disc_cost = -D.mean().cpu().data.numpy() dev_disc_costs.append(_dev_disc_cost) plot.plot(save_dir, '/dev disc cost', np.round(np.mean(dev_disc_costs), 4)) # utils.generate_image(iteration, netG, save_dir, args) utils.generate_ae_image(iteration, netE, netG, save_dir, args, real_data_v) # Save logs every 100 iters if (iteration < 5) or (iteration % 100 == 99): plot.flush() plot.tick()
def train(args): torch.manual_seed(8734) netE = Encoder(args).cuda() W1 = GeneratorW1(args).cuda() W2 = GeneratorW2(args).cuda() W3 = GeneratorW3(args).cuda() W4 = GeneratorW4(args).cuda() W5 = GeneratorW5(args).cuda() netD = DiscriminatorZ(args).cuda() print (netE, W1, W2, W3, W4, W5, netD) optimE = optim.Adam(netE.parameters(), lr=0.005, betas=(0.5, 0.9), weight_decay=1e-4) optimW1 = optim.Adam(W1.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW2 = optim.Adam(W2.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW3 = optim.Adam(W3.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW4 = optim.Adam(W4.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW5 = optim.Adam(W5.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimD = optim.Adam(netD.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=1e-4) best_test_acc, best_test_loss = 0., np.inf args.best_loss, args.best_acc = best_test_loss, best_test_acc if args.resume: netE, optimE, stats = load_model(args, netE, optimE, 'single_code1') W1, optimW1, stats = load_model(args, W1, optimrW1, 'single_code1') W2, optimW2, stats = load_model(args, W2, optimW2, 'single_code1') W3, optimW3, stats = load_model(args, W3, optimW3, 'single_code1') W4, optimW3, stats = load_model(args, W4, optimW4, 'single_code1') W4, optimW3, stats = load_model(args, W5, optimW5, 'single_code1') netD, optimD, stats = load_model(args, netD, optimD, 'single_code1') best_test_acc, best_test_loss = stats print ('==> resuming models at ', stats) cifar_train, cifar_test = load_cifar(args) if args.use_x: base_gen = datagen.load(args) w1_gen = utils.inf_train_gen(base_gen[0]) w2_gen = utils.inf_train_gen(base_gen[1]) w3_gen = utils.inf_train_gen(base_gen[2]) w4_gen = utils.inf_train_gen(base_gen[3]) w5_gen = utils.inf_train_gen(base_gen[4]) one = torch.FloatTensor([1]).cuda() mone = (one * -1).cuda() if args.use_x: X = sample_x(args, [w1_gen, w2_gen, w3_gen, w4_gen, w5_gen], 0) X = list(map(lambda x: (x+1e-10).float(), X)) print ("==> pretraining encoder") j = 0 final = 100. e_batch_size = 1000 if args.load_e: netE, optimE, _ = utils.load_model(args, netE, optimE, 'Encoder_cifar.pt') print ('==> loading pretrained encoder') if args.pretrain_e: for j in range(200): #x = sample_x(args, [w1_gen, w2_gen, w3_gen, w4_gen, w5_gen], 0) x = sample_z_like((e_batch_size, args.ze)) z = sample_z_like((e_batch_size, args.z)) codes = netE(x) for i, code in enumerate(codes): code = code.view(e_batch_size, args.z) mean_loss, cov_loss = pretrain_loss(code, z) loss = mean_loss + cov_loss loss.backward(retain_graph=True) optimE.step() netE.zero_grad() print ('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format( j, mean_loss.item(), cov_loss.item())) final = loss.item() if loss.item() < 0.1: print ('Finished Pretraining Encoder') break utils.save_model(args, netE, optimE) print ('==> Begin Training') for _ in range(1000): for batch_idx, (data, target) in enumerate(cifar_train): batch_zero_grad([netE, W1, W2, W3, W4, W5, netD]) #netE.zero_grad(); W1.zero_grad(); W2.zero_grad() #W3.zero_grad(); W4.zero_grad(); W5.zero_grad() z = sample_z_like((args.batch_size, args.ze,)) codes = netE(z) l1 = W1(codes[0]).mean(0) l2 = W2(codes[1]).mean(0) l3 = W3(codes[2]).mean(0) l4 = W4(codes[3]).mean(0) l5 = W5(codes[4]).mean(0) # Z Adversary free_params([netD]) frozen_params([netE, W1, W2, W3, W4, W5]) for code in codes: noise = sample_z_like((args.batch_size, args.z)) d_real = netD(noise) d_fake = netD(code) d_real_loss = -1 * torch.log((1-d_real).mean()) d_fake_loss = -1 * torch.log(d_fake.mean()) d_real_loss.backward(retain_graph=True) d_fake_loss.backward(retain_graph=True) d_loss = d_real_loss + d_fake_loss optimD.step() # Generator (Mean test) frozen_params([netD]) free_params([netE, W1, W2, W3, W4, W5]) d_costs = [] for code in codes: d_costs.append(netD(code)) d_loss = torch.cat(d_costs).mean() correct, loss = train_clf(args, [l1, l2, l3, l4, l5], data, target, val=True) scaled_loss = (args.beta*loss) + d_loss scaled_loss.backward() optimE.step(); optimW1.step(); optimW2.step() optimW3.step(); optimW4.step(); optimW5.step() loss = loss.item() """ Update Statistics """ if batch_idx % 50 == 0: acc = (correct / 1) norm_z1 = np.linalg.norm(l1.data) norm_z2 = np.linalg.norm(l2.data) norm_z3 = np.linalg.norm(l3.data) norm_z4 = np.linalg.norm(l4.data) norm_z5 = np.linalg.norm(l5.data) print ('**************************************') print ('Mean Test: Enc, Dz, Lscale: {} test'.format(args.beta)) print ('Acc: {}, G Loss: {}, D Loss: {}'.format(acc, loss, d_loss)) print ('Filter norm: ', norm_z1) print ('Filter norm: ', norm_z2) print ('Filter norm: ', norm_z3) print ('Linear norm: ', norm_z4) print ('Linear norm: ', norm_z5) print ('best test loss: {}'.format(args.best_loss)) print ('best test acc: {}'.format(args.best_acc)) print ('**************************************') if batch_idx % 100 == 0: test_acc = 0. test_loss = 0. for i, (data, y) in enumerate(cifar_test): z = sample_z_like((args.batch_size, args.ze,)) w1_code, w2_code, w3_code, w4_code, w5_code = netE(z) l1 = W1(w1_code).mean(0) l2 = W2(w2_code).mean(0) l3 = W3(w3_code).mean(0) l4 = W4(w4_code).mean(0) l5 = W5(w5_code).mean(0) min_loss_batch = 10. z_test = [l1, l2, l3, l4, l5] correct, loss = train_clf(args, [l1, l2, l3, l4, l5], data, y, val=True) if loss.item() < min_loss_batch: min_loss_batch = loss.item() z_test = [l1, l2, l3, l4, l5] test_acc += correct.item() test_loss += loss.item() #y_acc, y_loss = utils.test_samples(args, z_test, train=True) test_loss /= len(cifar_test.dataset) test_acc /= len(cifar_test.dataset) print ('Test Accuracy: {}, Test Loss: {}'.format(test_acc, test_loss)) # print ('FC Accuracy: {}, FC Loss: {}'.format(y_acc, y_loss)) if test_loss < best_test_loss or test_acc > best_test_acc: print ('==> new best stats, saving') if test_loss < best_test_loss: best_test_loss = test_loss args.best_loss = test_loss if test_acc > best_test_acc: best_test_acc = test_acc args.best_acc = test_acc
train_sampler = torch.utils.data.distributed.DistributedSampler( dataset) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=train_sampler) else: dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) dataloader = inf_train_gen(dataloader) #models generator = Generator().to(device) discriminator = Discriminator().to(device) vgg = Vgg16(requires_grad=False).to(device) if args.pre_train: if args.distributed: g_checkpoint = torch.load( args.checkpoint_path + 'generator_checkpoint_{}.ckpt'.format(args.last_iter), map_location=lambda storage, loc: storage.cuda(args.local_rank )) d_checkpoint = torch.load( args.checkpoint_path +
def train(args): torch.manual_seed(8734) netE = Encoder(args).cuda() W1 = GeneratorW1(args).cuda() W2 = GeneratorW2(args).cuda() W3 = GeneratorW3(args).cuda() W4 = GeneratorW4(args).cuda() W5 = GeneratorW5(args).cuda() print (netE, W1, W2, W3, W4, W5) optimizerE = optim.Adam(netE.parameters(), lr=3e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimizerW1 = optim.Adam(W1.parameters(), lr=3e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimizerW2 = optim.Adam(W2.parameters(), lr=3e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimizerW3 = optim.Adam(W3.parameters(), lr=3e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimizerW4 = optim.Adam(W4.parameters(), lr=3e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimizerW5 = optim.Adam(W5.parameters(), lr=3e-4, betas=(0.5, 0.9), weight_decay=1e-4) best_test_acc, best_test_loss = 0., np.inf args.best_loss, args.best_acc = best_test_loss, best_test_acc if args.resume: netE, optimizerE = load_model(args, netE, optimizerE, 'single_code1', m0) W1, optimizerW1 = load_model(args, W1, optimizerW1, 'single_code1', m1) W2, optimizerW2 = load_model(args, W2, optimizerW2, 'single_code1', m2) W3, optimizerW3, stats = load_model(args, W3, optimizerW3, 'single_code1', m3) W4, optimizerW3, stats = load_model(args, W4, optimizerW4, 'single_code1', m4) W4, optimizerW3, stats = load_model(args, W5, optimizerW5, 'single_code1', m5) best_test_acc, best_test_loss = stats print ('==> resumeing models at ', stats) cifar_train, cifar_test = load_cifar() if args.use_x: base_gen = datagen.load(args) w1_gen = utils.inf_train_gen(base_gen[0]) w2_gen = utils.inf_train_gen(base_gen[1]) w3_gen = utils.inf_train_gen(base_gen[2]) w4_gen = utils.inf_train_gen(base_gen[3]) w5_gen = utils.inf_train_gen(base_gen[4]) one = torch.FloatTensor([1]).cuda() mone = (one * -1).cuda() if args.use_x: X = sample_x(args, [w1_gen, w2_gen, w3_gen, w4_gen, w5_gen], 0) X = list(map(lambda x: (x+1e-10).float(), X)) for _ in range(1000): for batch_idx, (data, target) in enumerate(cifar_train): netE.zero_grad() W1.zero_grad() W2.zero_grad() W3.zero_grad() W4.zero_grad() W5.zero_grad() """ if batch_idx % 50 == 0: if args.use_x: acc, loss = 0., 0. for i, (x, y) in enumerate(cifar_test): correct, l = train_clf(args, X, x, y, val=True) acc += correct.item() loss += l.item() print ("Functional Net: ", acc/len(cifar_test.dataset), loss/len(cifar_test.dataset)) """ z = sample_z_like((args.batch_size, args.ze,)) code = netE(z) l1 = W1(code) l2 = W2(code) l3 = W3(code) l4 = W4(code) l5 = W5(code)#.contiguous().view(args.batch_size, -1)) for (z1, z2, z3, z4, z5) in zip(l1, l2, l3, l4, l5): correct, loss = train_clf(args, [z1, z2, z3, z4, z5], data, target, val=True) scaled_loss = (1000*loss) #+ z1_loss + z2_loss + z3_loss scaled_loss.backward(retain_graph=True) optimizerE.step() optimizerW1.step() optimizerW2.step() optimizerW3.step() optimizerW4.step() optimizerW5.step() loss = loss.item() if batch_idx % 50 == 0: acc = (correct / 1) norm_z1 = np.linalg.norm(z1.data) norm_z2 = np.linalg.norm(z2.data) norm_z3 = np.linalg.norm(z3.data) norm_z4 = np.linalg.norm(z4.data) norm_z5 = np.linalg.norm(z5.data) print ('**************************************') print ('100 tied test') print ('Acc: {}, Loss: {}'.format(acc, loss)) print ('Filter norm: ', norm_z1) print ('Filter norm: ', norm_z2) print ('Linear norm: ', norm_z3) print ('Linear norm: ', norm_z4) print ('Linear norm: ', norm_z5) print ('best test loss: {}'.format(args.best_loss)) print ('best test acc: {}'.format(args.best_acc)) print ('**************************************') if batch_idx % 100 == 0: test_acc = 0. test_loss = 0. for i, (data, y) in enumerate(cifar_test): z = sample_z_like((args.batch_size, args.ze,)) code = netE(z) l1 = W1(code) l2 = W2(code) l3 = W3(code) l4 = W4(code) l5 = W5(code) min_loss_batch = 10. z_test = [l1[0], l2[0], l3[0], l4[0], l5[0]] for (z1, z2, z3, z4, z5) in zip(l1, l2, l3, l4, l5): correct, loss = train_clf(args, [z1, z2, z3, z4, z5], data, y, val=True) if loss.item() < min_loss_batch: min_loss_batch = loss.item() z_test = [z1, z2, z3, z4, z5] test_acc += correct.item() test_loss += loss.item() #y_acc, y_loss = utils.test_samples(args, z_test, train=True) test_loss /= len(cifar_test.dataset) * 32 test_acc /= len(cifar_test.dataset) * 32 print ('Test Accuracy: {}, Test Loss: {}'.format(test_acc, test_loss)) # print ('FC Accuracy: {}, FC Loss: {}'.format(y_acc, y_loss)) if test_loss < best_test_loss or test_acc > best_test_acc: print ('==> new best stats, saving') if test_loss < best_test_loss: best_test_loss = test_loss args.best_loss = test_loss if test_acc > best_test_acc: best_test_acc = test_acc args.best_acc = test_acc
def train(): args = load_args() train_gen = utils.dataset_iterator(args) dev_gen = utils.dataset_iterator(args) torch.manual_seed(1) netG, netD, netE = load_models(args) if args.use_spectral_norm: optimizerD = optim.Adam(filter(lambda p: p.requires_grad, netD.parameters()), lr=2e-4, betas=(0.0,0.9)) else: optimizerD = optim.Adam(netD.parameters(), lr=2e-4, betas=(0.5, 0.9)) optimizerG = optim.Adam(netG.parameters(), lr=2e-4, betas=(0.5, 0.9)) optimizerE = optim.Adam(netE.parameters(), lr=2e-4, betas=(0.5, 0.9)) schedulerD = optim.lr_scheduler.ExponentialLR(optimizerD, gamma=0.99) schedulerG = optim.lr_scheduler.ExponentialLR(optimizerG, gamma=0.99) schedulerE = optim.lr_scheduler.ExponentialLR(optimizerE, gamma=0.99) ae_criterion = nn.MSELoss() one = torch.FloatTensor([1]).cuda() mone = (one * -1).cuda() gen = utils.inf_train_gen(train_gen) for iteration in range(args.epochs): start_time = time.time() """ Update AutoEncoder """ for p in netD.parameters(): p.requires_grad = False netG.zero_grad() netE.zero_grad() _data = next(gen) # real_data = stack_data(args, _data) real_data = _data real_data_v = autograd.Variable(real_data).cuda() encoding = netE(real_data_v) fake = netG(encoding) ae_loss = ae_criterion(fake, real_data_v) ae_loss.backward(one) optimizerE.step() optimizerG.step() """ Update D network """ for p in netD.parameters(): p.requires_grad = True for i in range(5): _data = next(gen) # real_data = stack_data(args, _data) real_data = _data real_data_v = autograd.Variable(real_data).cuda() # train with real data netD.zero_grad() D_real = netD(real_data_v) D_real = D_real.mean() D_real.backward(mone) # train with fake data noise = torch.randn(args.batch_size, args.dim).cuda() noisev = autograd.Variable(noise, volatile=True) fake = autograd.Variable(netG(noisev).data) inputv = fake D_fake = netD(inputv) D_fake = D_fake.mean() D_fake.backward(one) # train with gradient penalty gradient_penalty = ops.calc_gradient_penalty(args, netD, real_data_v.data, fake.data) gradient_penalty.backward() D_cost = D_fake - D_real + gradient_penalty Wasserstein_D = D_real - D_fake optimizerD.step() # Update generator network (GAN) noise = torch.randn(args.batch_size, args.dim).cuda() noisev = autograd.Variable(noise) fake = netG(noisev) G = netD(fake) G = G.mean() G.backward(mone) G_cost = -G optimizerG.step() schedulerD.step() schedulerG.step() schedulerE.step() # Write logs and save samples save_dir = './plots/'+args.dataset plot.plot(save_dir, '/disc cost', D_cost.cpu().data.numpy()) plot.plot(save_dir, '/gen cost', G_cost.cpu().data.numpy()) plot.plot(save_dir, '/w1 distance', Wasserstein_D.cpu().data.numpy()) plot.plot(save_dir, '/ae cost', ae_loss.data.cpu().numpy()) # Calculate dev loss and generate samples every 100 iters if iteration % 100 == 99: dev_disc_costs = [] for i, (images, _) in enumerate(dev_gen): # imgs = stack_data(args, images) imgs = images imgs_v = autograd.Variable(imgs, volatile=True).cuda() D = netD(imgs_v) _dev_disc_cost = -D.mean().cpu().data.numpy() dev_disc_costs.append(_dev_disc_cost) plot.plot(save_dir ,'/dev disc cost', np.mean(dev_disc_costs)) utils.generate_image(iteration, netG, save_dir, args) # utils.generate_ae_image(iteration, netE, netG, save_dir, args, real_data_v) # Save logs every 100 iters if (iteration < 5) or (iteration % 100 == 99): plot.flush() plot.tick() if iteration % 100 == 0: utils.save_model(netG, optimizerG, iteration, 'models/{}/G_{}'.format(args.dataset, iteration)) utils.save_model(netD, optimizerD, iteration, 'models/{}/D_{}'.format(args.dataset, iteration))