Пример #1
0
    def dist_mean_cosine(self, src_emb, tgt_emb):
        """
        Mean-cosine model selection criterion.
        """
        # get normalized embeddings
        src_emb = src_emb / src_emb.norm(2, 1, keepdim=True).expand_as(src_emb)
        tgt_emb = tgt_emb / tgt_emb.norm(2, 1, keepdim=True).expand_as(tgt_emb)

        # build dictionary
        for dico_method in ['csls_knn_10']:
            dico_max_size = 10000
            _params = self.params
            s2t_candidates = get_candidates(src_emb, tgt_emb, _params)
            t2s_candidates = get_candidates(tgt_emb, src_emb, _params)
            dico = build_dictionary(src_emb, tgt_emb, _params, s2t_candidates,
                                    t2s_candidates)
            # mean cosine
            if dico is None:
                mean_cosine = -1e9
            else:
                mean_cosine = (src_emb[dico[:dico_max_size, 0]] *
                               tgt_emb[dico[:dico_max_size, 1]]).sum(1).mean()
            print(
                "Mean cosine (%s method, %s build, %i max size): %.5f" %
                (dico_method, _params.dico_build, dico_max_size, mean_cosine))
            # to_log['mean_cosine-%s-%s-%i' % (dico_method, _params.dico_build, dico_max_size)] = mean_cosine

            return mean_cosine
Пример #2
0
 def build_dictionary(self):
     """
     Build a dictionary from aligned embeddings.
     """
     src_emb = self.mapping(self.src_emb.weight).data
     tgt_emb = self.tgt_emb.weight.data
     self.dico = build_dictionary(src_emb, tgt_emb, self.params)
Пример #3
0
 def build_dictionary(self):
     """
     Build a dictionary from aligned embeddings.
     """
     src_emb = self.mapping(self.src_emb.weight).data
     tgt_emb = self.tgt_emb.weight.data
     src_emb = src_emb / src_emb.norm(2, 1, keepdim=True).expand_as(src_emb)
     tgt_emb = tgt_emb / tgt_emb.norm(2, 1, keepdim=True).expand_as(tgt_emb)
     self.dico = build_dictionary(src_emb, tgt_emb, self.params)
Пример #4
0
    def dist_mean_cosine(self, to_log):
        """
        Mean-cosine model selection criterion.
        """
        # get normalized embeddings
        src_emb = self.src_mapping(self.src_emb.weight, self.src_adj).data
        tgt_emb = self.tgt_mapping(self.tgt_emb.weight, self.tgt_adj).data
        src_emb = src_emb / src_emb.norm(2, 1, keepdim=True).expand_as(src_emb)
        tgt_emb = tgt_emb / tgt_emb.norm(2, 1, keepdim=True).expand_as(tgt_emb)

        # build dictionary
        for dico_method in ['nn', 'csls_knn_10']:
            dico_build = 'S2T'
            dico_max_size = 10000
            # temp params / dictionary generation
            _params = deepcopy(self.params)
            _params.dico_method = dico_method
            _params.dico_build = dico_build
            _params.dico_threshold = 0
            # _params.dico_max_rank = 10000
            _params.dico_max_rank = 0
            _params.dico_min_size = 0
            _params.dico_max_size = dico_max_size
            s2t_candidates = get_candidates(src_emb, tgt_emb, _params)
            t2s_candidates = get_candidates(tgt_emb, src_emb, _params)
            dico = build_dictionary(src_emb, tgt_emb, _params, s2t_candidates,
                                    t2s_candidates)
            # mean cosine
            if dico is None:
                mean_cosine = -1e9
            else:
                mean_cosine = (src_emb[dico[:dico_max_size, 0]] *
                               tgt_emb[dico[:dico_max_size, 1]]).sum(1).mean()
            mean_cosine = mean_cosine.item() if isinstance(
                mean_cosine, torch_tensor) else mean_cosine
            logger.info(
                "Mean cosine (%s method, %s build, %i max size): %.5f" %
                (dico_method, _params.dico_build, dico_max_size, mean_cosine))
            to_log['mean_cosine-%s-%s-%i' % (dico_method, _params.dico_build,
                                             dico_max_size)] = mean_cosine
Пример #5
0
    def export(self,
               src_dico,
               tgt_dico,
               emb_en,
               emb_it,
               seed,
               export_emb=False):
        params = _get_eval_params(self.params)
        eval = Evaluator(params, emb_en, emb_it, torch.cuda.is_available())
        # Export adversarial dictionaries
        optim_X_AE = AE(params).cuda()
        optim_Y_AE = AE(params).cuda()
        print('Loading pre-trained models...')
        optim_X_AE.load_state_dict(
            torch.load(self.tune_dir +
                       '/best/seed_{}_dico_{}_best_X.t7'.format(
                           seed, params.dico_build)))
        optim_Y_AE.load_state_dict(
            torch.load(self.tune_dir +
                       '/best/seed_{}_dico_{}_best_Y.t7'.format(
                           seed, params.dico_build)))
        X_Z = optim_X_AE.encode(Variable(emb_en)).data
        Y_Z = optim_Y_AE.encode(Variable(emb_it)).data

        mstart_time = timer()
        for method in ['nn', 'csls_knn_10']:
            results = get_word_translation_accuracy(params.src_lang,
                                                    src_dico[1],
                                                    X_Z,
                                                    params.tgt_lang,
                                                    tgt_dico[1],
                                                    emb_it,
                                                    method=method,
                                                    dico_eval=self.eval_file,
                                                    device=params.cuda_device)
            acc1 = results[0][1]
            results = get_word_translation_accuracy(params.tgt_lang,
                                                    tgt_dico[1],
                                                    Y_Z,
                                                    params.src_lang,
                                                    src_dico[1],
                                                    emb_en,
                                                    method=method,
                                                    dico_eval=self.eval_file2,
                                                    device=params.cuda_device)
            acc2 = results[0][1]

            # csls = 0
            print('{} takes {:.2f}s'.format(method, timer() - mstart_time))
            print('Method:{} score:{:.4f}-{:.4f}'.format(method, acc1, acc2))

        f_csls = eval.dist_mean_cosine(X_Z, emb_it)
        b_csls = eval.dist_mean_cosine(Y_Z, emb_en)
        csls = (f_csls + b_csls) / 2.0
        print("Seed:{},ACC:{:.4f}-{:.4f},CSLS_FB:{:.6f}".format(
            seed, acc1, acc2, csls))
        #'''
        print('Building dictionaries...')
        params.dico_build = "S2T&T2S"
        params.dico_method = "csls_knn_10"
        X_Z = X_Z / X_Z.norm(2, 1, keepdim=True).expand_as(X_Z)
        emb_it = emb_it / emb_it.norm(2, 1, keepdim=True).expand_as(emb_it)
        f_dico_induce = build_dictionary(X_Z, emb_it, params)
        f_dico_induce = f_dico_induce.cpu().numpy()
        Y_Z = Y_Z / Y_Z.norm(2, 1, keepdim=True).expand_as(Y_Z)
        emb_en = emb_en / emb_en.norm(2, 1, keepdim=True).expand_as(emb_en)
        b_dico_induce = build_dictionary(Y_Z, emb_en, params)
        b_dico_induce = b_dico_induce.cpu().numpy()

        f_dico_set = set([(a, b) for a, b in f_dico_induce])
        b_dico_set = set([(b, a) for a, b in b_dico_induce])

        intersect = list(f_dico_set & b_dico_set)
        union = list(f_dico_set | b_dico_set)

        with io.open(
                self.tune_dir +
                '/export/{}-{}.dict'.format(params.src_lang, params.tgt_lang),
                'w',
                encoding='utf-8',
                newline='\n') as f:
            for item in f_dico_induce:
                f.write('{} {}\n'.format(src_dico[0][item[0]],
                                         tgt_dico[0][item[1]]))

        with io.open(
                self.tune_dir +
                '/export/{}-{}.dict'.format(params.tgt_lang, params.src_lang),
                'w',
                encoding='utf-8',
                newline='\n') as f:
            for item in b_dico_induce:
                f.write('{} {}\n'.format(tgt_dico[0][item[0]],
                                         src_dico[0][item[1]]))

        with io.open(self.tune_dir + '/export/{}-{}.intersect'.format(
                params.src_lang, params.tgt_lang),
                     'w',
                     encoding='utf-8',
                     newline='\n') as f:
            for item in intersect:
                f.write('{} {}\n'.format(src_dico[0][item[0]],
                                         tgt_dico[0][item[1]]))

        with io.open(self.tune_dir + '/export/{}-{}.intersect'.format(
                params.tgt_lang, params.src_lang),
                     'w',
                     encoding='utf-8',
                     newline='\n') as f:
            for item in intersect:
                f.write('{} {}\n'.format(tgt_dico[0][item[1]],
                                         src_dico[0][item[0]]))

        with io.open(
                self.tune_dir +
                '/export/{}-{}.union'.format(params.src_lang, params.tgt_lang),
                'w',
                encoding='utf-8',
                newline='\n') as f:
            for item in union:
                f.write('{} {}\n'.format(src_dico[0][item[0]],
                                         tgt_dico[0][item[1]]))

        with io.open(
                self.tune_dir +
                '/export/{}-{}.union'.format(params.tgt_lang, params.src_lang),
                'w',
                encoding='utf-8',
                newline='\n') as f:
            for item in union:
                f.write('{} {}\n'.format(tgt_dico[0][item[1]],
                                         src_dico[0][item[0]]))

        if export_emb:
            print('Exporting {}-{}.{}'.format(params.src_lang, params.tgt_lang,
                                              params.src_lang))
            loader.export_embeddings(
                src_dico[0],
                X_Z,
                path=self.tune_dir + '/export/{}-{}.{}'.format(
                    params.src_lang, params.tgt_lang, params.src_lang),
                eformat='txt')
            print('Exporting {}-{}.{}'.format(params.src_lang, params.tgt_lang,
                                              params.tgt_lang))
            loader.export_embeddings(
                tgt_dico[0],
                emb_it,
                path=self.tune_dir + '/export/{}-{}.{}'.format(
                    params.src_lang, params.tgt_lang, params.tgt_lang),
                eformat='txt')
            print('Exporting {}-{}.{}'.format(params.tgt_lang, params.src_lang,
                                              params.tgt_lang))
            loader.export_embeddings(
                tgt_dico[0],
                Y_Z,
                path=self.tune_dir + '/export/{}-{}.{}'.format(
                    params.tgt_lang, params.src_lang, params.tgt_lang),
                eformat='txt')
            print('Exporting {}-{}.{}'.format(params.tgt_lang, params.src_lang,
                                              params.src_lang))
            loader.export_embeddings(
                src_dico[0],
                emb_en,
                path=self.tune_dir + '/export/{}-{}.{}'.format(
                    params.tgt_lang, params.src_lang, params.src_lang),
                eformat='txt')
Пример #6
0
    def export_dict(self, src_dico, tgt_dico, emb_en, emb_it, seed):
        params = self.params
        # Export adversarial dictionaries
        optim_X_AE = VAE(params).cuda()
        optim_Y_AE = VAE(params).cuda()
        print('Loading pre-trained models...')
        optim_X_AE.load_state_dict(
            torch.load(self.tune_dir + '/best/seed_{}_best_X.t7'.format(seed)))
        optim_Y_AE.load_state_dict(
            torch.load(self.tune_dir + '/best/seed_{}_best_Y.t7'.format(seed)))
        X_Z = optim_X_AE.encode(Variable(emb_en)).data
        Y_Z = optim_Y_AE.encode(Variable(emb_it)).data

        mstart_time = timer()
        for method in [params.eval_method]:
            results = get_word_translation_accuracy(params.src_lang,
                                                    src_dico[1],
                                                    X_Z,
                                                    params.tgt_lang,
                                                    tgt_dico[1],
                                                    emb_it,
                                                    method=method,
                                                    dico_eval='default')
            acc1 = results[0][1]
        for method in [params.eval_method]:
            results = get_word_translation_accuracy(params.tgt_lang,
                                                    tgt_dico[1],
                                                    Y_Z,
                                                    params.src_lang,
                                                    src_dico[1],
                                                    emb_en,
                                                    method=method,
                                                    dico_eval='default')
            acc2 = results[0][1]
        # csls = 0
        print('{} takes {:.2f}s'.format(method, timer() - mstart_time))
        print('Method:{} score:{:.4f}-{:.4f}'.format(method, acc1, acc2))

        print('Building dictionaries...')
        params.dico_build = "S2T&T2S"
        params.dico_method = "csls_knn_10"
        X_Z = X_Z / X_Z.norm(2, 1, keepdim=True).expand_as(X_Z)
        emb_it = emb_it / emb_it.norm(2, 1, keepdim=True).expand_as(emb_it)
        f_dico_induce = build_dictionary(X_Z, emb_it, params)
        f_dico_induce = f_dico_induce.cpu().numpy()
        Y_Z = Y_Z / Y_Z.norm(2, 1, keepdim=True).expand_as(Y_Z)
        emb_en = emb_en / emb_en.norm(2, 1, keepdim=True).expand_as(emb_en)
        b_dico_induce = build_dictionary(Y_Z, emb_en, params)
        b_dico_induce = b_dico_induce.cpu().numpy()

        f_dico_set = set([(a, b) for a, b in f_dico_induce])
        b_dico_set = set([(b, a) for a, b in b_dico_induce])

        intersect = list(f_dico_set & b_dico_set)
        union = list(f_dico_set | b_dico_set)

        with io.open(
                self.tune_dir +
                '/best/{}-{}.dict'.format(params.src_lang, params.tgt_lang),
                'w',
                encoding='utf-8',
                newline='\n') as f:
            for item in f_dico_induce:
                f.write('{} {}\n'.format(src_dico[0][item[0]],
                                         tgt_dico[0][item[1]]))

        with io.open(
                self.tune_dir +
                '/best/{}-{}.dict'.format(params.tgt_lang, params.src_lang),
                'w',
                encoding='utf-8',
                newline='\n') as f:
            for item in b_dico_induce:
                f.write('{} {}\n'.format(tgt_dico[0][item[0]],
                                         src_dico[0][item[1]]))

        with io.open(self.tune_dir + '/best/{}-{}.intersect'.format(
                params.src_lang, params.tgt_lang),
                     'w',
                     encoding='utf-8',
                     newline='\n') as f:
            for item in intersect:
                f.write('{} {}\n'.format(src_dico[0][item[0]],
                                         tgt_dico[0][item[1]]))

        with io.open(self.tune_dir + '/best/{}-{}.intersect'.format(
                params.tgt_lang, params.src_lang),
                     'w',
                     encoding='utf-8',
                     newline='\n') as f:
            for item in intersect:
                f.write('{} {}\n'.format(tgt_dico[0][item[1]],
                                         src_dico[0][item[0]]))

        with io.open(
                self.tune_dir +
                '/best/{}-{}.union'.format(params.src_lang, params.tgt_lang),
                'w',
                encoding='utf-8',
                newline='\n') as f:
            for item in union:
                f.write('{} {}\n'.format(src_dico[0][item[0]],
                                         tgt_dico[0][item[1]]))

        with io.open(
                self.tune_dir +
                '/best/{}-{}.union'.format(params.tgt_lang, params.src_lang),
                'w',
                encoding='utf-8',
                newline='\n') as f:
            for item in union:
                f.write('{} {}\n'.format(tgt_dico[0][item[1]],
                                         src_dico[0][item[0]]))