コード例 #1
0
 def _save_model_desc(self):
     search_space = SearchSpace()
     codec = Codec(self.cfg.codec, search_space)
     pareto_front_df = pd.read_csv(
         FileOps.join_path(self.result_path, "pareto_front.csv"))
     codes = pareto_front_df['Code']
     for i in range(len(codes)):
         search_desc = Config()
         search_desc.custom = deepcopy(search_space.search_space.custom)
         search_desc.modules = deepcopy(search_space.search_space.modules)
         code = codes.loc[i]
         search_desc.custom.code = code
         search_desc.custom.method = 'full'
         codec.decode(search_desc.custom)
         self.trainer.output_model_desc(i, search_desc)
コード例 #2
0
class ESRTrainerCallback(Callback):
    """Construct the trainer of ESR-EA."""
    def before_train(self, epoch, logs=None):
        """Be called before the training process."""
        self.cfg = self.trainer.cfg
        # Use own save checkpoint and save performance function
        self.trainer.auto_save_ckpt = False
        self.trainer.auto_save_perf = False
        # This part is tricky and
        model = ClassFactory.__configs__.get('model', None)
        if model:
            self.model_desc = model.get("model_desc", None)
            if self.model_desc is not None:
                model = self._init_model()
                self.trainer.build(model=model)

    def make_batch(self, batch):
        """Make batch for each training step."""
        input = batch["LR"]
        target = batch["HR"]
        if self.cfg.cuda:
            input = input.cuda()
            target = target.cuda()
        return input, target

    def after_epoch(self, epoch, logs=None):
        """Be called after one epoch training."""
        # Get summary perfs from logs from built-in MetricsEvaluator callback.
        self.performance = logs.get('summary_perfs', None)
        best_valid_perfs = self.performance['best_valid_perfs']
        best_valid = list(best_valid_perfs.values())[0]
        best_changed = self.performance['best_valid_perfs_changed']
        if best_changed:
            self._save_checkpoint({"Best PSNR": best_valid, "Epoch": epoch})

    def after_train(self, logs=None):
        """Be called after the whole train process."""
        # Extract performance logs. This can be moved into builtin callback
        # if we can unify the performace content
        best_valid_perfs = self.performance['best_valid_perfs']
        best_valid = list(best_valid_perfs.values())[0]
        self._save_performance(best_valid)

    def _save_checkpoint(self, performance=None, model_name="best.pth"):
        """Save the trained model.

        :param performance: dict of all the result needed
        :type performance: dictionary
        :param model_name: name of the result file
        :type model_name: string
        :return: the path of the saved file
        :rtype: string
        """
        local_worker_path = self.trainer.get_local_worker_path()
        model_save_path = os.path.join(local_worker_path, model_name)
        torch.save(
            {
                'model_state_dict': self.trainer.model.state_dict(),
                **performance
            }, model_save_path)

        torch.save(self.trainer.model.state_dict(), model_save_path)
        logging.info("model saved to {}".format(model_save_path))
        return model_save_path

    def _save_performance(self, performance, model_desc=None):
        """Save result of the model, and calculate pareto front.

        :param performance: The dict that contains all the result needed
        :type performance: dictionary
        :param model_desc: config of the model
        :type model_desc: dictionary
        """
        self.trainer._save_performance(performance)
        # FileOps.copy_file(self.performance_file, self.best_model_pfm)
        pd_path = os.path.join(self.trainer.local_output_path,
                               'population_fitness.csv')
        df = pd.DataFrame([[performance]], columns=["PSNR"])
        if not os.path.exists(pd_path):
            with open(pd_path, "w") as file:
                df.to_csv(file, index=False)
        else:
            with open(pd_path, "a") as file:
                df.to_csv(file, index=False, header=False)

    def _init_model(self):
        """Initialize the model architecture for full train step.

        :return: train model
        :rtype: class
        """
        search_space = Config({"search_space": self.model_desc})
        self.codec = Codec(self.cfg.codec, search_space)
        self.get_selected_arch()
        indiv_cfg = self.codec.decode(self.elitism)
        self.trainer.model_desc = self.elitism.active_net_list()
        # self.output_model_desc()
        net_desc = NetworkDesc(indiv_cfg)
        model = net_desc.to_model()
        return model

    def get_selected_arch(self):
        """Get the gene code of selected model architecture."""
        self.elitism = ESRIndividual(self.codec, self.cfg)
        if "model_arch" in self.cfg and self.cfg.model_arch is not None:
            self.elitism.update_gene(self.cfg.model_arch)
        else:
            sel_arch_file = self.cfg.model_desc_file
            sel_arch = np.load(sel_arch_file)
            self.elitism.update_gene(sel_arch[0])
コード例 #3
0
ファイル: sp_nas.py プロジェクト: zhwzhong/vega
class SpNas(SearchAlgorithm):
    """Search Algorithm Stage of SPNAS."""

    def __init__(self, search_space=None):
        super(SpNas, self).__init__(search_space)
        self.search_space = search_space
        self.codec = Codec(self.cfg.codec, search_space)
        self.sample_level = self.cfg.sample_level
        self.max_sample = self.cfg.max_sample
        self.max_optimal = self.cfg.max_optimal
        self._total_list_name = self.cfg.total_list
        self.serial_settings = self.cfg.serial_settings

        self._total_list = ListDict()
        self.sample_count = 0
        self.init_code = None
        remote_output_path = FileOps.join_path(self.local_output_path, self.cfg.step_name)

        if 'last_search_result' in self.cfg:
            last_search_file = self.cfg.last_search_result
            assert FileOps.exists(os.path.join(remote_output_path, last_search_file)
                                  ), "Not found serial results!"
            # self.download_task_folder()
            last_search_results = os.path.join(self.local_output_path, last_search_file)
            last_search_results = ListDict.load_csv(last_search_results)
            pre_worker_id, pre_arch = self.select_from_remote(self.max_optimal, last_search_results)
            # re-write config template
            if self.cfg.regnition:
                self.codec.config_template['model']['backbone']['reignition'] = True
                assert FileOps.exists(os.path.join(remote_output_path,
                                                   pre_arch + '_imagenet.pth')
                                      ), "Not found {} pretrained .pth file!".format(pre_arch)
                pretrained_pth = os.path.join(self.local_output_path, pre_arch + '_imagenet.pth')
                self.codec.config_template['model']['pretrained'] = pretrained_pth
                pre_worker_id = -1
            # update config template
            self.init_code = dict(arch=pre_arch,
                                  pre_arch=pre_arch.split('_')[1],
                                  pre_worker_id=pre_worker_id)

        logging.info("inited SpNas {}-level search...".format(self.sample_level))

    @property
    def is_completed(self):
        """Check sampling if finished.

        :return: True is completed, or False otherwise
        :rtype: bool
        """
        return self.sample_count > self.max_sample

    @property
    def num_samples(self):
        """Get the number of sampled architectures.

        :return: The number of sampled architectures
        :rtype: int
        """
        return len(self._total_list)

    def select_from_remote(self, max_optimal, total_list):
        """Select base model to mutate.

        :return: worker id and arch encode
        :rtype: int, str
        """
        def normalization(x):
            sum_ = sum(x)
            return [float(i) / float(sum_) for i in x]

        # rank with mAP and memory
        top_ = total_list.sort('mAP')[:max_optimal]
        if max_optimal > 1:
            prob = [round(np.log(i + 1e-2), 2) for i in range(1, len(top_) + 1)]
            prob_temp = prob
            sorted_ind = sorted(range(len(top_)), key=lambda k: top_['memory'][k], reverse=True)
            for idx, ind in enumerate(sorted_ind):
                prob[ind] += prob_temp[idx]
            ind = np.random.choice(len(top_), p=normalization(prob))
            worker_id, arch = top_['worker_id', 'arch'][ind]
        else:
            worker_id, arch = top_['worker_id', 'arch'][0]
        return worker_id, arch

    def search(self):
        """Search a sample.

        :return: sample count and info
        :rtype: int, dict
        """
        code = self.init_code
        if self.num_samples > 0:
            pre_worker_id, pre_arch = self.select_from_remote(self.max_optimal, self._total_list)
            block_type, pre_serial, pre_paral = pre_arch.split('_')

            success = False
            while not success:
                serialnet, parallelnet = pre_serial, pre_paral
                if self.sample_level == 'serial':
                    serialnet = self._mutate_serialnet(serialnet, **self.serial_settings)
                    parallelnet = '-'.join(['0'] * len(serialnet.split('-')))
                elif self.sample_level == 'parallel':
                    parallelnet = self._mutate_parallelnet(parallelnet)
                    pre_worker_id = self.init_code['pre_worker_id']
                else:
                    raise ValueError("Unknown type of sample level")
                arch = self.codec.encode(block_type, serialnet, parallelnet)
                if arch not in self._total_list['arch'] or len(self._total_list['arch']) == 0:
                    success = True
            code = dict(arch=arch,
                        pre_arch=pre_serial,
                        pre_worker_id=pre_worker_id)

        self.sample_count += 1
        logging.info("The {}-th successfully sampling result: {}".format(self.sample_count, code))
        net_desc = self.codec.decode(code)
        return self.sample_count, net_desc

    def update(self, worker_result_path):
        """Update sampler."""
        performance_file = self.performance_path(worker_result_path)
        logging.info(
            "SpNas.update(), performance file={}".format(performance_file))
        info = FileOps.load_pickle(performance_file)
        if info is not None:
            self._total_list.append(info)
        else:
            logging.info("SpNas.update(), file is not exited, "
                         "performance file={}".format(performance_file))
        self.save_output(self.local_output_path)
        if self.backup_base_path is not None:
            FileOps.copy_folder(self.local_output_path, self.backup_base_path)

    def performance_path(self, worker_result_path):
        """Get performance path."""
        performance_dir = os.path.join(worker_result_path, 'performance')
        if not os.path.exists(performance_dir):
            FileOps.make_dir(performance_dir)
        return os.path.join(performance_dir, 'performance.pkl')

    def save_output(self, local_output_path):
        """Save results."""
        local_totallist_path = os.path.join(local_output_path, self._total_list_name)
        self._total_list.to_csv(local_totallist_path)

    def _mutate_serialnet(self, arch, num_mutate=3, expend_ratio=0, addstage_ratio=0.95, max_stages=6):
        """Swap & Expend operation in Serial-level searching.

        :param arch: base arch encode
        :type arch: str
        :param num_mutate: number of mutate
        :type num_mutate: int
        :param expend_ratio: probability of expend block
        :type expend_ratio: float
        :param addstage_ratio: probability of expand new stage
        :type addstage_ratio: float
        :param max_stages:  max stage allowed to expand
        :type max_stages: int
        :return: arch encode after mutate
        :rtype: str
        """
        def is_valid(arch):
            stages = arch.split('-')
            for stage in stages:
                if len(stage) == 0:
                    return False
            return True

        def expend(arc):
            idx = np.random.randint(low=1, high=len(arc))
            arc = arc[:idx] + '1' + arc[idx:]
            return arc, idx

        def swap(arc, len_step=3):
            is_not_valid = True
            arc_origin = copy.deepcopy(arc)
            temp = arc.split('-')
            num_insert = len(temp) - 1
            while is_not_valid or arc == arc_origin:
                next_start = 0
                arc = list(''.join(temp))
                for i in range(num_insert):
                    pos = arc_origin[next_start:].find('-') + next_start
                    assert arc_origin[pos] == '-', "Wrong '-' is found!"
                    max_step = min(len_step, max(len(temp[i]), len(temp[i + 1])))
                    step_range = list(range(-1 * max_step, max_step))
                    step = random.choice(step_range)
                    next_start = pos + 1
                    pos = pos + step
                    arc.insert(pos, '-')
                arc = ''.join(arc)
                is_not_valid = (not is_valid(arc))
            return arc

        arch_origin = arch
        success = False
        k = 0
        while not success:
            k += 1
            arch = arch_origin
            ops = []
            for i in range(num_mutate):
                op_idx = np.random.randint(low=0, high=3)
                adds_thresh_ = addstage_ratio if len(arch.split('-')) < max_stages else 1
                if op_idx == 0 and random.random() > expend_ratio:
                    arch, idx = expend(arch)
                    arch, idx = expend(arch)
                elif op_idx == 1:
                    arch = swap(arch)
                elif op_idx == 2 and random.random() > adds_thresh_:
                    arch = arch + '-1'
                    ops.append('add stage')
                else:
                    ops.append('Do Nothing.')
            success = arch != arch_origin
            flag = 'Success' if success else 'Failed'
            logging.info('Serial-level Sample{}: {}. {}.'.format(k + 1, ' -> '.join(ops), flag))
        return arch

    def _mutate_parallelnet(self, arch):
        """Mutate operation in Parallel-level searching.

        :param arch: base arch encode
        :type arch: str
        :return: parallel arch encode after mutate
        :rtype: str
        """
        def limited_random(num_stage):
            p = [0.4, 0.3, 0.2, 0.1]
            l = np.random.choice(4, size=num_stage, replace=True, p=p)
            l = [str(i) for i in l]
            return '-'.join(l)

        num_stage = len(arch.split('-'))
        success = False
        k = 0
        while not success:
            k += 1
            new_arch = limited_random(num_stage)
            success = new_arch != arch
            flag = 'Success' if success else 'Failed'
            logging.info('Parallel-level Sample{}: {}.'.format(k + 1, flag))
        return new_arch
コード例 #4
0
ファイル: quant_ea.py プロジェクト: zhwzhong/vega
class QuantEA(SearchAlgorithm):
    """Class of Evolution Algorithm used to Quant Example.

    :param search_space: search space
    :type search_space: SearchSpace
    """
    def __init__(self, search_space):
        super(QuantEA, self).__init__(search_space)
        self.length = self.policy.length
        self.num_individual = self.policy.num_individual
        self.num_generation = self.policy.num_generation
        self.x_axis = 'flops'
        self.y_axis = 'acc'
        self.random_models = self.policy.random_models
        self.codec = Codec(self.cfg.codec, search_space)
        self.bit_candidates = self.codec.search_space.bit_candidates
        self.random_count = 0
        self.ea_count = 0
        self.ea_epoch = 0
        self.step_path = FileOps.join_path(self.local_output_path,
                                           self.cfg.step_name)
        self.pd_file_name = FileOps.join_path(self.step_path, "performace.csv")
        self.pareto_front_file = FileOps.join_path(self.step_path,
                                                   "pareto_front.csv")
        self.pd_path = FileOps.join_path(self.step_path, "pareto_front")
        FileOps.make_dir(self.pd_path)

    def get_pareto_front(self):
        """Get pareto front from remote result file.

        :return: pareto front
        :rtype: dataframe
        """
        with open(self.pd_file_name, "r") as file:
            df = pd.read_csv(file)
        fitness = df[[self.x_axis, self.y_axis]].values.transpose()
        # acc2error
        fitness[1, :] = 1 - fitness[1, :]
        _, _, selected = SortAndSelectPopulation(fitness, self.num_individual)
        result = df.loc[selected, :]
        if self.ea_count % self.num_individual == 0:
            file_name = "{}_epoch.csv".format(str(self.ea_epoch))
            pd_result_file = FileOps.join_path(self.pd_path, file_name)
            with open(pd_result_file, "w") as file:
                result.to_csv(file, index=False)
            with open(self.pareto_front_file, "w") as file:
                result.to_csv(file, index=False)
            self.ea_epoch += 1
        return result

    def crossover(self, ind0, ind1):
        """Cross over operation in EA algorithm.

        :param ind0: individual 0
        :type ind0: list of int
        :param ind1: individual 1
        :type ind1: list of int
        :return: new individual 0, new individual 1
        :rtype: list of int, list of int
        """
        two_idxs = np.random.randint(0, self.length, 2)
        start_idx, end_idx = np.min(two_idxs), np.max(two_idxs)
        a_copy = ind0.copy()
        b_copy = ind1.copy()
        a_copy[start_idx:end_idx] = ind1[start_idx:end_idx]
        b_copy[start_idx:end_idx] = ind0[start_idx:end_idx]
        return a_copy, b_copy

    def mutatation(self, ind):
        """Mutate operation in EA algorithm.

        :param ind: individual
        :type ind: list of int
        :return: new individual
        :rtype: list of int
        """
        two_idxs = np.random.randint(0, self.length, 2)
        start_idx, end_idx = np.min(two_idxs), np.max(two_idxs)
        a_copy = ind.copy()
        for k in range(start_idx, end_idx):
            candidates_ = self.bit_candidates.copy()
            candidates_.remove(ind[k])
            a_copy[k] = random.choice(candidates_)
        return a_copy

    def search(self):
        """Search one NetworkDesc from search space.

        :return: search id, network desc
        :rtype: int, NetworkDesc
        """
        if self.random_count < self.random_models:
            return self.random_count, self._random_sample()
        pareto_front_results = self.get_pareto_front()
        pareto_front = pareto_front_results["encoding"].tolist()
        if len(pareto_front) < 2:
            encoding1, encoding2 = pareto_front[0], pareto_front[0]
        else:
            encoding1, encoding2 = random.sample(pareto_front, 2)
        choice = random.randint(0, 1)
        # mutate
        if choice == 0:
            encoding1List = str2list(encoding1)
            encoding_new = self.mutatation(encoding1List)
        # crossover
        else:
            encoding1List = str2list(encoding1)
            encoding2List = str2list(encoding2)
            encoding_new, _ = self.crossover(encoding1List, encoding2List)
        self.ea_count += 1
        net_desc = self.codec.decode(encoding_new)
        return self.random_count + self.ea_count, net_desc

    def _random_sample(self):
        """Choose one sample randomly.

        :return: network desc
        :rtype: NetworkDesc
        """
        individual = []
        for _ in range(self.length):
            individual.append(random.choice(self.bit_candidates))
        self.random_count += 1
        return self.codec.decode(individual)

    def update(self, worker_path):
        """Update QuantEA."""
        if self.backup_base_path is not None:
            FileOps.copy_folder(self.local_output_path, self.backup_base_path)

    @property
    def is_completed(self):
        """Whether to complete algorithm.

        :return: whether to complete algorithm
        :rtype: bool
        """
        return self.random_count >= self.random_models and self.ea_epoch >= self.num_generation
コード例 #5
0
ファイル: esr_search.py プロジェクト: zhwzhong/vega
class ESRSearch(SearchAlgorithm):
    """Evolutionary search algorithm of the efficient super-resolution."""

    def __init__(self, search_space=None):
        """Construct the ESR EA search class.

        :param search_space: config of the search space
        :type search_space: dictionary
        """
        super(ESRSearch, self).__init__(search_space)
        self.search_space = search_space
        self.codec = Codec(self.cfg.codec, search_space)

        self.individual_num = self.policy.num_individual
        self.generation_num = self.policy.num_generation
        self.elitism_num = self.policy.num_elitism
        self.mutation_rate = self.policy.mutation_rate
        self.min_active = self.range.min_active
        self.max_params = self.range.max_params
        self.min_params = self.range.min_params

        self.indiv_count = 0
        self.evolution_count = 0
        self.initialize_pop()
        self.elitism = [ESRIndividual(self.codec, self.cfg) for _ in range(self.elitism_num)]
        self.elit_fitness = [0] * self.elitism_num
        self.fitness_pop = [0] * self.individual_num
        self.fit_state = [0] * self.individual_num

    @property
    def is_completed(self):
        """Tell whether the search process is completed.

        :return: True is completed, or False otherwise
        :rtype: bool
        """
        return self.indiv_count > self.generation_num * self.individual_num

    def update_fitness(self, evals):
        """Update the fitness of each individual.

        :param evals: the evalution
        :type evals: list
        """
        for i in range(self.individual_num):
            self.pop[i].update_fitness(evals[i])

    def update_elitism(self, evaluations):
        """Update the elitism and its fitness.

        :param evaluations: evaluations result
        :type evaluations: list
        """
        popu_all = [ESRIndividual(self.codec, self.cfg) for _ in range(self.elitism_num + self.individual_num)]
        for i in range(self.elitism_num + self.individual_num):
            if i < self.elitism_num:
                popu_all[i].copy(self.elitism[i])
            else:
                popu_all[i].copy(self.pop[i - self.elitism_num])
        fitness_all = self.elit_fitness + evaluations
        sorted_ind = sorted(range(len(fitness_all)), key=lambda k: fitness_all[k])
        for i in range(self.elitism_num):
            self.elitism[i].copy(popu_all[sorted_ind[len(fitness_all) - 1 - i]])
            self.elit_fitness[i] = fitness_all[sorted_ind[len(fitness_all) - 1 - i]]
        logging.info('Generation: {}, updated elitism fitness: {}'.format(self.evolution_count, self.elit_fitness))

    def _log_data(self, net_info_type='active_only', pop=None, value=0):
        """Get the evolution and network information of children.

        :param net_info_type:  defaults to 'active_only'
        :type net_info_type: str
        :param pop: defaults to None
        :type pop: list
        :param value:  defaults to 0
        :type value: int
        :return: log_list
        :rtype: list
        """
        log_list = [value, pop.parameter, pop.flops]
        if net_info_type == 'active_only':
            log_list.append(pop.active_net_list())
        elif net_info_type == 'full':
            log_list += pop.gene.flatten().tolist()
        else:
            pass
        return log_list

    def save_results(self):
        """Save the results of evolution contains the information of pupulation and elitism."""
        step_name = Config(deepcopy(UserConfig().data)).general.step_name
        _path = FileOps.join_path(self.local_output_path, step_name)
        FileOps.make_dir(_path)
        arch_file = FileOps.join_path(_path, 'arch.txt')
        arch_child = FileOps.join_path(_path, 'arch_child.txt')
        sel_arch_file = FileOps.join_path(_path, 'selected_arch.npy')
        sel_arch = []
        with open(arch_file, 'a') as fw_a, open(arch_child, 'a') as fw_ac:
            writer_a = csv.writer(fw_a, lineterminator='\n')
            writer_ac = csv.writer(fw_ac, lineterminator='\n')
            writer_ac.writerow(['Population Iteration: ' + str(self.evolution_count + 1)])
            for c in range(self.individual_num):
                writer_ac.writerow(
                    self._log_data(net_info_type='active_only', pop=self.pop[c],
                                   value=self.pop[c].fitness))

            writer_a.writerow(['Population Iteration: ' + str(self.evolution_count + 1)])
            for c in range(self.elitism_num):
                writer_a.writerow(self._log_data(net_info_type='active_only',
                                                 pop=self.elitism[c],
                                                 value=self.elit_fitness[c]))
                sel_arch.append(self.elitism[c].gene)
        sel_arch = np.stack(sel_arch)
        np.save(sel_arch_file, sel_arch)
        if self.backup_base_path is not None:
            FileOps.copy_folder(self.local_output_path, self.backup_base_path)

    def parent_select(self, parent_num=2, select_type='Tournament'):
        """Select parent from a population with Tournament or Roulette.

        :param parent_num: number of parents
        :type parent_num: int
        :param select_type: select_type, defaults to 'Tournament'
        :type select_type: str
        :return: the selected parent individuals
        :rtype: list
        """
        popu_all = [ESRIndividual(self.codec, self.cfg) for _ in range(self.elitism_num + self.individual_num)]
        parent = [ESRIndividual(self.codec, self.cfg) for _ in range(parent_num)]
        fitness_all = self.elit_fitness
        for i in range(self.elitism_num + self.individual_num):
            if i < self.elitism_num:
                popu_all[i].copy(self.elitism[i])
            else:
                popu_all[i].copy(self.pop[i - self.elitism_num])
                fitness_all = fitness_all + [popu_all[i].fitness]
        fitness_all = np.asarray(fitness_all)
        if select_type == 'Tournament':
            for i in range(parent_num):
                tourn = sample(range(len(popu_all)), 2)
                if fitness_all[tourn[0]] >= fitness_all[tourn[1]]:
                    parent[i].copy(popu_all[tourn[0]])
                    fitness_all[tourn[0]] = 0
                else:
                    parent[i] = popu_all[tourn[1]]
                    fitness_all[tourn[1]] = 0
        elif select_type == 'Roulette':
            eval_submean = fitness_all - np.min(fitness_all)
            eval_norm = eval_submean / sum(eval_submean)
            eva_threshold = np.cumsum(eval_norm)
            for i in range(parent_num):
                ran = random()
                selec_id = bisect_right(eva_threshold, ran)
                parent[i].copy(popu_all[selec_id])
                eval_submean[selec_id] = 0
                eval_norm = eval_submean / sum(eval_submean)
                eva_threshold = np.cumsum(eval_norm)
        else:
            logging.info('Wrong selection type')
        return parent

    def initialize_pop(self):
        """Initialize the population of first generation."""
        self.pop = [ESRIndividual(self.codec, self.cfg) for _ in range(self.individual_num)]
        for i in range(self.individual_num):
            while self.pop[i].active_num < self.min_active:
                self.pop[i].mutation_using(self.mutation_rate)
            while self.pop[i].parameter > self.max_params or self.pop[i].parameter < self.min_params:
                self.pop[i].mutation_node(self.mutation_rate)

    def get_mutate_child(self, muta_num):
        """Generate the mutated children of the next offspring with mutation operation.

        :param muta_num: number of mutated children
        :type muta_num: int
        """
        for i in range(muta_num):
            if int(self.individual_num / 2) == len(self.elitism):
                self.pop[i].copy(self.elitism[i])
            else:
                self.pop[i].copy(sample(self.elitism, 1)[0])
            self.pop[i].mutation_using(self.mutation_rate)
            while self.pop[i].active_num < self.min_active:
                self.pop[i].mutation_using(self.mutation_rate)
            self.pop[i].mutation_node(self.mutation_rate)
            while self.pop[i].parameter > self.max_params or self.pop[i].parameter < self.min_params:
                self.pop[i].mutation_node(self.mutation_rate)

    def get_cross_child(self, muta_num):
        """Generate the children of the next offspring with crossover operation.

        :param muta_num: number of mutated children
        :type muta_num: int
        """
        for i in range(int(self.individual_num / 4)):
            pop_id = muta_num + i * 2
            father, mother = self.parent_select(2, 'Roulette')
            length = np.random.randint(4, int(father.gene.shape[0] / 2))
            location = np.random.randint(0, father.gene.shape[0] - length)
            gene_1 = father.gene.copy()
            gene_2 = mother.gene.copy()
            gene_1[location:(location + length), :] = gene_2[location:(location + length), :]
            gene_2[location:(location + length), :] = father.gene[location:(location + length), :]
            self.pop[pop_id].update_gene(gene_1)
            self.pop[pop_id + 1].update_gene(gene_2)
            while self.pop[pop_id].active_num < self.min_active:
                self.pop[pop_id].mutation_using(self.mutation_rate)
            param = self.pop[pop_id].parameter
            while param > self.max_params or param < self.min_params:
                self.pop[pop_id].mutation_node(self.mutation_rate)
                param = self.pop[pop_id].parameter
            while self.pop[pop_id + 1].active_num < self.min_active:
                self.pop[pop_id + 1].mutation_using(self.mutation_rate)
            param = self.pop[pop_id + 1].parameter
            while param > self.max_params or param < self.min_params:
                self.pop[pop_id + 1].mutation_node(self.mutation_rate)
                param = self.pop[pop_id + 1].parameter

    def reproduction(self):
        """Generate the new offsprings."""
        muta_num = self.individual_num - (self.individual_num // 4) * 2
        self.get_mutate_child(muta_num)
        self.get_cross_child(muta_num)

    def update(self, local_worker_path):
        """Update function.

        :param local_worker_path: the local path that saved `performance.txt`.
        :type local_worker_path: str
        """
        update_id = int(os.path.basename(os.path.abspath(local_worker_path)))
        file_path = os.path.join(local_worker_path, 'performance.txt')
        with open(file_path, "r") as pf:
            fitness = float(pf.readline())
            self.fitness_pop[(update_id - 1) % self.individual_num] = fitness
            self.fit_state[(update_id - 1) % self.individual_num] = 1

    def get_fitness(self):
        """Get the evalutation of each individual.

        :return: a list of evaluations
        :rtype: list
        """
        pd_path = os.path.join(self.local_output_path, 'population_fitness.csv')
        with open(pd_path, "r") as file:
            df = pd.read_csv(file)
        fitness_all = df['PSNR'].values
        fitness = fitness_all[fitness_all.size - self.individual_num:]
        return list(fitness)

    def search(self):
        """Search one random model.

        :return: current number of samples, and the model
        :rtype: int and class
        """
        if self.indiv_count > 0 and self.indiv_count % self.individual_num == 0:
            if np.sum(np.asarray(self.fit_state)) < self.individual_num:
                return None, None
            else:
                self.update_fitness(self.fitness_pop)
                self.update_elitism(self.fitness_pop)
                self.save_results()
                self.reproduction()
                self.evolution_count += 1
                self.fitness_pop = [0] * self.individual_num
                self.fit_state = [0] * self.individual_num
        current_indiv = self.pop[self.indiv_count % self.individual_num]
        indiv_cfg = self.codec.decode(current_indiv)
        self.indiv_count += 1
        logging.info('model parameters:{}, model flops:{}'.format(current_indiv.parameter, current_indiv.flops))
        logging.info('model arch:{}'.format(current_indiv.active_net_list()))
        return self.indiv_count, NetworkDesc(indiv_cfg)
コード例 #6
0
ファイル: sr_mutate.py プロジェクト: zhwzhong/vega
class SRMutate(SearchAlgorithm):
    """Search algorithm of the mutated structures."""
    def __init__(self, search_space=None):
        """Construct the class SRMutate.

        :param search_space: Config of the search space
        """
        super(SRMutate, self).__init__(search_space)
        self.search_space = search_space
        self.codec = Codec(self.cfg.codec, search_space)
        self.max_sample = self.policy.num_sample
        self.num_mutate = self.policy.num_mutate
        self.sample_count = 0

    @property
    def is_completed(self):
        """Tell whether the search process is completed.

        :return: True is completed, or False otherwise
        """
        return self.sample_count >= self.max_sample

    def search(self):
        """Search one mutated model.

        :return: current number of samples, and the model
        """
        search_desc = self.search_space.search_space.custom
        pareto_front_folder = FileOps.join_path(self.local_base_path, "result")
        if 'pareto_folder' in self.search_space.cfg and self.search_space.cfg.pareto_folder is not None:
            pareto_front_folder = self.search_space.cfg.pareto_folder.replace(
                "{local_base_path}", self.local_base_path)
        pareto_front_df = pd.read_csv(
            FileOps.join_path(pareto_front_folder, "pareto_front.csv"))
        code_to_mutate = random.choice(pareto_front_df['Code'])

        current_mutate, code_mutated = 0, code_to_mutate
        num_candidates = len(search_desc["candidates"])
        while current_mutate < self.num_mutate:
            code_new = self.mutate_once(code_mutated, num_candidates)
            if code_new != code_mutated:
                current_mutate += 1
                code_mutated = code_new

        logging.info("Mutate from {} to {}".format(code_to_mutate,
                                                   code_mutated))
        search_desc['code'] = code_mutated
        search_desc['method'] = "mutate"
        search_desc = self.codec.decode(search_desc)
        self.sample_count += 1
        return self.sample_count, NetworkDesc(self.search_space.search_space)

    def mutate_once(self, code, num_largest):
        """Do one mutate.

        :param code: original code
        :param num_largest: number of candidates (largest number in code)
        :return: the mutated code
        """
        fun = random.choice(
            [self.flip_once, self.insert_once, self.erase, self.swap_once])
        return fun(code, num_largest)

    @staticmethod
    def flip_once(code, num_largest):
        """Flip one block.

        :param code: original code
        :param num_largest: number of candidates (largest number in code)
        :return: the mutated code
        """
        index_to_flip = random.choice(
            [index for index in range(len(code)) if code[index] != '+'])
        flip_choices = list(map(str, range(num_largest)))
        flip_choices.remove(code[index_to_flip])
        ch_flipped = random.choice(flip_choices)
        return code[:index_to_flip] + ch_flipped + code[index_to_flip + 1:]

    @staticmethod
    def insert_once(code, num_largest):
        """Insert one block.

        :param code: original code
        :param num_largest: number of candidates (largest number in code)
        :return: the mutated code
        """
        ch_insert = random.choice(list(map(str, range(num_largest))))
        place_insert = random.randint(0, len(code))
        return code[:place_insert] + ch_insert + code[place_insert:]

    @staticmethod
    def erase(code, num_largest):
        """Erase one block.

        :param code: original code
        :param num_largest: number of candidates (largest number in code)
        :return: the mutated code
        """
        place_choices, index = list(), 0
        while index < len(code):
            if code[index] == '+':
                index += 3
            else:
                place_choices.append(index)
                index += 1
        if len(place_choices) == 0:
            return code
        place_chosen = random.choice(place_choices)
        return code[:place_chosen] + code[place_chosen + 1:]

    @staticmethod
    def swap_once(code, num_largest):
        """Swap two adjacent blocks.

        :param code: original code
        :param num_largest: number of candidates (largest number in code)
        :return: the mutated code
        """
        parts, index = list(), 0
        while index < len(code):
            if code[index] == '+':
                parts.append(code[index:index + 3])
                index += 3
            else:
                parts.append(code[index])
                index += 1
        if len(parts) < 2:
            return code
        valid_choices = [
            index for index in range(len(parts) - 2)
            if parts[index] != parts[index + 1]
        ]
        if len(valid_choices) == 0:
            return code
        place_chosen = random.choice(valid_choices)
        parts[place_chosen], parts[place_chosen +
                                   1] = parts[place_chosen +
                                              1], parts[place_chosen]
        return ''.join(parts)

    def update(self, local_worker_path):
        """Update function.

        :param local_worker_path: Local path that saved `performance.txt`
        :type local_worker_path: str
        """
        pass
コード例 #7
0
ファイル: backbone_nas.py プロジェクト: zhwzhong/vega
class BackboneNas(SearchAlgorithm):
    """BackboneNas.

    :param search_space: input search_space
    :type: SeachSpace
    """
    def __init__(self, search_space=None):
        """Init BackboneNas."""
        super(BackboneNas, self).__init__(search_space)
        # ea or random
        self.search_space = search_space
        self.codec = Codec(self.cfg.codec, search_space)
        self.num_mutate = self.policy.num_mutate
        self.random_ratio = self.policy.random_ratio
        self.max_sample = self.range.max_sample
        self.min_sample = self.range.min_sample
        self.sample_count = 0
        logging.info("inited BackboneNas")
        self.pareto_front = ParetoFront(self.cfg)
        self.random_search = RandomSearchAlgorithm(self.search_space)
        self._best_desc_file = 'nas_model_desc.json'
        if 'best_desc_file' in self.cfg and self.cfg.best_desc_file is not None:
            self._best_desc_file = self.cfg.best_desc_file

    @property
    def is_completed(self):
        """Check if NAS is finished."""
        return self.sample_count > self.max_sample

    def search(self):
        """Search in search_space and return a sample."""
        sample = {}
        while sample is None or 'code' not in sample:
            pareto_dict = self.pareto_front.get_pareto_front()
            pareto_list = list(pareto_dict.values())
            if self.pareto_front.size < self.min_sample or random.random(
            ) < self.random_ratio or len(pareto_list) == 0:
                sample_desc = self.random_search.search()
                sample = self.codec.encode(sample_desc)
            else:
                sample = pareto_list[0]
            if sample is not None and 'code' in sample:
                code = sample['code']
                code = self.ea_sample(code)
                sample['code'] = code
            if not self.pareto_front._add_to_board(id=self.sample_count + 1,
                                                   config=sample):
                sample = None
        self.sample_count += 1
        logging.info(sample)
        sample_desc = self.codec.decode(sample)
        return self.sample_count, NetworkDesc(sample_desc)

    def random_sample(self):
        """Random sample from search_space."""
        sample_desc = self.random_search.search()
        sample = self.codec.encode(sample_desc, is_random=True)
        return sample

    def ea_sample(self, code):
        """Use EA op to change a arch code.

        :param code: list of code for arch
        :type code: list
        :return: changed code
        :rtype: list
        """
        new_arch = code.copy()
        self._insert(new_arch)
        self._remove(new_arch)
        self._swap(new_arch[0], self.num_mutate // 2)
        self._swap(new_arch[1], self.num_mutate // 2)
        return new_arch

    def update(self, worker_result_path):
        """Use train and evaluate result to update algorithm.

        :param worker_result_path: current result path
        :type: str
        """
        step_name = os.path.basename(os.path.dirname(worker_result_path))
        config_id = int(os.path.basename(worker_result_path))
        performance = self._get_performance(step_name, config_id)
        logging.info("update performance={}".format(performance))
        self.pareto_front.add_pareto_score(config_id, performance)
        self.save_output(self.local_output_path)
        if self.backup_base_path is not None:
            FileOps.copy_folder(self.local_base_path, self.backup_base_path)

    def _get_performance(self, step_name, worker_id):
        saved_folder = self.get_local_worker_path(step_name, worker_id)
        performance_file = FileOps.join_path(saved_folder, "performance.txt")
        if not os.path.isfile(performance_file):
            logging.info("Performance file is not exited, file={}".format(
                performance_file))
            return []
        with open(performance_file, 'r') as f:
            performance = []
            for line in f.readlines():
                line = line.strip()
                if line == "":
                    continue
                data = json.loads(line)
                if isinstance(data, list):
                    data = data[0]
                performance.append(data)
            logging.info("performance={}".format(performance))
        return performance

    def _insert(self, arch):
        """Random insert to arch code.

        :param arch: input arch code
        :type arch: list
        :return: changed arch code
        :rtype: list
        """
        idx = np.random.randint(low=0, high=len(arch[0]))
        arch[0].insert(idx, 1)
        idx = np.random.randint(low=0, high=len(arch[1]))
        arch[1].insert(idx, 1)
        return arch

    def _remove(self, arch):
        """Random remove one from arch code.

        :param arch: input arch code
        :type arch: list
        :return: changed arch code
        :rtype: list
        """
        # random pop arch[0]
        ones_index = [i for i, char in enumerate(arch[0]) if char == 1]
        idx = random.choice(ones_index)
        arch[0].pop(idx)
        # random pop arch[1]
        ones_index = [i for i, char in enumerate(arch[1]) if char == 1]
        idx = random.choice(ones_index)
        arch[1].pop(idx)
        return arch

    def _swap(self, arch, R):
        """Random swap one in arch code.

        :param arch: input arch code
        :type arch: list
        :return: changed arch code
        :rtype: list
        """
        while True:
            not_ones_index = [i for i, char in enumerate(arch) if char != 1]
            idx = random.choice(not_ones_index)
            r = random.randint(1, R)
            direction = -r if random.random() > 0.5 else r
            try:
                arch[idx], arch[idx + direction] = arch[idx +
                                                        direction], arch[idx]
                break
            except Exception:
                continue
        return arch

    def is_valid(self, arch):
        """Check if valid for arch code.

        :param arch: input arch code
        :type arch: list
        :return: if current arch is valid
        :rtype: bool
        """
        return True

    def save_output(self, output_path):
        """Save result to output_path.

        :param output_path: the result save output_path
        :type: str
        """
        try:
            self.pareto_front.sieve_board.to_csv(os.path.join(
                output_path, 'nas_score_board.csv'),
                                                 index=None,
                                                 header=True)
        except Exception as e:
            logging.error("write nas_score_board.csv error:{}".format(str(e)))
        try:
            pareto_dict = self.pareto_front.get_pareto_front()
            if len(pareto_dict) > 0:
                id = list(pareto_dict.keys())[0]
                net_desc = pareto_dict[id]
                net_desc = self.codec.decode(net_desc)
                with open(os.path.join(output_path, self._best_desc_file),
                          'w') as fp:
                    json.dump(net_desc, fp)
        except Exception as e:
            logging.error("write best model error:{}".format(str(e)))
コード例 #8
0
ファイル: esr_evaluator.py プロジェクト: zhwzhong/vega
class EsrGpuEvaluator(GpuEvaluator):
    """Evaluator is a gpu evaluator.

    :param args: arguments from user and default config file
    :type args: dict or Config, default to None
    :param train_data: training dataset
    :type train_data: torch dataset, default to None
    :param valid_data: validate dataset
    :type valid_data: torch dataset, default to None
    :param worker_info: the dict worker info of workers that finished train.
    :type worker_info: dict or None.

    """
    def __init__(self, worker_info=None, model=None, hps=None, **kwargs):
        """Init GpuEvaluator."""
        super(EsrGpuEvaluator, self).__init__(self.cfg)

    def _init_model(self):
        """Initialize the model architecture for full train step.

        :return: train model
        :rtype: class
        """
        model_cfg = ClassFactory.__configs__.get('model')
        if 'model_desc' in model_cfg and model_cfg.model_desc is not None:
            model_desc = model_cfg.model_desc
        else:
            raise ValueError('Model_desc is None for evaluator')
        search_space = Config({"search_space": model_desc})
        self.codec = Codec(self.cfg.codec, search_space)
        self._get_selected_arch()
        indiv_cfg = self.codec.decode(self.elitism)
        logger.info('Model arch:{}'.format(self.elitism.active_net_list()))
        self.model_desc = self.elitism.active_net_list()
        net_desc = NetworkDesc(indiv_cfg)
        model = net_desc.to_model()
        return model

    def _get_selected_arch(self):
        self.elitism = ESRIndividual(self.codec, deepcopy(self.cfg))
        if "model_arch" in self.cfg and self.cfg.model_arch is not None:
            self.elitism.update_gene(self.cfg.model_arch)
        else:
            sel_arch_file = self.cfg.model_desc_file
            sel_arch = np.load(sel_arch_file)
            self.elitism.update_gene(sel_arch[0])

    def valid(self, loader):
        """Validate one step of model.

        :param loader: validation dataloader
        """
        metrics = Metrics(self.cfg.metric)
        self.model.eval()
        with torch.no_grad():
            for batch in loader:
                img_lr, img_hr = batch["LR"].cuda(), batch["HR"].cuda()
                image_sr = self.model(img_lr)
                metrics(image_sr, img_hr)
        performance = metrics.results
        logging.info('Valid metric: {}'.format(performance))
        return performance
コード例 #9
0
class AdelaideMutate(SearchAlgorithm):
    """Search algorithm of the random structures."""

    def __init__(self, search_space=None):
        """Construct the AdelaideMutate class.

        :param search_space: Config of the search space
        """
        super(AdelaideMutate, self).__init__(search_space)
        self.search_space = search_space
        self.codec = Codec(self.cfg.codec, search_space)
        self.max_sample = self.cfg.max_sample
        self.sample_count = 0
        self._copy_needed_file()

    def _copy_needed_file(self):
        if "pareto_front_file" not in self.cfg or self.cfg.pareto_front_file is None:
            raise FileNotFoundError("Config item paretor_front_file not found in config file.")
        init_pareto_front_file = self.cfg.pareto_front_file.replace("{local_base_path}", self.local_base_path)
        self.pareto_front_file = FileOps.join_path(self.local_output_path, self.cfg.step_name, "pareto_front.csv")
        FileOps.make_base_dir(self.pareto_front_file)
        FileOps.copy_file(init_pareto_front_file, self.pareto_front_file)
        if "random_file" not in self.cfg or self.cfg.random_file is None:
            raise FileNotFoundError("Config item random_file not found in config file.")
        init_random_file = self.cfg.random_file.replace("{local_base_path}", self.local_base_path)
        self.random_file = FileOps.join_path(self.local_output_path, self.cfg.step_name, "random.csv")
        FileOps.copy_file(init_random_file, self.random_file)

    @property
    def is_completed(self):
        """Tell whether the search process is completed.

        :return: True is completed, or False otherwise
        """
        return self.sample_count >= self.max_sample

    def search(self):
        """Search one random model.

        :return: current number of samples, and the model
        """
        search_desc = self.search_space.search_space.custom
        pareto_front_df = pd.read_csv(self.pareto_front_file)
        num_ops = len(search_desc.op_names)
        upper_bounds = [num_ops, 2, 2, num_ops, num_ops, 5, 5, num_ops, num_ops,
                        8, 8, num_ops, num_ops, 4, 4, 5, 5, 6, 6]
        code_to_mutate = random.choice(pareto_front_df['Code'])
        index = random.randrange(len(upper_bounds))
        choices = list(range(upper_bounds[index]))
        choices.pop(int(code_to_mutate[index + 1], 36))
        choice = random.choice(choices)
        code_mutated = code_to_mutate[:index + 1] + str(choice) + code_to_mutate[index + 2:]
        search_desc['code'] = code_mutated
        search_desc['method'] = "mutate"
        logging.info("Mutate from {} to {}".format(code_to_mutate, code_mutated))
        search_desc = self.codec.decode(search_desc)
        self.sample_count += 1
        return self.sample_count, NetworkDesc(self.search_space.search_space)

    def update(self, local_worker_path):
        """Update function.

        :param local_worker_path: Local path that saved `performance.txt`
        :type local_worker_path: str
        """
        pass