Beispiel #1
0
    def __iter__(self):
        samplers_list = []
        sampler_iterators = []
        for dataset_idx in range(self.number_of_datasets):
            cur_dataset = self.dataset.datasets[dataset_idx]
            sampler = RandomSampler(cur_dataset)
            samplers_list.append(sampler)
            cur_sampler_iterator = sampler.__iter__()
            sampler_iterators.append(cur_sampler_iterator)

        push_index_val = [0] + self.dataset.cumulative_sizes[:-1]
        step = self.batch_size * self.number_of_datasets
        samples_to_grab = self.batch_size

        epoch_samples = self.largest_dataset_size * self.number_of_datasets

        final_samples_list = []
        for _ in range(0, epoch_samples, step):
            for i in range(self.number_of_datasets):
                cur_batch_sampler = sampler_iterators[i]
                cur_samples = []
                for _ in range(samples_to_grab):
                    try:
                        cur_sample_org = cur_batch_sampler.__next__()
                        cur_sample = cur_sample_org + push_index_val[i]
                        cur_samples.append(cur_sample)
                    except StopIteration:
                        sampler_iterators[i] = sampelrs_list[i].__iter__()
                        cur_batch_sampler = sampler_iterators[i]
                        cur_sample_org = cur_batch_sampler.__next__()
                        cur_sample = cur_sample_org + push_index_val[i]
                        cur_samples.append(cur_sample)
                    final_samples_list.extend(cur_samples)

        return iter(final_samples_list)
Beispiel #2
0
    def __iter__(self):
        samplers_list = []
        sampler_iterators = []
        datasets_length = []
        for dataset_idx in range(self.number_of_datasets):
            cur_dataset = self.dataset.datasets[dataset_idx]
            if dataset_idx == 0:
                # the first dataset is kept at RandomSampler
                sampler = RandomSampler(cur_dataset)
            else:
                # the second unbalanced dataset is changed
                sampler = ExampleImbalancedDatasetSampler(cur_dataset)
            samplers_list.append(sampler)
            cur_sampler_iterator = sampler.__iter__()
            sampler_iterators.append(cur_sampler_iterator)
            datasets_length.append(len(cur_dataset))

        push_index_val = [0] + self.dataset.cumulative_sizes[:-1]
        step = self.batch_size * self.number_of_datasets
        samples_to_grab = self.batch_size
        largest_dataset_index = torch.argmax(
            torch.as_tensor(datasets_length)).item()
        # for this case we want to get all samples in dataset, this force us to resample from the smaller datasets
        epoch_samples = datasets_length[
            largest_dataset_index] * self.number_of_datasets

        final_samples_list = [
        ]  # this is a list of indexes from the combined dataset
        for _ in range(0, epoch_samples, step):
            for i in range(self.number_of_datasets):
                cur_batch_sampler = sampler_iterators[i]
                cur_samples = []
                for _ in range(samples_to_grab):
                    try:
                        cur_sample_org = cur_batch_sampler.__next__()
                        cur_sample = cur_sample_org + push_index_val[i]
                        cur_samples.append(cur_sample)
                    except StopIteration:
                        if i == largest_dataset_index:
                            # largest dataset iterator is done we can break
                            samples_to_grab = len(
                                cur_samples)  # adjusting the samples_to_grab
                            break  # got to the end of iterator - extend final list and continue to next task
                        else:
                            # restart the iterator - we want more samples until finishing with the largest dataset
                            sampler_iterators[i] = samplers_list[i].__iter__()
                            cur_batch_sampler = sampler_iterators[i]
                            cur_sample_org = cur_batch_sampler.__next__()
                            cur_sample = cur_sample_org + push_index_val[i]
                            cur_samples.append(cur_sample)
                final_samples_list.extend(cur_samples)

        return iter(final_samples_list)
    def __iter__(self):
        samplers_list = []
        sampler_iterators = []
        for dataset_idx in range(self.number_of_datasets):
            cur_dataset = self.dataset.datasets[dataset_idx]
            if dataset_idx == 0:
                # the first dataset is kept at RandomSampler
                sampler = RandomSampler(cur_dataset)
            else:
                # the second unbalanced dataset is changed
                sampler = ExampleImbalancedDatasetSampler(cur_dataset)
            samplers_list.append(sampler)
            cur_sampler_iterator = sampler.__iter__()
            sampler_iterators.append(cur_sampler_iterator)

        push_index_val = [0] + self.dataset.cumulative_sizes[:-1]
        step = self.batch_size * self.number_of_datasets
        samples_to_grab = self.batch_size
        # for this case we want to get all samples in dataset, this force us to resample from the smaller datasets
        epoch_samples = self.largest_dataset_size * self.number_of_datasets

        final_samples_list = [
        ]  # this is a list of indexes from the combined dataset
        for _ in range(0, epoch_samples, step):
            for i in range(self.number_of_datasets):
                cur_batch_sampler = sampler_iterators[i]
                cur_samples = []
                for _ in range(samples_to_grab):
                    try:
                        cur_sample_org = cur_batch_sampler.__next__()
                        cur_sample = cur_sample_org + push_index_val[i]
                        cur_samples.append(cur_sample)
                    except StopIteration:
                        # got to the end of iterator - restart the iterator and continue to get samples
                        # until reaching "epoch_samples"
                        sampler_iterators[i] = samplers_list[i].__iter__()
                        cur_batch_sampler = sampler_iterators[i]
                        cur_sample_org = cur_batch_sampler.__next__()
                        cur_sample = cur_sample_org + push_index_val[i]
                        cur_samples.append(cur_sample)
                final_samples_list.extend(cur_samples)

        return iter(final_samples_list)
Beispiel #4
0
    def get_data_order(self, train_examples, data_order_seed,
                       num_train_epochs):
        train_features = convert_examples_to_features(
            train_examples,
            self.label_map,
            self.rparams.max_seq_length,
            self.tokenizer,
            verbose=False,
        )
        train_data, train_tokens = convert_to_dataset(
            train_features,
            label_mode=get_label_mode(self.label_map),
        )

        iterators = []
        train_sampler = RandomSampler(train_data)
        print("Setting data order seed to {}".format(data_order_seed))
        torch.manual_seed(data_order_seed)
        for epoch in range(num_train_epochs):
            iterators.append(train_sampler.__iter__())

        return iterators
Beispiel #5
0
    def __init__(self, dataset: ConcatNamedDataset, batch_size: int):
        self.dataset = dataset
        self.batch_size = batch_size
        self.number_of_datasets = dataset.get_datasets_nb()

        self.samplers_list: List[Sampler] = []
        self.sampler_iterators: List[Iterator] = []
        self.datasets_length: List[int] = []  # as we get datasets in descending size, 0 is the largest dataset
        self.datasets_by_desc_size = self.dataset.get_datasets_info_by_desc_count()

        for idx, (lang_id, lang, dataset_count) in enumerate(self.datasets_by_desc_size):
            cur_dataset = self.dataset.get_dataset_by_lang_id(lang_id)
            logger.debug(
                f"Sampling batches from Dataset[lang_id:{lang_id}, count:{dataset_count}, len:{len(cur_dataset)} lang:{lang}]"
            )
            sampler = RandomSampler(cur_dataset)

            self.samplers_list.append(sampler)
            cur_sampler_iterator = sampler.__iter__()
            self.sampler_iterators.append(cur_sampler_iterator)
            self.datasets_length.append(dataset_count)

        self.generate_sample_list()
Beispiel #6
0
 def __iter__(self):
     for index in RandomSampler.__iter__(self):
         yield self.data_source[index]