예제 #1
0
class FLTrainer(object):
    def __init__(self):
        self._logger = logging.getLogger("FLTrainer")
        pass

    def set_trainer_job(self, job):
        self._startup_program = \
            job._trainer_startup_program
        self._main_program = \
            job._trainer_main_program
        self._step = job._strategy._inner_step
        self._feed_names = job._feed_names
        self._target_names = job._target_names
        self._scheduler_ep = job._scheduler_ep
        self._current_ep = None
        self.cur_step = 0

    def start(self, place):
        #current_ep = "to be added"
        self.agent = FLWorkerAgent(self._scheduler_ep, self._current_ep)
        self.agent.connect_scheduler()
        self.exe = fluid.Executor(place)
        self.exe.run(self._startup_program)

    def run(self, feed, fetch):
        self._logger.debug("begin to run")
        self.exe.run(self._main_program, feed=feed, fetch_list=fetch)
        self._logger.debug("end to run current batch")
        self.cur_step += 1

    def save_inference_program(self, output_folder):
        target_vars = []
        infer_program = self._main_program.clone(for_test=True)
        for name in self._target_names:
            tmp_var = self._main_program.block(0)._find_var_recursive(name)
            target_vars.append(tmp_var)
        fluid.io.save_inference_model(
            output_folder,
            self._feed_names,
            target_vars,
            self.exe,
            main_program=infer_program)

    def stop(self):
        # ask for termination with master endpoint
        # currently not open sourced, will release the code later
        # TODO(guru4elephant): add connection with master
        if self.cur_step != 0:
            while not self.agent.finish_training():
                self._logger.debug("Wait others finish")
                continue
        while not self.agent.can_join_training():
            self._logger.debug("Wait permit")
            continue
        self._logger.debug("Ready to train")
        return False
예제 #2
0
class FLTrainer(object):
    def __init__(self):
        self._logger = logging.getLogger("FLTrainer")
        pass

    def set_trainer_job(self, job):
        self._startup_program = \
            job._trainer_startup_program
        self._main_program = \
            job._trainer_main_program
        self._step = job._strategy._inner_step
        self._feed_names = job._feed_names
        self._target_names = job._target_names
        self._scheduler_ep = job._scheduler_ep
        self._current_ep = None
        self.cur_step = 0

    def start(self, place):
        #current_ep = "to be added"
        self.agent = FLWorkerAgent(self._scheduler_ep, self._current_ep)
        self.agent.connect_scheduler()
        self.exe = fluid.Executor(place)
        self.exe.run(self._startup_program)

    def run(self, feed, fetch):
        self._logger.debug("begin to run")
        self.exe.run(self._main_program, feed=feed, fetch_list=fetch)
        self._logger.debug("end to run current batch")
        self.cur_step += 1

    def save_inference_program(self, output_folder):
        target_vars = []
        infer_program = self._main_program.clone(for_test=True)
        for name in self._target_names:
            tmp_var = self._main_program.block(0)._find_var_recursive(name)
            target_vars.append(tmp_var)
        fluid.io.save_inference_model(
            output_folder,
            self._feed_names,
            target_vars,
            self.exe,
            main_program=infer_program)

    def save(self, parameter_dir, model_path):
        base_name = os.path.basename(model_path)
        assert base_name != "", \
            "The input model_path MUST be format of dirname/filename [dirname\\filename in Windows system], but received model_path is empty string."

        dir_name = os.path.dirname(model_path)
        if dir_name and not os.path.exists(dir_name):
            os.makedirs(dir_name)

        def get_tensor(var_name):
            t = global_scope().find_var(var_name).get_tensor()
            return numpy.array(t)

        parameter_list = []
        with open(parameter_dir + '/para_info', 'r') as fin:
            for line in fin:
                current_para = line[:-1]
                parameter_list.append(current_para)
        param_dict = {p: get_tensor(p) for p in parameter_list}
        with open(model_path + ".pdparams", 'wb') as f:
            pickle.dump(param_dict, f, protocol=2)

        optimizer_var_list = list(
            filter(is_belong_to_optimizer, self._main_program.list_vars()))

        opt_dict = {p.name: get_tensor(p.name) for p in optimizer_var_list}
        with open(model_path + ".pdopt", 'wb') as f:
            pickle.dump(opt_dict, f, protocol=2)

        main_program = self._main_program.clone()
        self._main_program.desc.flush()
        main_program.desc._set_version()
        fluid.core.save_op_compatible_info(self._main_program.desc)

        with open(model_path + ".pdmodel", "wb") as f:
            f.write(self._main_program.desc.serialize_to_string())

    def save_serving_model(self, model_path, client_conf_path):
        feed_vars = {}
        target_vars = {}
        for target in self._target_names:
            tmp_target = self._main_program.block(0)._find_var_recursive(
                target)
            target_vars[target] = tmp_target

        for feed in self._feed_names:
            tmp_feed = self._main_program.block(0)._find_var_recursive(feed)
            feed_vars[feed] = tmp_feed

        serving_io.save_model(model_path, client_conf_path, feed_vars,
                              target_vars, self._main_program)

    def stop(self):
        # ask for termination with master endpoint
        # currently not open sourced, will release the code later
        # TODO(guru4elephant): add connection with master
        if self.cur_step != 0:
            while not self.agent.finish_training():
                self._logger.debug("Wait others finish")
                continue
        while not self.agent.can_join_training():
            self._logger.debug("Wait permit")
            continue
        self._logger.debug("Ready to train")
        return False