def train(pool_size, epochs, train_data, val_data, ctx, netG, netD, trainerG, trainerD, lambda1, batch_size, expname, append=True, useAE=False): text_file = open(expname + "_validtest.txt", "w") text_file.close() #netGT, netDT, _, _ = set_test_network(opt.depth, ctx, opt.lr, opt.beta1,opt.ndf, opt.ngf, opt.append) dlr = trainerD.learning_rate glr = trainerG.learning_rate GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss() L1_loss = gluon.loss.L2Loss() image_pool = imagePool.ImagePool(pool_size) metric = mx.metric.CustomMetric(facc) metric2 = mx.metric.MSE() loss_rec_G = [] loss_rec_D = [] loss_rec_R = [] acc_rec = [] stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') logging.basicConfig(level=logging.DEBUG) for epoch in range(epochs): if useAE: for batch in train_data: train_data.reset() real_in = batch.data[0].as_in_context(ctx) real_out = batch.data[1].as_in_context(ctx) fake_out = netG(real_in) loss = L1_loss(real_out, fake_out) loss.backward() trainerG.step(batch.data[0].shape[0]) metric2.update([ real_out, ], [ fake_out, ]) if epoch % 10 == 0: filename = "checkpoints/" + expname + "_" + str( epoch) + "_G.params" netG.save_params(filename) fake_img1 = nd.concat(real_in[0], real_out[0], fake_out[0], dim=1) fake_img2 = nd.concat(real_in[1], real_out[1], fake_out[1], dim=1) fake_img3 = nd.concat(real_in[2], real_out[2], fake_out[2], dim=1) fake_img4 = nd.concat(real_in[3], real_out[3], fake_out[3], dim=1) #fake_img4T = nd.concat(real_in[3],real_out[3], fake_out[3], dim=1) fake_img = nd.concat(fake_img1, fake_img2, fake_img3, dim=2) visual.visualize(fake_img) plt.savefig('outputs/' + expname + '_' + str(epoch) + '.png') train_data.reset() name, acc = metric.get() metric2.reset() print("training acc: " + acc) else: tic = time.time() btic = time.time() train_data.reset() iter = 0 if epoch > 250: trainerD.set_learning_rate(dlr * (1 - int(epoch - 250) / 1000)) trainerG.set_learning_rate(glr * (1 - int(epoch - 250) / 1000)) #print('learning rate : '+str(trainerD.learning_rate )) for batch in train_data: ############################ # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z))) ########################### real_in = batch.data[0].as_in_context(ctx) real_out = batch.data[1].as_in_context(ctx) fake_out = netG(real_in) fake_concat = nd.concat(real_in, fake_out, dim=1) if append else fake_out with autograd.record(): # Train with fake image # Use image pooling to utilize history images output = netD(fake_concat) fake_label = nd.zeros(output.shape, ctx=ctx) errD_fake = GAN_loss(output, fake_label) metric.update([ fake_label, ], [ output, ]) real_concat = nd.concat(real_in, real_out, dim=1) if append else real_out output = netD(real_concat) real_label = nd.ones(output.shape, ctx=ctx) errD_real = GAN_loss(output, real_label) errD = (errD_real + errD_fake) * 0.5 errD.backward() metric.update([ real_label, ], [ output, ]) trainerD.step(batch.data[0].shape[0]) ############################ # (2) Update G network: maximize log(D(x, G(x, z))) - lambda1 * L1(y, G(x, z)) ########################### with autograd.record(): fake_out = netG(real_in) fake_concat = nd.concat(real_in, fake_out, dim=1) if append else fake_out output = netD(fake_concat) real_label = nd.ones(output.shape, ctx=ctx) errG = GAN_loss(output, real_label) + L1_loss( real_out, fake_out) * lambda1 errR = L1_loss(real_out, fake_out) errG.backward() trainerG.step(batch.data[0].shape[0]) loss_rec_G.append( nd.mean(errG).asscalar() - nd.mean(errR).asscalar() * lambda1) loss_rec_D.append(nd.mean(errD).asscalar()) loss_rec_R.append(nd.mean(errR).asscalar()) name, acc = metric.get() acc_rec.append(acc) # Print log infomation every ten batches if iter % 5 == 0: name, acc = metric.get() logging.info('speed: {} samples/s'.format( batch_size / (time.time() - btic))) #print(errD) logging.info( 'discriminator loss = %f, generator loss = %f, binary training acc = %f reconstruction error= %f at iter %d epoch %d' % (nd.mean(errD).asscalar(), nd.mean(errG).asscalar(), acc, nd.mean(errR).asscalar(), iter, epoch)) iter = iter + 1 btic = time.time() name, acc = metric.get() metric.reset() train_data.reset() logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc)) logging.info('time: %f' % (time.time() - tic)) if epoch % 5 == 0: text_file = open(expname + "_validtest.txt", "a") filename = "checkpoints/" + expname + "_" + str( epoch) + "_D.params" netD.save_params(filename) filename = "checkpoints/" + expname + "_" + str( epoch) + "_G.params" netG.save_params(filename) fake_img1 = nd.concat(real_in[0], real_out[0], fake_out[0], dim=1) fake_img2 = nd.concat(real_in[1], real_out[1], fake_out[1], dim=1) fake_img3 = nd.concat(real_in[2], real_out[2], fake_out[2], dim=1) fake_img4 = nd.concat(real_in[3], real_out[3], fake_out[3], dim=1) val_data.reset() for vbatch in val_data: real_in = vbatch.data[0].as_in_context(ctx) real_out = vbatch.data[1].as_in_context(ctx) fake_out = netG(real_in) metric2.update([ fake_out, ], [ real_out, ]) _, acc2 = metric2.get() text_file.write( "%s %s %s\n" % (str(epoch), nd.mean(errR).asscalar(), str(acc2))) metric2.reset() fake_img1T = nd.concat(real_in[0], real_out[0], fake_out[0], dim=1) fake_img2T = nd.concat(real_in[1], real_out[1], fake_out[1], dim=1) fake_img3T = nd.concat(real_in[2], real_out[2], fake_out[2], dim=1) #fake_img4T = nd.concat(real_in[3],real_out[3], fake_out[3], dim=1) fake_img = nd.concat(fake_img1, fake_img2, fake_img3, fake_img1T, fake_img2T, fake_img3T, dim=2) visual.visualize(fake_img) plt.savefig('outputs/' + expname + '_' + str(epoch) + '.png') text_file.close() return ([loss_rec_D, loss_rec_G, loss_rec_R, acc_rec])
def train(pool_size, epochs, train_data, ctx, netEn, netDe, netD, trainerEn, trainerDe, trainerD, lambda1, batch_size, expname): threewayloss =gluon.loss.SoftmaxCrossEntropyLoss() GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss() L1_loss = gluon.loss.L1Loss() image_pool = imagePool.ImagePool(pool_size) metric = mx.metric.CustomMetric(facc) stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') logging.basicConfig(level=logging.DEBUG) for epoch in range(epochs): tic = time.time() btic = time.time() train_data.reset() iter = 0 for batch in train_data: ############################ # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z))) ########################### real_in = batch.data[0].as_in_context(ctx) real_out = batch.data[1].as_in_context(ctx) tempout = netEn(real_in) fake_out = netDe(tempout) fake_concat = fake_out #fake_concat = image_pool.query(fake_out) #fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1)) with autograd.record(): # Train with fake image # Use image pooling to utilize history images output = netD(fake_concat) fake_label = nd.zeros(output.shape[0], ctx=ctx) errD_fake = threewayloss(output, fake_label) metric.update([fake_label, ], [output, ]) # Train with real image real_concat = real_out output = netD(real_concat) real_label = nd.ones(output.shape[0], ctx=ctx) errD_real = threewayloss(output, real_label) metric.update([real_label, ], [output, ]) #train with abnormal image abinput = nd.random.uniform(-1,1,tempout.shape,ctx=ctx) aboutput =netD( netDe(abinput)) #print(aboutput.shape) #print(output.shape) ab_label = 2*nd.ones(aboutput.shape[0], ctx=ctx) errD_ab = threewayloss(aboutput, ab_label) errD = (errD_real + errD_fake + errD_ab) * 0.33 errD.backward() trainerD.step(batch.data[0].shape[0]) ############################ # (2) Update G network: maximize log(D(x, G(x, z))) - lambda1 * L1(y, G(x, z)) ########################### with autograd.record(): fake_out = netDe(netEn(real_in)) fake_concat = fake_out output = netD(fake_concat) real_label = nd.ones(output.shape[0], ctx=ctx) errG = threewayloss(output, real_label) + L1_loss(real_out, fake_out) * lambda1 errR = L1_loss(real_out, fake_out) errG.backward() trainerEn.step(batch.data[0].shape[0]) trainerDe.step(batch.data[0].shape[0]) # Print log infomation every ten batches if iter % 10 == 0: name, acc = metric.get() logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic))) logging.info( 'discriminator loss = %f, generator loss = %f, latent error = %f, binary training acc = %f, reconstruction error= %f at iter %d epoch %d' % (nd.mean(errD).asscalar(), nd.mean(errG).asscalar(), nd.mean(errD_ab).asscalar() , acc,nd.mean(errR).asscalar() ,iter, epoch)) iter = iter + 1 btic = time.time() name, acc = metric.get() metric.reset() logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc)) logging.info('time: %f' % (time.time() - tic)) if epoch%10 ==0: filename = "checkpoints/"+expname+"_"+str(epoch)+"_D.params" netD.save_params(filename) filename = "checkpoints/"+expname+"_"+str(epoch)+"_En.params" netEn.save_params(filename) filename = "checkpoints/"+expname+"_"+str(epoch)+"_De.params" netDe.save_params(filename) # Visualize one generated image for each epoch fake_img = nd.concat(real_in[0],real_out[0], fake_out[0], dim=1) visual.visualize(fake_img) plt.savefig('outputs/'+expname+'_'+str(epoch)+'.png')
def train(pool_size, epochs, train_data, val_data, ctx, netEn, netDe, netD, netD2, trainerEn, trainerDe, trainerD, trainerD2, lambda1, batch_size, expname, append=True, useAE = False): tp_file = open(expname + "_trainloss.txt", "w") tp_file.close() text_file = open(expname + "_validtest.txt", "w") text_file.close() #netGT, netDT, _, _ = set_test_network(opt.depth, ctx, opt.lr, opt.beta1,opt.ndf, opt.ngf, opt.append) GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss() L1_loss = gluon.loss.L2Loss() image_pool = imagePool.ImagePool(pool_size) metric = mx.metric.CustomMetric(facc) metric2 = mx.metric.CustomMetric(facc) metricMSE = mx.metric.MSE() loss_rec_G = [] loss_rec_D = [] loss_rec_R = [] acc_rec = [] acc2_rec = [] loss_rec_D2 = [] loss_rec_G2 = [] lr = 0.002 #mu = nd.random_normal(loc=0, scale=1, shape=(batch_size/2,64,1,1), ctx=ctx) mu = nd.random.uniform(low= -1, high=1, shape=(batch_size/2,64,1,1),ctx=ctx) #mu = nd.zeros((batch_size/2,64,1,1),ctx=ctx) sigma = nd.ones((64,1,1),ctx=ctx) mu.attach_grad() sigma.attach_grad() stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') logging.basicConfig(level=logging.DEBUG) for epoch in range(epochs): tic = time.time() btic = time.time() train_data.reset() iter = 0 #print('learning rate : '+str(trainerD.learning_rate )) for batch in train_data: ############################ # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z))) ########################### real_in = batch.data[0].as_in_context(ctx) real_out = batch.data[1].as_in_context(ctx) fake_latent= netEn(real_in) #real_latent = nd.random_normal(loc=0, scale=1, shape=fake_latent.shape, ctx=ctx) real_latent = nd.multiply(nd.power(sigma,2),nd.random_normal(loc=0, scale=1, shape=fake_latent.shape, ctx=ctx)) #nd.random.uniform( low=-1, high=1, shape=fake_latent.shape, ctx=ctx) fake_out = netDe(fake_latent) fake_concat = nd.concat(real_in, fake_out, dim=1) if append else fake_out with autograd.record(): # Train with fake image # Use image pooling to utilize history imagesi output = netD(fake_concat) output2 = netD2(fake_latent) fake_label = nd.zeros(output.shape, ctx=ctx) fake_latent_label = nd.zeros(output2.shape, ctx=ctx) noiseshape = (fake_latent.shape[0]/2,fake_latent.shape[1],fake_latent.shape[2],fake_latent.shape[3]) eps2 = nd.multiply(nd.power(sigma,2),nd.random_normal(loc=0, scale=1, shape=fake_latent.shape, ctx=ctx)) #eps2 = nd.random_normal(loc=0, scale=sigma.asscalar(), shape=fake_latent.shape, ctx=ctx) # #eps = nd.random.uniform( low=-1, high=1, shape=noiseshape, ctx=ctx) rec_output = netD(netDe(eps2)) errD_fake = GAN_loss(rec_output, fake_label) errD_fake2 = GAN_loss(output, fake_label) errD2_fake = GAN_loss(output2, fake_latent_label) metric.update([fake_label, ], [output, ]) metric2.update([fake_latent_label, ], [output2, ]) real_concat = nd.concat(real_in, real_out, dim=1) if append else real_out output = netD(real_concat) output2 = netD2(real_latent) real_label = nd.ones(output.shape, ctx=ctx) real_latent_label = nd.ones(output2.shape, ctx=ctx) errD_real = GAN_loss(output, real_label) errD2_real = GAN_loss(output2, real_latent_label) #errD = (errD_real + 0.5*(errD_fake+errD_fake2)) * 0.5 errD = (errD_real + errD_fake) * 0.5 errD2 = (errD2_real + errD2_fake) * 0.5 totalerrD = errD+errD2 totalerrD.backward() #errD2.backward() metric.update([real_label, ], [output, ]) metric2.update([real_latent_label, ], [output2, ]) trainerD.step(batch.data[0].shape[0]) trainerD2.step(batch.data[0].shape[0]) ############################ # (2) Update G network: maximize log(D(x, G(x, z))) - lambda1 * L1(y, G(x, z)) ########################### with autograd.record(): sh = fake_latent.shape eps2 = nd.multiply(nd.power(sigma,2),nd.random_normal(loc=0, scale=1, shape=fake_latent.shape, ctx=ctx)) #eps2 = nd.random_normal(loc=0, scale=sigma.asscalar(), shape=fake_latent.shape, ctx=ctx) # #eps = nd.random.uniform( low=-1, high=1, shape=noiseshape, ctx=ctx) rec_output = netD(netDe(eps2)) fake_latent= (netEn(real_in)) output2 = netD2(fake_latent) fake_out = netDe(fake_latent) fake_concat = nd.concat(real_in, fake_out, dim=1) if append else fake_out output = netD(fake_concat) real_label = nd.ones(output.shape, ctx=ctx) real_latent_label = nd.ones(output2.shape, ctx=ctx) errG2 = GAN_loss(rec_output, real_label) errR = L1_loss(real_out, fake_out) * lambda1 errG = 10.0*GAN_loss(output2, real_latent_label)+errG2+errR+nd.mean(nd.power(sigma,2)) errG.backward() if epoch>50: sigma -= lr / sigma.shape[0] * sigma.grad print(sigma) trainerDe.step(batch.data[0].shape[0]) trainerEn.step(batch.data[0].shape[0]) loss_rec_G2.append(nd.mean(errG2).asscalar()) loss_rec_G.append(nd.mean(nd.mean(errG)).asscalar()-nd.mean(errG2).asscalar()-nd.mean(errR).asscalar()) loss_rec_D.append(nd.mean(errD).asscalar()) loss_rec_R.append(nd.mean(errR).asscalar()) loss_rec_D2.append(nd.mean(errD2).asscalar()) _, acc2 = metric2.get() name, acc = metric.get() acc_rec.append(acc) acc2_rec.append(acc2) # Print log infomation every ten batches if iter % 10 == 0: _, acc2 = metric2.get() name, acc = metric.get() logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic))) #print(errD) logging.info('discriminator loss = %f, D2 loss = %f, generator loss = %f, G2 loss = %f, binary training acc = %f , D2 acc = %f, reconstruction error= %f at iter %d epoch %d' % (nd.mean(errD).asscalar(),nd.mean(errD2).asscalar(), nd.mean(errG-errG2-errR).asscalar(),nd.mean(errG2).asscalar(), acc,acc2,nd.mean(errR).asscalar() ,iter, epoch)) iter = iter + 1 btic = time.time() name, acc = metric.get() _, acc2 = metric2.get() tp_file = open(expname + "_trainloss.txt", "a") tp_file.write(str(nd.mean(errG2).asscalar()) + " " + str( nd.mean(nd.mean(errG)).asscalar() - nd.mean(errG2).asscalar() - nd.mean(errR).asscalar()) + " " + str( nd.mean(errD).asscalar()) + " " + str(nd.mean(errD2).asscalar()) + " " + str(nd.mean(errR).asscalar()) +" "+str(acc) + " " + str(acc2)+"\n") tp_file.close() metric.reset() metric2.reset() train_data.reset() logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc)) logging.info('time: %f' % (time.time() - tic)) if epoch%10 ==0:# and epoch>0: text_file = open(expname + "_validtest.txt", "a") filename = "checkpoints/"+expname+"_"+str(epoch)+"_D.params" netD.save_params(filename) filename = "checkpoints/"+expname+"_"+str(epoch)+"_D2.params" netD2.save_params(filename) filename = "checkpoints/"+expname+"_"+str(epoch)+"_En.params" netEn.save_params(filename) filename = "checkpoints/"+expname+"_"+str(epoch)+"_De.params" netDe.save_params(filename) fake_img1 = nd.concat(real_in[0],real_out[0], fake_out[0], dim=1) fake_img2 = nd.concat(real_in[1],real_out[1], fake_out[1], dim=1) fake_img3 = nd.concat(real_in[2],real_out[2], fake_out[2], dim=1) fake_img4 = nd.concat(real_in[3],real_out[3], fake_out[3], dim=1) val_data.reset() text_file = open(expname + "_validtest.txt", "a") for vbatch in val_data: real_in = vbatch.data[0].as_in_context(ctx) real_out = vbatch.data[1].as_in_context(ctx) fake_latent= netEn(real_in) y = netDe(fake_latent) fake_out = y metricMSE.update([fake_out, ], [real_out, ]) _, acc2 = metricMSE.get() text_file.write("%s %s %s\n" % (str(epoch), nd.mean(errR).asscalar(), str(acc2))) metricMSE.reset() images = netDe(eps2) fake_img1T = nd.concat(images[0],images[1], images[2], dim=1) fake_img2T = nd.concat(images[3],images[4], images[5], dim=1) fake_img3T = nd.concat(images[6],images[7], images[8], dim=1) fake_img = nd.concat(fake_img1T,fake_img2T, fake_img3T,dim=2) visual.visualize(fake_img) plt.savefig('outputs/'+expname+'_fakes_'+str(epoch)+'.png') text_file.close() # Do 10 iterations of sampler update fake_img1T = nd.concat(real_in[0],real_out[0], fake_out[0], dim=1) fake_img2T = nd.concat(real_in[1],real_out[1], fake_out[1], dim=1) fake_img3T = nd.concat(real_in[2],real_out[2], fake_out[2], dim=1) #fake_img4T = nd.concat(real_in[3],real_out[3], fake_out[3], dim=1) fake_img = nd.concat(fake_img1,fake_img2, fake_img3,fake_img1T,fake_img2T, fake_img3T,dim=2) visual.visualize(fake_img) plt.savefig('outputs/'+expname+'_'+str(epoch)+'.png') '''if epoch > 100: for ep2 in range(10): with autograd.record(): #eps = nd.random_normal(loc=0, scale=1, shape=noiseshape, ctx=ctx) # eps = nd.random.uniform( low=-1, high=1, shape=noiseshape, ctx=ctx) eps2 = nd.random_normal(loc=0, scale=0.02, shape=noiseshape, ctx=ctx) eps2 = nd.tanh(eps2*sigma+mu) eps2 = nd.concat(eps,eps2,dim=0) rec_output = netD(netDe(eps2)) fake_label = nd.zeros(rec_output.shape, ctx=ctx) errGS = GAN_loss(rec_output, fake_label) errGS.backward() mu -= lr / mu.shape[0] * mu.grad sigma -= lr / sigma.shape[0] * sigma.grad print('mu ' + str(mu[0,0,0,0].asnumpy())+ ' sigma '+ str(sigma[0,0,0,0].asnumpy())) ''' images = netDe(eps2) fake_img1T = nd.concat(images[0],images[1], images[2], dim=1) fake_img2T = nd.concat(images[3],images[4], images[5], dim=1) fake_img3T = nd.concat(images[6],images[7], images[8], dim=1) fake_img = nd.concat(fake_img1T,fake_img2T, fake_img3T,dim=2) visual.visualize(fake_img) plt.savefig('outputs/'+expname+'_fakespost_'+str(epoch)+'.png') return([loss_rec_D,loss_rec_G, loss_rec_R, acc_rec, loss_rec_D2, loss_rec_G2, acc2_rec])
def train(): image_pool = imagePool.ImagePool(pool_size) metric = mx.metric.CustomMetric(facc) stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') logging.basicConfig(level=logging.DEBUG) for epoch in range(epochs): tic = time.time() btic = time.time() train_data.reset() iter = 0 for batch in train_data: ############################ # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z))) ########################### real_in = batch.data[0].as_in_context(ctx) real_out = batch.data[1].as_in_context(ctx) fake_out = netG(real_in) fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1)) with autograd.record(): # Train with fake image # Use image pooling to utilize history images output = netD(fake_concat) fake_label = nd.zeros(output.shape, ctx=ctx) errD_fake = GAN_loss(output, fake_label) metric.update([ fake_label, ], [ output, ]) # Train with real image real_concat = nd.concat(real_in, real_out, dim=1) output = netD(real_concat) real_label = nd.ones(output.shape, ctx=ctx) errD_real = GAN_loss(output, real_label) errD = (errD_real + errD_fake) * 0.5 errD.backward() metric.update([ real_label, ], [ output, ]) trainerD.step(batch.data[0].shape[0]) ############################ # (2) Update G network: maximize log(D(x, G(x, z))) - lambda1 * L1(y, G(x, z)) ########################### with autograd.record(): fake_out = netG(real_in) fake_concat = nd.concat(real_in, fake_out, dim=1) output = netD(fake_concat) real_label = nd.ones(output.shape, ctx=ctx) errG = GAN_loss( output, real_label) + L1_loss(real_out, fake_out) * lambda1 errG.backward() trainerG.step(batch.data[0].shape[0]) # Print log infomation every ten batches if iter % 10 == 0: name, acc = metric.get() logging.info('speed: {} samples/s'.format( batch_size / (time.time() - btic))) logging.info( 'discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d' % (nd.mean(errD).asscalar(), nd.mean(errG).asscalar(), acc, iter, epoch)) iter = iter + 1 btic = time.time() name, acc = metric.get() metric.reset() logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc)) logging.info('time: %f' % (time.time() - tic)) if epoch % 10 == 0: filename = "checkpoints/testnet_" + str(epoch) + "_D.params" netD.save_params(filename) filename = "checkpoints/testnet_" + str(epoch) + "_G.params" netG.save_params(filename) # Visualize one generated image for each epoch fake_img = fake_out[0] visual.visualize(fake_img) plt.savefig('outputs/testnet_' + str(epoch) + '.png')
def train(pool_size, epochs, train_data, val_data, ctx, netEn, netDe, netD, trainerEn, trainerDe, trainerD, lambda1, batch_size, expname, append=True, useAE = False): text_file = open(expname + "_validtest.txt", "w") text_file.close() #netGT, netDT, _, _ = set_test_network(opt.depth, ctx, opt.lr, opt.beta1,opt.ndf, opt.ngf, opt.append) GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss() L1_loss = gluon.loss.L2Loss() image_pool = imagePool.ImagePool(pool_size) metric = mx.metric.CustomMetric(facc) metric2 = mx.metric.MSE() loss_rec_G = [] loss_rec_D = [] loss_rec_R = [] acc_rec = [] stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') logging.basicConfig(level=logging.DEBUG) for epoch in range(epochs): tic = time.time() btic = time.time() train_data.reset() iter = 0 #print('learning rate : '+str(trainerD.learning_rate )) for batch in train_data: ############################ # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z))) ########################### real_in = batch.data[0].as_in_context(ctx) real_out = batch.data[1].as_in_context(ctx) soft_zero = 1e-10 fake_latent= netEn(real_in) fake_latent = np.squeeze(fake_latent) mu_lv = nd.split(fake_latent, axis=1, num_outputs=2) mu = (mu_lv[0]) lv = (mu_lv[1]) KL = 0.5*nd.nansum(1+lv-mu*mu-nd.exp(lv+soft_zero)) eps = nd.random_normal(loc=0, scale=1, shape=(batch_size, 2048), ctx=ctx) z = mu + nd.exp(0.5*lv)*eps z = nd.expand_dims(nd.expand_dims(z,2),2) y = netDe(z) fake_out = y logloss = nd.nansum(real_in*nd.log(y+soft_zero)+ (1-real_in)*nd.log(1-y+soft_zero)) loss = -logloss-KL fake_concat = nd.concat(real_in, fake_out, dim=1) if append else fake_out with autograd.record(): # Train with fake image # Use image pooling to utilize history imagesi output = netD(fake_concat) fake_label = nd.zeros(output.shape, ctx=ctx) errD_fake = GAN_loss(output, fake_label) metric.update([fake_label, ], [output, ]) real_concat = nd.concat(real_in, real_out, dim=1) if append else real_out output = netD(real_concat) real_label = nd.ones(output.shape, ctx=ctx) errD_real = GAN_loss(output, real_label) errD = (errD_real + errD_fake) * 0.5 errD.backward() metric.update([real_label, ], [output, ]) trainerD.step(batch.data[0].shape[0]) ############################ # (2) Update G network: maximize log(D(x, G(x, z))) - lambda1 * L1(y, G(x, z)) ########################### with autograd.record(): fake_latent= np.squeeze(netEn(real_in)) mu_lv = nd.split(fake_latent, axis=1, num_outputs=2) mu = mu_lv[0] lv = mu_lv[1] KL = 0.5*nd.nansum(1+lv-mu*mu-nd.exp(lv+soft_zero)) eps = nd.random_normal(loc=0, scale=1, shape=(batch_size, 2048), ctx=ctx) #KL = 0.5*nd.nansum(1+lv-mu*mu-nd.exp(lv+soft_zero)) z = mu + nd.exp(0.5*lv)*eps z = nd.expand_dims(nd.expand_dims(z,2),2) y = netDe(z) fake_out = y logloss = nd.nansum((real_in+1)*0.5*nd.log(0.5*(y+1)+soft_zero)+ (1-0.5*(real_in+1))*nd.log(1-0.5*(y+1)+soft_zero)) loss =-logloss-KL fake_concat = nd.concat(real_in, fake_out, dim=1) if append else fake_out output = netD(fake_concat) real_label = nd.ones(output.shape, ctx=ctx) errG = GAN_loss(output, real_label) + loss*lambda1 #L1_loss(real_out, fake_out) * lambda1 errR = logloss#L1_loss(real_out, fake_out) errG.backward() trainerDe.step(batch.data[0].shape[0]) trainerEn.step(batch.data[0].shape[0]) loss_rec_G.append(nd.mean(errG).asscalar()-nd.mean(errR).asscalar()*lambda1) loss_rec_D.append(nd.mean(errD).asscalar()) loss_rec_R.append(nd.mean(errR).asscalar()) name, acc = metric.get() acc_rec.append(acc) # Print log infomation every ten batches if iter % 10 == 0: name, acc = metric.get() logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic))) #print(errD) logging.info('discriminator loss = %f, generator loss = %f, binary training acc = %f reconstruction error= %f at iter %d epoch %d' % (nd.mean(errD).asscalar(), nd.mean(errG).asscalar(), acc,nd.mean(errR).asscalar() ,iter, epoch)) iter = iter + 1 btic = time.time() name, acc = metric.get() metric.reset() train_data.reset() logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc)) logging.info('time: %f' % (time.time() - tic)) if epoch%10 ==0: text_file = open(expname + "_validtest.txt", "a") filename = "checkpoints/"+expname+"_"+str(epoch)+"_D.params" netD.save_params(filename) filename = "checkpoints/"+expname+"_"+str(epoch)+"_En.params" netEn.save_params(filename) filename = "checkpoints/"+expname+"_"+str(epoch)+"_De.params" netDe.save_params(filename) fake_img1 = nd.concat(real_in[0],real_out[0], fake_out[0], dim=1) fake_img2 = nd.concat(real_in[1],real_out[1], fake_out[1], dim=1) fake_img3 = nd.concat(real_in[2],real_out[2], fake_out[2], dim=1) fake_img4 = nd.concat(real_in[3],real_out[3], fake_out[3], dim=1) val_data.reset() text_file = open(expname + "_validtest.txt", "a") for vbatch in val_data: real_in = vbatch.data[0].as_in_context(ctx) real_out = vbatch.data[1].as_in_context(ctx) fake_latent= netEn(real_in) mu_lv = nd.split(fake_latent, axis=1, num_outputs=2) mu = mu_lv[0] lv = mu_lv[1] eps = nd.random_normal(loc=0, scale=1, shape=(batch_size/5, 2048,1,1), ctx=ctx) z = mu + nd.exp(0.5*lv)*eps y = netDe(z) fake_out = y KL = 0.5*nd.sum(1+lv-mu*mu-nd.exp(lv),axis=1) logloss = nd.sum(real_in*nd.log(y+soft_zero)+ (1-real_in)*nd.log(1-y+soft_zero), axis=1) loss = logloss+KL metric2.update([fake_out, ], [real_out, ]) _, acc2 = metric2.get() text_file.write("%s %s %s\n" % (str(epoch), nd.mean(errR).asscalar(), str(acc2))) metric2.reset() fake_img1T = nd.concat(real_in[0],real_out[0], fake_out[0], dim=1) fake_img2T = nd.concat(real_in[1],real_out[1], fake_out[1], dim=1) fake_img3T = nd.concat(real_in[2],real_out[2], fake_out[2], dim=1) #fake_img4T = nd.concat(real_in[3],real_out[3], fake_out[3], dim=1) fake_img = nd.concat(fake_img1,fake_img2, fake_img3,fake_img1T,fake_img2T, fake_img3T,dim=2) visual.visualize(fake_img) plt.savefig('outputs/'+expname+'_'+str(epoch)+'.png') text_file.close() return([loss_rec_D,loss_rec_G, loss_rec_R, acc_rec])
def train(cep, pool_size, epochs, train_data, val_data, ctx, netEn, netDe, netD, netD2, netDS, trainerEn, trainerDe, trainerD, trainerD2, trainerSD, lambda1, batch_size, expname, append=True, useAE=False): tp_file = open(expname + "_trainloss.txt", "w") tp_file.close() text_file = open(expname + "_validtest.txt", "w") text_file.close() #netGT, netDT, _, _ = set_test_network(opt.depth, ctx, opt.lr, opt.beta1,opt.ndf, opt.ngf, opt.append) GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss() L1_loss = gluon.loss.L2Loss() image_pool = imagePool.ImagePool(pool_size) metric = mx.metric.CustomMetric(facc) metric2 = mx.metric.CustomMetric(facc) metricStrong = mx.metric.CustomMetric(facc) metricMSE = mx.metric.MSE() loss_rec_G = [] loss_rec_D = [] loss_rec_R = [] acc_rec = [] acc2_rec = [] loss_rec_D2 = [] loss_rec_G2 = [] lr = 2.0 * 512 stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') logging.basicConfig(level=logging.DEBUG) if cep == -1: cep = 0 else: netEn.load_params('checkpoints/' + opt.expname + '_' + str(cep) + '_En.params', ctx=ctx) netDe.load_params('checkpoints/' + opt.expname + '_' + str(cep) + '_De.params', ctx=ctx) netD.load_params('checkpoints/' + opt.expname + '_' + str(cep) + '_D.params', ctx=ctx) netD2.load_params('checkpoints/' + opt.expname + '_' + str(cep) + '_D2.params', ctx=ctx) netDS.load_params('checkpoints/' + opt.expname + '_' + str(cep) + '_SD.params', ctx=ctx) iter = 0 for epoch in range(cep + 1, epochs): tic = time.time() btic = time.time() train_data.reset() #print('learning rate : '+str(trainerD.learning_rate )) for batch in train_data: ############################ # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z))) ########################### if ctx == mx.cpu(): ct = mx.cpu() else: ct = mx.gpu() real_in = batch.data[0] #.as_in_context(ctx) real_out = batch.data[1] #.as_in_context(ctx) if iter == 0: latent_shape = (batch_size, 512, 1, 1) #code.shape out_l_shape = (batch_size, 1, 1, 1) #netD2((code)).shape out_i_shape = (batch_size, 1, 1, 1) #netD(netDe(code)).shape out_s_shape = (batch_size, 1, 1, 1) #netSD(netDe(code)).shape real_in = gluon.utils.split_and_load(real_in, ctx) real_out = gluon.utils.split_and_load(real_out, ctx) fake_latent = [netEn(r) for r in real_in] real_latent = nd.random.uniform(low=-1, high=1, shape=latent_shape) real_latent = gluon.utils.split_and_load(real_latent, ctx) fake_out = [netDe(f) for f in fake_latent] fake_concat = nd.concat(real_in, fake_out, dim=1) if append else fake_out eps2 = nd.random.uniform(low=-1, high=1, shape=latent_shape, ctx=ct) eps2 = gluon.utils.split_and_load(eps2, ctx) if epoch > 150: # (1/float(batch_size))*512*150:# and epoch%10==0: print('Mining..') mu = nd.random.uniform(low=-1, high=1, shape=latent_shape, ctx=ct) #isigma = nd.ones((batch_size,64,1,1),ctx=ctx)*0.000001 mu.attach_grad() #sigma.attach_grad() images = netDe(mu) fake_img1T = nd.concat(images[0], images[1], images[2], dim=1) fake_img2T = nd.concat(images[3], images[4], images[5], dim=1) fake_img3T = nd.concat(images[6], images[7], images[8], dim=1) fake_img = nd.concat(fake_img1T, fake_img2T, fake_img3T, dim=2) visual.visualize(fake_img) plt.savefig('outputs/' + expname + '_fakespre_' + str(epoch) + '.png') eps2 = gluon.utils.split_and_load(mu, ctx) for e in eps2: e.attach_grad() for ep2 in range(1): with autograd.record(): #eps = nd.random_normal(loc=0, scale=1, shape=fake_latent.shape, ctx=ctx) # #eps2 = gluon.utils.split_and_load(nd.tanh(mu),ctx) #+nd.multiply(eps,sigma))#nd.random.uniform( low=-1, high=1, shape=fake_latent.shape, ctx=ctx) rec_output = [netDS(netDe(e)) for e in eps2] fake_label = gluon.utils.split_and_load( nd.zeros(out_s_shape), ctx) errGS = [ GAN_loss(r, f) for r, f in zip(rec_output, fake_label) ] for e in errGS: e.backward() for idx, _ in enumerate(eps2): eps2[idx] = nd.tanh(eps2[idx] - lr / eps2[idx].shape[0] * eps2[idx].grad) images = netDe((eps2[0])) fake_img1T = nd.concat(images[0], images[1], images[2], dim=1) fake_img2T = nd.concat(images[3], images[4], images[5], dim=1) fake_img3T = nd.concat(images[6], images[7], images[8], dim=1) fake_img = nd.concat(fake_img1T, fake_img2T, fake_img3T, dim=2) visual.visualize(fake_img) plt.savefig('outputs/' + expname + str(ep2) + '_fakespost_' + str(epoch) + '.png') #eps2 = nd.tanh(mu)#+nd.multiply(eps,sigma))#nd.random.uniform( low=-1, high=1, shape=fake_latent.shape, ctx=ctx) with autograd.record(): #eps2 = gluon.utils.split_and_load(eps2,ctx) # Train with fake image # Use image pooling to utilize history imagesi output = [netD(f) for f in fake_concat] output2 = [netD2(f) for f in fake_latent] fake_label = nd.zeros(out_i_shape) fake_label = gluon.utils.split_and_load(fake_label, ctx) fake_latent_label = nd.zeros(out_l_shape) fake_latent_label = gluon.utils.split_and_load( fake_latent_label, ctx) eps = gluon.utils.split_and_load( nd.random.uniform(low=-1, high=1, shape=latent_shape), ctx) rec_output = [netD(netDe(e)) for e in eps] errD_fake = [ GAN_loss(r, f) for r, f in zip(rec_output, fake_label) ] errD_fake2 = [ GAN_loss(o, f) for o, f in zip(output, fake_label) ] errD2_fake = [ GAN_loss(o, f) for o, f in zip(output2, fake_latent_label) ] for f, o in zip(fake_label, rec_output): metric.update([ f, ], [ o, ]) for f, o in zip(fake_latent_label, output2): metric2.update([ f, ], [ o, ]) real_concat = nd.concat(real_in, real_out, dim=1) if append else real_out output = [netD(r) for r in real_concat] output2 = [netD2(r) for r in real_latent] real_label = gluon.utils.split_and_load( nd.ones(out_i_shape), ctx) real_latent_label = gluon.utils.split_and_load( nd.ones(out_l_shape), ctx) errD_real = [ GAN_loss(o, r) for o, r in zip(output, real_label) ] errD2_real = [ GAN_loss(o, r) for o, r in zip(output2, real_latent_label) ] for e1, e2, e4, e5 in zip(errD_real, errD_fake, errD2_real, errD2_fake): err = (e1 + e2) * 0.5 + (e5 + e4) * 0.5 err.backward() for f, o in zip(real_label, output): metric.update([ f, ], [ o, ]) for f, o in zip(real_latent_label, output2): metric2.update([ f, ], [ o, ]) trainerD.step(batch.data[0].shape[0]) trainerD2.step(batch.data[0].shape[0]) nd.waitall() with autograd.record(): strong_output = [netDS(netDe(e)) for e in eps] strong_real = [netDS(f) for f in fake_concat] errs1 = [ GAN_loss(r, f) for r, f in zip(strong_output, fake_label) ] errs2 = [ GAN_loss(r, f) for r, f in zip(strong_real, real_label) ] for f, s in zip(fake_label, strong_output): metricStrong.update([ f, ], [ s, ]) for f, s in zip(real_label, strong_real): metricStrong.update([ f, ], [ s, ]) for e1, e2 in zip(errs1, errs2): strongerr = 0.5 * (e1 + e2) strongerr.backward() trainerSD.step(batch.data[0].shape[0]) nd.waitall() ############################ # (2) Update G network: maximize log(D(x, G(x, z))) - lambda1 * L1(y, G(x, z)) ########################### with autograd.record(): sh = out_l_shape #eps2 = nd.random_normal(loc=0, scale=1, shape=noiseshape, ctx=ctx) # #eps = nd.random.uniform( low=-1, high=1, shape=noiseshape, ctx=ctx) #if epoch>100: # eps2 = nd.multiply(eps2,sigma)+mu # eps2 = nd.tanh(eps2) #else: #eps = nd.random.uniform( low=-1, high=1, shape=noiseshape, ctx=ctx) #eps2 = nd.concat(eps,eps2,dim=0) rec_output = [netD(netDe(e)) for e in eps2] fake_latent = [(netEn(r)) for r in real_in] output2 = [netD2(f) for f in fake_latent] fake_out = [netDe(f) for f in fake_latent] fake_concat = nd.concat(real_in, fake_out, dim=1) if append else fake_out output = [netD(f) for f in fake_concat] real_label = gluon.utils.split_and_load( nd.ones(out_i_shape), ctx) real_latent_label = gluon.utils.split_and_load( nd.ones(out_l_shape), ctx) errG2 = [ GAN_loss(r, f) for r, f in zip(rec_output, real_label) ] errR = [ L1_loss(r, f) * lambda1 for r, f in zip(real_out, fake_out) ] errG = [ 10 * GAN_loss(r, f) for r, f in zip(output2, real_latent_label) ] # +errG2+errR for e1, e2, e3 in zip(errG, errG2, errR): e = e1 + e2 + e3 e.backward() trainerDe.step(batch.data[0].shape[0]) trainerEn.step(batch.data[0].shape[0]) nd.waitall() errD = (errD_real[0] + errD_fake[0]) * 0.5 errD2 = (errD2_real[0] + errD2_fake[0]) * 0.5 loss_rec_G2.append(nd.mean(errG2[0]).asscalar()) loss_rec_G.append( nd.mean(nd.mean(errG[0])).asscalar() - nd.mean(errG2[0]).asscalar() - nd.mean(errR[0]).asscalar()) loss_rec_D.append(nd.mean(errD[0]).asscalar()) loss_rec_R.append(nd.mean(errR[0]).asscalar()) loss_rec_D2.append(nd.mean(errD2[0]).asscalar()) _, acc2 = metric2.get() name, acc = metric.get() acc_rec.append(acc) acc2_rec.append(acc2) # Print log infomation every ten batches if iter % 10 == 0: _, acc2 = metric2.get() name, acc = metric.get() _, accStrong = metricStrong.get() logging.info('speed: {} samples/s'.format( batch_size / (time.time() - btic))) #print(errD) #logging.info('discriminator loss = %f, D2 loss = %f, generator loss = %f, G2 loss = %f, SD loss = %f, D acc = %f , D2 acc = %f, DS acc = %f, reconstruction error= %f at iter %d epoch %d' # % (nd.mean(errD[0]).asscalar(),nd.mean(errD2[0]).asscalar(), # nd.mean(errG[0]-errG2[0]-errR[0]).asscalar(),nd.mean(errG2[0]).asscalar(),nd.mean(strongerr[0]).asscalar() ,acc,acc2,accStrong[0],nd.mean(errR[0]).asscalar() ,iter, epoch)) iter = iter + 1 btic = time.time() name, acc = metric.get() _, acc2 = metric2.get() #tp_file = open(expname + "_trainloss.txt", "a") #tp_file.write(str(nd.mean(errG2).asscalar()) + " " + str( # nd.mean(nd.mean(errG)).asscalar() - nd.mean(errG2).asscalar() - nd.mean(errR).asscalar()) + " " + str( # nd.mean(errD).asscalar()) + " " + str(nd.mean(errD2).asscalar()) + " " + str(nd.mean(errR).asscalar()) +" "+str(acc) + " " + str(acc2)+"\n") #tp_file.close() metric.reset() metric2.reset() train_data.reset() metricStrong.reset() logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc)) logging.info('time: %f' % (time.time() - tic)) if epoch % 2 == 0: # and epoch>0: text_file = open(expname + "_validtest.txt", "a") filename = "checkpoints/" + expname + "_" + str( epoch) + "_D.params" netD.save_parameters(filename) filename = "checkpoints/" + expname + "_" + str( epoch) + "_D2.params" netD2.save_parameters(filename) filename = "checkpoints/" + expname + "_" + str( epoch) + "_En.params" netEn.save_parameters(filename) filename = "checkpoints/" + expname + "_" + str( epoch) + "_De.params" netDe.save_parameters(filename) filename = "checkpoints/" + expname + "_" + str( epoch) + "_SD.params" netDS.save_parameters(filename) fake_img1 = nd.concat(real_in[0], real_out[0], fake_out[0], dim=1) fake_img2 = nd.concat(real_in[1], real_out[1], fake_out[1], dim=1) fake_img3 = nd.concat(real_in[2], real_out[2], fake_out[2], dim=1) fake_img4 = nd.concat(real_in[3], real_out[3], fake_out[3], dim=1) val_data.reset() text_file = open(expname + "_validtest.txt", "a") for vbatch in val_data: real_in = vbatch.data[0] real_out = vbatch.data[1] real_in = gluon.utils.split_and_load(real_in, ctx) real_out = gluon.utils.split_and_load(real_out, ctx) fake_latent = [netEn(r) for r in real_in] fake_out = [netDe(f) for f in fake_latent] for f, r in zip(fake_out, real_out): metricMSE.update([ f, ], [ r, ]) _, acc2 = metricMSE.get() toterrR = 0 for e in errR: toterrR += nd.mean(e).asscalar() text_file.write("%s %s %s\n" % (str(epoch), toterrR, str(acc2))) metricMSE.reset() return ([ loss_rec_D, loss_rec_G, loss_rec_R, acc_rec, loss_rec_D2, loss_rec_G2, acc2_rec ])