コード例 #1
0
def test_shuffle_batch_sampler():
    data_source = [[1], [2], [3], [4], [5], [6]]
    sort_key = lambda r: len(r)
    batch_size = 2
    batches = list(
        ShuffleBatchSampler(SortedSampler(data_source, sort_key=sort_key), batch_size, False))
    assert len(batches) == 3
コード例 #2
0
def test_sorted_sampler():
    data_source = [[1], [2], [3], [4], [5], [6]]
    sort_key = lambda r: r[0]
    indexes = list(SortedSampler(data_source, sort_key=sort_key))
    assert len(indexes) == len(data_source)
    for i, j in enumerate(indexes):
        assert i == j
コード例 #3
0
def test_shuffle_batch_sampler_drop_last():
    data_source = [[1], [2], [3], [4], [5]]
    sort_key = lambda r: len(r)
    batch_size = 2
    batches = list(
        ShuffleBatchSampler(SortedSampler(data_source, sort_key), batch_size, drop_last=True))
    assert len(batches) == 2
コード例 #4
0
def test_pickleable():
    data_source = [[1], [2], [3], [4], [5], [6]]
    sampler = SortedSampler(data_source)
    pickle.dumps(sampler)
コード例 #5
0
    def setup(self, stage):
        # https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html#setup

        self._dataset = []
        max_seq_len = 0
        skipped_examples = 0
        with open(self.json_path) as json_file:

            for line in itr.islice(json_file, self.n_samples):
                json_obj = json.loads(line)
                del json_obj["errors"]

                if len(json_obj["incorrect"]) > self._max_seq_len:
                    # omit to long sequence
                    skipped_examples += 1
                    continue

                self._dataset.append(json_obj)
                # compute some dataset stats
                max_seq_len = max(
                    max_seq_len, len(json_obj["correct"]), len(json_obj["incorrect"])
                )

        ds_len = len(self._dataset)
        self.dims = (ds_len, max_seq_len)

        stats = {
            "dataset_len": ds_len,
            "max_seq_len": max_seq_len,
            "skiped_examples": skipped_examples,
        }

        self.tokenizer.train(
            self._dataset, append_eos=True, append_sos=True, min_occurrences=1000
        )

        # bad, bad, hardcoded path, possible pull request
        self.tokenizer.save_vocab("model_corrector/")

        self.vocab_size = self.tokenizer.vocab_size
        self.padding_index = self.tokenizer.padding_index  # =0

        dataset_len = len(self._dataset)

        assert_msg = (
            f"lenght of all gathered dataset examples is {dataset_len}, it is less than validation split."
            + f"Try to increase self._max_seq_len={self._max_seq_len} or decrease self.valid_split_size={self.valid_split_size}"
        )
        assert dataset_len > self.valid_split_size, assert_msg

        last_idx = dataset_len - self.valid_split_size

        # list of dicts
        self.valid_ds = self._dataset[last_idx:]

        self.train_ds = self._dataset[0:last_idx]
        # random.shuffle(self.train_ds)

        self.train_sampler = SortedSampler(
            self.train_ds, sort_key=self._sampler_sort_func
        )

        self.val_sampler = SortedSampler(
            self.valid_ds, sort_key=self._sampler_sort_func
        )

        return stats
コード例 #6
0
    def setup(self, stage):
        # https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html#setup

        N_valid_size = self.N_valid_size

        # dataset = self._setup_task1(self.N_random_samples)
        dataset = self._setup_task2(self.N_random_samples)

        # list of dicts
        self.train_ds = dataset[0:-N_valid_size]

        self.valid_ds = dataset[-N_valid_size:]

        # load dataset build vocab and numericalize

        # todo: change it bad design! only for prototyping and learning
        dataset_example_gen = (ex["correct"] + " " + ex["incorrect"] for ex in dataset)

        
        self.tokenizer = CharacterEncoder(
            dataset_example_gen, append_eos=True, append_sos=True
        )
        pickle.dump(
            self.tokenizer,
            open(f"./abc_data_character_encoder.p", "wb"),
        )

        self.train_sampler = SortedSampler(
            self.train_ds, sort_key=self._sampler_sort_func
        )

        self.val_sampler = SortedSampler(
            self.valid_ds, sort_key=self._sampler_sort_func
        )

        # #samplers from torchnlp, did not work with distibutedDataParallel
        # self.train_sampler = BucketBatchSampler(
        #     sampler=SequentialSampler(self.train_ds),
        #     # bucket_size_multiplier=1000,
        #     batch_size=self.batch_size,
        #     drop_last=True,
        #     sort_key=self._bucket_train_sort_func,
        #     #sort_key=lambda i: -len(self.train_ds[i]["incorrect"]),
        # )

        # self.val_sampler = BucketBatchSampler(
        #     sampler=SequentialSampler(self.valid_ds),
        #     batch_size=self.batch_size,
        #     drop_last=True,
        #     sort_key = self._bucket_val_sort_func,
        #     #sort_key=lambda i: -len(self.valid_ds[i]["incorrect"]),
        # )

        # samplers from catalyst
        # DistributedWrapperSampler
        # DynamicBatchLensampler
        # https://github.com/catalyst-team/catalyst/blob/master/catalyst/data/sampler.py

        # DynamicLenBatchSampler, DistributedSamplerWrapper

        # train_sampler = RandomSampler(self.train_ds)
        # train_sampler = DynamicLenBatchSampler(train_sampler, self.batch_size, drop_last=True)

        # self.train_sampler = train_sampler
        # self.train_sampler = DistributedSamplerWrapper(train_sampler)

        # valid_sampler = RandomSampler(self.valid_ds)
        # valid_sampler = DynamicLenBatchSampler(valid_sampler, self.batch_size, drop_last=True)
        # self.val_sampler = valid_sampler
        # self.valid_sampler = DistributedSamplerWrapper(valid_sampler)

        ### todo: do wymiany
        self.vocab_size = self.tokenizer.vocab_size
        self.padding_index = self.tokenizer.padding_index  # =0
コード例 #7
0
def test_pickleable():
    data_source = [[1], [2], [3], [4], [5], [6]]
    sampler = ShuffleBatchSampler(SortedSampler(data_source), batch_size=2, drop_last=False)
    pickle.dumps(sampler)