Exemplo n.º 1
0
 def send_partial_params(cls, client: str, job_id: str, step: int,
                         params) -> NoReturn:
     # 这里的参数client,暂时认为是client_address
     # 在standalone模式下,trainer当前训练轮次得到的模型保存在指定路径下
     client_params_dir = JobPath(job_id).client_params_dir(step, client)
     os.makedirs(client_params_dir, exist_ok=True)
     # 保存 job_id.pth为文件名
     path = PathUtils.join(client_params_dir, job_id + '.pkl')
     # path = client_params_dir + 'job_id.pth'
     # torch.save(params, path)
     with open(path, 'wb') as f:
         pickle.dump(params, f)
     print("训练完成,已将模型保存至:" + str(client_params_dir))
Exemplo n.º 2
0
    def save_job_zip(cls, job_id: str, job: File) -> NoReturn:
        """
        Save the zip file of the job.

        :param job_id: job ID
        :param job: zip file
        """
        job_path = JobPath(job_id)
        if job.ipfs_hash is not None and job.ipfs_hash != "":
            file_obj = Ipfs.get(job.ipfs_hash)
        else:
            file_obj = File.file
        ZipUtils.extract(file_obj, job_path.root_dir)
Exemplo n.º 3
0
 def is_available(self):
     # 获取模型前先更新客户端
     self.update_clients()
     n_available_models = 0
     for client_addr in self.clients:
         client_params_dir = JobPath(self.job_id).client_params_dir(
             self.step, client_addr) + f"/{self.job_id}.pkl"
         if os.path.exists(client_params_dir):
             n_available_models += 1
             self.available_clients.add(client_addr)
     if n_available_models >= self.job.aggregate_config.clients_per_round:
         return True
     else:
         return False
Exemplo n.º 4
0
def load_job_zip(job_id: str) -> File:
    """
    Load job from a ZIP file.

    :param job_id: dataset ID
    """
    job_path = JobPath(job_id)
    file_obj = BytesIO()
    ZipUtils.compress([job_path.metadata_file, job_path.config_dir], file_obj)
    if GflConf.get_property("ipfs.enabled"):
        file_obj.seek(0)
        ipfs_hash = Ipfs.put(file_obj.read())
        return File(ipfs_hash == ipfs_hash)
    else:
        return File(file=file_obj)
Exemplo n.º 5
0
    def load_topology_manager(cls, job_id: str):
        """
        Load topology_manager from JSON file.
        json->topology_config->topology_manager

        :param job_id: Job ID
        """
        job_path = JobPath(job_id)
        topology_config = cls.__load_json(job_path.topology_config_file,
                                          TopologyConfig)
        if topology_config is not None and isinstance(topology_config,
                                                      TopologyConfig):
            if topology_config.get_isCentralized() is True:
                topology_manager = CentralizedTopologyManager(
                    topology_config=topology_config)
                return topology_manager
Exemplo n.º 6
0
 def receive_global_params(cls, job_id: str, cur_round: int):
     # 在standalone模式下,trainer获取当前聚合轮次下的全局模型
     # 根据 Job 中的 job_id 和 cur_round 获取指定轮次聚合后的 全局模型参数的路径
     global_params_dir = JobPath(job_id).global_params_dir(cur_round)
     model_params_path = PathUtils.join(global_params_dir, job_id + '.pkl')
     # 判断是否存在模型参数文件,如果存在则返回。
     if os.path.exists(global_params_dir) and os.path.isfile(
             model_params_path):
         # resources_already:1
         # self.__status = JobStatus.RESOURCE_ALREADY
         print("训练方接收全局模型")
         return model_params_path
     else:
         # 等待一段时间。在这段时间内获取到了模型参数文件,则返回
         # 暂时不考虑这种情况
         # 否则,认为当前模型参数文件已经无法获取
         return None
Exemplo n.º 7
0
    def setUp(self) -> None:
        self.dataset = generate_dataset()
        print("dataset_id:" + self.dataset.dataset_id)
        self.job = generate_job()
        print("job_id:" + self.job.job_id)
        self.job.mount_dataset(self.dataset)
        GflNode.init_node()
        node = GflNode.default_node
        self.jobTrainerScheduler = JobTrainScheduler(node=node, job=self.job)
        self.jobTrainerScheduler.register()

        # aggregator需要初始化随机模型
        global_params_dir = JobPath(self.job.job_id).global_params_dir(
            self.job.cur_round)
        # print("global_params_dir:"+global_params_dir)
        os.makedirs(global_params_dir, exist_ok=True)
        model_params_path = PathUtils.join(global_params_dir,
                                           self.job.job_id + '.pth')
        # print("model_params_path:"+model_params_path)
        model = Net()
        torch.save(model.state_dict(), model_params_path)
Exemplo n.º 8
0
    def load_job(cls, job_id: str) -> Job:
        """
        Load job from JSON file.

        :param job_id: dataset ID
        """
        job_path = JobPath(job_id)
        metadata = cls.__load_json(job_path.metadata_file, JobMetadata)
        job_config = cls.__load_json(job_path.job_config_file, JobConfig)
        train_config = cls.__load_json(job_path.train_config_file, TrainConfig)
        aggregate_config = cls.__load_json(job_path.aggregate_config_file,
                                           AggregateConfig)
        module = ModuleUtils.import_module(job_path.module_dir,
                                           job_path.module_name)
        job_config.module = module
        train_config.module = module
        aggregate_config.module = module
        job = Job(job_id=job_id,
                  metadata=metadata,
                  job_config=job_config,
                  train_config=train_config,
                  aggregate_config=aggregate_config)
        job.module = module
        return job
Exemplo n.º 9
0
    def receive_partial_params(cls, client_address: str, job_id: str,
                               step: int) -> File:
        """
        获得指定客户端的模型参数
        Parameters
        ----------
        client_address: 客户端的地址
        job_id
        step: 训练任务目前的训练轮数

        Returns
        -------

        """
        client_params_dir = JobPath(job_id).client_params_dir(
            step, client_address) + f"/{job_id}.pth"
        if os.path.exists(client_params_dir):
            try:
                client_model_param = torch.load(client_params_dir)
            except Exception as e:
                raise ValueError(f"模型 {client_model_param} 加载失败" f"Error: {e}")
            return torch.load(client_params_dir)
        else:
            return None
Exemplo n.º 10
0
 def validate(self):
     job_path = JobPath(self.job_id)
     work_dir = job_path.client_work_dir(self.job.cur_round, self.client.address)
     os.makedirs(work_dir, exist_ok=True)
     with WorkDirContext(work_dir):
         self._validate()
Exemplo n.º 11
0
def __get(job_id: str, statement: str, *params):
    job_path = JobPath(job_id)
    with SqliteContext(job_path.sqlite_file) as (_, cursor):
        cursor.execute(statement, tuple(params))
        ret = cursor.fetchall()
    return ret
Exemplo n.º 12
0
def __save(job_id: str, statement: str, *params):
    job_path = JobPath(job_id)
    with SqliteContext(job_path.sqlite_file) as (_, cursor):
        cursor.execute(statement, tuple(params))
Exemplo n.º 13
0
 def make_dir(self):
     cur_round = self.job.cur_round
     global_params_path = JobPath(self.job_id).global_params_dir(cur_round)
     os.makedirs(global_params_path, exist_ok=True)
Exemplo n.º 14
0
 def make_dir(self):
     cur_round = self.job.cur_round
     client_params_dir = JobPath(self.job_id).client_params_dir(
         cur_round, self.node.address)
     os.makedirs(client_params_dir, exist_ok=True)