コード例 #1
0
 def __init__(self, config):
     """R
     """
     self._cost = None
     self._metrics = {}
     self._data_var = []
     self._data_loader = None
     self._fetch_interval = 20
     self._namespace = "train.model"
     self._platform = envs.get_platform()
コード例 #2
0
    def processor_register(self):
        self.regist_context_processor('uninit', self.instance)
        self.regist_context_processor('init_pass', self.init)

        if envs.get_platform() == "LINUX":
            self.regist_context_processor('train_pass', self.dataset_train)
        else:
            self.regist_context_processor('train_pass', self.dataloader_train)

        self.regist_context_processor('infer_pass', self.infer)
        self.regist_context_processor('terminal_pass', self.terminal)
コード例 #3
0
def single_engine(args):
    print("use single engine to run model: {}".format(args.model))

    single_envs = {}
    single_envs["train.trainer.trainer"] = "SingleTrainer"
    single_envs["train.trainer.threads"] = "2"
    single_envs["train.trainer.engine"] = "single"
    single_envs["train.trainer.device"] = args.device
    single_envs["train.trainer.platform"] = envs.get_platform()

    set_runtime_envs(single_envs, args.model)
    trainer = TrainerFactory.create(args.model)
    return trainer
コード例 #4
0
def cluster_mpi_engine(args):
    print("launch cluster engine with cluster to run model: {}".format(
        args.model))

    cluster_envs = {}
    cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer"
    cluster_envs["train.trainer.device"] = args.device
    cluster_envs["train.trainer.platform"] = envs.get_platform()

    set_runtime_envs(cluster_envs, args.model)

    trainer = TrainerFactory.create(args.model)
    return trainer
コード例 #5
0
    def processor_register(self):
        role = PaddleCloudRoleMaker()
        fleet.init(role)

        if fleet.is_server():
            self.regist_context_processor('uninit', self.instance)
            self.regist_context_processor('init_pass', self.init)
            self.regist_context_processor('server_pass', self.server)
        else:
            self.regist_context_processor('uninit', self.instance)
            self.regist_context_processor('init_pass', self.init)

            if envs.get_platform() == "LINUX":
                self.regist_context_processor('train_pass', self.dataset_train)
            else:
                self.regist_context_processor('train_pass', self.dataloader_train)
            self.regist_context_processor('terminal_pass', self.terminal)
コード例 #6
0
def local_mpi_engine(args):
    print("launch cluster engine with cluster to run model: {}".format(
        args.model))
    from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine

    print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(
        args.model))

    mpi = util.run_which("mpirun")
    if not mpi:
        raise RuntimeError("can not find mpirun, please check environment")
    cluster_envs = {}
    cluster_envs["mpirun"] = mpi
    cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer"
    cluster_envs["log_dir"] = "logs"
    cluster_envs["train.trainer.engine"] = "local_cluster"

    cluster_envs["train.trainer.device"] = args.device
    cluster_envs["train.trainer.platform"] = envs.get_platform()

    set_runtime_envs(cluster_envs, args.model)
    launch = LocalMPIEngine(cluster_envs, args.model)
    return launch
コード例 #7
0
def local_cluster_engine(args):
    print("launch cluster engine with cluster to run model: {}".format(
        args.model))
    from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine

    cluster_envs = {}
    cluster_envs["server_num"] = 1
    cluster_envs["worker_num"] = 1
    cluster_envs["start_port"] = 36001
    cluster_envs["log_dir"] = "logs"
    cluster_envs["train.trainer.trainer"] = "ClusterTrainer"
    cluster_envs["train.trainer.strategy"] = "async"
    cluster_envs["train.trainer.threads"] = "2"
    cluster_envs["train.trainer.engine"] = "local_cluster"

    cluster_envs["train.trainer.device"] = args.device
    cluster_envs["train.trainer.platform"] = envs.get_platform()

    cluster_envs["CPU_NUM"] = "2"

    set_runtime_envs(cluster_envs, args.model)

    launch = LocalClusterEngine(cluster_envs, args.model)
    return launch