示例#1
0
    def __init__(self,
                 model_fn=None,
                 model_dir=None,
                 config=None,
                 params=None):
        # Create a run configuration.
        if config is None:
            self._config = RunConfig()
            logging.info("Using default config.")
        else:
            if not isinstance(config, RunConfig):
                raise ValueError("config must be an instance of RunConfig, "
                                 "received {}.".format(config))
            self._config = config
        logging.info("Using config: {}".format(vars(self._config)))

        if (model_dir is not None) and (self._config.model_dir is not None):
            if model_dir != self._config.model_dir:
                # pylint: disable=g-doc-exception
                raise ValueError(
                    "model_dir are set both in constructor and RunConfig, but with "
                    "different values. In constructor: '{}', in RunConfig: "
                    "'{}' ".format(model_dir, self._config.model_dir))

        self._model_dir = model_dir or self._config.model_dir or generate_model_dir(
        )
        if self._config.model_dir is None:
            self._config = self._config.replace(model_dir=self._model_dir)

        if self._config.session_config is None:
            self._session_config = config_pb2.ConfigProto(
                allow_soft_placement=True)
        else:
            self._session_config = self._config.session_config

        # Set device function depending if there are replicas or not.
        self._device_fn = _get_replica_device_setter(self._config)

        self._graph = None

        self._verify_model_fn_args(model_fn, params)

        self._model_fn = model_fn
        self._params = params or {}
示例#2
0
    def __init__(self, model_fn, model_dir=None, config=None, params=None):
        # Create a run configuration.
        if config is None:
            self._config = RunConfig()
            logging.info("Using default config.")
        else:
            if not isinstance(config, RunConfig):
                raise ValueError("config must be an instance of RunConfig, "
                                 "received {}.".format(config))
            self._config = config

        if(model_dir is not None) and (self._config.model_dir is not None):
            if model_dir != self._config.model_dir:
                # pylint: disable=g-doc-exception
                raise ValueError(
                    "model_dir are set both in constructor and RunConfig, but with "
                    "different values. In constructor: '{}', in RunConfig: "
                    "'{}' ".format(model_dir, self._config.model_dir))

        self._model_dir = model_dir or self._config.model_dir or generate_model_dir()
        if self._config.model_dir is None:
            self._config = self._config.replace(model_dir=self._model_dir)
        logging.info("Using config: {}".format(vars(self._config)))

        if self._config.session_config is None:
            self._session_config = config_pb2.ConfigProto(allow_soft_placement=True)
        else:
            self._session_config = self._config.session_config

        # Set device function depending if there are replicas or not.
        self._device_fn = _get_replica_device_setter(self._config)

        self._graph = None

        self._verify_model_fn_args(model_fn, params)

        self._model_fn = model_fn
        self._params = params or {}
示例#3
0
 def __init__(self, module='Estimator', output_dir=None, params=None):
     self.module = module
     self.output_dir = output_dir or generate_model_dir()
     self.params = params
示例#4
0
 def __init__(self, module='Estimator', output_dir=None, params=None):
     self.module = module
     self.output_dir = output_dir or generate_model_dir()
     self.params = params