示例#1
0
文件: cvpriter.py 项目: zkzt/OCGAN
def train(pool_size,
          epochs,
          train_data,
          val_data,
          ctx,
          netG,
          netD,
          trainerG,
          trainerD,
          lambda1,
          batch_size,
          expname,
          append=True,
          useAE=False):

    text_file = open(expname + "_validtest.txt", "w")
    text_file.close()
    #netGT, netDT, _, _ = set_test_network(opt.depth, ctx, opt.lr, opt.beta1,opt.ndf,  opt.ngf, opt.append)
    dlr = trainerD.learning_rate
    glr = trainerG.learning_rate
    GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
    L1_loss = gluon.loss.L2Loss()
    image_pool = imagePool.ImagePool(pool_size)
    metric = mx.metric.CustomMetric(facc)
    metric2 = mx.metric.MSE()
    loss_rec_G = []
    loss_rec_D = []
    loss_rec_R = []
    acc_rec = []
    stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')
    logging.basicConfig(level=logging.DEBUG)
    for epoch in range(epochs):

        if useAE:
            for batch in train_data:
                train_data.reset()
                real_in = batch.data[0].as_in_context(ctx)
                real_out = batch.data[1].as_in_context(ctx)
                fake_out = netG(real_in)
                loss = L1_loss(real_out, fake_out)
                loss.backward()
                trainerG.step(batch.data[0].shape[0])
                metric2.update([
                    real_out,
                ], [
                    fake_out,
                ])
                if epoch % 10 == 0:
                    filename = "checkpoints/" + expname + "_" + str(
                        epoch) + "_G.params"
                    netG.save_params(filename)
                    fake_img1 = nd.concat(real_in[0],
                                          real_out[0],
                                          fake_out[0],
                                          dim=1)
                    fake_img2 = nd.concat(real_in[1],
                                          real_out[1],
                                          fake_out[1],
                                          dim=1)
                    fake_img3 = nd.concat(real_in[2],
                                          real_out[2],
                                          fake_out[2],
                                          dim=1)
                    fake_img4 = nd.concat(real_in[3],
                                          real_out[3],
                                          fake_out[3],
                                          dim=1)
                    #fake_img4T = nd.concat(real_in[3],real_out[3], fake_out[3], dim=1)
                    fake_img = nd.concat(fake_img1,
                                         fake_img2,
                                         fake_img3,
                                         dim=2)
                    visual.visualize(fake_img)
                    plt.savefig('outputs/' + expname + '_' + str(epoch) +
                                '.png')
            train_data.reset()
            name, acc = metric.get()
            metric2.reset()
            print("training acc: " + acc)

        else:
            tic = time.time()
            btic = time.time()
            train_data.reset()
            iter = 0
            if epoch > 250:
                trainerD.set_learning_rate(dlr * (1 - int(epoch - 250) / 1000))
                trainerG.set_learning_rate(glr * (1 - int(epoch - 250) / 1000))
            #print('learning rate : '+str(trainerD.learning_rate ))
            for batch in train_data:
                ############################
                # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
                ###########################
                real_in = batch.data[0].as_in_context(ctx)
                real_out = batch.data[1].as_in_context(ctx)

                fake_out = netG(real_in)
                fake_concat = nd.concat(real_in, fake_out,
                                        dim=1) if append else fake_out
                with autograd.record():
                    # Train with fake image
                    # Use image pooling to utilize history images
                    output = netD(fake_concat)
                    fake_label = nd.zeros(output.shape, ctx=ctx)
                    errD_fake = GAN_loss(output, fake_label)
                    metric.update([
                        fake_label,
                    ], [
                        output,
                    ])
                    real_concat = nd.concat(real_in, real_out,
                                            dim=1) if append else real_out
                    output = netD(real_concat)
                    real_label = nd.ones(output.shape, ctx=ctx)
                    errD_real = GAN_loss(output, real_label)
                    errD = (errD_real + errD_fake) * 0.5
                    errD.backward()
                    metric.update([
                        real_label,
                    ], [
                        output,
                    ])

                trainerD.step(batch.data[0].shape[0])

                ############################
                # (2) Update G network: maximize log(D(x, G(x, z))) - lambda1 * L1(y, G(x, z))
                ###########################
                with autograd.record():
                    fake_out = netG(real_in)
                    fake_concat = nd.concat(real_in, fake_out,
                                            dim=1) if append else fake_out
                    output = netD(fake_concat)
                    real_label = nd.ones(output.shape, ctx=ctx)
                    errG = GAN_loss(output, real_label) + L1_loss(
                        real_out, fake_out) * lambda1
                    errR = L1_loss(real_out, fake_out)
                    errG.backward()

                trainerG.step(batch.data[0].shape[0])
                loss_rec_G.append(
                    nd.mean(errG).asscalar() -
                    nd.mean(errR).asscalar() * lambda1)
                loss_rec_D.append(nd.mean(errD).asscalar())
                loss_rec_R.append(nd.mean(errR).asscalar())
                name, acc = metric.get()
                acc_rec.append(acc)
                # Print log infomation every ten batches
                if iter % 5 == 0:
                    name, acc = metric.get()
                    logging.info('speed: {} samples/s'.format(
                        batch_size / (time.time() - btic)))
                    #print(errD)
                    logging.info(
                        'discriminator loss = %f, generator loss = %f, binary training acc = %f reconstruction error= %f at iter %d epoch %d'
                        % (nd.mean(errD).asscalar(), nd.mean(errG).asscalar(),
                           acc, nd.mean(errR).asscalar(), iter, epoch))
                iter = iter + 1
                btic = time.time()

            name, acc = metric.get()
            metric.reset()
            train_data.reset()

            logging.info('\nbinary training acc at epoch %d: %s=%f' %
                         (epoch, name, acc))
            logging.info('time: %f' % (time.time() - tic))
            if epoch % 5 == 0:
                text_file = open(expname + "_validtest.txt", "a")
                filename = "checkpoints/" + expname + "_" + str(
                    epoch) + "_D.params"
                netD.save_params(filename)
                filename = "checkpoints/" + expname + "_" + str(
                    epoch) + "_G.params"
                netG.save_params(filename)
                fake_img1 = nd.concat(real_in[0],
                                      real_out[0],
                                      fake_out[0],
                                      dim=1)
                fake_img2 = nd.concat(real_in[1],
                                      real_out[1],
                                      fake_out[1],
                                      dim=1)
                fake_img3 = nd.concat(real_in[2],
                                      real_out[2],
                                      fake_out[2],
                                      dim=1)
                fake_img4 = nd.concat(real_in[3],
                                      real_out[3],
                                      fake_out[3],
                                      dim=1)
                val_data.reset()

                for vbatch in val_data:

                    real_in = vbatch.data[0].as_in_context(ctx)
                    real_out = vbatch.data[1].as_in_context(ctx)
                    fake_out = netG(real_in)
                    metric2.update([
                        fake_out,
                    ], [
                        real_out,
                    ])
                _, acc2 = metric2.get()
                text_file.write(
                    "%s %s %s\n" %
                    (str(epoch), nd.mean(errR).asscalar(), str(acc2)))
                metric2.reset()

                fake_img1T = nd.concat(real_in[0],
                                       real_out[0],
                                       fake_out[0],
                                       dim=1)
                fake_img2T = nd.concat(real_in[1],
                                       real_out[1],
                                       fake_out[1],
                                       dim=1)
                fake_img3T = nd.concat(real_in[2],
                                       real_out[2],
                                       fake_out[2],
                                       dim=1)
                #fake_img4T = nd.concat(real_in[3],real_out[3], fake_out[3], dim=1)
                fake_img = nd.concat(fake_img1,
                                     fake_img2,
                                     fake_img3,
                                     fake_img1T,
                                     fake_img2T,
                                     fake_img3T,
                                     dim=2)
                visual.visualize(fake_img)
                plt.savefig('outputs/' + expname + '_' + str(epoch) + '.png')
                text_file.close()

    return ([loss_rec_D, loss_rec_G, loss_rec_R, acc_rec])
示例#2
0
文件: threeway.py 项目: zkzt/OCGAN
def train(pool_size, epochs, train_data, ctx, netEn, netDe, netD, trainerEn, trainerDe, trainerD, lambda1, batch_size, expname):

    threewayloss =gluon.loss.SoftmaxCrossEntropyLoss()
    GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
    L1_loss = gluon.loss.L1Loss()
    image_pool = imagePool.ImagePool(pool_size)
    metric = mx.metric.CustomMetric(facc)

    stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')
    logging.basicConfig(level=logging.DEBUG)

    for epoch in range(epochs):
        tic = time.time()
        btic = time.time()
        train_data.reset()
        iter = 0
        for batch in train_data:
            ############################
            # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
            ###########################
            real_in = batch.data[0].as_in_context(ctx)
            real_out = batch.data[1].as_in_context(ctx)
            tempout = netEn(real_in)
            fake_out = netDe(tempout)
            fake_concat = fake_out
            #fake_concat = image_pool.query(fake_out)
            #fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1))
            with autograd.record():
                # Train with fake image
                # Use image pooling to utilize history images
                output = netD(fake_concat)
                fake_label = nd.zeros(output.shape[0], ctx=ctx)
                errD_fake = threewayloss(output, fake_label)
                metric.update([fake_label, ], [output, ])

                

                # Train with real image
                real_concat = real_out
                output = netD(real_concat)
                real_label = nd.ones(output.shape[0], ctx=ctx)
                errD_real = threewayloss(output, real_label)
                metric.update([real_label, ], [output, ])


                #train with abnormal image
                abinput = nd.random.uniform(-1,1,tempout.shape,ctx=ctx)
                aboutput =netD( netDe(abinput))
		#print(aboutput.shape)
		#print(output.shape)
                ab_label = 2*nd.ones(aboutput.shape[0], ctx=ctx)
                errD_ab = threewayloss(aboutput, ab_label)
                errD = (errD_real + errD_fake + errD_ab) * 0.33
                errD.backward()
                

            trainerD.step(batch.data[0].shape[0])

            ############################
            # (2) Update G network: maximize log(D(x, G(x, z))) - lambda1 * L1(y, G(x, z))
            ###########################
            with autograd.record():
                fake_out = netDe(netEn(real_in))
                fake_concat = fake_out
                output = netD(fake_concat)
                real_label = nd.ones(output.shape[0], ctx=ctx)
                errG = threewayloss(output, real_label) + L1_loss(real_out, fake_out) * lambda1
                errR = L1_loss(real_out, fake_out)
                errG.backward()

            trainerEn.step(batch.data[0].shape[0])
            trainerDe.step(batch.data[0].shape[0])

            # Print log infomation every ten batches
            if iter % 10 == 0:
                name, acc = metric.get()
                logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
                logging.info(
                    'discriminator loss = %f, generator loss = %f, latent error = %f,  binary training acc = %f, reconstruction error= %f at iter %d epoch %d'
                    % (nd.mean(errD).asscalar(),
                       nd.mean(errG).asscalar(), nd.mean(errD_ab).asscalar()   , acc,nd.mean(errR).asscalar() ,iter, epoch))
            iter = iter + 1
            btic = time.time()

        name, acc = metric.get()
        metric.reset()
        logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))
        logging.info('time: %f' % (time.time() - tic))
        if epoch%10 ==0:
            filename = "checkpoints/"+expname+"_"+str(epoch)+"_D.params"
            netD.save_params(filename)
            filename = "checkpoints/"+expname+"_"+str(epoch)+"_En.params"
            netEn.save_params(filename)
            filename = "checkpoints/"+expname+"_"+str(epoch)+"_De.params"
            netDe.save_params(filename)
            # Visualize one generated image for each epoch
            fake_img = nd.concat(real_in[0],real_out[0], fake_out[0], dim=1)
            visual.visualize(fake_img)
            plt.savefig('outputs/'+expname+'_'+str(epoch)+'.png')
示例#3
0
def train(pool_size, epochs, train_data, val_data,  ctx, netEn, netDe,  netD, netD2, trainerEn, trainerDe, trainerD, trainerD2, lambda1, batch_size, expname,  append=True, useAE = False):
    tp_file = open(expname + "_trainloss.txt", "w")  
    tp_file.close()  
    text_file = open(expname + "_validtest.txt", "w")
    text_file.close()
    #netGT, netDT, _, _ = set_test_network(opt.depth, ctx, opt.lr, opt.beta1,opt.ndf,  opt.ngf, opt.append)
    GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
    L1_loss = gluon.loss.L2Loss()
    image_pool = imagePool.ImagePool(pool_size)
    metric = mx.metric.CustomMetric(facc)
    metric2 = mx.metric.CustomMetric(facc)
    metricMSE = mx.metric.MSE()
    loss_rec_G = []
    loss_rec_D = []
    loss_rec_R = []
    acc_rec = []
    acc2_rec = []
    loss_rec_D2 = []
    loss_rec_G2 = []
    lr = 0.002
    #mu = nd.random_normal(loc=0, scale=1, shape=(batch_size/2,64,1,1), ctx=ctx) 
    mu = nd.random.uniform(low= -1, high=1, shape=(batch_size/2,64,1,1),ctx=ctx)
    #mu =  nd.zeros((batch_size/2,64,1,1),ctx=ctx)
    sigma = nd.ones((64,1,1),ctx=ctx)
    mu.attach_grad()
    sigma.attach_grad()    
    stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')
    logging.basicConfig(level=logging.DEBUG)
    for epoch in range(epochs):

        tic = time.time()
        btic = time.time()
        train_data.reset()
        iter = 0
        #print('learning rate : '+str(trainerD.learning_rate ))
        for batch in train_data:
            ############################
            # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
            ###########################
            real_in = batch.data[0].as_in_context(ctx)
            real_out = batch.data[1].as_in_context(ctx)
            fake_latent= netEn(real_in)
            #real_latent = nd.random_normal(loc=0, scale=1, shape=fake_latent.shape, ctx=ctx)
            real_latent = nd.multiply(nd.power(sigma,2),nd.random_normal(loc=0, scale=1, shape=fake_latent.shape, ctx=ctx))
	    #nd.random.uniform( low=-1, high=1, shape=fake_latent.shape, ctx=ctx)
	    fake_out = netDe(fake_latent)
            fake_concat =  nd.concat(real_in, fake_out, dim=1) if append else  fake_out
            with autograd.record():
                # Train with fake image
                # Use image pooling to utilize history imagesi
                output = netD(fake_concat)
                output2 = netD2(fake_latent)
                fake_label = nd.zeros(output.shape, ctx=ctx)
                fake_latent_label = nd.zeros(output2.shape, ctx=ctx)
		noiseshape = (fake_latent.shape[0]/2,fake_latent.shape[1],fake_latent.shape[2],fake_latent.shape[3])
                eps2 = nd.multiply(nd.power(sigma,2),nd.random_normal(loc=0, scale=1, shape=fake_latent.shape, ctx=ctx))
		#eps2 = nd.random_normal(loc=0, scale=sigma.asscalar(), shape=fake_latent.shape, ctx=ctx) #
		#eps = nd.random.uniform( low=-1, high=1, shape=noiseshape, ctx=ctx)
		rec_output = netD(netDe(eps2))
                errD_fake = GAN_loss(rec_output, fake_label)
                errD_fake2 = GAN_loss(output, fake_label)
                errD2_fake = GAN_loss(output2, fake_latent_label)
                metric.update([fake_label, ], [output, ])
                metric2.update([fake_latent_label, ], [output2, ])
                real_concat =  nd.concat(real_in, real_out, dim=1) if append else  real_out
                output = netD(real_concat)
                output2 = netD2(real_latent)
                real_label = nd.ones(output.shape, ctx=ctx)
                real_latent_label =  nd.ones(output2.shape, ctx=ctx)
                errD_real = GAN_loss(output, real_label)
                errD2_real =  GAN_loss(output2, real_latent_label)
                #errD = (errD_real + 0.5*(errD_fake+errD_fake2)) * 0.5
                errD = (errD_real + errD_fake) * 0.5
                errD2 = (errD2_real + errD2_fake) * 0.5
		totalerrD = errD+errD2
                totalerrD.backward()
                #errD2.backward()
                metric.update([real_label, ], [output, ])
            	metric2.update([real_latent_label, ], [output2, ])
            trainerD.step(batch.data[0].shape[0])
            trainerD2.step(batch.data[0].shape[0])
            ############################
            # (2) Update G network: maximize log(D(x, G(x, z))) - lambda1 * L1(y, G(x, z))
            ###########################
            with autograd.record():
		sh = fake_latent.shape
		eps2 = nd.multiply(nd.power(sigma,2),nd.random_normal(loc=0, scale=1, shape=fake_latent.shape, ctx=ctx))
                #eps2 = nd.random_normal(loc=0, scale=sigma.asscalar(), shape=fake_latent.shape, ctx=ctx) #
		#eps = nd.random.uniform( low=-1, high=1, shape=noiseshape, ctx=ctx)
		rec_output = netD(netDe(eps2))
                fake_latent= (netEn(real_in))
                output2 = netD2(fake_latent)
                fake_out = netDe(fake_latent)
                fake_concat =  nd.concat(real_in, fake_out, dim=1) if append else  fake_out
                output = netD(fake_concat)
                real_label = nd.ones(output.shape, ctx=ctx)
                real_latent_label = nd.ones(output2.shape, ctx=ctx)
                errG2 = GAN_loss(rec_output, real_label)
                errR = L1_loss(real_out, fake_out) * lambda1
		errG = 10.0*GAN_loss(output2, real_latent_label)+errG2+errR+nd.mean(nd.power(sigma,2))
		errG.backward()
	    if epoch>50:
	    	sigma -= lr / sigma.shape[0] * sigma.grad
	    	print(sigma)
            trainerDe.step(batch.data[0].shape[0])
            trainerEn.step(batch.data[0].shape[0])
            loss_rec_G2.append(nd.mean(errG2).asscalar())
            loss_rec_G.append(nd.mean(nd.mean(errG)).asscalar()-nd.mean(errG2).asscalar()-nd.mean(errR).asscalar())
            loss_rec_D.append(nd.mean(errD).asscalar())
            loss_rec_R.append(nd.mean(errR).asscalar())
            loss_rec_D2.append(nd.mean(errD2).asscalar())
            _, acc2 = metric2.get()
            name, acc = metric.get()
            acc_rec.append(acc)
            acc2_rec.append(acc2)

            # Print log infomation every ten batches
            if iter % 10 == 0:
                _, acc2 = metric2.get()
                name, acc = metric.get()
                logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
                #print(errD)
		logging.info('discriminator loss = %f, D2 loss = %f, generator loss = %f, G2 loss = %f,  binary training acc = %f , D2 acc = %f, reconstruction error= %f  at iter %d epoch %d'
                    	% (nd.mean(errD).asscalar(),nd.mean(errD2).asscalar(),
                      	nd.mean(errG-errG2-errR).asscalar(),nd.mean(errG2).asscalar(), acc,acc2,nd.mean(errR).asscalar() ,iter, epoch))
                iter = iter + 1
        btic = time.time()
        name, acc = metric.get()
        _, acc2 = metric2.get()
        tp_file = open(expname + "_trainloss.txt", "a")
        tp_file.write(str(nd.mean(errG2).asscalar()) + " " + str(
            nd.mean(nd.mean(errG)).asscalar() - nd.mean(errG2).asscalar() - nd.mean(errR).asscalar()) + " " + str(
            nd.mean(errD).asscalar()) + " " + str(nd.mean(errD2).asscalar()) + " " + str(nd.mean(errR).asscalar()) +" "+str(acc) + " " + str(acc2)+"\n")
        tp_file.close()
        metric.reset()
        metric2.reset()
        train_data.reset()

        logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))
        logging.info('time: %f' % (time.time() - tic))
        if epoch%10 ==0:# and epoch>0:
            text_file = open(expname + "_validtest.txt", "a")
            filename = "checkpoints/"+expname+"_"+str(epoch)+"_D.params"
            netD.save_params(filename)
            filename = "checkpoints/"+expname+"_"+str(epoch)+"_D2.params"
            netD2.save_params(filename)
            filename = "checkpoints/"+expname+"_"+str(epoch)+"_En.params"
            netEn.save_params(filename)
            filename = "checkpoints/"+expname+"_"+str(epoch)+"_De.params"
            netDe.save_params(filename)
            fake_img1 = nd.concat(real_in[0],real_out[0], fake_out[0], dim=1)
            fake_img2 = nd.concat(real_in[1],real_out[1], fake_out[1], dim=1)
            fake_img3 = nd.concat(real_in[2],real_out[2], fake_out[2], dim=1)
            fake_img4 = nd.concat(real_in[3],real_out[3], fake_out[3], dim=1)
            val_data.reset()
            text_file = open(expname + "_validtest.txt", "a")
            for vbatch in val_data:
                
                real_in = vbatch.data[0].as_in_context(ctx)
                real_out = vbatch.data[1].as_in_context(ctx)
                fake_latent= netEn(real_in)
                y = netDe(fake_latent)
                fake_out = y
                metricMSE.update([fake_out, ], [real_out, ])
            _, acc2 = metricMSE.get()
            text_file.write("%s %s %s\n" % (str(epoch), nd.mean(errR).asscalar(), str(acc2)))
            metricMSE.reset()
	    images = netDe(eps2)
            fake_img1T = nd.concat(images[0],images[1], images[2], dim=1)
	    fake_img2T = nd.concat(images[3],images[4], images[5], dim=1)
            fake_img3T = nd.concat(images[6],images[7], images[8], dim=1)
            fake_img = nd.concat(fake_img1T,fake_img2T, fake_img3T,dim=2)
            visual.visualize(fake_img)
            plt.savefig('outputs/'+expname+'_fakes_'+str(epoch)+'.png')
            text_file.close()

	    # Do 10 iterations of sampler update
	    fake_img1T = nd.concat(real_in[0],real_out[0], fake_out[0], dim=1)
            fake_img2T = nd.concat(real_in[1],real_out[1], fake_out[1], dim=1)
            fake_img3T = nd.concat(real_in[2],real_out[2], fake_out[2], dim=1)
            #fake_img4T = nd.concat(real_in[3],real_out[3], fake_out[3], dim=1)
            fake_img = nd.concat(fake_img1,fake_img2, fake_img3,fake_img1T,fake_img2T, fake_img3T,dim=2)
            visual.visualize(fake_img)
            plt.savefig('outputs/'+expname+'_'+str(epoch)+'.png')
	    '''if epoch > 100:
	      for ep2 in range(10):
	    	with autograd.record():
                	#eps = nd.random_normal(loc=0, scale=1, shape=noiseshape, ctx=ctx) #
			eps = nd.random.uniform( low=-1, high=1, shape=noiseshape, ctx=ctx)
			eps2 = nd.random_normal(loc=0, scale=0.02, shape=noiseshape, ctx=ctx)
                	eps2 = nd.tanh(eps2*sigma+mu)
                	eps2 = nd.concat(eps,eps2,dim=0)
			rec_output = netD(netDe(eps2))
			fake_label = nd.zeros(rec_output.shape, ctx=ctx)
                	errGS = GAN_loss(rec_output, fake_label)
    			errGS.backward()
		mu -= lr / mu.shape[0] * mu.grad
		sigma -= lr / sigma.shape[0] * sigma.grad
	    	print('mu ' + str(mu[0,0,0,0].asnumpy())+ '  sigma '+ str(sigma[0,0,0,0].asnumpy()))
	    '''
	    images = netDe(eps2)
            fake_img1T = nd.concat(images[0],images[1], images[2], dim=1)
            fake_img2T = nd.concat(images[3],images[4], images[5], dim=1)
            fake_img3T = nd.concat(images[6],images[7], images[8], dim=1)
            fake_img = nd.concat(fake_img1T,fake_img2T, fake_img3T,dim=2)
            visual.visualize(fake_img)
            plt.savefig('outputs/'+expname+'_fakespost_'+str(epoch)+'.png')
    return([loss_rec_D,loss_rec_G, loss_rec_R, acc_rec, loss_rec_D2, loss_rec_G2, acc2_rec])
示例#4
0
def train():
    image_pool = imagePool.ImagePool(pool_size)
    metric = mx.metric.CustomMetric(facc)

    stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')
    logging.basicConfig(level=logging.DEBUG)

    for epoch in range(epochs):
        tic = time.time()
        btic = time.time()
        train_data.reset()
        iter = 0
        for batch in train_data:
            ############################
            # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
            ###########################
            real_in = batch.data[0].as_in_context(ctx)
            real_out = batch.data[1].as_in_context(ctx)

            fake_out = netG(real_in)
            fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1))
            with autograd.record():
                # Train with fake image
                # Use image pooling to utilize history images
                output = netD(fake_concat)
                fake_label = nd.zeros(output.shape, ctx=ctx)
                errD_fake = GAN_loss(output, fake_label)
                metric.update([
                    fake_label,
                ], [
                    output,
                ])

                # Train with real image
                real_concat = nd.concat(real_in, real_out, dim=1)
                output = netD(real_concat)
                real_label = nd.ones(output.shape, ctx=ctx)
                errD_real = GAN_loss(output, real_label)
                errD = (errD_real + errD_fake) * 0.5
                errD.backward()
                metric.update([
                    real_label,
                ], [
                    output,
                ])

            trainerD.step(batch.data[0].shape[0])

            ############################
            # (2) Update G network: maximize log(D(x, G(x, z))) - lambda1 * L1(y, G(x, z))
            ###########################
            with autograd.record():
                fake_out = netG(real_in)
                fake_concat = nd.concat(real_in, fake_out, dim=1)
                output = netD(fake_concat)
                real_label = nd.ones(output.shape, ctx=ctx)
                errG = GAN_loss(
                    output, real_label) + L1_loss(real_out, fake_out) * lambda1
                errG.backward()

            trainerG.step(batch.data[0].shape[0])

            # Print log infomation every ten batches
            if iter % 10 == 0:
                name, acc = metric.get()
                logging.info('speed: {} samples/s'.format(
                    batch_size / (time.time() - btic)))
                logging.info(
                    'discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d'
                    % (nd.mean(errD).asscalar(), nd.mean(errG).asscalar(), acc,
                       iter, epoch))
            iter = iter + 1
            btic = time.time()

        name, acc = metric.get()
        metric.reset()
        logging.info('\nbinary training acc at epoch %d: %s=%f' %
                     (epoch, name, acc))
        logging.info('time: %f' % (time.time() - tic))
        if epoch % 10 == 0:
            filename = "checkpoints/testnet_" + str(epoch) + "_D.params"
            netD.save_params(filename)
            filename = "checkpoints/testnet_" + str(epoch) + "_G.params"
            netG.save_params(filename)
            # Visualize one generated image for each epoch
            fake_img = fake_out[0]
            visual.visualize(fake_img)
            plt.savefig('outputs/testnet_' + str(epoch) + '.png')
示例#5
0
def train(pool_size, epochs, train_data, val_data,  ctx, netEn, netDe,  netD, trainerEn, trainerDe, trainerD, lambda1, batch_size, expname, append=True, useAE = False):
    
    text_file = open(expname + "_validtest.txt", "w")
    text_file.close()
    #netGT, netDT, _, _ = set_test_network(opt.depth, ctx, opt.lr, opt.beta1,opt.ndf,  opt.ngf, opt.append)
    GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
    L1_loss = gluon.loss.L2Loss()
    image_pool = imagePool.ImagePool(pool_size)
    metric = mx.metric.CustomMetric(facc)
    metric2 = mx.metric.MSE()
    loss_rec_G = []
    loss_rec_D = []
    loss_rec_R = []
    acc_rec = []
    stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')
    logging.basicConfig(level=logging.DEBUG)
    for epoch in range(epochs):

        tic = time.time()
        btic = time.time()
        train_data.reset()
        iter = 0
        #print('learning rate : '+str(trainerD.learning_rate ))
	for batch in train_data:
            ############################
            # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
            ###########################
            real_in = batch.data[0].as_in_context(ctx)
            real_out = batch.data[1].as_in_context(ctx)
            soft_zero = 1e-10
            fake_latent= netEn(real_in)
	    fake_latent = np.squeeze(fake_latent)
            mu_lv = nd.split(fake_latent, axis=1, num_outputs=2)
	    mu = (mu_lv[0])
            lv = (mu_lv[1])
	    KL = 0.5*nd.nansum(1+lv-mu*mu-nd.exp(lv+soft_zero))
            eps = nd.random_normal(loc=0, scale=1, shape=(batch_size, 2048), ctx=ctx)
            z = mu + nd.exp(0.5*lv)*eps
	    z = nd.expand_dims(nd.expand_dims(z,2),2)
            y = netDe(z)
            fake_out = y
	    
	    logloss = nd.nansum(real_in*nd.log(y+soft_zero)+ (1-real_in)*nd.log(1-y+soft_zero))
            loss = -logloss-KL
            fake_concat =  nd.concat(real_in, fake_out, dim=1) if append else  fake_out
            with autograd.record():
                # Train with fake image
                # Use image pooling to utilize history imagesi
                output = netD(fake_concat)
                fake_label = nd.zeros(output.shape, ctx=ctx)
                errD_fake = GAN_loss(output, fake_label)
                metric.update([fake_label, ], [output, ])
                real_concat =  nd.concat(real_in, real_out, dim=1) if append else  real_out
                output = netD(real_concat)
                real_label = nd.ones(output.shape, ctx=ctx)
                errD_real = GAN_loss(output, real_label)
                errD = (errD_real + errD_fake) * 0.5
                errD.backward()
                metric.update([real_label, ], [output, ])

            trainerD.step(batch.data[0].shape[0])

            ############################
            # (2) Update G network: maximize log(D(x, G(x, z))) - lambda1 * L1(y, G(x, z))
            ###########################
            with autograd.record():

                fake_latent= np.squeeze(netEn(real_in))
                mu_lv = nd.split(fake_latent, axis=1, num_outputs=2)
                mu = mu_lv[0]
                lv = mu_lv[1]
		KL = 0.5*nd.nansum(1+lv-mu*mu-nd.exp(lv+soft_zero))
                eps = nd.random_normal(loc=0, scale=1, shape=(batch_size, 2048), ctx=ctx)
		#KL = 0.5*nd.nansum(1+lv-mu*mu-nd.exp(lv+soft_zero))
                z = mu + nd.exp(0.5*lv)*eps
		z = nd.expand_dims(nd.expand_dims(z,2),2)
                y = netDe(z)
                fake_out = y
		logloss = nd.nansum((real_in+1)*0.5*nd.log(0.5*(y+1)+soft_zero)+ (1-0.5*(real_in+1))*nd.log(1-0.5*(y+1)+soft_zero))
                loss =-logloss-KL
                fake_concat =  nd.concat(real_in, fake_out, dim=1) if append else  fake_out
                output = netD(fake_concat)
                real_label = nd.ones(output.shape, ctx=ctx)
                errG = GAN_loss(output, real_label) + loss*lambda1 #L1_loss(real_out, fake_out) * lambda1
                errR = logloss#L1_loss(real_out, fake_out)
                errG.backward()
        trainerDe.step(batch.data[0].shape[0])	   
        trainerEn.step(batch.data[0].shape[0])
        loss_rec_G.append(nd.mean(errG).asscalar()-nd.mean(errR).asscalar()*lambda1)
        loss_rec_D.append(nd.mean(errD).asscalar())
        loss_rec_R.append(nd.mean(errR).asscalar())
        name, acc = metric.get()
        acc_rec.append(acc)
        # Print log infomation every ten batches
        if iter % 10 == 0:
                name, acc = metric.get()
                logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
                #print(errD)
		logging.info('discriminator loss = %f, generator loss = %f, binary training acc = %f reconstruction error= %f at iter %d epoch %d'
                    	% (nd.mean(errD).asscalar(),
                      	nd.mean(errG).asscalar(), acc,nd.mean(errR).asscalar() ,iter, epoch))
        iter = iter + 1
        btic = time.time()

        name, acc = metric.get()
        metric.reset()
        train_data.reset()

        logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))
        logging.info('time: %f' % (time.time() - tic))
        if epoch%10 ==0:
            text_file = open(expname + "_validtest.txt", "a")
            filename = "checkpoints/"+expname+"_"+str(epoch)+"_D.params"
            netD.save_params(filename)
            filename = "checkpoints/"+expname+"_"+str(epoch)+"_En.params"
            netEn.save_params(filename)
            filename = "checkpoints/"+expname+"_"+str(epoch)+"_De.params"
            netDe.save_params(filename)
            fake_img1 = nd.concat(real_in[0],real_out[0], fake_out[0], dim=1)
            fake_img2 = nd.concat(real_in[1],real_out[1], fake_out[1], dim=1)
            fake_img3 = nd.concat(real_in[2],real_out[2], fake_out[2], dim=1)
            fake_img4 = nd.concat(real_in[3],real_out[3], fake_out[3], dim=1)
            val_data.reset()
            text_file = open(expname + "_validtest.txt", "a")
            for vbatch in val_data:
                
            	real_in = vbatch.data[0].as_in_context(ctx)
            	real_out = vbatch.data[1].as_in_context(ctx)

            	fake_latent= netEn(real_in)
            	mu_lv = nd.split(fake_latent, axis=1, num_outputs=2)
            	mu = mu_lv[0]
            	lv = mu_lv[1]
            	eps = nd.random_normal(loc=0, scale=1, shape=(batch_size/5, 2048,1,1), ctx=ctx)
            	z = mu + nd.exp(0.5*lv)*eps
            	y = netDe(z)
            	fake_out = y
            	KL = 0.5*nd.sum(1+lv-mu*mu-nd.exp(lv),axis=1)
            	logloss = nd.sum(real_in*nd.log(y+soft_zero)+ (1-real_in)*nd.log(1-y+soft_zero), axis=1)
            	loss = logloss+KL
            	metric2.update([fake_out, ], [real_out, ])
            	_, acc2 = metric2.get()
            text_file.write("%s %s %s\n" % (str(epoch), nd.mean(errR).asscalar(), str(acc2)))
            metric2.reset()

            fake_img1T = nd.concat(real_in[0],real_out[0], fake_out[0], dim=1)
            fake_img2T = nd.concat(real_in[1],real_out[1], fake_out[1], dim=1)
            fake_img3T = nd.concat(real_in[2],real_out[2], fake_out[2], dim=1)
            #fake_img4T = nd.concat(real_in[3],real_out[3], fake_out[3], dim=1)
            fake_img = nd.concat(fake_img1,fake_img2, fake_img3,fake_img1T,fake_img2T, fake_img3T,dim=2)
            visual.visualize(fake_img)
            plt.savefig('outputs/'+expname+'_'+str(epoch)+'.png')
            text_file.close()
    return([loss_rec_D,loss_rec_G, loss_rec_R, acc_rec])
示例#6
0
def train(cep,
          pool_size,
          epochs,
          train_data,
          val_data,
          ctx,
          netEn,
          netDe,
          netD,
          netD2,
          netDS,
          trainerEn,
          trainerDe,
          trainerD,
          trainerD2,
          trainerSD,
          lambda1,
          batch_size,
          expname,
          append=True,
          useAE=False):
    tp_file = open(expname + "_trainloss.txt", "w")
    tp_file.close()
    text_file = open(expname + "_validtest.txt", "w")
    text_file.close()
    #netGT, netDT, _, _ = set_test_network(opt.depth, ctx, opt.lr, opt.beta1,opt.ndf,  opt.ngf, opt.append)
    GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
    L1_loss = gluon.loss.L2Loss()
    image_pool = imagePool.ImagePool(pool_size)
    metric = mx.metric.CustomMetric(facc)
    metric2 = mx.metric.CustomMetric(facc)
    metricStrong = mx.metric.CustomMetric(facc)
    metricMSE = mx.metric.MSE()
    loss_rec_G = []
    loss_rec_D = []
    loss_rec_R = []
    acc_rec = []
    acc2_rec = []
    loss_rec_D2 = []
    loss_rec_G2 = []
    lr = 2.0 * 512
    stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')
    logging.basicConfig(level=logging.DEBUG)
    if cep == -1:
        cep = 0
    else:
        netEn.load_params('checkpoints/' + opt.expname + '_' + str(cep) +
                          '_En.params',
                          ctx=ctx)
        netDe.load_params('checkpoints/' + opt.expname + '_' + str(cep) +
                          '_De.params',
                          ctx=ctx)
        netD.load_params('checkpoints/' + opt.expname + '_' + str(cep) +
                         '_D.params',
                         ctx=ctx)
        netD2.load_params('checkpoints/' + opt.expname + '_' + str(cep) +
                          '_D2.params',
                          ctx=ctx)
        netDS.load_params('checkpoints/' + opt.expname + '_' + str(cep) +
                          '_SD.params',
                          ctx=ctx)
    iter = 0
    for epoch in range(cep + 1, epochs):

        tic = time.time()
        btic = time.time()
        train_data.reset()
        #print('learning rate : '+str(trainerD.learning_rate ))
        for batch in train_data:
            ############################
            # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
            ###########################
            if ctx == mx.cpu():
                ct = mx.cpu()
            else:
                ct = mx.gpu()
            real_in = batch.data[0]  #.as_in_context(ctx)
            real_out = batch.data[1]  #.as_in_context(ctx)
            if iter == 0:
                latent_shape = (batch_size, 512, 1, 1)  #code.shape
                out_l_shape = (batch_size, 1, 1, 1)  #netD2((code)).shape
                out_i_shape = (batch_size, 1, 1, 1)  #netD(netDe(code)).shape
                out_s_shape = (batch_size, 1, 1, 1)  #netSD(netDe(code)).shape
            real_in = gluon.utils.split_and_load(real_in, ctx)
            real_out = gluon.utils.split_and_load(real_out, ctx)
            fake_latent = [netEn(r) for r in real_in]
            real_latent = nd.random.uniform(low=-1, high=1, shape=latent_shape)
            real_latent = gluon.utils.split_and_load(real_latent, ctx)
            fake_out = [netDe(f) for f in fake_latent]
            fake_concat = nd.concat(real_in, fake_out,
                                    dim=1) if append else fake_out
            eps2 = nd.random.uniform(low=-1,
                                     high=1,
                                     shape=latent_shape,
                                     ctx=ct)
            eps2 = gluon.utils.split_and_load(eps2, ctx)
            if epoch > 150:  # (1/float(batch_size))*512*150:# and epoch%10==0:
                print('Mining..')
                mu = nd.random.uniform(low=-1,
                                       high=1,
                                       shape=latent_shape,
                                       ctx=ct)
                #isigma = nd.ones((batch_size,64,1,1),ctx=ctx)*0.000001
                mu.attach_grad()
                #sigma.attach_grad()
                images = netDe(mu)
                fake_img1T = nd.concat(images[0], images[1], images[2], dim=1)
                fake_img2T = nd.concat(images[3], images[4], images[5], dim=1)
                fake_img3T = nd.concat(images[6], images[7], images[8], dim=1)
                fake_img = nd.concat(fake_img1T, fake_img2T, fake_img3T, dim=2)
                visual.visualize(fake_img)
                plt.savefig('outputs/' + expname + '_fakespre_' + str(epoch) +
                            '.png')
                eps2 = gluon.utils.split_and_load(mu, ctx)
                for e in eps2:
                    e.attach_grad()
                for ep2 in range(1):
                    with autograd.record():
                        #eps = nd.random_normal(loc=0, scale=1, shape=fake_latent.shape, ctx=ctx) #
                        #eps2 = gluon.utils.split_and_load(nd.tanh(mu),ctx) #+nd.multiply(eps,sigma))#nd.random.uniform( low=-1, high=1, shape=fake_latent.shape, ctx=ctx)
                        rec_output = [netDS(netDe(e)) for e in eps2]
                        fake_label = gluon.utils.split_and_load(
                            nd.zeros(out_s_shape), ctx)
                        errGS = [
                            GAN_loss(r, f)
                            for r, f in zip(rec_output, fake_label)
                        ]
                        for e in errGS:
                            e.backward()
                    for idx, _ in enumerate(eps2):
                        eps2[idx] = nd.tanh(eps2[idx] -
                                            lr / eps2[idx].shape[0] *
                                            eps2[idx].grad)
                images = netDe((eps2[0]))
                fake_img1T = nd.concat(images[0], images[1], images[2], dim=1)
                fake_img2T = nd.concat(images[3], images[4], images[5], dim=1)
                fake_img3T = nd.concat(images[6], images[7], images[8], dim=1)
                fake_img = nd.concat(fake_img1T, fake_img2T, fake_img3T, dim=2)
                visual.visualize(fake_img)
                plt.savefig('outputs/' + expname + str(ep2) + '_fakespost_' +
                            str(epoch) + '.png')
                #eps2 = nd.tanh(mu)#+nd.multiply(eps,sigma))#nd.random.uniform( low=-1, high=1, shape=fake_latent.shape, ctx=ctx)

            with autograd.record():
                #eps2 = gluon.utils.split_and_load(eps2,ctx)
                # Train with fake image
                # Use image pooling to utilize history imagesi
                output = [netD(f) for f in fake_concat]
                output2 = [netD2(f) for f in fake_latent]
                fake_label = nd.zeros(out_i_shape)
                fake_label = gluon.utils.split_and_load(fake_label, ctx)
                fake_latent_label = nd.zeros(out_l_shape)
                fake_latent_label = gluon.utils.split_and_load(
                    fake_latent_label, ctx)
                eps = gluon.utils.split_and_load(
                    nd.random.uniform(low=-1, high=1, shape=latent_shape), ctx)
                rec_output = [netD(netDe(e)) for e in eps]
                errD_fake = [
                    GAN_loss(r, f) for r, f in zip(rec_output, fake_label)
                ]
                errD_fake2 = [
                    GAN_loss(o, f) for o, f in zip(output, fake_label)
                ]
                errD2_fake = [
                    GAN_loss(o, f) for o, f in zip(output2, fake_latent_label)
                ]
                for f, o in zip(fake_label, rec_output):
                    metric.update([
                        f,
                    ], [
                        o,
                    ])
                for f, o in zip(fake_latent_label, output2):
                    metric2.update([
                        f,
                    ], [
                        o,
                    ])
                real_concat = nd.concat(real_in, real_out,
                                        dim=1) if append else real_out
                output = [netD(r) for r in real_concat]
                output2 = [netD2(r) for r in real_latent]
                real_label = gluon.utils.split_and_load(
                    nd.ones(out_i_shape), ctx)
                real_latent_label = gluon.utils.split_and_load(
                    nd.ones(out_l_shape), ctx)
                errD_real = [
                    GAN_loss(o, r) for o, r in zip(output, real_label)
                ]
                errD2_real = [
                    GAN_loss(o, r) for o, r in zip(output2, real_latent_label)
                ]
                for e1, e2, e4, e5 in zip(errD_real, errD_fake, errD2_real,
                                          errD2_fake):
                    err = (e1 + e2) * 0.5 + (e5 + e4) * 0.5
                    err.backward()
                for f, o in zip(real_label, output):
                    metric.update([
                        f,
                    ], [
                        o,
                    ])
                for f, o in zip(real_latent_label, output2):
                    metric2.update([
                        f,
                    ], [
                        o,
                    ])
            trainerD.step(batch.data[0].shape[0])
            trainerD2.step(batch.data[0].shape[0])
            nd.waitall()
            with autograd.record():
                strong_output = [netDS(netDe(e)) for e in eps]
                strong_real = [netDS(f) for f in fake_concat]
                errs1 = [
                    GAN_loss(r, f) for r, f in zip(strong_output, fake_label)
                ]
                errs2 = [
                    GAN_loss(r, f) for r, f in zip(strong_real, real_label)
                ]
                for f, s in zip(fake_label, strong_output):
                    metricStrong.update([
                        f,
                    ], [
                        s,
                    ])
                for f, s in zip(real_label, strong_real):
                    metricStrong.update([
                        f,
                    ], [
                        s,
                    ])
                for e1, e2 in zip(errs1, errs2):
                    strongerr = 0.5 * (e1 + e2)
                    strongerr.backward()
            trainerSD.step(batch.data[0].shape[0])
            nd.waitall()
            ############################
            # (2) Update G network: maximize log(D(x, G(x, z))) - lambda1 * L1(y, G(x, z))
            ###########################
            with autograd.record():
                sh = out_l_shape
                #eps2 = nd.random_normal(loc=0, scale=1, shape=noiseshape, ctx=ctx) #
                #eps = nd.random.uniform( low=-1, high=1, shape=noiseshape, ctx=ctx)
                #if epoch>100:
                #        eps2 = nd.multiply(eps2,sigma)+mu
                #        eps2 = nd.tanh(eps2)
                #else:
                #eps = nd.random.uniform( low=-1, high=1, shape=noiseshape, ctx=ctx)
                #eps2 = nd.concat(eps,eps2,dim=0)
                rec_output = [netD(netDe(e)) for e in eps2]
                fake_latent = [(netEn(r)) for r in real_in]
                output2 = [netD2(f) for f in fake_latent]
                fake_out = [netDe(f) for f in fake_latent]
                fake_concat = nd.concat(real_in, fake_out,
                                        dim=1) if append else fake_out
                output = [netD(f) for f in fake_concat]
                real_label = gluon.utils.split_and_load(
                    nd.ones(out_i_shape), ctx)
                real_latent_label = gluon.utils.split_and_load(
                    nd.ones(out_l_shape), ctx)
                errG2 = [
                    GAN_loss(r, f) for r, f in zip(rec_output, real_label)
                ]
                errR = [
                    L1_loss(r, f) * lambda1
                    for r, f in zip(real_out, fake_out)
                ]
                errG = [
                    10 * GAN_loss(r, f)
                    for r, f in zip(output2, real_latent_label)
                ]  # +errG2+errR
                for e1, e2, e3 in zip(errG, errG2, errR):
                    e = e1 + e2 + e3
                    e.backward()
            trainerDe.step(batch.data[0].shape[0])
            trainerEn.step(batch.data[0].shape[0])
            nd.waitall()
            errD = (errD_real[0] + errD_fake[0]) * 0.5
            errD2 = (errD2_real[0] + errD2_fake[0]) * 0.5
            loss_rec_G2.append(nd.mean(errG2[0]).asscalar())
            loss_rec_G.append(
                nd.mean(nd.mean(errG[0])).asscalar() -
                nd.mean(errG2[0]).asscalar() - nd.mean(errR[0]).asscalar())
            loss_rec_D.append(nd.mean(errD[0]).asscalar())
            loss_rec_R.append(nd.mean(errR[0]).asscalar())
            loss_rec_D2.append(nd.mean(errD2[0]).asscalar())
            _, acc2 = metric2.get()
            name, acc = metric.get()
            acc_rec.append(acc)
            acc2_rec.append(acc2)

            # Print log infomation every ten batches
            if iter % 10 == 0:
                _, acc2 = metric2.get()
                name, acc = metric.get()
                _, accStrong = metricStrong.get()
                logging.info('speed: {} samples/s'.format(
                    batch_size / (time.time() - btic)))
                #print(errD)
                #logging.info('discriminator loss = %f, D2 loss = %f, generator loss = %f, G2 loss = %f, SD loss = %f,  D acc = %f , D2 acc = %f, DS acc = %f, reconstruction error= %f  at iter %d epoch %d'
                #   	% (nd.mean(errD[0]).asscalar(),nd.mean(errD2[0]).asscalar(),
                #     	nd.mean(errG[0]-errG2[0]-errR[0]).asscalar(),nd.mean(errG2[0]).asscalar(),nd.mean(strongerr[0]).asscalar() ,acc,acc2,accStrong[0],nd.mean(errR[0]).asscalar() ,iter, epoch))
                iter = iter + 1
        btic = time.time()
        name, acc = metric.get()
        _, acc2 = metric2.get()
        #tp_file = open(expname + "_trainloss.txt", "a")
        #tp_file.write(str(nd.mean(errG2).asscalar()) + " " + str(
        #    nd.mean(nd.mean(errG)).asscalar() - nd.mean(errG2).asscalar() - nd.mean(errR).asscalar()) + " " + str(
        #    nd.mean(errD).asscalar()) + " " + str(nd.mean(errD2).asscalar()) + " " + str(nd.mean(errR).asscalar()) +" "+str(acc) + " " + str(acc2)+"\n")
        #tp_file.close()
        metric.reset()
        metric2.reset()
        train_data.reset()
        metricStrong.reset()

        logging.info('\nbinary training acc at epoch %d: %s=%f' %
                     (epoch, name, acc))
        logging.info('time: %f' % (time.time() - tic))
        if epoch % 2 == 0:  # and epoch>0:
            text_file = open(expname + "_validtest.txt", "a")
            filename = "checkpoints/" + expname + "_" + str(
                epoch) + "_D.params"
            netD.save_parameters(filename)
            filename = "checkpoints/" + expname + "_" + str(
                epoch) + "_D2.params"
            netD2.save_parameters(filename)
            filename = "checkpoints/" + expname + "_" + str(
                epoch) + "_En.params"
            netEn.save_parameters(filename)
            filename = "checkpoints/" + expname + "_" + str(
                epoch) + "_De.params"
            netDe.save_parameters(filename)
            filename = "checkpoints/" + expname + "_" + str(
                epoch) + "_SD.params"
            netDS.save_parameters(filename)
            fake_img1 = nd.concat(real_in[0], real_out[0], fake_out[0], dim=1)
            fake_img2 = nd.concat(real_in[1], real_out[1], fake_out[1], dim=1)
            fake_img3 = nd.concat(real_in[2], real_out[2], fake_out[2], dim=1)
            fake_img4 = nd.concat(real_in[3], real_out[3], fake_out[3], dim=1)
            val_data.reset()
            text_file = open(expname + "_validtest.txt", "a")
            for vbatch in val_data:

                real_in = vbatch.data[0]
                real_out = vbatch.data[1]
                real_in = gluon.utils.split_and_load(real_in, ctx)
                real_out = gluon.utils.split_and_load(real_out, ctx)

                fake_latent = [netEn(r) for r in real_in]
                fake_out = [netDe(f) for f in fake_latent]
                for f, r in zip(fake_out, real_out):
                    metricMSE.update([
                        f,
                    ], [
                        r,
                    ])
            _, acc2 = metricMSE.get()
            toterrR = 0
            for e in errR:
                toterrR += nd.mean(e).asscalar()
            text_file.write("%s %s %s\n" % (str(epoch), toterrR, str(acc2)))
            metricMSE.reset()
    return ([
        loss_rec_D, loss_rec_G, loss_rec_R, acc_rec, loss_rec_D2, loss_rec_G2,
        acc2_rec
    ])