예제 #1
0
 def test_flatten(self):
     dset_split = Dataset.from_dict(
         {
             "a": [{
                 "b": {
                     "c": ["text"]
                 }
             }] * 10,
             "foo": [1] * 10
         },
         features=Features({
             "a": {
                 "b": Sequence({"c": Value("string")})
             },
             "foo": Value("int64")
         }),
     )
     dset = DatasetDict({"train": dset_split, "test": dset_split})
     dset = dset.flatten()
     self.assertDictEqual(dset.column_names, {
         "train": ["a.b.c", "foo"],
         "test": ["a.b.c", "foo"]
     })
     self.assertListEqual(sorted(dset["train"].features.keys()),
                          ["a.b.c", "foo"])
     self.assertDictEqual(
         dset["train"].features,
         Features({
             "a.b.c": Sequence(Value("string")),
             "foo": Value("int64")
         }))
     del dset
예제 #2
0
 def test_align_labels_with_mapping(self):
     train_features = Features({
         "input_text":
         Value("string"),
         "input_labels":
         ClassLabel(num_classes=3,
                    names=["entailment", "neutral", "contradiction"]),
     })
     test_features = Features({
         "input_text":
         Value("string"),
         "input_labels":
         ClassLabel(num_classes=3,
                    names=["entailment", "contradiction", "neutral"]),
     })
     train_data = {
         "input_text": ["a", "a", "b", "b", "c", "c"],
         "input_labels": [0, 0, 1, 1, 2, 2]
     }
     test_data = {
         "input_text": ["a", "a", "c", "c", "b", "b"],
         "input_labels": [0, 0, 1, 1, 2, 2]
     }
     label2id = {"CONTRADICTION": 0, "ENTAILMENT": 2, "NEUTRAL": 1}
     id2label = {v: k for k, v in label2id.items()}
     train_expected_labels = [2, 2, 1, 1, 0, 0]
     test_expected_labels = [2, 2, 0, 0, 1, 1]
     train_expected_label_names = [
         id2label[idx] for idx in train_expected_labels
     ]
     test_expected_label_names = [
         id2label[idx] for idx in test_expected_labels
     ]
     dsets = DatasetDict({
         "train":
         Dataset.from_dict(train_data, features=train_features),
         "test":
         Dataset.from_dict(test_data, features=test_features),
     })
     dsets = dsets.align_labels_with_mapping(label2id, "input_labels")
     self.assertListEqual(train_expected_labels,
                          dsets["train"]["input_labels"])
     self.assertListEqual(test_expected_labels,
                          dsets["test"]["input_labels"])
     train_aligned_label_names = [
         dsets["train"].features["input_labels"].int2str(idx)
         for idx in dsets["train"]["input_labels"]
     ]
     test_aligned_label_names = [
         dsets["test"].features["input_labels"].int2str(idx)
         for idx in dsets["test"]["input_labels"]
     ]
     self.assertListEqual(train_expected_label_names,
                          train_aligned_label_names)
     self.assertListEqual(test_expected_label_names,
                          test_aligned_label_names)
예제 #3
0
def test_dummy_dataset_serialize_s3(s3, dataset):
    dsets = DatasetDict({"train": dataset, "test": dataset.select(range(2))})
    mock_bucket = s3_test_bucket_name
    dataset_path = f"s3://{mock_bucket}/datasets/dict"
    column_names = dsets["train"].column_names
    lengths = [len(dset) for dset in dsets.values()]
    dataset.save_to_disk(dataset_path, s3)
    dataset = dataset.load_from_disk(dataset_path, s3)

    assert sorted(dsets) == ["test", "train"]
    assert [len(dset) for dset in dsets.values()] == lengths
    assert dsets["train"].column_names == column_names
    assert dsets["test"].column_names == column_names
예제 #4
0
def test_datasetdict_from_csv(split, features, keep_in_memory, csv_path,
                              tmp_path):
    if split:
        path = {split: csv_path}
    else:
        split = "train"
        path = {"train": csv_path, "test": csv_path}
    cache_dir = tmp_path / "cache"
    # CSV file loses col_1 string dtype information: default now is "int64" instead of "string"
    default_expected_features = {
        "col_1": "int64",
        "col_2": "int64",
        "col_3": "float64"
    }
    expected_features = features.copy(
    ) if features else default_expected_features
    features = Features(
        {feature: Value(dtype)
         for feature, dtype in features.items()}) if features else None
    with assert_arrow_memory_increases(
    ) if keep_in_memory else assert_arrow_memory_doesnt_increase():
        dataset = DatasetDict.from_csv(path,
                                       features=features,
                                       cache_dir=cache_dir,
                                       keep_in_memory=keep_in_memory)
    assert isinstance(dataset, DatasetDict)
    dataset = dataset[split]
    assert dataset.num_rows == 4
    assert dataset.num_columns == 3
    assert dataset.column_names == ["col_1", "col_2", "col_3"]
    assert dataset.split == split
    for feature, expected_dtype in expected_features.items():
        assert dataset.features[feature].dtype == expected_dtype
예제 #5
0
def test_datasetdict_from_text(split, features, keep_in_memory, text_path,
                               tmp_path):
    if split:
        path = {split: text_path}
    else:
        split = "train"
        path = {"train": text_path, "test": text_path}
    cache_dir = tmp_path / "cache"
    default_expected_features = {"text": "string"}
    expected_features = features.copy(
    ) if features else default_expected_features
    features = Features(
        {feature: Value(dtype)
         for feature, dtype in features.items()}) if features else None
    with assert_arrow_memory_increases(
    ) if keep_in_memory else assert_arrow_memory_doesnt_increase():
        dataset = DatasetDict.from_text(path,
                                        features=features,
                                        cache_dir=cache_dir,
                                        keep_in_memory=keep_in_memory)
    assert isinstance(dataset, DatasetDict)
    dataset = dataset[split]
    assert dataset.num_rows == 4
    assert dataset.num_columns == 1
    assert dataset.column_names == ["text"]
    assert dataset.split == split
    for feature, expected_dtype in expected_features.items():
        assert dataset.features[feature].dtype == expected_dtype
예제 #6
0
 def _create_dummy_dataset_dict(self, multiple_columns=False) -> DatasetDict:
     return DatasetDict(
         {
             "train": self._create_dummy_dataset(multiple_columns=multiple_columns),
             "test": self._create_dummy_dataset(multiple_columns=multiple_columns),
         }
     )
예제 #7
0
def test_datasetdict_from_json(
    split,
    features,
    keep_in_memory,
    jsonl_path,
    tmp_path,
):
    file_path = jsonl_path
    field = None
    if split:
        path = {split: file_path}
    else:
        split = "train"
        path = {"train": file_path, "test": file_path}
    cache_dir = tmp_path / "cache"
    default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
    expected_features = features.copy() if features else default_expected_features
    features = Features({feature: Value(dtype) for feature, dtype in features.items()}) if features else None
    with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
        dataset = DatasetDict.from_json(
            path, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, field=field
        )
    assert isinstance(dataset, DatasetDict)
    dataset = dataset[split]
    assert dataset.num_rows == 4
    assert dataset.num_columns == 3
    assert dataset.column_names == ["col_1", "col_2", "col_3"]
    assert dataset.split == split
    for feature, expected_dtype in expected_features.items():
        assert dataset.features[feature].dtype == expected_dtype
    def test_serialization(self):
        with tempfile.TemporaryDirectory() as tmp_dir:
            dsets = self._create_dummy_dataset_dict()
            dsets.save_to_disk(tmp_dir)
            dsets = DatasetDict.load_from_disk(tmp_dir)
            self.assertListEqual(sorted(dsets), ["test", "train"])
            self.assertEqual(len(dsets["train"]), 30)
            self.assertListEqual(dsets["train"].column_names, ["filename"])
            self.assertEqual(len(dsets["test"]), 30)
            self.assertListEqual(dsets["test"].column_names, ["filename"])

            del dsets["test"]
            dsets.save_to_disk(tmp_dir)
            dsets = DatasetDict.load_from_disk(tmp_dir)
            self.assertListEqual(sorted(dsets), ["train"])
            self.assertEqual(len(dsets["train"]), 30)
            self.assertListEqual(dsets["train"].column_names, ["filename"])
예제 #9
0
def load_and_split_dataset(dataset_args, split_percentage=5):
    """Alternative: if no validation set available, manuallly split the train set"""

    dataset = DatasetDict()
    dataset["train"] = load_dataset(**dataset_args,
                                    split=f"train[{split_percentage}%:]")
    dataset["validation"] = load_dataset(**dataset_args,
                                         split=f"train[:{split_percentage}%]")
    return dataset
예제 #10
0
def test_datasetdict_from_text_keep_in_memory(keep_in_memory, text_path,
                                              tmp_path):
    cache_dir = tmp_path / "cache"
    expected_features = {"text": "string"}
    with assert_arrow_memory_increases(
    ) if keep_in_memory else assert_arrow_memory_doesnt_increase():
        dataset = DatasetDict.from_text({"train": text_path},
                                        cache_dir=cache_dir,
                                        keep_in_memory=keep_in_memory)
    _check_text_datasetdict(dataset, expected_features)
예제 #11
0
def test_datasetdict_from_text_split(split, text_path, tmp_path):
    if split:
        path = {split: text_path}
    else:
        split = "train"
        path = {"train": text_path, "test": text_path}
    cache_dir = tmp_path / "cache"
    expected_features = {"text": "string"}
    dataset = DatasetDict.from_text(path, cache_dir=cache_dir)
    _check_text_datasetdict(dataset,
                            expected_features,
                            splits=list(path.keys()))
    assert all(dataset[split].split == split for split in path.keys())
예제 #12
0
def test_datasetdict_from_text_features(features, text_path, tmp_path):
    cache_dir = tmp_path / "cache"
    default_expected_features = {"text": "string"}
    expected_features = features.copy(
    ) if features else default_expected_features
    features = (Features({
        feature: Value(dtype)
        for feature, dtype in features.items()
    }) if features is not None else None)
    dataset = DatasetDict.from_text({"train": text_path},
                                    features=features,
                                    cache_dir=cache_dir)
    _check_text_datasetdict(dataset, expected_features)
예제 #13
0
def tokenize_dataset(dataset: DatasetDict) -> DatasetDict:
    tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

    def tokenize_function(examples):
        return tokenizer(examples["text"], padding="max_length",
                         truncation=True)

    tokenized_datasets = dataset.map(tokenize_function, batched=True)
    tokenized_datasets = tokenized_datasets.remove_columns(['text'])
    tokenized_datasets = tokenized_datasets.rename_column('label', 'labels')
    tokenized_datasets.set_format('torch')

    return tokenized_datasets
예제 #14
0
def test_datasetdict_from_parquet_keep_in_memory(keep_in_memory, parquet_path,
                                                 tmp_path):
    cache_dir = tmp_path / "cache"
    expected_features = {
        "col_1": "string",
        "col_2": "int64",
        "col_3": "float64"
    }
    with assert_arrow_memory_increases(
    ) if keep_in_memory else assert_arrow_memory_doesnt_increase():
        dataset = DatasetDict.from_parquet({"train": parquet_path},
                                           cache_dir=cache_dir,
                                           keep_in_memory=keep_in_memory)
    _check_parquet_datasetdict(dataset, expected_features)
예제 #15
0
def test_datasetdict_from_parquet_split(split, parquet_path, tmp_path):
    if split:
        path = {split: parquet_path}
    else:
        split = "train"
        path = {"train": parquet_path, "test": parquet_path}
    cache_dir = tmp_path / "cache"
    expected_features = {
        "col_1": "string",
        "col_2": "int64",
        "col_3": "float64"
    }
    dataset = DatasetDict.from_parquet(path, cache_dir=cache_dir)
    _check_parquet_datasetdict(dataset,
                               expected_features,
                               splits=list(path.keys()))
    assert all(dataset[split].split == split for split in path.keys())
예제 #16
0
def test_datasetdict_from_parquet_features(features, parquet_path, tmp_path):
    cache_dir = tmp_path / "cache"
    default_expected_features = {
        "col_1": "string",
        "col_2": "int64",
        "col_3": "float64"
    }
    expected_features = features.copy(
    ) if features else default_expected_features
    features = (Features({
        feature: Value(dtype)
        for feature, dtype in features.items()
    }) if features is not None else None)
    dataset = DatasetDict.from_parquet({"train": parquet_path},
                                       features=features,
                                       cache_dir=cache_dir)
    _check_parquet_datasetdict(dataset, expected_features)
예제 #17
0
def load_and_concatenate_datasets(data_args):
    """Load and concatenate multiple compatible datasets"""
    train_datasets, validation_datasets = [], []
    for name, config in zip(data_args.dataset_name,
                            data_args.dataset_config_name):

        dataset = load_dataset(name, config)
        if "validation" not in dataset.keys():
            validation_ds = load_dataset(
                name,
                config,
                split=f"train[:{data_args.validation_split_percentage}%]",
            )
            train_ds = load_dataset(
                name,
                config,
                split=f"train[{data_args.validation_split_percentage}%:]",
            )
        else:
            validation_ds = dataset["validation"]
            train_ds = dataset["train"]

        # Some specific preprocessing to align fields on known datasets
        # extraneous fields not used in language modeling are also removed
        # after preprocessing
        if name == "wikipedia":
            train_ds.remove_columns_("title")
            validation_ds.remove_columns_("title")
        elif name == "ptb_text_only":
            train_ds.rename_column_("sentence", "text")
            validation_ds.rename_column_("sentence", "text")

        train_datasets.append(train_ds)
        validation_datasets.append(validation_ds)

    for ds_idx in range(1, len(train_datasets)):
        assert train_datasets[ds_idx].features.type == \
            train_datasets[ds_idx - 1].features.type, \
            "Features name and type must match between all datasets"

    datasets = DatasetDict()
    datasets["train"] = concatenate_datasets(train_datasets)
    datasets["validation"] = concatenate_datasets(validation_datasets)

    return datasets
예제 #18
0
def test_datasetdict_from_csv_features(features, csv_path, tmp_path):
    cache_dir = tmp_path / "cache"
    # CSV file loses col_1 string dtype information: default now is "int64" instead of "string"
    default_expected_features = {
        "col_1": "int64",
        "col_2": "int64",
        "col_3": "float64"
    }
    expected_features = features.copy(
    ) if features else default_expected_features
    features = (Features({
        feature: Value(dtype)
        for feature, dtype in features.items()
    }) if features is not None else None)
    dataset = DatasetDict.from_csv({"train": csv_path},
                                   features=features,
                                   cache_dir=cache_dir)
    _check_csv_datasetdict(dataset, expected_features)
예제 #19
0
def encode_glue_samples_to_hf_dataset(samples,
                                      encode_datasets_fn,
                                      glue_labels,
                                      tokenizer,
                                      data_args,
                                      is_test=False):
    arrow_dataset = convert_glue_samples_to_hf_dataset(samples,
                                                       glue_labels,
                                                       is_test=is_test)

    hf_datasets = DatasetDict({
        'dataset': arrow_dataset,
    })

    label_to_id = {v: i for i, v in enumerate(glue_labels)}
    sentence1_key = "sentence1"
    sentence2_key = "sentence2"

    hf_datasets = encode_datasets_fn(tokenizer, hf_datasets, data_args,
                                     label_to_id, sentence1_key, sentence2_key)

    return hf_datasets['dataset']
def get_dataset(tokenizer, args, output_all_cols=False, data_dir=''):
    ds_path = os.path.join(data_dir, f'task1/preprocessed_data',
                           args.transformer)
    print(f'Dataset path: {ds_path}')
    try:
        encoded_ds = DatasetDict.load_from_disk(ds_path)
        print('Reloaded persisted dataset.')
    except:
        ds: DatasetDict = load_dataset("humicroedit", "subtask-1")
        glove = torchtext.vocab.GloVe(name='840B',
                                      dim=300,
                                      cache=os.path.join(
                                          os.environ['HOME'], '.vector_cache'))
        synset_sizes = get_synsets_sizes(ds)

        ds = ds.rename_column('edit', 'word_fin')
        ds = ds.map(
            get_preprocess_ds(glove=glove,
                              synset_sizes=synset_sizes,
                              add_amb_feat=True))
        ds = ds.remove_columns(['original'])

        ds = ds.rename_column('meanGrade', 'grade')
        encode_fn = get_encode(tokenizer)
        encoded_ds = ds.map(encode_fn, batched=True, batch_size=100)

        print('Saving preprocessed dataset.')
        os.makedirs(ds_path)
        encoded_ds.save_to_disk(ds_path)

    encoded_ds_cols = get_encoded_ds_cols(args)
    for _ds in encoded_ds.values():
        _ds.set_format(type='torch',
                       columns=encoded_ds_cols + ['grade'],
                       output_all_columns=output_all_cols)
    return encoded_ds
예제 #21
0
def main(train_function):

    # ----- Parse local_rank for torch.distributed.launch -----------

    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int)
    local_rank = parser.parse_args().local_rank
    if local_rank is None:
        local_rank = 0

    # ----- Setup logging -----------

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    logger.setLevel(logging.INFO if is_main_process(local_rank) else logging.WARN)

    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(local_rank):
        transformers.utils.logging.set_verbosity_info()
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()

    # ----- Configurable Params -----------

    # List of dicts with configuration for each dataset to be loaded
    # see available datasets in the Hub: https://huggingface.co/datasets. sizes
    # are of generated dataset, can be an order of magnitude larger after tokenization.
    # Not all datasets can be concatenated without preprocessing, features must align
    datasets_args = [
        dict(path="wikitext", name="wikitext-2-raw-v1"),  # 12.91 MB
        # dict(path="wikitext", name="wikitext-103-raw-v1"),  # 524 MB
        # dict(path="ptb_text_only"), # 5.7 MB
        # dict(path="bookcorpus"),  # 4.63 GB
        # dict(path="wikipedia"),  # 35.38 GB
    ]

    # Training params
    # note: in V100 bs=8 uses 11/16 of available gpu mem, bs=12 uses 15/16

    output_dir = os.path.expanduser("~/nta/results/bert")
    training_args = TrainingArguments(
        # Logging
        output_dir=output_dir,
        logging_first_step=True,
        logging_steps=10,  # also define eval_steps
        eval_steps=10,
        max_steps=30,  # num_train_epochs replaced by steps
        disable_tqdm=True,
        run_name="debug_run",  # used for wandb, not for Ray
        # hyperparams
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        learning_rate=1e-4,
        lr_scheduler_type="linear",
        warmup_steps=500,
        weight_decay=1e-6,
    )

    # Evaluate refers to evaluating perplexity on trained model in the validation set
    # doesn't refer to finetuning and evaluating on downstream tasks such as GLUE
    seed = random.randint(0, 1000000)

    # Changing the tokenizer will result in re-tokenizing the dataset.
    # As a reference, BERT tokenization will take ~ 3 hours for a 5GB dataset
    config_class = BertConfig
    tokenizer_name = "bert-base-cased"

    # ----- Seed -----------

    set_seed(seed)
    print(f"Seed to reproduce: {seed}")

    # ----- Dataset -----------

    # Load multiple datasets and concatenate.
    # using only 'train' and 'validation' sets, could also include 'test'
    # if no split is defined, load_dataset returns DatasetDict with all available splits
    train_datasets = [load_dataset(**args, split="train") for args in datasets_args]
    val_datasets = [load_dataset(**args, split="validation") for args in datasets_args]

    dataset = DatasetDict()
    dataset["train"] = concatenate_datasets(train_datasets)
    dataset["validation"] = concatenate_datasets(val_datasets)

    def load_and_split_dataset(dataset_args, split_percentage=5):
        """Alternative: if no validation set available, manuallly split the train set"""

        dataset = DatasetDict()
        dataset["train"] = load_dataset(
            **dataset_args, split=f"train[{split_percentage}%:]"
        )
        dataset["validation"] = load_dataset(
            **dataset_args, split=f"train[:{split_percentage}%]"
        )
        return dataset

    # ----- Load Model -----------

    # Load model
    config = config_class()
    model = AutoModelForMaskedLM.from_config(config)

    # Load tokenizer
    # use_fast falls back to tokenizer lib implementation under the hood
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)
    model.resize_token_embeddings(len(tokenizer))

    # ----- Preprocess dataset -----------

    # Only use the text column name when doing language modeling
    # this feature might have a different name depending on the dataset
    # might need to change column names prior to concatenating, if that is the case
    column_names = dataset["train"].column_names
    text_column_name = "text" if "text" in column_names else column_names[0]

    # Setting overwrite_cache to False will retokenize the dataset.
    # do not overwrite cache if using shared cache repository.
    overwrite_cache = False
    preprocessing_num_workers = None

    # We tokenize every text, then concatenate them together before splitting in smaller
    # parts. We use `return_special_tokens_mask=True` given
    # DataCollatorForLanguageModeling is more efficient when it
    # receives the `special_tokens_mask`.
    def tokenize_function(examples):
        return tokenizer(examples[text_column_name], return_special_tokens_mask=True)

    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=column_names,
        num_proc=preprocessing_num_workers,
        load_from_cache_file=not overwrite_cache,
    )

    # Main data processing function that will concatenate all texts from our dataset and
    # generate chunks of max_seq_length.
    max_seq_length = tokenizer.model_max_length

    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it
        # instead of this drop, you can customize this part to your needs.
        total_length = (total_length // max_seq_length) * max_seq_length
        # Split by chunks of max_len.
        result = {
            k: [t[i : i + max_seq_length] for i in
                range(0, total_length, max_seq_length)]
            for k, t in concatenated_examples.items()
        }
        return result

    # Note that with `batched=True`, this map processes 1,000 texts together, so
    # group_texts throws away a remainder for each of those groups of 1,000 texts.
    # You can adjust batch_size here but a higher value will be slower to preprocess.
    tokenized_dataset = tokenized_dataset.map(
        group_texts,
        batched=True,
        num_proc=preprocessing_num_workers,
        load_from_cache_file=not overwrite_cache,
    )

    # Data collator
    # This one will take care of randomly masking the tokens.
    # Q: what about dynamic masking, used in Roberta?
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm_probability=0.15
    )

    # ----- Setup Trainer -----------

    # Initialize Trainer. Similar to Vernon's Experiment class.
    # dataloader and training loop are contained in Trainer abstraction
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset["validation"],
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    # ----- Functions to train and evaluate -----------

    if train_function == "huggingface":
        # Tested
        run_hf(trainer, logger, output_dir, save_model=True, evaluate=True)

    elif train_function == "ray_single_node":
        # Tested
        run_ray_single_instance(
            trainer,
            logger,
            name="bert_test",
            config=None,
            num_samples=1,
            local_dir=os.path.expanduser("~/nta/results/experiments/transformers"),
            keep_checkpoints_num=1,
            resources_per_trial={"cpu": 8},
            # note: checkpoint arguments cannot be used with a checkpointable function
        )

    elif train_function == "ray_multiple_nodes":
        # Untested
        run_ray_distributed(
            trainer,
            logger,
            name="bert_test",
            config=None,
            num_samples=1,
            local_dir=os.path.expanduser("~/nta/results/experiments/transformers"),
            keep_checkpoints_num=1,
            queue_trials=True,
            verbose=2,
            resources_per_trial={"gpu": 4},
        )
예제 #22
0
set_seed(seed)
print(f"Seed to reproduce: {seed}")

# ----- Dataset -----------

# Load multiple datasets and concatenate.
# using only 'train' and 'validation' sets, could also include 'test'
# if no split is defined, load_dataset returns a DatasetDict with all available splits
train_datasets = [
    load_dataset(**args, split="train") for args in datasets_args
]
val_datasets = [
    load_dataset(**args, split="validation") for args in datasets_args
]

dataset = DatasetDict()
dataset["train"] = concatenate_datasets(train_datasets)
dataset["validation"] = concatenate_datasets(val_datasets)


def load_and_split_dataset(dataset_args, split_percentage=5):
    """Alternative: if no validation set available, manuallly split the train set"""

    dataset = DatasetDict()
    dataset["train"] = load_dataset(**dataset_args,
                                    split=f"train[{split_percentage}%:]")
    dataset["validation"] = load_dataset(**dataset_args,
                                         split=f"train[:{split_percentage}%]")
    return dataset

def get_dataset(tokenizer,
                model_id=0,
                args=None,
                output_all_cols=False,
                data_dir=''):
    ds_path = os.path.join(data_dir,
                           f'task2/preprocessed_data/model{model_id}',
                           args.transformer)
    print(f'Dataset path: {ds_path}')
    try:
        encoded_ds = DatasetDict.load_from_disk(ds_path)
        print('Reloaded persisted dataset.')
    except:
        ds: DatasetDict = load_dataset("humicroedit", "subtask-2")
        glove, synset_sizes = None, None
        if model_id == 0:
            glove = torchtext.vocab.GloVe(name='840B',
                                          dim=300,
                                          cache=os.path.join(
                                              os.environ['HOME'],
                                              '.vector_cache'))
            synset_sizes = get_synsets_sizes(ds, task=2)

        for i in range(2):
            ds = ds.rename_column(f'edit{i+1}', f'word_fin{i+1}')
            ds = ds.map(
                get_preprocess_ds(glove=glove,
                                  synset_sizes=synset_sizes,
                                  idx=i + 1))
            ds = ds.remove_columns([f'original{i+1}'])
            ds = ds.rename_column(f'meanGrade{i+1}', f'grade{i+1}')

        if model_id == 2:
            ds = ds.map(add_T5_input)

        ds = ds.rename_column('label', 'labels')
        binary_ds = ds.filter(lambda ex: ex['labels'] != 0).\
            map(lambda ex: {'labels': ex['labels'] - 1})
        binary_ds_features = ds['train'].features.copy()
        binary_ds_features['labels'] = ClassLabel(
            names=ds['train'].features['labels'].names[1:])
        binary_ds = binary_ds.cast(binary_ds_features)

        encode_fn = get_encode(tokenizer, model_id=model_id)
        encoded_ds = binary_ds.map(encode_fn, batched=True, batch_size=100)

        print('Saving preprocessed dataset.')
        os.makedirs(ds_path)
        encoded_ds.save_to_disk(ds_path)

    if model_id == 0:
        from task1.data import get_encoded_ds_cols
        encoded_ds_cols = get_encoded_ds_cols(args)
        encoded_ds_cols = [
            f'{col}{i+1}' for i in range(2) for col in encoded_ds_cols
        ]
        encoded_ds_cols += ['grade1', 'grade2']
    elif model_id == 1 and args.transformer != 'distilbert-base-cased':
        encoded_ds_cols = ['input_ids', 'token_type_ids', 'attention_mask']
    else:
        encoded_ds_cols = ['input_ids', 'attention_mask']

    for _ds in encoded_ds.values():
        _ds.set_format(type='torch',
                       columns=encoded_ds_cols + ['labels'],
                       output_all_columns=output_all_cols)

    return encoded_ds