Ejemplo n.º 1
0
def get_train_dataloader_from_cache(
    train_cache: caching.ChunkedFilesDataCache,
    task,
    train_batch_size: int,
    sample_weights_path=None,
    fix_seed_for_weighted_sampler=False,
):
    # TODO: Expose buffer_size parameter  (issue #1183)

    if sample_weights_path is not None:
        dataset = train_cache.get_iterable_dataset(buffer_size=10000,
                                                   shuffle=False)
        dataset = _ListDataset([elem for elem in dataset])
        _sample_weights = pd.read_csv(sample_weights_path,
                                      sep='\t',
                                      header=None)[0]
        sampler = WeightedDatasetSampler(
            dataset, _sample_weights, fix_seed=fix_seed_for_weighted_sampler)
    else:
        dataset = train_cache.get_iterable_dataset(buffer_size=10000,
                                                   shuffle=True)
        sampler = None

    train_dataloader = torch_utils.DataLoaderWithLength(
        dataset=dataset,
        batch_size=train_batch_size,
        collate_fn=task.collate_fn,
        sampler=sampler)
    return train_dataloader
Ejemplo n.º 2
0
def get_train_dataloader_from_cache(
    train_cache: caching.ChunkedFilesDataCache, task, train_batch_size: int
):
    # TODO: Expose buffer_size parameter  (issue #1183)
    dataset = train_cache.get_iterable_dataset(buffer_size=10000, shuffle=True)
    train_dataloader = torch_utils.DataLoaderWithLength(
        dataset=dataset, batch_size=train_batch_size, collate_fn=task.collate_fn,
    )
    return train_dataloader
Ejemplo n.º 3
0
def get_train_dataloader_from_cache(
    train_cache: caching.ChunkedFilesDataCache,
    task,
    train_batch_size: int,
    batch_method: str,
    min_batch_size: int,
    total_batches: int,
    matchlist_pickle_path: int,
):
    # TODO: Expose buffer_size parameter  (Issue #50)
    if batch_method == 'default':
        dataset = train_cache.get_iterable_dataset(buffer_size=10000,
                                                   shuffle=True)
        train_dataloader = torch_utils.DataLoaderWithLength(
            dataset=dataset,
            batch_size=train_batch_size,
            collate_fn=task.collate_fn,
        )
    elif batch_method == 'clustered':
        dataset = train_cache.get_uniterable_dataset(buffer_size=10000)
        assert (
            total_batches > 0
        ), f"Must define total number of batches to generate. Given: {total_batches}."
        assert (
            train_batch_size > 0
        ), f"Max batch size must be greater than zero. Given: {train_batch_size}."

        # Currently only supports pickled matchlist. Could potentially incorporate matching, but may take long
        # depending on size of data.

        assert os.path.exists(
            matchlist_pickle_path
        ), f"Must first create pickled match list or path given does not exist. Given: {matchlist_pickle_path}"
        match_list = pickle.load(open(matchlist_pickle_path, 'rb'))

        matched_random_batch_sampler = torch_utils.MatchedRandomBatchSampler(
            min_batch_size=min_batch_size,
            max_batch_size=train_batch_size,
            drop_last=True,
            match_list=match_list,
            total_batches=total_batches,
        )

        train_dataloader = torch_utils.DataLoader(
            dataset=dataset,
            collate_fn=task.collate_fn,
            batch_sampler=matched_random_batch_sampler,
        )
    else:
        raise KeyError(f"Batching method not supported: {batch_method}")

    return train_dataloader
Ejemplo n.º 4
0
def get_eval_dataloader_from_cache(
    eval_cache: caching.ChunkedFilesDataCache,
    task,
    eval_batch_size: int,
    subset_num=None,
    explicit_subset=None,
):
    dataset = eval_cache.get_iterable_dataset(
        buffer_size=10000, shuffle=False, subset_num=subset_num, explicit_subset=explicit_subset,
    )
    eval_dataloader = torch_utils.DataLoaderWithLength(
        dataset=dataset, batch_size=eval_batch_size, collate_fn=task.collate_fn,
    )
    return eval_dataloader
Ejemplo n.º 5
0
 def get_loss_weights_dict(self, start_position: int = None):
     if start_position is not None:
         raise Exception()
     loss_weights_dict = {}
     for task_name in self.jiant_task_container.task_run_config.train_task_list:
         task_specific_config = self.jiant_task_container.task_specific_configs[
             task_name]
         logger.info('task="%s": loading loss weights from "%s"', task_name,
                     task_specific_config.train_loss_weights)
         train_batch_size = task_specific_config.train_batch_size
         if task_specific_config.train_loss_weights is not None:
             dataset = pd.read_csv(task_specific_config.train_loss_weights,
                                   sep='\t',
                                   header=None)[0].values
             dataset = torch.Tensor(dataset).to(self.device)
             loss_weights_dict[task_name] = InfiniteYield(
                 torch_utils.DataLoaderWithLength(
                     dataset=dataset, batch_size=train_batch_size))
         else:
             loss_weights_dict[task_name] = None
     return loss_weights_dict
Ejemplo n.º 6
0
def main(args: RunConfiguration):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # === Shared model components setup === #
    model_type = "roberta-base"
    model_arch = ModelArchitectures.from_model_type(model_type=model_type)
    transformers_class_spec = model_setup.TRANSFORMERS_CLASS_SPEC_DICT[
        model_arch]
    ancestor_model = model_setup.get_ancestor_model(
        transformers_class_spec=transformers_class_spec,
        model_config_path=args.model_config_path,
    )
    encoder = model_setup.get_encoder(
        model_arch=model_arch,
        ancestor_model=ancestor_model,
    )
    tokenizer = shared_model_setup.get_tokenizer(
        model_type=model_type,
        tokenizer_path=args.model_tokenizer_path,
    )

    # === Taskmodels setup === #
    task_dict = {
        "mnli":
        tasks.create_task_from_config_path(
            os.path.join(
                args.task_config_base_path,
                "mnli.json",
            )),
        "qnli":
        tasks.create_task_from_config_path(
            os.path.join(
                args.task_config_base_path,
                "qnli.json",
            )),
        "rte":
        tasks.create_task_from_config_path(
            os.path.join(
                args.task_config_base_path,
                "qnli.json",
            ))
    }
    taskmodels_dict = {
        "nli":
        taskmodels.ClassificationModel(
            encoder=encoder,
            classification_head=heads.ClassificationHead(
                hidden_size=encoder.config.hidden_size,
                hidden_dropout_prob=encoder.config.hidden_dropout_prob,
                num_labels=len(task_dict["mnli"].LABELS),
            ),
        ),
        "rte":
        taskmodels.ClassificationModel(
            encoder=encoder,
            classification_head=heads.ClassificationHead(
                hidden_size=encoder.config.hidden_size,
                hidden_dropout_prob=encoder.config.hidden_dropout_prob,
                num_labels=len(task_dict["rte"].LABELS),
            ),
        ),
    }
    task_to_taskmodel_map = {
        "mnli": "nli",
        "qnli": "nli",
        "rte": "rte",
    }

    # === Final === #
    jiant_model = JiantModel(
        task_dict=task_dict,
        encoder=encoder,
        taskmodels_dict=taskmodels_dict,
        task_to_taskmodel_map=task_to_taskmodel_map,
        tokenizer=tokenizer,
    )
    jiant_model = jiant_model.to(device)

    # === Run === #
    task_dataloader_dict = {}
    for task_name, task in task_dict.items():
        train_cache = caching.ChunkedFilesDataCache(
            cache_fol_path=os.path.join(args.task_cache_base_path, task_name,
                                        "train"), )
        train_dataset = train_cache.get_iterable_dataset(buffer_size=10000,
                                                         shuffle=True)
        train_dataloader = torch_utils.DataLoaderWithLength(
            dataset=train_dataset,
            batch_size=4,
            collate_fn=task.collate_fn,
        )
        task_dataloader_dict[task_name] = train_dataloader

    for task_name, task in task_dict.items():
        batch, batch_metadata = next(iter(task_dataloader_dict[task_name]))
        batch = batch.to(device)
        with torch.no_grad():
            model_output = wrap_jiant_forward(
                jiant_model=jiant_model,
                batch=batch,
                task=task,
                compute_loss=True,
            )
        print(task_name)
        print(model_output)
        print()