Esempio n. 1
0
def mineBitext(src_sents, trg_sents, x, y, x2y_ind, x2y_mean, y2x_ind,
               y2x_mean, outputFSrc, outputFTgt, outputFScore, encoding,
               margin, retrieval, threshold, verbose):
    logger.info(' - mining for parallel data')
    fwd_scores = score_candidates(x, y, x2y_ind, x2y_mean, y2x_mean, margin,
                                  verbose)
    bwd_scores = score_candidates(y, x, y2x_ind, y2x_mean, x2y_mean, margin,
                                  verbose)
    fwd_best = x2y_ind[np.arange(x.shape[0]), fwd_scores.argmax(axis=1)]
    bwd_best = y2x_ind[np.arange(y.shape[0]), bwd_scores.argmax(axis=1)]

    logger.info(' - writing mined output to {:s}, {:s}, {:s}'.format(
        outputFSrc, outputFTgt, outputFScore))
    if threshold > 0:
        logger.info(' - with threshold of {:f}'.format(threshold))

    foutSrc = openOutputF(outputFSrc, encoding)
    foutTgt = openOutputF(outputFTgt, encoding)
    foutScore = openOutputF(outputFScore, encoding)

    def _printTriplet(src, tgt, score):
        foutSrc.write('{:s}\n'.format(src))
        foutTgt.write('{:s}\n'.format(tgt))
        foutScore.write('{:f}\n'.format(score))

    if retrieval == 'fwd':
        for i, j in enumerate(fwd_best):
            _printTriplet(src_sents[i], trg_sents[j], fwd_scores[i].max())
    if retrieval == 'bwd':
        for j, i in enumerate(bwd_best):
            _printTriplet(src_sents[i], trg_sents[j], bwd_scores[j].max())
    if retrieval == 'intersect':
        for i, j in enumerate(fwd_best):
            if bwd_best[j] == i:
                _printTriplet(src_sents[i], trg_sents[j], fwd_scores[i].max())
    if retrieval == 'max':
        indices = np.stack((np.concatenate((np.arange(x.shape[0]), bwd_best)),
                            np.concatenate((fwd_best, np.arange(y.shape[0])))),
                           axis=1)
        scores = np.concatenate(
            (fwd_scores.max(axis=1), bwd_scores.max(axis=1)))
        seen_src, seen_trg = set(), set()
        for i in np.argsort(-scores):
            src_ind, trg_ind = indices[i]
            if src_ind not in seen_src and trg_ind not in seen_trg:
                seen_src.add(src_ind)
                seen_trg.add(trg_ind)
                if scores[i] > threshold:
                    _printTriplet(src_sents[src_ind], trg_sents[trg_ind],
                                  scores[i])
    foutSrc.close()
    foutTgt.close()
    foutScore.close()
Esempio n. 2
0
def Mine(src_doc_ind, trg_doc_ind, src, trg, encoding, src_embeddings,
         trg_embeddings, output, unify, mode, retrieval, margin, neighborhood,
         gpu, dim, threshold, verbose):
    print('LASER: tool to search, score or mine bitexts', file=sys.stderr)
    if gpu:
        print(' - knn will run on all available GPUs (recommended)',
              file=sys.stderr)
    else:
        print(' - knn will run on CPU (slow)', file=sys.stderr)

    args = AttrDict({"encoding": encoding, "unify": unify, "verbose": verbose})
    src_inds, src_sents = TextLoadUnify(src, args)
    trg_inds, trg_sents = TextLoadUnify(trg, args)

    def unique_embeddings(emb, ind, verbose=False):
        aux = {j: i for i, j in enumerate(ind)}
        if verbose:
            print(' - unify embeddings: {:d} -> {:d}'.format(
                len(emb), len(aux)),
                  file=sys.stderr)
        return emb[[aux[i] for i in range(len(aux))]]

    # load the embeddings
    x = EmbedLoad(src_embeddings, dim, verbose=verbose)
    if unify:
        x = unique_embeddings(x, src_inds, verbose)
    faiss.normalize_L2(x)
    y = EmbedLoad(trg_embeddings, dim, verbose=verbose)
    if unify:
        y = unique_embeddings(y, trg_inds, verbose)
    faiss.normalize_L2(y)

    # calculate knn in both directions
    if retrieval != 'bwd':
        if verbose:
            print(' - perform {:d}-nn source against target'.format(
                neighborhood),
                  file=sys.stderr)
        x2y_sim, x2y_ind = knn(x, y, min(y.shape[0], neighborhood), gpu)
        x2y_mean = x2y_sim.mean(axis=1)

    if retrieval != 'fwd':
        if verbose:
            print(' - perform {:d}-nn target against source'.format(
                neighborhood),
                  file=sys.stderr)
        y2x_sim, y2x_ind = knn(y, x, min(x.shape[0], neighborhood), gpu)
        y2x_mean = y2x_sim.mean(axis=1)

    # margin function
    if margin == 'absolute':

        def margin(a, b):
            return a
    elif margin == 'distance':

        def margin(a, b):
            return a - b
    else:  # margin == 'ratio':

        def margin(a, b):
            return a / b

    if output:
        if output.endswith('.xz'):
            fout = lzma.open(output,
                             mode='at',
                             encoding=encoding,
                             errors='surrogateescape')
        else:
            fout = open(output,
                        mode='a',
                        encoding=encoding,
                        errors='surrogateescape')
    else:
        output = "stdout"
        fout = sys.stdout

    if mode == 'search':
        if verbose:
            print(' - Searching for closest sentences in target',
                  file=sys.stderr)
            print(' - writing alignments to {:s}'.format(output),
                  file=sys.stderr)
        scores = score_candidates(x, y, x2y_ind, x2y_mean, y2x_mean, margin,
                                  verbose)
        best = x2y_ind[np.arange(x.shape[0]), scores.argmax(axis=1)]

        nbex = x.shape[0]
        ref = np.linspace(0, nbex - 1, nbex).astype(int)  # [0, nbex)
        err = nbex - np.equal(best.reshape(nbex), ref).astype(int).sum()
        print(' - errors: {:d}={:.2f}%'.format(err, 100 * err / nbex),
              file=sys.stderr)
        for i in src_inds:
            print(trg_sents[best[i]], file=fout)

    elif mode == 'score':
        for i, j in zip(src_inds, trg_inds):
            s = score(x[i], y[j], x2y_mean[i], y2x_mean[j], margin)
            print(s, src_sents[i], trg_sents[j], sep='\t', file=fout)

    elif mode == 'mine':
        if verbose:
            print(' - mining for parallel data', file=sys.stderr)
        fwd_scores = score_candidates(x, y, x2y_ind, x2y_mean, y2x_mean,
                                      margin, verbose)
        bwd_scores = score_candidates(y, x, y2x_ind, y2x_mean, x2y_mean,
                                      margin, verbose)
        fwd_best = x2y_ind[np.arange(x.shape[0]), fwd_scores.argmax(axis=1)]
        bwd_best = y2x_ind[np.arange(y.shape[0]), bwd_scores.argmax(axis=1)]
        if verbose:
            print(' - writing alignments to {:s}'.format(output),
                  file=sys.stderr)
            if threshold > 0:
                print(' - with threshold of {:f}'.format(threshold),
                      file=sys.stderr)
        if retrieval == 'fwd':
            for i, j in enumerate(fwd_best):
                print(fwd_scores[i].max(),
                      src_sents[i],
                      trg_sents[j],
                      sep='\t',
                      file=fout)
        if retrieval == 'bwd':
            for j, i in enumerate(bwd_best):
                print(bwd_scores[j].max(),
                      src_sents[i],
                      trg_sents[j],
                      sep='\t',
                      file=fout)
        if retrieval == 'intersect':
            for i, j in enumerate(fwd_best):
                if bwd_best[j] == i:
                    print(fwd_scores[i].max(),
                          src_sents[i],
                          trg_sents[j],
                          sep='\t',
                          file=fout)
        if retrieval == 'max':
            indices = np.stack(
                (np.concatenate((np.arange(x.shape[0]), bwd_best)),
                 np.concatenate((fwd_best, np.arange(y.shape[0])))),
                axis=1)
            scores = np.concatenate(
                (fwd_scores.max(axis=1), bwd_scores.max(axis=1)))
            seen_src, seen_trg = set(), set()
            for i in np.argsort(-scores):
                src_ind, trg_ind = indices[i]
                if src_ind not in seen_src and trg_ind not in seen_trg:
                    seen_src.add(src_ind)
                    seen_trg.add(trg_ind)
                    if scores[i] > threshold:
                        print(src_doc_ind,
                              trg_doc_ind,
                              src_sents[src_ind],
                              trg_sents[trg_ind],
                              scores[i],
                              sep='\t',
                              file=fout)

    if fout != sys.stdout:
        fout.close()