dur_pad=data_repr_params['dur_pad'], dur_width=data_repr_params['dur_width'], num_step=data_repr_params['num_time_step'], note_emb_size=model_params['note_emb_size'], enc_notes_hid_size=model_params['enc_notes_hid_size'], enc_time_hid_size=model_params['enc_time_hid_size'], z_size=model_params['z_size'], dec_emb_hid_size=model_params['dec_emb_hid_size'], dec_time_hid_size=model_params['dec_time_hid_size'], dec_notes_hid_size=model_params['dec_notes_hid_size'], dec_z_in_size=model_params['dec_z_in_size'], dec_dur_hid_size=model_params['dec_dur_hid_size'], device=device) if INIT_WEIGHT: model.apply(utils.init_weights) print('The parameters in the model are initialized!') print(f'The model has {utils.count_parameters(model):,} trainable parameters') if PARALLEL: model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3]) model = model.to(device) model = model.module else: model = model.to(device) print('Model loaded!') ############################################################################### # Optimizer and Criterion ############################################################################### optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
class BiAVAE(object): def __init__(self, params): self.params = params self.tune_dir = "{}/{}-{}/{}".format(params.exp_id, params.src_lang, params.tgt_lang, params.norm_embeddings) self.tune_best_dir = "{}/best".format(self.tune_dir) if self.params.eval_file == 'wiki': self.eval_file = '/data/dictionaries/{}-{}.5000-6500.txt'.format( self.params.src_lang, self.params.tgt_lang) elif self.params.eval_file == 'wacky': self.eval_file = '/data/dictionaries/{}-{}.test.txt'.format( self.params.src_lang, self.params.tgt_lang) self.X_AE = VAE(params) self.Y_AE = VAE(params) self.D = Discriminator(input_size=params.d_input_size, hidden_size=params.d_hidden_size, output_size=params.d_output_size) self.nets = [self.X_AE, self.Y_AE, self.D] self.loss_fn = torch.nn.BCELoss() self.loss_fn2 = torch.nn.CosineSimilarity(dim=1, eps=1e-6) def weights_init(self, m): # 正交初始化 if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal(m.weight) if m.bias is not None: torch.nn.init.constant(m.bias, 0.01) def weights_init2(self, m): # xavier_normal 初始化 if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_normal(m.weight) if m.bias is not None: torch.nn.init.constant(m.bias, 0.01) def weights_init3(self, m): # 单位阵初始化 if isinstance(m, torch.nn.Linear): m.weight.data.copy_( torch.diag(torch.ones(self.params.g_input_size))) def freeze(self, m): for p in m.parameters(): p.requires_grad = False def defreeze(self, m): for p in m.parameters(): p.requires_grad = True def init_state(self, state=1, seed=-1): if torch.cuda.is_available(): # Move the network and the optimizer to the GPU for net in self.nets: net.cuda() self.loss_fn = self.loss_fn.cuda() self.loss_fn2 = self.loss_fn2.cuda() if state == 1: print('Init the model...') self.X_AE.apply(self.weights_init) # 可更改G初始化方式 self.Y_AE.apply(self.weights_init3) # 可更改G初始化方式 self.D.apply(self.weights_init2) self.Y_AE.apply(self.freeze) # self.X_AE.apply(self.freeze) elif state == 2: self.X_AE.load_state_dict( torch.load('{}/seed_{}_dico_{}_stage_1_best_X.t7'.format( self.tune_best_dir, seed, self.params.dico_build))) self.Y_AE.load_state_dict( torch.load('{}/seed_{}_dico_{}_stage_1_best_Y.t7'.format( self.tune_best_dir, seed, self.params.dico_build))) self.Y_AE.apply(self.defreeze) self.X_AE.apply(self.freeze) #self.D.load_state_dict(torch.load('{}/seed_{}_dico_{}_stage_1_best_D.t7'.format(self.tune_best_dir,seed,self.params.dico_build))) self.D.apply(self.weights_init2) elif state == 3: print('Init3 the model...') self.X_AE.apply(self.weights_init) # 可更改G初始化方式 self.Y_AE.apply(self.weights_init) # 可更改G初始化方式 self.D.apply(self.weights_init2) else: print('Invalid state!') def train(self, src_dico, tgt_dico, src_emb, tgt_emb, seed, stage): params = self.params # Load data if not os.path.exists(params.data_dir): print("Data path doesn't exists: %s" % params.data_dir) if not os.path.exists(self.tune_dir): os.makedirs(self.tune_dir) if not os.path.exists(self.tune_best_dir): os.makedirs(self.tune_best_dir) src_word2id = src_dico[1] tgt_word2id = tgt_dico[1] en = src_emb it = tgt_emb params = _get_eval_params(params) self.params = params eval = Evaluator(params, en, it, torch.cuda.is_available()) AE_optimizer = optim.SGD(filter( lambda p: p.requires_grad, list(self.X_AE.parameters()) + list(self.Y_AE.parameters())), lr=params.g_learning_rate) # AE_optimizer = optim.SGD(G_params, lr=0.1, momentum=0.9) # AE_optimizer = optim.Adam(G_params, lr=params.g_learning_rate, betas=(0.9, 0.9)) # AE_optimizer = optim.RMSprop(filter(lambda p: p.requires_grad, list(self.X_AE.parameters()) + list(self.Y_AE.parameters())),lr=params.g_learning_rate,alpha=0.9) D_optimizer = optim.SGD(list(self.D.parameters()), lr=params.d_learning_rate) # D_optimizer = optim.Adam(D_params, lr=params.d_learning_rate, betas=(0.5, 0.9)) # D_optimizer = optim.RMSprop(list(self.D_X.parameters()) + list(self.D_Y.parameters()), lr=params.d_learning_rate , alpha=0.9) # true_dict = get_true_dict(params.data_dir) D_acc_epochs = [] d_loss_epochs = [] G_AB_loss_epochs = [] G_BA_loss_epochs = [] G_AB_recon_epochs = [] G_BA_recon_epochs = [] g_loss_epochs = [] acc_epochs = [] csls_epochs = [] best_valid_metric = -100 # logs for plotting later log_file = open( "log_src_tgt.txt", "w") # Being overwritten in every loop, not really required log_file.write( "epoch, disA_loss, disB_loss , disA_acc, disB_acc, g_AB_loss, g_BA_loss, g_AB_recon, g_BA_recon, CSLS, trans_Acc\n" ) if stage == 1: self.params.num_epochs = 50 if stage == 2: self.params.num_epochs = 10 try: for epoch in range(self.params.num_epochs): G_AB_recon = [] G_BA_recon = [] G_X_loss = [] G_Y_loss = [] d_losses = [] g_losses = [] hit_A = 0 total = 0 start_time = timer() # lowest_loss = 1e5 label_D = to_variable( torch.FloatTensor(2 * params.mini_batch_size).zero_()) label_D[:params.mini_batch_size] = 1 - params.smoothing label_D[params.mini_batch_size:] = params.smoothing label_G = to_variable( torch.FloatTensor(params.mini_batch_size).zero_()) label_G = label_G + 1 - params.smoothing label_G2 = to_variable( torch.FloatTensor( params.mini_batch_size).zero_()) + params.smoothing for mini_batch in range( 0, params.iters_in_epoch // params.mini_batch_size): for d_index in range(params.d_steps): D_optimizer.zero_grad() # Reset the gradients self.D.train() view_X, view_Y = self.get_batch_data_fast_new(en, it) # Discriminator X _, Y_Z = self.Y_AE(view_Y) _, X_Z = self.X_AE(view_X) Y_Z = Y_Z.detach() X_Z = X_Z.detach() input = torch.cat([Y_Z, X_Z], 0) pred = self.D(input) D_loss = self.loss_fn(pred, label_D) D_loss.backward( ) # compute/store gradients, but don't change params d_losses.append(to_numpy(D_loss.data)) discriminator_decision_A = to_numpy(pred.data) hit_A += np.sum( discriminator_decision_A[:params.mini_batch_size] >= 0.5) hit_A += np.sum( discriminator_decision_A[params.mini_batch_size:] < 0.5) D_optimizer.step( ) # Only optimizes D's parameters; changes based on stored gradients from backward() # Clip weights _clip(self.D, params.clip_value) sys.stdout.write( "[%d/%d] :: Discriminator Loss: %.3f \r" % (mini_batch, params.iters_in_epoch // params.mini_batch_size, np.asscalar(np.mean(d_losses)))) sys.stdout.flush() total += 2 * params.mini_batch_size * params.d_steps for g_index in range(params.g_steps): # 2. Train G on D's response (but DO NOT train D on these labels) AE_optimizer.zero_grad() self.D.eval() view_X, view_Y = self.get_batch_data_fast_new(en, it) # Generator X_AE ## adversarial loss X_recon, X_Z = self.X_AE(view_X) Y_recon, Y_Z = self.Y_AE(view_Y) # input = torch.cat([Y_Z, X_Z], 0) predx = self.D(X_Z) D_X_loss = self.loss_fn(predx, label_G) predy = self.D(Y_Z) D_Y_loss = self.loss_fn(predy, label_G2) L_recon_X = 1.0 - torch.mean( self.loss_fn2(view_X, X_recon)) L_recon_Y = 1.0 - torch.mean( self.loss_fn2(view_Y, Y_recon)) G_loss = D_X_loss + D_Y_loss + L_recon_X + L_recon_Y G_loss.backward() g_losses.append(to_numpy(G_loss.data)) G_X_loss.append( to_numpy(D_X_loss.data + L_recon_X.data)) G_Y_loss.append( to_numpy(D_Y_loss.data + L_recon_Y.data)) G_AB_recon.append(to_numpy(L_recon_X.data)) G_BA_recon.append(to_numpy(L_recon_Y.data)) AE_optimizer.step() # Only optimizes G's parameters sys.stdout.write( "[%d/%d] :: Generator Loss: %.3f Generator Y recon: %.3f\r" % (mini_batch, params.iters_in_epoch // params.mini_batch_size, np.asscalar(np.mean(g_losses)), np.asscalar(np.mean(G_BA_recon)))) sys.stdout.flush() '''for each epoch''' D_acc_epochs.append(hit_A / total) G_AB_recon_epochs.append(np.asscalar(np.mean(G_AB_recon))) G_BA_recon_epochs.append(np.asscalar(np.mean(G_BA_recon))) d_loss_epochs.append(np.asscalar(np.mean(d_losses))) g_loss_epochs.append(np.asscalar(np.mean(g_losses))) print( "Epoch {} : Discriminator Loss: {:.3f}, Discriminator Accuracy: {:.3f}, Generator Loss: {:.3f}, Time elapsed {:.2f} mins" .format(epoch, np.asscalar(np.mean(d_losses)), hit_A / total, np.asscalar(np.mean(g_losses)), (timer() - start_time) / 60)) if (epoch + 1) % params.print_every == 0: # No need for discriminator weights _, X_Z = self.X_AE(Variable(en)) _, Y_Z = self.Y_AE(Variable(it)) X_Z = X_Z.data Y_Z = Y_Z.data mstart_time = timer() for method in [params.eval_method]: results = get_word_translation_accuracy( params.src_lang, src_word2id, X_Z, params.tgt_lang, tgt_word2id, Y_Z, method=method, dico_eval='default') acc1 = results[0][1] print('{} takes {:.2f}s'.format(method, timer() - mstart_time)) print('Method:{} score:{:.4f}'.format(method, acc1)) csls = eval.dist_mean_cosine(X_Z, Y_Z) if csls > best_valid_metric: print("New csls value: {}".format(csls)) best_valid_metric = csls fp = open( self.tune_best_dir + "/seed_{}_dico_{}_stage_{}_epoch_{}_acc_{:.3f}.tmp" .format(seed, params.dico_build, stage, epoch, acc1), 'w') fp.close() torch.save( self.X_AE.state_dict(), self.tune_best_dir + '/seed_{}_dico_{}_stage_{}_best_X.t7'.format( seed, params.dico_build, stage)) torch.save( self.Y_AE.state_dict(), self.tune_best_dir + '/seed_{}_dico_{}_stage_{}_best_Y.t7'.format( seed, params.dico_build, stage)) torch.save( self.D.state_dict(), self.tune_best_dir + '/seed_{}_dico_{}_stage_{}_best_D.t7'.format( seed, params.dico_build, stage)) # Saving generator weights fp = open( self.tune_dir + "/seed_{}_stage_{}_epoch_{}_acc_{:.3f}.tmp".format( seed, stage, epoch, acc1), 'w') fp.close() acc_epochs.append(acc1) csls_epochs.append(csls) csls_fb, epoch_fb = max([ (score, index) for index, score in enumerate(csls_epochs) ]) fp = open( self.tune_best_dir + "/seed_{}_dico_{}_stage_{}_epoch_{}_Acc_{:.3f}_{:.3f}.cslsfb". format(seed, params.dico_build, stage, epoch_fb, acc_epochs[epoch_fb], csls_fb), 'w') fp.close() # Save the plot for discriminator accuracy and generator loss # fig = plt.figure() # plt.plot(range(0, len(D_A_acc_epochs)), D_A_acc_epochs, color='b', label='D_A') # plt.plot(range(0, len(D_B_acc_epochs)), D_B_acc_epochs, color='r', label='D_B') # plt.ylabel('D_accuracy') # plt.xlabel('epochs') # plt.legend() # fig.savefig(self.tune_dir + '/seed_{}_stage_{}_D_acc.png'.format(seed, stage)) # # fig = plt.figure() # plt.plot(range(0, len(D_A_loss_epochs)), D_A_loss_epochs, color='b', label='D_A') # plt.plot(range(0, len(D_B_loss_epochs)), D_B_loss_epochs, color='r', label='D_B') # plt.ylabel('D_losses') # plt.xlabel('epochs') # plt.legend() # fig.savefig(self.tune_dir + '/seed_{}_stage_{}_D_loss.png'.format(seed, stage)) # # fig = plt.figure() # plt.plot(range(0, len(G_AB_loss_epochs)), G_AB_loss_epochs, color='b', label='G_AB') # plt.plot(range(0, len(G_BA_loss_epochs)), G_BA_loss_epochs, color='r', label='G_BA') # plt.ylabel('G_losses') # plt.xlabel('epochs') # plt.legend() # fig.savefig(self.tune_dir + '/seed_{}_stage_{}_G_loss.png'.format(seed,stage)) # # fig = plt.figure() # plt.plot(range(0, len(G_AB_recon_epochs)), G_AB_recon_epochs, color='b', label='G_AB') # plt.plot(range(0, len(G_BA_recon_epochs)), G_BA_recon_epochs, color='r', label='G_BA') # plt.ylabel('G_recon_loss') # plt.xlabel('epochs') # plt.legend() # fig.savefig(self.tune_dir + '/seed_{}_stage_{}_G_Recon.png'.format(seed,stage)) # fig = plt.figure() # plt.plot(range(0, len(L_Z_loss_epoches)), L_Z_loss_epoches, color='b', label='L_Z') # plt.ylabel('L_Z_loss') # plt.xlabel('epochs') # plt.legend() # fig.savefig(tune_dir + '/seed_{}_stage_{}_L_Z.png'.format(seed,stage)) fig = plt.figure() plt.plot(range(0, len(acc_epochs)), acc_epochs, color='b', label='trans_acc1') plt.ylabel('trans_acc') plt.xlabel('epochs') plt.legend() fig.savefig(self.tune_dir + '/seed_{}_stage_{}_trans_acc.png'.format(seed, stage)) fig = plt.figure() plt.plot(range(0, len(csls_epochs)), csls_epochs, color='b', label='csls') plt.ylabel('csls') plt.xlabel('epochs') plt.legend() fig.savefig(self.tune_dir + '/seed_{}_stage_{}_csls.png'.format(seed, stage)) fig = plt.figure() plt.plot(range(0, len(g_loss_epochs)), g_loss_epochs, color='b', label='G_loss') plt.ylabel('g_loss') plt.xlabel('epochs') plt.legend() fig.savefig(self.tune_dir + '/seed_{}_g_stage_{}_loss.png'.format(seed, stage)) fig = plt.figure() plt.plot(range(0, len(d_loss_epochs)), d_loss_epochs, color='b', label='csls') plt.ylabel('D_loss') plt.xlabel('epochs') plt.legend() fig.savefig(self.tune_dir + '/seed_{}_stage_{}_d_loss.png'.format(seed, stage)) plt.close('all') except KeyboardInterrupt: print("Interrupted.. saving model !!!") torch.save(self.X_AE.state_dict(), 'X_model_interrupt.t7') torch.save(self.Y_AE.state_dict(), 'Y_model_interrupt.t7') torch.save(self.D.state_dict(), 'd_model_interrupt.t7') log_file.close() exit() log_file.close() return def get_batch_data_fast_new(self, emb_en, emb_it): params = self.params random_en_indices = torch.LongTensor(params.mini_batch_size).random_( params.most_frequent_sampling_size) random_it_indices = torch.LongTensor(params.mini_batch_size).random_( params.most_frequent_sampling_size) en_batch = to_variable(emb_en)[random_en_indices.cuda()] it_batch = to_variable(emb_it)[random_it_indices.cuda()] #print(random_en_indices) #print(random_it_indices) return en_batch, it_batch def export_dict(self, src_dico, tgt_dico, emb_en, emb_it, seed): params = self.params # Export adversarial dictionaries optim_X_AE = VAE(params).cuda() optim_Y_AE = VAE(params).cuda() print('Loading pre-trained models...') optim_X_AE.load_state_dict( torch.load(self.tune_dir + '/best/seed_{}_best_X.t7'.format(seed))) optim_Y_AE.load_state_dict( torch.load(self.tune_dir + '/best/seed_{}_best_Y.t7'.format(seed))) X_Z = optim_X_AE.encode(Variable(emb_en)).data Y_Z = optim_Y_AE.encode(Variable(emb_it)).data mstart_time = timer() for method in [params.eval_method]: results = get_word_translation_accuracy(params.src_lang, src_dico[1], X_Z, params.tgt_lang, tgt_dico[1], emb_it, method=method, dico_eval='default') acc1 = results[0][1] for method in [params.eval_method]: results = get_word_translation_accuracy(params.tgt_lang, tgt_dico[1], Y_Z, params.src_lang, src_dico[1], emb_en, method=method, dico_eval='default') acc2 = results[0][1] # csls = 0 print('{} takes {:.2f}s'.format(method, timer() - mstart_time)) print('Method:{} score:{:.4f}-{:.4f}'.format(method, acc1, acc2)) print('Building dictionaries...') params.dico_build = "S2T&T2S" params.dico_method = "csls_knn_10" X_Z = X_Z / X_Z.norm(2, 1, keepdim=True).expand_as(X_Z) emb_it = emb_it / emb_it.norm(2, 1, keepdim=True).expand_as(emb_it) f_dico_induce = build_dictionary(X_Z, emb_it, params) f_dico_induce = f_dico_induce.cpu().numpy() Y_Z = Y_Z / Y_Z.norm(2, 1, keepdim=True).expand_as(Y_Z) emb_en = emb_en / emb_en.norm(2, 1, keepdim=True).expand_as(emb_en) b_dico_induce = build_dictionary(Y_Z, emb_en, params) b_dico_induce = b_dico_induce.cpu().numpy() f_dico_set = set([(a, b) for a, b in f_dico_induce]) b_dico_set = set([(b, a) for a, b in b_dico_induce]) intersect = list(f_dico_set & b_dico_set) union = list(f_dico_set | b_dico_set) with io.open( self.tune_dir + '/best/{}-{}.dict'.format(params.src_lang, params.tgt_lang), 'w', encoding='utf-8', newline='\n') as f: for item in f_dico_induce: f.write('{} {}\n'.format(src_dico[0][item[0]], tgt_dico[0][item[1]])) with io.open( self.tune_dir + '/best/{}-{}.dict'.format(params.tgt_lang, params.src_lang), 'w', encoding='utf-8', newline='\n') as f: for item in b_dico_induce: f.write('{} {}\n'.format(tgt_dico[0][item[0]], src_dico[0][item[1]])) with io.open(self.tune_dir + '/best/{}-{}.intersect'.format( params.src_lang, params.tgt_lang), 'w', encoding='utf-8', newline='\n') as f: for item in intersect: f.write('{} {}\n'.format(src_dico[0][item[0]], tgt_dico[0][item[1]])) with io.open(self.tune_dir + '/best/{}-{}.intersect'.format( params.tgt_lang, params.src_lang), 'w', encoding='utf-8', newline='\n') as f: for item in intersect: f.write('{} {}\n'.format(tgt_dico[0][item[1]], src_dico[0][item[0]])) with io.open( self.tune_dir + '/best/{}-{}.union'.format(params.src_lang, params.tgt_lang), 'w', encoding='utf-8', newline='\n') as f: for item in union: f.write('{} {}\n'.format(src_dico[0][item[0]], tgt_dico[0][item[1]])) with io.open( self.tune_dir + '/best/{}-{}.union'.format(params.tgt_lang, params.src_lang), 'w', encoding='utf-8', newline='\n') as f: for item in union: f.write('{} {}\n'.format(tgt_dico[0][item[1]], src_dico[0][item[0]]))