def constrain_weight(self, weight_arr):
        square_weight_arr = weight_arr * weight_arr
        while nd.nansum(square_weight_arr) > self.__weight_limit:
            weight_arr = weight_arr * 0.9
            square_weight_arr = weight_arr * weight_arr

        return weight_arr
Example #2
0
 def log_pdf(self, y):
     return nd.sum(
         nd.nansum(y * nd.log_softmax(self.unnormalized_mean),
                   axis=0,
                   exclude=True))
Example #3
0
def softmax_cross_entropy(yhat_linear, y):  # 交叉熵损失
    # return - nd.nansum(y * nd.log_softmax(yhat_linear), axis=0, exclude=True)
    return -nd.nansum(
        y * nd.log(transform_softmax(yhat_linear)), axis=0, exclude=True)
Example #4
0
def train(pool_size, epochs, train_data, val_data,  ctx, netEn, netDe,  netD, trainerEn, trainerDe, trainerD, lambda1, batch_size, expname, append=True, useAE = False):
    
    text_file = open(expname + "_validtest.txt", "w")
    text_file.close()
    #netGT, netDT, _, _ = set_test_network(opt.depth, ctx, opt.lr, opt.beta1,opt.ndf,  opt.ngf, opt.append)
    GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
    L1_loss = gluon.loss.L2Loss()
    image_pool = imagePool.ImagePool(pool_size)
    metric = mx.metric.CustomMetric(facc)
    metric2 = mx.metric.MSE()
    loss_rec_G = []
    loss_rec_D = []
    loss_rec_R = []
    acc_rec = []
    stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')
    logging.basicConfig(level=logging.DEBUG)
    for epoch in range(epochs):

        tic = time.time()
        btic = time.time()
        train_data.reset()
        iter = 0
        #print('learning rate : '+str(trainerD.learning_rate ))
	for batch in train_data:
            ############################
            # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
            ###########################
            real_in = batch.data[0].as_in_context(ctx)
            real_out = batch.data[1].as_in_context(ctx)
            soft_zero = 1e-10
            fake_latent= netEn(real_in)
	    fake_latent = np.squeeze(fake_latent)
            mu_lv = nd.split(fake_latent, axis=1, num_outputs=2)
	    mu = (mu_lv[0])
            lv = (mu_lv[1])
	    KL = 0.5*nd.nansum(1+lv-mu*mu-nd.exp(lv+soft_zero))
            eps = nd.random_normal(loc=0, scale=1, shape=(batch_size, 2048), ctx=ctx)
            z = mu + nd.exp(0.5*lv)*eps
	    z = nd.expand_dims(nd.expand_dims(z,2),2)
            y = netDe(z)
            fake_out = y
	    
	    logloss = nd.nansum(real_in*nd.log(y+soft_zero)+ (1-real_in)*nd.log(1-y+soft_zero))
            loss = -logloss-KL
            fake_concat =  nd.concat(real_in, fake_out, dim=1) if append else  fake_out
            with autograd.record():
                # Train with fake image
                # Use image pooling to utilize history imagesi
                output = netD(fake_concat)
                fake_label = nd.zeros(output.shape, ctx=ctx)
                errD_fake = GAN_loss(output, fake_label)
                metric.update([fake_label, ], [output, ])
                real_concat =  nd.concat(real_in, real_out, dim=1) if append else  real_out
                output = netD(real_concat)
                real_label = nd.ones(output.shape, ctx=ctx)
                errD_real = GAN_loss(output, real_label)
                errD = (errD_real + errD_fake) * 0.5
                errD.backward()
                metric.update([real_label, ], [output, ])

            trainerD.step(batch.data[0].shape[0])

            ############################
            # (2) Update G network: maximize log(D(x, G(x, z))) - lambda1 * L1(y, G(x, z))
            ###########################
            with autograd.record():

                fake_latent= np.squeeze(netEn(real_in))
                mu_lv = nd.split(fake_latent, axis=1, num_outputs=2)
                mu = mu_lv[0]
                lv = mu_lv[1]
		KL = 0.5*nd.nansum(1+lv-mu*mu-nd.exp(lv+soft_zero))
                eps = nd.random_normal(loc=0, scale=1, shape=(batch_size, 2048), ctx=ctx)
		#KL = 0.5*nd.nansum(1+lv-mu*mu-nd.exp(lv+soft_zero))
                z = mu + nd.exp(0.5*lv)*eps
		z = nd.expand_dims(nd.expand_dims(z,2),2)
                y = netDe(z)
                fake_out = y
		logloss = nd.nansum((real_in+1)*0.5*nd.log(0.5*(y+1)+soft_zero)+ (1-0.5*(real_in+1))*nd.log(1-0.5*(y+1)+soft_zero))
                loss =-logloss-KL
                fake_concat =  nd.concat(real_in, fake_out, dim=1) if append else  fake_out
                output = netD(fake_concat)
                real_label = nd.ones(output.shape, ctx=ctx)
                errG = GAN_loss(output, real_label) + loss*lambda1 #L1_loss(real_out, fake_out) * lambda1
                errR = logloss#L1_loss(real_out, fake_out)
                errG.backward()
        trainerDe.step(batch.data[0].shape[0])	   
        trainerEn.step(batch.data[0].shape[0])
        loss_rec_G.append(nd.mean(errG).asscalar()-nd.mean(errR).asscalar()*lambda1)
        loss_rec_D.append(nd.mean(errD).asscalar())
        loss_rec_R.append(nd.mean(errR).asscalar())
        name, acc = metric.get()
        acc_rec.append(acc)
        # Print log infomation every ten batches
        if iter % 10 == 0:
                name, acc = metric.get()
                logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
                #print(errD)
		logging.info('discriminator loss = %f, generator loss = %f, binary training acc = %f reconstruction error= %f at iter %d epoch %d'
                    	% (nd.mean(errD).asscalar(),
                      	nd.mean(errG).asscalar(), acc,nd.mean(errR).asscalar() ,iter, epoch))
        iter = iter + 1
        btic = time.time()

        name, acc = metric.get()
        metric.reset()
        train_data.reset()

        logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))
        logging.info('time: %f' % (time.time() - tic))
        if epoch%10 ==0:
            text_file = open(expname + "_validtest.txt", "a")
            filename = "checkpoints/"+expname+"_"+str(epoch)+"_D.params"
            netD.save_params(filename)
            filename = "checkpoints/"+expname+"_"+str(epoch)+"_En.params"
            netEn.save_params(filename)
            filename = "checkpoints/"+expname+"_"+str(epoch)+"_De.params"
            netDe.save_params(filename)
            fake_img1 = nd.concat(real_in[0],real_out[0], fake_out[0], dim=1)
            fake_img2 = nd.concat(real_in[1],real_out[1], fake_out[1], dim=1)
            fake_img3 = nd.concat(real_in[2],real_out[2], fake_out[2], dim=1)
            fake_img4 = nd.concat(real_in[3],real_out[3], fake_out[3], dim=1)
            val_data.reset()
            text_file = open(expname + "_validtest.txt", "a")
            for vbatch in val_data:
                
            	real_in = vbatch.data[0].as_in_context(ctx)
            	real_out = vbatch.data[1].as_in_context(ctx)

            	fake_latent= netEn(real_in)
            	mu_lv = nd.split(fake_latent, axis=1, num_outputs=2)
            	mu = mu_lv[0]
            	lv = mu_lv[1]
            	eps = nd.random_normal(loc=0, scale=1, shape=(batch_size/5, 2048,1,1), ctx=ctx)
            	z = mu + nd.exp(0.5*lv)*eps
            	y = netDe(z)
            	fake_out = y
            	KL = 0.5*nd.sum(1+lv-mu*mu-nd.exp(lv),axis=1)
            	logloss = nd.sum(real_in*nd.log(y+soft_zero)+ (1-real_in)*nd.log(1-y+soft_zero), axis=1)
            	loss = logloss+KL
            	metric2.update([fake_out, ], [real_out, ])
            	_, acc2 = metric2.get()
            text_file.write("%s %s %s\n" % (str(epoch), nd.mean(errR).asscalar(), str(acc2)))
            metric2.reset()

            fake_img1T = nd.concat(real_in[0],real_out[0], fake_out[0], dim=1)
            fake_img2T = nd.concat(real_in[1],real_out[1], fake_out[1], dim=1)
            fake_img3T = nd.concat(real_in[2],real_out[2], fake_out[2], dim=1)
            #fake_img4T = nd.concat(real_in[3],real_out[3], fake_out[3], dim=1)
            fake_img = nd.concat(fake_img1,fake_img2, fake_img3,fake_img1T,fake_img2T, fake_img3T,dim=2)
            visual.visualize(fake_img)
            plt.savefig('outputs/'+expname+'_'+str(epoch)+'.png')
            text_file.close()
    return([loss_rec_D,loss_rec_G, loss_rec_R, acc_rec])
Example #5
0
 def softmax_cross_entropy(yhat_linear, y):
     return -nd.nansum(
         y * nd.log_softmax(yhat_linear), axis=0, exclude=True)