Esempio n. 1
0
def load_config(config_file):
    """Load config from file."""
    import os
    import pickle
    import vega

    with open(config_file, 'rb') as f:
        config = pickle.load(f)
    for (key, value) in config["env"].items():
        if value is not None:
            os.environ[key] = value

    vega.set_backend(os.environ['BACKEND_TYPE'].lower(), os.environ["DEVICE_CATEGORY"])

    from vega.common.class_factory import ClassFactory
    from vega.common.general import General
    from vega.datasets.conf.dataset import DatasetConfig
    from vega.networks.model_config import ModelConfig
    from vega.trainer.conf import TrainerConfig
    from vega.evaluator.conf import EvaluatorConfig
    from vega.core.pipeline.conf import PipeStepConfig

    ClassFactory.__registry__ = config["class_factory"]
    General.from_dict(config["general"])
    DatasetConfig.from_dict(config["dataset"])
    ModelConfig.from_dict(config["model"])
    TrainerConfig.from_dict(config["trainer"])
    EvaluatorConfig.from_dict(config["evaluator"])
    PipeStepConfig.from_dict(config["pipe_step"])
Esempio n. 2
0
 def _simulate_tiny_pipeline(self, cfg_tiny):
     """Simulate tiny pipeline by using one sample one epoch."""
     report = ReportServer()
     for i, step_name in enumerate(PipelineConfig.steps):
         step_cfg = cfg_tiny.get(step_name)
         if step_cfg.pipe_step.type != 'SearchPipeStep':
             continue
         step_cfg.trainer.distributed = False
         step_cfg.trainer.epochs = 1
         self.restrict_config.trials[step_name] = 1
         General.step_name = step_name
         PipeStepConfig.from_dict(step_cfg)
         pipestep = PipeStep()
         if i == 0:
             pipestep.do()
             record = report.get_step_records(step_name)[-1]
             self.epoch_time = record.runtime
             _worker_path = TaskOps().local_base_path
             if os.path.exists(_worker_path):
                 os.system('rm -rf {}'.format(_worker_path))
         if step_cfg.pipe_step.type == 'SearchPipeStep':
             self.params_dict[step_name][
                 'max_samples'] = pipestep.generator.search_alg.max_samples
         _file = os.path.join(TaskOps().step_path, ".generator")
         if os.path.exists(_file):
             os.system('rm {}'.format(_file))
Esempio n. 3
0
import horovod.torch as hvd
from zeus.common import ClassFactory
from zeus.common.general import General
from vega.core.pipeline.conf import PipeStepConfig

parser = argparse.ArgumentParser(description='Horovod Fully Train')
parser.add_argument('--cf_file', type=str, help='ClassFactory pickle file')
args = parser.parse_args()

if 'VEGA_INIT_ENV' in os.environ:
    exec(os.environ.copy()['VEGA_INIT_ENV'])
logging.info('start horovod setting')
hvd.init()
try:
    import moxing as mox
    mox.file.set_auth(obs_client_log=False)
except Exception:
    pass
hvd.join()
with open(args.cf_file, 'rb') as f:
    cf_content = pickle.load(f)
model_desc = cf_content.get('model_desc')
worker_id = cf_content.get('worker_id')
ClassFactory.__registry__ = cf_content.get('registry')
General.from_dict(cf_content.get('general_config'))
PipeStepConfig.from_dict(cf_content.get('pipe_step_config'))
cls_trainer = ClassFactory.get_cls('trainer')
# for record in records:
trainer = cls_trainer(model_desc=model_desc, id=worker_id)
trainer.train_process()