def test_shuffle(): FILES = ["../data/dataset/testTFTestAllTypes/test.data"] SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json" ds.config.set_seed(1) data1 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.GLOBAL) data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES) data2 = data2.shuffle(10000) for d1, d2 in zip(data1.create_tuple_iterator(output_numpy=True), data2.create_tuple_iterator(output_numpy=True)): for t1, t2 in zip(d1, d2): np.testing.assert_array_equal(t1, t2) ds.config.set_seed(1) DATA_ALL_FILE = "../data/dataset/testTextFileDataset/*" data1 = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.GLOBAL) data2 = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.FILES) data2 = data2.shuffle(10000) for d1, d2 in zip(data1.create_tuple_iterator(output_numpy=True), data2.create_tuple_iterator(output_numpy=True)): for t1, t2 in zip(d1, d2): np.testing.assert_array_equal(t1, t2) ds.config.set_seed(1) TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' data1 = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=ds.Shuffle.GLOBAL) data2 = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=ds.Shuffle.FILES) data2 = data2.shuffle(10000) for d1, d2 in zip(data1.create_tuple_iterator(output_numpy=True), data2.create_tuple_iterator(output_numpy=True)): for t1, t2 in zip(d1, d2): np.testing.assert_array_equal(t1, t2)
def test_clue_tnews(): """ Test TNEWS for train, test and evaluation """ TRAIN_FILE = '../data/dataset/testCLUE/tnews/train.json' TEST_FILE = '../data/dataset/testCLUE/tnews/test.json' EVAL_FILE = '../data/dataset/testCLUE/tnews/dev.json' # train buffer = [] data = ds.CLUEDataset(TRAIN_FILE, task='TNEWS', usage='train', shuffle=False) for d in data.create_dict_iterator(): buffer.append({ 'label': d['label'].item().decode("utf8"), 'label_desc': d['label_desc'].item().decode("utf8"), 'sentence': d['sentence'].item().decode("utf8"), 'keywords': d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords'] }) assert len(buffer) == 3 # test buffer = [] data = ds.CLUEDataset(TEST_FILE, task='TNEWS', usage='test', shuffle=False) for d in data.create_dict_iterator(): buffer.append({ 'id': d['id'], 'sentence': d['sentence'].item().decode("utf8"), 'keywords': d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords'] }) assert len(buffer) == 3 # eval buffer = [] data = ds.CLUEDataset(EVAL_FILE, task='TNEWS', usage='eval', shuffle=False) for d in data.create_dict_iterator(): buffer.append({ 'label': d['label'].item().decode("utf8"), 'label_desc': d['label_desc'].item().decode("utf8"), 'sentence': d['sentence'].item().decode("utf8"), 'keywords': d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords'] }) assert len(buffer) == 3
def process_tnews_clue_dataset(data_dir, label_list, bert_vocab_path, data_usage='train', shuffle_dataset=False, max_seq_len=128, batch_size=64): """Process TNEWS dataset""" ### Loading TNEWS from CLUEDataset assert data_usage in ['train', 'eval', 'test'] if data_usage == 'train': dataset = ds.CLUEDataset(os.path.join(data_dir, "train.json"), task='TNEWS', usage=data_usage, shuffle=shuffle_dataset) elif data_usage == 'eval': dataset = ds.CLUEDataset(os.path.join(data_dir, "dev.json"), task='TNEWS', usage=data_usage, shuffle=shuffle_dataset) else: dataset = ds.CLUEDataset(os.path.join(data_dir, "test.json"), task='TNEWS', usage=data_usage, shuffle=shuffle_dataset) ### Processing label if data_usage == 'test': dataset = dataset.map(input_columns=["id"], output_columns=["id", "label_id"], columns_order=["id", "label_id", "sentence"], operations=ops.Duplicate()) dataset = dataset.map(input_columns=["label_id"], operations=ops.Fill(0)) else: label_vocab = text.Vocab.from_list(label_list) label_lookup = text.Lookup(label_vocab) dataset = dataset.map(input_columns="label_desc", output_columns="label_id", operations=label_lookup) ### Processing sentence vocab = text.Vocab.from_file(bert_vocab_path) tokenizer = text.BertTokenizer(vocab, lower_case=True) lookup = text.Lookup(vocab, unknown_token='[UNK]') dataset = dataset.map(input_columns=["sentence"], operations=tokenizer) dataset = dataset.map(input_columns=["sentence"], operations=ops.Slice(slice(0, max_seq_len))) dataset = dataset.map(input_columns=["sentence"], operations=ops.Concatenate(prepend=np.array(["[CLS]"], dtype='S'), append=np.array(["[SEP]"], dtype='S'))) dataset = dataset.map(input_columns=["sentence"], output_columns=["text_ids"], operations=lookup) dataset = dataset.map(input_columns=["text_ids"], operations=ops.PadEnd([max_seq_len], 0)) dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "mask_ids"], columns_order=["label_id", "text_ids", "mask_ids"], operations=ops.Duplicate()) dataset = dataset.map(input_columns=["mask_ids"], operations=ops.Mask(ops.Relational.NE, 0, mstype.int32)) dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "segment_ids"], columns_order=["label_id", "text_ids", "mask_ids", "segment_ids"], operations=ops.Duplicate()) dataset = dataset.map(input_columns=["segment_ids"], operations=ops.Fill(0)) dataset = dataset.batch(batch_size) label = [] text_ids = [] mask_ids = [] segment_ids = [] for data in dataset: label.append(data[0]) text_ids.append(data[1]) mask_ids.append(data[2]) segment_ids.append(data[3]) return label, text_ids, mask_ids, segment_ids
def test_clue_wsc(): """ Test WSC for train, test and evaluation """ TRAIN_FILE = '../data/dataset/testCLUE/wsc/train.json' TEST_FILE = '../data/dataset/testCLUE/wsc/test.json' EVAL_FILE = '../data/dataset/testCLUE/wsc/dev.json' # train buffer = [] data = ds.CLUEDataset(TRAIN_FILE, task='WSC', usage='train') for d in data.create_dict_iterator(): buffer.append({ 'span1_index': d['span1_index'], 'span2_index': d['span2_index'], 'span1_text': d['span1_text'].item().decode("utf8"), 'span2_text': d['span2_text'].item().decode("utf8"), 'idx': d['idx'], 'label': d['label'].item().decode("utf8"), 'text': d['text'].item().decode("utf8") }) assert len(buffer) == 3 # test buffer = [] data = ds.CLUEDataset(TEST_FILE, task='WSC', usage='test') for d in data.create_dict_iterator(): buffer.append({ 'span1_index': d['span1_index'], 'span2_index': d['span2_index'], 'span1_text': d['span1_text'].item().decode("utf8"), 'span2_text': d['span2_text'].item().decode("utf8"), 'idx': d['idx'], 'text': d['text'].item().decode("utf8") }) assert len(buffer) == 3 # eval buffer = [] data = ds.CLUEDataset(EVAL_FILE, task='WSC', usage='eval') for d in data.create_dict_iterator(): buffer.append({ 'span1_index': d['span1_index'], 'span2_index': d['span2_index'], 'span1_text': d['span1_text'].item().decode("utf8"), 'span2_text': d['span2_text'].item().decode("utf8"), 'idx': d['idx'], 'label': d['label'].item().decode("utf8"), 'text': d['text'].item().decode("utf8") }) assert len(buffer) == 3
def test_clue_csl(): """ Test CSL for train, test and evaluation """ TRAIN_FILE = '../data/dataset/testCLUE/csl/train.json' TEST_FILE = '../data/dataset/testCLUE/csl/test.json' EVAL_FILE = '../data/dataset/testCLUE/csl/dev.json' # train buffer = [] data = ds.CLUEDataset(TRAIN_FILE, task='CSL', usage='train', shuffle=False) for d in data.create_dict_iterator(): buffer.append({ 'id': d['id'], 'abst': d['abst'].item().decode("utf8"), 'keyword': [i.item().decode("utf8") for i in d['keyword']], 'label': d['label'].item().decode("utf8") }) assert len(buffer) == 3 # test buffer = [] data = ds.CLUEDataset(TEST_FILE, task='CSL', usage='test', shuffle=False) for d in data.create_dict_iterator(): buffer.append({ 'id': d['id'], 'abst': d['abst'].item().decode("utf8"), 'keyword': [i.item().decode("utf8") for i in d['keyword']], }) assert len(buffer) == 3 # eval buffer = [] data = ds.CLUEDataset(EVAL_FILE, task='CSL', usage='eval', shuffle=False) for d in data.create_dict_iterator(): buffer.append({ 'id': d['id'], 'abst': d['abst'].item().decode("utf8"), 'keyword': [i.item().decode("utf8") for i in d['keyword']], 'label': d['label'].item().decode("utf8") }) assert len(buffer) == 3
def test_clue_dataset_size(): dataset = ds.CLUEDataset(CLUE_FILE, task='AFQMC', usage='train', shuffle=False) assert dataset.get_dataset_size() == 3 dataset_shard_2_0 = ds.CLUEDataset(CLUE_FILE, task='AFQMC', usage='train', shuffle=False, num_shards=2, shard_id=0) assert dataset_shard_2_0.get_dataset_size() == 2
def test_clue_iflytek(): """ Test IFLYTEK for train, test and evaluation """ TRAIN_FILE = '../data/dataset/testCLUE/iflytek/train.json' TEST_FILE = '../data/dataset/testCLUE/iflytek/test.json' EVAL_FILE = '../data/dataset/testCLUE/iflytek/dev.json' # train buffer = [] data = ds.CLUEDataset(TRAIN_FILE, task='IFLYTEK', usage='train', shuffle=False) for d in data.create_dict_iterator(): buffer.append({ 'label': d['label'].item().decode("utf8"), 'label_des': d['label_des'].item().decode("utf8"), 'sentence': d['sentence'].item().decode("utf8"), }) assert len(buffer) == 3 # test buffer = [] data = ds.CLUEDataset(TEST_FILE, task='IFLYTEK', usage='test', shuffle=False) for d in data.create_dict_iterator(): buffer.append({ 'id': d['id'], 'sentence': d['sentence'].item().decode("utf8") }) assert len(buffer) == 3 # eval buffer = [] data = ds.CLUEDataset(EVAL_FILE, task='IFLYTEK', usage='eval', shuffle=False) for d in data.create_dict_iterator(): buffer.append({ 'label': d['label'].item().decode("utf8"), 'label_des': d['label_des'].item().decode("utf8"), 'sentence': d['sentence'].item().decode("utf8") }) assert len(buffer) == 3
def test_clue_to_device(): """ Test CLUE with to_device """ TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False) data = data.to_device() data.send()
def test_clue_exception_file_path(): """ Test file info in err msg when exception occur of CLUE dataset """ TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' def exception_func(item): raise Exception("Error occur!") try: data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train') data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1) for _ in data.create_dict_iterator(): pass assert False except RuntimeError as e: assert "map operation: [PyFunc] failed. The corresponding data files" in str( e) try: data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train') data = data.map(operations=exception_func, input_columns=["sentence1"], num_parallel_workers=1) for _ in data.create_dict_iterator(): pass assert False except RuntimeError as e: assert "map operation: [PyFunc] failed. The corresponding data files" in str( e) try: data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train') data = data.map(operations=exception_func, input_columns=["sentence2"], num_parallel_workers=1) for _ in data.create_dict_iterator(): pass assert False except RuntimeError as e: assert "map operation: [PyFunc] failed. The corresponding data files" in str( e)
def test_clue_cmnli(): """ Test CMNLI for train, test and evaluation """ TRAIN_FILE = '../data/dataset/testCLUE/cmnli/train.json' TEST_FILE = '../data/dataset/testCLUE/cmnli/test.json' EVAL_FILE = '../data/dataset/testCLUE/cmnli/dev.json' # train buffer = [] data = ds.CLUEDataset(TRAIN_FILE, task='CMNLI', usage='train', shuffle=False) for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): buffer.append({ 'label': d['label'].item().decode("utf8"), 'sentence1': d['sentence1'].item().decode("utf8"), 'sentence2': d['sentence2'].item().decode("utf8") }) assert len(buffer) == 3 # test buffer = [] data = ds.CLUEDataset(TEST_FILE, task='CMNLI', usage='test', shuffle=False) for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): buffer.append({ 'id': d['id'], 'sentence1': d['sentence1'], 'sentence2': d['sentence2'] }) assert len(buffer) == 3 # eval buffer = [] data = ds.CLUEDataset(EVAL_FILE, task='CMNLI', usage='eval', shuffle=False) for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): buffer.append({ 'label': d['label'], 'sentence1': d['sentence1'], 'sentence2': d['sentence2'] }) assert len(buffer) == 3
def test_clue_afqmc(): """ Test AFQMC for train, test and evaluation """ TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' TEST_FILE = '../data/dataset/testCLUE/afqmc/test.json' EVAL_FILE = '../data/dataset/testCLUE/afqmc/dev.json' # train buffer = [] data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False) for d in data.create_dict_iterator(): buffer.append({ 'label': d['label'].item().decode("utf8"), 'sentence1': d['sentence1'].item().decode("utf8"), 'sentence2': d['sentence2'].item().decode("utf8") }) assert len(buffer) == 3 # test buffer = [] data = ds.CLUEDataset(TEST_FILE, task='AFQMC', usage='test', shuffle=False) for d in data.create_dict_iterator(): buffer.append({ 'id': d['id'], 'sentence1': d['sentence1'].item().decode("utf8"), 'sentence2': d['sentence2'].item().decode("utf8") }) assert len(buffer) == 3 # evaluation buffer = [] data = ds.CLUEDataset(EVAL_FILE, task='AFQMC', usage='eval', shuffle=False) for d in data.create_dict_iterator(): buffer.append({ 'label': d['label'].item().decode("utf8"), 'sentence1': d['sentence1'].item().decode("utf8"), 'sentence2': d['sentence2'].item().decode("utf8") }) assert len(buffer) == 3
def test_global_shuffle_pass(): FILES = ["../data/dataset/testTFTestAllTypes/test.data"] SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json" ds.config.set_seed(1) data1 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.GLOBAL) data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES) data2 = data2.shuffle(10000) for d1, d2 in zip(data1, data2): for t1, t2 in zip(d1, d2): assert np.array_equal(t1, t2) ds.config.set_seed(1) DATA_ALL_FILE = "../data/dataset/testTextFileDataset/*" data1 = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.GLOBAL) data2 = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.FILES) data2 = data2.shuffle(10000) for d1, d2 in zip(data1, data2): for t1, t2 in zip(d1, d2): assert np.array_equal(t1, t2) ds.config.set_seed(1) TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' data1 = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=ds.Shuffle.GLOBAL) data2 = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=ds.Shuffle.FILES) data2 = data2.shuffle(10000) for d1, d2 in zip(data1, data2): for t1, t2 in zip(d1, d2): assert np.array_equal(t1, t2)
def test_clue_num_samples(): """ Test num_samples param of CLUE dataset """ TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_samples=2) count = 0 for _ in data.create_dict_iterator(): count += 1 assert count == 2
def test_clue_invalid_files(): """ Test CLUE with invalid files """ AFQMC_DIR = '../data/dataset/testCLUE/afqmc' afqmc_train_json = os.path.join(AFQMC_DIR) with pytest.raises(ValueError) as info: _ = ds.CLUEDataset(afqmc_train_json, task='AFQMC', usage='train', shuffle=False) assert "The following patterns did not match any files" in str(info.value) assert AFQMC_DIR in str(info.value)
def test_clue_num_shards(): """ Test num_shards param of CLUE dataset """ TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' buffer = [] data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_shards=3, shard_id=1) for d in data.create_dict_iterator(): buffer.append({ 'label': d['label'].item().decode("utf8"), 'sentence1': d['sentence1'].item().decode("utf8"), 'sentence2': d['sentence2'].item().decode("utf8") }) assert len(buffer) == 1
def test_clue_padded_and_skip_with_0_samples(): """ Test num_samples param of CLUE dataset """ TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train') count = 0 for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): count += 1 assert count == 3 data_copy1 = copy.deepcopy(data) sample = { "label": np.array(1, np.string_), "sentence1": np.array(1, np.string_), "sentence2": np.array(1, np.string_) } samples = [sample] padded_ds = ds.PaddedDataset(samples) dataset = data + padded_ds testsampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None) dataset.use_sampler(testsampler) assert dataset.get_dataset_size() == 2 count = 0 for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): count += 1 assert count == 2 dataset = dataset.skip(count=2) # dataset2 has none samples count = 0 for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): count += 1 assert count == 0 with pytest.raises(ValueError, match="There is no samples in the "): dataset = dataset.concat(data_copy1) count = 0 for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): count += 1 assert count == 2
def test_clue(): """ Test CLUE with repeat, skip and so on """ TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' buffer = [] data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False) data = data.repeat(2) data = data.skip(3) for d in data.create_dict_iterator(): buffer.append({ 'label': d['label'].item().decode("utf8"), 'sentence1': d['sentence1'].item().decode("utf8"), 'sentence2': d['sentence2'].item().decode("utf8") }) assert len(buffer) == 3
def test_get_column_name_clue(): data = ds.CLUEDataset(CLUE_DIR, task="AFQMC", usage="train") assert data.get_col_names() == ["label", "sentence1", "sentence2"]
def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path, data_usage='train', shuffle_dataset=False, max_seq_len=128, batch_size=64, drop_remainder=True): """Process CMNLI dataset""" ### Loading CMNLI from CLUEDataset assert data_usage in ['train', 'eval', 'test'] if data_usage == 'train': dataset = ds.CLUEDataset(os.path.join(data_dir, "train.json"), task='CMNLI', usage=data_usage, shuffle=shuffle_dataset) elif data_usage == 'eval': dataset = ds.CLUEDataset(os.path.join(data_dir, "dev.json"), task='CMNLI', usage=data_usage, shuffle=shuffle_dataset) else: dataset = ds.CLUEDataset(os.path.join(data_dir, "test.json"), task='CMNLI', usage=data_usage, shuffle=shuffle_dataset) ### Processing label if data_usage == 'test': dataset = dataset.map( operations=ops.Duplicate(), input_columns=["id"], output_columns=["id", "label_id"], column_order=["id", "label_id", "sentence1", "sentence2"]) dataset = dataset.map(operations=ops.Fill(0), input_columns=["label_id"]) else: label_vocab = text.Vocab.from_list(label_list) label_lookup = text.Lookup(label_vocab) dataset = dataset.map(operations=label_lookup, input_columns="label", output_columns="label_id") ### Processing sentence pairs vocab = text.Vocab.from_file(bert_vocab_path) tokenizer = text.BertTokenizer(vocab, lower_case=True) lookup = text.Lookup(vocab, unknown_token='[UNK]') ### Tokenizing sentences and truncate sequence pair dataset = dataset.map(operations=tokenizer, input_columns=["sentence1"]) dataset = dataset.map(operations=tokenizer, input_columns=["sentence2"]) dataset = dataset.map(operations=text.TruncateSequencePair(max_seq_len - 3), input_columns=["sentence1", "sentence2"]) ### Adding special tokens dataset = dataset.map(operations=ops.Concatenate( prepend=np.array(["[CLS]"], dtype='S'), append=np.array(["[SEP]"], dtype='S')), input_columns=["sentence1"]) dataset = dataset.map( operations=ops.Concatenate(append=np.array(["[SEP]"], dtype='S')), input_columns=["sentence2"]) ### Generating segment_ids dataset = dataset.map( operations=ops.Duplicate(), input_columns=["sentence1"], output_columns=["sentence1", "type_sentence1"], column_order=["sentence1", "type_sentence1", "sentence2", "label_id"]) dataset = dataset.map(operations=ops.Duplicate(), input_columns=["sentence2"], output_columns=["sentence2", "type_sentence2"], column_order=[ "sentence1", "type_sentence1", "sentence2", "type_sentence2", "label_id" ]) dataset = dataset.map(operations=[lookup, ops.Fill(0)], input_columns=["type_sentence1"]) dataset = dataset.map(operations=[lookup, ops.Fill(1)], input_columns=["type_sentence2"]) dataset = dataset.map( operations=ops.Concatenate(), input_columns=["type_sentence1", "type_sentence2"], output_columns=["segment_ids"], column_order=["sentence1", "sentence2", "segment_ids", "label_id"]) dataset = dataset.map(operations=ops.PadEnd([max_seq_len], 0), input_columns=["segment_ids"]) ### Generating text_ids dataset = dataset.map(operations=ops.Concatenate(), input_columns=["sentence1", "sentence2"], output_columns=["text_ids"], column_order=["text_ids", "segment_ids", "label_id"]) dataset = dataset.map(operations=lookup, input_columns=["text_ids"]) dataset = dataset.map(operations=ops.PadEnd([max_seq_len], 0), input_columns=["text_ids"]) ### Generating mask_ids dataset = dataset.map( operations=ops.Duplicate(), input_columns=["text_ids"], output_columns=["text_ids", "mask_ids"], column_order=["text_ids", "mask_ids", "segment_ids", "label_id"]) dataset = dataset.map(operations=ops.Mask(ops.Relational.NE, 0, mstype.int32), input_columns=["mask_ids"]) dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) return dataset