コード例 #1
0
ファイル: adelaide_random.py プロジェクト: zhwzhong/vega
class AdelaideRandom(SearchAlgorithm):
    """Search algorithm of the random structures."""
    def __init__(self, search_space=None):
        """Construct the AdelaideRandom class.

        :param search_space: Config of the search space
        """
        super(AdelaideRandom, self).__init__(search_space)
        self.search_space = search_space
        self.codec = Codec(self.cfg.codec, search_space)
        self.max_sample = self.cfg.max_sample
        self.sample_count = 0

    @property
    def is_completed(self):
        """Tell whether the search process is completed.

        :return: True is completed, or False otherwise
        """
        return self.sample_count >= self.max_sample

    def search(self):
        """Search one random model.

        :return: current number of samples, and the model
        """
        search_desc = self.search_space.search_space.custom
        num_ops = len(search_desc.op_names)
        ops = [random.randrange(num_ops) for _ in range(7)]
        inputs = list()
        for inputs_index in range(3):
            for i in range(2):
                inputs.append(random.randint(0, (inputs_index + 2) * 3 - 5))
        conns = list()
        for conns_index in range(3):
            for i in range(2):
                conns.append(random.randint(0, conns_index + 3))
        decoder_cell_str = list()
        decoder_cell_str.append(ops[0])
        decoder_cell_str.append([inputs[0], inputs[1], ops[1], ops[2]])
        decoder_cell_str.append([inputs[2], inputs[3], ops[3], ops[4]])
        decoder_cell_str.append([inputs[4], inputs[5], ops[5], ops[6]])
        decoder_conn_str = [[conns[0], conns[1]], [conns[2], conns[3]],
                            [conns[4], conns[5]]]
        decoder_arch_str = [decoder_cell_str, decoder_conn_str]
        search_desc['config'] = decoder_arch_str
        search_desc['method'] = "random"
        search_desc = self.codec.encode(search_desc)
        self.sample_count += 1
        return self.sample_count, NetworkDesc(self.search_space.search_space)

    def update(self, local_worker_path):
        """Update function.

        :param local_worker_path: Local path that saved `performance.txt`
        :type local_worker_path: str
        """
        pass
コード例 #2
0
ファイル: sr_random.py プロジェクト: zhwzhong/vega
class SRRandom(SearchAlgorithm):
    """Search algorithm of the random structures."""
    def __init__(self, search_space=None):
        """Construct the SRRandom class.

        :param search_space: Config of the search space
        """
        super(SRRandom, self).__init__(search_space)
        self.search_space = search_space
        self.codec = Codec(self.cfg.codec, search_space)
        self.max_sample = self.policy.num_sample
        self.sample_count = 0

    @property
    def is_completed(self):
        """Tell whether the search process is completed.

        :return: True is completed, or False otherwise
        """
        return self.sample_count >= self.max_sample

    def search(self):
        """Search one random model.

        :return: current number of samples, and the model
        """
        search_desc = self.search_space.search_space.custom
        num_blocks = random.randint(*search_desc.block_range)
        num_cibs = random.randint(*search_desc.cib_range)
        candidates = search_desc.candidates
        blocks = [random.choice(candidates) for _ in range(num_blocks)]
        for _ in range(num_cibs):
            cib = [random.choice(candidates) for _ in range(2)]
            blocks.insert(random.randint(0, len(blocks)), cib)
        search_desc['blocks'] = blocks
        search_desc['method'] = "random"
        search_desc = self.codec.encode(search_desc)
        self.sample_count += 1
        return self.sample_count, NetworkDesc(self.search_space.search_space)

    def update(self, local_worker_path):
        """Update function.

        :param local_worker_path: Local path that saved `performance.txt`
        :type local_worker_path: str
        """
        pass
コード例 #3
0
ファイル: sp_nas.py プロジェクト: zhwzhong/vega
class SpNas(SearchAlgorithm):
    """Search Algorithm Stage of SPNAS."""

    def __init__(self, search_space=None):
        super(SpNas, self).__init__(search_space)
        self.search_space = search_space
        self.codec = Codec(self.cfg.codec, search_space)
        self.sample_level = self.cfg.sample_level
        self.max_sample = self.cfg.max_sample
        self.max_optimal = self.cfg.max_optimal
        self._total_list_name = self.cfg.total_list
        self.serial_settings = self.cfg.serial_settings

        self._total_list = ListDict()
        self.sample_count = 0
        self.init_code = None
        remote_output_path = FileOps.join_path(self.local_output_path, self.cfg.step_name)

        if 'last_search_result' in self.cfg:
            last_search_file = self.cfg.last_search_result
            assert FileOps.exists(os.path.join(remote_output_path, last_search_file)
                                  ), "Not found serial results!"
            # self.download_task_folder()
            last_search_results = os.path.join(self.local_output_path, last_search_file)
            last_search_results = ListDict.load_csv(last_search_results)
            pre_worker_id, pre_arch = self.select_from_remote(self.max_optimal, last_search_results)
            # re-write config template
            if self.cfg.regnition:
                self.codec.config_template['model']['backbone']['reignition'] = True
                assert FileOps.exists(os.path.join(remote_output_path,
                                                   pre_arch + '_imagenet.pth')
                                      ), "Not found {} pretrained .pth file!".format(pre_arch)
                pretrained_pth = os.path.join(self.local_output_path, pre_arch + '_imagenet.pth')
                self.codec.config_template['model']['pretrained'] = pretrained_pth
                pre_worker_id = -1
            # update config template
            self.init_code = dict(arch=pre_arch,
                                  pre_arch=pre_arch.split('_')[1],
                                  pre_worker_id=pre_worker_id)

        logging.info("inited SpNas {}-level search...".format(self.sample_level))

    @property
    def is_completed(self):
        """Check sampling if finished.

        :return: True is completed, or False otherwise
        :rtype: bool
        """
        return self.sample_count > self.max_sample

    @property
    def num_samples(self):
        """Get the number of sampled architectures.

        :return: The number of sampled architectures
        :rtype: int
        """
        return len(self._total_list)

    def select_from_remote(self, max_optimal, total_list):
        """Select base model to mutate.

        :return: worker id and arch encode
        :rtype: int, str
        """
        def normalization(x):
            sum_ = sum(x)
            return [float(i) / float(sum_) for i in x]

        # rank with mAP and memory
        top_ = total_list.sort('mAP')[:max_optimal]
        if max_optimal > 1:
            prob = [round(np.log(i + 1e-2), 2) for i in range(1, len(top_) + 1)]
            prob_temp = prob
            sorted_ind = sorted(range(len(top_)), key=lambda k: top_['memory'][k], reverse=True)
            for idx, ind in enumerate(sorted_ind):
                prob[ind] += prob_temp[idx]
            ind = np.random.choice(len(top_), p=normalization(prob))
            worker_id, arch = top_['worker_id', 'arch'][ind]
        else:
            worker_id, arch = top_['worker_id', 'arch'][0]
        return worker_id, arch

    def search(self):
        """Search a sample.

        :return: sample count and info
        :rtype: int, dict
        """
        code = self.init_code
        if self.num_samples > 0:
            pre_worker_id, pre_arch = self.select_from_remote(self.max_optimal, self._total_list)
            block_type, pre_serial, pre_paral = pre_arch.split('_')

            success = False
            while not success:
                serialnet, parallelnet = pre_serial, pre_paral
                if self.sample_level == 'serial':
                    serialnet = self._mutate_serialnet(serialnet, **self.serial_settings)
                    parallelnet = '-'.join(['0'] * len(serialnet.split('-')))
                elif self.sample_level == 'parallel':
                    parallelnet = self._mutate_parallelnet(parallelnet)
                    pre_worker_id = self.init_code['pre_worker_id']
                else:
                    raise ValueError("Unknown type of sample level")
                arch = self.codec.encode(block_type, serialnet, parallelnet)
                if arch not in self._total_list['arch'] or len(self._total_list['arch']) == 0:
                    success = True
            code = dict(arch=arch,
                        pre_arch=pre_serial,
                        pre_worker_id=pre_worker_id)

        self.sample_count += 1
        logging.info("The {}-th successfully sampling result: {}".format(self.sample_count, code))
        net_desc = self.codec.decode(code)
        return self.sample_count, net_desc

    def update(self, worker_result_path):
        """Update sampler."""
        performance_file = self.performance_path(worker_result_path)
        logging.info(
            "SpNas.update(), performance file={}".format(performance_file))
        info = FileOps.load_pickle(performance_file)
        if info is not None:
            self._total_list.append(info)
        else:
            logging.info("SpNas.update(), file is not exited, "
                         "performance file={}".format(performance_file))
        self.save_output(self.local_output_path)
        if self.backup_base_path is not None:
            FileOps.copy_folder(self.local_output_path, self.backup_base_path)

    def performance_path(self, worker_result_path):
        """Get performance path."""
        performance_dir = os.path.join(worker_result_path, 'performance')
        if not os.path.exists(performance_dir):
            FileOps.make_dir(performance_dir)
        return os.path.join(performance_dir, 'performance.pkl')

    def save_output(self, local_output_path):
        """Save results."""
        local_totallist_path = os.path.join(local_output_path, self._total_list_name)
        self._total_list.to_csv(local_totallist_path)

    def _mutate_serialnet(self, arch, num_mutate=3, expend_ratio=0, addstage_ratio=0.95, max_stages=6):
        """Swap & Expend operation in Serial-level searching.

        :param arch: base arch encode
        :type arch: str
        :param num_mutate: number of mutate
        :type num_mutate: int
        :param expend_ratio: probability of expend block
        :type expend_ratio: float
        :param addstage_ratio: probability of expand new stage
        :type addstage_ratio: float
        :param max_stages:  max stage allowed to expand
        :type max_stages: int
        :return: arch encode after mutate
        :rtype: str
        """
        def is_valid(arch):
            stages = arch.split('-')
            for stage in stages:
                if len(stage) == 0:
                    return False
            return True

        def expend(arc):
            idx = np.random.randint(low=1, high=len(arc))
            arc = arc[:idx] + '1' + arc[idx:]
            return arc, idx

        def swap(arc, len_step=3):
            is_not_valid = True
            arc_origin = copy.deepcopy(arc)
            temp = arc.split('-')
            num_insert = len(temp) - 1
            while is_not_valid or arc == arc_origin:
                next_start = 0
                arc = list(''.join(temp))
                for i in range(num_insert):
                    pos = arc_origin[next_start:].find('-') + next_start
                    assert arc_origin[pos] == '-', "Wrong '-' is found!"
                    max_step = min(len_step, max(len(temp[i]), len(temp[i + 1])))
                    step_range = list(range(-1 * max_step, max_step))
                    step = random.choice(step_range)
                    next_start = pos + 1
                    pos = pos + step
                    arc.insert(pos, '-')
                arc = ''.join(arc)
                is_not_valid = (not is_valid(arc))
            return arc

        arch_origin = arch
        success = False
        k = 0
        while not success:
            k += 1
            arch = arch_origin
            ops = []
            for i in range(num_mutate):
                op_idx = np.random.randint(low=0, high=3)
                adds_thresh_ = addstage_ratio if len(arch.split('-')) < max_stages else 1
                if op_idx == 0 and random.random() > expend_ratio:
                    arch, idx = expend(arch)
                    arch, idx = expend(arch)
                elif op_idx == 1:
                    arch = swap(arch)
                elif op_idx == 2 and random.random() > adds_thresh_:
                    arch = arch + '-1'
                    ops.append('add stage')
                else:
                    ops.append('Do Nothing.')
            success = arch != arch_origin
            flag = 'Success' if success else 'Failed'
            logging.info('Serial-level Sample{}: {}. {}.'.format(k + 1, ' -> '.join(ops), flag))
        return arch

    def _mutate_parallelnet(self, arch):
        """Mutate operation in Parallel-level searching.

        :param arch: base arch encode
        :type arch: str
        :return: parallel arch encode after mutate
        :rtype: str
        """
        def limited_random(num_stage):
            p = [0.4, 0.3, 0.2, 0.1]
            l = np.random.choice(4, size=num_stage, replace=True, p=p)
            l = [str(i) for i in l]
            return '-'.join(l)

        num_stage = len(arch.split('-'))
        success = False
        k = 0
        while not success:
            k += 1
            new_arch = limited_random(num_stage)
            success = new_arch != arch
            flag = 'Success' if success else 'Failed'
            logging.info('Parallel-level Sample{}: {}.'.format(k + 1, flag))
        return new_arch
コード例 #4
0
ファイル: backbone_nas.py プロジェクト: zhwzhong/vega
class BackboneNas(SearchAlgorithm):
    """BackboneNas.

    :param search_space: input search_space
    :type: SeachSpace
    """
    def __init__(self, search_space=None):
        """Init BackboneNas."""
        super(BackboneNas, self).__init__(search_space)
        # ea or random
        self.search_space = search_space
        self.codec = Codec(self.cfg.codec, search_space)
        self.num_mutate = self.policy.num_mutate
        self.random_ratio = self.policy.random_ratio
        self.max_sample = self.range.max_sample
        self.min_sample = self.range.min_sample
        self.sample_count = 0
        logging.info("inited BackboneNas")
        self.pareto_front = ParetoFront(self.cfg)
        self.random_search = RandomSearchAlgorithm(self.search_space)
        self._best_desc_file = 'nas_model_desc.json'
        if 'best_desc_file' in self.cfg and self.cfg.best_desc_file is not None:
            self._best_desc_file = self.cfg.best_desc_file

    @property
    def is_completed(self):
        """Check if NAS is finished."""
        return self.sample_count > self.max_sample

    def search(self):
        """Search in search_space and return a sample."""
        sample = {}
        while sample is None or 'code' not in sample:
            pareto_dict = self.pareto_front.get_pareto_front()
            pareto_list = list(pareto_dict.values())
            if self.pareto_front.size < self.min_sample or random.random(
            ) < self.random_ratio or len(pareto_list) == 0:
                sample_desc = self.random_search.search()
                sample = self.codec.encode(sample_desc)
            else:
                sample = pareto_list[0]
            if sample is not None and 'code' in sample:
                code = sample['code']
                code = self.ea_sample(code)
                sample['code'] = code
            if not self.pareto_front._add_to_board(id=self.sample_count + 1,
                                                   config=sample):
                sample = None
        self.sample_count += 1
        logging.info(sample)
        sample_desc = self.codec.decode(sample)
        return self.sample_count, NetworkDesc(sample_desc)

    def random_sample(self):
        """Random sample from search_space."""
        sample_desc = self.random_search.search()
        sample = self.codec.encode(sample_desc, is_random=True)
        return sample

    def ea_sample(self, code):
        """Use EA op to change a arch code.

        :param code: list of code for arch
        :type code: list
        :return: changed code
        :rtype: list
        """
        new_arch = code.copy()
        self._insert(new_arch)
        self._remove(new_arch)
        self._swap(new_arch[0], self.num_mutate // 2)
        self._swap(new_arch[1], self.num_mutate // 2)
        return new_arch

    def update(self, worker_result_path):
        """Use train and evaluate result to update algorithm.

        :param worker_result_path: current result path
        :type: str
        """
        step_name = os.path.basename(os.path.dirname(worker_result_path))
        config_id = int(os.path.basename(worker_result_path))
        performance = self._get_performance(step_name, config_id)
        logging.info("update performance={}".format(performance))
        self.pareto_front.add_pareto_score(config_id, performance)
        self.save_output(self.local_output_path)
        if self.backup_base_path is not None:
            FileOps.copy_folder(self.local_base_path, self.backup_base_path)

    def _get_performance(self, step_name, worker_id):
        saved_folder = self.get_local_worker_path(step_name, worker_id)
        performance_file = FileOps.join_path(saved_folder, "performance.txt")
        if not os.path.isfile(performance_file):
            logging.info("Performance file is not exited, file={}".format(
                performance_file))
            return []
        with open(performance_file, 'r') as f:
            performance = []
            for line in f.readlines():
                line = line.strip()
                if line == "":
                    continue
                data = json.loads(line)
                if isinstance(data, list):
                    data = data[0]
                performance.append(data)
            logging.info("performance={}".format(performance))
        return performance

    def _insert(self, arch):
        """Random insert to arch code.

        :param arch: input arch code
        :type arch: list
        :return: changed arch code
        :rtype: list
        """
        idx = np.random.randint(low=0, high=len(arch[0]))
        arch[0].insert(idx, 1)
        idx = np.random.randint(low=0, high=len(arch[1]))
        arch[1].insert(idx, 1)
        return arch

    def _remove(self, arch):
        """Random remove one from arch code.

        :param arch: input arch code
        :type arch: list
        :return: changed arch code
        :rtype: list
        """
        # random pop arch[0]
        ones_index = [i for i, char in enumerate(arch[0]) if char == 1]
        idx = random.choice(ones_index)
        arch[0].pop(idx)
        # random pop arch[1]
        ones_index = [i for i, char in enumerate(arch[1]) if char == 1]
        idx = random.choice(ones_index)
        arch[1].pop(idx)
        return arch

    def _swap(self, arch, R):
        """Random swap one in arch code.

        :param arch: input arch code
        :type arch: list
        :return: changed arch code
        :rtype: list
        """
        while True:
            not_ones_index = [i for i, char in enumerate(arch) if char != 1]
            idx = random.choice(not_ones_index)
            r = random.randint(1, R)
            direction = -r if random.random() > 0.5 else r
            try:
                arch[idx], arch[idx + direction] = arch[idx +
                                                        direction], arch[idx]
                break
            except Exception:
                continue
        return arch

    def is_valid(self, arch):
        """Check if valid for arch code.

        :param arch: input arch code
        :type arch: list
        :return: if current arch is valid
        :rtype: bool
        """
        return True

    def save_output(self, output_path):
        """Save result to output_path.

        :param output_path: the result save output_path
        :type: str
        """
        try:
            self.pareto_front.sieve_board.to_csv(os.path.join(
                output_path, 'nas_score_board.csv'),
                                                 index=None,
                                                 header=True)
        except Exception as e:
            logging.error("write nas_score_board.csv error:{}".format(str(e)))
        try:
            pareto_dict = self.pareto_front.get_pareto_front()
            if len(pareto_dict) > 0:
                id = list(pareto_dict.keys())[0]
                net_desc = pareto_dict[id]
                net_desc = self.codec.decode(net_desc)
                with open(os.path.join(output_path, self._best_desc_file),
                          'w') as fp:
                    json.dump(net_desc, fp)
        except Exception as e:
            logging.error("write best model error:{}".format(str(e)))