def test_shuffle_batch_sampler(): data_source = [[1], [2], [3], [4], [5], [6]] sort_key = lambda r: len(r) batch_size = 2 batches = list( ShuffleBatchSampler(SortedSampler(data_source, sort_key=sort_key), batch_size, False)) assert len(batches) == 3
def test_sorted_sampler(): data_source = [[1], [2], [3], [4], [5], [6]] sort_key = lambda r: r[0] indexes = list(SortedSampler(data_source, sort_key=sort_key)) assert len(indexes) == len(data_source) for i, j in enumerate(indexes): assert i == j
def test_shuffle_batch_sampler_drop_last(): data_source = [[1], [2], [3], [4], [5]] sort_key = lambda r: len(r) batch_size = 2 batches = list( ShuffleBatchSampler(SortedSampler(data_source, sort_key), batch_size, drop_last=True)) assert len(batches) == 2
def test_pickleable(): data_source = [[1], [2], [3], [4], [5], [6]] sampler = SortedSampler(data_source) pickle.dumps(sampler)
def setup(self, stage): # https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html#setup self._dataset = [] max_seq_len = 0 skipped_examples = 0 with open(self.json_path) as json_file: for line in itr.islice(json_file, self.n_samples): json_obj = json.loads(line) del json_obj["errors"] if len(json_obj["incorrect"]) > self._max_seq_len: # omit to long sequence skipped_examples += 1 continue self._dataset.append(json_obj) # compute some dataset stats max_seq_len = max( max_seq_len, len(json_obj["correct"]), len(json_obj["incorrect"]) ) ds_len = len(self._dataset) self.dims = (ds_len, max_seq_len) stats = { "dataset_len": ds_len, "max_seq_len": max_seq_len, "skiped_examples": skipped_examples, } self.tokenizer.train( self._dataset, append_eos=True, append_sos=True, min_occurrences=1000 ) # bad, bad, hardcoded path, possible pull request self.tokenizer.save_vocab("model_corrector/") self.vocab_size = self.tokenizer.vocab_size self.padding_index = self.tokenizer.padding_index # =0 dataset_len = len(self._dataset) assert_msg = ( f"lenght of all gathered dataset examples is {dataset_len}, it is less than validation split." + f"Try to increase self._max_seq_len={self._max_seq_len} or decrease self.valid_split_size={self.valid_split_size}" ) assert dataset_len > self.valid_split_size, assert_msg last_idx = dataset_len - self.valid_split_size # list of dicts self.valid_ds = self._dataset[last_idx:] self.train_ds = self._dataset[0:last_idx] # random.shuffle(self.train_ds) self.train_sampler = SortedSampler( self.train_ds, sort_key=self._sampler_sort_func ) self.val_sampler = SortedSampler( self.valid_ds, sort_key=self._sampler_sort_func ) return stats
def setup(self, stage): # https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html#setup N_valid_size = self.N_valid_size # dataset = self._setup_task1(self.N_random_samples) dataset = self._setup_task2(self.N_random_samples) # list of dicts self.train_ds = dataset[0:-N_valid_size] self.valid_ds = dataset[-N_valid_size:] # load dataset build vocab and numericalize # todo: change it bad design! only for prototyping and learning dataset_example_gen = (ex["correct"] + " " + ex["incorrect"] for ex in dataset) self.tokenizer = CharacterEncoder( dataset_example_gen, append_eos=True, append_sos=True ) pickle.dump( self.tokenizer, open(f"./abc_data_character_encoder.p", "wb"), ) self.train_sampler = SortedSampler( self.train_ds, sort_key=self._sampler_sort_func ) self.val_sampler = SortedSampler( self.valid_ds, sort_key=self._sampler_sort_func ) # #samplers from torchnlp, did not work with distibutedDataParallel # self.train_sampler = BucketBatchSampler( # sampler=SequentialSampler(self.train_ds), # # bucket_size_multiplier=1000, # batch_size=self.batch_size, # drop_last=True, # sort_key=self._bucket_train_sort_func, # #sort_key=lambda i: -len(self.train_ds[i]["incorrect"]), # ) # self.val_sampler = BucketBatchSampler( # sampler=SequentialSampler(self.valid_ds), # batch_size=self.batch_size, # drop_last=True, # sort_key = self._bucket_val_sort_func, # #sort_key=lambda i: -len(self.valid_ds[i]["incorrect"]), # ) # samplers from catalyst # DistributedWrapperSampler # DynamicBatchLensampler # https://github.com/catalyst-team/catalyst/blob/master/catalyst/data/sampler.py # DynamicLenBatchSampler, DistributedSamplerWrapper # train_sampler = RandomSampler(self.train_ds) # train_sampler = DynamicLenBatchSampler(train_sampler, self.batch_size, drop_last=True) # self.train_sampler = train_sampler # self.train_sampler = DistributedSamplerWrapper(train_sampler) # valid_sampler = RandomSampler(self.valid_ds) # valid_sampler = DynamicLenBatchSampler(valid_sampler, self.batch_size, drop_last=True) # self.val_sampler = valid_sampler # self.valid_sampler = DistributedSamplerWrapper(valid_sampler) ### todo: do wymiany self.vocab_size = self.tokenizer.vocab_size self.padding_index = self.tokenizer.padding_index # =0
def test_pickleable(): data_source = [[1], [2], [3], [4], [5], [6]] sampler = ShuffleBatchSampler(SortedSampler(data_source), batch_size=2, drop_last=False) pickle.dumps(sampler)