Beispiel #1
0
    def prepare_data(self):
        """
        Downloads the ag_news or 20newsgroup dataset and initializes bert tokenizer
        """
        np.random.seed(self.RANDOM_SEED)
        torch.manual_seed(self.RANDOM_SEED)

        if self.dataset == "20newsgroups":
            num_samples = self.args["num_samples"]
            self.news_group_df = (
                get_20newsgroups(num_samples)
                if self.args["dataset"] == "20newsgroups"
                else get_ag_news(num_samples)
            )
        else:
            train_iter, test_iter = AG_NEWS()
            self.train_dataset = to_map_style_dataset(train_iter)
            self.test_dataset = to_map_style_dataset(test_iter)

        self.tokenizer = BertTokenizer.from_pretrained(self.PRE_TRAINED_MODEL_NAME)
Beispiel #2
0
# is used here to adjust the learning rate through epochs.
#

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# Hyperparameters
EPOCHS = 10  # epoch
LR = 5  # learning rate
BATCH_SIZE = 64  # batch size for training

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None
train_iter, test_iter = AG_NEWS()
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = \
    random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset,
                             batch_size=BATCH_SIZE,
Beispiel #3
0
def main():

    num_args = len(sys.argv)

    # Checking if filename input is specified
    if num_args < 2:
        sys.exit("Please specify an input file")

    filename = str(sys.argv[1])
    p = Path(filename)

    # Checking if filepath is valid and/or file exists
    if not (p.exists()):
        sys.exit("File not found")

    # Prepare data processing pipelines
    tokenizer = get_tokenizer('basic_english')
    train_iter = AG_NEWS(split='train')

    vocab = build_vocab_from_iterator(yield_tokens(train_iter, tokenizer),
                                      specials=["<unk>"])
    vocab.set_default_index(vocab["<unk>"])

    text_pipeline = lambda x: vocab(tokenizer(x))
    label_pipeline = lambda x: int(x) - 1

    # Generate data batch and iterator
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def collate_batch(batch):
        label_list, text_list, offsets = [], [], [0]
        for (_label, _text) in batch:
            label_list.append(label_pipeline(_label))
            processed_text = torch.tensor(text_pipeline(_text),
                                          dtype=torch.int64)
            text_list.append(processed_text)
            offsets.append(processed_text.size(0))
        label_list = torch.tensor(label_list, dtype=torch.int64)
        offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
        text_list = torch.cat(text_list)
        return label_list.to(device), text_list.to(device), offsets.to(device)

    # This variable needs to be initialized twice or else an IndexError occurs
    train_iter = AG_NEWS(split='train')
    dataloader = DataLoader(train_iter,
                            batch_size=8,
                            shuffle=False,
                            collate_fn=collate_batch)

    # Build an instance
    num_class = len(set([label for (label, text) in train_iter]))
    vocab_size = len(vocab)
    emsize = 64
    model = TextClassificationModel(vocab_size, emsize, num_class).to(device)

    # Split the dataset and run the model
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
    total_accu = None
    train_iter, test_iter = AG_NEWS()
    train_dataset = to_map_style_dataset(train_iter)
    test_dataset = to_map_style_dataset(test_iter)
    num_train = int(len(train_dataset) * 0.95)
    split_train_, split_valid_ = \
        random_split(train_dataset,
        [num_train, len(train_dataset) - num_train])

    train_dataloader = DataLoader(split_train_,
                                  batch_size=BATCH_SIZE,
                                  shuffle=True,
                                  collate_fn=collate_batch)
    valid_dataloader = DataLoader(split_valid_,
                                  batch_size=BATCH_SIZE,
                                  shuffle=True,
                                  collate_fn=collate_batch)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=BATCH_SIZE,
                                 shuffle=True,
                                 collate_fn=collate_batch)

    # Run epochs
    for epoch in range(1, EPOCHS + 1):
        epoch_start_time = time.time()
        train(train_dataloader, model, optimizer, criterion, epoch)
        accu_val = evaluate(valid_dataloader, model, criterion)
        if total_accu is not None and total_accu > accu_val:
            scheduler.step()
        else:
            total_accu = accu_val
        print('-' * 59)
        print('| end of epoch {:3d} | time: {:5.2f}s | '
              'valid accuracy {:8.3f} '.format(epoch,
                                               time.time() - epoch_start_time,
                                               accu_val))
        print('-' * 59)

    print('Checking the results of test dataset.')
    accu_test = evaluate(test_dataloader, model, criterion)
    print('test accuracy {:8.3f}'.format(accu_test))

    # Run article prediction
    ag_news_label = {1: "World", 2: "Sports", 3: "Business", 4: "Sci/Tec"}

    with p.open() as readfile:
        ex_text_str = readfile.read()

    model = model.to("cpu")

    print("This is a %s news" %
          ag_news_label[predict(ex_text_str, text_pipeline, model)])