class BertEmbeddings: def __init__(self, model='bert_24_1024_16', corpus='book_corpus_wiki_en_cased'): self.__model = model self.__corpus = corpus assert self.__model in ['bert_12_768_12', 'bert_24_1024_16'], "Model is not recognized." assert self.__corpus in [ 'book_corpus_wiki_en_uncased', 'book_corpus_wiki_en_cased', 'wiki_multilingual', 'wiki_multilingual_cased' ], "Corpus is unknown." self.__bert = BertEmbedding(model=self.__model, dataset_name=self.__corpus) def predict(self, text): if not isinstance(text, list): text = [text] bertEmbeddings = self.__bert.embedding(text) return bertEmbeddings
def to_dataset(samples, labels, ctx=mx.gpu(), batch_size=64, max_seq_length=25): ''' this function will use BertEmbedding to get each fields' embeddings and load the given labels, put them together into a dataset ''' bertembedding = BertEmbedding(ctx=mx.gpu(), batch_size=batch_size, max_seq_length=max_seq_length) logger.info('Construct bert embedding for sentences') embs = [] from tqdm import tqdm for sample in tqdm(samples): tokens_embs = bertembedding.embedding(sample) embs.append([np.asarray(token_emb[1]) for token_emb in tokens_embs]) if labels: dataset = [[*obs_hyp, label] for obs_hyp, label in zip(embs, labels)] else: dataset = embs return dataset
def main(): np.set_printoptions(threshold=5) parser = argparse.ArgumentParser( description='Get embeddings from BERT', formatter_class=argparse.RawTextHelpFormatter) parser.add_argument( '--gpu', type=int, default=None, help='id of the gpu to use. Set it to empty means to use cpu.') parser.add_argument('--model', type=str, default='bert_12_768_12', help='pre-trained model') parser.add_argument('--dataset_name', type=str, default='book_corpus_wiki_en_uncased', help='dataset') parser.add_argument('--max_seq_length', type=int, default=25, help='max length of each sequence') parser.add_argument('--batch_size', type=int, default=256, help='batch size') parser.add_argument( '--oov_way', type=str, default='avg', help='how to handle oov\n' 'avg: average all oov embeddings to represent the original token\n' 'sum: sum all oov embeddings to represent the original token\n' 'last: use last oov embeddings to represent the original token\n') parser.add_argument('--sentences', type=str, nargs='+', default=None, help='sentence for encoding') parser.add_argument('--file', type=str, default=None, help='file for encoding') args = parser.parse_args() context = mx.gpu(args.gpu) if args.gpu else mx.cpu() bert = BertEmbedding(ctx=context, model=args.model, dataset_name=args.dataset_name, max_seq_length=args.max_seq_length, batch_size=args.batch_size) result = [] sents = [] if args.sentences: sents = args.sentences result = bert.embedding(sents, oov_way=args.oov_way) elif args.file: with io.open(args.file, 'r', encoding='utf8') as in_file: for line in in_file: sents.append(line.strip()) result = bert.embedding(sents, oov_way=args.oov_way) else: print('Please specify --sentence or --file') if result: for sent, embeddings in zip(sents, result): print('Text: {}'.format(sent)) sentence_embedding, _, tokens_embedding = embeddings print('Sentence embedding: {}'.format(sentence_embedding)) print('Tokens embedding: {}'.format(tokens_embedding))