def __init__(self, discriminator, generator, utils, embedder): super(CycleGAN, self).__init__() self.D = discriminator self.G = generator self.R = copy.deepcopy(generator) self.D_opt = torch.optim.Adam(self.D.parameters()) # self.G_opt = torch.optim.Adam(self.G.parameters()) self.G_opt = NoamOpt( utils.emb_mat.shape[1], 1, 4000, torch.optim.Adam(self.G.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) # self.R_opt = torch.optim.Adam(self.R.parameters()) self.R_opt = NoamOpt( utils.emb_mat.shape[1], 1, 4000, torch.optim.Adam(self.R.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) self.embed = embedder self.utils = utils self.criterion = nn.CrossEntropyLoss(ignore_index=-1) self.mse = nn.MSELoss() self.cos = nn.CosineSimilarity(dim=-1) self.cosloss = nn.CosineEmbeddingLoss() self.r_criterion = LabelSmoothing(size=utils.emb_mat.shape[0], padding_idx=0, smoothing=0.0) self.r_loss_compute = SimpleLossCompute(self.R.generator, self.r_criterion, self.R_opt)
def pretrain(model, embedding_layer, utils, epoch_num=1): criterion = LabelSmoothing(size=utils.emb_mat.shape[0], padding_idx=0, smoothing=0.0) model_opt = NoamOpt(utils.emb_mat.shape[1], 1, 4000, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) X_test_batch = None Y_test_batch = None for i, batch in enumerate(data_gen(utils.data_generator("X"), utils.sents2idx)): X_test_batch = batch break for i, batch in enumerate(data_gen(utils.data_generator("Y"), utils.sents2idx)): Y_test_batch = batch break model.to(device) for epoch in range(epoch_num): model.train() print("EPOCH %d:"%(epoch+1)) pretrain_run_epoch(data_gen(utils.data_generator("Y"), utils.sents2idx), model, SimpleLossCompute(model.generator, criterion, model_opt), utils.train_step_num, embedding_layer) pretrain_run_epoch(data_gen(utils.data_generator("X"), utils.sents2idx), model, SimpleLossCompute(model.generator, criterion, model_opt), utils.train_step_num, embedding_layer) model.eval() torch.save(model.state_dict(), 'model_pretrain.ckpt') x = utils.idx2sent(greedy_decode(model, embedding_layer, X_test_batch.src, X_test_batch.src_mask, max_len=20, start_symbol=2)) y = utils.idx2sent(greedy_decode(model, embedding_layer, Y_test_batch.src, Y_test_batch.src_mask, max_len=20, start_symbol=2)) for i,j in zip(X_test_batch.src, x): print("===") k = utils.idx2sent([i])[0] print("ORG:", " ".join(k[:k.index('<eos>')+1])) print("--") print("GEN:", " ".join(j[:j.index('<eos>')+1] if '<eos>' in j else j)) print("===") print("=====") for i, j in zip(Y_test_batch.src, y): print("===") k = utils.idx2sent([i])[0] print("ORG:", " ".join(k[:k.index('<eos>')+1])) print("--") print("GEN:", " ".join(j[:j.index('<eos>')+1] if '<eos>' in j else j)) print("===")
class CycleGAN(nn.Module): def __init__(self, discriminator, generator, utils, embedder): super(CycleGAN, self).__init__() self.D = discriminator self.G = generator self.R = copy.deepcopy(generator) self.D_opt = torch.optim.Adam(self.D.parameters()) # self.G_opt = torch.optim.Adam(self.G.parameters()) self.G_opt = NoamOpt(utils.emb_mat.shape[1], 1, 4000, torch.optim.Adam(self.G.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) # self.R_opt = torch.optim.Adam(self.R.parameters()) self.R_opt = NoamOpt(utils.emb_mat.shape[1], 1, 4000, torch.optim.Adam(self.R.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) self.embed = embedder self.utils = utils self.criterion = nn.CrossEntropyLoss(ignore_index=-1) self.mse = nn.MSELoss() self.cos = nn.CosineSimilarity(dim=-1) self.cosloss=nn.CosineEmbeddingLoss() self.r_criterion = LabelSmoothing(size=utils.emb_mat.shape[0], padding_idx=0, smoothing=0.0) self.r_loss_compute = SimpleLossCompute(self.R.generator, self.r_criterion, self.R_opt) def save_model(self, d_path="Dis_model.ckpt", g_path="Gen_model.ckpt", r_path="Res_model.ckpt"): torch.save(self.D.state_dict(), d_path) torch.save(self.G.state_dict(), g_path) torch.save(self.R.state_dict(), r_path) def load_model(self, path="", g_file=None, d_file=None, r_file=None): if g_file!=None: self.G.load_state_dict(torch.load(os.path.join(path, g_file), map_location=device)) if d_file!=None: self.D.load_state_dict(torch.load(os.path.join(path, d_file), map_location=device)) if r_file!=None: self.R.load_state_dict(torch.load(os.path.join(path, r_file), map_location=device)) print("model loaded!") def pretrain_disc(self, num_epochs=100): # self.D.to(self.D.device) # self.G.to(self.G.device) # self.R.to(self.R.device) X_datagen = self.utils.data_generator("X") Y_datagen = self.utils.data_generator("Y") for epoch in range(num_epochs): d_steps = self.utils.train_step_num d_ct = Clock(d_steps, title="Train Discriminator(%d/%d)"%(epoch, num_epochs)) for step, X_data, Y_data in zip(range(d_steps), data_gen(X_datagen, self.utils.sents2idx), data_gen(Y_datagen, self.utils.sents2idx)): # 1. Train D on real+fake # if epoch == 0: # break self.D.zero_grad() # 1A: Train D on real d_real_pred = self.D(self.embed(Y_data.src.to(device))) d_real_error = self.criterion(d_real_pred, torch.ones((d_real_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # ones = true # 1B: Train D on fake d_fake_pred = self.D(self.embed(X_data.src.to(device))) d_fake_error = self.criterion(d_fake_pred, torch.zeros((d_fake_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # zeros = fake (d_fake_error + d_real_error).backward() self.D_opt.step() # Only optimizes D's parameters; changes based on stored gradients from backward() d_ct.flush(info={"D_loss": d_fake_error.item()}) torch.save(self.D.state_dict(), "model_disc_pretrain.ckpt") def train_model(self, num_epochs=100, d_steps=20, g_steps=80, g_scale=1.0, r_scale=1.0): # self.D.to(self.D.device) # self.G.to(self.G.device) # self.R.to(self.R.device) for i, batch in enumerate(data_gen(utils.data_generator("X"), utils.sents2idx)): X_test_batch = batch break for i, batch in enumerate(data_gen(utils.data_generator("Y"), utils.sents2idx)): Y_test_batch = batch break X_datagen = self.utils.data_generator("X") Y_datagen = self.utils.data_generator("Y") for epoch in range(num_epochs): d_ct = Clock(d_steps, title="Train Discriminator(%d/%d)" % (epoch, num_epochs)) if epoch>0: for i, X_data, Y_data in zip(range(d_steps), data_gen(X_datagen, self.utils.sents2idx), data_gen(Y_datagen, self.utils.sents2idx)): # 1. Train D on real+fake # if epoch == 0: # break self.D.zero_grad() # 1A: Train D on real d_real_pred = self.D(self.embed(Y_data.src.to(device))) d_real_error = self.criterion(d_real_pred, torch.ones((d_real_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # ones = true # 1B: Train D on fake self.G.to(device) d_fake_data = backward_decode(self.G, self.embed, X_data.src, X_data.src_mask, max_len=self.utils.max_len, return_term=0).detach() # detach to avoid training G on these labels d_fake_pred = self.D(d_fake_data) d_fake_error = self.criterion(d_fake_pred, torch.zeros((d_fake_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # zeros = fake (d_fake_error + d_real_error).backward() self.D_opt.step() # Only optimizes D's parameters; changes based on stored gradients from backward() d_ct.flush(info={"D_loss":d_fake_error.item()}) g_ct = Clock(g_steps, title="Train Generator(%d/%d)"%(epoch, num_epochs)) r_ct = Clock(g_steps, title="Train Reconstructor(%d/%d)" % (epoch, num_epochs)) if epoch>0: for i, X_data in zip(range(g_steps), data_gen(X_datagen, self.utils.sents2idx)): # 2. Train G on D's response (but DO NOT train D on these labels) self.G.zero_grad() g_fake_data = backward_decode(self.G, self.embed, X_data.src, X_data.src_mask, max_len=self.utils.max_len, return_term=0) dg_fake_pred = self.D(g_fake_data) g_error = self.criterion(dg_fake_pred, torch.ones((dg_fake_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # we want to fool, so pretend it's all genuine g_error.backward(retain_graph=True) self.G_opt.step() # Only optimizes G's parameters self.G.zero_grad() g_ct.flush(info={"G_loss": g_error.item()}) # 3. reconstructor 643988636173-69t5i8ehelccbq85o3esu11jgh61j8u5.apps.googleusercontent.com # way_3 out = self.R.forward(g_fake_data, embedding_layer(X_data.trg.to(device)), None, X_data.trg_mask) r_loss = self.r_loss_compute(out, X_data.trg_y, X_data.ntokens) # way_2 # r_reco_data = prob_backward(self.R, self.embed, g_fake_data, None, max_len=self.utils.max_len, raw=True) # x_orgi_data = X_data.src[:, 1:] # r_loss = SimpleLossCompute(None, criterion, self.R_opt)(r_reco_data, x_orgi_data, X_data.ntokens) # way_1 # viewed_num = r_reco_data.shape[0]*r_reco_data.shape[1] # r_error = r_scale*self.cosloss(r_reco_data.float().view(-1, self.embed.weight.shape[1]), x_orgi_data.float().view(-1, self.embed.weight.shape[1]), torch.ones(viewed_num, dtype=torch.float32).to(device)) self.G_opt.step() self.G_opt.optimizer.zero_grad() r_ct.flush(info={"G_loss": g_error.item(), "R_loss": r_loss / X_data.ntokens.float().to(device)}) with torch.no_grad(): x_cont, x_ys = backward_decode(model, self.embed, X_test_batch.src, X_test_batch.src_mask, max_len=25, start_symbol=2) x = utils.idx2sent(x_ys) y_cont, y_ys = backward_decode(model, self.embed, Y_test_batch.src, Y_test_batch.src_mask, max_len=25, start_symbol=2) y = utils.idx2sent(y_ys) r_x = utils.idx2sent(backward_decode(self.R, self.embed, x_cont, None, max_len=self.utils.max_len, raw=True, return_term=1)) r_y = utils.idx2sent(backward_decode(self.R, self.embed, y_cont, None, max_len=self.utils.max_len, raw=True, return_term=1)) for i,j,l in zip(X_test_batch.src, x, r_x): print("===") k = utils.idx2sent([i])[0] print("ORG:", " ".join(k[:k.index('<eos>')+1])) print("--") print("GEN:", " ".join(j[:j.index('<eos>')+1] if '<eos>' in j else j)) print("--") print("REC:", " ".join(l[:l.index('<eos>')+1] if '<eos>' in l else l)) print("===") print("=====") for i, j, l in zip(Y_test_batch.src, y, r_y): print("===") k = utils.idx2sent([i])[0] print("ORG:", " ".join(k[:k.index('<eos>')+1])) print("--") print("GEN:", " ".join(j[:j.index('<eos>')+1] if '<eos>' in j else j)) print("--") print("REC:", " ".join(l[:l.index('<eos>')+1] if '<eos>' in l else l)) print("===")
main_model = CycleGAN(disc, model, utils, embedding_layer) main_model.to(device) main_model.load_model(g_file="model_pretrain.ckpt", r_file="model_pretrain.ckpt", d_file=args.disc_name) main_model.train_model() if args.mode == "disc": disc = Discriminator(word_dim=utils.emb_mat.shape[1], inner_dim=512, seq_len=20) main_model = CycleGAN(disc, model, utils, embedding_layer) main_model.to(device) main_model.pretrain_disc(2) if args.mode == "dev": model = Transformer(N=2) utils = Utils(X_data_path="big_cou.txt", Y_data_path="big_cna.txt") model.generator = Generator(d_model = utils.emb_mat.shape[1], vocab=utils.emb_mat.shape[0]) criterion = LabelSmoothing(size=utils.emb_mat.shape[0], padding_idx=0, smoothing=0.0) model_opt = NoamOpt(utils.emb_mat.shape[1], 1, 400, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) X_test_batch = None Y_test_batch = None for i, batch in enumerate(data_gen(utils.data_generator("X"), utils.sents2idx)): X_test_batch = batch break for i, batch in enumerate(data_gen(utils.data_generator("Y"), utils.sents2idx)): Y_test_batch = batch break # if args.load_model: # model.load_model(filename=args.model_name) # if args.mode == "train": # model.train_model(num_epochs=int(args.epoch)) # print("========= Testing =========")