예제 #1
0
    def init_pending_tasks(self) -> List[Task]:
        origin_model = torch.load(self._origin_model_path)
        origin_masks = torch.load(self._origin_masks_path)

        task_result = TaskResult('origin', origin_model, origin_masks, origin_masks, None)

        return self.generate_tasks(task_result)
예제 #2
0
파일: amc_pruner.py 프로젝트: microsoft/nni
    def init_pending_tasks(self) -> List[Task]:
        origin_model = torch.load(self._origin_model_path)
        origin_masks = torch.load(self._origin_masks_path)
        with open(self._origin_config_list_path, "r") as f:
            origin_config_list = json_tricks.load(f)

        self.T = []
        self.action = None
        self.observation = None
        self.warmup_episode = self.ddpg_params[
            'warmup'] if 'warmup' in self.ddpg_params.keys() else int(
                self.total_episode / 4)

        config_list_copy = config_list_canonical(origin_model,
                                                 origin_config_list)
        total_sparsity = config_list_copy[0]['total_sparsity']
        max_sparsity_per_layer = config_list_copy[0].get(
            'max_sparsity_per_layer', 1.)

        self.env = AMCEnv(origin_model, origin_config_list, self.dummy_input,
                          total_sparsity, max_sparsity_per_layer, self.target)
        self.agent = DDPG(len(self.env.state_feature), 1, self.ddpg_params)
        self.agent.is_training = True
        task_result = TaskResult('origin', origin_model, origin_masks,
                                 origin_masks, None)

        return self.generate_tasks(task_result)
예제 #3
0
    def init_pending_tasks(self) -> List[Task]:
        origin_model = torch.load(self._origin_model_path)
        origin_masks = torch.load(self._origin_masks_path)

        self.temp_model_path = Path(self._intermediate_result_dir, 'origin_compact_model.pth')
        self.temp_masks_path = Path(self._intermediate_result_dir, 'origin_compact_model_masks.pth')
        torch.save(origin_model, self.temp_model_path)
        torch.save(origin_masks, self.temp_masks_path)

        task_result = TaskResult('origin', origin_model, origin_masks, origin_masks, None)

        return self.generate_tasks(task_result)
예제 #4
0
def run_task_generator_(task_generator):
    task = task_generator.next()
    factor = 0.9
    count = 0
    while task is not None:
        factor = factor**2
        count += 1
        task_result = TaskResult(task.task_id, TorchModel(), {}, {},
                                 1 - factor)
        task_generator.receive_task_result(task_result)
        task = task_generator.next()
    return count
예제 #5
0
    def pruning_one_step_normal(self, task: Task) -> TaskResult:
        """
        generate masks -> speedup -> finetune -> evaluate
        """
        model, masks, config_list = task.load_data()
        self.pruner.reset(model, config_list)
        self.pruner.load_masks(masks)

        # pruning model
        compact_model, pruner_generated_masks = self.pruner.compress()
        compact_model_masks = deepcopy(pruner_generated_masks)

        # show the pruning effect
        self.pruner.show_pruned_weights()
        self.pruner._unwrap_model()

        # speedup
        if self.speedup and task.speedup:
            ModelSpeedup(compact_model, self.dummy_input,
                         pruner_generated_masks).speedup_model()
            compact_model_masks = {}

        # finetune
        if self.finetuner is not None and task.finetune:
            if self.speedup:
                self.finetuner(compact_model)
            else:
                self.pruner._wrap_model()
                self.finetuner(compact_model)
                self.pruner._unwrap_model()

        # evaluate
        if self.evaluator is not None and task.evaluate:
            if self.speedup:
                score = self.evaluator(compact_model)
            else:
                self.pruner._wrap_model()
                score = self.evaluator(compact_model)
                self.pruner._unwrap_model()
        else:
            score = None

        # clear model references
        self.pruner.clear_model_references()

        return TaskResult(task.task_id, compact_model, compact_model_masks,
                          pruner_generated_masks, score)