Esempio n. 1
0
 def _simulate_tiny_pipeline(self, cfg_tiny):
     """Simulate tiny pipeline by using one sample one epoch."""
     report = Report()
     for i, step_name in enumerate(PipelineConfig.steps):
         step_cfg = cfg_tiny.get(step_name)
         step_cfg.trainer.distributed = False
         step_cfg.trainer.epochs = 1
         self.restrict_config.trials[step_name] = 1
         General.step_name = step_name
         PipeStepConfig.from_json(step_cfg)
         pipestep = PipeStep()
         if i == 0:
             pipestep.do()
             record = report.get_step_records(step_name)[-1]
             self.epoch_time = record.runtime
             _worker_path = TaskOps().local_base_path
             if os.path.exists(_worker_path):
                 os.system('rm -rf {}'.format(_worker_path))
         if step_cfg.pipe_step.type == 'NasPipeStep':
             self.params_dict[step_name][
                 'max_samples'] = pipestep.generator.search_alg.max_samples
         _file = os.path.join(TaskOps().step_path, ".generator")
         if os.path.exists(_file):
             os.system('rm {}'.format(_file))
Esempio n. 2
0
import logging
import horovod.torch as hvd
from zeus.common import ClassFactory
from zeus.common.general import General
from zeus.common import FileOps
from vega.core.pipeline.conf import PipeStepConfig

parser = argparse.ArgumentParser(description='Horovod Fully Train')
parser.add_argument('--cf_file', type=str, help='ClassFactory pickle file')
args = parser.parse_args()

if 'VEGA_INIT_ENV' in os.environ:
    exec(os.environ.copy()['VEGA_INIT_ENV'])
logging.info('start horovod setting')
hvd.init()
try:
    import moxing as mox
    mox.file.set_auth(obs_client_log=False)
except Exception:
    pass
FileOps.copy_file(args.cf_file, './cf_file.pickle')
hvd.join()
with open('./cf_file.pickle', 'rb') as f:
    cf_content = pickle.load(f)
ClassFactory.__registry__ = cf_content.get('registry')
General.from_json(cf_content.get('general_config'))
PipeStepConfig.from_json(cf_content.get('pipe_step_config'))
cls_trainer = ClassFactory.get_cls('trainer', "Trainer")
trainer = cls_trainer(None, 0)
trainer.train_process()