示例#1
0
 def gen():
     while True:
         if self.result:
             num_result = len(self.result)
             worker.send_multipart(
                 [ident, b'', pickle.dumps(self.result)])
             self.result = []
             time_used = time.clock() - start
             logger.info('encoded %d strs from %s in %.2fs @ %d/s' %
                         (num_result, ident, time_used,
                          int(num_result / time_used)))
         ident, empty, msg = worker.recv_multipart()
         start = time.clock()
         msg = pickle.loads(msg)
         if self.is_valid_input(msg):
             tmp_f = list(
                 convert_lst_to_features(msg, self.max_len,
                                         self.tokenizer))
             yield {
                 'input_ids': [f.input_ids for f in tmp_f],
                 'input_mask': [f.input_mask for f in tmp_f],
                 'input_type_ids': [f.input_type_ids for f in tmp_f]
             }
         else:
             logger.warning(
                 'worker %d: received unsupported type! sending back None'
                 % self.id)
             worker.send_multipart([ident, b'', pickle.dumps(None)])
示例#2
0
        def gen_eilts_article():
            score = dict()
            with open(score_path, "r", encoding="utf-8") as sr:
                for line in sr:
                    score[line.split()[0]] = float(line.split()[1])

            for dirpath, dirnames, filenames in os.walk(essay_path):
                if filenames:
                    for filename in filenames:
                        filepath = os.path.join(dirpath, filename)
                        with open(filepath, "r") as dr:
                            lines = []
                            for line in dr:
                                if line.strip():
                                    lines.append(line.strip())
                            title_and_doc = " ".join(lines)
                            title = title_and_doc.split("\t", 1)[0].strip()
                            doc = title_and_doc.split("\t", 1)[1].strip()
                            sentences = sentence_tokenize(doc)
                            tmp_f = list(
                                convert_lst_to_features(
                                    sentences, self.max_seq_len,
                                    self.tokenizer))
                            yield {
                                "input_ids": [f.input_ids for f in tmp_f],
                                "input_mask": [f.input_mask for f in tmp_f],
                                "input_type_ids":
                                [f.input_type_ids for f in tmp_f],
                                "article_set": 9,
                                "domain1_score": float(score[filename]),
                                "article_id": int(filename)
                            }
示例#3
0
 def gen():
     for i in range(1):
         tmp_f = list(convert_lst_to_features(msg, max_seq_len, tokenizer))
         yield {
             'input_ids': [f.input_ids for f in tmp_f],
             'input_mask': [f.input_mask for f in tmp_f],
             'input_type_ids': [f.input_type_ids for f in tmp_f]
         }
示例#4
0
 def gen():
     while not self.exit_flag.is_set():
         client_id, msg = worker.recv_multipart()
         msg = jsonapi.loads(msg)
         self.logger.info('new job %s, size: %d' % (client_id, len(msg)))
         if BertClient.is_valid_input(msg):
             tmp_f = list(convert_lst_to_features(msg, self.max_seq_len, self.tokenizer))
             yield {
                 'client_id': client_id,
                 'input_ids': [f.input_ids for f in tmp_f],
                 'input_mask': [f.input_mask for f in tmp_f],
                 'input_type_ids': [f.input_type_ids for f in tmp_f]
             }
         else:
             self.logger.error('unsupported type of job %s! sending back None' % client_id)
    def gen():
        for sample in samples:
            l_r_sample = list(sample[0])
            m1, s1 = l_r_sample[0]
            m2, s2 = l_r_sample[1]
            label = sample[1]

            tokens1, mapping1 = tokenizer.tokenize(str(s1))
            tokens2, mapping2 = tokenizer.tokenize(str(s2))

            # Update the mapping with the CLS and SEP tag
            mapping1 = [('[CLS]', 0)] + [
                (token, token_index + 1) for (token, token_index) in mapping1
            ] + [('[SEP]', mapping1[-1][1] + 2)]
            mapping2 = [('[CLS]', 0)] + [
                (token, token_index + 1) for (token, token_index) in mapping2
            ] + [('[SEP]', mapping2[-1][1] + 2)]

            features = list(
                convert_lst_to_features([tokens1, tokens2],
                                        max_seq_length=max_seq_length,
                                        tokenizer=tokenizer))

            # Make sure all text is cleaned and in string format to avoid any encoding issues
            m1 = tokenizer.clean_text(m1)
            m2 = tokenizer.clean_text(m2)
            s1 = tokenizer.clean_text(s1)
            s2 = tokenizer.clean_text(s2)

            # Get the mention masks
            mention_mask1 = BertEncoder.get_mention_mask(
                max_seq_length, m1, s1, mapping1)
            mention_mask2 = BertEncoder.get_mention_mask(
                max_seq_length, m2, s2, mapping2)

            yield {
                'input_ids_left': [features[0].input_ids],
                'input_mask_left': [features[0].input_mask],
                'input_type_ids_left': [features[0].input_type_ids],
                'mention_mask_left': [mention_mask1],
                'input_ids_right': [features[1].input_ids],
                'input_mask_right': [features[1].input_mask],
                'input_type_ids_right': [features[1].input_type_ids],
                'mention_mask_right': [mention_mask2],
                'label': [label]
            }
示例#6
0
 def gen_asap_article():
     dataset = pd.read_csv(file_path)
     articles = dataset["essay"]
     articles_set = dataset["essay_set"]
     domain1_score = dataset["domain1_score"]
     articles_id = dataset["essay_id"]
     for i in range(len(articles)):
         doc = articles[i]
         sentences = sentence_tokenize(doc)
         tmp_f = list(
             convert_lst_to_features(sentences, self.max_seq_len,
                                     self.tokenizer))
         yield {
             "input_ids": [f.input_ids for f in tmp_f],
             "input_mask": [f.input_mask for f in tmp_f],
             "input_type_ids": [f.input_type_ids for f in tmp_f],
             "article_set": articles_set[i],
             "domain1_score": float(domain1_score[i]),
             "article_id": articles_id[i]
         }
示例#7
0
 def gen():
     while not self.exit_flag.is_set():
         self.dest, empty, msg = worker.recv_multipart()
         self._start_t = time.perf_counter()
         msg = pickle.loads(msg)
         if BertClient.is_valid_input(msg):
             tmp_f = list(
                 convert_lst_to_features(msg, self.max_seq_len,
                                         self.tokenizer))
             yield {
                 'input_ids': [f.input_ids for f in tmp_f],
                 'input_mask': [f.input_mask for f in tmp_f],
                 'input_type_ids': [f.input_type_ids for f in tmp_f]
             }
         else:
             logger.warning(
                 'worker %s: received unsupported type! sending back None'
                 % self.dest)
             worker.send_multipart([self.dest, b'', b''])
     worker.close()
示例#8
0
        def gen():
            tokenizer = FullTokenizer(
                vocab_file=os.path.join(self.args.bert_model_dir, 'vocab.txt'))
            # Windows does not support logger in MP environment, thus get a new logger
            # inside the process for better compatibility
            logger = set_logger(
                colored('WORKER-%d' % self.worker_id, 'yellow'), self.verbose)

            poller = zmq.Poller()
            for sock in socks:
                poller.register(sock, zmq.POLLIN)

            logger.info('ready and listening!')
            while not self.exit_flag.is_set():
                events = dict(poller.poll())
                for sock_idx, sock in enumerate(socks):
                    if sock in events:
                        # 接收来自客户端的消息
                        client_id, raw_msg = sock.recv_multipart()
                        msg = jsonapi.loads(raw_msg)
                        logger.info(
                            'new job\tsocket: %d\tsize: %d\tclient: %s' %
                            (sock_idx, len(msg), client_id))
                        # check if msg is a list of list, if yes consider the input is already tokenized
                        # 对接收到的字符进行切词,并且转化为id格式
                        # logger.info('get msg:%s, type:%s' % (msg[0], type(msg[0])))
                        is_tokenized = all(isinstance(el, list) for el in msg)
                        tmp_f = list(
                            convert_lst_to_features(msg, self.max_seq_len,
                                                    tokenizer, logger,
                                                    is_tokenized,
                                                    self.mask_cls_sep))
                        # print([f.input_ids for f in tmp_f])
                        yield {
                            'client_id': client_id,
                            'input_ids': [f.input_ids for f in tmp_f],
                            'input_mask': [f.input_mask for f in tmp_f],
                            'input_type_ids':
                            [f.input_type_ids for f in tmp_f]
                        }
示例#9
0
 def gen():
     while not self.exit_flag.is_set():
         client_id, empty, msg = worker.recv_multipart()
         msg = pickle.loads(msg)
         self.logger.info('received %4d from %s' %
                          (len(msg), client_id))
         if BertClient.is_valid_input(msg):
             tmp_f = list(
                 convert_lst_to_features(msg, self.max_seq_len,
                                         self.tokenizer))
             yield {
                 'client_id': client_id,
                 'input_ids': [f.input_ids for f in tmp_f],
                 'input_mask': [f.input_mask for f in tmp_f],
                 'input_type_ids': [f.input_type_ids for f in tmp_f]
             }
         else:
             self.logger.warning(
                 'received unsupported type from %s! sending back None'
                 % client_id)
             worker.send_multipart([client_id, b'', b''])
     worker.close()
示例#10
0
    def get_token_embeddings(self, original_sentences: List[str]):
        """
        Calculate the token embeddings for a given sentence.
        """
        # Tokenize the sentences
        tokenized_sentences = []
        tokens_mappings = []
        for sentence in original_sentences:
            tokens, mapping = self.tokenize(sentence)

            # Update the mapping with the CLS and SEP tag
            mapping = [('[CLS]', 0)] + [(token, token_index+1) for (token, token_index) in
                                        mapping] + [('[SEP]', mapping[-1][1]+2)]
            tokenized_sentences.append(tokens)
            tokens_mappings.append(mapping)

        # Check if the tokenized input format is correct
        self._check_input_lst_lst_str(tokenized_sentences)

        # Check if all sentences are shorter than the max seq len
        if not self._check_length(tokenized_sentences, self._seq_len, True):
            log.warning('Some of your sentences have more tokens than "max_seq_len=%d" set,'
                        'as a consequence you may get less-accurate or truncated embeddings or lose the '
                        'embedding for a specified phrase of a sentence.\n' % self._seq_len)

        all_token_embeddings = []
        all_feature_tokens = []

        batch = {
            self._input_ids: [],
            self._input_mask: [],
            self._input_type_ids: []
        }
        # Tokenizer is still needed for mapping and some other stuff that happens in the convert method
        for sample_index, data in enumerate(zip(convert_lst_to_features(tokenized_sentences,
                                                                        max_seq_length=self._seq_len,
                                                                        tokenizer=self._tokenizer),
                                                original_sentences, tokens_mappings)):
            feature, sentence, token_mapping = data

            batch[self._input_ids].append(feature.input_ids)
            batch[self._input_mask].append(feature.input_mask)
            batch[self._input_type_ids].append(feature.input_type_ids)

            all_feature_tokens.append(feature.tokens)
            if sample_index % self._batch_size == 0:
                batch_token_embeddings = self._sess.run(self._bert_output_layer, feed_dict=batch)

                for token_embeddings in batch_token_embeddings:
                    all_token_embeddings.append(token_embeddings)

                # Reset the batch
                batch = {
                    self._input_ids: [],
                    self._input_mask: [],
                    self._input_type_ids: []
                }

        # Handle leftover samples that did not fit in the last batch
        if len(batch[self._input_ids]) > 0:
            batch_mention_embeddings = self._sess.run(self._bert_output_layer, feed_dict=batch)

            for mention_embedding in batch_mention_embeddings:
                all_token_embeddings.append(mention_embedding)

        return all_token_embeddings, all_feature_tokens, tokenized_sentences, tokens_mappings