Beispiel #1
0
class BertEmbeddings:
    def __init__(self,
                 model='bert_24_1024_16',
                 corpus='book_corpus_wiki_en_cased'):
        self.__model = model
        self.__corpus = corpus

        assert self.__model in ['bert_12_768_12',
                                'bert_24_1024_16'], "Model is not recognized."
        assert self.__corpus in [
            'book_corpus_wiki_en_uncased', 'book_corpus_wiki_en_cased',
            'wiki_multilingual', 'wiki_multilingual_cased'
        ], "Corpus is unknown."

        self.__bert = BertEmbedding(model=self.__model,
                                    dataset_name=self.__corpus)

    def predict(self, text):
        if not isinstance(text, list):
            text = [text]

        bertEmbeddings = self.__bert.embedding(text)
        return bertEmbeddings
Beispiel #2
0
def to_dataset(samples,
               labels,
               ctx=mx.gpu(),
               batch_size=64,
               max_seq_length=25):
    '''
    this function will use BertEmbedding to get each fields' embeddings
    and load the given labels, put them together into a dataset
    '''
    bertembedding = BertEmbedding(ctx=mx.gpu(),
                                  batch_size=batch_size,
                                  max_seq_length=max_seq_length)
    logger.info('Construct bert embedding for sentences')
    embs = []
    from tqdm import tqdm
    for sample in tqdm(samples):
        tokens_embs = bertembedding.embedding(sample)
        embs.append([np.asarray(token_emb[1]) for token_emb in tokens_embs])

    if labels:
        dataset = [[*obs_hyp, label] for obs_hyp, label in zip(embs, labels)]
    else:
        dataset = embs
    return dataset
Beispiel #3
0
def main():
    np.set_printoptions(threshold=5)
    parser = argparse.ArgumentParser(
        description='Get embeddings from BERT',
        formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument(
        '--gpu',
        type=int,
        default=None,
        help='id of the gpu to use. Set it to empty means to use cpu.')
    parser.add_argument('--model',
                        type=str,
                        default='bert_12_768_12',
                        help='pre-trained model')
    parser.add_argument('--dataset_name',
                        type=str,
                        default='book_corpus_wiki_en_uncased',
                        help='dataset')
    parser.add_argument('--max_seq_length',
                        type=int,
                        default=25,
                        help='max length of each sequence')
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='batch size')
    parser.add_argument(
        '--oov_way',
        type=str,
        default='avg',
        help='how to handle oov\n'
        'avg: average all oov embeddings to represent the original token\n'
        'sum: sum all oov embeddings to represent the original token\n'
        'last: use last oov embeddings to represent the original token\n')
    parser.add_argument('--sentences',
                        type=str,
                        nargs='+',
                        default=None,
                        help='sentence for encoding')
    parser.add_argument('--file',
                        type=str,
                        default=None,
                        help='file for encoding')

    args = parser.parse_args()
    context = mx.gpu(args.gpu) if args.gpu else mx.cpu()
    bert = BertEmbedding(ctx=context,
                         model=args.model,
                         dataset_name=args.dataset_name,
                         max_seq_length=args.max_seq_length,
                         batch_size=args.batch_size)
    result = []
    sents = []
    if args.sentences:
        sents = args.sentences
        result = bert.embedding(sents, oov_way=args.oov_way)
    elif args.file:
        with io.open(args.file, 'r', encoding='utf8') as in_file:
            for line in in_file:
                sents.append(line.strip())
        result = bert.embedding(sents, oov_way=args.oov_way)
    else:
        print('Please specify --sentence or --file')

    if result:
        for sent, embeddings in zip(sents, result):
            print('Text: {}'.format(sent))
            sentence_embedding, _, tokens_embedding = embeddings
            print('Sentence embedding: {}'.format(sentence_embedding))
            print('Tokens embedding: {}'.format(tokens_embedding))