def test_run_range_test_with_traindataloaderiter(self, mocker): task = mod_task.XORTask() lr_finder = prepare_lr_finder(task) num_iter = 5 loader_iter = TrainDataLoaderIter(task.train_loader) spy = mocker.spy(loader_iter, "inputs_labels_from_batch") lr_finder.range_test(loader_iter, num_iter=num_iter) assert spy.call_count == num_iter
def test_traindataloaderiter(self): batch_size, data_length = 32, 256 dataset = mod_dataset.RandomDataset(data_length) dataloader = DataLoader(dataset, batch_size=batch_size) loader_iter = TrainDataLoaderIter(dataloader) assert run_loader_iter(loader_iter) # `TrainDataLoaderIter` can reset itself, so that it's ok to reuse it # directly and iterate it more than `len(dataloader)` times. assert run_loader_iter(loader_iter, desired_runs=len(dataloader) + 1)
def test_run_range_test_with_valloaderiter_without_subclassing(self): task = mod_task.XORTask(validate=True) lr_finder = prepare_lr_finder(task) num_iter = 5 train_loader_iter = TrainDataLoaderIter(task.train_loader) val_loader_iter = CustomLoaderIter(task.val_loader) with pytest.raises(ValueError, match="`val_loader` has unsupported type"): lr_finder.range_test(train_loader_iter, val_loader=val_loader_iter, num_iter=num_iter)
def test_run_range_test_with_valdataloaderiter(self, mocker): task = mod_task.XORTask(validate=True) lr_finder = prepare_lr_finder(task) num_iter = 5 train_loader_iter = TrainDataLoaderIter(task.train_loader) val_loader_iter = ValDataLoaderIter(task.val_loader) spy_train = mocker.spy(train_loader_iter, "inputs_labels_from_batch") spy_val = mocker.spy(val_loader_iter, "inputs_labels_from_batch") lr_finder.range_test(train_loader_iter, val_loader=val_loader_iter, num_iter=num_iter) assert spy_train.call_count == num_iter assert spy_val.call_count == num_iter * len(task.val_loader)