Esempio n. 1
0
def _wait_model_to_load(grpc_proxy_client, max_seconds):
    """Wait TF Serving to load the model

    :param grpc_proxy_client: proxy client to make rpc call to TF Serving
    :param max_seconds: max number of seconds to wait
    """

    for i in range(max_seconds):
        try:
            grpc_proxy_client.cache_prediction_metadata()

            logger.info("TF Serving model successfully loaded")
            return
        except AbortionError as abort_err:
            if abort_err.code == StatusCode.UNAVAILABLE:
                _handle_rpc_exception(abort_err)
        # GRPC throws a _Rendezvous, which inherits from RpcError
        # _Rendezvous has a method for code instead of a parameter.
        # https://github.com/grpc/grpc/issues/9270
        except RpcError as rpc_error:
            if rpc_error.code() == StatusCode.UNAVAILABLE:
                _handle_rpc_exception(rpc_error)

    message = 'TF Serving failed to load the model under the maximum load time in seconds: {}'
    raise ValueError(message.format(max_seconds))
def _wait_model_to_load(grpc_proxy_client, max_seconds):
    """Wait TF Serving to load the model

    :param grpc_proxy_client: proxy client to make rpc call to TF Serving
    :param max_seconds: max number of seconds to wait
    """

    for i in range(max_seconds):
        try:
            grpc_proxy_client.cache_prediction_metadata()

            logger.info("TF Serving model successfully loaded")
            return
        except AbortionError as abort_err:
            if abort_err.code == StatusCode.UNAVAILABLE:
                _handle_rpc_exception(abort_err)
        # GRPC throws a _Rendezvous, which inherits from RpcError
        # _Rendezvous has a method for code instead of a parameter.
        # https://github.com/grpc/grpc/issues/9270
        except RpcError as rpc_error:
            if rpc_error.code() == StatusCode.UNAVAILABLE:
                _handle_rpc_exception(rpc_error)

    message = 'TF Serving failed to load the model under the maximum load time in seconds: {}'
    raise ValueError(message.format(max_seconds))
Esempio n. 3
0
def export_saved_model(checkpoint_dir, model_path, s3=boto3.client('s3')):
    if checkpoint_dir.startswith('s3://'):
        bucket_name, key_prefix = cs.parse_s3_url(checkpoint_dir)
        prefix = os.path.join(key_prefix, 'export', 'Servo')

        try:
            contents = s3.list_objects_v2(Bucket=bucket_name,
                                          Prefix=prefix)["Contents"]
            saved_model_path_array = [
                x['Key'] for x in contents
                if x['Key'].endswith('saved_model.pb')
            ]

            if len(saved_model_path_array) == 0:
                logger.info(
                    "Failed to download saved model. File does not exist in {}"
                    .format(checkpoint_dir))
                return
        except KeyError as e:
            logger.error(
                "Failed to download saved model. File does not exist in {}".
                format(checkpoint_dir))
            raise e

        saved_model_path = saved_model_path_array[0]

        variables_path = [
            x['Key'] for x in contents if 'variables/variables' in x['Key']
        ]
        variable_names_to_paths = {
            v.split('/').pop(): v
            for v in variables_path
        }

        prefixes = key_prefix.split('/')
        folders = saved_model_path.split('/')[len(prefixes):]
        saved_model_filename = folders.pop()
        path_to_save_model = os.path.join(model_path, *folders)

        path_to_variables = os.path.join(path_to_save_model, 'variables')

        os.makedirs(path_to_variables)

        target = os.path.join(path_to_save_model, saved_model_filename)
        s3.download_file(bucket_name, saved_model_path, target)
        logger.info("Downloaded saved model at {}".format(target))

        for filename, full_path in variable_names_to_paths.items():
            key = full_path
            target = os.path.join(path_to_variables, filename)
            s3.download_file(bucket_name, key, target)
    else:
        if os.path.exists(checkpoint_dir):
            shutil.copy2(checkpoint_dir, model_path)
        else:
            logger.error(
                "Failed to copy saved model. File does not exist in {}".format(
                    checkpoint_dir))
    def _build_run_config(self):
        valid_runconfig_keys = ['save_summary_steps', 'save_checkpoints_secs', 'save_checkpoints_steps',
                                'keep_checkpoint_max', 'keep_checkpoint_every_n_hours', 'log_step_count_steps']

        runconfig_params = {k: v for k, v in self.customer_params.items() if k in valid_runconfig_keys}

        logger.info('creating RunConfig:')
        logger.info(runconfig_params)

        run_config = tf.estimator.RunConfig(model_dir=self.model_path, **runconfig_params)
        return run_config
Esempio n. 5
0
    def _build_run_config(self):
        valid_runconfig_keys = ['save_summary_steps', 'save_checkpoints_secs', 'save_checkpoints_steps',
                                'keep_checkpoint_max', 'keep_checkpoint_every_n_hours', 'log_step_count_steps']

        runconfig_params = {k: v for k, v in self.customer_params.items() if k in valid_runconfig_keys}

        logger.info('creating RunConfig:')
        logger.info(runconfig_params)

        run_config = tf.estimator.RunConfig(model_dir=self.model_path, **runconfig_params)
        return run_config
Esempio n. 6
0
def export_saved_model(checkpoint_dir,
                       model_path,
                       s3=boto3.client(
                           's3', region_name=os.environ.get('AWS_REGION'))):
    if checkpoint_dir.startswith('s3://'):
        bucket_name, key_prefix = cs.parse_s3_url(checkpoint_dir)
        prefix = os.path.join(key_prefix, 'export', 'Servo')

        try:
            contents = s3.list_objects_v2(Bucket=bucket_name,
                                          Prefix=prefix)["Contents"]
            saved_model_path_array = [
                x['Key'] for x in contents
                if x['Key'].endswith('saved_model.pb')
            ]

            if len(saved_model_path_array) == 0:
                logger.info(
                    "Failed to download saved model. File does not exist in {}"
                    .format(checkpoint_dir))
                return
        except KeyError as e:
            logger.error(
                "Failed to download saved model. File does not exist in {}".
                format(checkpoint_dir))
            raise e
        # Select most recent saved_model.pb
        saved_model_path = saved_model_path_array[-1]
        saved_model_base_path = os.path.dirname(saved_model_path)

        prefixes = key_prefix.split('/')
        folders = saved_model_path.split('/')[len(prefixes):-1]
        path_to_save_model = os.path.join(model_path, *folders)

        def file_filter(x):
            return x['Key'].startswith(
                saved_model_base_path) and not x['Key'].endswith("/")

        paths_to_copy = [x['Key'] for x in contents if file_filter(x)]

        for key in paths_to_copy:
            target = re.sub(r"^" + saved_model_base_path, path_to_save_model,
                            key)
            _makedirs_for_file(target)
            s3.download_file(bucket_name, key, target)
        logger.info("Downloaded saved model at {}".format(path_to_save_model))
    else:
        if os.path.exists(checkpoint_dir):
            _recursive_copy(checkpoint_dir, model_path)
        else:
            logger.error(
                "Failed to copy saved model. File does not exist in {}".format(
                    checkpoint_dir))
def export_saved_model(checkpoint_dir, model_path, s3=boto3.client('s3')):
    if checkpoint_dir.startswith('s3://'):
        bucket_name, key_prefix = cs.parse_s3_url(checkpoint_dir)
        prefix = os.path.join(key_prefix, 'export', 'Servo')

        try:
            contents = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)["Contents"]
            saved_model_path_array = [x['Key'] for x in contents if x['Key'].endswith('saved_model.pb')]

            if len(saved_model_path_array) == 0:
                logger.info("Failed to download saved model. File does not exist in {}".format(checkpoint_dir))
                return
        except KeyError as e:
            logger.error("Failed to download saved model. File does not exist in {}".format(checkpoint_dir))
            raise e
        # Select most recent saved_model.pb
        saved_model_path = saved_model_path_array[-1]

        variables_path = [x['Key'] for x in contents if 'variables/variables' in x['Key']]
        variable_names_to_paths = {v.split('/').pop(): v for v in variables_path}

        prefixes = key_prefix.split('/')
        folders = saved_model_path.split('/')[len(prefixes):]
        saved_model_filename = folders.pop()
        path_to_save_model = os.path.join(model_path, *folders)

        path_to_variables = os.path.join(path_to_save_model, 'variables')

        os.makedirs(path_to_variables)

        target = os.path.join(path_to_save_model, saved_model_filename)
        s3.download_file(bucket_name, saved_model_path, target)
        logger.info("Downloaded saved model at {}".format(target))

        for filename, full_path in variable_names_to_paths.items():
            key = full_path
            target = os.path.join(path_to_variables, filename)
            s3.download_file(bucket_name, key, target)
    else:
        if os.path.exists(checkpoint_dir):
            _recursive_copy(checkpoint_dir, model_path)
        else:
            logger.error("Failed to copy saved model. File does not exist in {}".format(checkpoint_dir))
def _wait_model_to_load(grpc_proxy_client, max_seconds):
    """Wait TF Serving to load the model

    :param grpc_proxy_client: proxy client to make rpc call to TF Serving
    :param max_seconds: max number of seconds to wait
    """

    for i in range(max_seconds):
        try:
            grpc_proxy_client.cache_prediction_metadata()

            logger.info("TF Serving model successfully loaded")
            return
        except AbortionError as err:
            if err.code == StatusCode.UNAVAILABLE:
                logger.info("Waiting for TF Serving to load the model")
                time.sleep(1)

    message = 'TF Serving failed to load the model under the maximum load time in seconds: {}'
    raise ValueError(message.format(max_seconds))
Esempio n. 9
0
def _wait_model_to_load(grpc_proxy_client, max_seconds):
    """Wait TF Serving to load the model

    :param grpc_proxy_client: proxy client to make rpc call to TF Serving
    :param max_seconds: max number of seconds to wait
    """

    for i in range(max_seconds):
        try:
            grpc_proxy_client.cache_prediction_metadata()

            logger.info("TF Serving model successfully loaded")
            return
        except AbortionError as err:
            if err.code == StatusCode.UNAVAILABLE:
                logger.info("Waiting for TF Serving to load the model")
                time.sleep(1)

    message = 'TF Serving failed to load the model under the maximum load time in seconds: {}'
    raise ValueError(message.format(max_seconds))
Esempio n. 10
0
    def cache_prediction_metadata(self):
        channel = implementations.insecure_channel(self.host,
                                                   self.tf_serving_port)
        stub = prediction_service_pb2.beta_create_PredictionService_stub(
            channel)
        request = get_model_metadata_pb2.GetModelMetadataRequest()

        request.model_spec.name = self.model_name
        request.metadata_field.append('signature_def')
        result = stub.GetModelMetadata(request, self.request_timeout)

        _logger.info(
            '---------------------------Model Spec---------------------------')
        _logger.info(json_format.MessageToJson(result))
        _logger.info(
            '----------------------------------------------------------------')

        signature_def = result.metadata['signature_def']
        signature_map = get_model_metadata_pb2.SignatureDefMap()
        signature_map.ParseFromString(signature_def.value)

        serving_default = signature_map.ListFields()[0][1]['serving_default']
        serving_inputs = serving_default.inputs

        self.input_type_map = {
            key: serving_inputs[key].dtype
            for key in serving_inputs.keys()
        }
        self.prediction_type = serving_default.method_name
    def cache_prediction_metadata(self):
        channel = grpc.insecure_channel('{}:{}'.format(self.host,
                                                       self.tf_serving_port),
                                        options=[
                                            ('grpc.max_send_message_length',
                                             MAX_GRPC_MESSAGE_SIZE),
                                            ('grpc.max_receive_message_length',
                                             MAX_GRPC_MESSAGE_SIZE)
                                        ])
        stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
        request = get_model_metadata_pb2.GetModelMetadataRequest()

        request.model_spec.name = self.model_name
        request.metadata_field.append('signature_def')
        result = stub.GetModelMetadata(request, self.request_timeout)

        _logger.info(
            '---------------------------Model Spec---------------------------')
        _logger.info(json_format.MessageToJson(result))
        _logger.info(
            '----------------------------------------------------------------')

        signature_def = result.metadata['signature_def']
        signature_map = get_model_metadata_pb2.SignatureDefMap()
        signature_map.ParseFromString(signature_def.value)

        serving_default = signature_map.ListFields()[0][1]['serving_default']
        serving_inputs = serving_default.inputs

        self.input_type_map = {
            key: serving_inputs[key].dtype
            for key in serving_inputs.keys()
        }
        self.prediction_type = serving_default.method_name
        self.prediction_service_stub = stub
Esempio n. 12
0
    def _build_estimator(self, run_config):
        hyperparameters = self.customer_params

        if hasattr(self.customer_script, 'estimator_fn'):
            logger.info('invoking the user-provided estimator_fn')
            return self.customer_script.estimator_fn(run_config, hyperparameters)
        elif hasattr(self.customer_script, 'keras_model_fn'):
            logger.info('invoking the user-provided keras_model_fn')
            model = self.customer_script.keras_model_fn(hyperparameters)
            return tf.keras.estimator.model_to_estimator(keras_model=model, config=run_config)
        else:
            logger.info('creating an estimator from the user-provided model_fn')

            # We must wrap the model_fn from customer_script like this to maintain compatibility with our
            # existing behavior, which passes arguments to the customer model_fn positionally, not by name.
            # The TensorFlow Estimator checks the signature of the given model_fn for a parameter named "params":
            # https://github.com/tensorflow/tensorflow/blob/2c9a67ffb384a13cd533a0e89a96211058fa2631/tensorflow/python/estimator/estimator.py#L1215
            # Wrapping it in _model_fn allows the customer to use whatever parameter names they want. It's unclear whether
            # this behavior is desirable theoretically, but we shouldn't break existing behavior.

            def _model_fn(features, labels, mode, params):
                return self.customer_script.model_fn(features, labels, mode, params)

            return tf.estimator.Estimator(
                model_fn=_model_fn,
                params=hyperparameters,
                config=run_config)
def export_saved_model(checkpoint_dir, model_path, s3=boto3.client('s3', region_name=os.environ.get('AWS_REGION'))):
    if checkpoint_dir.startswith('s3://'):
        bucket_name, key_prefix = cs.parse_s3_url(checkpoint_dir)
        prefix = os.path.join(key_prefix, 'export', 'Servo')

        try:
            contents = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)["Contents"]
            saved_model_path_array = [x['Key'] for x in contents if x['Key'].endswith('saved_model.pb')]

            if len(saved_model_path_array) == 0:
                logger.info("Failed to download saved model. File does not exist in {}".format(checkpoint_dir))
                return
        except KeyError as e:
            logger.error("Failed to download saved model. File does not exist in {}".format(checkpoint_dir))
            raise e
        # Select most recent saved_model.pb
        saved_model_path = saved_model_path_array[-1]
        saved_model_base_path = os.path.dirname(saved_model_path)

        prefixes = key_prefix.split('/')
        folders = saved_model_path.split('/')[len(prefixes):-1]
        path_to_save_model = os.path.join(model_path, *folders)

        def file_filter(x): return x['Key'].startswith(saved_model_base_path) and not x['Key'].endswith("/")
        paths_to_copy = [x['Key'] for x in contents if file_filter(x)]

        for key in paths_to_copy:
            target = re.sub(r"^"+saved_model_base_path, path_to_save_model, key)
            _makedirs_for_file(target)
            s3.download_file(bucket_name, key, target)
        logger.info("Downloaded saved model at {}".format(path_to_save_model))
    else:
        if os.path.exists(checkpoint_dir):
            _recursive_copy(checkpoint_dir, model_path)
        else:
            logger.error("Failed to copy saved model. File does not exist in {}".format(checkpoint_dir))
Esempio n. 14
0
    def _build_estimator(self, run_config):
        hyperparameters = self.customer_params

        if hasattr(self.customer_script, 'estimator_fn'):
            logger.info('invoking the user-provided estimator_fn')
            return self.customer_script.estimator_fn(run_config, hyperparameters)
        elif hasattr(self.customer_script, 'keras_model_fn'):
            logger.info('invoking the user-provided keras_model_fn')
            model = self.customer_script.keras_model_fn(hyperparameters)
            return tf.keras.estimator.model_to_estimator(keras_model=model, config=run_config)
        else:
            logger.info('creating an estimator from the user-provided model_fn')

            return tf.estimator.Estimator(
                model_fn=self.customer_script.model_fn,
                params=hyperparameters,
                config=run_config)
    def _build_estimator(self, run_config, hparams):
        # hparams is of type HParams at this point but all the interface functions are assuming dict
        hyperparameters = hparams.values()

        if hasattr(self.customer_script, 'estimator_fn'):
            logger.info("invoking estimator_fn")
            return self.customer_script.estimator_fn(run_config, hyperparameters)
        elif hasattr(self.customer_script, 'keras_model_fn'):
            logger.info("involing keras_model_fn")
            model = self.customer_script.keras_model_fn(hyperparameters)
            return tf.keras.estimator.model_to_estimator(keras_model=model, config=run_config)
        else:
            logger.info("creating the estimator")

            def _model_fn(features, labels, mode, params):
                return self.customer_script.model_fn(features, labels, mode, params)

            return tf.estimator.Estimator(
                model_fn=_model_fn,
                params=hyperparameters,
                config=run_config)
    def _build_estimator(self, run_config, hparams):
        # hparams is of type HParams at this point but all the interface functions are assuming dict
        hyperparameters = hparams.values()

        if hasattr(self.customer_script, 'estimator_fn'):
            logger.info("invoking estimator_fn")
            return self.customer_script.estimator_fn(run_config, hyperparameters)
        elif hasattr(self.customer_script, 'keras_model_fn'):
            logger.info("involing keras_model_fn")
            model = self.customer_script.keras_model_fn(hyperparameters)
            return tf.keras.estimator.model_to_estimator(keras_model=model, config=run_config)
        else:
            logger.info("creating the estimator")

            def _model_fn(features, labels, mode, params):
                return self.customer_script.model_fn(features, labels, mode, params)

            return tf.estimator.Estimator(
                model_fn=_model_fn,
                params=hyperparameters,
                config=run_config)
    def cache_prediction_metadata(self):
        channel = implementations.insecure_channel(self.host, self.tf_serving_port)
        stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
        request = get_model_metadata_pb2.GetModelMetadataRequest()

        request.model_spec.name = self.model_name
        request.metadata_field.append('signature_def')
        result = stub.GetModelMetadata(request, self.request_timeout)

        _logger.info('---------------------------Model Spec---------------------------')
        _logger.info(json_format.MessageToJson(result))
        _logger.info('----------------------------------------------------------------')

        signature_def = result.metadata['signature_def']
        signature_map = get_model_metadata_pb2.SignatureDefMap()
        signature_map.ParseFromString(signature_def.value)

        serving_default = signature_map.ListFields()[0][1]['serving_default']
        serving_inputs = serving_default.inputs

        self.input_type_map = {key: serving_inputs[key].dtype for key in serving_inputs.keys()}
        self.prediction_type = serving_default.method_name
        def _experiment_fn(run_config, hparams):
            valid_experiment_keys = ['eval_metrics', 'train_monitors', 'eval_hooks', 'local_eval_frequency',
                                     'eval_delay_secs', 'continuous_eval_throttle_secs', 'min_eval_frequency',
                                     'delay_workers_by_global_step', 'train_steps_per_iteration']

            experiment_params = {k: v for k, v in self.customer_params.items() if k in valid_experiment_keys}

            logger.info("creating Experiment:")
            logger.info(experiment_params)

            '''
            TensorFlow input functions (train_input_fn, and eval_input_fn) can return features and
            labels, or a function that returns features and labels
            Examples of valid input functions:

                def train_input_fn(training_dir, hyperparameters):
                    ...
                    return tf.estimator.inputs.numpy_input_fn(x={"x": train_data}, y=train_labels)

                def train_input_fn(training_dir, hyperparameters):
                    ...
                    return features, labels
            '''
            def _train_input_fn():
                """Prepare parameters for the train_input_fn and invoke it"""
                declared_args = inspect.getargspec(self.customer_script.train_input_fn)
                invoke_args = {arg: self._resolve_value_for_training_input_fn_parameter(arg)
                               for arg in declared_args.args}
                return _function(self.customer_script.train_input_fn(**invoke_args))()

            def _eval_input_fn():
                declared_args = inspect.getargspec(self.customer_script.eval_input_fn)
                invoke_args = {arg: self._resolve_value_for_training_input_fn_parameter(arg)
                               for arg in declared_args.args}
                return _function(self.customer_script.eval_input_fn(**invoke_args))()

            '''
            TensorFlow serving input functions (serving_input_fn) can return a ServingInputReceiver object or a
            function that a ServingInputReceiver
            Examples of valid serving input functions:

                def serving_input_fn(params):
                    feature_spec = {INPUT_TENSOR_NAME: tf.FixedLenFeature(dtype=tf.float32, shape=[4])}
                    return tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)

                def serving_input_fn(hyperpameters):
                    inputs = {INPUT_TENSOR_NAME: tf.placeholder(tf.float32, [None, 32, 32, 3])}
                    return tf.estimator.export.ServingInputReceiver(inputs, inputs)
            '''
            def _serving_input_fn():
                return _function(self.customer_script.serving_input_fn(self.customer_params))()

            def _export_strategy():
                if self.saves_training():
                    return [saved_model_export_utils.make_export_strategy(
                        serving_input_fn=_serving_input_fn,
                        default_output_alternative_key=None,
                        exports_to_keep=1)]
                logger.warn("serving_input_fn not specified, model NOT saved, use checkpoints to reconstruct")
                return None

            return Experiment(
                estimator=self._build_estimator(run_config=run_config, hparams=hparams),
                train_input_fn=_train_input_fn,
                eval_input_fn=_eval_input_fn,
                export_strategies=_export_strategy(),
                train_steps=self.train_steps,
                eval_steps=self.eval_steps,
                **experiment_params
            )
Esempio n. 19
0
        def _experiment_fn(run_config, hparams):
            valid_experiment_keys = [
                'eval_metrics', 'train_monitors', 'eval_hooks',
                'local_eval_frequency', 'eval_delay_secs',
                'continuous_eval_throttle_secs', 'min_eval_frequency',
                'delay_workers_by_global_step', 'train_steps_per_iteration'
            ]

            experiment_params = {
                k: v
                for k, v in self.customer_params.items()
                if k in valid_experiment_keys
            }

            logger.info("creating Experiment:")
            logger.info(experiment_params)
            '''
            TensorFlow input functions (train_input_fn, and eval_input_fn) can return features and
            labels, or a function that returns features and labels
            Examples of valid input functions:

                def train_input_fn(training_dir, hyperparameters):
                    ...
                    return tf.estimator.inputs.numpy_input_fn(x={"x": train_data}, y=train_labels)

                def train_input_fn(training_dir, hyperparameters):
                    ...
                    return features, labels
            '''
            def _train_input_fn():
                """Prepare parameters for the train_input_fn and invoke it"""
                declared_args = inspect.getargspec(
                    self.customer_script.train_input_fn)
                invoke_args = {
                    arg:
                    self._resolve_value_for_training_input_fn_parameter(arg)
                    for arg in declared_args.args
                }
                return _function(
                    self.customer_script.train_input_fn(**invoke_args))()

            def _eval_input_fn():
                declared_args = inspect.getargspec(
                    self.customer_script.eval_input_fn)
                invoke_args = {
                    arg:
                    self._resolve_value_for_training_input_fn_parameter(arg)
                    for arg in declared_args.args
                }
                return _function(
                    self.customer_script.eval_input_fn(**invoke_args))()

            '''
            TensorFlow serving input functions (serving_input_fn) can return a ServingInputReceiver object or a
            function that a ServingInputReceiver
            Examples of valid serving input functions:

                def serving_input_fn(params):
                    feature_spec = {INPUT_TENSOR_NAME: tf.FixedLenFeature(dtype=tf.float32, shape=[4])}
                    return tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)

                def serving_input_fn(hyperpameters):
                    inputs = {INPUT_TENSOR_NAME: tf.placeholder(tf.float32, [None, 32, 32, 3])}
                    return tf.estimator.export.ServingInputReceiver(inputs, inputs)
            '''

            def _serving_input_fn():
                return _function(
                    self.customer_script.serving_input_fn(
                        self.customer_params))()

            def _export_strategy():
                if self.saves_training():
                    return [
                        saved_model_export_utils.make_export_strategy(
                            serving_input_fn=_serving_input_fn,
                            default_output_alternative_key=None,
                            exports_to_keep=1)
                    ]
                logger.warn(
                    "serving_input_fn not specified, model NOT saved, use checkpoints to reconstruct"
                )
                return None

            return Experiment(estimator=self._build_estimator(
                run_config=run_config, hparams=hparams),
                              train_input_fn=_train_input_fn,
                              eval_input_fn=_eval_input_fn,
                              export_strategies=_export_strategy(),
                              train_steps=self.train_steps,
                              eval_steps=self.eval_steps,
                              **experiment_params)
def _handle_rpc_exception(err):
    logger.info("Waiting for TF Serving to load the model due to {}"
                .format(err.__class__.__name__))
    time.sleep(1)
Esempio n. 21
0
def _handle_rpc_exception(err):
    logger.info("Waiting for TF Serving to load the model due to {}".format(
        err.__class__.__name__))
    time.sleep(1)
Esempio n. 22
0
    def train(self):

        if "tunning" not in self.customer_params:
            super(TrainerBayesOptimizer, self).train()
            return self.model_path
        exploratoryParams = self.customer_params['tunning']
        self.model_path_base = self.model_path

        def addParamsTrain(**params):
            self.customer_params.update(params)
            self.model_path = os.path.join(self.model_path_base,
                                           self.params2Path(params))
            estimator = super(TrainerBayesOptimizer, self).train()
            invoke_args = self._resolve_input_fn_args(
                self.customer_script.eval_input_fn)
            res = estimator.evaluate(
                lambda: self.customer_script.eval_input_fn(**invoke_args))
            return res['accuracy']

        nnBO = BayesianOptimization(addParamsTrain, exploratoryParams)
        ## do a first exploratory work with many init and big kappa
        exploratory_tunning = self.customer_params[
            'exploratory_tunning'] if 'exploratory_tunning' in self.customer_params else 6
        nnBO.maximize(init_points=2,
                      n_iter=exploratory_tunning,
                      kappa=5,
                      acq='ei')
        #finetune with small kappa
        fine_tunning = self.customer_params[
            'fine_tunning'] if 'fine_tunning' in self.customer_params else 6
        nnBO.maximize(init_points=0, n_iter=fine_tunning, kappa=2, acq='ei')

        logger.info("-------------------")
        logger.info("model results")
        for params, results in zip(nnBO.res['all']['params'],
                                   nnBO.res['all']['values']):
            logger.info("%s : %f" % (str(params), results))
        logger.info("-------------------")
        logger.info("best model")
        logger.info(nnBO.res['max']['max_params'])
        logger.info(nnBO.res['max']['max_val'])
        logger.info("-------------------")
        best_model_path = os.path.join(
            self.model_path_base,
            self.params2Path(nnBO.res['max']['max_params']))
        logger.info("best_model_path=%s" % best_model_path)
        return best_model_path