def train(self): self.G_weights_layer = nn.softmax_weights(self.args.ng, LL.InputLayer(shape=(), input_var=self.dummy_input)) self.D_weights_layer = nn.softmax_weights(self.args.ng, LL.InputLayer(shape=(), input_var=self.dummy_input)) self.G_weights = LL.get_output(self.G_weights_layer, None, deterministic=True) self.D_weights = LL.get_output(self.D_weights_layer, None, deterministic=True) self.Disc_weights_entropy = T.sum((-1./self.args.nd) * T.log(self.D_weights + 0.000001), [0,1]) self.Gen_weights_entropy = T.sum((-1./self.args.ng) * T.log(self.G_weights + 0.000001), [0,1]) for i in range(self.args.ng): gen_layers_i, gen_x_i = self.get_generator(self.meanx, self.z, self.y_1hot) self.G_layers.append(gen_layers_i) self.Gen_x_list.append(gen_x_i) self.Gen_x = T.concatenate(self.Gen_x_list, axis=0) for i in range(self.args.nd): disc_layers_i, disc_layer_adv_i, disc_layer_z_recon_i = self.get_discriminator() self.D_layers.append(disc_layers_i) self.D_layer_adv.append(disc_layer_adv_i) self.D_layer_z_recon.append(disc_layer_z_recon_i) #T.set_subtensor(self.Gen_x[i*self.args.batch_size:(i+1)*self.args.batch_size], gen_x_i) #self.samplers.append(self.sampler(self.z[i], self.y)) ''' forward pass ''' loss_gen0_cond_list = [] loss_disc0_class_list = [] loss_disc0_adv_list = [] loss_gen0_ent_list = [] loss_gen0_adv_list = [] #loss_disc_list for i in range(self.args.ng): self.y_recon_list.append(LL.get_output(self.enc_layer_fc4, self.Gen_x_list[i], deterministic=True)) # reconstructed pool3 activations for i in range(self.args.ng): #loss_gen0_cond = T.mean((recon_fc3_list[i] - self.real_fc3)**2) # feature loss, euclidean distance in feature space loss_gen0_cond = T.mean(T.nnet.categorical_crossentropy(self.y_recon_list[i], self.y)) loss_disc0_class = 0 loss_disc0_adv = 0 loss_gen0_ent = 0 loss_gen0_adv = 0 for j in range(self.args.nd): output_before_softmax_real0 = LL.get_output(self.D_layer_adv[j], self.x, deterministic=False) output_before_softmax_gen0, recon_z0 = LL.get_output([self.D_layer_adv[j], self.D_layer_z_recon[j]], self.Gen_x_list[i], deterministic=False) # discriminator's predicted probability that gen_x is real ''' loss for discriminator and Q ''' l_lab0 = output_before_softmax_real0[T.arange(self.args.batch_size),self.y] l_unl0 = nn.log_sum_exp(output_before_softmax_real0) l_gen0 = nn.log_sum_exp(output_before_softmax_gen0) loss_disc0_class += T.dot(self.D_weights[0,j], -T.mean(l_lab0) + T.mean(T.mean(nn.log_sum_exp(output_before_softmax_real0)))) # loss for not correctly classifying the category of real images loss_real0 = -T.mean(l_unl0) + T.mean(T.nnet.softplus(l_unl0)) # loss for classifying real as fake loss_fake0 = T.mean(T.nnet.softplus(l_gen0)) # loss for classifying fake as real loss_disc0_adv += T.dot(self.D_weights[0,j], 0.5*loss_real0 + 0.5*loss_fake0) loss_gen0_ent += T.dot(self.D_weights[0,j], T.mean((recon_z0 - self.z)**2)) #loss_gen0_ent = T.mean((recon_z0 - self.z)**2) ''' loss for generator ''' loss_gen0_adv += T.dot(self.D_weights[0,j], -T.mean(T.nnet.softplus(l_gen0))) loss_gen0_cond_list.append(T.dot(self.G_weights[0,i], loss_gen0_cond)) loss_disc0_class_list.append(T.dot(self.G_weights[0,i], loss_disc0_class)) loss_disc0_adv_list.append(T.dot(self.G_weights[0,i], loss_disc0_adv)) loss_gen0_ent_list.append(T.dot(self.G_weights[0,i], loss_gen0_ent)) loss_gen0_adv_list.append(T.dot(self.G_weights[0,i], loss_gen0_adv)) self.loss_gen0_cond = sum(loss_gen0_cond_list) self.loss_disc0_class = sum(loss_disc0_class_list) self.loss_disc0_adv = sum(loss_disc0_adv_list) self.loss_gen0_ent = sum(loss_gen0_ent_list) self.loss_gen0_adv = sum(loss_gen0_adv_list) self.loss_disc = self.args.labloss_weight * self.loss_disc0_class + self.args.advloss_weight * self.loss_disc0_adv + self.args.entloss_weight * self.loss_gen0_ent + self.args.mix_entloss_weight * self.Disc_weights_entropy self.loss_gen = self.args.advloss_weight * self.loss_gen0_adv + self.args.condloss_weight * self.loss_gen0_cond + self.args.entloss_weight * self.loss_gen0_ent + self.args.mix_entloss_weight * self.Gen_weights_entropy if self.args.load_epoch is not None: print("loading model") self.load_model(self.args.load_epoch) print("success") ''' collect parameter updates for discriminators ''' Disc_params = LL.get_all_params(self.D_weights_layer, trainable=True) Disc_bn_updates = [] Disc_bn_params = [] self.threshold = self.mincost + self.args.labloss_weight * self.loss_disc0_class + self.args.entloss_weight * self.loss_gen0_ent + self.args.mix_entloss_weight * self.Disc_weights_entropy #threshold = mincost + self.args.labloss_weight * self.loss_disc0_class + self.args.entloss_weight * self.loss_gen0_ent for i in range(self.args.nd): Disc_params.extend(LL.get_all_params(self.D_layers[i], trainable=True)) Disc_bn_updates.extend([u for l in LL.get_all_layers(self.D_layers[i][-1]) for u in getattr(l,'bn_updates',[])]) for l in LL.get_all_layers(self.D_layers[i][-1]): if hasattr(l, 'avg_batch_mean'): Disc_bn_params.append(l.avg_batch_mean) Disc_bn_params.append(l.avg_batch_var) Disc_param_updates = nn.adam_conditional_updates(Disc_params, self.loss_disc, mincost=self.threshold, lr=self.disc_lr, mom1=0.5) # if loss_disc_x < mincost, don't update the discriminator Disc_param_avg = [th.shared(np.cast[th.config.floatX](0.*p.get_value())) for p in Disc_params] # initialized with 0 Disc_avg_updates = [(a,a+0.0001*(p-a)) for p,a in zip(Disc_params, Disc_param_avg)] # online update of historical parameters """ #Disc_param_updates = nn.adam_updates(Disc_params, self.loss_disc, lr=self.lr, mom1=0.5) # collect parameters #Disc_params = LL.get_all_params(self.D_layers[-1], trainable=True) Disc_params = LL.get_all_params(self.D_layers, trainable=True) #Disc_param_updates = nn.adam_updates(Disc_params, loss_disc_x, lr=lr, mom1=0.5) # loss for discriminator = supervised_loss + unsupervised loss Disc_param_updates = nn.adam_conditional_updates(Disc_params, self.loss_disc, mincost=threshold, lr=self.disc_lr, mom1=0.5) # if loss_disc_x < mincost, don't update the discriminator Disc_param_avg = [th.shared(np.cast[th.config.floatX](0.*p.get_value())) for p in Disc_params] # initialized with 0 Disc_avg_updates = [(a,a+0.0001*(p-a)) for p,a in zip(Disc_params,Disc_param_avg)] # online update of historical parameters #Disc_avg_givens = [(p,a) for p,a in zip(Disc_params,Disc_param_avg)] Disc_bn_updates = [u for l in LL.get_all_layers(self.D_layers[-1]) for u in getattr(l,'bn_updates',[])] Disc_bn_params = [] for l in LL.get_all_layers(self.D_layers[-1]): if hasattr(l, 'avg_batch_mean'): Disc_bn_params.append(l.avg_batch_mean) Disc_bn_params.append(l.avg_batch_var) """ ''' collect parameter updates for generators ''' Gen_params = LL.get_all_params(self.G_weights_layer, trainable=True) Gen_params_updates = [] Gen_bn_updates = [] Gen_bn_params = [] for i in range(self.args.ng): Gen_params.extend(LL.get_all_params(self.G_layers[i][-1], trainable=True)) Gen_bn_updates.extend([u for l in LL.get_all_layers(self.G_layers[i][-1]) for u in getattr(l,'bn_updates',[])]) for l in LL.get_all_layers(self.G_layers[i][-1]): if hasattr(l, 'avg_batch_mean'): Gen_bn_params.append(l.avg_batch_mean) Gen_bn_params.append(l.avg_batch_var) Gen_param_updates = nn.adam_updates(Gen_params, self.loss_gen, lr=self.gen_lr, mom1=0.5) """ #print(Gen_params) #train_batch_gen = th.function(inputs=[self.x, self.meanx, self.z, self.y_1hot, self.lr], outputs=[self.loss_gen], on_unused_input='warn') #theano.printing.debugprint(train_batch_gen) Gen_param_updates = nn.adam_updates(Gen_params, self.loss_gen, lr=self.lr, mom1=0.5) Gen_params = LL.get_all_params(self.G_layers[-1], trainable=True) Gen_param_updates = nn.adam_updates(Gen_params, self.loss_gen, lr=self.gen_lr, mom1=0.5) Gen_bn_updates = [u for l in LL.get_all_layers(self.G_layers[-1]) for u in getattr(l,'bn_updates',[])] Gen_bn_params = [] for l in LL.get_all_layers(self.G_layers[-1]): if hasattr(l, 'avg_batch_mean'): Gen_bn_params.append(l.avg_batch_mean) Gen_bn_params.append(l.avg_batch_var) """ ''' define training and testing functions ''' #train_batch_disc = th.function(inputs=[x, meanx, y, lr], outputs=[loss_disc0_class, loss_disc0_adv, gen_x, x], # updates=disc0_param_updates+disc0_bn_updates) #th.printing.debugprint(self.loss_disc) train_batch_disc = th.function(inputs=[self.dummy_input, self.meanx, self.x, self.y, self.y_1hot, self.mincost, self.disc_lr], outputs=[self.loss_disc0_class, self.loss_disc0_adv], updates=Disc_param_updates+Disc_bn_updates+Disc_avg_updates) #th.printing.pydotprint(train_batch_disc, outfile="logreg_pydotprint_prediction.png", var_with_name_simple=True) #train_batch_gen = th.function(inputs=[x, meanx, y_1hot, lr], outputs=[loss_gen0_adv, loss_gen0_cond, loss_gen0_ent], # updates=gen0_param_updates+gen0_bn_updates) #train_batch_gen = th.function(inputs=gen_inputs, outputs=gen_outputs, updates=gen0_param_updates+gen0_bn_updates) #train_batch_gen = th.function(inputs=[self.dummy_input, self.x, self.meanx, self.z, self.y_1hot, self.lr], outputs=[self.loss_gen0_adv, self.loss_gen0_cond, self.loss_gen0_ent], updates=Gen_param_updates+Gen_bn_updates) train_batch_gen = th.function(inputs=[self.dummy_input, self.meanx, self.y, self.y_1hot, self.gen_lr], outputs=[self.loss_gen0_adv, self.loss_gen0_cond, self.loss_gen0_ent], updates=Gen_param_updates+Gen_bn_updates) # samplefun = th.function(inputs=[meanx, y_1hot], outputs=gen_x_joint) # sample function: generating images by stacking all generators reconfun = th.function(inputs=[self.meanx, self.y_1hot], outputs=self.Gen_x) # reconstruction function: use the bottom generator # to generate images conditioned on real fc3 features mix_weights = th.function(inputs=[self.dummy_input], outputs=[self.D_weights, self.Disc_weights_entropy, self.G_weights, self.Gen_weights_entropy]) ''' load data ''' print("Loading data...") meanimg, data = load_cifar_data(self.args.data_dir) trainx = data['X_train'] trainy = data['Y_train'] nr_batches_train = int(trainx.shape[0]/self.args.batch_size) # testx = data['X_test'] # testy = data['Y_test'] # nr_batches_test = int(testx.shape[0]/self.args.batch_size) ''' perform training ''' #logs = {'loss_gen0_adv': [], 'loss_gen0_cond': [], 'loss_gen0_ent': [], 'loss_disc0_class': [], 'var_gen0': [], 'var_real0': []} # training logs logs = {'loss_gen0_adv': [], 'loss_gen0_cond': [], 'loss_gen0_ent': [], 'loss_disc0_class': []} # training logs for epoch in range(self.args.load_epoch+1, self.args.num_epoch): begin = time.time() ''' shuffling ''' inds = rng.permutation(trainx.shape[0]) trainx = trainx[inds] trainy = trainy[inds] for t in range(nr_batches_train): #for t in range(1): ''' construct minibatch ''' #batchz = np.random.uniform(size=(self.args.batch_size, self.args.z0dim)).astype(np.float32) batchx = trainx[t*self.args.batch_size:(t+1)*self.args.batch_size] batchy = trainy[t*self.args.batch_size:(t+1)*self.args.batch_size] batchy_1hot = np.zeros((self.args.batch_size, 10), dtype=np.float32) batchy_1hot[np.arange(self.args.batch_size), batchy] = 1 # convert to one-hot label # randomy = np.random.randint(10, size = (self.args.batch_size,)) # randomy_1hot = np.zeros((self.args.batch_size, 10),dtype=np.float32) # randomy_1hot[np.arange(self.args.batch_size), randomy] = 1 ''' train discriminators ''' l_disc0_class, l_disc0_adv = train_batch_disc(0.0, meanimg, batchx, batchy, batchy_1hot, self.args.mincost, self.args.disc_lr) ''' train generators ''' #prob_gen0 = np.exp() if l_disc0_adv > 0.65: n_iter = 1 elif l_disc0_adv > 0.5: n_iter = 3 elif l_disc0_adv > 0.3: n_iter = 5 else: n_iter = 7 for i in range(n_iter): #l_gen0_adv, l_gen0_cond, l_gen0_ent = train_batch_gen(0.0, batchx, meanimg, batchz, batchy_1hot, self.args.gen_lr) l_gen0_adv, l_gen0_cond, l_gen0_ent = train_batch_gen(0.0, meanimg, batchy, batchy_1hot, self.args.gen_lr) d_mix_weights, d_entloss, g_mix_weights, g_entloss = mix_weights(0.0) ''' store log information ''' # logs['loss_gen1_adv'].append(l_gen1_adv) # logs['loss_gen1_cond'].append(l_gen1_cond) # logs['loss_gen1_ent'].append(l_gen1_ent) # logs['loss_disc1_class'].append(l_disc1_class) # logs['var_gen1'].append(np.var(np.array(g1))) # logs['var_real1'].append(np.var(np.array(r1))) logs['loss_gen0_adv'].append(l_gen0_adv) logs['loss_gen0_cond'].append(l_gen0_cond) logs['loss_gen0_ent'].append(l_gen0_ent) logs['loss_disc0_class'].append(l_disc0_class) #logs['var_gen0'].append(np.var(np.array(g0))) #logs['var_real0'].append(np.var(np.array(r0))) print("---Epoch %d, time = %ds" % (epoch, time.time()-begin)) print("D_weights=[%.6f, %.6f, %.6f, %.6f, %.6f] loss = %0.6f" % (d_mix_weights[0,0], d_mix_weights[0,1], d_mix_weights[0,2], d_mix_weights[0,3], d_mix_weights[0,4], d_entloss)) print("G_weights=[%.6f, %.6f, %.6f, %.6f, %.6f] loss = %0.6f" % (g_mix_weights[0,0], g_mix_weights[0,1], g_mix_weights[0,2], g_mix_weights[0,3], g_mix_weights[0,4], g_entloss)) #print("G_weights=[%.6f]" % (g_mix_weights[0,0])) print("loss_disc0_adv = %.4f, loss_gen0_adv = %.4f, loss_gen0_cond = %.4f, loss_gen0_ent = %.4f, loss_disc0_class = %.4f" % (l_disc0_adv, l_gen0_adv, l_gen0_cond, l_gen0_ent, l_disc0_class)) # ''' sample images by stacking all generators''' # imgs = samplefun(meanimg, refy_1hot) # imgs = np.transpose(np.reshape(imgs[:100,], (100, 3, 32, 32)), (0, 2, 3, 1)) # imgs = [imgs[i] for i in range(100)] # rows = [] # for i in range(10): # rows.append(np.concatenate(imgs[i::10], 1)) # imgs = np.concatenate(rows, 0) # scipy.misc.imsave(self.args.out_dir + "/mnist_sample_epoch{}.png".format(epoch), imgs) """ ''' original images in the training set''' orix = np.transpose(np.reshape(batchx[:100,], (100, 3, 32, 32)), (0, 2, 3, 1)) orix = [orix[i] for i in range(100)] rows = [] for i in range(10): rows.append(np.concatenate(orix[i::10], 1)) orix = np.concatenate(rows, 0) scipy.misc.imsave(self.args.out_dir + "/mnist_ori_epoch{}.png".format(epoch), orix) """ if epoch%self.args.save_interval==0: # np.savez(self.args.out_dir + "/disc1_params_epoch{}.npz".format(epoch), *LL.get_all_param_values(disc1_layers[-1])) # np.savez(self.args.out_dir + '/gen1_params_epoch{}.npz'.format(epoch), *LL.get_all_param_values(gen1_layers[-1])) #np.savez(self.args.out_dir + "/disc0_params_epoch{}.npz".format(epoch), *LL.get_all_param_values(disc0_layers)) #np.savez(self.args.out_dir + '/gen0_params_epoch{}.npz'.format(epoch), *LL.get_all_param_values(gen0_layers)) np.savez(self.args.out_dir + '/Dweights_params_epoch{}.npz'.format(epoch), *LL.get_all_param_values(self.D_weights_layer)) np.savez(self.args.out_dir + '/Gweights_params_epoch{}.npz'.format(epoch), *LL.get_all_param_values(self.G_weights_layer)) for i in range(self.args.ng): np.savez(self.args.out_dir + ("/disc%d_params_epoch%d.npz" % (i,epoch)), *LL.get_all_param_values(self.D_layers[i])) np.savez(self.args.out_dir + ("/gen%d_params_epoch%d.npz" % (i,epoch)), *LL.get_all_param_values(self.G_layers[i])) np.save(self.args.out_dir + '/logs.npy',logs) ''' reconstruct images ''' reconx = reconfun(meanimg, batchy_1hot) + meanimg width = np.round(np.sqrt(self.args.batch_size)).astype(int) for i in range(self.args.ng): reconx_i = np.transpose(np.reshape(reconx[i*self.args.batch_size:(i+1)*self.args.batch_size], (self.args.batch_size, 3, 32, 32)), (0, 2, 3, 1)) reconx_i = [reconx_i[j] for j in range(self.args.batch_size)] rows = [] for j in range(width): rows.append(np.concatenate(reconx_i[j::width], 1)) reconx_i = np.concatenate(rows, 0) scipy.misc.imsave(self.args.out_dir + ("/cifar_recon_%d_epoch%d.png"%(i,epoch)), reconx_i)
train_batch_disc = th.function( inputs=[x, meanx, y, lr], outputs=[loss_disc0_class, loss_disc0_adv, gen_x, x], updates=disc0_param_updates + disc0_bn_updates) train_batch_gen = th.function( inputs=[x, meanx, lr], outputs=[loss_gen0_adv, loss_gen0_cond, loss_gen0_ent], updates=gen0_param_updates + gen0_bn_updates) # samplefun = th.function(inputs=[meanx, y_1hot], outputs=gen_x_joint) # sample function: generating images by stacking all generators reconfun = th.function( inputs=[x, meanx], outputs=gen_x) # reconstruction function: use the bottom generator # to generate images conditioned on real fc3 features ''' load data ''' print("Loading data...") meanimg, data = load_cifar_data(args.data_dir) trainx = data['X_train'] trainy = data['Y_train'] nr_batches_train = int(trainx.shape[0] / args.batch_size) # testx = data['X_test'] # testy = data['Y_test'] # nr_batches_test = int(testx.shape[0]/args.batch_size) refy = np.zeros((args.batch_size, ), dtype=np.int) for i in range(args.batch_size): refy[i] = i % 10 refy_1hot = np.zeros((args.batch_size, 10), dtype=np.float32) refy_1hot[np.arange(args.batch_size), refy] = 1 ''' perform training ''' logs = { 'loss_gen0_adv': [],