def generate_onestep(self): #,h2,aa,bb): """ generate from middle layer #call reset() first, but no relation between img_array[input] and generated image[output] [input] x : BxHxW[mono] 3BxHxW[color] matrix (Variable) errx : BxHxW[mono] 3BxHxW[color] matrix (Variable) [output] y : BxHxW[mono] 3BxHxW[color] matrix (Variable) [normal]:sigmoid,relu [whitened]:tanh """ zero_mat = XP.fzeros((self.B, self.z_dim)) z = F.gaussian(zero_mat, zero_mat) #F.gaussian(mean,ln_var) self.c2, self.h2, inc_canvas, Wmean_x, Wmean_y, Wln_var, Wln_stride, Wln_gamma, Rmean_x, Rmean_y, Rln_var, Rln_stride, Rln_gamma = self.decode( self.c2, self.h2, z) #,aa,bb) self.W_filter.mkFilter(Wmean_x, Wmean_y, Wln_var, Wln_stride, Wln_gamma) inc_canvas = F.reshape( inc_canvas, (self.B * self.C, self.Write_patch, self.Write_patch)) inc_canvas = self.W_filter.InvFilter(inc_canvas) self.canvas += inc_canvas y = F.relu(self.canvas + 0.5) - F.relu( self.canvas - 0.5) #F.sigmoid(self.canvas) #[normal]:sigmoid, [whitened]:tanh self.errx = self.x - y self.t += 1 return y #,h2
def reset(self,image_batch): """ initialization target image: x BxHW[mono] Bx3HW[color] matrix (Variable) current canvas: canvas BxHW[mono] Bx3HW[color] matrix (Variable) initial canvas: ini_ [normal] val=0.5 [whitened] val=0.0 BxHW[mono] Bx3HH[color] matrix (Variable) error : target - current canvas {encoder,decoder} {cell,hidden}: 0 [attentional window patch [position](meanx,meany): center of each image(0.5*width,0.5*height) [varience](ln_var):-6.9 (var=0.001) [stride](ln_stride):1.1 (stride=3.0) [gamma](ln_gamma):0.0 (gamma=1.0) t: the number of processed minibatch """ self.cleargrads() self.B = image_batch.shape[0] self.canvas = XP.fzeros((self.B,self.C*self.height*self.width)) self.x = F.reshape(XP.farray(image_batch),(self.B,self.C*self.height*self.width)) self.errx = self.x-XP.fnonzeros((self.B,self.C*self.height*self.width),val=0.0) self.c = XP.fzeros((self.B, self.H_enc)) #initialize encoder cell self.h = XP.fzeros((self.B, self.H_enc)) #initialize encoder hidden self.c2 = XP.fzeros((self.B, self.H_dec)) #initialize decoder cell (decoder hidden is initialized in train_align_draw.py) self.h2 = XP.fzeros((self.B, self.H_dec)) self.t = 0
def __init__(self, C, height, width, patchsize): """ [input] height: vertical size of image (int) width : horizonal size of image(int) patchsize : Attentional window(square) patch edge length(width==height)(int) """ self.C = C self.height = height self.width = width self.patchsize = patchsize self.L_edge = width if width > height else height self.Harray = XP.farray(np.arange(height)) # (1,H) Variable self.Warray = XP.farray(np.arange(width)) # (1,W) Variable self.Parray = XP.farray( np.arange(patchsize, dtype=np.float32).reshape(patchsize, 1) - 0.5 * (patchsize + 1.0)) # (P,1) Variable
def reset(self, image_batch): """ initialization target image: x BxHxW[mono] 3BxHxW[color] matrix (Variable) current canvas: canvas BxHxW[mono] 3BxHxW[color] matrix (Variable) initial canvas: ini_ [normal] val=0.5 [whitened] val=0.0 BxHxW[mono] 3BxHxW[color] matrix (Variable) error : target - current canvas {encoder,decoder} {cell,hidden}: 0 [attentional window patch] [position](meanx,meany): center of each image(0.5*width,0.5*height) [varience](ln_var):-6.9 (var=0.001) [stride](ln_stride):1.1 (stride=3.0) [gamma](ln_gamma):0.0 (gamma=1.0) t: the number of processed minibatch """ self.cleargrads() self.B = image_batch.shape[0] self.canvas = XP.fnonzeros((self.B * self.C, self.height, self.width), val=0.0) self.x = F.reshape(XP.farray(image_batch), (self.B * self.C, self.height, self.width)) self.errx = self.x - F.sigmoid( XP.fnonzeros((self.B * self.C, self.height, self.width), val=0.0)) self.c = XP.fzeros((self.B, self.H_enc)) #initialize encoder cell self.h = XP.fzeros((self.B, self.H_enc)) #initialize encoder hidden self.c2 = XP.fzeros( (self.B, self.H_dec) ) #initialize decoder cell (decoder hidden is initialized in train_align_draw.py) self.h2 = XP.fzeros((self.B, self.H_dec)) #Rmean_x = XP.fzeros((self.B,1)) #Rmean_y = XP.fzeros((self.B,1)) #Wmean_x = XP.fzeros((self.B,1)) #Wmean_y = XP.fzeros((self.B,1)) #ln_var = F.reshape(XP.fnonzeros(self.B,val=0.0),(self.B,1)) #initial_var:0.001 -> ln_var:-6.9 #ln_stride = F.reshape(XP.fnonzeros(self.B,val=0.0),(self.B,1)) #initial_stride:3.0 -> ln_stride:1.1 #ln_gamma = XP.fzeros((self.B,1)) #initial_gamma:1.0 -> ln_gamma:0.0 h_dec = self.h2 Wmean_x = self.dec_hd_Wmeanx(h_dec) Wmean_y = self.dec_hd_Wmeany(h_dec) Wln_var = self.dec_hd_Wlnvar(h_dec) Wln_stride = self.dec_hd_Wlnstride(h_dec) Wln_gamma = self.dec_hd_Wlngamma(h_dec) Rmean_x = self.dec_hd_Rmeanx(h_dec) Rmean_y = self.dec_hd_Rmeany(h_dec) Rln_var = self.dec_hd_Rlnvar(h_dec) Rln_stride = self.dec_hd_Rlnstride(h_dec) Rln_gamma = self.dec_hd_Rlngamma(h_dec) self.R_filter.mkFilter(Rmean_x, Rmean_y, Rln_var, Rln_stride, Rln_gamma) self.W_filter.mkFilter(Wmean_x, Wmean_y, Wln_var, Wln_stride, Wln_gamma) self.t = 0
def generate_onestep(self): """ generate from middle layer #call reset() first, but no relation between img_array[input] and generated image[output] [input] x : BxHW[mono] Bx3HW[color] matrix (Variable) errx : BxHW[mono] Bx3HW[color] matrix (Variable) [output] y : BxHW[mono] Bx3HW[color] matrix (Variable) [normal]:sigmoid,relu [whitened]:tanh """ zero_mat = XP.fzeros((self.B,self.z_dim)) z = F.gaussian(zero_mat,zero_mat) #F.gaussian(mean,ln_var) self.c2,self.h2,inc_canvas = self.decode(self.c2,self.h2,z) self.canvas += inc_canvas y = F.sigmoid(self.canvas) #y = F.relu(self.canvas+0.5)-F.relu(self.canvas-0.5) return y
def main(): args = parse_args() XP.set_library(args) date=time.localtime()[:6] D=[] for i in date: D.append(str(i)) D="_".join(D) save_path=args.save_path if os.path.exists(save_path)==False: os.mkdir(save_path) if args.model_path!=None: print("continue existed model!! load recipe of {}".format(args.model_path)) with open(args.model_path+'/recipe.json','r') as f: recipe=json.load(f) vae_enc=recipe["network"]["IM"]["vae_enc"] vae_z=recipe["network"]["IM"]["vae_z"] vae_dec=recipe["network"]["IM"]["vae_dec"] times=recipe["network"]["IM"]["times"] alpha=recipe["network"]["IM"]["KLcoefficient"] batchsize=recipe["setting"]["batchsize"] maxepoch=args.maxepoch weightdecay=recipe["setting"]["weightdecay"] grad_clip=recipe["setting"]["grad_clip"] cur_epoch=recipe["setting"]["cur_epoch"]+1 ini_lr=recipe["setting"]["initial_learningrate"] cur_lr=recipe["setting"]["cur_lr"] with open(args.model_path+"/../trainloss.json",'r') as f: trainloss_dic=json.load(f) with open(args.model_path+"/../valloss.json",'r') as f: valloss_dic=json.load(f) else: vae_enc=args.vae_enc vae_z=args.vae_z vae_dec=args.vae_dec times=args.times alpha=args.alpha batchsize=args.batchsize maxepoch=args.maxepoch weightdecay=args.weightdecay grad_clip=5 cur_epoch=0 ini_lr=args.lr cur_lr=ini_lr trainloss_dic={} valloss_dic={} print('this experiment started at :{}'.format(D)) print('***Experiment settings***') print('[IM]vae encoder hidden size :{}'.format(vae_enc)) print('[IM]vae hidden layer size :{}'.format(vae_z)) print('[IM]vae decoder hidden layer size :{}'.format(vae_dec)) print('[IM]sequence length:{}'.format(times)) print('max epoch :{}'.format(maxepoch)) print('mini batch size :{}'.format(batchsize)) print('initial learning rate :{}'.format(cur_lr)) print('weight decay :{}'.format(weightdecay)) print("optimization by :{}".format("Adam")) print("VAE KL coefficient:",alpha) print('*************************') vae = VAE_bernoulli_noattention(vae_enc,vae_z,vae_dec,28,28,1) opt = optimizers.Adam(alpha = cur_lr) opt.setup(vae) if args.model_path!=None: print('loading model ...') serializers.load_npz(args.model_path + '/VAEweights', vae) serializers.load_npz(args.model_path + '/optimizer', opt) else: print('making [[new]] model ...') for param in vae.params(): data = param.data data[:] = np.random.uniform(-0.1, 0.1, data.shape) opt.add_hook(optimizer.GradientClipping(grad_clip)) opt.add_hook(optimizer.WeightDecay(weightdecay)) if args.gpu >= 0 : vae.to_gpu() mnist=MNIST(binarize=True) train_size = mnist.train_size test_size = mnist.test_size eps = 1e-8 for epoch in range(cur_epoch+1, maxepoch+1): print('\nepoch {}'.format(epoch)) LX = 0.0 LZ = 0.0 counter = 0 for iter,(img_array,label_array) in enumerate(mnist.gen_train(batchsize,Random=True)): B = img_array.shape[0] Lz = XP.fzeros(()) vae.reset(img_array) #first to T-1 step for j in range(times-1): y,kl = vae.free_energy_onestep() Lz_i = alpha*kl Lz += Lz_i #last step j+=1 y,kl = vae.free_energy_onestep() Lz_i = alpha*kl Lz += Lz_i Lx = Bernoulli_nll_wesp(vae.x,y,eps) LZ += Lz.data LX += Lx.data loss = (Lx+Lz)/batchsize loss.backward() opt.update() counter += B sys.stdout.write('\rnow training ... epoch {}, {}/{} '.format(epoch,counter,mnist.train_size)) sys.stdout.flush() if (iter+1) % 100 == 0: print("({}-th batch mean loss) Lx:%03.3f Lz:%03.3f".format(counter) % (Lx.data/B,Lz.data/B)) img_array = cuda.to_cpu(y.data) im_array = img_array.reshape(batchsize*28,28) img = im_array[:28*5] plt.clf() plt.imshow(img,cmap=cm.gray) plt.colorbar(orientation='horizontal') plt.savefig(save_path+"/"+"img{}.png".format(epoch)) trace(save_path+"/trainloss.txt","epoch {} Lx:{} Lz:{} Lx+Lz:{}".format(epoch,LX/train_size,LZ/train_size,(LX+LZ)/train_size)) trainloss_dic[str(epoch).zfill(3)]={ "Lx":float(LX/train_size), "Lz":float(LZ/train_size), "Lx+Lz":float((LX+LZ)/train_size)} with open(save_path+"/trainloss.json",'w') as f: json.dump(trainloss_dic,f,indent=4) print('save model ...') prefix = save_path+"/"+str(epoch).zfill(3) if os.path.exists(prefix)==False: os.mkdir(prefix) serializers.save_npz(prefix + '/VAEweights', vae) serializers.save_npz(prefix + '/optimizer', opt) print('save recipe...') recipe_dic = { "date":D, "setting":{ "maxepoch":maxepoch, "batchsize":batchsize, "weightdecay":weightdecay, "grad_clip":grad_clip, "opt":"Adam", "initial_learningrate":ini_lr, "cur_epoch":epoch, "cur_lr":cur_lr}, "network":{ "IM":{ "x_size":784, "vae_enc":vae_enc, "vae_z":vae_z, "vae_dec":vae_dec, "times":times, "KLcoefficient":alpha}, }, } with open(prefix+'/recipe.json','w') as f: json.dump(recipe_dic,f,indent=4) if epoch % 1 == 0: print("\nvalidation step") LX = 0.0 LZ = 0.0 counter = 0 for iter,(img_array,label_array) in enumerate(mnist.gen_test(batchsize)): B = img_array.shape[0] Lz = XP.fzeros(()) vae.reset(img_array) #first to T-1 step for j in range(times-1): y,kl = vae.free_energy_onestep() Lz_i = alpha*kl Lz += Lz_i #last step j+=1 y,kl = vae.free_energy_onestep() Lz_i = alpha*kl Lz += Lz_i Lx = Bernoulli_nll_wesp(vae.x,y,eps) LZ += Lz.data.reshape(()) LX += Lx.data.reshape(()) counter += B sys.stdout.write('\rnow testing ... epoch {}, {}/{} '.format(epoch,counter,test_size)) sys.stdout.flush() print("") trace(save_path+"/valloss.txt","epoch {} Lx:{} Lz:{} Lx+Lz:{}".format(epoch,LX/test_size,LZ/test_size,(LX+LZ)/test_size)) valloss_dic[str(epoch).zfill(3)]={ "Lx":float(LX/test_size), "Lz":float(LZ/test_size), "Lx+Lz":float((LX+LZ)/test_size)} with open(save_path+"/valloss.json",'w') as f: json.dump(valloss_dic,f,indent=4) img_array = cuda.to_cpu(y.data) im_array = img_array.reshape(batchsize*28,28) img = im_array[:28*5] plt.clf() plt.imshow(img,cmap=cm.gray) plt.colorbar(orientation='horizontal') plt.savefig(save_path+"/"+"img_test{}.png".format(epoch)) print('finished.')
def mkFilter(self, mean_x, mean_y, ln_var, ln_stride, ln_gamma): eps = 1e-8 """ make Attention Filters, need B Filters for a minibatch(composed of B data), shared between each color map [input] C: 1[mono],3[color] mean_x: Bx1[mono] Bx1[color] (chainer.Variable) mean_y: Bx1[mono] Bx1[color] (chainer.Variable) ln_var: Bx1[mono] Bx1[color] (chainer.Variable) ln_stride: Bx1[mono] Bx1[color] (chainer.Variable) ln_gamma: Bx1[mono] Bx1[color] (Variable) [output] Fx : BxPxW[mono] 3BxPxW[color] matrix (Variable) Fy : BxPxH[mono] 3BxPxH[color] matrix (Variable) Gamma BxHxW[mono] 3BxHxW[color] (Variable) """ P = self.patchsize B = mean_x.data.shape[0] H = self.height W = self.width mean_x = 0.5 * (W + 1.0) * (mean_x + 1.0) # (B,1) mean_y = 0.5 * (H + 1.0) * (mean_y + 1.0) # (B,1) var = F.exp(ln_var) stride = (self.L_edge - 1.0) / (P - 1.0) * F.exp(ln_stride) gamma = F.exp(ln_gamma) mu_x = F.broadcast_to(mean_x, (P, B, 1)) # (B,1) -> (P,B,1) mu_x = F.transpose(mu_x, (1, 0, 2)) # -> (B,P,1) mu_y = F.broadcast_to(mean_y, (P, B, 1)) # (B,1) -> (P,B,1) mu_y = F.transpose(mu_y, (1, 0, 2)) # -> (B,P,1) stride = F.broadcast_to(stride, (P, B, 1)) # (B,1) -> (P,B,1) stride = F.transpose(stride, (1, 0, 2)) # -> (B,P,1) var_x = F.broadcast_to(var, (P, W, B, 1)) # (B,1) -> (P,W,B,1) var_x = F.transpose(var_x, (2, 0, 1, 3)) # -> (B,P,W,1) var_y = F.broadcast_to(var, (P, H, B, 1)) # (B,1) -> (P,H,B,1) var_y = F.transpose(var_y, (2, 0, 1, 3)) # -> (B,P,H,1) mu_x = mu_x + F.broadcast_to(self.Parray, (B, P, 1)) * stride # (B,P,1) mu_y = mu_y + F.broadcast_to(self.Parray, (B, P, 1)) * stride # (B,P,1) mu_x = F.transpose(F.broadcast_to(mu_x, (self.width, B, P, 1)), (1, 2, 0, 3)) mu_x = F.broadcast_to(self.Warray, (B, P, W)) - F.reshape(mu_x, (B, P, W)) mu_y = F.transpose(F.broadcast_to(mu_y, (self.height, B, P, 1)), (1, 2, 0, 3)) mu_y = F.broadcast_to(self.Harray, (B, P, H)) - F.reshape(mu_y, (B, P, H)) var_x = F.reshape(var_x, (B, P, W)) # (B,P,W) -> (B,P,W) var_y = F.reshape(var_y, (B, P, H)) # (B,P,H) -> (B,P,H) x_square = -0.5 * (mu_x / var_x)**2 # (B,P,W) y_square = -0.5 * (mu_y / var_y)**2 # (B,P,H) x_gauss = F.exp(x_square) y_gauss = F.exp(y_square) xsum = F.sum(x_gauss, 2) # (B,P) ysum = F.sum(y_gauss, 2) # (B,P) Zx_prev = F.transpose(F.broadcast_to(xsum, (W, B, P)), (1, 2, 0)) enable = Variable(Zx_prev.data > eps) Zx = F.where(enable, Zx_prev, XP.fnonzeros(Zx_prev.data.shape, val=1.0) * eps) Zy_prev = F.transpose(F.broadcast_to(ysum, (H, B, P)), (1, 2, 0)) enable = Variable(Zy_prev.data > eps) Zy = F.where(enable, Zy_prev, XP.fnonzeros(Zy_prev.data.shape, val=1.0) * eps) Fx = x_gauss / Zx Fy = y_gauss / Zy gamma_ = F.broadcast_to(gamma, (P, P, self.C, B, 1)) # (B,1) -> (H,W,C,B,1) Gamma = F.reshape(F.transpose(gamma_, (4, 3, 2, 0, 1)), (self.C * B, P, P)) # -> (C*B,H,W) Fx_ = F.broadcast_to(Fx, (self.C, B, P, W)) Fy_ = F.broadcast_to(Fy, (self.C, B, P, H)) Fx = F.reshape(F.transpose(Fx_, (1, 0, 2, 3)), (self.C * B, P, W)) Fy = F.reshape(F.transpose(Fy_, (1, 0, 2, 3)), (self.C * B, P, H)) self.Fx = Fx self.Fy = Fy self.Gamma = Gamma