Пример #1
0
def test_infinite_data_loader():
    dataset = RandNDataset(100, (3, 32, 32), True)
    data_loader = DataLoader(dataset)

    check_loader = infinite_data_loader(data_loader)
    check_count = 0

    for _ in check_loader:
        check_count += 1

        if check_count >= 150:
            break

    assert check_count == 150
Пример #2
0
def pruning_loss_sens_one_shot(
    module: Module,
    data: DataLoader,
    loss: Union[LossWrapper, Callable[[Any, Any], Tensor]],
    device: str,
    steps_per_measurement: int,
    sparsity_levels: List[int] = default_pruning_sparsities_loss(False),
    loss_key: str = DEFAULT_LOSS_KEY,
    tester_run_funcs: ModuleRunFuncs = None,
    tester_loggers: List[BaseLogger] = None,
    show_progress: bool = True,
) -> PruningLossSensitivityAnalysis:
    """
    Run a one shot sensitivity analysis for kernel sparsity.
    It does not retrain, and instead puts the model to eval mode.
    Moves layer by layer to calculate the sensitivity analysis for each and
    resets the previously run layers.
    Note, by default it caches the data.
    This means it is not parallel for data loading and the first run can take longer.
    Subsequent sparsity checks for layers and levels will be much faster.

    :param module: the module to run the kernel sparsity sensitivity analysis over
        will extract all prunable layers out
    :param data: the data to run through the module for calculating the sensitivity
        analysis
    :param loss: the loss function to use for the sensitivity analysis
    :param device: the device to run the analysis on; ex: cpu, cuda
    :param steps_per_measurement: the number of samples or items to take for each
        measurement at each sparsity lev
    :param sparsity_levels: the sparsity levels to check for each layer to calculate
        sensitivity
    :param loss_key: the key for the loss function to track in the returned dict
    :param tester_run_funcs: override functions to use in the ModuleTester that runs
    :param tester_loggers: loggers to log data to while running the analysis
    :param show_progress: track progress of the runs if True
    :return: the sensitivity results for every layer that is prunable
    """
    analysis = PruningLossSensitivityAnalysis()
    tester = ModuleTester(
        module,
        device,
        loss,
        loggers=tester_loggers,
        log_summary=False,
        log_steps=max(1, round(steps_per_measurement / 10)),
    )
    layers = get_prunable_layers(module)
    batch_end = _sensitivity_callback(layers, sparsity_levels,
                                      steps_per_measurement, analysis,
                                      loss_key)
    batch_end_hook = tester.run_hooks.register_batch_end_hook(batch_end)
    if tester_run_funcs is not None:
        tester.run_funcs.copy(tester_run_funcs)

    data_loader = infinite_data_loader(data,
                                       early_stop_steps=steps_per_measurement,
                                       cache=True)
    tester.run(
        data_loader,
        desc="KS Analysis",
        show_progress=show_progress,
        track_results=False,
        max_steps=steps_per_measurement * len(sparsity_levels) * len(layers),
    )
    batch_end_hook.remove()

    return analysis
Пример #3
0
def lr_loss_sensitivity(
    module: Module,
    data: DataLoader,
    loss: Union[LossWrapper, Callable[[Any, Any], Tensor]],
    optim: Optimizer,
    device: str,
    steps_per_measurement: int,
    check_lrs: Union[List[float], Tuple[float, ...]] = default_exponential_check_lrs(),
    loss_key: str = DEFAULT_LOSS_KEY,
    trainer_run_funcs: ModuleRunFuncs = None,
    trainer_loggers: List[PyTorchLogger] = None,
    show_progress: bool = True,
) -> LRLossSensitivityAnalysis:
    """
    Implementation for handling running sensitivity analysis for
    learning rates on modules.

    :param module: the module to run the learning rate sensitivity analysis over,
        it is expected to already be on the correct device
    :param data: the data to run through the module for calculating
        the sensitivity analysis
    :param loss: the loss function to use for the sensitivity analysis
    :param optim: the optimizer to run the sensitivity analysis with
    :param device: the device to run the analysis on; ex: cpu, cuda.
        module must already be on that device, this is used to place then data
        on that same device.
    :param steps_per_measurement: the number of batches to run through for
        the analysis at each LR
    :param check_lrs: the learning rates to check for analysis
        (will sort them small to large before running)
    :param loss_key: the key for the loss function to track in the returned dict
    :param trainer_run_funcs: override functions for ModuleTrainer class
    :param trainer_loggers: loggers to log data to while running the analysis
    :param show_progress: track progress of the runs if True
    :return: a list of tuples containing the analyzed learning rate at 0
        and the ModuleRunResults in 1, ModuleRunResults being a collection
        of all the batch results run through the module at that LR
    """
    analysis = LRLossSensitivityAnalysis()
    trainer = ModuleTrainer(
        module,
        device,
        loss,
        optim,
        loggers=trainer_loggers,
        log_summary=False,
        log_steps=max(1, round(steps_per_measurement / 10)),
    )
    batch_end, completed = _sensitivity_callback(
        check_lrs, steps_per_measurement, optim, analysis, loss_key
    )
    batch_end_hook = trainer.run_hooks.register_batch_end_hook(batch_end)
    if trainer_run_funcs is not None:
        trainer.run_funcs.copy(trainer_run_funcs)

    data_loader = infinite_data_loader(data)
    trainer.run(
        data_loader,
        desc="LR Analysis",
        show_progress=show_progress,
        track_results=False,
        max_steps=steps_per_measurement * len(check_lrs),
    )
    completed()
    batch_end_hook.remove()

    return analysis