Пример #1
0
def main():
    arg_parser = argparse.ArgumentParser(
        "Minimalist Transformer for Generation")
    arg_parser.add_argument("mode",
                            choices=["preprocess", "train"],
                            help="train a model or test or translate")
    arg_parser.add_argument('-f',
                            '--config',
                            default='default.yaml',
                            help='Configuration file to load.')
    args = arg_parser.parse_args()
    configs = yaml.load(open(args.config).read(), Loader=yaml.FullLoader)

    if args.mode == "train":
        train_manager(configs)

    elif args.mode == "preprocess":
        train, _, test = reverse_dataset()
        text_encoder = WhitespaceEncoder(train['source'] + train['target'])
        text_encoder.stoi['</s>'] = 2
        print(text_encoder.stoi)
        with open('.preprocess.pkl', 'wb') as filehandler:
            pickle.dump((text_encoder, train, test), filehandler)

    else:
        raise ValueError("Unknown mode")
Пример #2
0
def prepare_sample(
    sample: dict, text_encoder: WhitespaceEncoder
) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
    """
    Function that receives a sample from the Dataset iterator and prepares t
    he input to feed the transformer model.

    :param sample: dictionary containing the inputs to build the batch 
        (e.g: [{'source': '9 0', 'target': '0 9'}, {'source': '34 3 4', 'target': '4 3 34'}])
    :param text_encoder: Torch NLP text encoder for tokenization and vectorization.
    """
    sample = collate_tensors(sample)
    input_seqs, input_lengths = text_encoder.batch_encode(sample['source'])
    target_seqs, target_lengths = text_encoder.batch_encode(sample['target'])
    # bos tokens to initialize decoder
    bos_tokens = torch.full([target_seqs.size(0), 1],
                            text_encoder.stoi['<s>'],
                            dtype=torch.long)
    shifted_target = torch.cat((bos_tokens, target_seqs[:, :-1]), dim=1)
    return input_seqs, input_lengths, target_seqs, shifted_target, target_lengths
Пример #3
0
    def _get_tokenizer(self, X1, X2):
        tokenizer_pickle_file = 'word2vec/data_tokenizer.pickle'

        if os.path.isfile(tokenizer_pickle_file):
            with open(tokenizer_pickle_file, 'rb') as handle:
                tokenizer = pickle.load(handle)
        else:
            X1 = list(map(str, list(X1)))
            X2 = list(map(str, list(X2)))

            # use all sentence to build dictionary
            tokenizer = WhitespaceEncoder(X1 + X2)

            with open(tokenizer_pickle_file, 'wb') as handle:
                pickle.dump(tokenizer,
                            handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

        return tokenizer
def prepare_sample(
    sample: dict, text_encoder: WhitespaceEncoder, label_encoder: LabelEncoder,
    max_length: int
) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
    """
    Function that receives a sample from the Dataset iterator and prepares t
    he input to feed the transformer model.
    :param sample: dictionary containing the inputs to build the batch 
        (e.g: [{'source': 'This flight was amazing!', 'target': 'pos'}, 
               {'source': 'I hate Iberia', 'target': 'neg'}])
    :param text_encoder: Torch NLP text encoder for tokenization and vectorization.
    :param label_encoder: Torch NLP label encoder for vectorization of labels.
    :param max_length: Max length of the input sequences.
         If a sequence passes that value it is truncated.
    """
    sample = collate_tensors(sample)
    input_seqs, input_lengths = text_encoder.batch_encode(sample['source'])
    target_seqs = label_encoder.batch_encode(sample['target'])
    # Truncate Inputs
    if input_seqs.size(1) > max_length:
        input_seqs = input_seqs[:, :max_length]
    input_mask = lengths_to_mask(input_lengths).unsqueeze(1)
    return input_seqs, input_mask, target_seqs
Пример #5
0
#!/usr/bin/python
import torch
import numpy as np
import xgboost as xgb
import pandas as pd
from torchnlp.encoders.text import WhitespaceEncoder
from torchnlp.samplers import BucketBatchSampler
from torchnlp.utils import collate_tensors
from torchnlp.encoders.text import stack_and_pad_tensors
from torchnlp.nn import LockedDropout

loaded_data = ["now this ain't funny", "so don't you dare laugh"]
encoder = WhitespaceEncoder(loaded_data)
encoded_data = [encoder.encode(example) for example in loaded_data]

print("encoded_data", encoded_data)

encoded_data = [torch.randn(2), torch.randn(3), torch.randn(4), torch.randn(5)]

train_sampler = torch.utils.data.sampler.SequentialSampler(encoded_data)
train_batch_sampler = BucketBatchSampler(
    train_sampler,
    batch_size=2,
    drop_last=False,
    sort_key=lambda i: encoded_data[i].shape[0])

batches = [[encoded_data[i] for i in batch] for batch in train_batch_sampler]
batches = [
    collate_tensors(batch, stack_tensors=stack_and_pad_tensors)
    for batch in batches
]
Пример #6
0
def encoder(input_):
    return WhitespaceEncoder([input_])
Пример #7
0
    torch.cuda.set_device(args.gpu)

# load dataset
train, dev, test = snli_dataset(train=True, dev=True, test=True)

# Preprocess
for row in itertools.chain(train, dev, test):
    row['premise'] = row['premise'].lower()
    row['hypothesis'] = row['hypothesis'].lower()

# Make Encoders
sentence_corpus = [row['premise'] for row in itertools.chain(train, dev, test)]
sentence_corpus += [
    row['hypothesis'] for row in itertools.chain(train, dev, test)
]
sentence_encoder = WhitespaceEncoder(sentence_corpus)

label_corpus = [row['label'] for row in itertools.chain(train, dev, test)]
label_encoder = LabelEncoder(label_corpus)

# Encode
for row in itertools.chain(train, dev, test):
    row['premise'] = sentence_encoder.encode(row['premise'])
    row['hypothesis'] = sentence_encoder.encode(row['hypothesis'])
    row['label'] = label_encoder.encode(row['label'])

config = args
config.n_embed = sentence_encoder.vocab_size
config.d_out = label_encoder.vocab_size
config.n_cells = config.n_layers
Пример #8
0
import torch
from torchnlp.encoders.text import WhitespaceEncoder
from torchnlp.word_to_vector import GloVe

encoder = WhitespaceEncoder(["now this ain't funny", "so don't you dare laugh"])

vocab = set(encoder.vocab)
pretrained_embedding = GloVe(name='6B', dim=300, is_include=lambda w: w in vocab)
embedding_weights = torch.Tensor(encoder.vocab_size, pretrained_embedding.dim)
for i, token in enumerate(encoder.vocab):
    embedding_weights[i] = pretrained_embedding[token]
print("")
Пример #9
0
    def __init__(
        self,
        markdown_lines: List[str],
        tokenizer,
        seq_len=128,
    ):
        self.intent_dict = {}
        self.entity_dict = {}
        self.entity_dict["O"] = 0  # using BIO tagging

        self.dataset = []
        self.seq_len = seq_len

        intent_value_list = []
        entity_type_list = []

        current_intent_focus = ""

        text_list = []

        for line in tqdm(
                markdown_lines,
                desc=
                "Organizing Intent & Entity dictionary in NLU markdown file ...",
        ):
            if len(line.strip()) < 2:
                current_intent_focus = ""
                continue

            if "## " in line:
                if "intent:" in line:
                    intent_value_list.append(line.split(":")[1].strip())
                    current_intent_focus = line.split(":")[1].strip()
                else:
                    current_intent_focus = ""

            else:
                if current_intent_focus != "":
                    text = line[2:].strip().lower()

                    for type_str in re.finditer(r"\([a-zA-Z_1-2]+\)", text):
                        entity_type = (text[type_str.start() +
                                            1:type_str.end() - 1].replace(
                                                "(", "").replace(")", ""))
                        entity_type_list.append(entity_type)

                    text = re.sub(r"\([a-zA-Z_1-2]+\)", "",
                                  text)  # remove (...) str
                    text = text.replace("[", "").replace(
                        "]", "")  # remove '[',']' special char

                    if len(text) > 0:
                        text_list.append(text.strip())

        #dataset tokenizer setting
        if "ElectraTokenizer" in str(type(tokenizer)):
            self.tokenizer = tokenizer
            self.pad_token_id = 0
            self.unk_token_id = 1
            self.eos_token_id = 3  #[SEP] token
            self.bos_token_id = 2  #[CLS] token

        else:
            if tokenizer == 'char':
                self.tokenizer = CharacterEncoder(text_list)

                # torchnlp base special token indices
                self.pad_token_id = 0
                self.unk_token_id = 1
                self.eos_token_id = 2
                self.bos_token_id = 3
            elif tokenizer == 'space':
                self.tokenizer = WhitespaceEncoder(text_list)

                # torchnlp base special token indices
                self.pad_token_id = 0
                self.unk_token_id = 1
                self.eos_token_id = 2
                self.bos_token_id = 3
            elif tokenizer == 'kobert':
                self.tokenizer = kobert_tokenizer()
                self.pad_token_id = 1
                self.unk_token_id = 0
                self.eos_token_id = 3  #[SEP] token
                self.bos_token_id = 2  #[CLS] token
            else:
                raise ValueError('not supported tokenizer type')

        intent_value_list = sorted(intent_value_list)
        for intent_value in intent_value_list:
            if intent_value not in self.intent_dict.keys():
                self.intent_dict[intent_value] = len(self.intent_dict)

        entity_type_list = sorted(entity_type_list)
        for entity_type in entity_type_list:
            if entity_type + '_B' not in self.entity_dict.keys():
                self.entity_dict[str(entity_type) + '_B'] = len(
                    self.entity_dict)
            if entity_type + '_I' not in self.entity_dict.keys():
                self.entity_dict[str(entity_type) + '_I'] = len(
                    self.entity_dict)

        current_intent_focus = ""

        for line in tqdm(
                markdown_lines,
                desc="Extracting Intent & Entity in NLU markdown files...",
        ):
            if len(line.strip()) < 2:
                current_intent_focus = ""
                continue

            if "## " in line:
                if "intent:" in line:
                    current_intent_focus = line.split(":")[1].strip()
                else:
                    current_intent_focus = ""
            else:
                if current_intent_focus != "":  # intent & entity sentence occur case
                    text = line[2:].strip().lower()

                    entity_value_list = []
                    for value in re.finditer(r"\[(.*?)\]", text):
                        entity_value_list.append(
                            text[value.start() + 1:value.end() - 1].replace(
                                "[", "").replace("]", ""))

                    entity_type_list = []
                    for type_str in re.finditer(r"\([a-zA-Z_1-2]+\)", text):
                        entity_type = (text[type_str.start() +
                                            1:type_str.end() - 1].replace(
                                                "(", "").replace(")", ""))
                        entity_type_list.append(entity_type)

                    text = re.sub(r"\([a-zA-Z_1-2]+\)", "",
                                  text)  # remove (...) str
                    text = text.replace("[", "").replace(
                        "]", "")  # remove '[',']' special char

                    if len(text) > 0:
                        each_data_dict = {}
                        each_data_dict["text"] = text.strip()
                        each_data_dict["intent"] = current_intent_focus
                        each_data_dict["intent_idx"] = self.intent_dict[
                            current_intent_focus]
                        each_data_dict["entities"] = []

                        for value, type_str in zip(entity_value_list,
                                                   entity_type_list):
                            for entity in re.finditer(value, text):
                                entity_tokens = self.tokenize(value)

                                for i, entity_token in enumerate(
                                        entity_tokens):
                                    if i == 0:
                                        BIO_type_str = type_str + '_B'
                                    else:
                                        BIO_type_str = type_str + '_I'

                                    each_data_dict["entities"].append({
                                        "start":
                                        text.find(entity_token, entity.start(),
                                                  entity.end()),
                                        "end":
                                        text.find(entity_token, entity.start(),
                                                  entity.end()) +
                                        len(entity_token),
                                        "entity":
                                        type_str,
                                        "value":
                                        entity_token,
                                        "entity_idx":
                                        self.entity_dict[BIO_type_str],
                                    })

                        self.dataset.append(each_data_dict)

        print(f"Intents: {self.intent_dict}")
        print(f"Entities: {self.entity_dict}")
Пример #10
0
class RasaIntentEntityDataset(torch.utils.data.Dataset):
    """
    RASA NLU markdown file lines based Custom Dataset Class

    Dataset Example in nlu.md

    ## intent:intent_데이터_자동_선물하기_멀티턴                <- intent name
    - T끼리 데이터 주기적으로 보내기                            <- utterance without entity
    - 인터넷 데이터 [달마다](Every_Month)마다 보내줄 수 있어?    <- utterance with entity
    
    """
    def __init__(
        self,
        markdown_lines: List[str],
        tokenizer,
        seq_len=128,
    ):
        self.intent_dict = {}
        self.entity_dict = {}
        self.entity_dict["O"] = 0  # using BIO tagging

        self.dataset = []
        self.seq_len = seq_len

        intent_value_list = []
        entity_type_list = []

        current_intent_focus = ""

        text_list = []

        for line in tqdm(
                markdown_lines,
                desc=
                "Organizing Intent & Entity dictionary in NLU markdown file ...",
        ):
            if len(line.strip()) < 2:
                current_intent_focus = ""
                continue

            if "## " in line:
                if "intent:" in line:
                    intent_value_list.append(line.split(":")[1].strip())
                    current_intent_focus = line.split(":")[1].strip()
                else:
                    current_intent_focus = ""

            else:
                if current_intent_focus != "":
                    text = line[2:].strip().lower()

                    for type_str in re.finditer(r"\([a-zA-Z_1-2]+\)", text):
                        entity_type = (text[type_str.start() +
                                            1:type_str.end() - 1].replace(
                                                "(", "").replace(")", ""))
                        entity_type_list.append(entity_type)

                    text = re.sub(r"\([a-zA-Z_1-2]+\)", "",
                                  text)  # remove (...) str
                    text = text.replace("[", "").replace(
                        "]", "")  # remove '[',']' special char

                    if len(text) > 0:
                        text_list.append(text.strip())

        #dataset tokenizer setting
        if "ElectraTokenizer" in str(type(tokenizer)):
            self.tokenizer = tokenizer
            self.pad_token_id = 0
            self.unk_token_id = 1
            self.eos_token_id = 3  #[SEP] token
            self.bos_token_id = 2  #[CLS] token

        else:
            if tokenizer == 'char':
                self.tokenizer = CharacterEncoder(text_list)

                # torchnlp base special token indices
                self.pad_token_id = 0
                self.unk_token_id = 1
                self.eos_token_id = 2
                self.bos_token_id = 3
            elif tokenizer == 'space':
                self.tokenizer = WhitespaceEncoder(text_list)

                # torchnlp base special token indices
                self.pad_token_id = 0
                self.unk_token_id = 1
                self.eos_token_id = 2
                self.bos_token_id = 3
            elif tokenizer == 'kobert':
                self.tokenizer = kobert_tokenizer()
                self.pad_token_id = 1
                self.unk_token_id = 0
                self.eos_token_id = 3  #[SEP] token
                self.bos_token_id = 2  #[CLS] token
            else:
                raise ValueError('not supported tokenizer type')

        intent_value_list = sorted(intent_value_list)
        for intent_value in intent_value_list:
            if intent_value not in self.intent_dict.keys():
                self.intent_dict[intent_value] = len(self.intent_dict)

        entity_type_list = sorted(entity_type_list)
        for entity_type in entity_type_list:
            if entity_type + '_B' not in self.entity_dict.keys():
                self.entity_dict[str(entity_type) + '_B'] = len(
                    self.entity_dict)
            if entity_type + '_I' not in self.entity_dict.keys():
                self.entity_dict[str(entity_type) + '_I'] = len(
                    self.entity_dict)

        current_intent_focus = ""

        for line in tqdm(
                markdown_lines,
                desc="Extracting Intent & Entity in NLU markdown files...",
        ):
            if len(line.strip()) < 2:
                current_intent_focus = ""
                continue

            if "## " in line:
                if "intent:" in line:
                    current_intent_focus = line.split(":")[1].strip()
                else:
                    current_intent_focus = ""
            else:
                if current_intent_focus != "":  # intent & entity sentence occur case
                    text = line[2:].strip().lower()

                    entity_value_list = []
                    for value in re.finditer(r"\[(.*?)\]", text):
                        entity_value_list.append(
                            text[value.start() + 1:value.end() - 1].replace(
                                "[", "").replace("]", ""))

                    entity_type_list = []
                    for type_str in re.finditer(r"\([a-zA-Z_1-2]+\)", text):
                        entity_type = (text[type_str.start() +
                                            1:type_str.end() - 1].replace(
                                                "(", "").replace(")", ""))
                        entity_type_list.append(entity_type)

                    text = re.sub(r"\([a-zA-Z_1-2]+\)", "",
                                  text)  # remove (...) str
                    text = text.replace("[", "").replace(
                        "]", "")  # remove '[',']' special char

                    if len(text) > 0:
                        each_data_dict = {}
                        each_data_dict["text"] = text.strip()
                        each_data_dict["intent"] = current_intent_focus
                        each_data_dict["intent_idx"] = self.intent_dict[
                            current_intent_focus]
                        each_data_dict["entities"] = []

                        for value, type_str in zip(entity_value_list,
                                                   entity_type_list):
                            for entity in re.finditer(value, text):
                                entity_tokens = self.tokenize(value)

                                for i, entity_token in enumerate(
                                        entity_tokens):
                                    if i == 0:
                                        BIO_type_str = type_str + '_B'
                                    else:
                                        BIO_type_str = type_str + '_I'

                                    each_data_dict["entities"].append({
                                        "start":
                                        text.find(entity_token, entity.start(),
                                                  entity.end()),
                                        "end":
                                        text.find(entity_token, entity.start(),
                                                  entity.end()) +
                                        len(entity_token),
                                        "entity":
                                        type_str,
                                        "value":
                                        entity_token,
                                        "entity_idx":
                                        self.entity_dict[BIO_type_str],
                                    })

                        self.dataset.append(each_data_dict)

        print(f"Intents: {self.intent_dict}")
        print(f"Entities: {self.entity_dict}")

    def tokenize(self, text: str, skip_special_char=True):
        if isinstance(self.tokenizer, CharacterEncoder):
            return [char for char in text]
        elif isinstance(self.tokenizer, WhitespaceEncoder):
            return text.split()
        elif "KoBertTokenizer" in str(type(self.tokenizer)):
            if skip_special_char:
                return self.tokenizer.tokenize(text)
            else:
                return [
                    token.replace('▁', '')
                    for token in self.tokenizer.tokenize(text)
                ]
        elif "ElectraTokenizer" in str(type(self.tokenizer)):
            if skip_special_char:
                return self.tokenizer.tokenize(text)
            else:
                return [
                    token.replace('#', '')
                    for token in self.tokenizer.tokenize(text)
                ]
        else:
            raise ValueError('not supported tokenizer type')

    def encode(self,
               text: str,
               padding: bool = True,
               return_tensor: bool = True):
        tokens = self.tokenizer.encode(text)
        if type(tokens) == list:
            tokens = torch.tensor(tokens).long()
        else:
            tokens = tokens.long()

        # kobert_tokenizer & koelectra tokenize append [CLS](2) token to start and [SEP](3) token to end
        if isinstance(self.tokenizer, CharacterEncoder) or isinstance(
                self.tokenizer, WhitespaceEncoder):
            bos_tensor = torch.tensor([self.bos_token_id])
            eos_tensor = torch.tensor([self.eos_token_id])
            tokens = torch.cat((bos_tensor, tokens, eos_tensor), 0)

        if padding:
            if len(tokens) >= self.seq_len:
                tokens = tokens[:self.seq_len]
            else:
                pad_tensor = torch.tensor([self.pad_token_id] *
                                          (self.seq_len - len(tokens)))

                tokens = torch.cat((tokens, pad_tensor), 0)

        if return_tensor:
            return tokens
        else:
            return tokens.numpy()

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        tokens = self.encode(self.dataset[idx]["text"])

        intent_idx = torch.tensor([self.dataset[idx]["intent_idx"]])

        entity_idx = np.array(self.seq_len * [0])  # O tag indicate 0(zero)

        for entity_info in self.dataset[idx]["entities"]:
            if isinstance(self.tokenizer, CharacterEncoder):
                # Consider [CLS](bos) token
                for i in range(entity_info["start"] + 1,
                               entity_info["end"] + 2):
                    entity_idx[i] = entity_info["entity_idx"]

            elif isinstance(self.tokenizer, WhitespaceEncoder):
                ##check whether entity value is include in space splitted token
                for token_seq, token_value in enumerate(tokens):
                    # Consider [CLS](bos) token
                    if token_seq == 0:
                        continue

                    for entity_seq, entity_info in enumerate(
                            self.dataset[idx]["entities"]):
                        if entity_info["value"] in self.tokenizer.vocab[
                                token_value.item()]:
                            entity_idx[token_seq] = entity_info["entity_idx"]
                            break

            elif "KoBertTokenizer" in str(type(self.tokenizer)):
                ##check whether entity value is include in splitted token
                for token_seq, token_value in enumerate(tokens):
                    # Consider [CLS](bos) token
                    if token_seq == 0:
                        continue

                    for entity_seq, entity_info in enumerate(
                            self.dataset[idx]["entities"]):
                        if (self.tokenizer.idx2token[token_value.item()]
                                in entity_info["value"]):
                            entity_idx[token_seq] = entity_info["entity_idx"]
                            break

            elif "ElectraTokenizer" in str(type(self.tokenizer)):
                ##check whether entity value is include in splitted token
                for token_seq, token_value in enumerate(tokens):
                    # Consider [CLS](bos) token
                    if token_seq == 0:
                        continue

                    for entity_seq, entity_info in enumerate(
                            self.dataset[idx]["entities"]):
                        if (self.tokenizer.convert_ids_to_tokens(
                            [token_value.item()])[0] in entity_info["value"]):
                            entity_idx[token_seq] = entity_info["entity_idx"]
                            break

        entity_idx = torch.from_numpy(entity_idx)

        return tokens, intent_idx, entity_idx, self.dataset[idx]["text"]

    def get_intent_idx(self):
        return self.intent_dict

    def get_entity_idx(self):
        return self.entity_dict

    def get_vocab_size(self):
        return self.tokenizer.vocab_size

    def get_seq_len(self):
        return self.seq_len
Пример #11
0
  torch.cuda.set_device(args.gpu)

# load dataset
train, dev, test = wmt_dataset(train=True, dev=True, test=True)

src_key = 'en'
tar_key = 'de'

# Preprocess
for row in itertools.chain(train, dev, test):
  row[src_key] = row[src_key].lower()
  row[tar_key] = row[tar_key].lower()

# Make Encoders
src_corpus = [row[src_key] for row in itertools.chain(train, dev, test)]
src_encoder = WhitespaceEncoder(src_corpus)

tar_corpus = [row[tar_key] for row in itertools.chain(train, dev, test)]
tar_encoder = WhitespaceEncoder(tar_corpus)

# Encode
for row in itertools.chain(train, dev, test):
  row[src_key] = src_encoder.encode(row[src_key])
  row[tar_key] = tar_encoder.encode(row[tar_key])

# DONE UP TO HERE

config = args
config.n_embed = sentence_encoder.vocab_size
config.d_out = label_encoder.vocab_size
config.n_cells = config.n_layers
Пример #12
0
def load_data(data_type,
              preprocessing=False,
              fine_grained=False,
              verbose=False,
              text_length=5000,
              encode=True):
    if data_type == 'imdb':
        train_data, test_data = imdb_dataset(preprocessing=preprocessing,
                                             verbose=verbose,
                                             text_length=text_length)
    elif data_type == 'newsgroups':
        train_data, test_data = newsgroups_dataset(preprocessing=preprocessing,
                                                   verbose=verbose,
                                                   text_length=text_length)
    elif data_type == 'reuters':
        train_data, test_data = reuters_dataset(preprocessing=preprocessing,
                                                fine_grained=fine_grained,
                                                verbose=verbose,
                                                text_length=text_length)
    elif data_type == 'webkb':
        train_data, test_data = webkb_dataset(preprocessing=preprocessing,
                                              verbose=verbose,
                                              text_length=text_length)
    elif data_type == 'cade':
        train_data, test_data = cade_dataset(preprocessing=preprocessing,
                                             verbose=verbose,
                                             text_length=text_length)
    elif data_type == 'dbpedia':
        train_data, test_data = dbpedia_dataset(preprocessing=preprocessing,
                                                verbose=verbose,
                                                text_length=text_length)
    elif data_type == 'agnews':
        train_data, test_data = agnews_dataset(preprocessing=preprocessing,
                                               verbose=verbose,
                                               text_length=text_length)
    elif data_type == 'yahoo':
        train_data, test_data = yahoo_dataset(preprocessing=preprocessing,
                                              verbose=verbose,
                                              text_length=text_length)
    elif data_type == 'sogou':
        train_data, test_data = sogou_dataset(preprocessing=preprocessing,
                                              verbose=verbose,
                                              text_length=text_length)
    elif data_type == 'yelp':
        train_data, test_data = yelp_dataset(preprocessing=preprocessing,
                                             fine_grained=fine_grained,
                                             verbose=verbose,
                                             text_length=text_length)
    elif data_type == 'amazon':
        train_data, test_data = amazon_dataset(preprocessing=preprocessing,
                                               fine_grained=fine_grained,
                                               verbose=verbose,
                                               text_length=text_length)
    else:
        raise ValueError('{} data type not supported.'.format(data_type))

    if encode:
        sentence_corpus = [
            row['text'] for row in datasets_iterator(train_data, )
        ]
        sentence_encoder = WhitespaceEncoder(
            sentence_corpus,
            reserved_tokens=[DEFAULT_PADDING_TOKEN, DEFAULT_UNKNOWN_TOKEN])
        label_corpus = [
            row['label'] for row in datasets_iterator(train_data, )
        ]
        label_encoder = LabelEncoder(label_corpus, reserved_labels=[])

        # Encode
        for row in datasets_iterator(train_data, test_data):
            row['text'] = sentence_encoder.encode(row['text'])
            row['label'] = label_encoder.encode(row['label'])
        return sentence_encoder, label_encoder, train_data, test_data
    else:
        return train_data, test_data