def constrain_weight(self, weight_arr): square_weight_arr = weight_arr * weight_arr while nd.nansum(square_weight_arr) > self.__weight_limit: weight_arr = weight_arr * 0.9 square_weight_arr = weight_arr * weight_arr return weight_arr
def log_pdf(self, y): return nd.sum( nd.nansum(y * nd.log_softmax(self.unnormalized_mean), axis=0, exclude=True))
def softmax_cross_entropy(yhat_linear, y): # 交叉熵损失 # return - nd.nansum(y * nd.log_softmax(yhat_linear), axis=0, exclude=True) return -nd.nansum( y * nd.log(transform_softmax(yhat_linear)), axis=0, exclude=True)
def train(pool_size, epochs, train_data, val_data, ctx, netEn, netDe, netD, trainerEn, trainerDe, trainerD, lambda1, batch_size, expname, append=True, useAE = False): text_file = open(expname + "_validtest.txt", "w") text_file.close() #netGT, netDT, _, _ = set_test_network(opt.depth, ctx, opt.lr, opt.beta1,opt.ndf, opt.ngf, opt.append) GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss() L1_loss = gluon.loss.L2Loss() image_pool = imagePool.ImagePool(pool_size) metric = mx.metric.CustomMetric(facc) metric2 = mx.metric.MSE() loss_rec_G = [] loss_rec_D = [] loss_rec_R = [] acc_rec = [] stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') logging.basicConfig(level=logging.DEBUG) for epoch in range(epochs): tic = time.time() btic = time.time() train_data.reset() iter = 0 #print('learning rate : '+str(trainerD.learning_rate )) for batch in train_data: ############################ # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z))) ########################### real_in = batch.data[0].as_in_context(ctx) real_out = batch.data[1].as_in_context(ctx) soft_zero = 1e-10 fake_latent= netEn(real_in) fake_latent = np.squeeze(fake_latent) mu_lv = nd.split(fake_latent, axis=1, num_outputs=2) mu = (mu_lv[0]) lv = (mu_lv[1]) KL = 0.5*nd.nansum(1+lv-mu*mu-nd.exp(lv+soft_zero)) eps = nd.random_normal(loc=0, scale=1, shape=(batch_size, 2048), ctx=ctx) z = mu + nd.exp(0.5*lv)*eps z = nd.expand_dims(nd.expand_dims(z,2),2) y = netDe(z) fake_out = y logloss = nd.nansum(real_in*nd.log(y+soft_zero)+ (1-real_in)*nd.log(1-y+soft_zero)) loss = -logloss-KL fake_concat = nd.concat(real_in, fake_out, dim=1) if append else fake_out with autograd.record(): # Train with fake image # Use image pooling to utilize history imagesi output = netD(fake_concat) fake_label = nd.zeros(output.shape, ctx=ctx) errD_fake = GAN_loss(output, fake_label) metric.update([fake_label, ], [output, ]) real_concat = nd.concat(real_in, real_out, dim=1) if append else real_out output = netD(real_concat) real_label = nd.ones(output.shape, ctx=ctx) errD_real = GAN_loss(output, real_label) errD = (errD_real + errD_fake) * 0.5 errD.backward() metric.update([real_label, ], [output, ]) trainerD.step(batch.data[0].shape[0]) ############################ # (2) Update G network: maximize log(D(x, G(x, z))) - lambda1 * L1(y, G(x, z)) ########################### with autograd.record(): fake_latent= np.squeeze(netEn(real_in)) mu_lv = nd.split(fake_latent, axis=1, num_outputs=2) mu = mu_lv[0] lv = mu_lv[1] KL = 0.5*nd.nansum(1+lv-mu*mu-nd.exp(lv+soft_zero)) eps = nd.random_normal(loc=0, scale=1, shape=(batch_size, 2048), ctx=ctx) #KL = 0.5*nd.nansum(1+lv-mu*mu-nd.exp(lv+soft_zero)) z = mu + nd.exp(0.5*lv)*eps z = nd.expand_dims(nd.expand_dims(z,2),2) y = netDe(z) fake_out = y logloss = nd.nansum((real_in+1)*0.5*nd.log(0.5*(y+1)+soft_zero)+ (1-0.5*(real_in+1))*nd.log(1-0.5*(y+1)+soft_zero)) loss =-logloss-KL fake_concat = nd.concat(real_in, fake_out, dim=1) if append else fake_out output = netD(fake_concat) real_label = nd.ones(output.shape, ctx=ctx) errG = GAN_loss(output, real_label) + loss*lambda1 #L1_loss(real_out, fake_out) * lambda1 errR = logloss#L1_loss(real_out, fake_out) errG.backward() trainerDe.step(batch.data[0].shape[0]) trainerEn.step(batch.data[0].shape[0]) loss_rec_G.append(nd.mean(errG).asscalar()-nd.mean(errR).asscalar()*lambda1) loss_rec_D.append(nd.mean(errD).asscalar()) loss_rec_R.append(nd.mean(errR).asscalar()) name, acc = metric.get() acc_rec.append(acc) # Print log infomation every ten batches if iter % 10 == 0: name, acc = metric.get() logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic))) #print(errD) logging.info('discriminator loss = %f, generator loss = %f, binary training acc = %f reconstruction error= %f at iter %d epoch %d' % (nd.mean(errD).asscalar(), nd.mean(errG).asscalar(), acc,nd.mean(errR).asscalar() ,iter, epoch)) iter = iter + 1 btic = time.time() name, acc = metric.get() metric.reset() train_data.reset() logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc)) logging.info('time: %f' % (time.time() - tic)) if epoch%10 ==0: text_file = open(expname + "_validtest.txt", "a") filename = "checkpoints/"+expname+"_"+str(epoch)+"_D.params" netD.save_params(filename) filename = "checkpoints/"+expname+"_"+str(epoch)+"_En.params" netEn.save_params(filename) filename = "checkpoints/"+expname+"_"+str(epoch)+"_De.params" netDe.save_params(filename) fake_img1 = nd.concat(real_in[0],real_out[0], fake_out[0], dim=1) fake_img2 = nd.concat(real_in[1],real_out[1], fake_out[1], dim=1) fake_img3 = nd.concat(real_in[2],real_out[2], fake_out[2], dim=1) fake_img4 = nd.concat(real_in[3],real_out[3], fake_out[3], dim=1) val_data.reset() text_file = open(expname + "_validtest.txt", "a") for vbatch in val_data: real_in = vbatch.data[0].as_in_context(ctx) real_out = vbatch.data[1].as_in_context(ctx) fake_latent= netEn(real_in) mu_lv = nd.split(fake_latent, axis=1, num_outputs=2) mu = mu_lv[0] lv = mu_lv[1] eps = nd.random_normal(loc=0, scale=1, shape=(batch_size/5, 2048,1,1), ctx=ctx) z = mu + nd.exp(0.5*lv)*eps y = netDe(z) fake_out = y KL = 0.5*nd.sum(1+lv-mu*mu-nd.exp(lv),axis=1) logloss = nd.sum(real_in*nd.log(y+soft_zero)+ (1-real_in)*nd.log(1-y+soft_zero), axis=1) loss = logloss+KL metric2.update([fake_out, ], [real_out, ]) _, acc2 = metric2.get() text_file.write("%s %s %s\n" % (str(epoch), nd.mean(errR).asscalar(), str(acc2))) metric2.reset() fake_img1T = nd.concat(real_in[0],real_out[0], fake_out[0], dim=1) fake_img2T = nd.concat(real_in[1],real_out[1], fake_out[1], dim=1) fake_img3T = nd.concat(real_in[2],real_out[2], fake_out[2], dim=1) #fake_img4T = nd.concat(real_in[3],real_out[3], fake_out[3], dim=1) fake_img = nd.concat(fake_img1,fake_img2, fake_img3,fake_img1T,fake_img2T, fake_img3T,dim=2) visual.visualize(fake_img) plt.savefig('outputs/'+expname+'_'+str(epoch)+'.png') text_file.close() return([loss_rec_D,loss_rec_G, loss_rec_R, acc_rec])
def softmax_cross_entropy(yhat_linear, y): return -nd.nansum( y * nd.log_softmax(yhat_linear), axis=0, exclude=True)