示例#1
0
    def __init__(
        self,
        params: Union[str, List[str]],
        init_sparsity: float,
        final_sparsity: float,
        start_epoch: float,
        end_epoch: float,
        update_frequency: float,
        inter_func: str = "cubic",
        log_types: Union[str, List[str]] = ALL_TOKEN,
        mask_type: Union[str, List[int]] = "unstructured",
        leave_enabled: bool = True,
        **kwargs,
    ):
        kwargs["min_frequency"] = kwargs.get("min_frequency", -1.0)
        super().__init__(
            log_types=log_types,
            start_epoch=start_epoch,
            end_epoch=end_epoch,
            update_frequency=update_frequency,
            **kwargs,
        )
        self._params = validate_str_iterable(params, "{} for params".format(
            self.__class__.__name__))  # type: List[str]
        self._init_sparsity = init_sparsity
        self._final_sparsity = final_sparsity
        self._leave_enabled = convert_to_bool(leave_enabled)
        self._inter_func = inter_func
        self._mask_type = mask_type
        self._leave_enabled = convert_to_bool(leave_enabled)

        self.validate()
示例#2
0
 def __init__(self, log_types: Union[str, List[str]] = ALL_TOKEN, **kwargs):
     super().__init__(**kwargs)
     self._log_types = (validate_str_iterable(
         log_types, "log_types for {}".format(self.__class__.__name__))
                        if log_types else None)
     self._initialized = False
     self._enabled = True
示例#3
0
 def __init__(
     self,
     params: Union[str, List[str]],
     start_epoch: float = -1.0,
     min_start: float = -1.0,
     end_epoch: float = -1.0,
     min_end: float = -1.0,
     end_comparator: Union[int, None] = 0,
     update_frequency: float = -1.0,
     min_frequency: float = -1.0,
     log_types: Union[str, List[str]] = None,
 ):
     super().__init__(
         log_types=log_types,
         start_epoch=start_epoch,
         min_start=min_start,
         end_epoch=end_epoch,
         min_end=min_end,
         end_comparator=end_comparator,
         update_frequency=update_frequency,
         min_frequency=min_frequency,
     )
     self._params = validate_str_iterable(
         params, "{} for params".format(self.__class__.__name__)
     )
     self._module_masks = None  # type: Optional[ModuleParamPruningMask]
     self._analyzers = None  # type: Optional[List[ModulePruningAnalyzer]]
     self._last_logged_epoch = None
示例#4
0
 def layers(self, value: Union[str, List[str]]):
     """
     :param value: str or list of str for the layers to apply the AS modifier to
         can also use the token __ALL__ to specify all layers
     """
     self._layers = validate_str_iterable(
         value, "{} for layers".format(self.__class__.__name__))
示例#5
0
 def params(self, value: Union[str, List[str]]):
     """
     :param value: A list of full parameter names or regex patterns of names to apply
         pruning to.  Regex patterns must be specified with the prefix 're:'. __ALL__
         will match to all parameters.
     """
     self._params = validate_str_iterable(
         value, "{} for params".format(self.__class__.__name__))
示例#6
0
 def params(self, value: Union[str, List[str]]):
     """
     :param value: List of str for the variable names or regex patterns of names
         to apply the pruning modifier to. Regex patterns must be specified with
         the prefix 're:'.
     """
     self._params = validate_str_iterable(
         value, "{} for params".format(self.__class__.__name__))
示例#7
0
    def __init__(
        self,
        params: Union[str, List[str]],
        start_epoch: float = -1.0,
        min_start: float = -1.0,
        end_epoch: float = -1.0,
        min_end: float = -1.0,
        end_comparator: Union[int, None] = 0,
        update_frequency: float = -1.0,
        min_frequency: float = -1.0,
        log_types: Union[str, List[str]] = None,
        global_sparsity: bool = False,
        allow_reintroduction: bool = False,
        parent_class_kwarg_names: Optional[List[str]] = None,
        leave_enabled: bool = False,
        **kwargs,
    ):
        if parent_class_kwarg_names is not None:
            # filter kwargs only for ones that should be propagated
            # parent_class_kwarg_names = ["params", "init_sparsity", "iterpolation",...]
            kwargs = {
                k: v
                for k, v in kwargs.items() if k in parent_class_kwarg_names
            }
            if "params" in parent_class_kwarg_names:
                kwargs["params"] = params
        super().__init__(
            log_types=log_types,
            start_epoch=start_epoch,
            min_start=min_start,
            end_epoch=end_epoch,
            min_end=min_end,
            end_comparator=end_comparator,
            update_frequency=update_frequency,
            min_frequency=min_frequency,
            **kwargs,
        )
        self._params = validate_str_iterable(
            params, "{} for params".format(self.__class__.__name__))
        self._module_masks = None  # type: Optional[ModuleParamPruningMask]
        self._analyzers = None  # type: Optional[List[ModulePruningAnalyzer]]
        self._last_logged_epoch = None

        self._scorer = None  # type: PruningParamsScorer
        self._mask_creator = None  # type: PruningMaskCreator

        self._global_sparsity = global_sparsity
        self._allow_reintroduction = allow_reintroduction
        self._leave_enabled = leave_enabled

        self._applied_sparsity = None
        self._pre_step_completed = False
        self._sparsity_applied = False
示例#8
0
 def params(self, value: Union[str, List[str], None]):
     """
     :params value: A list of full parameter names or regex patterns of names to
         apply pruning to.
         Regex patterns must be specified with the prefix 're:'. __ALL__
         will match to all parameters. __ALL_PRUNABLE__ will match to all ConvNd
         and Linear layers' weights
     """
     self._params_orig = value
     params, self._final_sparsity = self._get_params_and_final_sparsity(
         self._params_orig, self._final_sparsity_orig)
     self._params = validate_str_iterable(
         params, "{} for params".format(self.__class__.__name__))
示例#9
0
    def __init__(
        self,
        params: Union[str, List[str]],
        init_val: Any,
        final_val: Any,
        start_epoch: float,
        end_epoch: float,
        update_frequency: float,
        inter_func: str = "linear",
        params_strict: bool = True,
    ):
        """
        :param params: A list of full parameter names or regex patterns of names
            to apply pruning to.
            Regex patterns must be specified with the prefix 're:'. __ALL__
            will match to all parameters.
        :param init_val: The initial value to set for the given param in the
            given layers at start_epoch
        :param final_val: The final value to set for the given param in the
            given layers at end_epoch
        :param start_epoch: The epoch to start the modifier at
        :param end_epoch: The epoch to end the modifier at
        :param update_frequency: The number of epochs or fraction of epochs to
            update at between start and end
        :param inter_func: the type of interpolation function to use:
            [linear, cubic, inverse_cubic]; default is linear
        :param params_strict: True if every regex pattern in params must match at least
            one parameter name in the module
            False if missing params are ok -- will not raise an err
        """
        super().__init__(
            start_epoch=start_epoch,
            end_epoch=end_epoch,
            update_frequency=update_frequency,
            min_end=0.0,
            end_comparator=1,
        )
        self._params = validate_str_iterable(
            params, "{} for params".format(self.__class__.__name__)
        )
        self._init_val = init_val
        self._final_val = final_val
        self._init_val_tens = None
        self._final_val_tens = None
        self._inter_func = inter_func
        self._params_strict = params_strict
        self._module_params = []  # type: List[Parameter]

        self.validate()
示例#10
0
    def __init__(
        self,
        params: Union[str, List[str]],
        start_epoch: float = -1,
        end_epoch: float = -1,
        log_types: Union[str, List[str]] = ALL_TOKEN,
        **kwargs,
    ):
        kwargs["end_comparator"] = kwargs.get("end_comparator", None)
        super().__init__(log_types=log_types,
                         start_epoch=start_epoch,
                         end_epoch=end_epoch,
                         **kwargs)

        self._params = validate_str_iterable(params, "{} for params".format(
            self.__class__.__name__))  # type: List[str]
示例#11
0
 def __init__(
     self,
     params: Union[str, List[str]],
     val: Any,
     params_strict: bool = True,
     start_epoch: float = 0.0,
     end_epoch: float = -1.0,
 ):
     super().__init__(start_epoch=start_epoch,
                      end_epoch=end_epoch,
                      end_comparator=None)
     self._params = validate_str_iterable(
         params, "{} for params".format(self.__class__.__name__))
     self._val = val
     self._params_strict = params_strict
     self._module_params = []  # type: List[Parameter]
示例#12
0
 def __init__(
     self,
     params: Union[str, List[str]],
     start_epoch: float = -1.0,
     end_epoch: float = -1.0,
     log_types: Union[str, List[str]] = ALL_TOKEN,
 ):
     super().__init__(
         log_types=log_types,
         start_epoch=start_epoch,
         end_epoch=end_epoch,
         end_comparator=-1,
     )
     self._params = validate_str_iterable(
         params, "{} for params".format(self.__class__.__name__))
     self._module_masks = None  # type: ModuleParamPruningMask
     self._analyzers = None
     self._last_logged_epoch = None
示例#13
0
 def __init__(
     self,
     params: Union[str, List[str]],
     trainable: bool,
     params_strict: bool = True,
     start_epoch: float = -1.0,
     end_epoch: float = -1.0,
 ):
     super().__init__(start_epoch=start_epoch,
                      end_epoch=end_epoch,
                      end_comparator=-1)
     self._start_epoch = start_epoch
     self._params = validate_str_iterable(
         params, "{} for params".format(self.__class__.__name__))
     self._trainable = convert_to_bool(trainable)
     self._params_strict = convert_to_bool(params_strict)
     self._module_params = []  # type: List[Parameter]
     self._original = []
示例#14
0
 def __init__(
     self,
     params: Union[str, List[str]],
     start_epoch: float = -1,
     end_epoch: float = -1,
     log_types: Union[str, List[str]] = ALL_TOKEN,
 ):
     super(ConstantPruningModifier, self).__init__(
         log_types=log_types,
         start_epoch=start_epoch,
         end_epoch=end_epoch,
         end_comparator=None,
     )
     self._params = validate_str_iterable(params, "{} for params".format(
         self.__class__.__name__))  # type: List[str]
     self._prune_op_vars = None
     self._update_ready = None
     self._sparsity = None
示例#15
0
    def __init__(
        self,
        params: Union[str, List[str]],
        init_sparsity: float,
        final_sparsity: float,
        start_epoch: float,
        end_epoch: float,
        update_frequency: float,
        inter_func: str = "cubic",
        log_types: Union[str, List[str]] = ALL_TOKEN,
        mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured",
        leave_enabled: bool = True,
    ):
        super(GMPruningModifier, self).__init__(
            log_types=log_types,
            start_epoch=start_epoch,
            min_start=-1.0,
            end_epoch=end_epoch,
            min_end=0.0,
            end_comparator=1,
            update_frequency=update_frequency,
            min_frequency=-1.0,
        )
        self._params = validate_str_iterable(
            params, "{} for params".format(self.__class__.__name__)
        )  # type: List[str]
        self._layer_names = [get_layer_name_from_param(p) for p in self._params]
        self._init_sparsity = init_sparsity
        self._final_sparsity = final_sparsity
        self._leave_enabled = convert_to_bool(leave_enabled)
        self._inter_func = inter_func
        self._mask_type = mask_type
        self._mask_creator = mask_type
        self._leave_enabled = convert_to_bool(leave_enabled)
        if not isinstance(mask_type, PruningMaskCreator):
            self._mask_creator = load_mask_creator(mask_type)
        self._prune_op_vars = None
        self._update_ready = None
        self._sparsity = None
        self._mask_initializer = None

        self._masked_layers = []

        self.validate()
示例#16
0
    def __init__(
        self,
        init_sparsity: float,
        final_sparsity: float,
        start_epoch: float,
        end_epoch: float,
        update_frequency: float,
        params: Union[str, List[str]],
        leave_enabled: bool = True,
        inter_func: str = "cubic",
        log_types: Union[str, List[str]] = ALL_TOKEN,
        mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured",
        global_sparsity: bool = False,
    ):
        super().__init__(
            log_types=log_types,
            start_epoch=start_epoch,
            end_epoch=end_epoch,
            update_frequency=update_frequency,
            min_end=0.0,
            end_comparator=1,
        )
        self._init_sparsity = init_sparsity
        self._final_sparsity = final_sparsity
        self._params = validate_str_iterable(
            params, "{} for params".format(self.__class__.__name__)
        )
        self._leave_enabled = convert_to_bool(leave_enabled)
        self._inter_func = inter_func
        self._mask_type = mask_type
        self._mask_creator = mask_type
        if not isinstance(mask_type, PruningMaskCreator):
            self._mask_creator = load_mask_creator(mask_type)
        self._global_sparsity = global_sparsity
        self._module_masks = None  # type: ModuleParamPruningMask
        self._applied_sparsity = None
        self._last_logged_sparsity = None
        self._last_logged_epoch = None
        self._analyzers = None

        self._non_serializable_props = {}

        self.validate()
示例#17
0
    def __init__(
        self,
        params: Union[str, List[str]],
        start_epoch: float = -1,
        end_epoch: float = -1,
        log_types: Union[str, List[str]] = ALL_TOKEN,
    ):
        super(ConstantPruningModifier, self).__init__(
            log_types=log_types,
            start_epoch=start_epoch,
            end_epoch=end_epoch,
            end_comparator=None,
        )
        self._params = validate_str_iterable(
            params, "{} for params".format(self.__class__.__name__)
        )  # type: List[str]
        self._layer_names = [get_layer_name_from_param(p) for p in self._params]
        self._masked_layers = []

        self._sparsity_scheduler = None
        self._mask_creator = load_mask_creator("unstructured")
示例#18
0
    def __init__(
        self,
        layers: Union[str, List[str]],
        alpha: Union[float, List[float]],
        layer_normalized: bool = False,
        reg_func: str = "l1",
        reg_tens: str = "inp",
        start_epoch: float = -1.0,
        end_epoch: float = -1.0,
    ):
        super().__init__(start_epoch=start_epoch,
                         end_epoch=end_epoch,
                         end_comparator=-1)
        self._layers = validate_str_iterable(
            layers, "{} for layers".format(self.__class__.__name__))
        self._alpha = alpha
        self._layer_normalized = convert_to_bool(layer_normalized)
        self._reg_func = reg_func
        self._reg_tens = reg_tens
        self._trackers = []  # type: List[ASLayerTracker]

        self.validate()
示例#19
0
 def __init__(
     self,
     params: Union[str, List[str]],
     trainable: bool,
     params_strict: bool = True,
     start_epoch: float = -1.0,
     end_epoch: float = -1.0,
     **kwargs,
 ):
     kwargs["end_comparator"] = kwargs.get("end_comparator", -1)
     super(TrainableParamsModifier, self).__init__(
         start_epoch=start_epoch,
         end_epoch=end_epoch,
         **kwargs,
     )
     self._params = validate_str_iterable(
         params, "{} for params".format(self.__class__.__name__)
     )
     self._trainable = convert_to_bool(trainable)
     self._params_strict = convert_to_bool(params_strict)
     self._vars_to_trainable_orig = {}
     self.validate()
示例#20
0
def test_validate_str_iterable(test_list, output):
    validated = validate_str_iterable(test_list, "")
    assert validated == output
示例#21
0
def test_validate_str_iterable_negative():
    with pytest.raises(ValueError):
        validate_str_iterable("will fail", "")