コード例 #1
0
    def test_run_range_test_with_traindataloaderiter(self, mocker):
        task = mod_task.XORTask()
        lr_finder = prepare_lr_finder(task)
        num_iter = 5

        loader_iter = TrainDataLoaderIter(task.train_loader)
        spy = mocker.spy(loader_iter, "inputs_labels_from_batch")

        lr_finder.range_test(loader_iter, num_iter=num_iter)
        assert spy.call_count == num_iter
コード例 #2
0
    def test_traindataloaderiter(self):
        batch_size, data_length = 32, 256
        dataset = mod_dataset.RandomDataset(data_length)
        dataloader = DataLoader(dataset, batch_size=batch_size)

        loader_iter = TrainDataLoaderIter(dataloader)

        assert run_loader_iter(loader_iter)

        # `TrainDataLoaderIter` can reset itself, so that it's ok to reuse it
        # directly and iterate it more than `len(dataloader)` times.
        assert run_loader_iter(loader_iter, desired_runs=len(dataloader) + 1)
コード例 #3
0
    def test_run_range_test_with_valloaderiter_without_subclassing(self):
        task = mod_task.XORTask(validate=True)
        lr_finder = prepare_lr_finder(task)
        num_iter = 5

        train_loader_iter = TrainDataLoaderIter(task.train_loader)
        val_loader_iter = CustomLoaderIter(task.val_loader)

        with pytest.raises(ValueError,
                           match="`val_loader` has unsupported type"):
            lr_finder.range_test(train_loader_iter,
                                 val_loader=val_loader_iter,
                                 num_iter=num_iter)
コード例 #4
0
    def test_run_range_test_with_valdataloaderiter(self, mocker):
        task = mod_task.XORTask(validate=True)
        lr_finder = prepare_lr_finder(task)
        num_iter = 5

        train_loader_iter = TrainDataLoaderIter(task.train_loader)
        val_loader_iter = ValDataLoaderIter(task.val_loader)
        spy_train = mocker.spy(train_loader_iter, "inputs_labels_from_batch")
        spy_val = mocker.spy(val_loader_iter, "inputs_labels_from_batch")

        lr_finder.range_test(train_loader_iter,
                             val_loader=val_loader_iter,
                             num_iter=num_iter)
        assert spy_train.call_count == num_iter
        assert spy_val.call_count == num_iter * len(task.val_loader)