def test_serialization(self):
        with tempfile.TemporaryDirectory() as tmp_dir:
            dsets = self._create_dummy_dataset_dict()
            dsets.save_to_disk(tmp_dir)
            dsets = DatasetDict.load_from_disk(tmp_dir)
            self.assertListEqual(sorted(dsets), ["test", "train"])
            self.assertEqual(len(dsets["train"]), 30)
            self.assertListEqual(dsets["train"].column_names, ["filename"])
            self.assertEqual(len(dsets["test"]), 30)
            self.assertListEqual(dsets["test"].column_names, ["filename"])

            del dsets["test"]
            dsets.save_to_disk(tmp_dir)
            dsets = DatasetDict.load_from_disk(tmp_dir)
            self.assertListEqual(sorted(dsets), ["train"])
            self.assertEqual(len(dsets["train"]), 30)
            self.assertListEqual(dsets["train"].column_names, ["filename"])
def get_dataset(tokenizer, args, output_all_cols=False, data_dir=''):
    ds_path = os.path.join(data_dir, f'task1/preprocessed_data',
                           args.transformer)
    print(f'Dataset path: {ds_path}')
    try:
        encoded_ds = DatasetDict.load_from_disk(ds_path)
        print('Reloaded persisted dataset.')
    except:
        ds: DatasetDict = load_dataset("humicroedit", "subtask-1")
        glove = torchtext.vocab.GloVe(name='840B',
                                      dim=300,
                                      cache=os.path.join(
                                          os.environ['HOME'], '.vector_cache'))
        synset_sizes = get_synsets_sizes(ds)

        ds = ds.rename_column('edit', 'word_fin')
        ds = ds.map(
            get_preprocess_ds(glove=glove,
                              synset_sizes=synset_sizes,
                              add_amb_feat=True))
        ds = ds.remove_columns(['original'])

        ds = ds.rename_column('meanGrade', 'grade')
        encode_fn = get_encode(tokenizer)
        encoded_ds = ds.map(encode_fn, batched=True, batch_size=100)

        print('Saving preprocessed dataset.')
        os.makedirs(ds_path)
        encoded_ds.save_to_disk(ds_path)

    encoded_ds_cols = get_encoded_ds_cols(args)
    for _ds in encoded_ds.values():
        _ds.set_format(type='torch',
                       columns=encoded_ds_cols + ['grade'],
                       output_all_columns=output_all_cols)
    return encoded_ds
def get_dataset(tokenizer,
                model_id=0,
                args=None,
                output_all_cols=False,
                data_dir=''):
    ds_path = os.path.join(data_dir,
                           f'task2/preprocessed_data/model{model_id}',
                           args.transformer)
    print(f'Dataset path: {ds_path}')
    try:
        encoded_ds = DatasetDict.load_from_disk(ds_path)
        print('Reloaded persisted dataset.')
    except:
        ds: DatasetDict = load_dataset("humicroedit", "subtask-2")
        glove, synset_sizes = None, None
        if model_id == 0:
            glove = torchtext.vocab.GloVe(name='840B',
                                          dim=300,
                                          cache=os.path.join(
                                              os.environ['HOME'],
                                              '.vector_cache'))
            synset_sizes = get_synsets_sizes(ds, task=2)

        for i in range(2):
            ds = ds.rename_column(f'edit{i+1}', f'word_fin{i+1}')
            ds = ds.map(
                get_preprocess_ds(glove=glove,
                                  synset_sizes=synset_sizes,
                                  idx=i + 1))
            ds = ds.remove_columns([f'original{i+1}'])
            ds = ds.rename_column(f'meanGrade{i+1}', f'grade{i+1}')

        if model_id == 2:
            ds = ds.map(add_T5_input)

        ds = ds.rename_column('label', 'labels')
        binary_ds = ds.filter(lambda ex: ex['labels'] != 0).\
            map(lambda ex: {'labels': ex['labels'] - 1})
        binary_ds_features = ds['train'].features.copy()
        binary_ds_features['labels'] = ClassLabel(
            names=ds['train'].features['labels'].names[1:])
        binary_ds = binary_ds.cast(binary_ds_features)

        encode_fn = get_encode(tokenizer, model_id=model_id)
        encoded_ds = binary_ds.map(encode_fn, batched=True, batch_size=100)

        print('Saving preprocessed dataset.')
        os.makedirs(ds_path)
        encoded_ds.save_to_disk(ds_path)

    if model_id == 0:
        from task1.data import get_encoded_ds_cols
        encoded_ds_cols = get_encoded_ds_cols(args)
        encoded_ds_cols = [
            f'{col}{i+1}' for i in range(2) for col in encoded_ds_cols
        ]
        encoded_ds_cols += ['grade1', 'grade2']
    elif model_id == 1 and args.transformer != 'distilbert-base-cased':
        encoded_ds_cols = ['input_ids', 'token_type_ids', 'attention_mask']
    else:
        encoded_ds_cols = ['input_ids', 'attention_mask']

    for _ds in encoded_ds.values():
        _ds.set_format(type='torch',
                       columns=encoded_ds_cols + ['labels'],
                       output_all_columns=output_all_cols)

    return encoded_ds