コード例 #1
0
ファイル: datasets.py プロジェクト: erikhou45/Lipspeak
def get_splits_datasetv(cmu_dict_path, data_struct_path, splitname):
    i_field = data.Field(lambda x: x)
    g_field = data.Field(
        init_token='<s>', tokenize=(
            lambda x: list(x.split('(')[0])))  #sequence reversing removed
    p_field = data.Field(init_token='<os>',
                         eos_token='</os>',
                         tokenize=(lambda x: x.split('#')[0].split()))

    if data_struct_path == "../data/lrw/DsplitsLRW.json":
        Wstruct = CMUDict.splits_dataset_lrw(cmu_dict_path, i_field, g_field,
                                             p_field)
    elif splitname == "test":
        Wstruct = CMUDict.splits_datasetv_test644(cmu_dict_path, i_field,
                                                  g_field, p_field)
        return Wstruct
    ##### Additional Code to perform multi-word queries by Erik #####
    elif splitname == "phrase":
        Wstruct = CMUDict.splits_dataset_phrases(cmu_dict_path, i_field,
                                                 g_field, p_field)
        return Wstruct
    ##### End of Additional Code to perform multi-word queries by Erik #####
    else:
        WsplitsLst = CMUDict.splits_datasetv(cmu_dict_path, i_field, g_field,
                                             p_field)
        S = ['train', 'val']
        Wsplits = {}
        for i, s in enumerate(S):
            Wsplits[s] = WsplitsLst[i]
        if splitname == 'val':
            Wstruct = Wsplits['val']
        else:
            Wstruct = Wsplits[splitname]
    return Wstruct
コード例 #2
0
    def get_data(self, batch_size, convert_digits=True):

        # Setup fields with batch dimension first
        inputs = data.Field(
            init_token="<bos>",
            eos_token="<eos>",
            batch_first=True,
            tokenize='spacy',
            preprocessing=data.Pipeline(
                lambda w: '0' if convert_digits and w.isdigit() else w))
        tags = data.Field(init_token="<bos>",
                          eos_token="<eos>",
                          batch_first=True)
        fields = (('text', inputs), ('label', tags))
        # Download and the load default data.
        # train, val, test = datasets.sequence_tagging.SequenceTaggingDataset.splits(root=self.path, fields=fields)
        train, val, test = Ingredients.splits(name=self.ner_type,
                                              fields=tuple(fields),
                                              root=self.path)

        print('---------- NYT INGREDIENTS NER ---------')
        print('Train size: %d' % (len(train)))
        print('Validation size: %d' % (len(val)))
        print('Test size: %d' % (len(test)))

        # Build vocab
        inputs.build_vocab(train, val, max_size=50000)
        # , vectors=[GloVe(name='6B', dim='200'), CharNGram()])
        tags.build_vocab(train.label)

        logger.info('Input vocab size:%d' % (len(inputs.vocab)))
        logger.info('Tagset size: %d' % (len(tags.vocab)))

        self.inputs = inputs
        self.tags = tags

        self.dictionary.word2idx = dict(self.inputs.vocab.stoi)
        self.dictionary.idx2word = list(self.dictionary.word2idx.keys())
        self.tag_vocab = self.tags.vocab.stoi

        # Get iterators
        self.train, self.valid, self.test = data.BucketIterator.splits(
            (train, val, test),
            batch_size=batch_size,
            device=torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu"))

        self.train.repeat = False
        self.valid.repeat = False
        self.test.repeat = False
        """
        return {
            'task': 'nyt_ingredients.ner',
            'iters': (self.train, self.val, self.test),
            'vocabs': (inputs.vocab, tags.vocab)
        }
        """
        """
コード例 #3
0
def get_field(tokenizer):

    # step 1 get special tokens indices
    init_token_idx = tokenizer.cls_token_id
    eos_token_idx = tokenizer.sep_token_id
    pad_token_idx = tokenizer.pad_token_id
    unk_token_idx = tokenizer.unk_token_id
    print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)

    # step 2 get max_length
    max_input_length = tokenizer.max_model_input_sizes['bert-base-uncased']
    print('max input size {}'.format(max_input_length))

    # step 3 define tokenize
    def tokenize_with_cut(sentence):
        tokens = tokenizer.tokenize(sentence)
        tokens = tokens[:max_input_length - 2]

        return tokens

    # step 4 define field
    TEXT = data.Field(batch_first=True,
                      use_vocab=False,
                      tokenize=tokenize_with_cut,
                      preprocessing=tokenizer.convert_tokens_to_ids,
                      init_token=init_token_idx,
                      eos_token=eos_token_idx,
                      pad_token=pad_token_idx,
                      unk_token=unk_token_idx)

    LABEL = data.LabelField(dtype=torch.float)

    return TEXT, LABEL
コード例 #4
0
 def __init__(self):
     # dataset + loaders
     import torchtext.legacy.data as data
     import torchtext.legacy.datasets as datasets
     max_len = 200
     text = data.Field(sequential=True,
                       fix_length=max_len,
                       batch_first=True,
                       lower=True,
                       dtype=torch.long)
     label = data.LabelField(sequential=False, dtype=torch.long)
     ds_train, ds_test = datasets.IMDB.splits(
         text, label, path='/home/seymour/data/IMDB/aclImdb')
     print('train.fields :', ds_train.fields)
     ds_train, ds_valid = ds_train.split(0.9)
     print('train : ', len(ds_train))
     print('valid : ', len(ds_valid))
     print('test : ', len(ds_test))
     num_words = 50000
     text.build_vocab(ds_train, max_size=num_words)
     label.build_vocab(ds_train)
     vocab = text.vocab
     batch_size = 164
     self.train_loader, self.valid_loader, self.test_loader = data.BucketIterator.splits(
         (ds_train, ds_valid, ds_test),
         batch_size=batch_size,
         sort_key=lambda x: len(x.text),
         repeat=False)
     print('train_loader : ', len(self.train_loader))
     print('valid_loader : ', len(self.valid_loader))
     print('test_loader : ', len(self.test_loader))
     # model etc
     self.model = TransformerClassifier().to(device=DEVICE)
     self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=0.001)
コード例 #5
0
def main_stupid3():
    """We use legacy stuff, like in the notebook"""
    # dataset + loaders
    import torchtext.legacy.data as data
    import torchtext.legacy.datasets as datasets
    max_len = 200
    text = data.Field(sequential=True,
                      fix_length=max_len,
                      batch_first=True,
                      lower=True,
                      dtype=torch.long)
    label = data.LabelField(sequential=False, dtype=torch.long)
    # datasets.IMDB.download('/home/seymour/data')
    ds_train, ds_test = datasets.IMDB.splits(
        text, label, path='/home/seymour/data/IMDB/aclImdb')
    # print('ds_train', len(ds_train))
    # print('ds_test', len(ds_test))
    print('train.fields :', ds_train.fields)
    ds_train, ds_valid = ds_train.split(0.9)
    print('train : ', len(ds_train))
    print('valid : ', len(ds_valid))
    print('test : ', len(ds_test))
    num_words = 50000
    text.build_vocab(ds_train, max_size=num_words)
    label.build_vocab(ds_train)
    vocab = text.vocab
    batch_size = 164
    train_loader, valid_loader, test_loader = data.BucketIterator.splits(
        (ds_train, ds_valid, ds_test),
        batch_size=batch_size,
        sort_key=lambda x: len(x.text),
        repeat=False)
    # model etc
    # model = MultiHeadAttention(d_model=32, num_heads=2)
    # model = Embeddings(d_model=32, vocab_size=50002, max_position_embeddings=10000, p=0.1)
    # model = Encoder(num_layers=1, d_model=32, num_heads=2, ff_hidden_dim=128, input_vocab_size=50002,
    #                 max_position_embeddings=10000)
    model = TransformerClassifier()
    model.to(device=DEVICE)

    if False:
        batch = next(iter(train_loader))
        print('batch:', type(batch), len(batch))
        x = batch.text
        y = batch.label
        print_it(x, 'x')  # [164, 200]
        print_it(y, 'y')  # [164]
        print(y)
        print(x[0])

    if True:
        batch = next(iter(train_loader))
        x = batch.text.to(DEVICE)
        y = batch.label.to(DEVICE)
        print_it(x, 'x')  # [164, 200]
        print_it(y, 'y')  # [164]
        out = model(x)
        print_it(out, 'out')
コード例 #6
0
def load_data(batch_size, device):
    # 文本
    text = data.Field(sequential=True,
                      tokenize=tokenizer,
                      lower=True,
                      batch_first=True)
    # 标签
    label = data.Field(sequential=True, tokenize=tokenizer, batch_first=True)
    # 构建DataSet
    train, valid, test = data.TabularDataset.splits(
        path='./data/',
        skip_header=True,
        train='train.csv',
        validation='valid.csv',
        test='test.csv',
        format='csv',
        fields=[('TEXT', text), ('LABEL', label)],
    )
    # 创建词表
    text.build_vocab(train, vectors=Vectors(name='./data/glove.6B.300d.txt'))

    label.build_vocab(train)

    # 构建迭代器
    train_iter, valid_iter = data.BucketIterator.splits(
        datasets=(train, valid),
        sort_key=lambda x: len(x.TEXT),
        sort_within_batch=False,
        shuffle=True,
        batch_size=batch_size,
        repeat=False,
        device=device)

    test_iter = data.Iterator(test,
                              sort=False,
                              shuffle=False,
                              sort_within_batch=False,
                              batch_size=batch_size,
                              repeat=False,
                              device=device)

    return train_iter, valid_iter, test_iter, text.vocab
コード例 #7
0
    def get_data(self):
        # Define the fields associated with the sequences.
        self.inputs = data.Field(init_token="<bos>",
                                 eos_token="<eos>",
                                 tokenize='spacy')
        self.UD_TAG = data.Field(init_token="<bos>", eos_token="<eos>")
        self.PTB_TAG = data.Field(init_token="<bos>", eos_token="<eos>")

        # Download and the load default data.
        print(self.path)
        # was udtag
        train_data, val_data, test_data = datasets.UDPOS.splits(
            root=self.path,
            fields=(('text', self.inputs), ('label', self.UD_TAG),
                    ('ptbtag', self.PTB_TAG)))

        self.inputs.build_vocab(train_data, min_freq=3)
        self.dictionary.word2idx = dict(self.inputs.vocab.stoi)
        self.dictionary.idx2word = list(self.dictionary.word2idx.keys())

        self.UD_TAG.build_vocab(train_data.label)
        self.PTB_TAG.build_vocab(train_data.ptbtag)

        if self.ud:
            self.tag_vocab = self.UD_TAG.vocab.stoi
            self.ptb_vocab = self.PTB_TAG.vocab.stoi
        else:
            self.ud_vocab = self.UD_TAG.vocab.stoi
            self.tag_vocab = self.PTB_TAG.vocab.stoi

        self.train, self.valid, self.test = data.BucketIterator.splits(
            (train_data, val_data, test_data),
            batch_size=self.batch_size,
            device=self.devicer)

        self.train.repeat = False
        self.valid.repeat = False
        self.test.repeat = False
コード例 #8
0
def main(args):
    if args.device:
        device = args.device
    else:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

    text_field = data.Field(tokenize=list)
    datasets = WikiText2.splits(text_field)
    text_field.build_vocab(datasets[0])

    train_iter, test_iter, val_iter = data.BPTTIterator.splits(datasets,
                                                               batch_size=32,
                                                               bptt_len=512,
                                                               device=device)

    vocab = text_field.vocab

    print(f'Vocab size: {len(vocab)}')

    model_args = dict(rnn_type='lstm',
                      ntoken=args.num_latents,
                      ninp=256,
                      nhid=1024,
                      nlayers=2)
    if args.model_args:
        model_args.update(dict(eval(args.model_args)))

    model = SHARNN(**model_args).to(device)
    model.train()

    criterion = nn.NLLLoss()

    #optim = torch.optim.SGD(model.parameters(), lr=5.0)
    optim = torch.optim.Adam(model.parameters(), lr=2e-3)

    for epoch in range(10):
        hidden = None
        mems = None

        total_loss = 0

        for step, batch in enumerate(train_iter):
            optim.zero_grad()

            if hidden is not None:
                hidden = repackage_hidden(hidden)
            if mems is not None:
                mems = repackage_hidden(mems)

            output, hidden, mems, attn_outs, _ = model(batch.text,
                                                       hidden,
                                                       return_h=True,
                                                       mems=mems)

            logits = model.decoder(output)
            logits = F.log_softmax(logits, dim=-1)

            assert logits.size(1) == batch.target.size(1)

            loss = criterion(logits.view(-1, logits.size(-1)),
                             batch.target.view(-1))
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

            optim.step()

            total_loss += loss.data

            if step % args.log_interval == 0 and step > 0:
                cur_loss = total_loss / args.log_interval
                print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | '
                      'loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format(
                          epoch, step, len(train_iter),
                          optim.param_groups[0]['lr'], cur_loss,
                          math.exp(cur_loss), cur_loss / math.log(2)))
                total_loss = 0
コード例 #9
0

BATCH_SIZE = 128
HIDDEN_SIZE = 100
device = torch.device("cuda")

tag_to_ix = {'start_tag': 0, 'stop_tag': 29, 'pad_tag': 30}
tag_size = 31
working_path = '/home/jongsu/jupyter/pytorch_dialogue_ie/'
WV_PATH = '/home/jongsu/jupyter/pytorch_dialogue_ie/parameter/dialogue_wv'

wv_model = word2vec.Word2Vec(size=100, window=5, min_count=5, workers=4)
wv_model = word2vec.Word2Vec.load(WV_PATH)

my_fields = {
    'dial': ('Text', data.Field(sequential=True)),
    'emo': ('labels_1', data.Field(sequential=False)),
    'act': ('labels_2', data.Field(sequential=False))
}
print("make data")
train_data = MyTabularDataset.splits(path=working_path,
                                     train='data_jsonfile/full_data_test.json',
                                     fields=my_fields)
train_data = sorted(train_data, key=lambda x: sentence_maxlen_per_dialogue(x))
train_data = train_data  # exclude dialogue which has extremely long sentence (0~11117 => 0~9999)
train = sorted(train_data, key=lambda x: -len(x.Text)
               )  # reordering training dataset with number of sentences
# low index has much sentence because afterwards we use torch pad_sequence
print(train[0].labels_1)
print(train[0].labels_2)
コード例 #10
0
def main():
    global USE_CUDA, DEVICE
    global args, best_met, valid_minib_counter
    global LOWER, PAD_INDEX, NAMES, TRG_NAMES
    global BOS_WORD, EOS_WORD, BLANK_WORD

    # global UNK_TOKEN,PAD_TOKEN,SOS_TOKEN,EOS_TOKEN,TRG_NAMES,LOWER,PAD_INDEX,NAMES,MIN_FREQ

    label_writers = []

    LOWER = False
    BOS_WORD = '['
    EOS_WORD = ']'
    BLANK_WORD = "!"
    MAX_LEN = 100
    MIN_FREQ = 1

    ID = data.Field(sequential=False, use_vocab=False)

    NAMES = data.Field(tokenize=tokenize,
                       batch_first=True,
                       lower=LOWER,
                       include_lengths=False,
                       pad_token=BLANK_WORD,
                       init_token=None,
                       eos_token=EOS_WORD)

    TRG_NAMES = data.Field(tokenize=tokenize,
                           batch_first=True,
                           lower=LOWER,
                           include_lengths=False,
                           pad_token=BLANK_WORD,
                           init_token=BOS_WORD,
                           eos_token=EOS_WORD)

    LBL = data.Field(sequential=False, use_vocab=False)

    CNT = data.Field(sequential=False, use_vocab=False)

    datafields = [("id", ID), ("src", NAMES), ("trg", TRG_NAMES), ("clf", LBL),
                  ("cn", CNT)]

    trainval_data = data.TabularDataset(path=args.train_df_path,
                                        format='csv',
                                        skip_header=True,
                                        fields=datafields)

    if not args.predict and not args.evaluate:
        train_data = data.TabularDataset(path=args.trn_df_path,
                                         format='csv',
                                         skip_header=True,
                                         fields=datafields)

        val_data = data.TabularDataset(path=args.val_df_path,
                                       format='csv',
                                       skip_header=True,
                                       fields=datafields)

        print('Train length {}, val length {}'.format(len(train_data),
                                                      len(val_data)))

    if args.predict:
        test_data = data.TabularDataset(path=args.test_df_path,
                                        format='csv',
                                        skip_header=True,
                                        fields=datafields)

    MIN_FREQ = args.min_freq  # NOTE: we limit the vocabulary to frequent words for speed
    NAMES.build_vocab(trainval_data.src, min_freq=MIN_FREQ)
    TRG_NAMES.build_vocab(trainval_data.trg, min_freq=MIN_FREQ)
    PAD_INDEX = TRG_NAMES.vocab.stoi[BLANK_WORD]

    del trainval_data
    gc.collect()

    if not args.predict and not args.evaluate:
        train_iter = data.BucketIterator(train_data,
                                         batch_size=args.batch_size,
                                         train=True,
                                         sort_within_batch=True,
                                         sort_key=lambda x:
                                         (len(x.src), len(x.trg)),
                                         repeat=False,
                                         device=DEVICE,
                                         shuffle=True)

        valid_iter_batch = data.Iterator(val_data,
                                         batch_size=args.batch_size,
                                         train=False,
                                         sort_within_batch=True,
                                         sort_key=lambda x:
                                         (len(x.src), len(x.trg)),
                                         repeat=False,
                                         device=DEVICE,
                                         shuffle=False)

        val_ids = []
        for b in valid_iter_batch:
            val_ids.extend(list(b.id.cpu().numpy()))

        print('Preparing data for validation')

        train_df = pd.read_csv('../data/proc_train.csv')
        train_df = train_df.set_index('id')

        # val_gts = train_df.loc[val_ids,'fullname_true'].values
        # val_ors = train_df.loc[val_ids,'fullname'].values
        # incorrect_idx = list(train_df[train_df.target==1].index.values)
        # incorrect_val_ids = list(set(val_ids).intersection(set(incorrect_idx)))
        # correct_val_ids = list(set(val_ids)-set(incorrect_val_ids))

        print('Making dictionaries')

        id2gt = dict(train_df['fullname_true'])
        id2clf_gt = dict(train_df['target'])
        val_gts = [id2gt[_] for _ in val_ids]
        val_clf_gts = [id2clf_gt[_] for _ in val_ids]
        del train_df
        gc.collect()

    if args.evaluate:
        val_data = data.TabularDataset(path=args.val_df_path,
                                       format='csv',
                                       skip_header=True,
                                       fields=datafields)

        valid_iter_batch = data.Iterator(val_data,
                                         batch_size=args.batch_size,
                                         train=False,
                                         sort_within_batch=True,
                                         sort_key=lambda x:
                                         (len(x.src), len(x.trg)),
                                         repeat=False,
                                         device=DEVICE,
                                         shuffle=False)

        val_ids = []
        for b in valid_iter_batch:
            val_ids.extend(list(b.id.cpu().numpy()))

        print('Preparing data for validation')

        train_df = pd.read_csv('../data/proc_train.csv')
        train_df = train_df.set_index('id')

        print('Making dictionaries')

        id2gt = dict(train_df['fullname_true'])
        id2clf_gt = dict(train_df['target'])
        val_gts = [id2gt[_] for _ in val_ids]
        val_clf_gts = [id2clf_gt[_] for _ in val_ids]
        del train_df
        gc.collect()

    if args.predict:
        test_iter_batch = data.Iterator(test_data,
                                        batch_size=args.batch_size,
                                        train=False,
                                        sort_within_batch=True,
                                        sort_key=lambda x:
                                        (len(x.src), len(x.trg)),
                                        repeat=False,
                                        device=DEVICE,
                                        shuffle=False)

        test_ids = []
        for b in test_iter_batch:
            test_ids.extend(list(b.id.cpu().numpy()))

    model = make_model(
        len(NAMES.vocab),
        len(TRG_NAMES.vocab),
        N=args.num_layers,
        d_model=args.hidden_size,
        d_ff=args.ff_size,
        h=args.att_heads,
        dropout=args.dropout,
        num_classes=args.num_classes,
    )
    model.to(DEVICE)

    loaded_from_checkpoint = False

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            args.start_epoch = checkpoint['epoch']
            best_met = checkpoint['best_met']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))

            loaded_from_checkpoint = True
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        args.start_epoch = 0

    if args.tensorboard:
        writer = SummaryWriter('runs_encdec/{}'.format(tb_name))

    if args.evaluate:
        print('Running prediction on val set')
        if not os.path.exists('eval/'):
            os.makedirs('eval/')

        preds, clf_preds = predict(
            (rebatch(PAD_INDEX, x) for x in valid_iter_batch),
            model,
            max_len=70,
            src_vocab=NAMES.vocab,
            trg_vocab=TRG_NAMES.vocab,
            num_batches=len(valid_iter_batch),
            return_logits=True)

        predict_df = pd.DataFrame({
            'id': val_ids,
            'target': clf_preds,
            'fullname_true': preds
        })

        predict_df.set_index('id').to_csv('eval/{}.csv'.format(args.tb_name))
    if args.predict:
        print('Running prediction on test set')
        if not os.path.exists('predictions/'):
            os.makedirs('predictions/')

        preds, clf_preds = predict(
            (rebatch(PAD_INDEX, x) for x in test_iter_batch),
            model,
            max_len=70,
            src_vocab=NAMES.vocab,
            trg_vocab=TRG_NAMES.vocab,
            num_batches=len(test_iter_batch),
            return_logits=True)

        predict_df = pd.DataFrame({
            'id': test_ids,
            'target': clf_preds,
            'fullname_true': preds
        })

        predict_df.set_index('id').to_csv('predictions/{}.csv'.format(
            args.tb_name))
    if not args.predict and not args.evaluate:
        print('Training starts...')
        dev_perplexity, dev_clf_loss, preds, clf_preds = train(
            model,
            lr=1e-4,
            num_epochs=args.epochs,
            print_every=args.print_freq,
            train_iter=train_iter,
            valid_iter_batch=valid_iter_batch,
            val_ids=val_ids,
            val_clf_gts=val_clf_gts,
            val_gts=val_gts,
            writer=writer)
コード例 #11
0
ファイル: cnn.py プロジェクト: realjanpaulus/wordembeddings
def main():

    # ================
    # time managment #
    # ================

    program_st = time.time()

    # =====================
    # cnn logging handler #
    # =====================

    logging_filename = f"../logs/cnn.log"
    logging.basicConfig(level=logging.INFO,
                        filename=logging_filename,
                        filemode="w")
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter("%(levelname)s: %(message)s")
    console.setFormatter(formatter)
    logging.getLogger("").addHandler(console)

    punctuation = [
        "!",
        "#",
        "$",
        "%",
        "&",
        "'",
        "(",
        ")",
        "*",
        "+",
        ",",
        "-",
        ".",
        "/",
        ":",
        ";",
        "<",
        "=",
        ">",
        "?",
        "@",
        "[",
        "\\",
        "]",
        "^",
        "_",
        "`",
        "{",
        "|",
        "}",
        "~",
        "`",
        "``",
    ]

    # =================
    # hyperparamaters #
    # =================

    BATCH_SIZE = args.batch_size
    DATA_PATH = args.datapath
    DROPOUT = 0.5
    EPOCHS = args.epochs

    FILTER_SIZES = [3, 4, 5]
    LEARNING_RATE = args.learning_rate
    MAX_FEATURES = args.max_features
    N_FILTERS = 100

    # ============
    # embeddings #
    # ============

    EMBEDDING_TYPE = args.embedding_type

    if EMBEDDING_TYPE == "fasttext-en":
        EMBEDDING_NAME = "fasttext.en.300d"
        EMBEDDING_DIM = 300
    elif EMBEDDING_TYPE == "fasttext-simple":
        EMBEDDING_NAME = "fasttext.simple.300d"
        EMBEDDING_DIM = 300
    elif EMBEDDING_TYPE == "glove-840":
        EMBEDDING_NAME = "glove.840B.300d"
        EMBEDDING_DIM = 300
    elif EMBEDDING_TYPE == "glove-6":
        EMBEDDING_NAME = "glove.6B.300d"
        EMBEDDING_DIM = 300
    elif EMBEDDING_TYPE == "glove-twitter":
        EMBEDDING_NAME = "glove.twitter.27B.200d"
        EMBEDDING_DIM = 200
    else:
        EMBEDDING_NAME = "unknown"
        EMBEDDING_DIM = 300

    # ===============
    # preprocessing #
    # ===============

    TEXT = data.Field(tokenize="toktok", lower=True)

    LABEL = data.LabelField(dtype=torch.long)
    assigned_fields = {"review": ("text", TEXT), "rating": ("label", LABEL)}

    train_data, val_data, test_data = data.TabularDataset.splits(
        path=DATA_PATH,
        train=f"train{args.splitnumber}.json",
        validation=f"val{args.splitnumber}.json",
        test=f"test{args.splitnumber}.json",
        format="json",
        fields=assigned_fields,
        skip_header=True,
    )

    TEXT.build_vocab(
        train_data,
        vectors=EMBEDDING_NAME,
        unk_init=torch.Tensor.normal_,
        max_size=MAX_FEATURES,
    )
    LABEL.build_vocab(train_data)

    INPUT_DIM = len(TEXT.vocab)
    OUTPUT_DIM = len(LABEL.vocab)

    if torch.cuda.is_available():
        device = torch.device("cuda")
        logging.info(
            f"There are {torch.cuda.device_count()} GPU(s) available.")
        logging.info(f"Device name: {torch.cuda.get_device_name(0)}")
    else:
        logging.info("No GPU available, using the CPU instead.")
        device = torch.device("cpu")

    train_iterator, val_iterator, test_iterator = data.BucketIterator.splits(
        (train_data, val_data, test_data),
        batch_size=BATCH_SIZE,
        device=device,
        sort_key=lambda x: len(x.text),
        sort=False,
        sort_within_batch=False,
    )

    # ===========
    # CNN Model #
    # ===========

    print("\n")
    logging.info("#####################################")
    logging.info(f"Input dimension (= vocab size): {INPUT_DIM}")
    logging.info(f"Output dimension (= n classes): {OUTPUT_DIM}")
    logging.info(f"Embedding dimension: {EMBEDDING_DIM}")
    logging.info(f"Embedding type: {EMBEDDING_TYPE}")
    logging.info(f"Number of filters: {N_FILTERS}")
    logging.info(f"Filter sizes: {FILTER_SIZES}")
    logging.info(f"Dropout: {DROPOUT}")
    logging.info("#####################################")
    print("\n")

    if args.model == "kimcnn":
        model = models.KimCNN(
            input_dim=INPUT_DIM,
            output_dim=OUTPUT_DIM,
            embedding_dim=EMBEDDING_DIM,
            embedding_type=EMBEDDING_TYPE,
            n_filters=N_FILTERS,
            filter_sizes=FILTER_SIZES,
            dropout=DROPOUT,
        )

        OPTIMIZER = optim.Adadelta(model.parameters(), lr=LEARNING_RATE)
        CRITERION = nn.CrossEntropyLoss()
    else:
        logging.info(
            f"Model '{args.model}' does not exist. Script will be stopped.")
        exit()

    # for pt model
    output_add = f"_bs{BATCH_SIZE}_mf{MAX_FEATURES}_{EMBEDDING_TYPE}"
    output_file = f"savefiles/cnnmodel{output_add}.pt"

    if args.load_savefile:
        model.load_state_dict(torch.load(output_file))

    # load embeddings
    pretrained_embeddings = TEXT.vocab.vectors
    UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]
    PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]

    model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)
    model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)

    # put model and loss criterion to device (cpu or gpu)
    model = model.to(device)
    CRITERION = CRITERION.to(device)

    # ================
    # train function #
    # ================

    def train(model, iterator, optimizer, criterion):

        epoch_loss = 0
        epoch_acc = 0

        model.train()

        for batch in iterator:
            optimizer.zero_grad()
            predictions = model(batch.text)
            loss = criterion(predictions, batch.label)
            acc = categorical_accuracy(predictions, batch.label)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            epoch_acc += acc.item()

        return epoch_loss / len(iterator), epoch_acc / len(iterator)

    # =====================
    # evaluation function #
    # =====================

    def evaluate(model, iterator, criterion, return_lists=False):

        epoch_loss = 0
        epoch_acc = 0

        model.eval()

        if return_lists:
            pred_labels, true_labels = [], []

        with torch.no_grad():
            for batch in iterator:
                predictions = model(batch.text)
                loss = criterion(predictions, batch.label)
                acc = categorical_accuracy(predictions, batch.label)

                epoch_loss += loss.item()
                epoch_acc += acc.item()

                if return_lists:
                    predictions = predictions.detach().cpu().numpy()
                    batch_labels = batch.label.to("cpu").numpy()
                    pred_labels.append(predictions)
                    true_labels.append(batch_labels)

        if return_lists:
            return (
                epoch_loss / len(iterator),
                epoch_acc / len(iterator),
                pred_labels,
                true_labels,
            )
        else:
            return epoch_loss / len(iterator), epoch_acc / len(iterator)

    # =================
    # actual training #
    # =================

    best_val_loss = float("inf")

    train_losses = []
    val_losses = []
    val_losses_epochs = {}
    total_train_time = time.time()

    for epoch in range(EPOCHS):

        start_time = time.time()

        train_loss, train_acc = train(model, train_iterator, OPTIMIZER,
                                      CRITERION)
        val_loss, val_acc = evaluate(model, val_iterator, CRITERION)

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        val_losses_epochs[f"epoch{epoch+1}"] = val_loss

        end_time = time.time()
        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), output_file)

        logging.info(
            f"Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s")
        logging.info(
            f"\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%"
        )
        logging.info(
            f"\tVal. Loss: {val_loss:.3f} |  Val. Acc: {val_acc*100:.2f}%")

        if early_stopping(val_losses_epochs, patience=args.patience):
            logging.info(f"Stopping epoch run early (Epoch {epoch}).")
            break

    logging.info("Training took {:} (h:mm:ss) \n".format(
        format_time(time.time() - total_train_time)))
    print("--------------------------------\n")

    plt.plot(train_losses, label="Training loss")
    plt.plot(val_losses, label="Validation loss")
    plt.legend()
    plt.title(f"Losses (until epoch {epoch})")
    plt.savefig(
        f"../results/{args.model}_loss_{args.embedding_type}_{args.splitnumber}_bs{args.batch_size}_mf{args.max_features}_lr{args.learning_rate}.png"
    )

    # ============
    # Test model #
    # ============

    total_test_time = time.time()
    model.load_state_dict(torch.load(output_file))

    if args.save_confusion_matrices:
        test_loss, test_acc, pred_labels, true_labels = evaluate(
            model, test_iterator, CRITERION, return_lists=True)

        flat_predictions = np.concatenate(pred_labels, axis=0)
        flat_predictions = np.argmax(flat_predictions, axis=1).flatten()
        flat_true_labels = np.concatenate(true_labels, axis=0)

        logging.info("Saving confusion matrices.")
        testd_ = load_jsonl_to_df(
            f"{args.datapath}/test{args.splitnumber}.json")
        classes = testd_["rating"].drop_duplicates().tolist()

        wrong_ratings = []
        for idx, (p, t) in enumerate(zip(flat_predictions, flat_true_labels)):
            if p != t:
                wrong_ratings.append(idx)
        testd_ = testd_.drop(wrong_ratings)
        testd_.to_csv("../results/misclassifications.csv")

        cm_df = pd.DataFrame(
            confusion_matrix(flat_true_labels, flat_predictions),
            index=classes,
            columns=classes,
        )
        cm_df.to_csv(
            f"../results/cm_{args.embedding_type}_{args.splitnumber}_bs{args.batch_size}_mf{args.max_features}_lr{args.learning_rate}.csv"
        )

    else:
        test_loss, test_acc = evaluate(model, test_iterator, CRITERION)

    test_output = f"\nTest Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%"
    test_outputfile = f"../results/{args.embedding_type}_{args.splitnumber}_bs{args.batch_size}_mf{args.max_features}_lr{args.learning_rate}.txt"

    with open(test_outputfile, "w") as txtfile:
        txtfile.write(f"Last epoch: {epoch}{test_output}")

    logging.info(test_output)
    logging.info("Testing took {:} (h:mm:ss) \n".format(
        format_time(time.time() - total_test_time)))
    print("--------------------------------\n")
    logging.info("Total duration {:} (h:mm:ss) \n".format(
        format_time(time.time() - program_st)))
コード例 #12
0
def load_data(batch_size, device):
    # 标签
    LABEL = data.Field(sequential=False, use_vocab=False, batch_first=True)
    # 文本
    SEN1 = data.Field(sequential=True, tokenize=tokenizer,  fix_length=50, lower=True, batch_first=True)
    SEN2 = data.Field(sequential=True, tokenize=tokenizer,  fix_length=50, lower=True, batch_first=True)

    # 构建DataSet
    train, valid = data.TabularDataset.splits(
        path='./snli_1.0/',
        skip_header=True,
        train="train4.csv",
        validation="dev3.csv",
        format='csv',
        fields=[("label", LABEL), ("sentence1", SEN1), ("sentence2", SEN2)],
    )

    test = data.TabularDataset(
        path='./snli_1.0/test3.csv',
        skip_header=True,
        format='csv',
        fields=[("sentence1", SEN1), ("sentence2", SEN2)],
    )

    # 创建词表
    SEN1.build_vocab((train.sentence11, train.sentence2), vectors=Vectors(name='/data/yinli/dataset/glove.840B.300d.txt'))
    SEN2.vocab = SEN1.vocab

    # 构建迭代器
    train_iter = data.BucketIterator(train,
                                sort_key=lambda x: len(x.SEN1),
                                sort_within_batch=False,
                                shuffle=True,
                                batch_size=batch_size,
                                repeat=False,
                                device=device)

    valid_iter = data.Iterator(valid,
                              sort=False,
                              shuffle=False,
                              sort_within_batch=False,
                              batch_size=batch_size,
                              repeat=False,
                              train=False,
                              device=device)

    test_iter = data.Iterator(test,
                               sort=False,
                               shuffle=False,
                               sort_within_batch=False,
                               batch_size=batch_size,
                               repeat=False,
                               train=False,
                               device=device)

    return train_iter, valid_iter, test_iter, SEN1.vocab, SEN2.vocab

# 加载数据集,生成迭代器
# def load_data(batch_size, device):
#     # 标签
#     LABEL = data.Field(sequential=True, batch_first=True)
#     # 文本
#     SEN1 = data.Field(sequential=True, tokenize=tokenizer, lower=True, batch_first=True)
#     SEN2 = data.Field(sequential=True, tokenize=tokenizer, lower=True, batch_first=True)
#
#     # 构建DataSet
#     train = data.TabularDataset(
#         path='./snli_1.0/train2.csv',
#         skip_header=True,
#         format='csv',
#         fields=[("label", LABEL), ("sentence1", SEN1), ("sentence2", SEN2)],
#     )
#
#     # 创建词表
#     SEN1.build_vocab(train, vectors=Vectors(name='/data/yinli/dataset/glove.840B.300d.txt'))
#     SEN2.build_vocab(train, vectors=Vectors(name='/data/yinli/dataset/glove.840B.300d.txt'))
#     LABEL.build_vocab(train)
#
#     # 构建迭代器
#     train_iter = data.BucketIterator(train,
#                                 sort_key=lambda x: len(x.SEN1),
#                                 sort_within_batch=False,
#                                 shuffle=True,
#                                 batch_size=batch_size,
#                                 repeat=False,
#                                 device=device)
#
#     return train_iter, SEN1.vocab, SEN2.vocab


# device = torch.device("cuda:1")
# train_iter, dev_iter, test_iter, sentence1_vocab, sentence2_vocab = load_data(5, 50, device)
#
# for batch in train_iter:
#     print(batch.label)
#     print(batch.sentence1)
#     print(batch.sentence2)
#     break
# print(len(sentence1_vocab.vectors))
#
# print(sentence1_vocab.stoi['frown'])
# print(sentence2_vocab.stoi['frown'])
# print(sentence1_vocab.stoi['<unk>'])
#
# del train_iter
# del dev_iter
# del test_iter
# del sentence1_vocab
# del sentence2_vocab

#
# embedding = torch.cat((sentence2_vocab.vectors ,sentence1_vocab.vectors[2:]), 0)
# print(embedding.size())
# vocab_size, embed_size = embedding.size()
# print(vocab_size)
# print(embed_size)
# print(len(label_vocab))
# print(label_vocab.stoi)
#label2id = {'<unk>': 0, '<pad>': 1, 'neutral': 2, 'contradiction': 3, 'entailment': 4}
コード例 #13
0
def main():
    global USE_CUDA,DEVICE
    global UNK_TOKEN,PAD_TOKEN,SOS_TOKEN,EOS_TOKEN,TRG_NAMES,LOWER,PAD_INDEX,NAMES,MIN_FREQ
    global args,best_met,valid_minib_counter

    label_writers = []

    UNK_TOKEN = "!"
    PAD_TOKEN = "_"    
    SOS_TOKEN = "["
    EOS_TOKEN = "]"
    LOWER = False

    ID = data.Field(sequential=False,
                    use_vocab=False)

    NAMES = data.Field(tokenize=tokenize,
                       batch_first=True,
                       lower=LOWER,
                       include_lengths=True,
                       unk_token=UNK_TOKEN,
                       pad_token=PAD_TOKEN,
                       init_token=None,
                       eos_token=EOS_TOKEN)

    TRG_NAMES = data.Field(tokenize=tokenize, 
                           batch_first=True,
                           lower=LOWER,
                           include_lengths=True,
                           unk_token=UNK_TOKEN,
                           pad_token=PAD_TOKEN,
                           init_token=SOS_TOKEN,
                           eos_token=EOS_TOKEN)

    LBL = data.Field(sequential=False,
                     use_vocab=False)

    CNT = data.Field(sequential=False,
                     use_vocab=False)

    datafields = [("id", ID),
                  ("src", NAMES),
                  ("trg", TRG_NAMES),
                  ("clf", LBL),
                  ("cn", CNT)
                 ]

    train_data = data.TabularDataset(path=args.train_df_path,
                                     format='csv',
                                     skip_header=True,
                                     fields=datafields)

    train_data, valid_data = train_data.split(split_ratio=args.split_ratio,
                                              stratified=args.stratified,
                                              strata_field=args.strata_field)
    
    
    print('Train length {}, val length {}'.format(len(train_data),len(valid_data)))

    MIN_FREQ = args.min_freq  # NOTE: we limit the vocabulary to frequent words for speed
    NAMES.build_vocab(train_data.src, min_freq=MIN_FREQ)
    TRG_NAMES.build_vocab(train_data.trg, min_freq=MIN_FREQ)
    PAD_INDEX = TRG_NAMES.vocab.stoi[PAD_TOKEN]

    train_iter = data.BucketIterator(train_data,
                                     batch_size=args.batch_size,
                                     train=True, 
                                     sort_within_batch=True, 
                                     sort_key=lambda x: (len(x.src), len(x.trg)),
                                     repeat=False,
                                     device=DEVICE,
                                     shuffle=True)

    valid_iter_batch = data.Iterator(valid_data,
                               batch_size=args.batch_size,
                               train=False,
                               sort_within_batch=True,
                               sort_key=lambda x: (len(x.src), len(x.trg)),
                               repeat=False, 
                               device=DEVICE,
                               shuffle=False)

    val_ids = []
    for b in valid_iter_batch:
        val_ids.extend(list(b.id.cpu().numpy()))    
    
    print('Preparing data for validation')

    train_df = pd.read_csv('../data/proc_train.csv')
    train_df = train_df.set_index('id')
    val_gts = train_df.loc[val_ids,'fullname_true'].values
    val_ors = train_df.loc[val_ids,'fullname'].values
    incorrect_idx = list(train_df[train_df.target==1].index.values)

    incorrect_val_ids = list(set(val_ids).intersection(set(incorrect_idx)))
    correct_val_ids = list(set(val_ids)-set(incorrect_val_ids))
    
    print('Making dictionaries')
    
    id2gt = dict(train_df['fullname_true'])
    id2clf_gt = dict(train_df['target'])
    val_gts = [id2gt[_] for _ in val_ids]
    val_clf_gts = [id2clf_gt[_] for _ in val_ids]    
    del train_df
    gc.collect()
    
    model = make_model(len(NAMES.vocab),
                       len(TRG_NAMES.vocab),
                       device=DEVICE,
                       emb_size=args.emb_size,
                       hidden_size=args.hidden_size,
                       num_layers=args.num_layers,
                       dropout=args.dropout,
                       num_classes=args.num_classes)
    
    loaded_from_checkpoint = False

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location='cpu')
            args.start_epoch = checkpoint['epoch']
            best_met = checkpoint['best_met']
            model.load_state_dict(checkpoint['state_dict'])           
            print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
            loaded_from_checkpoint = True
            del checkpoint
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))    
    else:
        args.start_epoch = 0
        
    criterion = nn.CrossEntropyLoss(reduce=False).to(DEVICE)

    if args.tensorboard:
        writer = SummaryWriter('runs_encdec/{}'.format(tb_name))
    
    if args.evaluate:
        pass
    elif args.predict:
        pass
    else:
        print('Training starts...') 
        dev_perplexity,dev_clf_loss,preds,clf_preds = train(model,
                                                            lr=args.lr,
                                                            num_epochs=args.epochs,
                                                            print_every=args.print_freq,
                                                            train_iter=train_iter,
                                                            valid_iter_batch=valid_iter_batch,
                                                            val_ids=val_ids,
                                                            val_clf_gts=val_clf_gts,
                                                            val_gts=val_gts,
                                                            writer=writer)        
コード例 #14
0
# Device: set up kwargs for future
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Define tokenizer and field (field will be deprecated soon)
my_tok = spacy.load('en')


def spacy_tok(x):
    return [tok.text for tok in my_tok.tokenizer(x)]


TEXT = data.Field(tokenize=spacy_tok)

# Define data and iterator
train = torchtext.datasets.LanguageModelingDataset(path='poems.txt',
                                                   text_field=TEXT,
                                                   newline_eos=True)
test = torchtext.datasets.LanguageModelingDataset(path='poems_test.txt',
                                                  text_field=TEXT,
                                                  newline_eos=True)

TEXT.build_vocab(train, vectors="glove.6B.200d")

train_iter, test_iter = data.BPTTIterator.splits((train, test),
                                                 batch_size=32,
                                                 bptt_len=30,
                                                 device=device,
コード例 #15
0
ファイル: iwslt.py プロジェクト: AlexisDrch/Transformer
    def build(language_pair: LanguagePair, split: Split, max_length=100, min_freq=2,
              start_token="<s>", eos_token="</s>", blank_token="<blank>",
              batch_size_train=32, batch_size_validation=32,
              batch_size_test=32, device='cpu'):
        """
        Initializes an iterator over the IWSLT dataset.
        The iterator then yields batches of size `batch_size`.

        Returns one iterator for each split alongside the input & output vocab sets.

        Example:

        >>> dataset_iterator, _, _, src_vocab, trg_vocab = IWSLTDatasetBuilder.build(
        ...                                                   language_pair=language_pair,
        ...                                                   split=Split.Train,
        ...                                                   max_length=5,
        ...                                                   batch_size_train=batch_size_train)
        >>> batch = next(iter(dataset_iterator))

        :param language_pair: The language pair for which to create a vocabulary.
        :param split: The split type.
        :param max_length: Max length of sequence.
        :param min_freq: The minimum frequency a word should have to be included in the vocabulary
        :param start_token: The token that marks the beginning of a sequence.
        :param eos_token: The token that marks an end of sequence.
        :param blank_token: The token to pad with.
        :param batch_size_train: Desired size of each training batch.
        :param batch_size_validation: Desired size of each validation batch.
        :param batch_size_test: Desired size of each testing batch.
        :param device: The device on which to store the batches.
        :type device: str or torch.device

        :returns: (train_iterator, validation_iterator, test_iterator,
                   source_field.vocab, target_field.vocab)
        """
        # load corresponding tokenizer
        source_tokenizer, target_tokenizer = language_pair.tokenizer()
        # create pytorchtext data field to generate vocabulary
        source_field = data.Field(tokenize=source_tokenizer, pad_token=blank_token)
        target_field = data.Field(tokenize=target_tokenizer, init_token=start_token,
                                  eos_token=eos_token, pad_token=blank_token)

        # Generates train and validation datasets
        settings = dict()
        for key, split_type in [
            # ("validation", Split.Validation),  # Due to a bug in TorchText, cannot set to None
            ("test", Split.Test),
        ]:
            if (split & split_type):
                pass  # Keep default split setting
            else:
                settings[key] = None  # Disable split
        # noinspection PyTypeChecker
        train, validation, *out = datasets.IWSLT.splits(
            root=ROOT_DATASET_DIR,  # To check if the dataset was already downloaded
            exts=language_pair.extensions(),
            fields=(source_field, target_field),
            filter_pred=lambda x: all(len(val) <= max_length for val in (x.src, x.trg)),
            **settings
        )

        # Build vocabulary on training set
        source_field.build_vocab(train, min_freq=min_freq)
        target_field.build_vocab(train, min_freq=min_freq)

        train_iterator, validation_iterator, test_iterator = None, None, None

        def sort_func(x):
            return data.interleave_keys(len(x.src), len(x.trg))

        if split & Split.Train:
            train_iterator = data.BucketIterator(
                dataset=train, batch_size=batch_size_train, repeat=False,
                device=device, sort_key=sort_func)
        if split & Split.Validation:
            validation_iterator = data.BucketIterator(
                dataset=validation, batch_size=batch_size_validation, repeat=False,
                device=device, sort_key=sort_func)
        if split & Split.Test:
            test, *out = out
            test_iterator = data.BucketIterator(
                dataset=test, batch_size=batch_size_test, repeat=False,
                device=device, sort_key=sort_func)

        return (
            train_iterator,
            validation_iterator,
            test_iterator,
            source_field.vocab,
            target_field.vocab,
        )
コード例 #16
0
ファイル: data.py プロジェクト: romilly/scholarly
def load_data(tsv_fname: str = 'arxiv_data',
              data_dir: str = '.data',
              batch_size: int = 32,
              split_ratio: float = 0.95,
              random_seed: int = 42,
              vectors: str = 'fasttext') -> tuple:
    ''' 
    Loads the preprocessed data, tokenises it, builds a vocabulary,
    splits into a training- and validation set, numeralises the texts,
    batches the data into batches of similar text lengths and pads 
    every batch.

    INPUT
        tsv_fname: str = 'arxiv_data'
            The name of the tsv file, without file extension
        data_dir: str = '.data'
            The data directory
        batch_size: int = 32,
            The size of each batch
        split_ratio: float = 0.95
            The proportion of the dataset reserved for training
        vectors: {'fasttext', 'glove'} = 'fasttext'
            The type of word vectors to use. Here the FastText vectors are
            trained on the abstracts and the GloVe vectors are pretrained
            on the 6B corpus
        random_seed: int = 42
            A random seed to ensure that the same training/validation split
            is achieved every time. If set to None then no seed is used.

    OUTPUT
        A triple (train_iter, val_iter, params), with train_iter and val_iter
        being the iterators that iterates over the training- and validation
        samples, respectively, and params is a dictionary with entries:
            vocab_size
                The size of the vocabulary
            emb_dim
                The dimension of the word vectors
            emb_matrix
                The embedding matrix containing the word vectors
    '''
    from torchtext import data, vocab
    from utils import get_cats
    import random

    # Define the two types of fields in the tsv file
    TXT = data.Field()
    CAT = data.Field(sequential=False, use_vocab=False, is_target=True)

    # Set up the columns in the tsv file with their associated fields
    cats = get_cats(data_dir=data_dir)['id']
    fields = [('text', TXT)] + [(cat, CAT) for cat in cats]

    # Load in the dataset and tokenise the texts
    dataset = data.TabularDataset(path=get_path(data_dir) / f'{tsv_fname}.tsv',
                                  format='tsv',
                                  fields=fields,
                                  skip_header=True)

    # Split into a training- and validation set
    if random_seed is None:
        train, val = dataset.split(split_ratio=split_ratio)
    else:
        random.seed(random_seed)
        train, val = dataset.split(split_ratio=split_ratio,
                                   random_state=random.getstate())

    # Get the word vectors
    vector_cache = get_path(data_dir)
    base_url = 'https://filedn.com/lRBwPhPxgV74tO0rDoe8SpH/scholarly_data/'
    vecs = vocab.Vectors(name=vectors,
                         cache=vector_cache,
                         url=base_url + vectors)

    # Build the vocabulary of the training set
    TXT.build_vocab(train, vectors=vecs)

    # Numericalise the texts, batch them into batches of similar text
    # lengths and pad the texts in each batch
    train_iter, val_iter = data.BucketIterator.splits(
        datasets=(train, val),
        batch_size=batch_size,
        sort_key=lambda sample: len(sample.text))

    # Wrap the iterators to ensure that we output tensors
    train_dl = BatchWrapper(train_iter, vectors=vectors, cats=cats)
    val_dl = BatchWrapper(val_iter, vectors=vectors, cats=cats)

    del dataset, train, val, train_iter, val_iter
    return train_dl, val_dl, TXT.vocab