def init_pending_tasks(self) -> List[Task]: origin_model = torch.load(self._origin_model_path) origin_masks = torch.load(self._origin_masks_path) with open(self._origin_config_list_path, "r") as f: origin_config_list = json_tricks.load(f) self.T = [] self.action = None self.observation = None self.warmup_episode = self.ddpg_params[ 'warmup'] if 'warmup' in self.ddpg_params.keys() else int( self.total_episode / 4) config_list_copy = config_list_canonical(origin_model, origin_config_list) total_sparsity = config_list_copy[0]['total_sparsity'] max_sparsity_per_layer = config_list_copy[0].get( 'max_sparsity_per_layer', 1.) self.env = AMCEnv(origin_model, origin_config_list, self.dummy_input, total_sparsity, max_sparsity_per_layer, self.target) self.agent = DDPG(len(self.env.state_feature), 1, self.ddpg_params) self.agent.is_training = True task_result = TaskResult('origin', origin_model, origin_masks, origin_masks, None) return self.generate_tasks(task_result)
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}): self.current_iteration = 1 if self.skip_first_iteration else 0 self.target_sparsity = config_list_canonical(model, config_list) super().reset(model, config_list=config_list, masks=masks)
def __init__(self, model: Module, config_list: List[Dict], dummy_input: Tensor, total_sparsity: float, max_sparsity_per_layer: Dict[str, float], target: str = 'flops'): pruning_op_names = [] [pruning_op_names.extend(config['op_names']) for config in config_list_canonical(model, config_list)] self.pruning_ops = OrderedDict() self.pruning_types = [] for i, (name, layer) in enumerate(model.named_modules()): if name in pruning_op_names: op_type = type(layer).__name__ stride = np.power(np.prod(layer.stride), 1 / len(layer.stride)) if hasattr(layer, 'stride') else 0 # type: ignore kernel_size = np.power(np.prod(layer.kernel_size), 1 / len(layer.kernel_size)) if hasattr(layer, 'kernel_size') else 1 # type: ignore self.pruning_ops[name] = (i, op_type, stride, kernel_size) self.pruning_types.append(op_type) self.pruning_types = list(set(self.pruning_types)) self.pruning_op_names = list(self.pruning_ops.keys()) self.dummy_input = dummy_input self.total_sparsity = total_sparsity self.max_sparsity_per_layer = max_sparsity_per_layer assert target in ['flops', 'params'] self.target = target self.origin_target, self.origin_params_num, origin_statistics = count_flops_params(model, dummy_input, verbose=False) self.origin_statistics = {result['name']: result for result in origin_statistics} self.under_pruning_target = sum([self.origin_statistics[name][self.target] for name in self.pruning_op_names]) self.excepted_pruning_target = self.total_sparsity * self.under_pruning_target
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}): self.current_iteration = 1 self.target_sparsity = config_list_canonical(model, config_list) super(FunctionBasedTaskGenerator, self).reset(model, config_list=config_list, masks=masks)
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}): self.current_temperature = self.start_temperature self.weights_numel, self.masked_rate = get_model_weights_numel(model, config_list, masks) self.target_sparsity_list = config_list_canonical(model, config_list) self._adjust_target_sparsity() self._temp_config_list = None self._current_sparsity_list = None self._current_score = None super().reset(model, config_list=config_list, masks=masks)
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}): self.current_temperature = self.start_temperature # TODO: replace with validation here for config in config_list: if 'sparsity' in config or 'sparsity_per_layer' in config: _logger.warning('Only `total_sparsity` can be differentially allocated sparse ratio to each layer, `sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. Make sure you know what this will lead to, otherwise please use `total_sparsity`.') self.weights_numel, self.masked_rate = get_model_weights_numel(model, config_list, masks) self.target_sparsity_list = config_list_canonical(model, config_list) self._adjust_target_sparsity() self._temp_config_list = [] self._current_sparsity_list = [] self._current_score = 0. super().reset(model, config_list=config_list, masks=masks)
def validate_config(self, model: Module, config_list: List[Dict]): self._validate_config_before_canonical(model, config_list) self.config_list = config_list_canonical(model, config_list)