예제 #1
0
 def _get_model_desc(self):
     model_desc = self.trainer.model_desc
     if not model_desc or 'modules' not in model_desc:
         if ModelConfig.model_desc_file is not None:
             desc_file = ModelConfig.model_desc_file
             desc_file = desc_file.replace("{local_base_path}",
                                           self.trainer.local_base_path)
             if ":" not in desc_file:
                 desc_file = os.path.abspath(desc_file)
             if ":" in desc_file:
                 local_desc_file = FileOps.join_path(
                     self.trainer.local_output_path,
                     os.path.basename(desc_file))
                 FileOps.copy_file(desc_file, local_desc_file)
                 desc_file = local_desc_file
             model_desc = Config(desc_file)
             logger.info("net_desc:{}".format(model_desc))
         elif ModelConfig.model_desc is not None:
             model_desc = ModelConfig.model_desc
         elif ModelConfig.models_folder is not None:
             folder = ModelConfig.models_folder.replace(
                 "{local_base_path}", self.trainer.local_base_path)
             pattern = FileOps.join_path(folder, "desc_*.json")
             desc_file = glob.glob(pattern)[0]
             model_desc = Config(desc_file)
         else:
             return None
     return model_desc
예제 #2
0
 def _save_best_model(self):
     """Save best model."""
     if zeus.is_torch_backend():
         torch.save(self.trainer.model.state_dict(),
                    self.trainer.weights_file)
     elif zeus.is_tf_backend():
         worker_path = self.trainer.get_local_worker_path()
         model_id = "model_{}".format(self.trainer.worker_id)
         weights_folder = FileOps.join_path(worker_path, model_id)
         FileOps.make_dir(weights_folder)
         checkpoint_file = tf.train.latest_checkpoint(worker_path)
         ckpt_globs = glob.glob("{}.*".format(checkpoint_file))
         for _file in ckpt_globs:
             dst_file = model_id + os.path.splitext(_file)[-1]
             FileOps.copy_file(_file,
                               FileOps.join_path(weights_folder, dst_file))
         FileOps.copy_file(FileOps.join_path(worker_path, 'checkpoint'),
                           weights_folder)
     elif zeus.is_ms_backend():
         worker_path = self.trainer.get_local_worker_path()
         save_path = os.path.join(
             worker_path, "model_{}.ckpt".format(self.trainer.worker_id))
         for file in os.listdir(worker_path):
             if file.startswith("CKP") and file.endswith(".ckpt"):
                 self.weights_file = FileOps.join_path(worker_path, file)
                 os.rename(self.weights_file, save_path)
예제 #3
0
 def _do_horovod_fully_train(self):
     pwd_dir = os.path.dirname(os.path.abspath(__file__))
     cf_file = os.path.join(pwd_dir, 'cf.pickle')
     cf_content = {
         'registry': ClassFactory.__registry__,
         'general_config': General().to_json(),
         'pipe_step_config': PipeStepConfig().to_json()
     }
     with open(cf_file, 'wb') as f:
         pickle.dump(cf_content, f)
     cf_file_remote = os.path.join(self.task.local_base_path, 'cf.pickle')
     FileOps.copy_file(cf_file, cf_file_remote)
     if os.environ.get('DLS_TASK_NUMBER') is None:
         # local cluster
         worker_ips = '127.0.0.1'
         if General.cluster.master_ip is not None and General.cluster.master_ip != '127.0.0.1':
             worker_ips = General.cluster.master_ip
             for ip in General.cluster.slaves:
                 worker_ips = worker_ips + ',' + ip
         cmd = [
             'bash',
             '{}/horovod/run_cluster_horovod_train.sh'.format(pwd_dir),
             str(self.world_device_size), cf_file_remote, worker_ips
         ]
     else:
         # Roma
         cmd = [
             'bash', '{}/horovod/run_horovod_train.sh'.format(pwd_dir),
             str(self.world_device_size), cf_file_remote
         ]
     proc = subprocess.Popen(cmd, env=os.environ)
     proc.wait()
예제 #4
0
 def _get_pretrained_model_file(self):
     if ModelConfig.pretrained_model_file:
         model_file = ModelConfig.pretrained_model_file
         model_file = model_file.replace("{local_base_path}", self.trainer.local_base_path)
         model_file = model_file.replace("{worker_id}", str(self.trainer.worker_id))
         if ":" not in model_file:
             model_file = os.path.abspath(model_file)
         if ":" in model_file:
             local_model_file = FileOps.join_path(
                 self.trainer.local_output_path, os.path.basename(model_file))
             FileOps.copy_file(model_file, local_model_file)
             model_file = local_model_file
         return model_file
     else:
         return None
예제 #5
0
 def _save_descript(self):
     """Save result descript."""
     template_file = self.config.darts_template_file
     genotypes = self.search_alg.codec.calc_genotype(
         self._get_arch_weights())
     if template_file == "{default_darts_cifar10_template}":
         template = DartsNetworkTemplateConfig.cifar10
     elif template_file == "{default_darts_imagenet_template}":
         template = DartsNetworkTemplateConfig.imagenet
     else:
         dst = FileOps.join_path(self.trainer.get_local_worker_path(),
                                 os.path.basename(template_file))
         FileOps.copy_file(template_file, dst)
         template = Config(dst)
     model_desc = self._gen_model_desc(genotypes, template)
     self.trainer.config.codec = model_desc
예제 #6
0
    def _new_model_init(self):
        """Init new model.

        :return: initial model after loading pretrained model
        :rtype: torch.nn.Module
        """
        init_model_file = self.config.init_model_file
        if ":" in init_model_file:
            local_path = FileOps.join_path(
                self.trainer.get_local_worker_path(),
                os.path.basename(init_model_file))
            FileOps.copy_file(init_model_file, local_path)
            self.config.init_model_file = local_path
        network_desc = copy.deepcopy(self.base_net_desc)
        network_desc.backbone.cfgs = network_desc.backbone.base_cfgs
        model_init = NetworkDesc(network_desc).to_model()
        return model_init
예제 #7
0
 def _save_best_model(self):
     """Save best model."""
     if zeus.is_torch_backend():
         torch.save(self.trainer.model.state_dict(),
                    self.trainer.weights_file)
     elif zeus.is_tf_backend():
         worker_path = self.trainer.get_local_worker_path()
         model_id = "model_{}".format(self.trainer.worker_id)
         weights_folder = FileOps.join_path(worker_path, model_id)
         FileOps.make_dir(weights_folder)
         checkpoint_file = tf.train.latest_checkpoint(worker_path)
         ckpt_globs = glob.glob("{}.*".format(checkpoint_file))
         for _file in ckpt_globs:
             dst_file = model_id + os.path.splitext(_file)[-1]
             FileOps.copy_file(_file,
                               FileOps.join_path(weights_folder, dst_file))
         FileOps.copy_file(FileOps.join_path(worker_path, 'checkpoint'),
                           weights_folder)
예제 #8
0
 def _copy_needed_file(self):
     if self.config.pareto_front_file is None:
         raise FileNotFoundError(
             "Config item paretor_front_file not found in config file.")
     init_pareto_front_file = self.config.pareto_front_file.replace(
         "{local_base_path}", self.local_base_path)
     self.pareto_front_file = FileOps.join_path(
         self.local_output_path, self.step_name, "pareto_front.csv")
     FileOps.make_base_dir(self.pareto_front_file)
     FileOps.copy_file(init_pareto_front_file, self.pareto_front_file)
     if self.config.random_file is None:
         raise FileNotFoundError(
             "Config item random_file not found in config file.")
     init_random_file = self.config.random_file.replace(
         "{local_base_path}", self.local_base_path)
     self.random_file = FileOps.join_path(
         self.local_output_path, self.step_name, "random.csv")
     FileOps.copy_file(init_random_file, self.random_file)
예제 #9
0
 def _output_records(self,
                     step_name,
                     records,
                     desc=True,
                     weights_file=False,
                     performance=False):
     """Dump records."""
     columns = ["worker_id", "performance", "desc"]
     outputs = []
     for record in records:
         record = record.serialize()
         _record = {}
         for key in columns:
             _record[key] = record[key]
         outputs.append(deepcopy(_record))
     data = pd.DataFrame(outputs)
     step_path = FileOps.join_path(TaskOps().local_output_path, step_name)
     FileOps.make_dir(step_path)
     _file = FileOps.join_path(step_path, "output.csv")
     try:
         data.to_csv(_file, index=False)
     except Exception:
         logging.error("Failed to save output file, file={}".format(_file))
     for record in outputs:
         worker_id = record["worker_id"]
         worker_path = TaskOps().get_local_worker_path(step_name, worker_id)
         outputs_globs = []
         if desc:
             outputs_globs += glob.glob(
                 FileOps.join_path(worker_path, "desc_*.json"))
         if weights_file:
             outputs_globs += glob.glob(
                 FileOps.join_path(worker_path, "model_*"))
         if performance:
             outputs_globs += glob.glob(
                 FileOps.join_path(worker_path, "performance_*.json"))
         for _file in outputs_globs:
             if os.path.isfile(_file):
                 FileOps.copy_file(_file, step_path)
             elif os.path.isdir(_file):
                 FileOps.copy_folder(
                     _file,
                     FileOps.join_path(step_path, os.path.basename(_file)))
예제 #10
0
import logging
import horovod.torch as hvd
from zeus.common import ClassFactory
from zeus.common.general import General
from zeus.common import FileOps
from vega.core.pipeline.conf import PipeStepConfig

parser = argparse.ArgumentParser(description='Horovod Fully Train')
parser.add_argument('--cf_file', type=str, help='ClassFactory pickle file')
args = parser.parse_args()

if 'VEGA_INIT_ENV' in os.environ:
    exec(os.environ.copy()['VEGA_INIT_ENV'])
logging.info('start horovod setting')
hvd.init()
try:
    import moxing as mox
    mox.file.set_auth(obs_client_log=False)
except Exception:
    pass
FileOps.copy_file(args.cf_file, './cf_file.pickle')
hvd.join()
with open('./cf_file.pickle', 'rb') as f:
    cf_content = pickle.load(f)
ClassFactory.__registry__ = cf_content.get('registry')
General.from_json(cf_content.get('general_config'))
PipeStepConfig.from_json(cf_content.get('pipe_step_config'))
cls_trainer = ClassFactory.get_cls('trainer', "Trainer")
trainer = cls_trainer(None, 0)
trainer.train_process()