Ejemplo n.º 1
0
 def forward(self, feat):
     square_sum = nd.sum(nd.square(feat), axis=self.axis, keepdims=True)
     inv_norm = nd.rsqrt(nd.maximum(square_sum, self.epsilon))
     l2_res = nd.multiply(feat, inv_norm)
     # print(l2_res.shape)
     return nd.multiply(l2_res.transpose([0, 2, 3, 1]),
                        self.scale.data()).transpose([0, 3, 1, 2])
Ejemplo n.º 2
0
def perceptual_loss(X):
    # Calculate ||X^TX||_* = tr(X^TX) = sum_ij (X.^2)_ij
    rank_est = nd.sum(nd.sum(nd.multiply(X, X), axis=2), axis=2)
    rank_est = rank_est[0, 0]

    # Calculate gradient
    grad = nd.multiply(2.0, X)
    return (rank_est.asscalar(), grad)
Ejemplo n.º 3
0
    def GRU_Cell(input, state):
        for x in input:
            z_t = nd.Activation(nd.FullyConnected(data=x,weight=wxz,no_bias=True,num_hidden=num_hidden)+
                                nd.FullyConnected(data=state,weight=whz,no_bias=True,num_hidden=num_hidden)+bz,act_type="sigmoid")
            r_t = nd.Activation(nd.FullyConnected(data=x,weight=wxr,no_bias=True,num_hidden=num_hidden)+
                                nd.FullyConnected(data=state,weight=whr,no_bias=True,num_hidden=num_hidden)+br,act_type="sigmoid")
            g_t = nd.Activation(nd.FullyConnected(data=x,weight=wxh,no_bias=True,num_hidden=num_hidden)+
                                nd.FullyConnected(data=r_t*state,weight=whh,no_bias=True,num_hidden=num_hidden)+bh,act_type="tanh")

            state = nd.multiply(z_t,state) + nd.multiply(1-z_t,g_t)

        output = nd.FullyConnected(data=state, weight=why, bias=by, num_hidden=num_outputs)
        output = nd.softmax(data=output)
        return output, state
def test_lab_to_rgb_np():
    cpu_context = mx.cpu()
    lab = rgb_to_lab(nd.array(test_image), ctx=cpu_context)
    rgb = lab_to_rgb(lab, ctx=cpu_context)
    rgb = nd.multiply(rgb, nd.array([256]))
    rgb = nd.cast(rgb, np.uint8)
    np.testing.assert_array_almost_equal(rgb.asnumpy(), test_image, 3)
Ejemplo n.º 5
0
    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):

        pred = in_data[0]
        ll = in_data[1]

        out = nd.add(pred, ll)
        out = nd.divide(ll, out)
        # out = (ll/(pred+ll))**2
        out = - nd.multiply(out, out)
        self.assign(in_grad[0], req[0], out)
def lab_to_rgb(lab, ctx=None):
    if ctx is None:
        raise ValueError("ctx can not be None")

    if lab is None:
        raise ValueError("lab can not be None")

    with mx.Context(ctx):
        lab = __check_image(lab)
        lab_pixels = lab.reshape([-1, 3])
        lab_to_fxfyfz = nd.array([
                #   fx      fy        fz
                [1 / 116.0, 1 / 116.0, 1 / 116.0],  # l
                [1 / 500.0, 0.0, 0.0],  # a
                [0.0, 0.0, -1 / 200.0],  # b
            ], ctx=ctx)
        fxfyfz_pixels = nd.dot(lab_pixels + nd.array([16.0, 0.0, 0.0], ctx=ctx), lab_to_fxfyfz)

        # convert to xyz
        epsilon = 6 / 29
        linear_mask = fxfyfz_pixels <= epsilon
        exponential_mask = fxfyfz_pixels > epsilon

        xyz_pixels = (3 * epsilon ** 2 * (fxfyfz_pixels - 4 / 29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask

        xyz_pixels = nd.multiply(xyz_pixels, nd.array([0.950456, 1.0, 1.088754]))
        xyz_to_rgb =nd.array([
                #     r           g          b
                [3.2404542, -0.9692660, 0.0556434],  # x
                [-1.5371385, 1.8760108, -0.2040259],  # y
                [-0.4985314, 0.0415560, 1.0572252],  # z
        ])
        rgb_pixels = nd.dot(xyz_pixels, xyz_to_rgb)
        nd.clip(rgb_pixels, 0.0, 1.0, out=rgb_pixels)

        linear_mask = rgb_pixels <= 0.0031308
        exponential_mask = rgb_pixels > 0.0031308

        step1 = nd.multiply(nd.multiply(rgb_pixels, 12.92), linear_mask)
        step2 = nd.multiply(nd.multiply(nd.power(rgb_pixels, (1 / 2.4)), 1.055) - 0.055, exponential_mask)
        srgb_pixels = step1 + step2

        return srgb_pixels.reshape(lab.shape)
Ejemplo n.º 7
0
    def LSTM_Cell(input, h_state, c_state):
        for x in input:
            f_t = nd.Activation(nd.FullyConnected(
                data=x, weight=wxhf, no_bias=True, num_hidden=num_hidden) +
                                nd.FullyConnected(data=h_state,
                                                  weight=whhf,
                                                  no_bias=True,
                                                  num_hidden=num_hidden) + bhf,
                                act_type="sigmoid")
            i_t = nd.Activation(nd.FullyConnected(
                data=x, weight=wxhi, no_bias=True, num_hidden=num_hidden) +
                                nd.FullyConnected(data=h_state,
                                                  weight=whhi,
                                                  no_bias=True,
                                                  num_hidden=num_hidden) + bhi,
                                act_type="sigmoid")
            o_t = nd.Activation(nd.FullyConnected(
                data=x, weight=wxho, no_bias=True, num_hidden=num_hidden) +
                                nd.FullyConnected(data=h_state,
                                                  weight=whho,
                                                  no_bias=True,
                                                  num_hidden=num_hidden) + bho,
                                act_type="sigmoid")
            g_t = nd.Activation(nd.FullyConnected(
                data=x, weight=wxhg, no_bias=True, num_hidden=num_hidden) +
                                nd.FullyConnected(data=h_state,
                                                  weight=whhg,
                                                  no_bias=True,
                                                  num_hidden=num_hidden) + bhg,
                                act_type="tanh")
            c_state = nd.multiply(f_t, c_state) + nd.multiply(i_t, g_t)
            h_state = nd.multiply(o_t, nd.tanh(c_state))

        output = nd.FullyConnected(data=h_state,
                                   weight=why,
                                   bias=by,
                                   num_hidden=num_outputs)
        output = nd.softmax(data=output)
        return output, h_state, c_state
def rgb_to_lab(image_srgb, ctx=None):

    if ctx is None:
        raise ValueError("ctx can not be None")

    if image_srgb is None:
        raise ValueError("image_srgb can not be None")

    with mx.Context(ctx):

        srgb = __check_image(image_srgb)

        if nd.max(srgb).asscalar() > 1:
            srgb = __normalize_rgb_image(srgb)

        srgb_pixels = nd.reshape(srgb, [-1, 3])

        linear_mask = nd.cast(srgb_pixels <= 0.04045, dtype='float32')
        exponential_mask = nd.cast(srgb_pixels > 0.04045, dtype='float32')
        rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask
        rgb_to_xyz = nd.array([
            #    X        Y          Z
            [0.412453, 0.212671, 0.019334],  # R
            [0.357580, 0.715160, 0.119193],  # G
            [0.180423, 0.072169, 0.950227],  # B
        ])
        xyz_pixels = nd.linalg_gemm2(rgb_pixels, rgb_to_xyz)

        # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
        # convert to fx = f(X/Xn), fy = f(Y/Yn), fz = f(Z/Zn)
        # normalize for D65 white point
        xyz_normalized_pixels = nd.multiply(xyz_pixels, nd.array([1 / 0.950456, 1.0, 1 / 1.088754]))

        epsilon = 6 / 29
        linear_mask = nd.cast(xyz_normalized_pixels <= (epsilon ** 3), dtype='float32')
        exponential_mask = nd.cast(xyz_normalized_pixels > (epsilon ** 3), dtype='float32')
        fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon ** 2) + 4 / 29) * linear_mask + (
                                                                                                  xyz_normalized_pixels ** (
                                                                                                  1 / 3)) * exponential_mask
            # convert to lab
        fxfyfz_to_lab = nd.array([
                #  l       a       b
                [0.0, 500.0, 0.0],  # fx
                [116.0, -500.0, 200.0],  # fy
                [0.0, 0.0, -200.0],  # fz
            ])
        lab_pixels = nd.linalg_gemm2(fxfyfz_pixels, fxfyfz_to_lab) + nd.array([-16.0, 0.0, 0.0])

        return nd.reshape(lab_pixels, srgb.shape)
    def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):
        if self.inference_update_stat:
            mean = x.mean(axis=(0, 2, 3))
            mean_expanded = F.expand_dims(F.expand_dims(F.expand_dims(mean,
                                                                      axis=0),
                                                        axis=2),
                                          axis=3)
            var = F.square(F.broadcast_minus(x,
                                             mean_expanded)).mean(axis=(0, 2,
                                                                        3))

            running_mean = F.add(
                F.multiply(self.running_mean.data(),
                           self.momentum.as_in_context(x.context)),
                F.multiply(mean, self.momentum_rest.as_in_context(x.context)))
            running_var = F.add(
                F.multiply(self.running_var.data(),
                           self.momentum.as_in_context(x.context)),
                F.multiply(var, self.momentum_rest.as_in_context(x.context)))
            self.running_mean.set_data(running_mean)
            self.running_var.set_data(running_var)
            return F.BatchNorm(x,
                               gamma,
                               beta,
                               mean,
                               var,
                               name='fwd',
                               **self._kwargs)
        else:
            return F.BatchNorm(x,
                               gamma,
                               beta,
                               running_mean,
                               running_var,
                               name='fwd',
                               **self._kwargs)
Ejemplo n.º 10
0
def train(pool_size, epochs, train_data, val_data,  ctx, netEn, netDe,  netD, netD2, trainerEn, trainerDe, trainerD, trainerD2, lambda1, batch_size, expname,  append=True, useAE = False):
    tp_file = open(expname + "_trainloss.txt", "w")  
    tp_file.close()  
    text_file = open(expname + "_validtest.txt", "w")
    text_file.close()
    #netGT, netDT, _, _ = set_test_network(opt.depth, ctx, opt.lr, opt.beta1,opt.ndf,  opt.ngf, opt.append)
    GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
    L1_loss = gluon.loss.L2Loss()
    image_pool = imagePool.ImagePool(pool_size)
    metric = mx.metric.CustomMetric(facc)
    metric2 = mx.metric.CustomMetric(facc)
    metricMSE = mx.metric.MSE()
    loss_rec_G = []
    loss_rec_D = []
    loss_rec_R = []
    acc_rec = []
    acc2_rec = []
    loss_rec_D2 = []
    loss_rec_G2 = []
    lr = 0.002
    #mu = nd.random_normal(loc=0, scale=1, shape=(batch_size/2,64,1,1), ctx=ctx) 
    mu = nd.random.uniform(low= -1, high=1, shape=(batch_size/2,64,1,1),ctx=ctx)
    #mu =  nd.zeros((batch_size/2,64,1,1),ctx=ctx)
    sigma = nd.ones((64,1,1),ctx=ctx)
    mu.attach_grad()
    sigma.attach_grad()    
    stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')
    logging.basicConfig(level=logging.DEBUG)
    for epoch in range(epochs):

        tic = time.time()
        btic = time.time()
        train_data.reset()
        iter = 0
        #print('learning rate : '+str(trainerD.learning_rate ))
        for batch in train_data:
            ############################
            # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
            ###########################
            real_in = batch.data[0].as_in_context(ctx)
            real_out = batch.data[1].as_in_context(ctx)
            fake_latent= netEn(real_in)
            #real_latent = nd.random_normal(loc=0, scale=1, shape=fake_latent.shape, ctx=ctx)
            real_latent = nd.multiply(nd.power(sigma,2),nd.random_normal(loc=0, scale=1, shape=fake_latent.shape, ctx=ctx))
	    #nd.random.uniform( low=-1, high=1, shape=fake_latent.shape, ctx=ctx)
	    fake_out = netDe(fake_latent)
            fake_concat =  nd.concat(real_in, fake_out, dim=1) if append else  fake_out
            with autograd.record():
                # Train with fake image
                # Use image pooling to utilize history imagesi
                output = netD(fake_concat)
                output2 = netD2(fake_latent)
                fake_label = nd.zeros(output.shape, ctx=ctx)
                fake_latent_label = nd.zeros(output2.shape, ctx=ctx)
		noiseshape = (fake_latent.shape[0]/2,fake_latent.shape[1],fake_latent.shape[2],fake_latent.shape[3])
                eps2 = nd.multiply(nd.power(sigma,2),nd.random_normal(loc=0, scale=1, shape=fake_latent.shape, ctx=ctx))
		#eps2 = nd.random_normal(loc=0, scale=sigma.asscalar(), shape=fake_latent.shape, ctx=ctx) #
		#eps = nd.random.uniform( low=-1, high=1, shape=noiseshape, ctx=ctx)
		rec_output = netD(netDe(eps2))
                errD_fake = GAN_loss(rec_output, fake_label)
                errD_fake2 = GAN_loss(output, fake_label)
                errD2_fake = GAN_loss(output2, fake_latent_label)
                metric.update([fake_label, ], [output, ])
                metric2.update([fake_latent_label, ], [output2, ])
                real_concat =  nd.concat(real_in, real_out, dim=1) if append else  real_out
                output = netD(real_concat)
                output2 = netD2(real_latent)
                real_label = nd.ones(output.shape, ctx=ctx)
                real_latent_label =  nd.ones(output2.shape, ctx=ctx)
                errD_real = GAN_loss(output, real_label)
                errD2_real =  GAN_loss(output2, real_latent_label)
                #errD = (errD_real + 0.5*(errD_fake+errD_fake2)) * 0.5
                errD = (errD_real + errD_fake) * 0.5
                errD2 = (errD2_real + errD2_fake) * 0.5
		totalerrD = errD+errD2
                totalerrD.backward()
                #errD2.backward()
                metric.update([real_label, ], [output, ])
            	metric2.update([real_latent_label, ], [output2, ])
            trainerD.step(batch.data[0].shape[0])
            trainerD2.step(batch.data[0].shape[0])
            ############################
            # (2) Update G network: maximize log(D(x, G(x, z))) - lambda1 * L1(y, G(x, z))
            ###########################
            with autograd.record():
		sh = fake_latent.shape
		eps2 = nd.multiply(nd.power(sigma,2),nd.random_normal(loc=0, scale=1, shape=fake_latent.shape, ctx=ctx))
                #eps2 = nd.random_normal(loc=0, scale=sigma.asscalar(), shape=fake_latent.shape, ctx=ctx) #
		#eps = nd.random.uniform( low=-1, high=1, shape=noiseshape, ctx=ctx)
		rec_output = netD(netDe(eps2))
                fake_latent= (netEn(real_in))
                output2 = netD2(fake_latent)
                fake_out = netDe(fake_latent)
                fake_concat =  nd.concat(real_in, fake_out, dim=1) if append else  fake_out
                output = netD(fake_concat)
                real_label = nd.ones(output.shape, ctx=ctx)
                real_latent_label = nd.ones(output2.shape, ctx=ctx)
                errG2 = GAN_loss(rec_output, real_label)
                errR = L1_loss(real_out, fake_out) * lambda1
		errG = 10.0*GAN_loss(output2, real_latent_label)+errG2+errR+nd.mean(nd.power(sigma,2))
		errG.backward()
	    if epoch>50:
	    	sigma -= lr / sigma.shape[0] * sigma.grad
	    	print(sigma)
            trainerDe.step(batch.data[0].shape[0])
            trainerEn.step(batch.data[0].shape[0])
            loss_rec_G2.append(nd.mean(errG2).asscalar())
            loss_rec_G.append(nd.mean(nd.mean(errG)).asscalar()-nd.mean(errG2).asscalar()-nd.mean(errR).asscalar())
            loss_rec_D.append(nd.mean(errD).asscalar())
            loss_rec_R.append(nd.mean(errR).asscalar())
            loss_rec_D2.append(nd.mean(errD2).asscalar())
            _, acc2 = metric2.get()
            name, acc = metric.get()
            acc_rec.append(acc)
            acc2_rec.append(acc2)

            # Print log infomation every ten batches
            if iter % 10 == 0:
                _, acc2 = metric2.get()
                name, acc = metric.get()
                logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
                #print(errD)
		logging.info('discriminator loss = %f, D2 loss = %f, generator loss = %f, G2 loss = %f,  binary training acc = %f , D2 acc = %f, reconstruction error= %f  at iter %d epoch %d'
                    	% (nd.mean(errD).asscalar(),nd.mean(errD2).asscalar(),
                      	nd.mean(errG-errG2-errR).asscalar(),nd.mean(errG2).asscalar(), acc,acc2,nd.mean(errR).asscalar() ,iter, epoch))
                iter = iter + 1
        btic = time.time()
        name, acc = metric.get()
        _, acc2 = metric2.get()
        tp_file = open(expname + "_trainloss.txt", "a")
        tp_file.write(str(nd.mean(errG2).asscalar()) + " " + str(
            nd.mean(nd.mean(errG)).asscalar() - nd.mean(errG2).asscalar() - nd.mean(errR).asscalar()) + " " + str(
            nd.mean(errD).asscalar()) + " " + str(nd.mean(errD2).asscalar()) + " " + str(nd.mean(errR).asscalar()) +" "+str(acc) + " " + str(acc2)+"\n")
        tp_file.close()
        metric.reset()
        metric2.reset()
        train_data.reset()

        logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))
        logging.info('time: %f' % (time.time() - tic))
        if epoch%10 ==0:# and epoch>0:
            text_file = open(expname + "_validtest.txt", "a")
            filename = "checkpoints/"+expname+"_"+str(epoch)+"_D.params"
            netD.save_params(filename)
            filename = "checkpoints/"+expname+"_"+str(epoch)+"_D2.params"
            netD2.save_params(filename)
            filename = "checkpoints/"+expname+"_"+str(epoch)+"_En.params"
            netEn.save_params(filename)
            filename = "checkpoints/"+expname+"_"+str(epoch)+"_De.params"
            netDe.save_params(filename)
            fake_img1 = nd.concat(real_in[0],real_out[0], fake_out[0], dim=1)
            fake_img2 = nd.concat(real_in[1],real_out[1], fake_out[1], dim=1)
            fake_img3 = nd.concat(real_in[2],real_out[2], fake_out[2], dim=1)
            fake_img4 = nd.concat(real_in[3],real_out[3], fake_out[3], dim=1)
            val_data.reset()
            text_file = open(expname + "_validtest.txt", "a")
            for vbatch in val_data:
                
                real_in = vbatch.data[0].as_in_context(ctx)
                real_out = vbatch.data[1].as_in_context(ctx)
                fake_latent= netEn(real_in)
                y = netDe(fake_latent)
                fake_out = y
                metricMSE.update([fake_out, ], [real_out, ])
            _, acc2 = metricMSE.get()
            text_file.write("%s %s %s\n" % (str(epoch), nd.mean(errR).asscalar(), str(acc2)))
            metricMSE.reset()
	    images = netDe(eps2)
            fake_img1T = nd.concat(images[0],images[1], images[2], dim=1)
	    fake_img2T = nd.concat(images[3],images[4], images[5], dim=1)
            fake_img3T = nd.concat(images[6],images[7], images[8], dim=1)
            fake_img = nd.concat(fake_img1T,fake_img2T, fake_img3T,dim=2)
            visual.visualize(fake_img)
            plt.savefig('outputs/'+expname+'_fakes_'+str(epoch)+'.png')
            text_file.close()

	    # Do 10 iterations of sampler update
	    fake_img1T = nd.concat(real_in[0],real_out[0], fake_out[0], dim=1)
            fake_img2T = nd.concat(real_in[1],real_out[1], fake_out[1], dim=1)
            fake_img3T = nd.concat(real_in[2],real_out[2], fake_out[2], dim=1)
            #fake_img4T = nd.concat(real_in[3],real_out[3], fake_out[3], dim=1)
            fake_img = nd.concat(fake_img1,fake_img2, fake_img3,fake_img1T,fake_img2T, fake_img3T,dim=2)
            visual.visualize(fake_img)
            plt.savefig('outputs/'+expname+'_'+str(epoch)+'.png')
	    '''if epoch > 100:
	      for ep2 in range(10):
	    	with autograd.record():
                	#eps = nd.random_normal(loc=0, scale=1, shape=noiseshape, ctx=ctx) #
			eps = nd.random.uniform( low=-1, high=1, shape=noiseshape, ctx=ctx)
			eps2 = nd.random_normal(loc=0, scale=0.02, shape=noiseshape, ctx=ctx)
                	eps2 = nd.tanh(eps2*sigma+mu)
                	eps2 = nd.concat(eps,eps2,dim=0)
			rec_output = netD(netDe(eps2))
			fake_label = nd.zeros(rec_output.shape, ctx=ctx)
                	errGS = GAN_loss(rec_output, fake_label)
    			errGS.backward()
		mu -= lr / mu.shape[0] * mu.grad
		sigma -= lr / sigma.shape[0] * sigma.grad
	    	print('mu ' + str(mu[0,0,0,0].asnumpy())+ '  sigma '+ str(sigma[0,0,0,0].asnumpy()))
	    '''
	    images = netDe(eps2)
            fake_img1T = nd.concat(images[0],images[1], images[2], dim=1)
            fake_img2T = nd.concat(images[3],images[4], images[5], dim=1)
            fake_img3T = nd.concat(images[6],images[7], images[8], dim=1)
            fake_img = nd.concat(fake_img1T,fake_img2T, fake_img3T,dim=2)
            visual.visualize(fake_img)
            plt.savefig('outputs/'+expname+'_fakespost_'+str(epoch)+'.png')
    return([loss_rec_D,loss_rec_G, loss_rec_R, acc_rec, loss_rec_D2, loss_rec_G2, acc2_rec])
Ejemplo n.º 11
0
def l1_regularization(X):
    return (-nd.norm(X).asscalar(), nd.multiply(-1.0, nd.sign(X)))
Ejemplo n.º 12
0
import mxnet as mx
Ejemplo n.º 13
0
import cv2