예제 #1
0
    def generate_tasks(self, task_result: TaskResult) -> List[Task]:
        compact_model = task_result.compact_model
        compact_model_masks = task_result.compact_model_masks

        # save intermediate result
        model_path = Path(self._intermediate_result_dir,
                          '{}_compact_model.pth'.format(task_result.task_id))
        masks_path = Path(
            self._intermediate_result_dir,
            '{}_compact_model_masks.pth'.format(task_result.task_id))
        torch.save(compact_model, model_path)
        torch.save(compact_model_masks, masks_path)

        # get current2origin_sparsity and compact2origin_sparsity
        origin_model = torch.load(self._origin_model_path)
        current2origin_sparsity, compact2origin_sparsity, _ = compute_sparsity(
            origin_model, compact_model, compact_model_masks,
            self.target_sparsity)
        _logger.debug(
            '\nTask %s total real sparsity compared with original model is:\n%s',
            str(task_result.task_id),
            json_tricks.dumps(current2origin_sparsity, indent=4))
        if task_result.task_id != 'origin':
            self._tasks[task_result.task_id].state[
                'current2origin_sparsity'] = current2origin_sparsity

        # if reach the total_iteration, no more task will be generated
        if self.current_iteration > self.total_iteration:
            return []

        task_id = self._task_id_candidate
        new_config_list = self.generate_config_list(self.target_sparsity,
                                                    self.current_iteration,
                                                    compact2origin_sparsity)
        new_config_list = self.allocate_sparsity(new_config_list,
                                                 compact_model,
                                                 compact_model_masks)
        config_list_path = Path(self._intermediate_result_dir,
                                '{}_config_list.json'.format(task_id))

        with Path(config_list_path).open('w') as f:
            json_tricks.dump(new_config_list, f, indent=4)
        task = Task(task_id, model_path, masks_path, config_list_path)

        self._tasks[task_id] = task

        self._task_id_candidate += 1
        self.current_iteration += 1

        return [task]
예제 #2
0
파일: amc_pruner.py 프로젝트: maxpark/nni
    def generate_tasks(self, task_result: TaskResult) -> List[Task]:
        # append experience & update agent policy
        if self.action is not None:
            action, reward, observation, done = self.env.step(self.action, task_result.compact_model)
            self.T.append([reward, self.observation, observation, self.action, done])
            self.observation = observation.copy()

            if done:
                assert task_result.score is not None, 'task_result.score should not be None if environment is done.'
                final_reward = task_result.score - 1
                # agent observe and update policy
                for _, s_t, s_t1, a_t, d_t in self.T:
                    self.agent.observe(final_reward, s_t, s_t1, a_t, d_t)
                    if self.current_episode > self.warmup_episode:
                        self.agent.update_policy()

                self.current_episode += 1
                self.T = []
                self.action = None
                self.observation = None

            # update current2origin_sparsity in log file
            origin_model = torch.load(self._origin_model_path)
            compact_model = task_result.compact_model
            compact_model_masks = task_result.compact_model_masks
            current2origin_sparsity, _, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.temp_config_list)
            self._tasks[task_result.task_id].state['current2origin_sparsity'] = current2origin_sparsity
            current2origin_sparsity, _, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.config_list_copy)
            self._tasks[task_result.task_id].state['current_total_sparsity'] = current2origin_sparsity
            flops, params, _ = count_flops_params(compact_model, self.dummy_input, verbose=False)
            self._tasks[task_result.task_id].state['current_flops'] = '{:.2f} M'.format(flops / 1e6)
            self._tasks[task_result.task_id].state['current_params'] = '{:.2f} M'.format(params / 1e6)

        # generate new action
        if self.current_episode < self.total_episode:
            if self.observation is None:
                self.observation = self.env.reset().copy()
                self.temp_config_list = []
                compact_model = torch.load(self._origin_model_path)
                compact_model_masks = torch.load(self._origin_masks_path)
            else:
                compact_model = task_result.compact_model
                compact_model_masks = task_result.compact_model_masks
            if self.current_episode <= self.warmup_episode:
                action = self.agent.random_action()
            else:
                action = self.agent.select_action(self.observation, episode=self.current_episode)
            action = action.tolist()[0]

            self.action = self.env.correct_action(action, compact_model)
            sub_config_list = [{'op_names': [self.env.current_op_name], 'total_sparsity': self.action}]
            self.temp_config_list.extend(sub_config_list)

            task_id = self._task_id_candidate
            if self.env.is_first_layer() or self.env.is_final_layer():
                task_config_list = self.temp_config_list
            else:
                task_config_list = sub_config_list

            config_list_path = Path(self._intermediate_result_dir, '{}_config_list.json'.format(task_id))
            with Path(config_list_path).open('w') as f:
                json_tricks.dump(task_config_list, f, indent=4)

            model_path = Path(self._intermediate_result_dir, '{}_compact_model.pth'.format(task_result.task_id))
            masks_path = Path(self._intermediate_result_dir, '{}_compact_model_masks.pth'.format(task_result.task_id))
            torch.save(compact_model, model_path)
            torch.save(compact_model_masks, masks_path)

            task = Task(task_id, model_path, masks_path, config_list_path)
            if not self.env.is_final_layer():
                task.finetune = False
                task.evaluate = False

            self._tasks[task_id] = task
            self._task_id_candidate += 1
            return [task]
        else:
            return []