Exemple #1
0
 def from_config(cls, config, for_train=True):
     if config["data_type"] == "bert_cased":
         do_lower_case = False
         fn = get_bert_data_loaders
     elif config["data_type"] == "bert_uncased":
         do_lower_case = True
         fn = get_bert_data_loaders
     else:
         raise NotImplementedError("No requested mode :(.")
     if config["train_path"] and config["valid_path"] and for_train:
         fn_res = fn(config["train_path"],
                     config["valid_path"],
                     config["vocab_file"],
                     config["batch_size"],
                     config["cuda"],
                     config["is_cls"],
                     do_lower_case,
                     config["max_seq_len"],
                     config["is_meta"],
                     label2idx=config["label2idx"],
                     cls2idx=config["cls2idx"])
     else:
         fn_res = (None, None,
                   tokenization.FullTokenizer(
                       vocab_file=config["vocab_file"],
                       do_lower_case=do_lower_case), config["label2idx"],
                   config["max_seq_len"], config["cls2idx"])
     return cls(config["train_path"],
                config["valid_path"],
                config["vocab_file"],
                config["data_type"],
                *fn_res,
                batch_size=config["batch_size"],
                cuda=config["cuda"],
                is_meta=config["is_meta"])
Exemple #2
0
def get_bert_data_loaders(train,
                          valid,
                          vocab_file,
                          batch_size=16,
                          cuda=True,
                          is_cls=False,
                          do_lower_case=False,
                          max_seq_len=424,
                          is_meta=False,
                          label2idx=None,
                          cls2idx=None):
    train = pd.read_csv(train)
    valid = pd.read_csv(valid)

    tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file,
                                           do_lower_case=do_lower_case)
    train_f, label2idx = get_data(train,
                                  tokenizer,
                                  label2idx,
                                  cls2idx=cls2idx,
                                  is_cls=is_cls,
                                  max_seq_len=max_seq_len,
                                  is_meta=is_meta)
    if is_cls:
        label2idx, cls2idx = label2idx
    train_dl = DataLoaderForTrain(train_f,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  cuda=cuda)
    valid_f, label2idx = get_data(valid,
                                  tokenizer,
                                  label2idx,
                                  cls2idx=cls2idx,
                                  is_cls=is_cls,
                                  max_seq_len=max_seq_len,
                                  is_meta=is_meta)
    if is_cls:
        label2idx, cls2idx = label2idx
    valid_dl = DataLoaderForTrain(valid_f,
                                  batch_size=batch_size,
                                  cuda=cuda,
                                  shuffle=False)
    if is_cls:
        return train_dl, valid_dl, tokenizer, label2idx, max_seq_len, cls2idx
    return train_dl, valid_dl, tokenizer, label2idx, max_seq_len
Exemple #3
0
    def create(cls,
               train_path,
               valid_path,
               vocab_file,
               batch_size=16,
               cuda=True,
               is_cls=False,
               data_type="bert_cased",
               max_seq_len=424,
               is_meta=False,
               for_train=True,
               label2idx={},
               cls2idx={}):
        if data_type == "bert_cased":
            do_lower_case = False
            fn = get_bert_data_loaders
        elif data_type == "bert_uncased":
            do_lower_case = True
            fn = get_bert_data_loaders
        else:
            raise NotImplementedError("No requested mode :(.")

        if for_train:
            return cls(train_path,
                       valid_path,
                       vocab_file,
                       data_type,
                       *fn(train_path, valid_path, vocab_file, batch_size,
                           cuda, is_cls, do_lower_case, max_seq_len, is_meta),
                       batch_size=batch_size,
                       cuda=cuda,
                       is_meta=is_meta)
        else:
            return cls(train_path,
                       valid_path,
                       vocab_file,
                       data_type,
                       *(None, None,
                         tokenization.FullTokenizer(
                             vocab_file, do_lower_case=do_lower_case),
                         label2idx, max_seq_len, cls2idx),
                       batch_size=batch_size,
                       cuda=cuda,
                       is_meta=is_meta)
Exemple #4
0
    def create(cls,
               bert_vocab_file,
               config_path=None,
               train_path=None,
               valid_path=None,
               idx2label=None,
               bert_model_type="bert_cased",
               idx2cls=None,
               max_seq_len=424,
               batch_size=16,
               is_cls=False,
               idx2label_path=None,
               idx2cls_path=None,
               pad="<pad>",
               device="cuda:0",
               clear_cache=True,
               data_columns=["0", "1", "2"],
               shuffle=True,
               dir_config=None,
               prc_text=preprocess_text):
        """
        Create or skip data loaders, load or create vocabs.
        DataFrame should has 2 or 3 columns. Structure see in data_columns description.

        Parameters
        ----------
        bert_vocab_file : str
            Path of vocabulary for BERT tokenizer.
        config_path : str, or None, optional (default=None)
            Path of config of BertNerData.
        train_path : str or None, optional (default=None)
            Path of train data frame. If not None update idx2label, idx2cls, idx2meta.
        valid_path : str or None, optional (default=None)
            Path of valid data frame. If not None update idx2label, idx2cls, idx2meta.
        idx2label : list or None, optional (default=None)
            Map form index to label.
        bert_model_type : str, optional (default="bert_cased")
            Mode of BERT model (CASED or UNCASED).
        idx2cls : list or None, optional (default=None)
            Map form index to cls.
        max_seq_len : int, optional (default=424)
            Max sequence length.
        batch_size : int, optional (default=16)
            Batch size.
        is_cls : bool, optional (default=False)
            Use joint model or single.
        idx2label_path : str or None, optional (default=None)
            Path to idx2label map. If not None and idx2label is None load idx2label.
        idx2cls_path : str or None, optional (default=None)
            Path to idx2cls map. If not None and idx2cls is None load idx2cls.
        pad : str, optional (default="<pad>")
            Padding token.
        device : str, optional (default="cuda:0")
            Run model on gpu or cpu. If "cpu" don't pin tensors in data loaders to gpu.
            Notation similar as torch.cuda.device.
        clear_cache : bool, optional (default=True)
            If True, rewrite all vocabs and BertNerData config.
        data_columns : list[str]
            Columns if pandas.DataFrame.
                data_columns[0] - represent labels column. Each label should be joined by space;
                data_columns[1] - represent tokens column. Input sequence should be tokenized and joined by space;
                data_columns[2] - represent cls column (if is_cls is not None).
        shuffle : bool, optional (default=True)
            Is shuffle data.
        dir_config : str or None, optional (default=None)
            Dir for store vocabs if paths is not set.
        prc_text : callable, optional (default=preprocess_text)
            Function for preprocess text. By default remove some bad unicode words.
            Note. don't see in word. Remove only full match bad symbol with word.

        Returns
        ----------
        data : BertNerData
            Created object of BertNerData.
        """
        idx2label_path = if_none(
            idx2label_path,
            os.path.join(dir_config, "idx2label.json")
            if dir_config is not None else None)

        if idx2label is None and idx2label_path is None:
            raise ValueError("Must set idx2label_path.")

        if bert_model_type == "bert_cased":
            do_lower_case = False
        elif bert_model_type == "bert_uncased":
            do_lower_case = True
        else:
            raise NotImplementedError("No requested mode :(.")

        tokenizer = tokenization.FullTokenizer(vocab_file=bert_vocab_file,
                                               do_lower_case=do_lower_case)

        if idx2label is None and os.path.exists(
                str(idx2label_path)) and not clear_cache:
            idx2label = read_json(idx2label_path)
        if is_cls:
            idx2cls_path = if_none(
                idx2cls_path,
                os.path.join(dir_config, "idx2cls.json")
                if dir_config is not None else None)
        if is_cls and idx2cls is None and os.path.exists(
                str(idx2cls_path)) and not clear_cache:
            idx2cls = read_json(idx2cls_path)

        config_path = if_none(
            config_path,
            os.path.join(dir_config, "data_ner.json")
            if dir_config is not None else None)

        data = cls(bert_vocab_file=bert_vocab_file,
                   train_path=train_path,
                   valid_path=valid_path,
                   idx2label=idx2label,
                   config_path=config_path,
                   tokenizer=tokenizer,
                   bert_model_type=bert_model_type,
                   idx2cls=idx2cls,
                   max_seq_len=max_seq_len,
                   batch_size=batch_size,
                   is_cls=is_cls,
                   idx2label_path=idx2label_path,
                   idx2cls_path=idx2cls_path,
                   pad=pad,
                   device=device,
                   data_columns=data_columns,
                   shuffle=shuffle,
                   prc_text=prc_text)

        if train_path is not None:
            _ = data.load_train_dl(train_path)

        if valid_path is not None:
            _ = data.load_valid_dl(valid_path)

        if clear_cache:
            data.save_vocabs_and_config()
        return data