def __init__(self, discriminator, generator, utils, embedder):
        super(CycleGAN, self).__init__()
        self.D = discriminator
        self.G = generator
        self.R = copy.deepcopy(generator)
        self.D_opt = torch.optim.Adam(self.D.parameters())
        # self.G_opt = torch.optim.Adam(self.G.parameters())
        self.G_opt = NoamOpt(
            utils.emb_mat.shape[1], 1, 4000,
            torch.optim.Adam(self.G.parameters(),
                             lr=0,
                             betas=(0.9, 0.98),
                             eps=1e-9))
        # self.R_opt = torch.optim.Adam(self.R.parameters())
        self.R_opt = NoamOpt(
            utils.emb_mat.shape[1], 1, 4000,
            torch.optim.Adam(self.R.parameters(),
                             lr=0,
                             betas=(0.9, 0.98),
                             eps=1e-9))
        self.embed = embedder

        self.utils = utils
        self.criterion = nn.CrossEntropyLoss(ignore_index=-1)
        self.mse = nn.MSELoss()
        self.cos = nn.CosineSimilarity(dim=-1)
        self.cosloss = nn.CosineEmbeddingLoss()
        self.r_criterion = LabelSmoothing(size=utils.emb_mat.shape[0],
                                          padding_idx=0,
                                          smoothing=0.0)
        self.r_loss_compute = SimpleLossCompute(self.R.generator,
                                                self.r_criterion, self.R_opt)
def pretrain(model, embedding_layer, utils, epoch_num=1):
    criterion = LabelSmoothing(size=utils.emb_mat.shape[0], padding_idx=0, smoothing=0.0)
    model_opt = NoamOpt(utils.emb_mat.shape[1], 1, 4000,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
    X_test_batch = None
    Y_test_batch = None
    for i, batch in enumerate(data_gen(utils.data_generator("X"), utils.sents2idx)):
        X_test_batch = batch
        break

    for i, batch in enumerate(data_gen(utils.data_generator("Y"), utils.sents2idx)):
        Y_test_batch = batch
        break
    model.to(device)
    for epoch in range(epoch_num):
            model.train()
            print("EPOCH %d:"%(epoch+1))
            pretrain_run_epoch(data_gen(utils.data_generator("Y"), utils.sents2idx), model, 
                    SimpleLossCompute(model.generator, criterion, model_opt), utils.train_step_num, embedding_layer)
            pretrain_run_epoch(data_gen(utils.data_generator("X"), utils.sents2idx), model, 
                    SimpleLossCompute(model.generator, criterion, model_opt), utils.train_step_num, embedding_layer)
            model.eval()
            torch.save(model.state_dict(), 'model_pretrain.ckpt')
            x = utils.idx2sent(greedy_decode(model, embedding_layer, X_test_batch.src, X_test_batch.src_mask, max_len=20, start_symbol=2))
            y = utils.idx2sent(greedy_decode(model, embedding_layer, Y_test_batch.src, Y_test_batch.src_mask, max_len=20, start_symbol=2))

            for i,j in zip(X_test_batch.src, x):
                print("===")
                k = utils.idx2sent([i])[0]
                print("ORG:", " ".join(k[:k.index('<eos>')+1]))
                print("--")
                print("GEN:", " ".join(j[:j.index('<eos>')+1] if '<eos>' in j else j))
                print("===")
            print("=====")
            for i, j in zip(Y_test_batch.src, y):
                print("===")
                k = utils.idx2sent([i])[0]
                print("ORG:", " ".join(k[:k.index('<eos>')+1]))
                print("--")
                print("GEN:", " ".join(j[:j.index('<eos>')+1] if '<eos>' in j else j))
                print("===")
class CycleGAN(nn.Module):
    def __init__(self, discriminator, generator, utils, embedder):
        super(CycleGAN, self).__init__()
        self.D = discriminator
        self.G = generator
        self.R = copy.deepcopy(generator)
        self.D_opt = torch.optim.Adam(self.D.parameters())
        # self.G_opt = torch.optim.Adam(self.G.parameters())
        self.G_opt = NoamOpt(utils.emb_mat.shape[1], 1, 4000,
            torch.optim.Adam(self.G.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
        # self.R_opt = torch.optim.Adam(self.R.parameters())
        self.R_opt = NoamOpt(utils.emb_mat.shape[1], 1, 4000,
            torch.optim.Adam(self.R.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
        self.embed = embedder

        self.utils = utils
        self.criterion = nn.CrossEntropyLoss(ignore_index=-1)
        self.mse = nn.MSELoss()
        self.cos = nn.CosineSimilarity(dim=-1)
        self.cosloss=nn.CosineEmbeddingLoss()
        self.r_criterion = LabelSmoothing(size=utils.emb_mat.shape[0], padding_idx=0, smoothing=0.0)
        self.r_loss_compute = SimpleLossCompute(self.R.generator, self.r_criterion, self.R_opt)

    def save_model(self, d_path="Dis_model.ckpt", g_path="Gen_model.ckpt", r_path="Res_model.ckpt"):
        torch.save(self.D.state_dict(), d_path)
        torch.save(self.G.state_dict(), g_path)
        torch.save(self.R.state_dict(), r_path)

    def load_model(self, path="", g_file=None, d_file=None, r_file=None):
        if g_file!=None:
            self.G.load_state_dict(torch.load(os.path.join(path, g_file), map_location=device))
        if d_file!=None:
            self.D.load_state_dict(torch.load(os.path.join(path, d_file), map_location=device))
        if r_file!=None:
            self.R.load_state_dict(torch.load(os.path.join(path, r_file), map_location=device))
        print("model loaded!")

    def pretrain_disc(self, num_epochs=100):
        # self.D.to(self.D.device)
        # self.G.to(self.G.device)
        # self.R.to(self.R.device)
        X_datagen = self.utils.data_generator("X")
        Y_datagen = self.utils.data_generator("Y")
        for epoch in range(num_epochs):
            d_steps = self.utils.train_step_num
            d_ct = Clock(d_steps, title="Train Discriminator(%d/%d)"%(epoch, num_epochs))
            for step, X_data, Y_data in zip(range(d_steps), data_gen(X_datagen, self.utils.sents2idx), data_gen(Y_datagen, self.utils.sents2idx)):
                # 1. Train D on real+fake
                # if epoch == 0:
                #     break
                self.D.zero_grad()
        
                #  1A: Train D on real
                d_real_pred = self.D(self.embed(Y_data.src.to(device)))
                d_real_error = self.criterion(d_real_pred, torch.ones((d_real_pred.shape[0],), dtype=torch.int64).to(self.D.device))  # ones = true

                #  1B: Train D on fake
                d_fake_pred = self.D(self.embed(X_data.src.to(device)))
                d_fake_error = self.criterion(d_fake_pred, torch.zeros((d_fake_pred.shape[0],), dtype=torch.int64).to(self.D.device))  # zeros = fake
                (d_fake_error + d_real_error).backward()
                self.D_opt.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()
                d_ct.flush(info={"D_loss": d_fake_error.item()})
        torch.save(self.D.state_dict(), "model_disc_pretrain.ckpt")

    def train_model(self, num_epochs=100, d_steps=20, g_steps=80, g_scale=1.0, r_scale=1.0):
        # self.D.to(self.D.device)
        # self.G.to(self.G.device)
        # self.R.to(self.R.device)
        for i, batch in enumerate(data_gen(utils.data_generator("X"), utils.sents2idx)):
            X_test_batch = batch
            break

        for i, batch in enumerate(data_gen(utils.data_generator("Y"), utils.sents2idx)):
            Y_test_batch = batch
            break
        X_datagen = self.utils.data_generator("X")
        Y_datagen = self.utils.data_generator("Y")
        for epoch in range(num_epochs):
            d_ct = Clock(d_steps, title="Train Discriminator(%d/%d)" % (epoch, num_epochs))
            if epoch>0:
                for i, X_data, Y_data in zip(range(d_steps), data_gen(X_datagen, self.utils.sents2idx), data_gen(Y_datagen, self.utils.sents2idx)):
                    # 1. Train D on real+fake
                    # if epoch == 0:
                    #     break
                    self.D.zero_grad()
            
                    #  1A: Train D on real
                    d_real_pred = self.D(self.embed(Y_data.src.to(device)))
                    d_real_error = self.criterion(d_real_pred, torch.ones((d_real_pred.shape[0],), dtype=torch.int64).to(self.D.device))  # ones = true

                    #  1B: Train D on fake
                    self.G.to(device)
                    d_fake_data = backward_decode(self.G, self.embed, X_data.src, X_data.src_mask, max_len=self.utils.max_len, return_term=0).detach()  # detach to avoid training G on these labels
                    d_fake_pred = self.D(d_fake_data)
                    d_fake_error = self.criterion(d_fake_pred, torch.zeros((d_fake_pred.shape[0],), dtype=torch.int64).to(self.D.device))  # zeros = fake
                    (d_fake_error + d_real_error).backward()
                    self.D_opt.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()
                    d_ct.flush(info={"D_loss":d_fake_error.item()})

            g_ct = Clock(g_steps, title="Train Generator(%d/%d)"%(epoch, num_epochs))
            r_ct = Clock(g_steps, title="Train Reconstructor(%d/%d)" % (epoch, num_epochs))
            if epoch>0:
                for i, X_data in zip(range(g_steps), data_gen(X_datagen, self.utils.sents2idx)):
                    # 2. Train G on D's response (but DO NOT train D on these labels)
                    self.G.zero_grad()
                    g_fake_data = backward_decode(self.G, self.embed, X_data.src, X_data.src_mask, max_len=self.utils.max_len, return_term=0)
                    dg_fake_pred = self.D(g_fake_data)
                    g_error = self.criterion(dg_fake_pred, torch.ones((dg_fake_pred.shape[0],), dtype=torch.int64).to(self.D.device))  # we want to fool, so pretend it's all genuine
            
                    g_error.backward(retain_graph=True)
                    self.G_opt.step()  # Only optimizes G's parameters
                    self.G.zero_grad()
                    g_ct.flush(info={"G_loss": g_error.item()})

                    # 3. reconstructor  643988636173-69t5i8ehelccbq85o3esu11jgh61j8u5.apps.googleusercontent.com
                    # way_3
                    out = self.R.forward(g_fake_data, embedding_layer(X_data.trg.to(device)), 
                        None, X_data.trg_mask)
                    r_loss = self.r_loss_compute(out, X_data.trg_y, X_data.ntokens)
                    # way_2
                    # r_reco_data = prob_backward(self.R, self.embed, g_fake_data, None, max_len=self.utils.max_len, raw=True)
                    # x_orgi_data = X_data.src[:, 1:]
                    # r_loss = SimpleLossCompute(None, criterion, self.R_opt)(r_reco_data, x_orgi_data, X_data.ntokens)
                    # way_1
                    # viewed_num = r_reco_data.shape[0]*r_reco_data.shape[1]
                    # r_error = r_scale*self.cosloss(r_reco_data.float().view(-1, self.embed.weight.shape[1]), x_orgi_data.float().view(-1, self.embed.weight.shape[1]), torch.ones(viewed_num, dtype=torch.float32).to(device))
                    self.G_opt.step()
                    self.G_opt.optimizer.zero_grad()
                    r_ct.flush(info={"G_loss": g_error.item(),
                    "R_loss": r_loss / X_data.ntokens.float().to(device)})
                
            with torch.no_grad():
                x_cont, x_ys = backward_decode(model, self.embed, X_test_batch.src, X_test_batch.src_mask, max_len=25, start_symbol=2)
                x = utils.idx2sent(x_ys)
                y_cont, y_ys = backward_decode(model, self.embed, Y_test_batch.src, Y_test_batch.src_mask, max_len=25, start_symbol=2)
                y = utils.idx2sent(y_ys)
                r_x = utils.idx2sent(backward_decode(self.R, self.embed, x_cont, None, max_len=self.utils.max_len, raw=True, return_term=1))
                r_y = utils.idx2sent(backward_decode(self.R, self.embed, y_cont, None, max_len=self.utils.max_len, raw=True, return_term=1))

                for i,j,l in zip(X_test_batch.src, x, r_x):
                    print("===")
                    k = utils.idx2sent([i])[0]
                    print("ORG:", " ".join(k[:k.index('<eos>')+1]))
                    print("--")
                    print("GEN:", " ".join(j[:j.index('<eos>')+1] if '<eos>' in j else j))
                    print("--")
                    print("REC:", " ".join(l[:l.index('<eos>')+1] if '<eos>' in l else l))
                    print("===")
                print("=====")
                for i, j, l in zip(Y_test_batch.src, y, r_y):
                    print("===")
                    k = utils.idx2sent([i])[0]
                    print("ORG:", " ".join(k[:k.index('<eos>')+1]))
                    print("--")
                    print("GEN:", " ".join(j[:j.index('<eos>')+1] if '<eos>' in j else j))
                    print("--")
                    print("REC:", " ".join(l[:l.index('<eos>')+1] if '<eos>' in l else l))
                    print("===")
        main_model = CycleGAN(disc, model, utils, embedding_layer)
        main_model.to(device)
        main_model.load_model(g_file="model_pretrain.ckpt", r_file="model_pretrain.ckpt", d_file=args.disc_name)
        main_model.train_model()
    if args.mode == "disc":
        disc = Discriminator(word_dim=utils.emb_mat.shape[1], inner_dim=512, seq_len=20)
        main_model = CycleGAN(disc, model, utils, embedding_layer)
        main_model.to(device)
        main_model.pretrain_disc(2)

    if args.mode == "dev":
        model = Transformer(N=2)
        utils = Utils(X_data_path="big_cou.txt", Y_data_path="big_cna.txt")
        model.generator = Generator(d_model = utils.emb_mat.shape[1], vocab=utils.emb_mat.shape[0])
        criterion = LabelSmoothing(size=utils.emb_mat.shape[0], padding_idx=0, smoothing=0.0)
        model_opt = NoamOpt(utils.emb_mat.shape[1], 1, 400,
                torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
        X_test_batch = None
        Y_test_batch = None
        for i, batch in enumerate(data_gen(utils.data_generator("X"), utils.sents2idx)):
            X_test_batch = batch
            break

        for i, batch in enumerate(data_gen(utils.data_generator("Y"), utils.sents2idx)):
            Y_test_batch = batch
            break

    # if args.load_model:
    #     model.load_model(filename=args.model_name)
    # if args.mode == "train":
    #     model.train_model(num_epochs=int(args.epoch))
    #     print("========= Testing =========")