Beispiel #1
0
 def data_generator(
     self,
     seq_type,
     x_train,
     x_valid,
     y_train,
     y_valid,
     x_len_train=None,
     x_len_valid=None,
 ):
     if seq_type == 'bucket':
         logger.info('use bucket sequence to speed up model training')
         train_batches = BucketIterator(self.task_type, self.transformer,
                                        x_len_train, x_train, y_train,
                                        self.nb_bucket, self.batch_size)
         valid_batches = BucketIterator(self.task_type, self.transformer,
                                        x_len_valid, x_valid, y_valid,
                                        self.nb_bucket, self.batch_size)
     elif seq_type == 'basic':
         train_batches = BasicIterator(self.task_type, self.transformer,
                                       x_train, y_train, self.batch_size)
         valid_batches = BasicIterator(self.task_type, self.transformer,
                                       x_valid, y_valid, self.batch_size)
     else:
         logger.warning(
             'invalid data iterator type, only supports "basic" or "bucket"'
         )
     return train_batches, valid_batches
Beispiel #2
0
 def predict(self,
             x: Dict[str, List[List[str]]],
             batch_size=64,
             return_attention=False,
             return_prob=False):
     n_labels = len(self.transformer._label_vocab._id2token)
     x_c = deepcopy(x)
     start = time.time()
     x_len = [item[-1] for item in x_c['token']]
     x_c['token'] = [item[:-1] for item in x_c['token']]
     x_seq = BasicIterator('classification',
                           self.transformer,
                           x_c,
                           batch_size=batch_size)
     result = self.model.model.predict_generator(x_seq)
     if return_prob:
         y_pred = result[:, :n_labels]
     else:
         y_pred = self.transformer.inverse_transform(result[:, :n_labels])
     used_time = time.time() - start
     logger.info('predict {} samples used {:4.1f}s'.format(
         len(x['token']), used_time))
     if result.shape[1] > n_labels and self.model_name == 'bi_lstm_att':
         attention = result[:, n_labels:]
         attention = [attention[idx][:l] for idx, l in enumerate(x_len)]
         return y_pred, attention
     else:
         return y_pred
Beispiel #3
0
 def load_data(self):
     if self.task_type == 'classification':
         self.load_tc_data()
     elif self.task_type == 'sequence_labeling':
         if self.mode != 'predict':
             self.load_sl_data()
         else:
             self.texts = [line.strip().split() for line in open(self.fname, 'r', encoding='utf8')]
     logger.info('data loaded')
Beispiel #4
0
    def __init__(self,
                 task_type: str,
                 transformer: IndexTransformer,
                 seq_lengths: List[int],
                 x: Dict[str, List[List[str]]],
                 y: List[List[str]],
                 num_buckets: int = 8,
                 batch_size=1):
        self.task_type = task_type
        self.t = transformer
        self.batch_size = batch_size
        self.task_type = task_type
        self.x = x
        self.y = y
        if self.t.use_radical:
            self.radical_dict = self.t.radical_dict
        else:
            self.radical_dict = None

        # Count bucket sizes
        bucket_sizes, bucket_ranges = np.histogram(seq_lengths,
                                                   bins=num_buckets)
        # Looking for non-empty buckets
        actual_buckets = [
            bucket_ranges[i + 1] for i, bs in enumerate(bucket_sizes) if bs > 0
        ]
        actual_bucket_sizes = [bs for bs in bucket_sizes if bs > 0]
        self.bucket_seqlen = [int(math.ceil(bs)) for bs in actual_buckets]
        num_actual = len(actual_buckets)
        logger.info('Training with %d non-empty buckets' % num_actual)

        self.bins = [(defaultdict(list), []) for bs in actual_bucket_sizes]
        assert len(self.bins) == num_actual

        # Insert the sequences into the bins
        self.feature_keys = list(self.x.keys())
        for i, sl in enumerate(seq_lengths):
            for j in range(num_actual):
                bsl = self.bucket_seqlen[j]
                if sl < bsl or j == num_actual - 1:
                    for k in self.feature_keys:
                        self.bins[j][0][k].append(x[k][i])
                    self.bins[j][1].append(y[i])
                    break

        self.num_samples = len(self.x['token'])
        self.dataset_len = int(
            sum([
                math.ceil(bs / self.batch_size) for bs in actual_bucket_sizes
            ]))
        self._permute()
Beispiel #5
0
 def extend_vocab(self, new_vocab, max_tokens=10000):
     assert isinstance(new_vocab, list)
     if max_tokens < 0:
         max_tokens = 10000
     base_index = self.__len__()
     added = 0
     for word in new_vocab:
         if added >= max_tokens:
             break
         if word not in self._token2id:
             self._token2id[word] = base_index + added
             self._id2token.append(word)
             added += 1
     logger.info('%d new words have been added to vocab' % added)
     return added
Beispiel #6
0
 def predict(self,
             x: Dict[str, List[List[str]]],
             batch_size=64,
             return_prob=False):
     start = time.time()
     x_c = deepcopy(x)
     x_len = [item[-1] for item in x_c['token']]
     x_c['token'] = [item[:-1] for item in x_c['token']]
     x_seq = BasicIterator('sequence_labeling',
                           self.transformer,
                           x_c,
                           batch_size=batch_size)
     result = self.model.model.predict_generator(x_seq)
     if return_prob:
         y_pred = [result[idx][:l] for idx, l in enumerate(x_len)]
     else:
         y_pred = self.transformer.inverse_transform(result, lengths=x_len)
     used_time = time.time() - start
     logger.info('predict {} samples used {:4.1f}s'.format(
         len(x['token']), used_time))
     return y_pred
Beispiel #7
0
    def train(self,
              x_ori,
              y,
              transformer,
              seq_type='bucket',
              return_attention=False):
        self.transformer = transformer
        self.feature_keys = list(x_ori.keys())

        if self.train_mode == 'single':
            x = deepcopy(x_ori)
            x_len = [item[-1] for item in x['token']]
            x['token'] = [item[:-1] for item in x['token']]

            # model initialization
            self.single_model.forward()
            logger.info('%s model structure...' % self.model_name)
            self.single_model.model.summary()

            # split dataset
            indices = np.random.permutation(len(x['token']))
            cut_point = int(len(x['token']) * (1 - self.test_size))
            train_idx, valid_idx = indices[:cut_point], indices[cut_point:]
            x_train = {
                k: [x[k][i] for i in train_idx]
                for k in self.feature_keys
            }
            x_valid = {
                k: [x[k][i] for i in valid_idx]
                for k in self.feature_keys
            }
            y_train, y_valid = [y[i]
                                for i in train_idx], [y[i] for i in valid_idx]
            x_len_train, x_len_valid = [x_len[i] for i in train_idx
                                        ], [x_len[i] for i in valid_idx]
            logger.info('train/valid set: {}/{}'.format(
                train_idx.shape[0], valid_idx.shape[0]))

            # transform data to sequence data streamer
            train_batches, valid_batches = self.data_generator(
                seq_type, x_train, x_valid, y_train, y_valid, x_len_train,
                x_len_valid)

            # define callbacks
            history = History(self.metric)
            self.callbacks = get_callbacks(history=history,
                                           metric=self.metric[0],
                                           log_dir=self.checkpoint_path,
                                           valid=valid_batches,
                                           transformer=transformer,
                                           attention=return_attention)

            # model compile
            self.single_model.model.compile(
                loss=self.single_model.get_loss(),
                optimizer=self.optimizer,
                metrics=self.single_model.get_metrics())

            # save transformer and model parameters
            if not self.checkpoint_path.exists():
                self.checkpoint_path.mkdir()
            transformer.save(self.checkpoint_path / 'transformer.h5')
            invalid_params = self.single_model.invalid_params
            param_file = self.checkpoint_path / 'model_parameters.json'
            self.single_model.save_params(param_file, invalid_params)
            logger.info('saving model parameters and transformer to {}'.format(
                self.checkpoint_path))

            # actual training start
            self.single_model.model.fit_generator(
                generator=train_batches,
                epochs=self.max_epoch,
                callbacks=self.callbacks,
                shuffle=self.shuffle,
                validation_data=valid_batches)
            print('best {}: {:04.2f}'.format(
                self.metric[0],
                max(history.metrics[self.metric[0]]) * 100))
            return self.single_model.model, history

        elif self.train_mode == 'fold':
            x = deepcopy(x_ori)
            x_len = [item[-1] for item in x['token']]
            x['token'] = [item[:-1] for item in x['token']]
            x_token_first = x['token'][0]

            fold_size = len(x['token']) // self.fold_cnt
            scores = []
            logger.info('%d-fold starts!' % self.fold_cnt)

            for fold_id in range(self.fold_cnt):
                print('\n------------------------ fold ' + str(fold_id) +
                      '------------------------')

                assert x_token_first == x['token'][0]
                model_init = self.fold_model
                model_init.forward()

                fold_start = fold_size * fold_id
                fold_end = fold_start + fold_size
                if fold_id == fold_size - 1:
                    fold_end = len(x)
                if fold_id == 0:
                    logger.info('%s model structure...' % self.model_name)
                    model_init.model.summary()

                x_train = {
                    k: x[k][:fold_start] + x[k][fold_end:]
                    for k in self.feature_keys
                }
                x_len_train = x_len[:fold_start] + x_len[fold_end:]
                y_train = y[:fold_start] + y[fold_end:]
                x_valid = {
                    k: x[k][fold_start:fold_end]
                    for k in self.feature_keys
                }
                x_len_valid = x_len[fold_start:fold_end]
                y_valid = y[fold_start:fold_end]

                train_batches, valid_batches = self.data_generator(
                    seq_type, x_train, x_valid, y_train, y_valid, x_len_train,
                    x_len_valid)

                history = History(self.metric)
                self.callbacks = get_callbacks(history=history,
                                               metric=self.metric[0],
                                               valid=valid_batches,
                                               transformer=transformer,
                                               attention=return_attention)

                model_init.model.compile(loss=model_init.get_loss(),
                                         optimizer=self.optimizer,
                                         metrics=model_init.get_metrics())

                model_init.model.fit_generator(generator=train_batches,
                                               epochs=self.max_epoch,
                                               callbacks=self.callbacks,
                                               shuffle=self.shuffle,
                                               validation_data=valid_batches)
                scores.append(max(history.metrics[self.metric[0]]))

            logger.info(
                'training finished! The mean {} scores: {:4.2f}(±{:4.2f})'.
                format(self.metric[0],
                       np.mean(scores) * 100,
                       np.std(scores) * 100))
Beispiel #8
0
    def __init__(self, mode, fname='', tran_fname='',
                 config=None, task_type=None, data_format=''):
        self.mode = mode
        self.fname = fname
        self.inner_char = False
        self.use_seg = False
        self.use_radical = False
        self.radical_dict = None
        
        if data_format != '':
            self.data_format = data_format

        if config:
            self.basic_token = config['data']['basic_token']
        self.html_texts = re.compile(r'('+'|'.join(REGEX_STR)+')', re.UNICODE)

        if task_type:
            if mode == 'train' and config is None:
                logger.error('please specify the config file path')
                sys.exit()
            self.task_type = task_type
        else:
            try:
                self.task_type = re.findall(r'config_(\w+)\.yaml', config)[0]
            except:
                logger.error('please check your config filename')
                sys.exit()

        if mode == 'train':
            if 'data' in config:
                self.config = config
                self.data_config = config['data']
                self.embed_config = config['embed']
                if self.task_type == 'sequence':
                    self.data_format = self.data_config['format']
                if self.basic_token == 'word':
                    self.max_tokens = self.data_config['max_words']
                    self.inner_char = self.data_config['inner_char']
                elif self.basic_token == 'char':
                    self.max_tokens = self.data_config['max_chars']
                    if self.task_type == 'sequence_labeling':
                        self.use_seg = self.data_config['use_seg']
                        self.use_radical = self.data_config['use_radical']
                        if self.config['train']['metric'] not in ['f1_seq']:
                            self.config['train']['metric'] = 'f1_seq'
                            logger.warning('sequence labeling task currently only support f1_seq callback')
                    elif self.task_type == 'classification':
                        if self.config['train']['metric'] in ['f1_seq']:
                            self.config['train']['metric'] = 'f1'
                            logger.warning('text classification task not support f1_seq callback, changed to f1')
                else:
                    logger.error('invalid token type, only support word and char')
                    sys.exit()
            else:
                logger.error("please pass in the correct config dict")
                sys.exit()

            if self.basic_token == 'char':
                self.use_seg = config['data']['use_seg']
                self.use_radical = config['data']['use_radical']

            if self.use_radical:
                radical_file = Path(os.path.dirname(
                    os.path.realpath(__file__))) / 'data' / 'radical.txt'
                self.radical_dict = {line.split()[0]: line.split()[1].strip()
                                    for line in open(radical_file, encoding='utf8')}

            self.transformer = IndexTransformer(
                task_type=self.task_type,
                max_tokens=self.max_tokens,
                max_inner_chars=self.data_config['max_inner_chars'],
                use_inner_char=self.inner_char,
                use_seg=self.use_seg,
                use_radical=self.use_radical,
                radical_dict=self.radical_dict,
                basic_token=self.basic_token)

        elif mode != 'train':
            if len(tran_fname) > 0:
                logger.info('transformer loaded')
                self.transformer = IndexTransformer.load(tran_fname)
                self.basic_token = self.transformer.basic_token
                self.use_seg = self.transformer.use_seg
                self.use_radical = self.transformer.use_radical
                self.inner_char = self.transformer.use_inner_char
                self.max_tokens = self.transformer.max_tokens
            else:
                logger.error("please pass in the transformer's filepath")
                sys.exit()

        if fname:
            self.load_data()
            self.fit()
        else:
            self.texts = []
            self.labels = []
Beispiel #9
0
    def fit(self):
        if self.mode != 'predict':
            if self.basic_token == 'char':
                if self.task_type == 'sequence_labeling':
                    self.texts = [
                        word2char(x, y, self.task_type, self.use_seg, self.radical_dict)
                        for x, y in zip(self.texts, self.labels)]
                    self.texts = {k: [dic[k] for dic in self.texts] for k in self.texts[0]}
                    self.labels = self.texts['label']
                    del self.texts['label']
                else:
                    self.texts = {'token': [word2char(x, task_type=self.task_type) for x in self.texts]}
            else:
                self.texts = {'token': self.texts}
            if self.mode == 'train':
                self.config['mode'] = self.mode
                self.transformer.fit(self.texts['token'], self.labels)
                logger.info('transformer fitting complete')
                embed = {}
                if self.embed_config['pre']:
                    token_embed, dim = load_vectors(
                        self.embed_config[self.basic_token]['path'], self.transformer._token_vocab)
                    embed[self.basic_token] = token_embed
                    logger.info('Loaded Pre_trained Embeddings')
                else:
                    logger.info('Use Embeddings from Straching ')
                    dim = self.embed_config[self.basic_token]['dim']
                    embed[self.basic_token] = None
                # update config
                self.config['nb_classes'] = self.transformer.label_size
                self.config['nb_tokens'] = self.transformer.token_vocab_size
                self.config['extra_features'] = []
                if self.inner_char:
                    self.config['nb_char_tokens'] = self.transformer.char_vocab_size
                else:
                    self.config['nb_char_tokens'] = 0
                    self.config['use_inner_char'] = False
                if self.use_seg:
                    self.config['nb_seg_tokens'] = self.transformer.seg_vocab_size
                    self.config['extra_features'].append('seg')
                    self.config['use_seg'] = self.use_seg
                else:
                    self.config['nb_seg_tokens'] = 0
                    self.config['use_seg'] = False
                if self.use_radical:
                    self.config['nb_radical_tokens'] = self.transformer.radical_vocab_size
                    self.config['extra_features'].append('radical')
                    self.config['use_radical'] = self.use_radical
                else:
                    self.config['nb_radical_tokens'] = 0
                    self.config['use_radical'] = False
                self.config['embedding_dim'] = dim
                self.config['token_embeddings'] = embed[self.basic_token]
                self.config['maxlen'] = self.max_tokens
                self.config['task_type'] = self.task_type
        else:
            if self.basic_token == 'char':
                self.texts = [
                    word2char(x, None, self.task_type, self.use_seg, self.radical_dict)
                    for x in self.texts]
                self.texts = {k: [dic[k] for dic in self.texts]
                              for k in self.texts[0]}
            else:
                self.texts = {'token': self.texts}

        lengths = [len(item) if len(item) <= self.max_tokens else self.max_tokens
                   for item in self.texts['token']]
        self.texts['token'] = list(map(list, self.texts['token']))
        self.texts['token'] = [item + [lengths[idx]] for idx, item in enumerate(self.texts['token'])]