예제 #1
0
    def export(self,
               src_dico,
               tgt_dico,
               emb_en,
               emb_it,
               seed,
               export_emb=False):
        params = _get_eval_params(self.params)
        eval = Evaluator(params, emb_en, emb_it, torch.cuda.is_available())
        # Export adversarial dictionaries
        optim_X_AE = AE(params).cuda()
        optim_Y_AE = AE(params).cuda()
        print('Loading pre-trained models...')
        optim_X_AE.load_state_dict(
            torch.load(self.tune_dir +
                       '/best/seed_{}_dico_{}_best_X.t7'.format(
                           seed, params.dico_build)))
        optim_Y_AE.load_state_dict(
            torch.load(self.tune_dir +
                       '/best/seed_{}_dico_{}_best_Y.t7'.format(
                           seed, params.dico_build)))
        X_Z = optim_X_AE.encode(Variable(emb_en)).data
        Y_Z = optim_Y_AE.encode(Variable(emb_it)).data

        mstart_time = timer()
        for method in ['nn', 'csls_knn_10']:
            results = get_word_translation_accuracy(params.src_lang,
                                                    src_dico[1],
                                                    X_Z,
                                                    params.tgt_lang,
                                                    tgt_dico[1],
                                                    emb_it,
                                                    method=method,
                                                    dico_eval=self.eval_file,
                                                    device=params.cuda_device)
            acc1 = results[0][1]
            results = get_word_translation_accuracy(params.tgt_lang,
                                                    tgt_dico[1],
                                                    Y_Z,
                                                    params.src_lang,
                                                    src_dico[1],
                                                    emb_en,
                                                    method=method,
                                                    dico_eval=self.eval_file2,
                                                    device=params.cuda_device)
            acc2 = results[0][1]

            # csls = 0
            print('{} takes {:.2f}s'.format(method, timer() - mstart_time))
            print('Method:{} score:{:.4f}-{:.4f}'.format(method, acc1, acc2))

        f_csls = eval.dist_mean_cosine(X_Z, emb_it)
        b_csls = eval.dist_mean_cosine(Y_Z, emb_en)
        csls = (f_csls + b_csls) / 2.0
        print("Seed:{},ACC:{:.4f}-{:.4f},CSLS_FB:{:.6f}".format(
            seed, acc1, acc2, csls))
        #'''
        print('Building dictionaries...')
        params.dico_build = "S2T&T2S"
        params.dico_method = "csls_knn_10"
        X_Z = X_Z / X_Z.norm(2, 1, keepdim=True).expand_as(X_Z)
        emb_it = emb_it / emb_it.norm(2, 1, keepdim=True).expand_as(emb_it)
        f_dico_induce = build_dictionary(X_Z, emb_it, params)
        f_dico_induce = f_dico_induce.cpu().numpy()
        Y_Z = Y_Z / Y_Z.norm(2, 1, keepdim=True).expand_as(Y_Z)
        emb_en = emb_en / emb_en.norm(2, 1, keepdim=True).expand_as(emb_en)
        b_dico_induce = build_dictionary(Y_Z, emb_en, params)
        b_dico_induce = b_dico_induce.cpu().numpy()

        f_dico_set = set([(a, b) for a, b in f_dico_induce])
        b_dico_set = set([(b, a) for a, b in b_dico_induce])

        intersect = list(f_dico_set & b_dico_set)
        union = list(f_dico_set | b_dico_set)

        with io.open(
                self.tune_dir +
                '/export/{}-{}.dict'.format(params.src_lang, params.tgt_lang),
                'w',
                encoding='utf-8',
                newline='\n') as f:
            for item in f_dico_induce:
                f.write('{} {}\n'.format(src_dico[0][item[0]],
                                         tgt_dico[0][item[1]]))

        with io.open(
                self.tune_dir +
                '/export/{}-{}.dict'.format(params.tgt_lang, params.src_lang),
                'w',
                encoding='utf-8',
                newline='\n') as f:
            for item in b_dico_induce:
                f.write('{} {}\n'.format(tgt_dico[0][item[0]],
                                         src_dico[0][item[1]]))

        with io.open(self.tune_dir + '/export/{}-{}.intersect'.format(
                params.src_lang, params.tgt_lang),
                     'w',
                     encoding='utf-8',
                     newline='\n') as f:
            for item in intersect:
                f.write('{} {}\n'.format(src_dico[0][item[0]],
                                         tgt_dico[0][item[1]]))

        with io.open(self.tune_dir + '/export/{}-{}.intersect'.format(
                params.tgt_lang, params.src_lang),
                     'w',
                     encoding='utf-8',
                     newline='\n') as f:
            for item in intersect:
                f.write('{} {}\n'.format(tgt_dico[0][item[1]],
                                         src_dico[0][item[0]]))

        with io.open(
                self.tune_dir +
                '/export/{}-{}.union'.format(params.src_lang, params.tgt_lang),
                'w',
                encoding='utf-8',
                newline='\n') as f:
            for item in union:
                f.write('{} {}\n'.format(src_dico[0][item[0]],
                                         tgt_dico[0][item[1]]))

        with io.open(
                self.tune_dir +
                '/export/{}-{}.union'.format(params.tgt_lang, params.src_lang),
                'w',
                encoding='utf-8',
                newline='\n') as f:
            for item in union:
                f.write('{} {}\n'.format(tgt_dico[0][item[1]],
                                         src_dico[0][item[0]]))

        if export_emb:
            print('Exporting {}-{}.{}'.format(params.src_lang, params.tgt_lang,
                                              params.src_lang))
            loader.export_embeddings(
                src_dico[0],
                X_Z,
                path=self.tune_dir + '/export/{}-{}.{}'.format(
                    params.src_lang, params.tgt_lang, params.src_lang),
                eformat='txt')
            print('Exporting {}-{}.{}'.format(params.src_lang, params.tgt_lang,
                                              params.tgt_lang))
            loader.export_embeddings(
                tgt_dico[0],
                emb_it,
                path=self.tune_dir + '/export/{}-{}.{}'.format(
                    params.src_lang, params.tgt_lang, params.tgt_lang),
                eformat='txt')
            print('Exporting {}-{}.{}'.format(params.tgt_lang, params.src_lang,
                                              params.tgt_lang))
            loader.export_embeddings(
                tgt_dico[0],
                Y_Z,
                path=self.tune_dir + '/export/{}-{}.{}'.format(
                    params.tgt_lang, params.src_lang, params.tgt_lang),
                eformat='txt')
            print('Exporting {}-{}.{}'.format(params.tgt_lang, params.src_lang,
                                              params.src_lang))
            loader.export_embeddings(
                src_dico[0],
                emb_en,
                path=self.tune_dir + '/export/{}-{}.{}'.format(
                    params.tgt_lang, params.src_lang, params.src_lang),
                eformat='txt')
예제 #2
0
파일: trainer.py 프로젝트: muyeby/Dou18
    def train(self, src_dico, tgt_dico, src_emb, tgt_emb, seed, stage):
        params = self.params
        # Load data
        if not os.path.exists(params.data_dir):
            print("Data path doesn't exists: %s" % params.data_dir)
        if not os.path.exists(self.tune_dir):
            os.makedirs(self.tune_dir)
        if not os.path.exists(self.tune_best_dir):
            os.makedirs(self.tune_best_dir)

        src_word2id = src_dico[1]
        tgt_word2id = tgt_dico[1]

        en = src_emb
        it = tgt_emb

        params = _get_eval_params(params)
        self.params = params
        eval = Evaluator(params, en, it, torch.cuda.is_available())

        AE_optimizer = optim.SGD(filter(
            lambda p: p.requires_grad,
            list(self.X_AE.parameters()) + list(self.Y_AE.parameters())),
                                 lr=params.g_learning_rate)
        # AE_optimizer = optim.SGD(G_params, lr=0.1, momentum=0.9)
        # AE_optimizer = optim.Adam(G_params, lr=params.g_learning_rate, betas=(0.9, 0.9))
        # AE_optimizer = optim.RMSprop(filter(lambda p: p.requires_grad, list(self.X_AE.parameters()) + list(self.Y_AE.parameters())),lr=params.g_learning_rate,alpha=0.9)
        D_optimizer = optim.SGD(list(self.D.parameters()),
                                lr=params.d_learning_rate)
        # D_optimizer = optim.Adam(D_params, lr=params.d_learning_rate, betas=(0.5, 0.9))
        # D_optimizer = optim.RMSprop(list(self.D_X.parameters()) + list(self.D_Y.parameters()), lr=params.d_learning_rate , alpha=0.9)

        # true_dict = get_true_dict(params.data_dir)
        D_acc_epochs = []
        d_loss_epochs = []
        G_AB_loss_epochs = []
        G_BA_loss_epochs = []
        G_AB_recon_epochs = []
        G_BA_recon_epochs = []
        g_loss_epochs = []
        acc_epochs = []

        csls_epochs = []
        best_valid_metric = -100

        # logs for plotting later
        log_file = open(
            "log_src_tgt.txt",
            "w")  # Being overwritten in every loop, not really required
        log_file.write(
            "epoch, disA_loss, disB_loss , disA_acc, disB_acc, g_AB_loss, g_BA_loss, g_AB_recon, g_BA_recon, CSLS, trans_Acc\n"
        )

        if stage == 1:
            self.params.num_epochs = 50
        if stage == 2:
            self.params.num_epochs = 10

        try:
            for epoch in range(self.params.num_epochs):

                G_AB_recon = []
                G_BA_recon = []
                G_X_loss = []
                G_Y_loss = []
                d_losses = []
                g_losses = []
                hit_A = 0
                total = 0
                start_time = timer()
                # lowest_loss = 1e5
                label_D = to_variable(
                    torch.FloatTensor(2 * params.mini_batch_size).zero_())
                label_D[:params.mini_batch_size] = 1 - params.smoothing
                label_D[params.mini_batch_size:] = params.smoothing

                label_G = to_variable(
                    torch.FloatTensor(params.mini_batch_size).zero_())
                label_G = label_G + 1 - params.smoothing
                label_G2 = to_variable(
                    torch.FloatTensor(
                        params.mini_batch_size).zero_()) + params.smoothing

                for mini_batch in range(
                        0, params.iters_in_epoch // params.mini_batch_size):
                    for d_index in range(params.d_steps):
                        D_optimizer.zero_grad()  # Reset the gradients
                        self.D.train()

                        view_X, view_Y = self.get_batch_data_fast_new(en, it)

                        # Discriminator X
                        _, Y_Z = self.Y_AE(view_Y)
                        _, X_Z = self.X_AE(view_X)
                        Y_Z = Y_Z.detach()
                        X_Z = X_Z.detach()
                        input = torch.cat([Y_Z, X_Z], 0)

                        pred = self.D(input)
                        D_loss = self.loss_fn(pred, label_D)
                        D_loss.backward(
                        )  # compute/store gradients, but don't change params
                        d_losses.append(to_numpy(D_loss.data))

                        discriminator_decision_A = to_numpy(pred.data)
                        hit_A += np.sum(
                            discriminator_decision_A[:params.mini_batch_size]
                            >= 0.5)
                        hit_A += np.sum(
                            discriminator_decision_A[params.mini_batch_size:] <
                            0.5)

                        D_optimizer.step(
                        )  # Only optimizes D's parameters; changes based on stored gradients from backward()

                        # Clip weights
                        _clip(self.D, params.clip_value)

                        sys.stdout.write(
                            "[%d/%d] :: Discriminator Loss: %.3f \r" %
                            (mini_batch,
                             params.iters_in_epoch // params.mini_batch_size,
                             np.asscalar(np.mean(d_losses))))
                        sys.stdout.flush()

                    total += 2 * params.mini_batch_size * params.d_steps

                    for g_index in range(params.g_steps):
                        # 2. Train G on D's response (but DO NOT train D on these labels)
                        AE_optimizer.zero_grad()
                        self.D.eval()

                        view_X, view_Y = self.get_batch_data_fast_new(en, it)

                        # Generator X_AE
                        ## adversarial loss
                        X_recon, X_Z = self.X_AE(view_X)
                        Y_recon, Y_Z = self.Y_AE(view_Y)

                        # input = torch.cat([Y_Z, X_Z], 0)

                        predx = self.D(X_Z)
                        D_X_loss = self.loss_fn(predx, label_G)
                        predy = self.D(Y_Z)
                        D_Y_loss = self.loss_fn(predy, label_G2)

                        L_recon_X = 1.0 - torch.mean(
                            self.loss_fn2(view_X, X_recon))
                        L_recon_Y = 1.0 - torch.mean(
                            self.loss_fn2(view_Y, Y_recon))

                        G_loss = D_X_loss + D_Y_loss + L_recon_X + L_recon_Y

                        G_loss.backward()

                        g_losses.append(to_numpy(G_loss.data))
                        G_X_loss.append(
                            to_numpy(D_X_loss.data + L_recon_X.data))
                        G_Y_loss.append(
                            to_numpy(D_Y_loss.data + L_recon_Y.data))
                        G_AB_recon.append(to_numpy(L_recon_X.data))
                        G_BA_recon.append(to_numpy(L_recon_Y.data))

                        AE_optimizer.step()  # Only optimizes G's parameters

                        sys.stdout.write(
                            "[%d/%d] ::                                     Generator Loss: %.3f Generator Y recon: %.3f\r"
                            % (mini_batch,
                               params.iters_in_epoch // params.mini_batch_size,
                               np.asscalar(np.mean(g_losses)),
                               np.asscalar(np.mean(G_BA_recon))))
                        sys.stdout.flush()
                '''for each epoch'''

                D_acc_epochs.append(hit_A / total)
                G_AB_recon_epochs.append(np.asscalar(np.mean(G_AB_recon)))
                G_BA_recon_epochs.append(np.asscalar(np.mean(G_BA_recon)))
                d_loss_epochs.append(np.asscalar(np.mean(d_losses)))
                g_loss_epochs.append(np.asscalar(np.mean(g_losses)))

                print(
                    "Epoch {} : Discriminator Loss: {:.3f}, Discriminator Accuracy: {:.3f}, Generator Loss: {:.3f}, Time elapsed {:.2f} mins"
                    .format(epoch, np.asscalar(np.mean(d_losses)),
                            hit_A / total, np.asscalar(np.mean(g_losses)),
                            (timer() - start_time) / 60))

                if (epoch + 1) % params.print_every == 0:
                    # No need for discriminator weights

                    _, X_Z = self.X_AE(Variable(en))
                    _, Y_Z = self.Y_AE(Variable(it))
                    X_Z = X_Z.data
                    Y_Z = Y_Z.data

                    mstart_time = timer()
                    for method in [params.eval_method]:
                        results = get_word_translation_accuracy(
                            params.src_lang,
                            src_word2id,
                            X_Z,
                            params.tgt_lang,
                            tgt_word2id,
                            Y_Z,
                            method=method,
                            dico_eval='default')
                        acc1 = results[0][1]

                    print('{} takes {:.2f}s'.format(method,
                                                    timer() - mstart_time))
                    print('Method:{} score:{:.4f}'.format(method, acc1))

                    csls = eval.dist_mean_cosine(X_Z, Y_Z)

                    if csls > best_valid_metric:
                        print("New csls value: {}".format(csls))
                        best_valid_metric = csls
                        fp = open(
                            self.tune_best_dir +
                            "/seed_{}_dico_{}_stage_{}_epoch_{}_acc_{:.3f}.tmp"
                            .format(seed, params.dico_build, stage, epoch,
                                    acc1), 'w')
                        fp.close()
                        torch.save(
                            self.X_AE.state_dict(), self.tune_best_dir +
                            '/seed_{}_dico_{}_stage_{}_best_X.t7'.format(
                                seed, params.dico_build, stage))
                        torch.save(
                            self.Y_AE.state_dict(), self.tune_best_dir +
                            '/seed_{}_dico_{}_stage_{}_best_Y.t7'.format(
                                seed, params.dico_build, stage))
                        torch.save(
                            self.D.state_dict(), self.tune_best_dir +
                            '/seed_{}_dico_{}_stage_{}_best_D.t7'.format(
                                seed, params.dico_build, stage))

                    # Saving generator weights
                    fp = open(
                        self.tune_dir +
                        "/seed_{}_stage_{}_epoch_{}_acc_{:.3f}.tmp".format(
                            seed, stage, epoch, acc1), 'w')
                    fp.close()

                    acc_epochs.append(acc1)
                    csls_epochs.append(csls)

            csls_fb, epoch_fb = max([
                (score, index) for index, score in enumerate(csls_epochs)
            ])
            fp = open(
                self.tune_best_dir +
                "/seed_{}_dico_{}_stage_{}_epoch_{}_Acc_{:.3f}_{:.3f}.cslsfb".
                format(seed, params.dico_build, stage, epoch_fb,
                       acc_epochs[epoch_fb], csls_fb), 'w')
            fp.close()

            # Save the plot for discriminator accuracy and generator loss
            # fig = plt.figure()
            # plt.plot(range(0, len(D_A_acc_epochs)), D_A_acc_epochs, color='b', label='D_A')
            # plt.plot(range(0, len(D_B_acc_epochs)), D_B_acc_epochs, color='r', label='D_B')
            # plt.ylabel('D_accuracy')
            # plt.xlabel('epochs')
            # plt.legend()
            # fig.savefig(self.tune_dir + '/seed_{}_stage_{}_D_acc.png'.format(seed, stage))
            #
            # fig = plt.figure()
            # plt.plot(range(0, len(D_A_loss_epochs)), D_A_loss_epochs, color='b', label='D_A')
            # plt.plot(range(0, len(D_B_loss_epochs)), D_B_loss_epochs, color='r', label='D_B')
            # plt.ylabel('D_losses')
            # plt.xlabel('epochs')
            # plt.legend()
            # fig.savefig(self.tune_dir + '/seed_{}_stage_{}_D_loss.png'.format(seed, stage))
            #
            # fig = plt.figure()
            # plt.plot(range(0, len(G_AB_loss_epochs)), G_AB_loss_epochs, color='b', label='G_AB')
            # plt.plot(range(0, len(G_BA_loss_epochs)), G_BA_loss_epochs, color='r', label='G_BA')
            # plt.ylabel('G_losses')
            # plt.xlabel('epochs')
            # plt.legend()
            # fig.savefig(self.tune_dir + '/seed_{}_stage_{}_G_loss.png'.format(seed,stage))
            #
            # fig = plt.figure()
            # plt.plot(range(0, len(G_AB_recon_epochs)), G_AB_recon_epochs, color='b', label='G_AB')
            # plt.plot(range(0, len(G_BA_recon_epochs)), G_BA_recon_epochs, color='r', label='G_BA')
            # plt.ylabel('G_recon_loss')
            # plt.xlabel('epochs')
            # plt.legend()
            # fig.savefig(self.tune_dir + '/seed_{}_stage_{}_G_Recon.png'.format(seed,stage))

            # fig = plt.figure()
            # plt.plot(range(0, len(L_Z_loss_epoches)), L_Z_loss_epoches, color='b', label='L_Z')
            # plt.ylabel('L_Z_loss')
            # plt.xlabel('epochs')
            # plt.legend()
            # fig.savefig(tune_dir + '/seed_{}_stage_{}_L_Z.png'.format(seed,stage))

            fig = plt.figure()
            plt.plot(range(0, len(acc_epochs)),
                     acc_epochs,
                     color='b',
                     label='trans_acc1')
            plt.ylabel('trans_acc')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir +
                        '/seed_{}_stage_{}_trans_acc.png'.format(seed, stage))

            fig = plt.figure()
            plt.plot(range(0, len(csls_epochs)),
                     csls_epochs,
                     color='b',
                     label='csls')
            plt.ylabel('csls')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir +
                        '/seed_{}_stage_{}_csls.png'.format(seed, stage))

            fig = plt.figure()
            plt.plot(range(0, len(g_loss_epochs)),
                     g_loss_epochs,
                     color='b',
                     label='G_loss')
            plt.ylabel('g_loss')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir +
                        '/seed_{}_g_stage_{}_loss.png'.format(seed, stage))

            fig = plt.figure()
            plt.plot(range(0, len(d_loss_epochs)),
                     d_loss_epochs,
                     color='b',
                     label='csls')
            plt.ylabel('D_loss')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir +
                        '/seed_{}_stage_{}_d_loss.png'.format(seed, stage))
            plt.close('all')

        except KeyboardInterrupt:
            print("Interrupted.. saving model !!!")
            torch.save(self.X_AE.state_dict(), 'X_model_interrupt.t7')
            torch.save(self.Y_AE.state_dict(), 'Y_model_interrupt.t7')
            torch.save(self.D.state_dict(), 'd_model_interrupt.t7')
            log_file.close()
            exit()

        log_file.close()
        return
예제 #3
0
    def train(self, src_dico, tgt_dico, src_emb, tgt_emb, seed):
        params = self.params
        # Load data
        if not os.path.exists(params.data_dir):
            print("Data path doesn't exists: %s" % params.data_dir)
        if not os.path.exists(self.tune_dir):
            os.makedirs(self.tune_dir)
        if not os.path.exists(self.tune_best_dir):
            os.makedirs(self.tune_best_dir)
        if not os.path.exists(self.tune_export_dir):
            os.makedirs(self.tune_export_dir)

        src_word2id = src_dico[1]
        tgt_word2id = tgt_dico[1]

        en = src_emb
        it = tgt_emb

        params = _get_eval_params(params)
        self.params = params
        eval = Evaluator(params, en, it, torch.cuda.is_available())

        # for seed_index in range(params.num_random_seeds):

        AE_optimizer = optim.SGD(filter(
            lambda p: p.requires_grad,
            list(self.X_AE.parameters()) + list(self.Y_AE.parameters())),
                                 lr=params.g_learning_rate)
        # AE_optimizer = optim.SGD(G_params, lr=0.1, momentum=0.9)
        # AE_optimizer = optim.Adam(G_params, lr=params.g_learning_rate, betas=(0.9, 0.9))
        # AE_optimizer = optim.RMSprop(filter(lambda p: p.requires_grad, list(self.X_AE.parameters()) + list(self.Y_AE.parameters())),lr=params.g_learning_rate,alpha=0.9)
        D_optimizer = optim.SGD(list(self.D_X.parameters()) +
                                list(self.D_Y.parameters()),
                                lr=params.d_learning_rate)
        # D_optimizer = optim.Adam(D_params, lr=params.d_learning_rate, betas=(0.5, 0.9))
        # D_optimizer = optim.RMSprop(list(self.D_X.parameters()) + list(self.D_Y.parameters()), lr=params.d_learning_rate , alpha=0.9)

        # D_X=nn.DataParallel(D_X)
        # D_Y=nn.DataParallel(D_Y)
        # true_dict = get_true_dict(params.data_dir)
        D_A_acc_epochs = []
        D_B_acc_epochs = []
        D_A_loss_epochs = []
        D_B_loss_epochs = []
        G_AB_loss_epochs = []
        G_BA_loss_epochs = []
        G_AB_recon_epochs = []
        G_BA_recon_epochs = []
        L_Z_loss_epoches = []

        acc1_epochs = []
        acc2_epochs = []

        csls_epochs = []
        f_csls_epochs = []
        b_csls_epochs = []
        best_valid_metric = -100

        # logs for plotting later
        log_file = open(
            "log_src_tgt.txt",
            "w")  # Being overwritten in every loop, not really required
        log_file.write("epoch, dis_loss, dis_acc, g_loss\n")

        try:
            for epoch in range(self.params.num_epochs):
                D_A_losses = []
                D_B_losses = []
                G_AB_losses = []
                G_AB_recon = []
                G_BA_losses = []
                G_adv_losses = []
                G_BA_recon = []
                L_Z_losses = []
                d_losses = []
                g_losses = []
                hit_A = 0
                hit_B = 0
                total = 0
                start_time = timer()
                # lowest_loss = 1e5
                # label_D = to_variable(torch.FloatTensor(2 * params.mini_batch_size).zero_())
                label_D = to_variable(
                    torch.FloatTensor(2 * params.mini_batch_size).zero_())
                label_D[:params.mini_batch_size] = 1 - params.smoothing
                label_D[params.mini_batch_size:] = params.smoothing

                label_G = to_variable(
                    torch.FloatTensor(params.mini_batch_size).zero_())
                label_G = label_G + 1 - params.smoothing

                for mini_batch in range(
                        0, params.iters_in_epoch // params.mini_batch_size):
                    for d_index in range(params.d_steps):
                        D_optimizer.zero_grad()  # Reset the gradients
                        self.D_X.train()
                        self.D_Y.train()

                        #print('D_X:', self.D_X.map1.weight.data)
                        #print('D_Y:', self.D_Y.map1.weight.data)

                        view_X, view_Y = self.get_batch_data_fast_new(en, it)
                        # Discriminator X
                        #print('View_Y',view_Y)
                        fake_X = self.Y_AE.encode(view_Y).detach()
                        #print('fakeX',fake_X)
                        input = torch.cat([view_X, fake_X], 0)

                        pred_A = self.D_X(input)
                        #print('Pred_A',pred_A)
                        D_A_loss = self.loss_fn(pred_A, label_D)
                        # print(view_Y)
                        # Discriminator Y
                        # print('View_X',view_X)
                        fake_Y = self.X_AE.encode(view_X).detach()
                        # print('fakeY:',fake_Y)

                        input = torch.cat([view_Y, fake_Y], 0)
                        pred_B = self.D_Y(input)
                        # print('Pred_B', pred_B)
                        D_B_loss = self.loss_fn(pred_B, label_D)

                        D_loss = (1.0) * D_A_loss + params.gate * D_B_loss

                        D_loss.backward(
                        )  # compute/store gradients, but don't change params
                        d_losses.append(to_numpy(D_loss.data))
                        D_A_losses.append(to_numpy(D_A_loss.data))
                        D_B_losses.append(to_numpy(D_B_loss.data))

                        discriminator_decision_A = to_numpy(pred_A.data)
                        hit_A += np.sum(
                            discriminator_decision_A[:params.mini_batch_size]
                            >= 0.5)
                        hit_A += np.sum(
                            discriminator_decision_A[params.mini_batch_size:] <
                            0.5)

                        discriminator_decision_B = to_numpy(pred_B.data)
                        hit_B += np.sum(
                            discriminator_decision_B[:params.mini_batch_size]
                            >= 0.5)
                        hit_B += np.sum(
                            discriminator_decision_B[params.mini_batch_size:] <
                            0.5)

                        D_optimizer.step(
                        )  # Only optimizes D's parameters; changes based on stored gradients from backward()

                        # Clip weights
                        _clip(self.D_X, params.clip_value)
                        _clip(self.D_Y, params.clip_value)
                        # print('D_loss',d_losses)

                        sys.stdout.write(
                            "[%d/%d] :: Discriminator Loss: %.3f \r" %
                            (mini_batch,
                             params.iters_in_epoch // params.mini_batch_size,
                             np.asscalar(np.mean(d_losses))))
                        sys.stdout.flush()

                    total += 2 * params.mini_batch_size * params.d_steps

                    for g_index in range(params.g_steps):
                        # 2. Train G on D's response (but DO NOT train D on these labels)
                        AE_optimizer.zero_grad()
                        self.D_X.eval()
                        self.D_Y.eval()
                        view_X, view_Y = self.get_batch_data_fast_new(en, it)

                        # Generator X_AE
                        ## adversarial loss
                        Y_fake = self.X_AE.encode(view_X)
                        # X_recon = self.X_AE.decode(X_Z)
                        # Y_fake = self.Y_AE.encode(X_Z)
                        pred_Y = self.D_Y(Y_fake)
                        L_adv_X = self.loss_fn(pred_Y, label_G)

                        X_Cycle = self.Y_AE.encode(Y_fake)
                        L_Cycle_X = 1.0 - torch.mean(
                            self.loss_fn2(view_X, X_Cycle))

                        # L_recon_X = 1.0 - torch.mean(self.loss_fn2(view_X, X_recon))
                        # L_G_AB = L_adv_X + params.recon_weight * L_recon_X

                        # Generator Y_AE
                        # adversarial loss
                        X_fake = self.Y_AE.encode(view_Y)
                        pred_X = self.D_X(X_fake)
                        L_adv_Y = self.loss_fn(pred_X, label_G)

                        ### Cycle Loss
                        Y_Cycle = self.X_AE.encode(X_fake)
                        L_Cycle_Y = 1.0 - torch.mean(
                            self.loss_fn2(view_Y, Y_Cycle))

                        # L_recon_Y = 1.0 - torch.mean(self.loss_fn2(view_Y, Y_recon))
                        # L_G_BA = L_adv_Y + params.recon_weight * L_recon_Y
                        # L_Z = 1.0 - torch.mean(self.loss_fn2(X_Z, Y_Z))

                        # G_loss = L_G_AB + L_G_BA + L_Z
                        G_loss = params.adv_weight * ( params.gate * L_adv_X + (1.0) * L_adv_Y) + \
                                 params.cycle_weight * (L_Cycle_X+L_Cycle_Y)

                        G_loss.backward()

                        g_losses.append(to_numpy(G_loss.data))
                        G_AB_losses.append(to_numpy(L_adv_X.data))
                        G_BA_losses.append(to_numpy(L_adv_Y.data))
                        G_adv_losses.append(to_numpy(L_adv_Y.data))
                        G_AB_recon.append(to_numpy(L_Cycle_X.data))
                        G_BA_recon.append(to_numpy(L_Cycle_Y.data))

                        AE_optimizer.step()  # Only optimizes G's parameters
                        self.orthogonalize(self.X_AE.map1.weight.data)
                        self.orthogonalize(self.Y_AE.map1.weight.data)

                        sys.stdout.write(
                            "[%d/%d] ::                                     Generator Loss: %.3f \r"
                            % (mini_batch,
                               params.iters_in_epoch // params.mini_batch_size,
                               np.asscalar(np.mean(g_losses))))
                        sys.stdout.flush()
                '''for each epoch'''
                D_A_acc_epochs.append(hit_A / total)
                D_B_acc_epochs.append(hit_B / total)
                G_AB_loss_epochs.append(np.asscalar(np.mean(G_AB_losses)))
                G_BA_loss_epochs.append(np.asscalar(np.mean(G_BA_losses)))
                D_A_loss_epochs.append(np.asscalar(np.mean(D_A_losses)))
                D_B_loss_epochs.append(np.asscalar(np.mean(D_B_losses)))
                G_AB_recon_epochs.append(np.asscalar(np.mean(G_AB_recon)))
                G_BA_recon_epochs.append(np.asscalar(np.mean(G_BA_recon)))
                # L_Z_loss_epoches.append(np.asscalar(np.mean(L_Z_losses)))

                print(
                    "Epoch {} : Discriminator Loss: {:.3f}, Discriminator Accuracy: {:.3f}, Generator Loss: {:.3f}, Time elapsed {:.2f} mins"
                    .format(epoch, np.asscalar(np.mean(d_losses)),
                            0.5 * (hit_A + hit_B) / total,
                            np.asscalar(np.mean(g_losses)),
                            (timer() - start_time) / 60))

                # lr decay
                # g_optim_state = AE_optimizer.state_dict()
                # old_lr = g_optim_state['param_groups'][0]['lr']
                # g_optim_state['param_groups'][0]['lr'] = max(old_lr * params.lr_decay, params.lr_min)
                # AE_optimizer.load_state_dict(g_optim_state)
                # print("Changing the learning rate: {} -> {}".format(old_lr, g_optim_state['param_groups'][0]['lr']))
                # d_optim_state = D_optimizer.state_dict()
                # d_optim_state['param_groups'][0]['lr'] = max(
                #     d_optim_state['param_groups'][0]['lr'] * params.lr_decay, params.lr_min)
                # D_optimizer.load_state_dict(d_optim_state)
                #     d_optim_state['param_groups'][0]['lr'] * params.lr_decay, params.lr_min)
                # D_optimizer.load_state_dict(d_optim_state)

                if (epoch + 1) % params.print_every == 0:
                    # No need for discriminator weights
                    # torch.save(d.state_dict(), 'discriminator_weights_en_es_{}.t7'.format(epoch))

                    # all_precisions = eval.get_all_precisions(G_AB(src_emb.weight).data)
                    Vec_xy = self.X_AE.encode(Variable(en))
                    Vec_xyx = self.Y_AE.encode(Vec_xy)
                    Vec_yx = self.Y_AE.encode(Variable(it))
                    Vec_yxy = self.X_AE.encode(Vec_yx)

                    mstart_time = timer()

                    # for method in ['csls_knn_10']:
                    for method in [params.eval_method]:
                        results = get_word_translation_accuracy(
                            params.src_lang,
                            src_word2id,
                            Vec_xy.data,
                            params.tgt_lang,
                            tgt_word2id,
                            it,
                            method=method,
                            dico_eval=self.eval_file,
                            device=params.cuda_device)
                        acc1 = results[0][1]
                        results = get_word_translation_accuracy(
                            params.tgt_lang,
                            tgt_word2id,
                            Vec_yx.data,
                            params.src_lang,
                            src_word2id,
                            en,
                            method=method,
                            dico_eval=self.eval_file2,
                            device=params.cuda_device)
                        acc2 = results[0][1]
                        print('{} takes {:.2f}s'.format(
                            method,
                            timer() - mstart_time))
                        print('Method:{} test_score:{:.4f}-{:.4f}'.format(
                            method, acc1, acc2))
                    '''
                    # for method in ['csls_knn_10']:
                    for method in [params.eval_method]:
                        results = get_word_translation_accuracy(
                            params.src_lang, src_word2id, Vec_xyx.data,
                            params.src_lang, src_word2id, en,
                            method=method,
                            dico_eval='/data/dictionaries/{}-{}.wacky.dict'.format(params.src_lang,params.src_lang),
                            device=params.cuda_device
                        )
                        acc11 = results[0][1]
                    # for method in ['csls_knn_10']:
                    for method in [params.eval_method]:
                        results = get_word_translation_accuracy(
                            params.tgt_lang, tgt_word2id, Vec_yxy.data,
                            params.tgt_lang, tgt_word2id, it,
                            method=method,
                            dico_eval='/data/dictionaries/{}-{}.wacky.dict'.format(params.tgt_lang,params.tgt_lang),
                            device=params.cuda_device
                        )
                        acc22 = results[0][1]
                    print('Valid:{} score:{:.4f}-{:.4f}'.format(method, acc11, acc22))
                    avg_valid = (acc11+acc22)/2.0
                    # valid_x = torch.mean(self.loss_fn2(en, Vec_xyx.data))
                    # valid_y = torch.mean(self.loss_fn2(it, Vec_yxy.data))
                    # avg_valid = (valid_x+valid_y)/2.0
                    '''
                    # csls = 0
                    f_csls = eval.dist_mean_cosine(Vec_xy.data, it)
                    b_csls = eval.dist_mean_cosine(Vec_yx.data, en)
                    csls = (f_csls + b_csls) / 2.0
                    # csls = eval.calc_unsupervised_criterion(X_Z)
                    if csls > best_valid_metric:
                        print("New csls value: {}".format(csls))
                        best_valid_metric = csls
                        fp = open(
                            self.tune_dir +
                            "/best/seed_{}_dico_{}_epoch_{}_acc_{:.3f}-{:.3f}.tmp"
                            .format(seed, params.dico_build, epoch, acc1,
                                    acc2), 'w')
                        fp.close()
                        torch.save(
                            self.X_AE.state_dict(), self.tune_dir +
                            '/best/seed_{}_dico_{}_best_X.t7'.format(
                                seed, params.dico_build))
                        torch.save(
                            self.Y_AE.state_dict(), self.tune_dir +
                            '/best/seed_{}_dico_{}_best_Y.t7'.format(
                                seed, params.dico_build))
                        torch.save(
                            self.D_X.state_dict(), self.tune_dir +
                            '/best/seed_{}_dico_{}_best_Dx.t7'.format(
                                seed, params.dico_build))
                        torch.save(
                            self.D_Y.state_dict(), self.tune_dir +
                            '/best/seed_{}_dico_{}_best_Dy.t7'.format(
                                seed, params.dico_build))
                    # print(json.dumps(all_precisions))
                    # p_1 = all_precisions['validation']['adv']['without-ref']['nn'][1]
                    # p_1 = all_precisions['validation']['adv']['without-ref']['csls'][1]
                    # log_file.write(str(results) + "\n")
                    # print('Method: nn score:{:.4f}'.format(acc))
                    # Saving generator weights
                    # torch.save(X_AE.state_dict(), tune_dir+'/G_AB_seed_{}_mf_{}_lr_{}_p@1_{:.3f}.t7'.format(seed,params.most_frequent_sampling_size,params.g_learning_rate,acc))
                    # torch.save(Y_AE.state_dict(), tune_dir+'/G_BA_seed_{}_mf_{}_lr_{}_p@1_{:.3f}.t7'.format(seed,params.most_frequent_sampling_size,params.g_learning_rate,acc))
                    fp = open(
                        self.tune_dir +
                        "/seed_{}_epoch_{}_acc_{:.3f}-{:.3f}_valid_{:.4f}.tmp".
                        format(seed, epoch, acc1, acc2, csls), 'w')
                    fp.close()
                    acc1_epochs.append(acc1)
                    acc2_epochs.append(acc2)
                    csls_epochs.append(csls)
                    f_csls_epochs.append(f_csls)
                    b_csls_epochs.append(b_csls)

            csls_fb, epoch_fb = max([
                (score, index) for index, score in enumerate(csls_epochs)
            ])
            fp = open(
                self.tune_dir +
                "/best/seed_{}_epoch_{}_{:.3f}_{:.3f}_{:.3f}.cslsfb".format(
                    seed, epoch_fb, acc1_epochs[epoch_fb],
                    acc2_epochs[epoch_fb], csls_fb), 'w')
            fp.close()
            csls_f, epoch_f = max([
                (score, index) for index, score in enumerate(f_csls_epochs)
            ])
            fp = open(
                self.tune_dir +
                "/best/seed_{}_epoch_{}_{:.3f}_{:.3f}_{:.3f}.cslsf".format(
                    seed, epoch_f, acc1_epochs[epoch_f], acc2_epochs[epoch_f],
                    csls_f), 'w')
            fp.close()
            csls_b, epoch_b = max([
                (score, index) for index, score in enumerate(b_csls_epochs)
            ])
            fp = open(
                self.tune_dir +
                "/best/seed_{}_epoch_{}_{:.3f}_{:.3f}_{:.3f}.cslsb".format(
                    seed, epoch_b, acc1_epochs[epoch_b], acc2_epochs[epoch_b],
                    csls_b), 'w')
            fp.close()
            '''

            # Save the plot for discriminator accuracy and generator loss
            fig = plt.figure()
            plt.plot(range(0, len(D_A_acc_epochs)), D_A_acc_epochs, color='b', label='D_A')
            plt.plot(range(0, len(D_B_acc_epochs)), D_B_acc_epochs, color='r', label='D_B')
            plt.ylabel('D_accuracy')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_D_acc.png'.format(seed))

            fig = plt.figure()
            plt.plot(range(0, len(D_A_loss_epochs)), D_A_loss_epochs, color='b', label='D_A')
            plt.plot(range(0, len(D_B_loss_epochs)), D_B_loss_epochs, color='r', label='D_B')
            plt.ylabel('D_losses')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_D_loss.png'.format(seed))

            fig = plt.figure()
            plt.plot(range(0, len(G_AB_loss_epochs)), G_AB_loss_epochs, color='b', label='G_AB')
            plt.plot(range(0, len(G_BA_loss_epochs)), G_BA_loss_epochs, color='r', label='G_BA')
            plt.ylabel('G_losses')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_G_loss.png'.format(seed))

            fig = plt.figure()
            plt.plot(range(0, len(G_AB_recon_epochs)), G_AB_recon_epochs, color='b', label='G_AB')
            plt.plot(range(0, len(G_BA_recon_epochs)), G_BA_recon_epochs, color='r', label='G_BA')
            plt.ylabel('G_Cycle_loss')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_G_Cycle.png'.format(seed))

            # fig = plt.figure()
            # plt.plot(range(0, len(L_Z_loss_epoches)), L_Z_loss_epoches, color='b', label='L_Z')
            # plt.ylabel('L_Z_loss')
            # plt.xlabel('epochs')
            # plt.legend()
            # fig.savefig(tune_dir + '/seed_{}_stage_{}_L_Z.png'.format(seed,stage))

            fig = plt.figure()
            plt.plot(range(0, len(acc1_epochs)), acc1_epochs, color='b', label='trans_acc1')
            plt.plot(range(0, len(acc2_epochs)), acc2_epochs, color='r', label='trans_acc2')
            plt.ylabel('trans_acc')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_trans_acc.png'.format(seed))

            fig = plt.figure()
            plt.plot(range(0, len(csls_epochs)), csls_epochs, color='b', label='csls')
            plt.plot(range(0, len(f_csls_epochs)), f_csls_epochs, color='r', label='csls_f')
            plt.plot(range(0, len(b_csls_epochs)), b_csls_epochs, color='g', label='csls_b')
            plt.ylabel('csls')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_csls.png'.format(seed))

            fig = plt.figure()
            plt.plot(range(0, len(g_losses)), g_losses, color='b', label='G_loss')
            plt.ylabel('g_loss')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_g_loss.png'.format(seed))

            fig = plt.figure()
            plt.plot(range(0, len(d_losses)), d_losses, color='b', label='csls')
            plt.ylabel('D_loss')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_d_loss.png'.format(seed))
            plt.close('all')
            '''

        except KeyboardInterrupt:
            print("Interrupted.. saving model !!!")
            torch.save(self.X_AE.state_dict(), 'g_model_interrupt.t7')
            torch.save(self.D_X.state_dict(), 'd_model_interrupt.t7')
            log_file.close()
            exit()

        log_file.close()
        return self.X_AE
예제 #4
0
class Trainer(object):
    def __init__(self, args):
        np.random.seed(args.seed)

        self.args = args

        self.logger = logger.Logger(args.output_dir)
        self.args.logger = self.logger

        current_commit_hash =\
            subprocess.check_output(["git", "rev-parse", "HEAD"]).strip()
        self.logger.log('current git commit hash: %s' % current_commit_hash)

        print('load vec')
        source_vecs, source_dico =\
            utils.load_word_vec_list(args.source_vec_file, args.source_lang)
        target_vecs, target_dico =\
            utils.load_word_vec_list(args.target_vec_file, args.target_lang)

        self.src_dico = source_dico
        self.tgt_dico = target_dico
        args.src_dico = source_dico
        args.tgt_dico = target_dico

        src_embed, tgt_embed =\
            utils.get_embeds_from_numpy(source_vecs, target_vecs)
        if args.use_cuda:
            self.src_embed = src_embed.cuda()
            self.tgt_embed = tgt_embed.cuda()
        else:
            self.src_embed = src_embed
            self.tgt_embed = tgt_embed

        print('setting models')
        netD = model.netD(self.args)
        netG = model.netG()
        netG.W.weight.data.copy_(torch.diag(torch.ones(300)))
        if args.multi_gpu:
            netD = nn.DataParallel(netD)
            netG = nn.DataParallel(netG)
        if args.use_cuda:
            netD = netD.cuda()
            netG = netG.cuda()
        self.netD = netD
        self.netG = netG
        self.optimizer_D = optim.Adam(self.netD.parameters(),
                                      lr=args.lr,
                                      betas=(args.beta1, 0.999))
        self.optimizer_G = optim.Adam(self.netG.parameters(),
                                      lr=args.lr,
                                      betas=(args.beta1, 0.999))
        self.criterion = nn.BCELoss()
        self.prefix = os.path.basename(args.output_dir)

        self.evaluator = Evaluator(self)

    def train(self):
        args = self.args

        for i_epoch in range(1, args.epoch + 1):

            error_D_list = []
            error_G_list = []

            for niter in tqdm(range(10000)):

                for _ in range(args.dis_step):
                    error_D = self.train_D()
                    error_D_list.append(error_D)

                error_G = self.train_G()

                error_G_list.append(error_G)

                if niter % 500 == 0:
                    print('error_D: ', np.mean(error_D_list))
                    print('error_G: ', np.mean(error_G_list))
                    if args.use_criteria:
                        print('dist cosine mean: ',
                              self.evaluator.dist_mean_cosine())
                        print('caluclating word translation accuracy...')
                        self.evaluator.word_translation()

            result_ = {
                'epoch': i_epoch,
                'error_D': np.mean(error_D_list),
                'error_G': np.mean(error_G_list)
            }
            self.logger.dump(result_)
            if i_epoch % args.log_inter == 0:
                progress_path = os.path.join(args.output_dir, 'progress.json')
                imgpaths = slack_utils.output_progress(progress_path,
                                                       args.output_dir,
                                                       self.prefix)
                if args.slack_output:
                    for imgpath in imgpaths:
                        slack_utils.send_slack_img(imgpath)

    def get_batch_for_disc(self, volatile):
        args = self.args
        batch_size = args.batch_size
        src_ids = torch.LongTensor(batch_size).random_(75000)
        tgt_ids = torch.LongTensor(batch_size).random_(75000)
        if args.use_cuda:
            src_ids = src_ids.cuda()
            tgt_ids = tgt_ids.cuda()
        src_embed = self.src_embed(Variable(src_ids, volatile=True))
        tgt_embed = self.tgt_embed(Variable(tgt_ids, volatile=True))
        src_embed = self.netG(Variable(src_embed.data, volatile=volatile))
        tgt_embed = Variable(tgt_embed.data, volatile=volatile)

        x = torch.cat([src_embed, tgt_embed], 0)
        y = torch.FloatTensor(2 * batch_size).zero_()
        y[:batch_size] = 1 - 0.1
        y[batch_size:] = 0.1
        y = Variable(y.cuda() if args.use_cuda else y)
        return x, y

    def train_D(self):
        self.netD.train()
        self.netG.eval()

        x, y = self.get_batch_for_disc(volatile=True)
        preds = self.netD(Variable(x.data))
        loss = self.criterion(preds, y)
        self.optimizer_D.zero_grad()
        loss.backward()
        self.optimizer_D.step()
        return loss.data[0]

    def train_G(self):
        self.netD.eval()
        self.netG.train()

        x, y = self.get_batch_for_disc(volatile=False)
        preds = self.netD(x)
        loss = self.criterion(preds, 1 - y)

        self.optimizer_G.zero_grad()
        loss.backward()
        self.optimizer_G.step()
        self.orthogonalize()

        return loss.data[0]

    def build_dictionary(self):
        src_emb = self.netG(self.src_embed.weight).data
        tgt_emb = self.tgt_embed.weight.data
        src_emb = src_emb / src_emb.norm(2, 1, keepdim=True).expand_as(src_emb)
        tgt_emb = tgt_emb / tgt_emb.norm(2, 1, keepdim=True).expand_as(tgt_emb)
        self.dico = build_dictionary(src_emb, tgt_emb, self.args)

    def orthogonalize(self):
        if self.args.map_beta > 0:
            W = self.netG.W.weight.data
            beta = self.args.map_beta
            W.copy_((1 + beta) * W - beta * W.mm(W.transpose(0, 1).mm(W)))

    def save_netG_state(self):
        multi_gpu = self.arg.multi_gpus
        odir = self.args.output_dir
        fpath = os.path.join(odir, 'netG_state.pth')
        self.logger.log('saving netG state to', fpath)
        if multi_gpu:
            state_dict = self.netG.module.state_dict()
        else:
            state_dict = self.netG.state_dict()
        torch.save(state_dict, fpath)