def test_getitem_index_real_tensor(self):
    if not context.executing_eagerly():
      self.skipTest('Complex slicing like this fails in v1')
    x = math_ops.range(10.0)
    slice_stop = keras.Input(shape=(), dtype='int32')

    out = x[slice_stop[0]]
    model = keras.Model(
        inputs=slice_stop,
        outputs=out)
    model.compile(
        adam.Adam(0.001),
        'mse',
        run_eagerly=testing_utils.should_run_eagerly())
    batch_size = 7
    index = 6
    args = constant_op.constant(index, shape=(batch_size,))
    expected = x[index]

    if keras_tensor.keras_tensors_enabled():
      self.assertIn('tf.__operators__.getitem', (
          x.name for x in model.layers))
      # TODO(b/161925288): Fix the bug then uncomment:
      # self.assertNotIn('tf.strided_slice', (
      #     x.name for x in model.layers))
    self.assertAllEqual(model(args), expected)
    self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)

    # TODO(b/161925288): Fix the bug then uncomment:
    # # Make sure it can be successfully saved and loaded
    # config = model.get_config()
    # model = keras.Model.from_config(config)

    self.assertAllEqual(model(args), expected)
    self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
  def test_getitem_slice_with_stop_only(self):
    if not context.executing_eagerly():
      self.skipTest('Complex slicing like this fails in v1')
    inp = keras.Input(shape=(4, 3, 8))
    slice_stop = keras.Input(shape=(), dtype='int32')

    out = inp[:slice_stop[0]]
    model = keras.Model(
        inputs=[inp, slice_stop],
        outputs=out)
    model.compile(
        adam.Adam(0.001),
        'mse',
        run_eagerly=testing_utils.should_run_eagerly())
    batch_size = 7
    stop = 6
    x = array_ops.stack([
        math_ops.range(8) for _ in range(batch_size)])
    args = [x, constant_op.constant(stop, shape=(batch_size,))]
    expected = x[:stop]

    if keras_tensor.keras_tensors_enabled():
      self.assertIn('tf.__operators__.getitem', (
          x.name for x in model.layers))
      self.assertNotIn('tf.strided_slice', (
          x.name for x in model.layers))
    self.assertAllEqual(model(args), expected)
    self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)

    # Make sure it can be successfully saved and loaded
    config = model.get_config()
    model = keras.Model.from_config(config)

    self.assertAllEqual(model(args), expected)
    self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
def _int32_manipulation_at_max_shape_dims_limit():
  # This test verifies that the Keras Functional API
  # won't crash when manipulating int32 tensors that are at the limit
  # of the max tensor size Keras can try inferring values for.
  inputs = keras.Input(batch_size=2, shape=(10,))
  batch_size = array_ops.shape(inputs)[0]
  num_features = int(keras_tensor._MAX_TENSOR_DIMS / int(inputs.shape[0]))
  x = math_ops.range(batch_size * num_features, dtype='int32')
  assert x.shape.as_list() == [keras_tensor._MAX_TENSOR_DIMS]

  # Verify that a value was actually inferred for a tensor that *might*
  # represent the shape, bying checking that a value in
  # the range appears in the printed inferred value
  if keras_tensor.keras_tensors_enabled():
    assert str(keras_tensor._MAX_TENSOR_DIMS - 1) in str(x)

  x = array_ops.reshape(x, (batch_size, num_features))
  x = math_ops.cast(x, dtype='float32')
  outputs = keras.layers.Dense(10)(x)
  if context.executing_eagerly():
    return keras.Model(inputs, outputs)
  else:
    # In V1 the op layer fails for some reason,
    # but we don't have access to the test case to call
    # self.skip_test in this util method
    return keras.Model(inputs, inputs)
    def __init__(self, layer, call_args=None, call_kwargs=None, outputs=None):
        call_args = [] if call_args is None else call_args
        call_kwargs = {} if call_kwargs is None else call_kwargs
        outputs = [] if outputs is None else outputs

        self.layer = layer
        self.is_input = not call_args and not call_kwargs

        # These arguments are user-provided. Copy the structures here so that
        # future user modifications do not affect the node's metadata.
        # We copy using map_structure rather than python's shallow or deep copy,
        # because the args can be data structures (so shallow copy is
        # insufficient), but individual values might not support copy.copy
        # or be too expensive to deep copy.
        call_args = nest.map_structure(lambda t: t, call_args)
        call_kwargs = nest.map_structure(lambda t: t, call_kwargs)
        self.outputs = nest.map_structure(lambda t: t, outputs)
        self.call_args = call_args
        self.call_kwargs = call_kwargs

        # Cached for performance.
        self._flat_arguments = nest.flatten((self.call_args, self.call_kwargs))
        # Used to avoid expensive `nest` operations in the most common case.
        self._single_positional_tensor_passed = (not self.call_kwargs and len(
            self.call_args) == 1 and tensor_util.is_tensor(self.call_args[0]))

        if not keras_tensor.keras_tensors_enabled():
            # Create TensorFlowOpLayers if needed.
            for obj in self._flat_arguments:
                if (isinstance(obj, ops.Tensor)
                        and base_layer_utils.needs_keras_history(
                            obj, ignore_call_context=True)):
                    base_layer_utils.create_keras_history(obj)

        self._keras_inputs = []
        self._keras_inputs_ids_and_indices = []
        for i, ele in enumerate(self._flat_arguments):
            if is_keras_tensor(ele):
                self._keras_inputs.append(ele)
                kt_id = str(id(ele))
                kt_index = i
                self._keras_inputs_ids_and_indices.append((kt_id, kt_index))

        # Wire up Node to Layers.
        self.layer._inbound_nodes.append(self)
        for kt in self.keras_inputs:
            inbound_layer = kt._keras_history.layer
            if inbound_layer is not None:  # `None` for `Input` tensors.
                inbound_layer._outbound_nodes.append(self)

        # Set metadata on outputs.
        node_index = len(self.layer._inbound_nodes) - 1
        for i, tensor in enumerate(nest.flatten(outputs)):
            tensor._keras_history = KerasHistory(layer=layer,
                                                 node_index=node_index,
                                                 tensor_index=i)

        # Cached for performance.
        self.flat_input_ids = [str(id(t)) for t in self._keras_inputs]
        self.flat_output_ids = [str(id(t)) for t in nest.flatten(self.outputs)]
Example #5
0
def keras_session(model, input_tensor_names, output_tensor_names, **kwargs):
    assert tf.version.VERSION > '2.1.0', 'keras model need tensorflow version > 2.1.0....'
    from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
    if not isinstance(model, tf.keras.Model):
        model = tf.keras.models.load_model(model)
    kwargs = dict(zip(model.input_names, model.inputs))
    if tf.version.VERSION > '2.2.0':
        from tensorflow.python.keras.engine import keras_tensor
        if keras_tensor.keras_tensors_enabled():
            for name, tensor in kwargs.items():
                kwargs[name] = tensor.type_spec
    full_model = tf.function(lambda **kwargs: model(kwargs.values()))
    concrete_function = full_model.get_concrete_function(**kwargs)
    frozen_model = convert_variables_to_constants_v2(concrete_function)
    graph_def = frozen_model.graph.as_graph_def()
    input_names = [
        node.name for node in graph_def.node if node.op == 'Placeholder'
    ]
    output_names = [output.split(':')[0] for output in model.output_names]
    # replace the output name with squential
    for output_name in output_names:
        for node in graph_def.node[::-1]:
            if node.op == 'Identity' and output_name in node.input[0]:
                node.name = output_name
                break

    return graph_def_session(graph_def, input_names, output_names, **kwargs)
Example #6
0
    def test_getitem_complex_slicing(self):
        if not context.executing_eagerly():
            self.skipTest('Complex slicing like this fails in v1')
        inp = keras.Input(shape=(4, 3, 8))
        first_dim = keras.Input(shape=(), dtype='int32')
        slice_start = keras.Input(shape=(), dtype='int32')
        slice_stop = keras.Input(shape=(), dtype='int32')
        slice_stride = keras.Input(shape=(), dtype='int32')

        out = inp[..., first_dim[0],
                  slice_start[0]:slice_stop[0]:slice_stride[0]]
        model = keras.Model(
            inputs=[inp, first_dim, slice_start, slice_stop, slice_stride],
            outputs=out)
        model.compile(adam.Adam(0.001),
                      'mse',
                      run_eagerly=testing_utils.should_run_eagerly())
        batch_size = 7
        start = 1
        stop = 6
        step = 2
        x = array_ops.stack([
            array_ops.stack([
                array_ops.stack([math_ops.range(8) for _ in range(3)])
                for _ in range(4)
            ]) for _ in range(batch_size)
        ])
        args = [
            x,
            constant_op.constant(0, shape=(batch_size, )),
            constant_op.constant(start, shape=(batch_size, )),
            constant_op.constant(stop, shape=(batch_size, )),
            constant_op.constant(step, shape=(batch_size, ))
        ]
        # Slice the innermost dim. only grab one index from the second-to-innermost
        # dim, removing that dim from the shape.
        expected = array_ops.stack([
            array_ops.stack(
                [math_ops.range(8)[start:stop:step] for _ in range(4)])
            for _ in range(batch_size)
        ])

        if keras_tensor.keras_tensors_enabled():
            self.assertIn('tf.__operators__.getitem',
                          (x.name for x in model.layers))
            self.assertNotIn('tf.strided_slice',
                             (x.name for x in model.layers))
        self.assertAllEqual(model(args), expected)
        self.assertAllEqual(model.predict(args, batch_size=batch_size),
                            expected)

        # Make sure it can be successfully saved and loaded
        config = model.get_config()
        model = keras.Model.from_config(config)

        self.assertAllEqual(model(args), expected)
        self.assertAllEqual(model.predict(args, batch_size=batch_size),
                            expected)
Example #7
0
 def testBody(self):
     mode = "eager" if context.executing_eagerly() else "graph"
     should_run_eagerly = testing_utils.should_run_eagerly()
     l.append((mode, should_run_eagerly,
               keras_tensor.keras_tensors_enabled()))
Example #8
0
    def __init__(self,
                 input_shape=None,
                 batch_size=None,
                 dtype=None,
                 input_tensor=None,
                 sparse=None,
                 name=None,
                 ragged=None,
                 type_spec=None,
                 **kwargs):
        self._init_input_shape = input_shape
        self._init_batch_size = batch_size
        self._init_dtype = dtype
        self._init_sparse = sparse
        self._init_ragged = ragged
        self._init_type_spec = type_spec

        strategy = distribution_strategy_context.get_strategy()
        if strategy and batch_size is not None and \
            distributed_training_utils.global_batch_size_supported(strategy):
            if batch_size % strategy.num_replicas_in_sync != 0:
                raise ValueError(
                    'The `batch_size` argument ({}) must be divisible by '
                    'the number of replicas ({})'.format(
                        batch_size, strategy.num_replicas_in_sync))
            batch_size = batch_size // strategy.num_replicas_in_sync

        if 'batch_input_shape' in kwargs:
            batch_input_shape = kwargs.pop('batch_input_shape')
            if input_shape and batch_input_shape:
                raise ValueError('Only provide the input_shape OR '
                                 'batch_input_shape argument to '
                                 'InputLayer, not both at the same time.')
            batch_size = batch_input_shape[0]
            input_shape = batch_input_shape[1:]
        if kwargs:
            raise ValueError('Unrecognized keyword arguments:', kwargs.keys())

        if sparse and ragged:
            raise ValueError(
                'Cannot set both sparse and ragged to True in a Keras input.')

        if not name:
            prefix = 'input'
            name = prefix + '_' + str(backend.get_uid(prefix))

        if not dtype:
            if input_tensor is None:
                dtype = backend.floatx()
            else:
                dtype = backend.dtype(input_tensor)
        elif input_tensor is not None and input_tensor.dtype != dtype:
            raise ValueError(
                '`input_tensor.dtype` differs from `dtype`: %s vs. %s' %
                (input_tensor.dtype, dtype))
        super(InputLayer, self).__init__(dtype=dtype, name=name)
        self.built = True
        self.sparse = True if sparse else False
        self.ragged = True if ragged else False
        self.batch_size = batch_size
        self.supports_masking = True

        if isinstance(input_shape, tensor_shape.TensorShape):
            input_shape = tuple(input_shape.as_list())
        elif isinstance(input_shape, int):
            input_shape = (input_shape, )

        if type_spec is not None:
            args_that_must_be_none = [
                ('(input_)shape', self._init_input_shape),
                ('batch_size', self._init_batch_size),
                ('dtype', self._init_dtype),
                ('input_tensor', input_tensor),
                ('sparse', self._init_sparse),
                ('ragged', self._init_ragged),
            ]
            for arg_name, arg in args_that_must_be_none:
                _assert_other_arg_none(arg_name, arg)
            if not keras_tensor.keras_tensors_enabled():
                raise ValueError(
                    'Creating Keras inputs from a type_spec is only '
                    'supported when eager execution is enabled.')
            input_tensor = keras_tensor.keras_tensor_from_type_spec(type_spec)
            if isinstance(input_tensor, keras_tensor.SparseKerasTensor):
                self.sparse = True
            if isinstance(input_tensor, keras_tensor.RaggedKerasTensor):
                self.ragged = True
            self.is_placeholder = True
            try:
                self._batch_input_shape = tuple(input_tensor.shape.as_list())
            except ValueError:
                # If the shape cannot be represented as a tuple (e.g. unknown rank)
                self._batch_input_shape = None
        elif input_tensor is None:
            if input_shape is not None:
                batch_input_shape = (batch_size, ) + tuple(input_shape)
            else:
                batch_input_shape = None
            graph = backend.get_graph()
            with graph.as_default():
                input_tensor = backend.placeholder(shape=batch_input_shape,
                                                   dtype=dtype,
                                                   name=self.name,
                                                   sparse=sparse,
                                                   ragged=ragged)

            self.is_placeholder = True
            self._batch_input_shape = batch_input_shape
        else:
            if keras_tensor.keras_tensors_enabled():
                if not isinstance(input_tensor, keras_tensor.KerasTensor):
                    input_tensor = keras_tensor.keras_tensor_from_tensor(
                        input_tensor)
            else:
                if not tf_utils.is_symbolic_tensor(input_tensor):
                    raise ValueError(
                        'You should not pass an EagerTensor to `Input`. '
                        'For example, instead of creating an '
                        'InputLayer, you should instantiate your model and '
                        'directly call it on your input.')
            self.is_placeholder = False
            try:
                self._batch_input_shape = tuple(input_tensor.shape.as_list())
            except ValueError:
                # If the shape cannot be represented as a tuple (e.g. unknown rank)
                self._batch_input_shape = None
        # Create an input node.
        input_tensor._keras_mask = None
        node_module.Node(layer=self, outputs=input_tensor)

        # Store type spec
        if isinstance(input_tensor, keras_tensor.KerasTensor) or (
                tf_utils.is_extension_type(input_tensor)):
            self._type_spec = input_tensor._type_spec  # pylint: disable=protected-access
        else:
            self._type_spec = tensor_spec.TensorSpec(shape=input_tensor.shape,
                                                     dtype=input_tensor.dtype,
                                                     name=self.name)
Example #9
0
def maybe_enter_backend_graph():
  if (keras_tensor is not None) and keras_tensor.keras_tensors_enabled():
    return NoOpContextManager()
  else:
    return backend.get_graph().as_default()
Example #10
0
def keras_session(model, input_tensor_names, output_tensor_names, **kwargs):
    """Build session with keras model

    Args:
        model (string or tf.keras.Model): model path or tf.keras.Model object
        input_tensor_names (list of string): input_tensor_names of model
        output_tensor_names (list of string): output_tensor_names of model

     Returns:
        sess (tf.compat.v1.Session): tf.compat.v1.Session object
        input_tensor_names (list of string): validated input_tensor_names
        output_tensor_names (list of string): validated output_tensor_names
    """

    assert tf.version.VERSION > '2.1.0', 'keras model need tensorflow version > 2.1.0....'
    from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
    if not isinstance(model, tf.keras.Model):
        model = tf.keras.models.load_model(model)
    kwargs = dict(zip(model.input_names, model.inputs))
    if tf.version.VERSION > '2.2.0':
        from tensorflow.python.keras.engine import keras_tensor
        if keras_tensor.keras_tensors_enabled():
            for name, tensor in kwargs.items():
                kwargs[name] = tensor.type_spec
    full_model = tf.function(lambda **kwargs: model(kwargs.values()))
    concrete_function = full_model.get_concrete_function(**kwargs)
    frozen_model = convert_variables_to_constants_v2(concrete_function)

    from tensorflow.python.training import saver
    from tensorflow.core.protobuf import config_pb2
    from tensorflow.python.grappler import tf_optimizer
    from tensorflow.core.protobuf import meta_graph_pb2
    graph_def = frozen_model.graph.as_graph_def()
    input_names = [
        node.name for node in graph_def.node if node.op == 'Placeholder'
    ]
    output_names = [output.split(':')[0] for output in model.output_names]
    # replace the output name with squential
    for output_name in output_names:
        for node in graph_def.node[::-1]:
            if node.op == 'Identity' and output_name in node.input[0]:
                node.name = output_name
                break

    grappler_meta_graph_def = saver.export_meta_graph(graph_def=graph_def,
                                                      graph=frozen_model.graph)

    # Add a collection 'train_op' so that Grappler knows the outputs.
    fetch_collection = meta_graph_pb2.CollectionDef()
    for array in model.output_names:
        fetch_collection.node_list.value.append(array)
    grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
        fetch_collection)
    grappler_session_config = config_pb2.ConfigProto()
    rewrite_options = grappler_session_config.graph_options.rewrite_options
    rewrite_options.optimizers.append('constfold')
    rewrite_options.min_graph_nodes = -1
    graph_def = tf_optimizer.OptimizeGraph(grappler_session_config, \
                        grappler_meta_graph_def, graph_id=b"tf_graph")

    return graph_def_session(graph_def, input_names, output_names, **kwargs)