Ejemplo n.º 1
0
def GenerateModelV2(tf_saved_model_dir, tftrt_saved_model_dir):
    """Generate and convert a model using TFv2 API."""
    class SimpleModel(tracking.AutoTrackable):
        """Define model with a TF function."""
        def __init__(self):
            self.v = None

        @def_function.function(input_signature=[
            tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32),
            tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32)
        ])
        def run(self, input1, input2):
            if self.v is None:
                self.v = variables.Variable([[[1.0]]], dtype=dtypes.float32)
            return GetGraph(input1, input2, self.v)

    root = SimpleModel()

    # Saved TF model
    # pylint: disable=not-callable
    save(root, tf_saved_model_dir,
         {signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: root.run})

    # Convert TF model to TensorRT
    converter = trt_convert.TrtGraphConverterV2(
        input_saved_model_dir=tf_saved_model_dir)
    converter.convert()
    try:
        line_length = max(160, os.get_terminal_size().columns)
    except OSError:
        line_length = 160
    converter.summary(line_length=line_length, detailed=True)
    converter.save(tftrt_saved_model_dir)
Ejemplo n.º 2
0
def GenerateModelV2(tf_saved_model_dir, tftrt_saved_model_dir):
    """Generate and convert a model using TFv2 API."""
    class SimpleModel(tracking.AutoTrackable):
        """Define model with a TF function."""
        def __init__(self):
            self.v = None

        @def_function.function(input_signature=[
            tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32),
            tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32)
        ])
        def run(self, input1, input2):
            if self.v is None:
                self.v = variables.Variable([[[1.0]]], dtype=dtypes.float32)
            return GetGraph(input1, input2, self.v)

    root = SimpleModel()

    # Saved TF model
    save(root, tf_saved_model_dir,
         {signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: root.run})

    # Convert TF model to TensorRT
    converter = trt_convert.TrtGraphConverterV2(
        input_saved_model_dir=tf_saved_model_dir)
    converter.convert()
    converter.save(tftrt_saved_model_dir)
Ejemplo n.º 3
0
def GenerateModelWithReadVariableOp(tf_saved_model_dir):
    """Generate a model with ReadVariableOp nodes."""
    my_model = MyModel()
    cfunc = my_model.__call__.get_concrete_function(
        tensor_spec.TensorSpec([None, 1, 1], dtypes.float32),
        tensor_spec.TensorSpec([None, 1, 1], dtypes.float32))
    # pylint: disable=not-callable
    save(my_model, tf_saved_model_dir, signatures=cfunc)
Ejemplo n.º 4
0
 def on_epoch_end(self, epoch, logs=None):
     current = logs.get(self.monitor)
     if current is None:
         #logging.warning('Can save best model only with %s available, '
         #                'skipping.', self.monitor)
         raise AttributeError(
             'Can save best model only with %s available, '
             'skipping.', self.monitor)
     filepath = self._get_file_path(epoch, logs)
     save(self.model, filepath)