Ejemplo n.º 1
0
Archivo: pba.py Proyecto: zeyefkey/vega
 def _init_next_rung(self):
     """Init next rung to search."""
     next_rung_id = self.rung_id + 1
     if next_rung_id >= self.total_rungs:
         self.rung_id = self.rung_id + 1
         return
     for i in range(self.config_count):
         self.all_config_dict[i][next_rung_id] = self.all_config_dict[i][self.rung_id]
     current_score = []
     for i in range(self.config_count):
         current_score.append((i, self.best_score_dict[self.rung_id][i]))
     current_score.sort(key=lambda current_score: current_score[1])
     for i in range(4):
         better_id = current_score[self.config_count - 1 - i][0]
         worse_id = current_score[i][0]
         better_worker_result_path = FileOps.join_path(self.local_base_path, 'cache', 'pba',
                                                       str(better_id), 'checkpoint')
         FileOps.make_dir(better_worker_result_path)
         worse_worker_result_path = FileOps.join_path(self.local_base_path, 'cache', 'pba',
                                                      str(worse_id), 'checkpoint')
         FileOps.make_dir(worse_worker_result_path)
         shutil.rmtree(worse_worker_result_path)
         shutil.copytree(better_worker_result_path, worse_worker_result_path)
         self.all_config_dict[worse_id] = self.all_config_dict[better_id]
         policy_unchange = self.all_config_dict[worse_id][next_rung_id]
         policy_changed = self.explore(policy_unchange)
         self.all_config_dict[worse_id][next_rung_id] = policy_changed
     for id in range(self.config_count):
         self.best_score_dict[next_rung_id][id] = -1 * float('inf')
         tmp_row_data = {'config_id': id,
                         'rung_id': next_rung_id,
                         'status': StatusType.WAITTING}
         self._add_to_board(tmp_row_data)
     self.rung_id = self.rung_id + 1
Ejemplo n.º 2
0
    def update(self, record):
        """Update current performance into hpo score board.

        :param hps: hyper parameters need to update
        :param performance:  trainer performance
        """
        super().update(record)
        config_id = str(record.get('worker_id'))
        step_name = record.get('step_name')
        worker_result_path = self.get_local_worker_path(step_name, config_id)
        new_worker_result_path = FileOps.join_path(self.local_base_path, 'cache', 'pba', config_id, 'checkpoint')
        FileOps.make_dir(worker_result_path)
        FileOps.make_dir(new_worker_result_path)
        if os.path.exists(new_worker_result_path):
            shutil.rmtree(new_worker_result_path)
        shutil.copytree(worker_result_path, new_worker_result_path)
Ejemplo n.º 3
0
 def __init__(self, search_space):
     super(PruneEA, self).__init__(search_space)
     self.length = self.policy.length
     self.num_individual = self.policy.num_individual
     self.num_generation = self.policy.num_generation
     self.x_axis = 'flops'
     self.y_axis = 'acc'
     self.random_models = self.policy.random_models
     self.codec = Codec(self.cfg.codec, search_space)
     self.random_count = 0
     self.ea_count = 0
     self.ea_epoch = 0
     self.step_path = FileOps.join_path(self.local_output_path, self.cfg.step_name)
     self.pd_file_name = FileOps.join_path(self.step_path, "performance.csv")
     self.pareto_front_file = FileOps.join_path(self.step_path, "pareto_front.csv")
     self.pd_path = FileOps.join_path(self.step_path, "pareto_front")
     FileOps.make_dir(self.pd_path)
Ejemplo n.º 4
0
 def search(self):
     """Search an id and hps from hpo."""
     sample = self.hpo.propose()
     if sample is None:
         return None
     re_hps = {}
     sample = copy.deepcopy(sample)
     sample_id = sample.get('config_id')
     trans_para = sample.get('configs')
     rung_id = sample.get('rung_id')
     re_hps['dataset.transforms'] = [{'type': 'PBATransformer', 'para_array': trans_para,
                                      'operation_names': self.operation_names}]
     checkpoint_path = FileOps.join_path(self.local_base_path, 'cache', 'pba', str(sample_id), 'checkpoint')
     FileOps.make_dir(checkpoint_path)
     if os.path.exists(checkpoint_path):
         re_hps['trainer.checkpoint_path'] = checkpoint_path
     if 'epoch' in sample:
         re_hps['trainer.epochs'] = sample.get('epoch')
     return dict(worker_id=sample_id, desc=re_hps, info=rung_id)
Ejemplo n.º 5
0
def write_ip(ip_address, port, args):
    """Write the ip and port in a system path.

    :param str ip_address: The `ip_address` need to write.
    :param str port: The `port` need to write.
    :param argparse.ArgumentParser args: `args` is a argparse that should
         contain `init_method`, `rank` and `world_size`.

    """
    local_base_path = General.task.local_base_path
    local_task_id = General.task.task_id
    local_path = os.path.join(local_base_path, local_task_id, 'ip_address/')
    if not os.path.exists(local_path):
        FileOps.make_dir(local_path)

    file_path = os.path.join(local_path, 'ip_address.txt')
    logging.info("write ip, file path={}".format(file_path))
    with open(file_path, 'w') as f:
        f.write(ip_address + "\n")
        f.write(port + "\n")
Ejemplo n.º 6
0
def get_write_ip(args):
    """Get the ip and port that write in a system path.

    :param argparse.ArgumentParser args: `args` is a argparse that should
         contain `init_method`, `rank` and `world_size`.
    :return: the ip and port .
    :rtype: str, str.

    """
    local_base_path = UserConfig().data.general.task.local_base_path
    local_task_id = UserConfig().data.general.task.task_id
    local_path = os.path.join(local_base_path, local_task_id, 'ip_address/')
    if not os.path.exists(local_path):
        FileOps.make_dir(local_path)
    file_path = os.path.join(local_path, 'ip_address.txt')
    with open(file_path, 'r') as f:
        ip = f.readline().strip()
        port = f.readline().strip()
        logging.info("get write ip, ip={}, port={}".format(ip, port))
        return ip, port
Ejemplo n.º 7
0
    def update_performance(self, hps, performance):
        """Update current performance into hpo score board.

        :param hps: hyper parameters need to update
        :param performance:  trainer performance
        """
        if isinstance(performance, list) and len(performance) > 0:
            self.hpo.add_score(int(hps.get('config_id')),
                               int(hps.get('rung_id')), performance[0])
        else:
            self.hpo.add_score(int(hps.get('config_id')),
                               int(hps.get('rung_id')), -1)
            logging.error("hpo get empty performance!")
        worker_result_path = self.get_local_worker_path(self.cfg.step_name, str(hps.get('config_id')))
        new_worker_result_path = FileOps.join_path(self.local_base_path, 'cache', 'pba',
                                                   str(hps.get('config_id')), 'checkpoint')
        FileOps.make_dir(worker_result_path)
        FileOps.make_dir(new_worker_result_path)
        if os.path.exists(new_worker_result_path):
            shutil.rmtree(new_worker_result_path)
        shutil.copytree(worker_result_path, new_worker_result_path)
Ejemplo n.º 8
0
    def sample(self):
        """Sample an id and hps from hpo.

        :return: id, hps
        :rtype: int, dict
        """
        re_hps = {}
        sample = self.hpo.propose()
        if sample is not None:
            sample = copy.deepcopy(sample)
            sample_id = sample.get('config_id')
            self._hps_cache[str(sample_id)] = [copy.deepcopy(sample), []]
            trans_para = sample.get('configs')
            re_hps['dataset.transforms'] = [{'type': 'PBATransformer', 'para_array': trans_para,
                                             'operation_names': self.operation_names}]
            checkpoint_path = FileOps.join_path(self.local_base_path, 'cache', 'pba', str(sample_id), 'checkpoint')
            FileOps.make_dir(checkpoint_path)
            if os.path.exists(checkpoint_path):
                re_hps['trainer.checkpoint_path'] = checkpoint_path
            if 'epoch' in sample:
                re_hps['trainer.epochs'] = sample.get('epoch')
            return sample_id, re_hps
        else:
            return None, None
Ejemplo n.º 9
0
 def performance_path(self, worker_result_path):
     """Get performance path."""
     performance_dir = os.path.join(worker_result_path, 'performance')
     if not os.path.exists(performance_dir):
         FileOps.make_dir(performance_dir)
     return os.path.join(performance_dir, 'performance.pkl')