Ejemplo n.º 1
0
def get_train_dataloader_from_cache(
    train_cache: caching.ChunkedFilesDataCache,
    task,
    train_batch_size: int,
    sample_weights_path=None,
    fix_seed_for_weighted_sampler=False,
):
    # TODO: Expose buffer_size parameter  (issue #1183)

    if sample_weights_path is not None:
        dataset = train_cache.get_iterable_dataset(buffer_size=10000,
                                                   shuffle=False)
        dataset = _ListDataset([elem for elem in dataset])
        _sample_weights = pd.read_csv(sample_weights_path,
                                      sep='\t',
                                      header=None)[0]
        sampler = WeightedDatasetSampler(
            dataset, _sample_weights, fix_seed=fix_seed_for_weighted_sampler)
    else:
        dataset = train_cache.get_iterable_dataset(buffer_size=10000,
                                                   shuffle=True)
        sampler = None

    train_dataloader = torch_utils.DataLoaderWithLength(
        dataset=dataset,
        batch_size=train_batch_size,
        collate_fn=task.collate_fn,
        sampler=sampler)
    return train_dataloader
Ejemplo n.º 2
0
def get_train_dataloader_from_cache(
    train_cache: caching.ChunkedFilesDataCache, task, train_batch_size: int
):
    # TODO: Expose buffer_size parameter  (issue #1183)
    dataset = train_cache.get_iterable_dataset(buffer_size=10000, shuffle=True)
    train_dataloader = torch_utils.DataLoaderWithLength(
        dataset=dataset, batch_size=train_batch_size, collate_fn=task.collate_fn,
    )
    return train_dataloader
Ejemplo n.º 3
0
def get_train_dataloader_from_cache(
    train_cache: caching.ChunkedFilesDataCache,
    task,
    train_batch_size: int,
    batch_method: str,
    min_batch_size: int,
    total_batches: int,
    matchlist_pickle_path: int,
):
    # TODO: Expose buffer_size parameter  (Issue #50)
    if batch_method == 'default':
        dataset = train_cache.get_iterable_dataset(buffer_size=10000,
                                                   shuffle=True)
        train_dataloader = torch_utils.DataLoaderWithLength(
            dataset=dataset,
            batch_size=train_batch_size,
            collate_fn=task.collate_fn,
        )
    elif batch_method == 'clustered':
        dataset = train_cache.get_uniterable_dataset(buffer_size=10000)
        assert (
            total_batches > 0
        ), f"Must define total number of batches to generate. Given: {total_batches}."
        assert (
            train_batch_size > 0
        ), f"Max batch size must be greater than zero. Given: {train_batch_size}."

        # Currently only supports pickled matchlist. Could potentially incorporate matching, but may take long
        # depending on size of data.

        assert os.path.exists(
            matchlist_pickle_path
        ), f"Must first create pickled match list or path given does not exist. Given: {matchlist_pickle_path}"
        match_list = pickle.load(open(matchlist_pickle_path, 'rb'))

        matched_random_batch_sampler = torch_utils.MatchedRandomBatchSampler(
            min_batch_size=min_batch_size,
            max_batch_size=train_batch_size,
            drop_last=True,
            match_list=match_list,
            total_batches=total_batches,
        )

        train_dataloader = torch_utils.DataLoader(
            dataset=dataset,
            collate_fn=task.collate_fn,
            batch_sampler=matched_random_batch_sampler,
        )
    else:
        raise KeyError(f"Batching method not supported: {batch_method}")

    return train_dataloader
Ejemplo n.º 4
0
def get_eval_dataloader_from_cache(
    eval_cache: caching.ChunkedFilesDataCache,
    task,
    eval_batch_size: int,
    subset_num=None,
    explicit_subset=None,
):
    dataset = eval_cache.get_iterable_dataset(
        buffer_size=10000, shuffle=False, subset_num=subset_num, explicit_subset=explicit_subset,
    )
    eval_dataloader = torch_utils.DataLoaderWithLength(
        dataset=dataset, batch_size=eval_batch_size, collate_fn=task.collate_fn,
    )
    return eval_dataloader