예제 #1
0
    def train(self, train_ds, train_hooks=[]):
        """train on a `Dataset`"""
        if not isinstance(train_ds, Dataset):
            raise ValueError('expect dataset to be instance of Dataset, got %s'
                             % repr(train_ds))

        train_program, model_spec, summary_record = self._build_for_train(
            train_ds)
        train_run_hooks = [
            hooks.StopAtStepHook(self.run_config.max_steps,
                                 self.run_config.run_steps),
            hooks.LoggingHook(
                model_spec.loss,
                summary_record=summary_record,
                summary_writer=_get_summary_writer(
                    os.path.join(self.run_config.model_dir, 'train_history')),
                per_step=self.run_config.log_steps,
                skip_step=self.run_config.skip_steps),
        ]
        if model_spec.train_hooks is not None:
            train_run_hooks.extend(model_spec.train_hooks)
        train_run_hooks.extend(train_hooks)

        train_executor = F.Executor(_get_one_place())

        mon_exe = MonitoredExecutor(
            train_executor,
            train_program,
            loss=model_spec.loss,
            run_config=self.run_config,
            run_hooks=train_run_hooks,
            warm_start_setting=self.warm_start_setting)

        distribution.init_distribuition_env(
            train_program)  #only initialize distribute training with
        mon_exe.init_or_restore_variables()
        if distribution.status.is_master:
            mon_exe._hooks.append(
                hooks.CheckpointSaverHook(
                    mon_exe._saver,
                    per_step=mon_exe._save_steps,
                    skip_step=mon_exe._skip_steps))

        try:
            with mon_exe:
                for data in train_ds.start():
                    mon_exe.run(feed=data)
        except (StopException, F.core.EOFException) as e:
            pass

        return mon_exe.result
예제 #2
0
    def evaluate(self, eval_dataset, eval_hooks=[]):
        """eval on a `Dataset`"""
        if not isinstance(eval_dataset, Dataset):
            raise ValueError('expect dataset to be instance of Dataset, got %s'
                             % repr(eval_dataset))
        program, model_spec = self._build_for_eval(eval_dataset)
        single_card_place = _get_one_place()
        eval_executor = F.Executor(single_card_place)

        eval_run_hooks = [
            hooks.StopAtStepHook(self.run_config.eval_max_steps,
                                 self.run_config.eval_max_steps),
            hooks.EvalHook(model_spec.metrics, )
        ]

        if model_spec.eval_hooks is not None:
            eval_run_hooks.extend(model_spec.eval_hooks)
        eval_run_hooks.extend(eval_hooks)

        mon_exe = MonitoredExecutor(
            eval_executor,
            program,
            run_config=self.run_config,
            run_hooks=eval_run_hooks)
        mon_exe.init_or_restore_variables()

        try:
            with mon_exe:
                for data in eval_dataset.start(places=[single_card_place]):
                    mon_exe.run(feed=data)
        except (StopException, F.core.EOFException) as e:
            pass

        _, eval_result = mon_exe.result

        summary_writer = _get_summary_writer(
            os.path.join(self.run_config.model_dir, 'eval_history'))
        _log_eval_result('eval', eval_result, summary_writer, mon_exe.state)

        return mon_exe.result
예제 #3
0
 def after_run(self, _, state):
     """doc"""
     if state.step > run_config.skip_steps and state.gstep % run_config.eval_steps == 0:
         eval_results = {}
         for name, ds in six.iteritems(eval_dataset):
             ehooks = [
                 hooks.StopAtStepHook(est.run_config.eval_max_steps,
                                      est.run_config.eval_max_steps),
                 hooks.EvalHook(
                     self.model_spec.metrics,
                     summary_writer=self.summary_writers[name],
                 )
             ]
             single_card_place = _get_one_place()
             eval_executor = F.Executor(single_card_place)
             mon_exe = MonitoredExecutor(
                 eval_executor,
                 self.program,
                 run_config=est.run_config,
                 run_hooks=ehooks + eval_hooks)
             try:
                 with mon_exe:
                     for data in ds.start(places=[single_card_place]):
                         mon_exe.run(feed=data)
             except (StopException, F.core.EOFException) as e:
                 pass
             hook_results = mon_exe.result
             eval_res = hook_results[
                 1]  # hook_results:  [StopAtStepHook, EvalHook, ...]
             eval_results[name] = eval_res
             _log_eval_result(name, eval_res, self.summary_writers[name],
                              state)
         for exporter in exporters:
             exporter.export(eval_executor, self.program,
                             self.model_spec, eval_results, state)
     else:
         eval_results = {}
     return eval_results