예제 #1
0
def main():
    global USE_CUDA, DEVICE
    global args, best_met, valid_minib_counter
    global LOWER, PAD_INDEX, NAMES, TRG_NAMES
    global BOS_WORD, EOS_WORD, BLANK_WORD

    # global UNK_TOKEN,PAD_TOKEN,SOS_TOKEN,EOS_TOKEN,TRG_NAMES,LOWER,PAD_INDEX,NAMES,MIN_FREQ

    label_writers = []

    LOWER = False
    BOS_WORD = '['
    EOS_WORD = ']'
    BLANK_WORD = "!"
    MAX_LEN = 100
    MIN_FREQ = 1

    ID = data.Field(sequential=False, use_vocab=False)

    NAMES = data.Field(tokenize=tokenize,
                       batch_first=True,
                       lower=LOWER,
                       include_lengths=False,
                       pad_token=BLANK_WORD,
                       init_token=None,
                       eos_token=EOS_WORD)

    TRG_NAMES = data.Field(tokenize=tokenize,
                           batch_first=True,
                           lower=LOWER,
                           include_lengths=False,
                           pad_token=BLANK_WORD,
                           init_token=BOS_WORD,
                           eos_token=EOS_WORD)

    LBL = data.Field(sequential=False, use_vocab=False)

    CNT = data.Field(sequential=False, use_vocab=False)

    datafields = [("id", ID), ("src", NAMES), ("trg", TRG_NAMES), ("clf", LBL),
                  ("cn", CNT)]

    trainval_data = data.TabularDataset(path=args.train_df_path,
                                        format='csv',
                                        skip_header=True,
                                        fields=datafields)

    if not args.predict and not args.evaluate:
        train_data = data.TabularDataset(path=args.trn_df_path,
                                         format='csv',
                                         skip_header=True,
                                         fields=datafields)

        val_data = data.TabularDataset(path=args.val_df_path,
                                       format='csv',
                                       skip_header=True,
                                       fields=datafields)

        print('Train length {}, val length {}'.format(len(train_data),
                                                      len(val_data)))

    if args.predict:
        test_data = data.TabularDataset(path=args.test_df_path,
                                        format='csv',
                                        skip_header=True,
                                        fields=datafields)

    MIN_FREQ = args.min_freq  # NOTE: we limit the vocabulary to frequent words for speed
    NAMES.build_vocab(trainval_data.src, min_freq=MIN_FREQ)
    TRG_NAMES.build_vocab(trainval_data.trg, min_freq=MIN_FREQ)
    PAD_INDEX = TRG_NAMES.vocab.stoi[BLANK_WORD]

    del trainval_data
    gc.collect()

    if not args.predict and not args.evaluate:
        train_iter = data.BucketIterator(train_data,
                                         batch_size=args.batch_size,
                                         train=True,
                                         sort_within_batch=True,
                                         sort_key=lambda x:
                                         (len(x.src), len(x.trg)),
                                         repeat=False,
                                         device=DEVICE,
                                         shuffle=True)

        valid_iter_batch = data.Iterator(val_data,
                                         batch_size=args.batch_size,
                                         train=False,
                                         sort_within_batch=True,
                                         sort_key=lambda x:
                                         (len(x.src), len(x.trg)),
                                         repeat=False,
                                         device=DEVICE,
                                         shuffle=False)

        val_ids = []
        for b in valid_iter_batch:
            val_ids.extend(list(b.id.cpu().numpy()))

        print('Preparing data for validation')

        train_df = pd.read_csv('../data/proc_train.csv')
        train_df = train_df.set_index('id')

        # val_gts = train_df.loc[val_ids,'fullname_true'].values
        # val_ors = train_df.loc[val_ids,'fullname'].values
        # incorrect_idx = list(train_df[train_df.target==1].index.values)
        # incorrect_val_ids = list(set(val_ids).intersection(set(incorrect_idx)))
        # correct_val_ids = list(set(val_ids)-set(incorrect_val_ids))

        print('Making dictionaries')

        id2gt = dict(train_df['fullname_true'])
        id2clf_gt = dict(train_df['target'])
        val_gts = [id2gt[_] for _ in val_ids]
        val_clf_gts = [id2clf_gt[_] for _ in val_ids]
        del train_df
        gc.collect()

    if args.evaluate:
        val_data = data.TabularDataset(path=args.val_df_path,
                                       format='csv',
                                       skip_header=True,
                                       fields=datafields)

        valid_iter_batch = data.Iterator(val_data,
                                         batch_size=args.batch_size,
                                         train=False,
                                         sort_within_batch=True,
                                         sort_key=lambda x:
                                         (len(x.src), len(x.trg)),
                                         repeat=False,
                                         device=DEVICE,
                                         shuffle=False)

        val_ids = []
        for b in valid_iter_batch:
            val_ids.extend(list(b.id.cpu().numpy()))

        print('Preparing data for validation')

        train_df = pd.read_csv('../data/proc_train.csv')
        train_df = train_df.set_index('id')

        print('Making dictionaries')

        id2gt = dict(train_df['fullname_true'])
        id2clf_gt = dict(train_df['target'])
        val_gts = [id2gt[_] for _ in val_ids]
        val_clf_gts = [id2clf_gt[_] for _ in val_ids]
        del train_df
        gc.collect()

    if args.predict:
        test_iter_batch = data.Iterator(test_data,
                                        batch_size=args.batch_size,
                                        train=False,
                                        sort_within_batch=True,
                                        sort_key=lambda x:
                                        (len(x.src), len(x.trg)),
                                        repeat=False,
                                        device=DEVICE,
                                        shuffle=False)

        test_ids = []
        for b in test_iter_batch:
            test_ids.extend(list(b.id.cpu().numpy()))

    model = make_model(
        len(NAMES.vocab),
        len(TRG_NAMES.vocab),
        N=args.num_layers,
        d_model=args.hidden_size,
        d_ff=args.ff_size,
        h=args.att_heads,
        dropout=args.dropout,
        num_classes=args.num_classes,
    )
    model.to(DEVICE)

    loaded_from_checkpoint = False

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            args.start_epoch = checkpoint['epoch']
            best_met = checkpoint['best_met']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))

            loaded_from_checkpoint = True
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        args.start_epoch = 0

    if args.tensorboard:
        writer = SummaryWriter('runs_encdec/{}'.format(tb_name))

    if args.evaluate:
        print('Running prediction on val set')
        if not os.path.exists('eval/'):
            os.makedirs('eval/')

        preds, clf_preds = predict(
            (rebatch(PAD_INDEX, x) for x in valid_iter_batch),
            model,
            max_len=70,
            src_vocab=NAMES.vocab,
            trg_vocab=TRG_NAMES.vocab,
            num_batches=len(valid_iter_batch),
            return_logits=True)

        predict_df = pd.DataFrame({
            'id': val_ids,
            'target': clf_preds,
            'fullname_true': preds
        })

        predict_df.set_index('id').to_csv('eval/{}.csv'.format(args.tb_name))
    if args.predict:
        print('Running prediction on test set')
        if not os.path.exists('predictions/'):
            os.makedirs('predictions/')

        preds, clf_preds = predict(
            (rebatch(PAD_INDEX, x) for x in test_iter_batch),
            model,
            max_len=70,
            src_vocab=NAMES.vocab,
            trg_vocab=TRG_NAMES.vocab,
            num_batches=len(test_iter_batch),
            return_logits=True)

        predict_df = pd.DataFrame({
            'id': test_ids,
            'target': clf_preds,
            'fullname_true': preds
        })

        predict_df.set_index('id').to_csv('predictions/{}.csv'.format(
            args.tb_name))
    if not args.predict and not args.evaluate:
        print('Training starts...')
        dev_perplexity, dev_clf_loss, preds, clf_preds = train(
            model,
            lr=1e-4,
            num_epochs=args.epochs,
            print_every=args.print_freq,
            train_iter=train_iter,
            valid_iter_batch=valid_iter_batch,
            val_ids=val_ids,
            val_clf_gts=val_clf_gts,
            val_gts=val_gts,
            writer=writer)
예제 #2
0
def main():

    BATCH_SIZE = 32
    MODEL_PATH = './output/transformer_model.pth'
    BEST_MODEL_PATH = './output/transformer_model_best.pth'
    # ---------------------------define field-------------------------
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    TEXT, LABEL = get_field(tokenizer)

    # -----------------get train, val and test data--------------------
    # load the data and create the validation splits
    train_data, test_data = datasets.IMDB.splits(TEXT,
                                                 LABEL,
                                                 root='../Dataset/IMDB')

    LABEL.build_vocab(train_data)
    print(LABEL.vocab.stoi)

    train_data, eval_data = train_data.split(random_state=random.seed(SEED))

    print('Number of train data {}'.format(len(train_data)))
    print('Number of evaluate data {}'.format(len(eval_data)))
    print('Number of test data {}'.format(len(test_data)))

    # generate dataloader
    train_iterator, eval_iterator = data.BucketIterator.splits(
        (train_data, eval_data), batch_size=BATCH_SIZE, device=device)

    test_iterator = data.BucketIterator(test_data,
                                        batch_size=BATCH_SIZE,
                                        device=device)

    for batch_data in train_iterator:
        print('text size {}'.format(batch_data.text.size()))
        print('label size {}'.format(batch_data.label.size()))
        break

    # --------------------------------- build model -------------------------------
    HIDDEN_SIZE = 256
    OUTPUT_SIZE = 1
    NUM_LAYERS = 2
    BIDIRECTIONAL = True
    DROPOUT = 0.25

    bert_model = BertModel.from_pretrained('bert-base-uncased')
    model = BertGRUSentiment(bert=bert_model,
                             hidden_size=HIDDEN_SIZE,
                             num_layers=NUM_LAYERS,
                             output_size=OUTPUT_SIZE,
                             bidirectional=BIDIRECTIONAL,
                             dropout=DROPOUT)

    # frozen bert
    for name, param in model.named_parameters():
        if name.startswith('bert'):
            param.requires_grad = False
    # get model trainable parameter
    print('The model has {:,} trainable parameters'.format(
        count_trainable_parameters(model)))

    # check trainable parameter
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name)

    # ---------------------------------- config -------------------------------------------

    optimizer = optim.Adam(params=model.parameters(), lr=0.001)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    criterion = nn.BCEWithLogitsLoss()

    # ----------------------------------- train -------------------------------------------
    NUM_EPOCH = 10
    model = model.to(device)

    # train and evalate
    best_eval_loss = float('inf')
    for epoch in range(NUM_EPOCH):
        print('{}/{}'.format(epoch, NUM_EPOCH))
        train_acc, train_loss = train(model,
                                      train_iterator,
                                      optimizer=optimizer,
                                      criterion=criterion)
        eval_acc, eval_loss = evaluate(model,
                                       eval_iterator,
                                       criterion=criterion)
        scheduler.step()
        print('Train => acc {:.3f}, loss {:4f}'.format(train_acc, train_loss))
        print('Eval => acc {:.3f}, loss {:4f}'.format(eval_acc, eval_loss))

        # save model
        state = {
            'hidden_size': HIDDEN_SIZE,
            'output_size': OUTPUT_SIZE,
            'num_layer': NUM_LAYERS,
            'bidirectional': BIDIRECTIONAL,
            'dropout': DROPOUT,
            'state_dict': model.state_dict(),
        }
        os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
        torch.save(state, MODEL_PATH)
        if eval_loss < best_eval_loss:
            shutil.copy(MODEL_PATH, BEST_MODEL_PATH)
            best_eval_loss = eval_loss

    # test
    test_acc, test_loss = evaluate(model, test_iterator, criterion)
    print('Test => acc {:.3f}, loss {:4f}'.format(test_acc, test_loss))
예제 #3
0
def main():
    global USE_CUDA,DEVICE
    global UNK_TOKEN,PAD_TOKEN,SOS_TOKEN,EOS_TOKEN,TRG_NAMES,LOWER,PAD_INDEX,NAMES,MIN_FREQ
    global args,best_met,valid_minib_counter

    label_writers = []

    UNK_TOKEN = "!"
    PAD_TOKEN = "_"    
    SOS_TOKEN = "["
    EOS_TOKEN = "]"
    LOWER = False

    ID = data.Field(sequential=False,
                    use_vocab=False)

    NAMES = data.Field(tokenize=tokenize,
                       batch_first=True,
                       lower=LOWER,
                       include_lengths=True,
                       unk_token=UNK_TOKEN,
                       pad_token=PAD_TOKEN,
                       init_token=None,
                       eos_token=EOS_TOKEN)

    TRG_NAMES = data.Field(tokenize=tokenize, 
                           batch_first=True,
                           lower=LOWER,
                           include_lengths=True,
                           unk_token=UNK_TOKEN,
                           pad_token=PAD_TOKEN,
                           init_token=SOS_TOKEN,
                           eos_token=EOS_TOKEN)

    LBL = data.Field(sequential=False,
                     use_vocab=False)

    CNT = data.Field(sequential=False,
                     use_vocab=False)

    datafields = [("id", ID),
                  ("src", NAMES),
                  ("trg", TRG_NAMES),
                  ("clf", LBL),
                  ("cn", CNT)
                 ]

    train_data = data.TabularDataset(path=args.train_df_path,
                                     format='csv',
                                     skip_header=True,
                                     fields=datafields)

    train_data, valid_data = train_data.split(split_ratio=args.split_ratio,
                                              stratified=args.stratified,
                                              strata_field=args.strata_field)
    
    
    print('Train length {}, val length {}'.format(len(train_data),len(valid_data)))

    MIN_FREQ = args.min_freq  # NOTE: we limit the vocabulary to frequent words for speed
    NAMES.build_vocab(train_data.src, min_freq=MIN_FREQ)
    TRG_NAMES.build_vocab(train_data.trg, min_freq=MIN_FREQ)
    PAD_INDEX = TRG_NAMES.vocab.stoi[PAD_TOKEN]

    train_iter = data.BucketIterator(train_data,
                                     batch_size=args.batch_size,
                                     train=True, 
                                     sort_within_batch=True, 
                                     sort_key=lambda x: (len(x.src), len(x.trg)),
                                     repeat=False,
                                     device=DEVICE,
                                     shuffle=True)

    valid_iter_batch = data.Iterator(valid_data,
                               batch_size=args.batch_size,
                               train=False,
                               sort_within_batch=True,
                               sort_key=lambda x: (len(x.src), len(x.trg)),
                               repeat=False, 
                               device=DEVICE,
                               shuffle=False)

    val_ids = []
    for b in valid_iter_batch:
        val_ids.extend(list(b.id.cpu().numpy()))    
    
    print('Preparing data for validation')

    train_df = pd.read_csv('../data/proc_train.csv')
    train_df = train_df.set_index('id')
    val_gts = train_df.loc[val_ids,'fullname_true'].values
    val_ors = train_df.loc[val_ids,'fullname'].values
    incorrect_idx = list(train_df[train_df.target==1].index.values)

    incorrect_val_ids = list(set(val_ids).intersection(set(incorrect_idx)))
    correct_val_ids = list(set(val_ids)-set(incorrect_val_ids))
    
    print('Making dictionaries')
    
    id2gt = dict(train_df['fullname_true'])
    id2clf_gt = dict(train_df['target'])
    val_gts = [id2gt[_] for _ in val_ids]
    val_clf_gts = [id2clf_gt[_] for _ in val_ids]    
    del train_df
    gc.collect()
    
    model = make_model(len(NAMES.vocab),
                       len(TRG_NAMES.vocab),
                       device=DEVICE,
                       emb_size=args.emb_size,
                       hidden_size=args.hidden_size,
                       num_layers=args.num_layers,
                       dropout=args.dropout,
                       num_classes=args.num_classes)
    
    loaded_from_checkpoint = False

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location='cpu')
            args.start_epoch = checkpoint['epoch']
            best_met = checkpoint['best_met']
            model.load_state_dict(checkpoint['state_dict'])           
            print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
            loaded_from_checkpoint = True
            del checkpoint
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))    
    else:
        args.start_epoch = 0
        
    criterion = nn.CrossEntropyLoss(reduce=False).to(DEVICE)

    if args.tensorboard:
        writer = SummaryWriter('runs_encdec/{}'.format(tb_name))
    
    if args.evaluate:
        pass
    elif args.predict:
        pass
    else:
        print('Training starts...') 
        dev_perplexity,dev_clf_loss,preds,clf_preds = train(model,
                                                            lr=args.lr,
                                                            num_epochs=args.epochs,
                                                            print_every=args.print_freq,
                                                            train_iter=train_iter,
                                                            valid_iter_batch=valid_iter_batch,
                                                            val_ids=val_ids,
                                                            val_clf_gts=val_clf_gts,
                                                            val_gts=val_gts,
                                                            writer=writer)        
def load_data(batch_size, device):
    # 标签
    LABEL = data.Field(sequential=False, use_vocab=False, batch_first=True)
    # 文本
    SEN1 = data.Field(sequential=True, tokenize=tokenizer,  fix_length=50, lower=True, batch_first=True)
    SEN2 = data.Field(sequential=True, tokenize=tokenizer,  fix_length=50, lower=True, batch_first=True)

    # 构建DataSet
    train, valid = data.TabularDataset.splits(
        path='./snli_1.0/',
        skip_header=True,
        train="train4.csv",
        validation="dev3.csv",
        format='csv',
        fields=[("label", LABEL), ("sentence1", SEN1), ("sentence2", SEN2)],
    )

    test = data.TabularDataset(
        path='./snli_1.0/test3.csv',
        skip_header=True,
        format='csv',
        fields=[("sentence1", SEN1), ("sentence2", SEN2)],
    )

    # 创建词表
    SEN1.build_vocab((train.sentence11, train.sentence2), vectors=Vectors(name='/data/yinli/dataset/glove.840B.300d.txt'))
    SEN2.vocab = SEN1.vocab

    # 构建迭代器
    train_iter = data.BucketIterator(train,
                                sort_key=lambda x: len(x.SEN1),
                                sort_within_batch=False,
                                shuffle=True,
                                batch_size=batch_size,
                                repeat=False,
                                device=device)

    valid_iter = data.Iterator(valid,
                              sort=False,
                              shuffle=False,
                              sort_within_batch=False,
                              batch_size=batch_size,
                              repeat=False,
                              train=False,
                              device=device)

    test_iter = data.Iterator(test,
                               sort=False,
                               shuffle=False,
                               sort_within_batch=False,
                               batch_size=batch_size,
                               repeat=False,
                               train=False,
                               device=device)

    return train_iter, valid_iter, test_iter, SEN1.vocab, SEN2.vocab

# 加载数据集,生成迭代器
# def load_data(batch_size, device):
#     # 标签
#     LABEL = data.Field(sequential=True, batch_first=True)
#     # 文本
#     SEN1 = data.Field(sequential=True, tokenize=tokenizer, lower=True, batch_first=True)
#     SEN2 = data.Field(sequential=True, tokenize=tokenizer, lower=True, batch_first=True)
#
#     # 构建DataSet
#     train = data.TabularDataset(
#         path='./snli_1.0/train2.csv',
#         skip_header=True,
#         format='csv',
#         fields=[("label", LABEL), ("sentence1", SEN1), ("sentence2", SEN2)],
#     )
#
#     # 创建词表
#     SEN1.build_vocab(train, vectors=Vectors(name='/data/yinli/dataset/glove.840B.300d.txt'))
#     SEN2.build_vocab(train, vectors=Vectors(name='/data/yinli/dataset/glove.840B.300d.txt'))
#     LABEL.build_vocab(train)
#
#     # 构建迭代器
#     train_iter = data.BucketIterator(train,
#                                 sort_key=lambda x: len(x.SEN1),
#                                 sort_within_batch=False,
#                                 shuffle=True,
#                                 batch_size=batch_size,
#                                 repeat=False,
#                                 device=device)
#
#     return train_iter, SEN1.vocab, SEN2.vocab


# device = torch.device("cuda:1")
# train_iter, dev_iter, test_iter, sentence1_vocab, sentence2_vocab = load_data(5, 50, device)
#
# for batch in train_iter:
#     print(batch.label)
#     print(batch.sentence1)
#     print(batch.sentence2)
#     break
# print(len(sentence1_vocab.vectors))
#
# print(sentence1_vocab.stoi['frown'])
# print(sentence2_vocab.stoi['frown'])
# print(sentence1_vocab.stoi['<unk>'])
#
# del train_iter
# del dev_iter
# del test_iter
# del sentence1_vocab
# del sentence2_vocab

#
# embedding = torch.cat((sentence2_vocab.vectors ,sentence1_vocab.vectors[2:]), 0)
# print(embedding.size())
# vocab_size, embed_size = embedding.size()
# print(vocab_size)
# print(embed_size)
# print(len(label_vocab))
# print(label_vocab.stoi)
#label2id = {'<unk>': 0, '<pad>': 1, 'neutral': 2, 'contradiction': 3, 'entailment': 4}
예제 #5
0
    def build(language_pair: LanguagePair, split: Split, max_length=100, min_freq=2,
              start_token="<s>", eos_token="</s>", blank_token="<blank>",
              batch_size_train=32, batch_size_validation=32,
              batch_size_test=32, device='cpu'):
        """
        Initializes an iterator over the IWSLT dataset.
        The iterator then yields batches of size `batch_size`.

        Returns one iterator for each split alongside the input & output vocab sets.

        Example:

        >>> dataset_iterator, _, _, src_vocab, trg_vocab = IWSLTDatasetBuilder.build(
        ...                                                   language_pair=language_pair,
        ...                                                   split=Split.Train,
        ...                                                   max_length=5,
        ...                                                   batch_size_train=batch_size_train)
        >>> batch = next(iter(dataset_iterator))

        :param language_pair: The language pair for which to create a vocabulary.
        :param split: The split type.
        :param max_length: Max length of sequence.
        :param min_freq: The minimum frequency a word should have to be included in the vocabulary
        :param start_token: The token that marks the beginning of a sequence.
        :param eos_token: The token that marks an end of sequence.
        :param blank_token: The token to pad with.
        :param batch_size_train: Desired size of each training batch.
        :param batch_size_validation: Desired size of each validation batch.
        :param batch_size_test: Desired size of each testing batch.
        :param device: The device on which to store the batches.
        :type device: str or torch.device

        :returns: (train_iterator, validation_iterator, test_iterator,
                   source_field.vocab, target_field.vocab)
        """
        # load corresponding tokenizer
        source_tokenizer, target_tokenizer = language_pair.tokenizer()
        # create pytorchtext data field to generate vocabulary
        source_field = data.Field(tokenize=source_tokenizer, pad_token=blank_token)
        target_field = data.Field(tokenize=target_tokenizer, init_token=start_token,
                                  eos_token=eos_token, pad_token=blank_token)

        # Generates train and validation datasets
        settings = dict()
        for key, split_type in [
            # ("validation", Split.Validation),  # Due to a bug in TorchText, cannot set to None
            ("test", Split.Test),
        ]:
            if (split & split_type):
                pass  # Keep default split setting
            else:
                settings[key] = None  # Disable split
        # noinspection PyTypeChecker
        train, validation, *out = datasets.IWSLT.splits(
            root=ROOT_DATASET_DIR,  # To check if the dataset was already downloaded
            exts=language_pair.extensions(),
            fields=(source_field, target_field),
            filter_pred=lambda x: all(len(val) <= max_length for val in (x.src, x.trg)),
            **settings
        )

        # Build vocabulary on training set
        source_field.build_vocab(train, min_freq=min_freq)
        target_field.build_vocab(train, min_freq=min_freq)

        train_iterator, validation_iterator, test_iterator = None, None, None

        def sort_func(x):
            return data.interleave_keys(len(x.src), len(x.trg))

        if split & Split.Train:
            train_iterator = data.BucketIterator(
                dataset=train, batch_size=batch_size_train, repeat=False,
                device=device, sort_key=sort_func)
        if split & Split.Validation:
            validation_iterator = data.BucketIterator(
                dataset=validation, batch_size=batch_size_validation, repeat=False,
                device=device, sort_key=sort_func)
        if split & Split.Test:
            test, *out = out
            test_iterator = data.BucketIterator(
                dataset=test, batch_size=batch_size_test, repeat=False,
                device=device, sort_key=sort_func)

        return (
            train_iterator,
            validation_iterator,
            test_iterator,
            source_field.vocab,
            target_field.vocab,
        )