예제 #1
0
    def __init__(self, config):
        """R
        """
        Trainer.__init__(self, config)
        config['output_path'] = util.get_absolute_path(
            config['output_path'], config['io']['afs'])

        self.global_config = config
        self._metrics = {}

        self._path_generator = util.PathGenerator({
            'templates': [
                {'name': 'xbox_base_done', 'template': config['output_path'] + '/xbox_base_done.txt'},
                {'name': 'xbox_delta_done', 'template': config['output_path'] + '/xbox_patch_done.txt'},
                {'name': 'xbox_base', 'template': config['output_path'] + '/xbox/{day}/base/'},
                {'name': 'xbox_delta', 'template': config['output_path'] + '/xbox/{day}/delta-{pass_id}/'},
                {'name': 'batch_model', 'template': config['output_path'] + '/batch_model/{day}/{pass_id}/'}
            ]
        })
        if 'path_generator' in config:
            self._path_generator.add_path_template(config['path_generator'])

        self.regist_context_processor('uninit', self.init)
        self.regist_context_processor('startup', self.startup)
        self.regist_context_processor('begin_day', self.begin_day)
        self.regist_context_processor('train_pass', self.train_pass)
        self.regist_context_processor('end_day', self.end_day)
예제 #2
0
    def __init__(self, config):
        """R
        """
        Trainer.__init__(self, config)

        self.global_config = config
        self._metrics = {}
        self.processor_register()
예제 #3
0
 def __init__(self, config=None):
     Trainer.__init__(self, config)
     device = envs.get_global_env("train.device", "cpu")
     if device == 'gpu':
         self._place = fluid.CUDAPlace(0)
         self._exe = fluid.Executor(self._place)
     self.processor_register()
     self.model = None
     self.inference_models = []
     self.increment_models = []
예제 #4
0
 def __init__(self, config=None):
     Trainer.__init__(self, config)
     self.processor_register()
     self.abs_dir = os.path.dirname(os.path.abspath(__file__))
     self.runner_env_name = "runner." + self._context["runner_name"]