def load_config(config_file): """Load config from file.""" import os import pickle import vega with open(config_file, 'rb') as f: config = pickle.load(f) for (key, value) in config["env"].items(): if value is not None: os.environ[key] = value vega.set_backend(os.environ['BACKEND_TYPE'].lower(), os.environ["DEVICE_CATEGORY"]) from vega.common.class_factory import ClassFactory from vega.common.general import General from vega.datasets.conf.dataset import DatasetConfig from vega.networks.model_config import ModelConfig from vega.trainer.conf import TrainerConfig from vega.evaluator.conf import EvaluatorConfig from vega.core.pipeline.conf import PipeStepConfig ClassFactory.__registry__ = config["class_factory"] General.from_dict(config["general"]) DatasetConfig.from_dict(config["dataset"]) ModelConfig.from_dict(config["model"]) TrainerConfig.from_dict(config["trainer"]) EvaluatorConfig.from_dict(config["evaluator"]) PipeStepConfig.from_dict(config["pipe_step"])
def _simulate_tiny_pipeline(self, cfg_tiny): """Simulate tiny pipeline by using one sample one epoch.""" report = ReportServer() for i, step_name in enumerate(PipelineConfig.steps): step_cfg = cfg_tiny.get(step_name) if step_cfg.pipe_step.type != 'SearchPipeStep': continue step_cfg.trainer.distributed = False step_cfg.trainer.epochs = 1 self.restrict_config.trials[step_name] = 1 General.step_name = step_name PipeStepConfig.from_dict(step_cfg) pipestep = PipeStep() if i == 0: pipestep.do() record = report.get_step_records(step_name)[-1] self.epoch_time = record.runtime _worker_path = TaskOps().local_base_path if os.path.exists(_worker_path): os.system('rm -rf {}'.format(_worker_path)) if step_cfg.pipe_step.type == 'SearchPipeStep': self.params_dict[step_name][ 'max_samples'] = pipestep.generator.search_alg.max_samples _file = os.path.join(TaskOps().step_path, ".generator") if os.path.exists(_file): os.system('rm {}'.format(_file))
import horovod.torch as hvd from zeus.common import ClassFactory from zeus.common.general import General from vega.core.pipeline.conf import PipeStepConfig parser = argparse.ArgumentParser(description='Horovod Fully Train') parser.add_argument('--cf_file', type=str, help='ClassFactory pickle file') args = parser.parse_args() if 'VEGA_INIT_ENV' in os.environ: exec(os.environ.copy()['VEGA_INIT_ENV']) logging.info('start horovod setting') hvd.init() try: import moxing as mox mox.file.set_auth(obs_client_log=False) except Exception: pass hvd.join() with open(args.cf_file, 'rb') as f: cf_content = pickle.load(f) model_desc = cf_content.get('model_desc') worker_id = cf_content.get('worker_id') ClassFactory.__registry__ = cf_content.get('registry') General.from_dict(cf_content.get('general_config')) PipeStepConfig.from_dict(cf_content.get('pipe_step_config')) cls_trainer = ClassFactory.get_cls('trainer') # for record in records: trainer = cls_trainer(model_desc=model_desc, id=worker_id) trainer.train_process()