示例#1
0
 def test_list(self):
   a = [np.ones(10), np.ones(20)]
   model_inputs = training_utils.ModelInputs(a)
   self.assertEqual(['input_1', 'input_2'], model_inputs.get_input_names())
   vals = model_inputs.get_symbolic_inputs()
   self.assertTrue(tensor_util.is_tensor(vals[0]))
   self.assertTrue(tensor_util.is_tensor(vals[1]))
示例#2
0
def keras_layer_tracepoint(layer, checkpoint_name):
  """An interface for adding the tensor outputs of a keras layer.

  Encapsulates tensor_tracepoint.

  Args:
     layer: A keras layer.
     checkpoint_name: a string name for the checkpoint. This name has to be a
     unique name if used within model comparison. The tensors that have the same
     checkpoint identifier is compared in model comparison.

  Returns:
    The provided layer.
  """
  try:
    outputs = layer.output
    if tensor_util.is_tensor(outputs):
      tensor_tracepoint(outputs, '%s' % (checkpoint_name))
    else:
      idx = 0
      for output_tensor in outputs:
        if tensor_util.is_tensor(outputs):
          tensor_tracepoint(output_tensor, '%s_%d' % (checkpoint_name, idx))
        idx += 1
  except AttributeError:
    pass
  except RuntimeError:
    pass
  return layer
示例#3
0
 def test_dict(self):
   a = {'b': np.ones(10), 'a': np.ones(20)}
   model_inputs = training_utils.ModelInputs(a)
   self.assertEqual(['a', 'b'], model_inputs.get_input_names())
   vals = model_inputs.get_symbolic_inputs()
   self.assertTrue(tensor_util.is_tensor(vals['a']))
   self.assertTrue(tensor_util.is_tensor(vals['b']))
示例#4
0
 def test_single_thing(self):
   a = np.ones(10)
   model_inputs = training_utils.ModelInputs(a)
   self.assertEqual(['input_1'], model_inputs.get_input_names())
   vals = model_inputs.get_symbolic_inputs()
   self.assertTrue(tensor_util.is_tensor(vals))
   vals = model_inputs.get_symbolic_inputs(return_single_as_list=True)
   self.assertEqual(1, len(vals))
   self.assertTrue(tensor_util.is_tensor(vals[0]))
示例#5
0
def _split_dataset_batch(dataset, split_batch_by):
  """Divide a batch-ed dataset's batches into smaller batches."""
  # TODO(sourabhbajaj): Remove this in lieu of distributed datasets
  # pylint: disable=protected-access
  def _get_batch_dataset(d):
    """Get the underlying batch dataset from the dataset object."""
    if isinstance(d, dataset_ops.DatasetV1Adapter):
      d = d._dataset

    if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)):
      return d
    elif isinstance(d, dataset_ops.PrefetchDataset):
      return _get_batch_dataset(d._input_dataset)
    raise ValueError(
        "Unable to get batched dataset from the input dataset. `batch` "
        "`map_and_batch` need to be the last operations on the dataset. "
        "The batch operations can be followed by a prefetch.")

  batched_dataset = _get_batch_dataset(dataset)
  if isinstance(batched_dataset, dataset_ops.BatchDataset):
    batch_size = batched_dataset._batch_size
    drop_remainder = batched_dataset._drop_remainder
  elif isinstance(batched_dataset, batching._MapAndBatchDataset):
    batch_size = batched_dataset._batch_size_t
    drop_remainder = batched_dataset._drop_remainder_t

  prefetch_buffer = None
  if isinstance(dataset, dataset_ops.PrefetchDataset):
    prefetch_buffer = dataset._buffer_size
  elif (isinstance(dataset, dataset_ops.DatasetV1Adapter)
        and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)):
    prefetch_buffer = dataset._dataset._buffer_size
  # pylint: enable=protected-access

  if tensor_util.is_tensor(batch_size):
    batch_size = tensor_util.constant_value(batch_size)

  if tensor_util.is_tensor(drop_remainder):
    drop_remainder = tensor_util.constant_value(drop_remainder)

  if batch_size % split_batch_by:
    raise ValueError(
        "Batch size %s cannot be sharded evenly across replicas %s" % (
            batch_size, split_batch_by))
  new_batch_size = batch_size // split_batch_by

  dataset = dataset.apply(batching.unbatch())
  dataset = dataset.batch(new_batch_size, drop_remainder=drop_remainder)
  if prefetch_buffer is not None:
    dataset = dataset.prefetch(prefetch_buffer)
  return dataset
示例#6
0
def test_on_batch(model, inputs, targets, sample_weights=None):
  """Calculates the loss for one input batch.

  Arguments:
      model: Model whose loss has to be calculated.
      inputs: Input batch data.
      targets: Target batch data.
      sample_weights: Sample weight batch data.

  Returns:
      total loss, loss and metrics associated with each output.
  """
  if len(inputs) and not tensor_util.is_tensor(inputs[0]):
    inputs = [
        ops.convert_to_tensor(val, dtype=backend.floatx()) for val in inputs
    ]
    targets = [
        ops.convert_to_tensor(val, dtype=backend.floatx()) for val in targets
    ]
  if sample_weights:
    sample_weights = [
        ops.convert_to_tensor(val, dtype=backend.floatx())
        if val is not None else None for val in sample_weights
    ]
  outs, loss, loss_metrics = _model_loss(
      model, inputs, targets, sample_weights=sample_weights, training=False)
  if not isinstance(outs, list):
    outs = [outs]
  metrics_results = _eager_metrics_fn(model, outs, targets)
  if not isinstance(loss, list):
    loss = [loss]
  return loss + loss_metrics + metrics_results
示例#7
0
    def _fetch_preprocesing_callback(f):
      """Extract out lists of ops, tensors, and tensor type info.

      Turns TensorInfos into Tensors in the original fetches structure.

      Args:
        f: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string
          identifying a Tensor or Operation.

      Returns:
        `f` converted to a Tensor.
      """
      if isinstance(f, ops.Operation):
        operation_fetches.append(f)
        return f
      elif isinstance(f, meta_graph_pb2.TensorInfo):
        tensor_infos.append(f)
        decoded = _get_element_from_tensor_info(f, self._func_graph)
        if tensor_util.is_tensor(decoded):
          tensor_fetches.append(decoded)
        else:
          operation_fetches.append(decoded)
        return decoded
      elif isinstance(f, ops.Tensor):
        tensor_fetches.append(f)
        return f
      else:
        graph_element = self.graph.as_graph_element(f)
        return _fetch_preprocesing_callback(graph_element)
示例#8
0
def if_stmt(cond, body, orelse, get_state, set_state):
  """Functional form of an if statement.

  Args:
    cond: Boolean.
    body: Callable with no arguments, and outputs of the positive (if) branch
        as return type.
    orelse: Callable with no arguments, and outputs of the negative (else)
        branch as return type.
    get_state: Function that returns a tuple containing the values of all
        composite symbols modified within the conditional. This allows access to
        state that branches may mutate through side effects. This function is
        not needed and should not be called when dispatching to code matching
        Python's default semantics. This is useful for checkpointing to avoid
        unintended side-effects when staging requires evaluating all code-paths.
    set_state: Function to set the values of all composite symbols modified
        within the conditional. This is the complement to get_state, used to
        restore checkpointed values. The single argument a tuple containing
        values for each composite symbol that may be modified in a branch of the
        conditional. The is usually the result of a call to get_state.

  Returns:
    Tuple containing the statement outputs.
  """
  if tensor_util.is_tensor(cond):
    return tf_if_stmt(cond, body, orelse, get_state, set_state)
  else:
    return _py_if_stmt(cond, body, orelse)
 def _find_any_tensor(batch_features):
   tensors = [
       x for x in nest.flatten(batch_features) if tensor_util.is_tensor(x)
   ]
   if not tensors:
     raise ValueError('Cannot find any Tensor in features dict.')
   return tensors[0]
def list_pop(list_, i, opts):
  """The list pop function.

  Note: it is unspecified where list_ will be mutated or not. If list_ is
  a TensorFlow entity, it will not be typically mutated. If list_ is a plain
  list, it will be. In general, if the list is mutated then the return value
  should point to the original entity.

  Args:
    list_: An entity that supports pop semantics.
    i: Optional index to pop from. May be None.
    opts: A ListPopOpts.

  Returns:
    Tuple (x, out_list_):
      out_list_: same as list_, after the removal was performed.
      x: the removed element value.

  Raises:
    ValueError: if list_ is not of a known list-like type or the operation is
    not supported for that type.
  """
  assert isinstance(opts, ListPopOpts)

  if isinstance(list_, tensor_array_ops.TensorArray):
    raise ValueError('TensorArray does not support item removal')
  elif tensor_util.is_tensor(list_):
    if list_.dtype == dtypes.variant:
      return _tf_tensor_list_pop(list_, i, opts)
    else:
      raise ValueError(
          'tensor lists are expected to be Tensors with dtype=tf.variant,'
          ' instead found %s' % list_)
  else:
    return _py_list_pop(list_, i)
示例#11
0
def get_item(target, i, opts):
  """The slice read operator (i.e. __getitem__).

  Note: it is unspecified whether target will be mutated or not. In general,
  if target is mutable (like Python lists), it will be mutated.

  Args:
    target: An entity that supports getitem semantics.
    i: Index to read from.
    opts: A GetItemOpts object.

  Returns:
    The read element.

  Raises:
    ValueError: if target is not of a supported type.
  """
  assert isinstance(opts, GetItemOpts)

  if isinstance(target, tensor_array_ops.TensorArray):
    return _tf_tensorarray_get_item(target, i)
  elif tensor_util.is_tensor(target):
    if target.dtype == dtypes.variant:
      return _tf_tensor_list_get_item(target, i, opts)
    elif target.dtype == dtypes.string and target.shape.ndims == 0:
      return _tf_tensor_string_get_item(target, i)
    else:
      return _tf_tensor_get_item(target, i)
  else:
    return _py_get_item(target, i)
def list_append(list_, x):
  """The list append function.

  Note: it is unspecified where list_ will be mutated or not. If list_ is
  a TensorFlow entity, it will not be typically mutated. If list_ is a plain
  list, it will be. In general, if the list is mutated then the return value
  should point to the original entity.

  Args:
    list_: An entity that supports append semantics.
    x: The element to append.

  Returns:
    Same as list_, after the append was performed.

  Raises:
    ValueError: if list_ is not of a known list-like type.
  """
  if isinstance(list_, tensor_array_ops.TensorArray):
    return _tf_tensorarray_append(list_, x)
  elif tensor_util.is_tensor(list_):
    if list_.dtype == dtypes.variant:
      return _tf_tensor_list_append(list_, x)
    else:
      raise ValueError(
          'tensor lists are expected to be Tensors with dtype=tf.variant,'
          ' instead found %s' % list_)
  else:
    return _py_list_append(list_, x)
示例#13
0
def while_stmt(test, body, init_state, extra_deps, opts=None):
  """Functional form of a while statement.

  The loop operates on a so-called state, which includes all symbols that are
  variant across loop iterations. In what follows we refer to state as either
  a tuple of entities that represent an actual state, or a list of arguments
  of the corresponding types.

  Args:
    test: Callable with the state as arguments, and boolean return type.
        The loop condition.
    body: Callable with the state as arguments, and state as return type.
        The actual loop body.
    init_state: Tuple containing the initial state.
    extra_deps: Tuple containing additional entities on which the loop may
        depend, such as loop invariants referenced by test. Used
        exclusively for dispatch control.
    opts: Optional dict of extra loop parameters.

  Returns:
    Tuple containing the final state.
  """
  # TODO(mdan): Consider adding a generic mechanism for dynamic dispatch.
  # That could be something as simple as a collection of dispatch rules, with
  # some prioritization.
  if any(tensor_util.is_tensor(v) for v in init_state + extra_deps):
    return _tf_while_stmt(test, body, init_state, opts)
  else:
    return _py_while_stmt(test, body, init_state, opts)
def validate_per_device_inputs(distribution_strategy, x):
  """Validates PerDevice dataset input list.

  Args:
    distribution_strategy: The current DistributionStrategy used to call
      `fit`, `evaluate` and `predict`.
    x: A list of PerDevice objects that represent the input or
      target values.

  Returns:
    List containing the first element of each of the PerDevice objects in
    the input list.

  Raises:
    ValueError: If any of the objects in the `per_device_list` is not a tensor.

  """
  # Convert the inputs and targets into a list of PerDevice objects.
  per_device_list = nest.flatten(x)
  x_values_list = []
  for x in per_device_list:
    if not tensor_util.is_tensor(x):
      raise ValueError('Dataset input to the model should be tensors instead '
                       'they are of type {}'.format(type(x)))

    # At this point both x and y contain tensors in the `DistributedValues`
    # structure.
    x_values = distribution_strategy.unwrap(x)

    # Validate that the shape and dtype of all the elements in x are the same.
    validate_all_tensor_shapes(x, x_values)
    validate_all_tensor_types(x, x_values)

    x_values_list.append(x_values[0])
  return x_values_list
示例#15
0
def slice_arrays(arrays, indices, contiguous=True):
  """Slices batches out of provided arrays (workaround for eager tensors).

  Unfortunately eager tensors don't have the same slicing behavior as
  Numpy arrays (they follow the same slicing behavior as symbolic TF tensors),
  hence we cannot use `generic_utils.slice_arrays` directly
  and we have to implement this workaround based on `concat`. This has a
  performance cost.

  Arguments:
    arrays: Single array or list of arrays.
    indices: List of indices in the array that should be included in the output
      batch.
    contiguous: Boolean flag indicating whether the indices are contiguous.

  Returns:
    Slice of data (either single array or list of arrays).
  """
  converted_to_list = False
  if not isinstance(arrays, list):
    converted_to_list = True
    arrays = [arrays]
  if any(tensor_util.is_tensor(x) for x in arrays):
    if not contiguous:
      entries = [[x[i:i + 1] for i in indices] for x in arrays]
      slices = [array_ops.concat(x, axis=0) for x in entries]
    else:
      slices = [x[indices[0]:indices[-1] + 1] for x in arrays]
  else:
    slices = generic_utils.slice_arrays(arrays, indices)

  if converted_to_list:
    slices = slices[0]
  return slices
示例#16
0
def set_item(target, i, x):
  """The slice write operator (i.e. __setitem__).

  Note: it is unspecified whether target will be mutated or not. In general,
  if target is mutable (like Python lists), it will be mutated.

  Args:
    target: An entity that supports setitem semantics.
    i: Index to modify.
    x: The new element value.

  Returns:
    Same as target, after the update was performed.

  Raises:
    ValueError: if target is not of a supported type.
  """
  if isinstance(target, tensor_array_ops.TensorArray):
    return _tf_tensorarray_set_item(target, i, x)
  elif tensor_util.is_tensor(target):
    if target.dtype == dtypes.variant:
      return _tf_tensor_list_set_item(target, i, x)
    else:
      raise ValueError(
          'tensor lists are expected to be Tensors with dtype=tf.variant,'
          ' instead found %s' % target)
  else:
    return _py_set_item(target, i, x)
def list_stack(list_, opts):
  """The list stack function.

  This does not have a direct correspondent in Python. The closest idiom to
  this is tf.append or np.stack. It's different from those in the sense that it
  accepts a Tensor list, rather than a list of tensors. It can also accept
  TensorArray. When the target is anything else, the dispatcher will rely on
  ctx.original_call for fallback.

  Args:
    list_: An entity that supports append semantics.
    opts: A ListStackOpts object.

  Returns:
    The output of the stack operation, typically a Tensor.
  """
  assert isinstance(opts, ListStackOpts)

  if isinstance(list_, tensor_array_ops.TensorArray):
    return _tf_tensorarray_stack(list_)
  elif tensor_util.is_tensor(list_):
    if list_.dtype == dtypes.variant:
      return _tf_tensor_list_stack(list_, opts)
    else:
      # No-op for primitive Tensor arguments.
      return list_
  else:
    return _py_list_stack(list_, opts)
示例#18
0
def is_tensor_list(t):
  # TODO(mdan): This is just a heuristic.
  # With TF lacking support for templated types, this is unfortunately the
  # closest we can get right now. A dedicated op ought to be possible to
  # construct.
  return (tensor_util.is_tensor(t) and t.dtype == dtypes.variant and
          not t.shape.ndims)
示例#19
0
def assert_stmt(expression1, expression2):
  """Functional form of an assert statement.

  This follows the semantics of the Python assert statement, however the
  concrete implementations may deviate from it. See the respective
  implementation for details.

  In general, the assert statement should not be used for control flow.
  Furthermore, it is encouraged that the assertion expressions should not have
  side effects.

  Args:
    expression1: Any
    expression2: Callable[[], Any], returns the expression to include in the
        error message when expression1 evaluates to False. When expression1 is
        True, the result of expression2 will not be evaluated, however,
        expression2 itself may be evaluated in some implementations.

  Returns:
    Any, implementation-dependent.

  Raises:
    ValueError: if any arguments are illegal.
  """
  if not callable(expression2):
    raise ValueError('{} must be a callable'.format(expression2))
  args, _, keywords, _ = tf_inspect.getargspec(expression2)
  if args or keywords:
    raise ValueError('{} may not have any arguments'.format(expression2))

  if tensor_util.is_tensor(expression1):
    return _tf_assert_stmt(expression1, expression2)
  else:
    return _py_assert_stmt(expression1, expression2)
示例#20
0
def for_loop(iterated, extra_cond, loop_body, init_state):
  """Functional form of a for statement.

  The loop operates on a so-called state, which includes all symbols that are
  variant across loop iterations, excluding the iterate. In what follows we
  refer to state as either a tuple of entities that represent an actual state,
  or a list of arguments of the corresponding types.

  Args:
    iterated: The entity being iterated over.
    extra_cond: Callable with the state as arguments, and boolean return type.
        An additionnal loop condition.
    loop_body: Callable with the iterate and the state as arguments, and
        state as return type. The actual loop body.
    init_state: Tuple containing the initial state.

  Returns:
    Tuple containing the final state.
  """
  if tensor_util.is_tensor(iterated):
    return _known_len_for_loop(iterated, extra_cond, loop_body, init_state)
  elif isinstance(iterated, dataset_ops.Dataset):
    return _dataset_for_loop(iterated, extra_cond, loop_body, init_state)
  else:
    return _py_for_loop(iterated, extra_cond, loop_body, init_state)
示例#21
0
  def map_resources(self):
    """Makes new resource handle ops corresponding to existing resource tensors.

    Creates resource handle ops in the current default graph, whereas
    `accessible_objects` will be from an eager context. Resource mapping adds
    resource handle ops to the main GraphDef of a SavedModel, which allows the
    C++ loader API to interact with variables.

    Returns:
      A tuple of (object_map, resource_map, asset_info):
        object_map: A dictionary mapping from object in `accessible_objects` to
          replacement objects created to hold the new resource tensors.
        resource_map: A dictionary mapping from resource tensors extracted from
          `accessible_objects` to newly created resource tensors.
        asset_info: An _AssetInfo tuple describing external assets referenced
          from accessible_objects.
    """
    # Only makes sense when adding to the export Graph
    assert not context.executing_eagerly()
    # TODO(allenl): Handle MirroredVariables and other types of variables which
    # may need special casing.
    object_map = object_identity.ObjectIdentityDictionary()
    resource_map = {}
    asset_info = _AssetInfo(
        asset_defs=[],
        asset_initializers_by_resource={},
        asset_filename_map={},
        asset_index={})
    for node_id, obj in enumerate(self.nodes):
      if isinstance(obj, tracking.TrackableResource):
        new_resource = obj._create_resource()  # pylint: disable=protected-access
        resource_map[obj.resource_handle] = new_resource
        self.captured_tensor_node_ids[obj.resource_handle] = node_id
      elif resource_variable_ops.is_resource_variable(obj):
        new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj)
        object_map[obj] = new_variable
        resource_map[obj.handle] = new_variable.handle
        self.captured_tensor_node_ids[obj.handle] = node_id
      elif isinstance(obj, tracking.TrackableAsset):
        _process_asset(obj, asset_info, resource_map)
        self.captured_tensor_node_ids[obj.asset_path] = node_id

    for concrete_function in self.concrete_functions:
      for capture in concrete_function.captured_inputs:
        if (tensor_util.is_tensor(capture)
            and capture.dtype not in _UNCOPIABLE_DTYPES
            and capture not in self.captured_tensor_node_ids):
          copied_tensor = constant_op.constant(
              tensor_util.constant_value(capture))
          node_id = len(self.nodes)
          node = _CapturedConstant(
              eager_tensor=capture, graph_tensor=copied_tensor)
          self.nodes.append(node)
          self.node_ids[capture] = node_id
          self.node_ids[node] = node_id
          self.captured_tensor_node_ids[capture] = node_id
          resource_map[capture] = copied_tensor

    return object_map, resource_map, asset_info
 def test_match_staging_level(self):
   some_tensor = constant_op.constant(0)
   tensor_one = special_functions.match_staging_level(1, some_tensor)
   python_one = special_functions.match_staging_level(1, 1)
   with self.cached_session() as sess:
     self.assertTrue(tensor_util.is_tensor(tensor_one))
     self.assertAllEqual(sess.run(tensor_one), 1)
     self.assertEqual(python_one, 1)
示例#23
0
def len_(s):
  if tensors.is_tensor_array(s):
    return _tf_tensor_array_len(s)
  elif tensors.is_tensor_list(s):
    return _tf_tensor_list_len(s)
  elif tensor_util.is_tensor(s):
    return _tf_tensor_len(s)
  return _py_len(s)
示例#24
0
 def set_of_lengths(x):
   # Returns a set with the variation between
   # different shapes, with None => 0
   if x is None:
     return {}
   else:
     return set([y.shape[0] for y in x
                 if y is not None and not tensor_util.is_tensor(y)])
示例#25
0
def standardize_single_array(x):
  if x is None:
    return None
  elif tensor_util.is_tensor(x):
    return x
  elif x.ndim == 1:
    x = np.expand_dims(x, 1)
  return x
示例#26
0
def standardize_single_array(x):
  if x is None:
    return None
  if x.shape is not None and len(x.shape) == 1:
    if tensor_util.is_tensor(x):
      x = array_ops.expand_dims(x, axis=1)
    else:
      x = np.expand_dims(x, 1)
  return x
示例#27
0
def _is_not_callable(obj):
  # TODO(brianklee): Handle case when obj is a tensor dependent on a py_func.
  if isinstance(obj, (int, float, complex, str, bool)):
    return True
  if isinstance(obj, (np.ndarray, np.generic)):
    return True
  if tensor_util.is_tensor(obj):
    return True
  return False
示例#28
0
 def print_wrapper(*vals):
   vals = tuple(v.numpy() if tensor_util.is_tensor(v) else v for v in vals)
   if six.PY3:
     # TensorFlow doesn't seem to generate Unicode when passing strings to
     # py_func. This causes the print to add a "b'" wrapper to the output,
     # which is probably never what you want.
     vals = tuple(
         v.decode('utf-8') if isinstance(v, bytes) else v for v in vals)
   six.print_(*vals, **override_kwargs)
示例#29
0
def dynamic_len(list_or_tensor):
  """Implementation of len using dynamic dispatch."""
  if tensor_util.is_tensor(list_or_tensor):
    shape = list_or_tensor.shape
    if not shape:
      raise ValueError(
          'len requires non-zero rank for tensor "%s"' % list_or_tensor)
    return array_ops.shape(list_or_tensor)[0]
  return len(list_or_tensor)
示例#30
0
def _convert_tensor(x):
  """Create or cast tensor if needed."""
  if not tensor_util.is_tensor(x):
    # x is a numpy array
    x = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(x)
  if check_ops.is_numeric_tensor(x):
    # is_numeric_tensor returns False if provided with a numpy array
    x = _cast_tensor_to_floatx(x)
  return x
示例#31
0
def range_(start_or_stop, stop=UNSPECIFIED, step=UNSPECIFIED):
    if any(tensor_util.is_tensor(s) for s in (start_or_stop, stop, step)):
        return _tf_range(start_or_stop, stop, step)
    return _py_range(start_or_stop, stop, step)
示例#32
0
        def tpu_function(args, kwargs):
            """TF Function used to replicate the user computation."""
            if kwargs is None:
                kwargs = {}

            # Remove None at the end of args as they are not replicatable
            # If there are None in the middle we can't do anything about it
            # so let those cases fail.
            # For example when Keras model predict is used they pass the targets as
            # None. We want to handle it here so all client libraries don't have to
            # do this as other strategies can handle None values better.
            while args and args[-1] is None:
                args = args[:-1]

            # Used to re-structure flattened output tensors from `tpu.replicate()`
            # into a structured format.
            result = [[]]

            def replicated_fn(replica_id, replica_args, replica_kwargs):
                """Wraps user function to provide replica ID and `Tensor` inputs."""
                with _TPUReplicaContext(strategy,
                                        replica_id_in_sync_group=replica_id):
                    result[0] = fn(*replica_args, **replica_kwargs)
                return result[0]

            replicate_inputs = []  # By replica.
            for i in range(strategy.num_replicas_in_sync):
                replicate_inputs.append([
                    constant_op.constant(i, dtype=dtypes.int32),
                    values.select_replica(i, args),
                    values.select_replica(i, kwargs)
                ])

            # Construct and pass `maximum_shapes` so that we could support dynamic
            # shapes using dynamic padder.
            if options.experimental_enable_dynamic_batch_size and replicate_inputs:
                maximum_shapes = []
                flattened_list = nest.flatten(replicate_inputs[0])
                for input_tensor in flattened_list:
                    if tensor_util.is_tensor(input_tensor):
                        rank = input_tensor.get_shape().rank
                    else:
                        rank = np.rank(input_tensor)
                    maximum_shape = tensor_shape.TensorShape([None] * rank)
                    maximum_shapes.append(maximum_shape)
                maximum_shapes = nest.pack_sequence_as(replicate_inputs[0],
                                                       maximum_shapes)
            else:
                maximum_shapes = None

            if options.experimental_bucketizing_dynamic_shape:
                padding_spec = tpu.PaddingSpec.POWER_OF_TWO
            else:
                padding_spec = None

            with strategy.scope():
                replicate_outputs = tpu.replicate(
                    replicated_fn,
                    replicate_inputs,
                    device_assignment=self._device_assignment,
                    maximum_shapes=maximum_shapes,
                    padding_spec=padding_spec)

            # Remove all no ops that may have been added during 'tpu.replicate()'
            if isinstance(result[0], list):
                result[0] = [
                    output for output in result[0]
                    if not isinstance(output, ops.Operation)
                ]

            # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
            if result[0] is None or isinstance(result[0], ops.Operation):
                replicate_outputs = [None] * len(replicate_outputs)
            else:
                replicate_outputs = [
                    nest.pack_sequence_as(result[0],
                                          nest.flatten(replica_output))
                    for replica_output in replicate_outputs
                ]
            return values.regroup(replicate_outputs)
示例#33
0
def has_tensors(ls):
    if isinstance(ls, (list, tuple)):
        return any(tensor_util.is_tensor(v) for v in ls)
    if isinstance(ls, dict):
        return any(tensor_util.is_tensor(v) for _, v in six.iteritems(ls))
    return tensor_util.is_tensor(ls)
示例#34
0
def abs_(x):
    if tensor_util.is_tensor(x):
        return _tf_abs(x)
    return _py_abs(x)
示例#35
0
 def testConstantTensor(self):
     np_val = np.random.rand(3).astype(np.int32)
     tf_val = constant_op.constant(np_val)
     self.assertFalse(tensor_util.is_tensor(np_val))
     self.assertTrue(tensor_util.is_tensor(tf_val))
示例#36
0
def has_tensors(ls):
    if isinstance(ls, (list, tuple)):
        return any(tensor_util.is_tensor(v) for v in ls)
    return tensor_util.is_tensor(ls)
示例#37
0
def abs_(x):
    if tensor_util.is_tensor(x):
        return _tf_abs(x)
    if isinstance(x, dataset_ops.DatasetV2):
        return _tf_dataset_abs(x)
    return _py_abs(x)
示例#38
0
    def set_vocabulary(self, vocabulary, idf_weights=None):
        """Sets vocabulary (and optionally document frequency) data for this layer.

    This method sets the vocabulary and idf weights for this layer directly,
    instead of analyzing a dataset through `adapt`. It should be used whenever
    the vocab (and optionally document frequency) information is already known.
    If vocabulary data is already present in the layer, this method will replace
    it.

    Args:
      vocabulary: An array, numpy array, or tensor of hashable tokens.
      idf_weights: An array, numpy array, or tensor of inverse document
        frequency weights with equal length to vocab. Only necessary if the
        layer output_mode is TF_IDF.

    Raises:
      ValueError: If there are too many inputs, the inputs do not match, or
        input data is missing.
      RuntimeError: If the vocabulary cannot be set when this function is
        called. This happens when `"multi_hot"`, `"count"`, and `"tfidf"` modes,
        if `pad_to_max_tokens` is False and the layer itself has already been
        called.
      RuntimeError: If a tensor vocabulary is passed outside of eager execution.
    """
        if self._has_static_table:
            raise RuntimeError(
                "Layer {} was created with a static file-based table "
                "because a file path was passed to the layer "
                "init. Layers created with static file-based tables "
                "do not support changing the vocabulary after "
                "creation.".format(self.name))

        if self.output_mode != TF_IDF and idf_weights is not None:
            raise ValueError(
                "`idf_weights` should only be set if output_mode is "
                "TF_IDF. output_mode is {}.".format(self.output_mode))

        if (self.output_mode in [MULTI_HOT, COUNT, TF_IDF] and self._called
                and not self.pad_to_max_tokens):
            raise RuntimeError(
                "When using {} mode and `pad_to_max_tokens` is "
                "False, the vocabulary cannot be changed after the "
                "layer is called.".format(self.output_mode))

        if not context.executing_eagerly() and (
                tensor_util.is_tensor(vocabulary)
                or tensor_util.is_tensor(idf_weights)):
            raise RuntimeError(
                "Cannot set a tensor vocabulary on {} layer {} when not executing "
                "eagerly. Create this layer or call `set_vocabulary` outside of "
                "any `tf.function`s and with eager execution enabled.".format(
                    self.__class__.__name__, self.name))

        # TODO(mattdangerw): for better performance we should rewrite this entire
        # function to operate on tensors and convert vocabulary to a tensor here.
        if tensor_util.is_tensor(vocabulary):
            vocabulary = self._tensor_vocab_to_numpy(vocabulary)
        if tensor_util.is_tensor(idf_weights):
            idf_weights = idf_weights.numpy()

        oov_start = self._oov_start_index()
        token_start = self._token_start_index()
        should_have_mask = (oov_start > 0)
        has_mask = should_have_mask and vocabulary[0] == self.mask_token

        should_have_oov = (self.num_oov_indices > 0)
        expected_oov = [self.oov_token] * self.num_oov_indices
        found_oov = vocabulary[oov_start:token_start]
        has_oov = should_have_oov and found_oov == expected_oov
        # If we get a numpy array, then has_oov may end up being a numpy array
        # instead of a bool. Fix this by collapsing the variable if it's not bool.
        if not isinstance(has_oov, bool):
            has_oov = any(has_oov)

        if all([should_have_mask, has_mask, should_have_oov]) and not has_oov:
            raise ValueError(
                "Invalid vocabulary format. The layer was created with "
                "`mask_token={mask}` and `oov_token={oov}`. These tokens should be "
                "included in the provided vocabulary. The passed vocabulary has the "
                "correct mask token `{mask}` at index 0, but does not have the OOV "
                "token `{oov}` in indices [{start}:{end}]. Instead, we found "
                "`{found}`. Was this vocabulary generated by a layer with "
                "incompatible settings?".format(mask=self.mask_token,
                                                oov=self.oov_token,
                                                start=oov_start,
                                                end=token_start,
                                                found=found_oov))

        if all([should_have_oov, has_oov, should_have_mask]) and not has_mask:
            raise ValueError(
                "Invalid vocabulary format. The layer was created with "
                "`mask_token={mask}` and `oov_token={oov}`. These tokens should be "
                "included in the provided vocabulary. The passed vocabulary has the "
                "correct OOV token `{oov}` at indices [{start}:{end}], but does not "
                "have the mask token `{mask}` in index 0. Instead, we found "
                "`{found}`. Was this vocabulary generated by a layer with "
                "incompatible settings?".format(mask=self.mask_token,
                                                oov=self.oov_token,
                                                start=oov_start,
                                                end=token_start,
                                                found=vocabulary[0]))

        found_special_tokens = has_oov or has_mask
        if found_special_tokens:
            tokens = vocabulary[token_start:]
        else:
            tokens = vocabulary

        repeated_tokens = table_utils.find_repeated_tokens(tokens)
        if repeated_tokens:
            raise ValueError(
                "The passed vocabulary has at least one repeated "
                "term. Please uniquify your dataset. The repeated terms "
                "are {}".format(repeated_tokens))

        if self.mask_token in tokens:
            raise ValueError(
                "Reserved mask token {} was found in the passed "
                "vocabulary at index {}. Please either remove the "
                "reserved token from the vocabulary or change the "
                "mask token for this layer.".format(
                    self.mask_token, tokens.index(self.mask_token)))
        if self.oov_token in tokens:
            raise ValueError(
                "Reserved OOV token {} was found in the passed "
                "vocabulary at index {}. Please either remove the "
                "reserved token from the vocabulary or change the "
                "OOV token for this layer.".format(
                    self.oov_token, tokens.index(self.oov_token)))

        self._vocab_size = token_start + len(tokens)
        if self.max_tokens is not None and self._vocab_size > self.max_tokens:
            raise ValueError(
                "Attempted to set a vocabulary larger than the maximum vocab size. "
                "Passed vocab size is {}, max vocab size is {}.".format(
                    self._vocab_size, self.max_tokens))

        if self.output_mode == TF_IDF:
            if idf_weights is None:
                raise ValueError(
                    "`idf_weights` must be set if output_mode is TF_IDF")
            if len(vocabulary) != len(idf_weights):
                raise ValueError(
                    "`idf_weights` must be the same length as vocabulary. "
                    "len(idf_weights) is {}, len(vocabulary) is {}".format(
                        len(vocabulary), len(idf_weights)))
            idf_weights = self._convert_to_ndarray(idf_weights)
            if idf_weights.ndim != 1:
                raise ValueError(
                    "TF-IDF data must be a 1-index array, but received {}".
                    format(type(idf_weights)))

        # We add the non-special vocab tokens and optionally the mask_token to our
        # hash table. OOV tokens are handled with the hash table default value and
        # not added directly.
        self._table_handler.clear()
        indices = np.arange(token_start,
                            len(tokens) + token_start,
                            dtype=np.int64)
        if self.invert:
            self._table_handler.insert(indices, tokens)
        else:
            self._table_handler.insert(tokens, indices)
        if self.mask_token is not None:
            self._table_handler.insert([self._mask_key], [self._mask_value])

        if self.output_mode == TF_IDF:
            # If the passed vocabulary has no special tokens, we need to pad the front
            # of idf_weights. We don't have real document frequencies for these tokens
            # so we will use an average of all idf_weights passed in as a reasonable
            # default.
            if found_special_tokens:
                front_padding = 0
                front_padding_value = 0
            else:
                front_padding = token_start
                front_padding_value = np.average(idf_weights)
            # If pad_to_max_tokens is true, and max_tokens is greater than our total
            # vocab size, we need to pad the back of idf_weights with zeros as well.
            back_padding_value = 0
            if self.pad_to_max_tokens and self.max_tokens is not None:
                back_padding = self.max_tokens - front_padding - len(
                    idf_weights)
            else:
                back_padding = 0
            idf_weights = np.pad(idf_weights, (front_padding, back_padding),
                                 "constant",
                                 constant_values=(front_padding_value,
                                                  back_padding_value))
            backend.set_value(self.tf_idf_weights, idf_weights)
def for_stmt(iter_,
             extra_test,
             body,
             get_state,
             set_state,
             init_vars,
             basic_symbol_names,
             composite_symbol_names,
             opts):
  """Functional form of a for statement.

  The loop operates on a state, which includes all symbols that are
  variant across loop iterations, excluding the iterate as well as the
  variables local to the loop.

  For example, given the loop below that calculates the geometric and
  arithmetic means or some numbers:

    geo_mean = 1
    arith_mean = 0
    for i in range(n):
      a = numbers[i]
      geo_mean *= a
      arith_mean += a

  The state is represented by the variables geo_mean and arith_mean. The
  argument for initial_state may contain the tuple (1, 0), the body will
  include the arguments geo_mean and arith_mean and will return a tuple
  representing the new values for geo_mean and respectively arith_mean.

  Args:
    iter_: The entity being iterated over.
    extra_test: Callable with the state as arguments, and boolean return type.
      An additional loop condition.
    body: Callable with the iterate and the state as arguments, and state as
      return type. The actual loop body.
    get_state: Additional callable which can capture additional state (such as
      the values of composite symbols). This is only useful when staging the
      loop.
    set_state: Additional callable which save values captured by get_state back
      into the Python environment. This is only useful when staging the loop.
    init_vars: Tuple containing the initial state.
    basic_symbol_names: Tuple containing basic loop var names.
    composite_symbol_names: Tuple containing composite loop var names.
    opts: Optional dict of extra loop parameters.

  Returns:
    Tuple containing the final state.
  """
  if tensor_util.is_tensor(iter_):
    if tensors.is_range_tensor(iter_):
      return _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state,
                                init_vars, basic_symbol_names,
                                composite_symbol_names, opts)
    else:
      return _known_len_tf_for_stmt(iter_, extra_test, body, get_state,
                                    set_state, init_vars, basic_symbol_names,
                                    composite_symbol_names, opts)

  if isinstance(iter_, dataset_ops.DatasetV2):
    return _tf_dataset_for_stmt(iter_, extra_test, body, get_state, set_state,
                                init_vars, basic_symbol_names,
                                composite_symbol_names, opts)

  if isinstance(iter_, iterator_ops.OwnedIterator):
    return _tf_iterator_for_stmt(iter_, extra_test, body, get_state, set_state,
                                 init_vars, basic_symbol_names,
                                 composite_symbol_names, opts)

  if isinstance(iter_, ragged_tensor.RaggedTensor):
    return _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state,
                               init_vars, basic_symbol_names,
                               composite_symbol_names, opts)

  if isinstance(iter_, input_lib.DistributedIterator):
    raise NotImplementedError(
        'distributed iterators not supported yet, use the distributed dataset'
        ' directly')

  if isinstance(iter_, input_lib.DistributedDataset):
    return _tf_distributed_dataset_for_stmt(iter_, extra_test, body, init_vars)

  return _py_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars)
示例#40
0
def float_(x=0):
    if tensor_util.is_tensor(x):
        return _tf_float(x)
    return _py_float(x)
示例#41
0
 def assertValuesEqual(self, actual, expected):
     values = nest.map_structure(
         lambda x: self.evaluate(x)
         if tensor_util.is_tensor(x) else x, actual)
     self.assertAllEqual(values, expected)
示例#42
0
def _convert_tensor(x):
    """Create or cast tensor if needed."""
    if not tensor_util.is_tensor(x):
        # x is a numpy array
        x = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(x)
    return x
示例#43
0
    def __init__(self,
                 dtype,
                 graph_parents=None,
                 is_non_singular=None,
                 is_self_adjoint=None,
                 is_positive_definite=None,
                 is_square=None,
                 name=None):
        r"""Initialize the `LinearOperator`.

    **This is a private method for subclass use.**
    **Subclasses should copy-paste this `__init__` documentation.**

    Args:
      dtype: The type of the this `LinearOperator`.  Arguments to `matmul` and
        `solve` will have to be this type.
      graph_parents: Python list of graph prerequisites of this `LinearOperator`
        Typically tensors that are passed during initialization.
      is_non_singular:  Expect that this operator is non-singular.
      is_self_adjoint:  Expect that this operator is equal to its hermitian
        transpose.  If `dtype` is real, this is equivalent to being symmetric.
      is_positive_definite:  Expect that this operator is positive definite,
        meaning the quadratic form `x^H A x` has positive real part for all
        nonzero `x`.  Note that we do not require the operator to be
        self-adjoint to be positive-definite.  See:
        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
      is_square:  Expect that this operator acts like square [batch] matrices.
      name: A name for this `LinearOperator`.

    Raises:
      ValueError:  If any member of graph_parents is `None` or not a `Tensor`.
      ValueError:  If hints are set incorrectly.
    """
        # Check and auto-set flags.
        if is_positive_definite:
            if is_non_singular is False:
                raise ValueError(
                    "A positive definite matrix is always non-singular.")
            is_non_singular = True

        if is_non_singular:
            if is_square is False:
                raise ValueError("A non-singular matrix is always square.")
            is_square = True

        if is_self_adjoint:
            if is_square is False:
                raise ValueError("A self-adjoint matrix is always square.")
            is_square = True

        self._is_square_set_or_implied_by_hints = is_square

        graph_parents = [] if graph_parents is None else graph_parents
        for i, t in enumerate(graph_parents):
            if t is None or not tensor_util.is_tensor(t):
                raise ValueError("Graph parent item %d is not a Tensor; %s." %
                                 (i, t))
        self._dtype = dtype
        self._graph_parents = graph_parents
        self._is_non_singular = is_non_singular
        self._is_self_adjoint = is_self_adjoint
        self._is_positive_definite = is_positive_definite
        self._name = name or type(self).__name__
示例#44
0
def standardize_input_data(data,
                           names,
                           shapes=None,
                           check_batch_axis=True,
                           exception_prefix=''):
    """Normalizes inputs and targets provided by users.

  Users may pass data as a list of arrays, dictionary of arrays,
  or as a single array. We normalize this to an ordered list of
  arrays (same order as `names`), while checking that the provided
  arrays have shapes that match the network's expectations.

  Arguments:
      data: User-provided input data (polymorphic).
      names: List of expected array names.
      shapes: Optional list of expected array shapes.
      check_batch_axis: Boolean; whether to check that
          the batch axis of the arrays matches the expected
          value found in `shapes`.
      exception_prefix: String prefix used for exception formatting.

  Returns:
      List of standardized input arrays (one array per model input).

  Raises:
      ValueError: in case of improperly formatted user-provided data.
  """
    if not names:
        if data is not None and hasattr(data, '__len__') and len(data):
            raise ValueError(
                'Error when checking model ' + exception_prefix + ': '
                'expected no data, but got:', data)
        return []
    if data is None:
        return [None for _ in range(len(names))]

    if isinstance(data, dict):
        try:
            data = [
                data[x].values
                if data[x].__class__.__name__ == 'DataFrame' else data[x]
                for x in names
            ]
        except KeyError as e:
            raise ValueError('No data provided for "' + e.args[0] +
                             '". Need data '
                             'for each key in: ' + str(names))
    elif isinstance(data, (list, tuple)):
        if isinstance(data[0], (list, tuple)):
            data = [np.asarray(d) for d in data]
        elif len(names) == 1 and isinstance(data[0], (float, int)):
            data = [np.asarray(data)]
        else:
            data = [
                x.values if x.__class__.__name__ == 'DataFrame' else x
                for x in data
            ]
    else:
        data = data.values if data.__class__.__name__ == 'DataFrame' else data
        data = [data]
    data = [standardize_single_array(x) for x in data]

    if len(data) != len(names):
        if data and hasattr(data[0], 'shape'):
            raise ValueError(
                'Error when checking model ' + exception_prefix +
                ': the list of Numpy arrays that you are passing to '
                'your model is not the size the model expected. '
                'Expected to see ' + str(len(names)) + ' array(s), '
                'but instead got the following list of ' + str(len(data)) +
                ' arrays: ' + str(data)[:200] + '...')
        elif len(names) > 1:
            raise ValueError(
                'Error when checking model ' + exception_prefix +
                ': you are passing a list as input to your model, '
                'but the model expects a list of ' + str(len(names)) +
                ' Numpy arrays instead. The list you passed was: ' +
                str(data)[:200])
        elif len(data) == 1 and not hasattr(data[0], 'shape'):
            raise TypeError('Error when checking model ' + exception_prefix +
                            ': data should be a Numpy array, or list/dict of '
                            'Numpy arrays. Found: ' + str(data)[:200] + '...')
        elif len(names) == 1:
            data = [np.asarray(data)]

    # Check shapes compatibility.
    if shapes:
        for i in range(len(names)):
            if shapes[i] is not None:
                if tensor_util.is_tensor(data[i]):
                    tensorshape = data[i].get_shape()
                    if not tensorshape:
                        continue
                    data_shape = tuple(tensorshape.as_list())
                else:
                    data_shape = data[i].shape
                shape = shapes[i]
                if len(data_shape) != len(shape):
                    raise ValueError('Error when checking ' +
                                     exception_prefix + ': expected ' +
                                     names[i] + ' to have ' + str(len(shape)) +
                                     ' dimensions, but got array '
                                     'with shape ' + str(data_shape))
                if not check_batch_axis:
                    data_shape = data_shape[1:]
                    shape = shape[1:]
                for dim, ref_dim in zip(data_shape, shape):
                    if ref_dim != dim and ref_dim is not None and dim is not None:
                        raise ValueError('Error when checking ' +
                                         exception_prefix + ': expected ' +
                                         names[i] + ' to have shape ' +
                                         str(shape) +
                                         ' but got array with shape ' +
                                         str(data_shape))
    return data
示例#45
0
def sorted_(iterable, key=UNSPECIFIED, reverse=UNSPECIFIED):
    if tensor_util.is_tensor(iterable):
        return _tf_sorted(iterable, key, reverse)
    return _py_sorted(iterable, key, reverse)
示例#46
0
def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names,
             opts):
    """Functional form of a for statement.

  The loop operates on a state, which includes all symbols that are
  variant across loop iterations, excluding the variables local to the loop.

  For example, given the loop below that calculates the geometric and
  arithmetic means or some numbers:

  ```
    geo_mean = 1
    arith_mean = 0
    for i in range(n):
      a = numbers[i]
      geo_mean *= a
      arith_mean += a
  ```

  The state is represented by the variables geo_mean and arith_mean. The
  `extra_test`, `body`, `get_state` and `set_state` functions must bind to the
  original `geo_mean` and `arith_mean` symbols, using `nonlocal`.

  The inputs and outputs of the callables representing the loop blocks are not
  explicit - instead, these functions must use nonlocal/global for side effects.
  The inputs and outputs are instead controlled by the set_state/get_state
  functions.

  Args:
    iter_: The entity being iterated over.
    extra_test: Callable with boolean return type.
      An additional loop condition.
    body: Callable representing the actual loop body.
    get_state: Additional callable which can capture additional state (such as
      the values of composite symbols). This is only useful when staging the
      loop.
    set_state: Additional callable which save values captured by get_state back
      into the Python environment. This is only useful when staging the loop.
    symbol_names: Tuple containing names of the loop variables returned by
      get_state.
    opts: Optional dict of extra loop parameters.
  """
    if tensor_util.is_tensor(iter_):
        if tensors.is_range_tensor(iter_):
            _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state,
                               symbol_names, opts)
        elif isinstance(iter_, ragged_tensor.RaggedTensor):
            _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state,
                                symbol_names, opts)
        else:
            _known_len_tf_for_stmt(iter_, extra_test, body, get_state,
                                   set_state, symbol_names, opts)

    elif isinstance(iter_, dataset_ops.DatasetV2):
        _tf_dataset_for_stmt(iter_, extra_test, body, get_state, set_state,
                             symbol_names, opts)

    elif isinstance(iter_, iterator_ops.OwnedIterator):
        _tf_iterator_for_stmt(iter_, extra_test, body, get_state, set_state,
                              symbol_names, opts)

    elif isinstance(iter_, ragged_tensor.RaggedTensor):
        _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state,
                            symbol_names, opts)

    elif isinstance(iter_, distribute.Iterator):
        _tf_iterator_for_stmt(iter_, extra_test, body, get_state, set_state,
                              symbol_names, opts)

    elif isinstance(iter_, distribute.Iterable):
        # TODO(b/162250181): Use _tf_iterator_for_stmt(iter(iter_)...
        _tf_distributed_iterable_for_stmt(iter_, extra_test, body, get_state,
                                          set_state, symbol_names, opts)

    else:
        _py_for_stmt(iter_, extra_test, body, None, None)
示例#47
0
def cast_single_tensor(x):
    if tensor_util.is_tensor(x) and x.dtype.is_floating:
        return math_ops.cast(x, dtype=K.floatx())
    return x
示例#48
0
def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts):
    """Overload of while_stmt that stages a TF while_stmt."""
    init_vars = get_state()
    orig_init_vars = init_vars

    nulls = tuple(_is_none_or_undef(v) for v in init_vars)
    if any(nulls):
        require_one_iteration, init_vars = _try_handling_undefineds(
            body, get_state, set_state, init_vars, nulls, symbol_names)
    else:
        require_one_iteration = False

    def aug_test(*loop_vars):
        if require_one_iteration:
            loop_vars = loop_vars[1:]

        set_state(loop_vars)
        return _verify_tf_condition(test(), 'while loop')

    def aug_body(*loop_vars):
        if require_one_iteration:
            loop_vars = loop_vars[1:]

        set_state(loop_vars)
        body()
        new_loop_vars = get_state()
        _verify_tf_loop_vars(init_vars, loop_vars, new_loop_vars, symbol_names,
                             opts)

        if require_one_iteration:
            new_loop_vars = (True, ) + new_loop_vars

        return new_loop_vars

    if 'shape_invariants' in opts:
        opts[
            'shape_invariants'] = _shape_invariants_mapping_to_positional_list(
                opts['shape_invariants'], init_vars)

    while_loop_opts = dict(opts)
    while_loop_opts.pop('iterate_names', None)

    # Non-v2 while_loop unpacks the results when there is only one return value.
    # This enforces consistency across versions.
    while_loop_opts['return_same_structure'] = True

    if require_one_iteration:
        aug_init_vars = (False, ) + init_vars
    else:
        aug_init_vars = init_vars

    final_loop_vars = control_flow_ops.while_loop(aug_test, aug_body,
                                                  aug_init_vars,
                                                  **while_loop_opts)

    if require_one_iteration:
        with ops.control_dependencies([
                control_flow_ops.Assert(final_loop_vars[0], [
                    _runtime_zero_iterations_errmsg(symbol_names, nulls,
                                                    orig_init_vars)
                ])
        ]):
            final_loop_vars = nest.map_structure(
                lambda v: (array_ops.identity(v)
                           if tensor_util.is_tensor(v) else v),
                final_loop_vars[1:],
            )

    set_state(final_loop_vars)
示例#49
0
    def __init__(self,
                 shift=None,
                 scale_identity_multiplier=None,
                 scale_diag=None,
                 scale_tril=None,
                 scale_perturb_factor=None,
                 scale_perturb_diag=None,
                 adjoint=False,
                 validate_args=False,
                 name="affine",
                 dtype=None):
        """Instantiates the `Affine` bijector.

    This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments,
    giving the forward operation:

    ```none
    Y = g(X) = scale @ X + shift
    ```

    where the `scale` term is logically equivalent to:

    ```python
    scale = (
      scale_identity_multiplier * tf.diag(tf.ones(d)) +
      tf.diag(scale_diag) +
      scale_tril +
      scale_perturb_factor @ diag(scale_perturb_diag) @
        tf.transpose([scale_perturb_factor])
    )
    ```

    If none of `scale_identity_multiplier`, `scale_diag`, or `scale_tril` are
    specified then `scale += IdentityMatrix`. Otherwise specifying a
    `scale` argument has the semantics of `scale += Expand(arg)`, i.e.,
    `scale_diag != None` means `scale += tf.diag(scale_diag)`.

    Args:
      shift: Floating-point `Tensor`. If this is set to `None`, no shift is
        applied.
      scale_identity_multiplier: floating point rank 0 `Tensor` representing a
        scaling done to the identity matrix.
        When `scale_identity_multiplier = scale_diag = scale_tril = None` then
        `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added
        to `scale`.
      scale_diag: Floating-point `Tensor` representing the diagonal matrix.
        `scale_diag` has shape `[N1, N2, ...  k]`, which represents a k x k
        diagonal matrix.
        When `None` no diagonal term is added to `scale`.
      scale_tril: Floating-point `Tensor` representing the lower triangular
        matrix. `scale_tril` has shape `[N1, N2, ...  k, k]`, which represents a
        k x k lower triangular matrix.
        When `None` no `scale_tril` term is added to `scale`.
        The upper triangular elements above the diagonal are ignored.
      scale_perturb_factor: Floating-point `Tensor` representing factor matrix
        with last two dimensions of shape `(k, r)`. When `None`, no rank-r
        update is added to `scale`.
      scale_perturb_diag: Floating-point `Tensor` representing the diagonal
        matrix. `scale_perturb_diag` has shape `[N1, N2, ...  r]`, which
        represents an `r x r` diagonal matrix. When `None` low rank updates will
        take the form `scale_perturb_factor * scale_perturb_factor.T`.
      adjoint: Python `bool` indicating whether to use the `scale` matrix as
        specified or its adjoint.
        Default value: `False`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.
      dtype: `tf.DType` to prefer when converting args to `Tensor`s. Else, we
        fall back to a common dtype inferred from the args, finally falling back
        to float32.

    Raises:
      ValueError: if `perturb_diag` is specified but not `perturb_factor`.
      TypeError: if `shift` has different `dtype` from `scale` arguments.
    """
        self._graph_parents = []
        self._name = name
        self._validate_args = validate_args

        # Ambiguous definition of low rank update.
        if scale_perturb_diag is not None and scale_perturb_factor is None:
            raise ValueError("When scale_perturb_diag is specified, "
                             "scale_perturb_factor must be specified.")

        # Special case, only handling a scaled identity matrix. We don't know its
        # dimensions, so this is special cased.
        # We don't check identity_multiplier, since below we set it to 1. if all
        # other scale args are None.
        self._is_only_identity_multiplier = (scale_tril is None
                                             and scale_diag is None
                                             and scale_perturb_factor is None)

        with self._name_scope("init",
                              values=[
                                  shift, scale_identity_multiplier, scale_diag,
                                  scale_tril, scale_perturb_diag,
                                  scale_perturb_factor
                              ]):

            if dtype is None:
                dtype = dtype_util.common_dtype([
                    shift, scale_identity_multiplier, scale_diag, scale_tril,
                    scale_perturb_diag, scale_perturb_factor
                ], tf.float32)

            if shift is not None:
                shift = tf.convert_to_tensor(shift, name="shift", dtype=dtype)
            self._shift = shift

            # When no args are specified, pretend the scale matrix is the identity
            # matrix.
            if (self._is_only_identity_multiplier
                    and scale_identity_multiplier is None):
                scale_identity_multiplier = tf.convert_to_tensor(1.,
                                                                 dtype=dtype)

            # self._create_scale_operator returns a LinearOperator in all cases
            # except if self._is_only_identity_multiplier; in which case it
            # returns a scalar Tensor.
            scale = self._create_scale_operator(
                identity_multiplier=scale_identity_multiplier,
                diag=scale_diag,
                tril=scale_tril,
                perturb_diag=scale_perturb_diag,
                perturb_factor=scale_perturb_factor,
                shift=shift,
                validate_args=validate_args,
                dtype=dtype)

            if scale is not None and not self._is_only_identity_multiplier:
                if (shift is not None
                        and shift.dtype.base_dtype != scale.dtype.base_dtype):
                    raise TypeError(
                        "shift.dtype({}) is incompatible with scale.dtype({})."
                        .format(shift.dtype, scale.dtype))

            self._scale = scale
            self._adjoint = adjoint
            super(Affine, self).__init__(
                forward_min_event_ndims=1,
                graph_parents=(
                    [self._scale] if tensor_util.is_tensor(
                        self._scale) else self._scale.graph_parents +
                    [self._shift] if self._shift is not None else []),
                is_constant_jacobian=True,
                dtype=dtype,
                validate_args=validate_args,
                name=name)
def string_format(template, inputs, placeholder="{}", summarize=3, name=None):
    r"""Formats a string template using a list of tensors.

  Formats a string template using a list of tensors, abbreviating tensors by
  only printing the first and last `summarize` elements of each dimension
  (recursively). If formatting only one tensor into a template, the tensor does
  not have to be wrapped in a list.

  Example:
    Formatting a single-tensor template:
    ```python
    sess = tf.compat.v1.Session()
    with sess.as_default():
        tensor = tf.range(10)
        formatted = tf.strings.format("tensor: {}, suffix", tensor)
        out = sess.run(formatted)
        expected = "tensor: [0 1 2 ... 7 8 9], suffix"

        assert(out.decode() == expected)
    ```

    Formatting a multi-tensor template:
    ```python
    sess = tf.compat.v1.Session()
    with sess.as_default():
        tensor_one = tf.reshape(tf.range(100), [10, 10])
        tensor_two = tf.range(10)
        formatted = tf.strings.format("first: {}, second: {}, suffix",
          (tensor_one, tensor_two))

        out = sess.run(formatted)
        expected = ("first: [[0 1 2 ... 7 8 9]\n"
              " [10 11 12 ... 17 18 19]\n"
              " [20 21 22 ... 27 28 29]\n"
              " ...\n"
              " [70 71 72 ... 77 78 79]\n"
              " [80 81 82 ... 87 88 89]\n"
              " [90 91 92 ... 97 98 99]], second: [0 1 2 ... 7 8 9], suffix")

        assert(out.decode() == expected)
    ```

  Args:
    template: A string template to format tensor values into.
    inputs: A list of `Tensor` objects, or a single Tensor.
      The list of tensors to format into the template string. If a solitary
      tensor is passed in, the input tensor will automatically be wrapped as a
      list.
    placeholder: An optional `string`. Defaults to `{}`.
      At each placeholder occurring in the template, a subsequent tensor
      will be inserted.
    summarize: An optional `int`. Defaults to `3`.
      When formatting the tensors, show the first and last `summarize`
      entries of each tensor dimension (recursively). If set to -1, all
      elements of the tensor will be shown.
    name: A name for the operation (optional).

  Returns:
    A scalar `Tensor` of type `string`.

  Raises:
    ValueError: if the number of placeholders does not match the number of
      inputs.
  """
    # If there is only one tensor to format, we will automatically wrap it in a
    # list to simplify the user experience
    if tensor_util.is_tensor(inputs):
        inputs = [inputs]
    if template.count(placeholder) != len(inputs):
        raise ValueError(
            "%s placeholder(s) in template does not match %s tensor(s)"
            " provided as input" % (template.count(placeholder), len(inputs)))

    return gen_string_ops.string_format(inputs,
                                        template=template,
                                        placeholder=placeholder,
                                        summarize=summarize,
                                        name=name)
示例#51
0
def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names,
             opts):
    """Functional form of a for statement.

  The loop operates on a state, which includes all symbols that are
  variant across loop iterations, excluding the variables local to the loop.

  For example, given the loop below that calculates the geometric and
  arithmetic means or some numbers:

  ```
    geo_mean = 1
    arith_mean = 0
    for i in range(n):
      a = numbers[i]
      geo_mean *= a
      arith_mean += a
  ```

  The state is represented by the variables geo_mean and arith_mean. The
  `extra_test`, `body`, `get_state` and `set_state` functions must bind to the
  original `geo_mean` and `arith_mean` symbols, using `nonlocal`.

  Args:
    iter_: The entity being iterated over.
    extra_test: Callable with the state as arguments, and boolean return type.
      An additional loop condition.
    body: Callable with the iterate and the state as arguments, and state as
      return type. The actual loop body.
    get_state: Additional callable which can capture additional state (such as
      the values of composite symbols). This is only useful when staging the
      loop.
    set_state: Additional callable which save values captured by get_state back
      into the Python environment. This is only useful when staging the loop.
    symbol_names: Tuple containing names of the loop variables returned by
      get_state.
    opts: Optional dict of extra loop parameters.

  Returns:
    Tuple containing the final state.
  """
    if tensor_util.is_tensor(iter_):
        if tensors.is_range_tensor(iter_):
            _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state,
                               symbol_names, opts)
        else:
            _known_len_tf_for_stmt(iter_, extra_test, body, get_state,
                                   set_state, symbol_names, opts)

    elif isinstance(iter_, dataset_ops.DatasetV2):
        _tf_dataset_for_stmt(iter_, extra_test, body, get_state, set_state,
                             symbol_names, opts)

    elif isinstance(iter_, iterator_ops.OwnedIterator):
        _tf_iterator_for_stmt(iter_, extra_test, body, get_state, set_state,
                              symbol_names, opts)

    elif isinstance(iter_, ragged_tensor.RaggedTensor):
        _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state,
                            symbol_names, opts)

    elif isinstance(iter_, input_lib.DistributedIterator):
        raise NotImplementedError(
            'distributed iterators not supported yet, use the distributed dataset'
            ' directly')

    # TODO(mdan): Resolve the private access issue.
    elif isinstance(iter_, input_lib._IterableInput):  # pylint:disable=protected-access
        _tf_distributed_iterable_for_stmt(iter_, extra_test, body, get_state,
                                          set_state, symbol_names, opts)

    else:
        _py_for_stmt(iter_, extra_test, body, None, None)
示例#52
0
def print_v2(*inputs, **kwargs):
    """Print the specified inputs.

  A TensorFlow operator that prints the specified inputs to a desired
  output stream or logging level. The inputs may be dense or sparse Tensors,
  primitive python objects, data structures that contain tensors, and printable
  Python objects. Printed tensors will recursively show the first and last
  elements of each dimension to summarize.

  Example:
    Single-input usage:

    ```python
    tensor = tf.range(10)
    tf.print(tensor, output_stream=sys.stderr)
    ```

    (This prints "[0 1 2 ... 7 8 9]" to sys.stderr)

    Multi-input usage:

    ```python
    tensor = tf.range(10)
    tf.print("tensors:", tensor, {2: tensor * 2}, output_stream=sys.stdout)
    ```

    (This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to
    sys.stdout)

    Changing the input separator:
    ```python
    tensor_a = tf.range(2)
    tensor_b = tensor_a * 2
    tf.print(tensor_a, tensor_b, output_stream=sys.stderr, sep=',')
    ```

    (This prints "[0 1],[0 2]" to sys.stderr)

    Usage in a `tf.function`:

    ```python
    @tf.function
    def f():
        tensor = tf.range(10)
        tf.print(tensor, output_stream=sys.stderr)
        return tensor

    range_tensor = f()
    ```

    (This prints "[0 1 2 ... 7 8 9]" to sys.stderr)

  @compatibility(TF 1.x Graphs and Sessions)
  In graphs manually created outside of `tf.function`, this method returns
  the created TF operator that prints the data. To make sure the
  operator runs, users need to pass the produced op to
  `tf.compat.v1.Session`'s run method, or to use the op as a control
  dependency for executed ops by specifying
  `with tf.compat.v1.control_dependencies([print_op])`.
  @end_compatibility

    Compatibility usage in TF 1.x graphs:

    ```python
    sess = tf.compat.v1.Session()
    with sess.as_default():
        tensor = tf.range(10)
        print_op = tf.print("tensors:", tensor, {2: tensor * 2},
                            output_stream=sys.stdout)
        with tf.control_dependencies([print_op]):
          tripled_tensor = tensor * 3
        sess.run(tripled_tensor)
    ```

    (This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to
    sys.stdout)

  Note: In Jupyter notebooks and colabs, `tf.print` prints to the notebook
    cell outputs. It will not write to the notebook kernel's console logs.

  Args:
    *inputs: Positional arguments that are the inputs to print. Inputs in the
      printed output will be separated by spaces. Inputs may be python
      primitives, tensors, data structures such as dicts and lists that may
      contain tensors (with the data structures possibly nested in arbitrary
      ways), and printable python objects.
    output_stream: The output stream, logging level, or file to print to.
      Defaults to sys.stderr, but sys.stdout, tf.compat.v1.logging.info,
      tf.compat.v1.logging.warning, tf.compat.v1.logging.error,
      absl.logging.info, absl.logging.warning and absl.logging.error are also
      supported. To print to a file, pass a string started with "file://"
      followed by the file path, e.g., "file:///tmp/foo.out".
    summarize: The first and last `summarize` elements within each dimension are
      recursively printed per Tensor. If None, then the first 3 and last 3
      elements of each dimension are printed for each tensor. If set to -1, it
      will print all elements of every tensor.
    sep: The string to use to separate the inputs. Defaults to " ".
    end: End character that is appended at the end the printed string.
      Defaults to the newline character.
    name: A name for the operation (optional).

  Returns:
    None when executing eagerly. During graph tracing this returns
    a TF operator that prints the specified inputs in the specified output
    stream or logging level. This operator will be automatically executed
    except inside of `tf.compat.v1` graphs and sessions.

  Raises:
    ValueError: If an unsupported output stream is specified.
  """
    # Because we are using arbitrary-length positional arguments, python 2
    # does not support explicitly specifying the keyword arguments in the
    # function definition. So, we manually get the keyword arguments w/ default
    # values here.
    output_stream = kwargs.pop("output_stream", sys.stderr)
    name = kwargs.pop("name", None)
    summarize = kwargs.pop("summarize", 3)
    sep = kwargs.pop("sep", " ")
    end = kwargs.pop("end", os.linesep)
    if kwargs:
        raise ValueError("Unrecognized keyword arguments for tf.print: %s" %
                         kwargs)
    format_name = None
    if name:
        format_name = name + "_format"

    # Match the C++ string constants representing the different output streams.
    # Keep this updated!
    output_stream_to_constant = {
        sys.stdout: "stdout",
        sys.stderr: "stderr",
        tf_logging.INFO: "log(info)",
        tf_logging.info: "log(info)",
        tf_logging.WARN: "log(warning)",
        tf_logging.warning: "log(warning)",
        tf_logging.warn: "log(warning)",
        tf_logging.ERROR: "log(error)",
        tf_logging.error: "log(error)",
        logging.INFO: "log(info)",
        logging.info: "log(info)",
        logging.INFO: "log(info)",
        logging.WARNING: "log(warning)",
        logging.WARN: "log(warning)",
        logging.warning: "log(warning)",
        logging.warn: "log(warning)",
        logging.ERROR: "log(error)",
        logging.error: "log(error)",
    }

    if _is_filepath(output_stream):
        output_stream_string = output_stream
    else:
        output_stream_string = output_stream_to_constant.get(output_stream)
        if not output_stream_string:
            raise ValueError(
                "Unsupported output stream, logging level, or file." +
                str(output_stream) + ". Supported streams are sys.stdout, "
                "sys.stderr, tf.logging.info, "
                "tf.logging.warning, tf.logging.error. " +
                "File needs to be in the form of 'file://<filepath>'.")

    # If we are only printing a single string scalar, there is no need to format
    if (len(inputs) == 1 and tensor_util.is_tensor(inputs[0])
            and (not isinstance(inputs[0], sparse_tensor.SparseTensor))
            and (inputs[0].shape.ndims == 0)
            and (inputs[0].dtype == dtypes.string)):
        formatted_string = inputs[0]
    # Otherwise, we construct an appropriate template for the tensors we are
    # printing, and format the template using those tensors.
    else:
        # For each input to this print function, we extract any nested tensors,
        # and construct an appropriate template to format representing the
        # printed input.
        templates = []
        tensors = []
        tensor_free_structure = nest.map_structure(
            lambda x: "" if tensor_util.is_tensor(x) else x, inputs)
        tensor_free_template = " ".join(
            pprint.pformat(x) for x in tensor_free_structure)
        placeholder = _generate_placeholder_string(tensor_free_template)

        for input_ in inputs:
            placeholders = []
            # Use the nest utilities to flatten & process any nested elements in this
            # input. The placeholder for a tensor in the template should be the
            # placeholder string, and the placeholder for a non-tensor can just be
            # the printed value of the non-tensor itself.
            for x in nest.flatten(input_):
                # support sparse tensors
                if isinstance(x, sparse_tensor.SparseTensor):
                    tensors.extend([x.indices, x.values, x.dense_shape])
                    placeholders.append(
                        "SparseTensor(indices={}, values={}, shape={})".format(
                            placeholder, placeholder, placeholder))
                elif tensor_util.is_tensor(x):
                    tensors.append(x)
                    placeholders.append(placeholder)
                else:
                    placeholders.append(x)

            if isinstance(input_, six.string_types):
                # If the current input to format/print is a normal string, that string
                # can act as the template.
                cur_template = input_
            else:
                # We pack the placeholders into a data structure that matches the
                # input data structure format, then format that data structure
                # into a string template.
                #
                # NOTE: We must use pprint.pformat here for building the template for
                # unordered data structures such as `dict`, because `str` doesn't
                # guarantee orderings, while pprint prints in sorted order. pprint
                # will match the ordering of `nest.flatten`.
                # This even works when nest.flatten reorders OrderedDicts, because
                # pprint is printing *after* the OrderedDicts have been reordered.
                cur_template = pprint.pformat(
                    nest.pack_sequence_as(input_, placeholders))
            templates.append(cur_template)

        # We join the templates for the various inputs into a single larger
        # template. We also remove all quotes surrounding the placeholders, so that
        # the formatted/printed output will not contain quotes around tensors.
        # (example of where these quotes might appear: if we have added a
        # placeholder string into a list, then pretty-formatted that list)
        template = sep.join(templates)
        template = template.replace("'" + placeholder + "'", placeholder)
        formatted_string = string_ops.string_format(inputs=tensors,
                                                    template=template,
                                                    placeholder=placeholder,
                                                    summarize=summarize,
                                                    name=format_name)

    return gen_logging_ops.print_v2(formatted_string,
                                    output_stream=output_stream_string,
                                    name=name,
                                    end=end)
    def __init__(self,
                 filenames,
                 record_defaults,
                 compression_type=None,
                 buffer_size=None,
                 header=False,
                 field_delim=",",
                 use_quote_delim=True,
                 na_value="",
                 select_cols=None):
        """Creates a `CsvDataset` by reading and decoding CSV files.

    The elements of this dataset correspond to records from the file(s).
    RFC 4180 format is expected for CSV files
    (https://tools.ietf.org/html/rfc4180)
    Note that we allow leading and trailing spaces with int or float field.


    For example, suppose we have a file 'my_file0.csv' with four CSV columns of
    different data types:
    ```
    abcdefg,4.28E10,5.55E6,12
    hijklmn,-5.3E14,,2
    ```

    We can construct a CsvDataset from it as follows:

    ```python
    tf.compat.v1.enable_eager_execution()

     dataset = tf.data.experimental.CsvDataset(
        "my_file*.csv",
        [tf.float32,  # Required field, use dtype or empty tensor
         tf.constant([0.0], dtype=tf.float32),  # Optional field, default to 0.0
         tf.int32,  # Required field, use dtype or empty tensor
         ],
        select_cols=[1,2,3]  # Only parse last three columns
    )
    ```

    The expected output of its iterations is:

    ```python
    for element in dataset:
      print(element)

    >> (4.28e10, 5.55e6, 12)
    >> (-5.3e14, 0.0, 2)
    ```

    Args:
      filenames: A `tf.string` tensor containing one or more filenames.
      record_defaults: A list of default values for the CSV fields. Each item in
        the list is either a valid CSV `DType` (float32, float64, int32, int64,
        string), or a `Tensor` object with one of the above types. One per
        column of CSV data, with either a scalar `Tensor` default value for the
        column if it is optional, or `DType` or empty `Tensor` if required. If
        both this and `select_columns` are specified, these must have the same
        lengths, and `column_defaults` is assumed to be sorted in order of
        increasing column index.
      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
        `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no
        compression.
      buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
        to buffer while reading files. Defaults to 4MB.
      header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s)
        have header line(s) that should be skipped when parsing. Defaults to
        `False`.
      field_delim: (Optional.) A `tf.string` scalar containing the delimiter
        character that separates fields in a record. Defaults to `","`.
      use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats
        double quotation marks as regular characters inside of string fields
        (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`.
      na_value: (Optional.) A `tf.string` scalar indicating a value that will
        be treated as NA/NaN.
      select_cols: (Optional.) A sorted list of column indices to select from
        the input data. If specified, only this subset of columns will be
        parsed. Defaults to parsing all columns.
    """
        self._filenames = ops.convert_to_tensor(filenames,
                                                dtype=dtypes.string,
                                                name="filenames")
        self._compression_type = convert.optional_param_to_tensor(
            "compression_type",
            compression_type,
            argument_default="",
            argument_dtype=dtypes.string)
        record_defaults = [
            constant_op.constant([], dtype=x) if not tensor_util.is_tensor(x)
            and x in _ACCEPTABLE_CSV_TYPES else x for x in record_defaults
        ]
        self._record_defaults = ops.convert_n_to_tensor(record_defaults,
                                                        name="record_defaults")
        self._buffer_size = convert.optional_param_to_tensor(
            "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
        self._header = ops.convert_to_tensor(header,
                                             dtype=dtypes.bool,
                                             name="header")
        self._field_delim = ops.convert_to_tensor(field_delim,
                                                  dtype=dtypes.string,
                                                  name="field_delim")
        self._use_quote_delim = ops.convert_to_tensor(use_quote_delim,
                                                      dtype=dtypes.bool,
                                                      name="use_quote_delim")
        self._na_value = ops.convert_to_tensor(na_value,
                                               dtype=dtypes.string,
                                               name="na_value")
        self._select_cols = convert.optional_param_to_tensor(
            "select_cols",
            select_cols,
            argument_default=[],
            argument_dtype=dtypes.int64,
        )
        self._element_spec = tuple(
            tensor_spec.TensorSpec([], d.dtype) for d in self._record_defaults)
        variant_tensor = gen_experimental_dataset_ops.csv_dataset(
            filenames=self._filenames,
            record_defaults=self._record_defaults,
            buffer_size=self._buffer_size,
            header=self._header,
            output_shapes=self._flat_shapes,
            field_delim=self._field_delim,
            use_quote_delim=self._use_quote_delim,
            na_value=self._na_value,
            select_cols=self._select_cols,
            compression_type=self._compression_type)
        super(CsvDatasetV2, self).__init__(variant_tensor)
示例#54
0
def is_tensor_or_variable(x):
    return tensor_util.is_tensor(x) or isinstance(x, variables.Variable)
def make_csv_dataset_v2(
    file_pattern,
    batch_size,
    column_names=None,
    column_defaults=None,
    label_name=None,
    select_columns=None,
    field_delim=",",
    use_quote_delim=True,
    na_value="",
    header=True,
    num_epochs=None,
    shuffle=True,
    shuffle_buffer_size=10000,
    shuffle_seed=None,
    prefetch_buffer_size=None,
    num_parallel_reads=None,
    sloppy=False,
    num_rows_for_inference=100,
    compression_type=None,
    ignore_errors=False,
):
    """Reads CSV files into a dataset.

  Reads CSV files into a dataset, where each element is a (features, labels)
  tuple that corresponds to a batch of CSV rows. The features dictionary
  maps feature column names to `Tensor`s containing the corresponding
  feature data, and labels is a `Tensor` containing the batch's label data.

  Args:
    file_pattern: List of files or patterns of file paths containing CSV
      records. See `tf.io.gfile.glob` for pattern rules.
    batch_size: An int representing the number of records to combine
      in a single batch.
    column_names: An optional list of strings that corresponds to the CSV
      columns, in order. One per column of the input record. If this is not
      provided, infers the column names from the first row of the records.
      These names will be the keys of the features dict of each dataset element.
    column_defaults: A optional list of default values for the CSV fields. One
      item per selected column of the input record. Each item in the list is
      either a valid CSV dtype (float32, float64, int32, int64, or string), or a
      `Tensor` with one of the aforementioned types. The tensor can either be
      a scalar default value (if the column is optional), or an empty tensor (if
      the column is required). If a dtype is provided instead of a tensor, the
      column is also treated as required. If this list is not provided, tries
      to infer types based on reading the first num_rows_for_inference rows of
      files specified, and assumes all columns are optional, defaulting to `0`
      for numeric values and `""` for string values. If both this and
      `select_columns` are specified, these must have the same lengths, and
      `column_defaults` is assumed to be sorted in order of increasing column
      index.
    label_name: A optional string corresponding to the label column. If
      provided, the data for this column is returned as a separate `Tensor` from
      the features dictionary, so that the dataset complies with the format
      expected by a `tf.Estimator.train` or `tf.Estimator.evaluate` input
      function.
    select_columns: An optional list of integer indices or string column
      names, that specifies a subset of columns of CSV data to select. If
      column names are provided, these must correspond to names provided in
      `column_names` or inferred from the file header lines. When this argument
      is specified, only a subset of CSV columns will be parsed and returned,
      corresponding to the columns specified. Using this results in faster
      parsing and lower memory usage. If both this and `column_defaults` are
      specified, these must have the same lengths, and `column_defaults` is
      assumed to be sorted in order of increasing column index.
    field_delim: An optional `string`. Defaults to `","`. Char delimiter to
      separate fields in a record.
    use_quote_delim: An optional bool. Defaults to `True`. If false, treats
      double quotation marks as regular characters inside of the string fields.
    na_value: Additional string to recognize as NA/NaN.
    header: A bool that indicates whether the first rows of provided CSV files
      correspond to header lines with column names, and should not be included
      in the data.
    num_epochs: An int specifying the number of times this dataset is repeated.
      If None, cycles through the dataset forever.
    shuffle: A bool that indicates whether the input should be shuffled.
    shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size
      ensures better shuffling, but increases memory usage and startup time.
    shuffle_seed: Randomization seed to use for shuffling.
    prefetch_buffer_size: An int specifying the number of feature
      batches to prefetch for performance improvement. Recommended value is the
      number of batches consumed per training step. Defaults to auto-tune.
    num_parallel_reads: Number of threads used to read CSV records from files.
      If >1, the results will be interleaved. Defaults to `1`.
    sloppy: If `True`, reading performance will be improved at
      the cost of non-deterministic ordering. If `False`, the order of elements
      produced is deterministic prior to shuffling (elements are still
      randomized if `shuffle=True`. Note that if the seed is set, then order
      of elements after shuffling is deterministic). Defaults to `False`.
    num_rows_for_inference: Number of rows of a file to use for type inference
      if record_defaults is not provided. If None, reads all the rows of all
      the files. Defaults to 100.
    compression_type: (Optional.) A `tf.string` scalar evaluating to one of
      `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no compression.
    ignore_errors: (Optional.) If `True`, ignores errors with CSV file parsing,
      such as malformed data or empty lines, and moves on to the next valid
      CSV record. Otherwise, the dataset raises an error and stops processing
      when encountering any invalid records. Defaults to `False`.

  Returns:
    A dataset, where each element is a (features, labels) tuple that corresponds
    to a batch of `batch_size` CSV rows. The features dictionary maps feature
    column names to `Tensor`s containing the corresponding column data, and
    labels is a `Tensor` containing the column data for the label column
    specified by `label_name`.

  Raises:
    ValueError: If any of the arguments is malformed.
  """
    if num_parallel_reads is None:
        num_parallel_reads = 1

    if prefetch_buffer_size is None:
        prefetch_buffer_size = dataset_ops.AUTOTUNE

    # Create dataset of all matching filenames
    filenames = _get_file_names(file_pattern, False)
    dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
    if shuffle:
        dataset = dataset.shuffle(len(filenames), shuffle_seed)

    # Clean arguments; figure out column names and defaults
    if column_names is None or column_defaults is None:
        # Find out which io function to open the file
        file_io_fn = lambda filename: file_io.FileIO(filename, "r")
        if compression_type is not None:
            compression_type_value = tensor_util.constant_value(
                compression_type)
            if compression_type_value is None:
                raise ValueError("Received unkown compression_type")
            if compression_type_value == "GZIP":
                file_io_fn = lambda filename: gzip.open(filename, "rt")
            elif compression_type_value == "ZLIB":
                raise ValueError(
                    "compression_type (%s) is not supported for probing columns"
                    % compression_type)
            elif compression_type_value != "":
                raise ValueError("compression_type (%s) is not supported" %
                                 compression_type)
    if column_names is None:
        if not header:
            raise ValueError(
                "Cannot infer column names without a header line.")
        # If column names are not provided, infer from the header lines
        column_names = _infer_column_names(filenames, field_delim,
                                           use_quote_delim, file_io_fn)
    if len(column_names) != len(set(column_names)):
        raise ValueError("Cannot have duplicate column names.")

    if select_columns is not None:
        select_columns = _get_sorted_col_indices(select_columns, column_names)

    if column_defaults is not None:
        column_defaults = [
            constant_op.constant([], dtype=x) if not tensor_util.is_tensor(x)
            and x in _ACCEPTABLE_CSV_TYPES else x for x in column_defaults
        ]
    else:
        # If column defaults are not provided, infer from records at graph
        # construction time
        column_defaults = _infer_column_defaults(filenames, len(column_names),
                                                 field_delim, use_quote_delim,
                                                 na_value, header,
                                                 num_rows_for_inference,
                                                 select_columns, file_io_fn)

    if select_columns is not None and len(column_defaults) != len(
            select_columns):
        raise ValueError(
            "If specified, column_defaults and select_columns must have same "
            "length.")
    if select_columns is not None and len(column_names) > len(select_columns):
        # Pick the relevant subset of column names
        column_names = [column_names[i] for i in select_columns]

    if label_name is not None and label_name not in column_names:
        raise ValueError("`label_name` provided must be one of the columns.")

    def filename_to_dataset(filename):
        dataset = CsvDataset(filename,
                             record_defaults=column_defaults,
                             field_delim=field_delim,
                             use_quote_delim=use_quote_delim,
                             na_value=na_value,
                             select_cols=select_columns,
                             header=header,
                             compression_type=compression_type)
        if ignore_errors:
            dataset = dataset.apply(error_ops.ignore_errors())
        return dataset

    def map_fn(*columns):
        """Organizes columns into a features dictionary.

    Args:
      *columns: list of `Tensor`s corresponding to one csv record.
    Returns:
      An OrderedDict of feature names to values for that particular record. If
      label_name is provided, extracts the label feature to be returned as the
      second element of the tuple.
    """
        features = collections.OrderedDict(zip(column_names, columns))
        if label_name is not None:
            label = features.pop(label_name)
            return features, label
        return features

    if num_parallel_reads == dataset_ops.AUTOTUNE:
        dataset = dataset.interleave(filename_to_dataset,
                                     num_parallel_calls=num_parallel_reads)
        options = dataset_ops.Options()
        options.experimental_deterministic = not sloppy
        dataset = dataset.with_options(options)
    else:
        # Read files sequentially (if num_parallel_reads=1) or in parallel
        dataset = dataset.apply(
            interleave_ops.parallel_interleave(filename_to_dataset,
                                               cycle_length=num_parallel_reads,
                                               sloppy=sloppy))

    dataset = _maybe_shuffle_and_repeat(dataset, num_epochs, shuffle,
                                        shuffle_buffer_size, shuffle_seed)

    # Apply batch before map for perf, because map has high overhead relative
    # to the size of the computation in each map.
    # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
    # improve the shape inference, because it makes the batch dimension static.
    # It is safe to do this because in that case we are repeating the input
    # indefinitely, and all batches will be full-sized.
    dataset = dataset.batch(batch_size=batch_size,
                            drop_remainder=num_epochs is None)
    dataset = dataset_ops.MapDataset(dataset,
                                     map_fn,
                                     use_inter_op_parallelism=False)
    dataset = dataset.prefetch(prefetch_buffer_size)

    return dataset
示例#56
0
  def prune(self, feeds, fetches, name=None, input_signature=None):
    """Extract a subgraph of this function's underlying graph.

    Wraps the subgraph in a new `WrappedFunction` object.

    Args:
      feeds: Input tensors to the subgraph to extract, as `Tensor` objects.
      fetches: Possibly-nested Python data structure containing information
        about outputs of the target subgraph. Each entry can either be a
        `Tensor` object (for data outputs), an `Operation` object (for control
        outputs), or a `TensorInfo` proto. Any additional shape/dtype
        information provided in a `TensorInfo` and not present in the original
        graph will be added to the returned subgraph.
      name: (optional) Name to give to the underlying `FuncGraph` of the
        returned object. If no name is provided, the graph's name will be
        `"pruned"`.
      input_signature: (optional) possibly-nested Python data structure
        containing `TensorSpec` objects, with which to populate the returned
        functions's `FuncGraph`'s `structured_input_signature` field.

    Returns:
      A new `WrappedFunction` object containing a copy of the portion of this
        object's graph that goes from `feeds` to `fetches`.
    """
    # TODO(b/129646028): Add support for CompositeTensors.
    name = name or "pruned"
    feeds = nest.map_structure(self.graph.as_graph_element, feeds)
    flat_feeds = nest.flatten(feeds)
    for f in flat_feeds:
      if not isinstance(f, ops.Tensor):
        raise ValueError("Feeds must be tensors.")

    # Ignoring all feeds that are captures allows prune to be called
    # using wrapped_func.inputs even when it uses variables
    internal_captures = self.graph.internal_captures
    flat_feeds = [f for f in flat_feeds if f not in internal_captures]

    operation_fetches = []
    tensor_fetches = []
    tensor_infos = []

    def _fetch_preprocesing_callback(f):
      """Extract out lists of ops, tensors, and tensor type info.

      Turns TensorInfos into Tensors in the original fetches structure.

      Args:
        f: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string
          identifying a Tensor or Operation.

      Returns:
        `f` converted to a Tensor.
      """
      if isinstance(f, ops.Operation):
        operation_fetches.append(f)
        return f
      elif isinstance(f, meta_graph_pb2.TensorInfo):
        tensor_infos.append(f)
        decoded = _get_element_from_tensor_info(f, self._func_graph)
        if tensor_util.is_tensor(decoded):
          tensor_fetches.append(decoded)
        else:
          operation_fetches.append(decoded)
        return decoded
      elif isinstance(f, ops.Tensor):
        tensor_fetches.append(f)
        return f
      else:
        graph_element = self.graph.as_graph_element(f)
        return _fetch_preprocesing_callback(graph_element)

    fetches = nest.map_structure(_fetch_preprocesing_callback, fetches)

    for f in flat_feeds + tensor_fetches + operation_fetches:
      if f.graph is not self._func_graph:
        raise ValueError("Can only prune function whose feeds and fetches "
                         "are from this graph (%s). Input %s is from graph %s" %
                         (self._func_graph, f, f.graph))
    with self._func_graph.as_default():
      pruned_graph = func_graph.FuncGraph(name)
    lift_map = lift_to_graph.lift_to_graph(
        operation_fetches + tensor_fetches,
        pruned_graph,
        sources=flat_feeds + internal_captures)
    pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
    pruned_graph.control_outputs.extend(
        [lift_map[operation] for operation in operation_fetches])
    for external_capture, internal_capture in self.graph.captures.items():
      pruned_graph.captures[external_capture] = lift_map[internal_capture]
    pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
    pruned_graph.inputs.extend(pruned_graph.captures.values())
    for ti in tensor_infos:
      if ti.WhichOneof("encoding") == "name":  # Dense tensors only
        t = pruned_graph.as_graph_element(ti.name)
        if tensor_util.is_tensor(t):
          t.set_shape(tensor_shape.TensorShape(ti.tensor_shape))
    # pylint: disable=protected-access
    for f in self.graph._functions.values():
      pruned_graph._add_function(f)
    # pylint: enable=protected-access

    pruned_graph.variables = self.graph.variables

    def _structured_output_mapping(fetched):
      lifted = lift_map[fetched]
      if isinstance(lifted, ops.Operation):
        return None
      return lifted

    pruned_graph.structured_outputs = nest.map_structure(
        _structured_output_mapping, fetches)
    pruned_graph.structured_input_signature = input_signature
    pruned_fn = WrappedFunction(
        pruned_graph, variable_holder=self._variable_holder)
    pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
    # TODO(kathywu): Enable keyword arguments if an input signature is specified
    pruned_fn._arg_keywords = [tensor.op.name for tensor in flat_feeds]  # pylint: disable=protected-access
    return pruned_fn
示例#57
0
def standardize_weights(y,
                        sample_weight=None,
                        class_weight=None,
                        sample_weight_mode=None):
    """Performs sample weight validation and standardization.

  Everything gets normalized to a single sample-wise (or timestep-wise)
  weight array.

  Arguments:
      y: Numpy array of model targets to be weighted.
      sample_weight: User-provided `sample_weight` argument.
      class_weight: User-provided `class_weight` argument.
      sample_weight_mode: One of `None` or `"temporal"`.
          `"temporal"` indicated that we expect 2D weight data
          that will be applied to the last 2 dimensions of
          the targets (i.e. we are weighting timesteps, not samples).

  Returns:
      A numpy array of target weights, one entry per sample to weight.

  Raises:
      ValueError: In case of invalid user-provided arguments.
  """
    # Iterator may return sample_weight as 1-tuple
    if isinstance(sample_weight, tuple):
        sample_weight = sample_weight[0]
    if sample_weight_mode is not None:
        if sample_weight_mode != 'temporal':
            raise ValueError('"sample_weight_mode '
                             'should be None or "temporal". '
                             'Found: ' + str(sample_weight_mode))
        if len(y.shape) < 3:
            raise ValueError('Found a sample_weight array for '
                             'an input with shape ' + str(y.shape) + '. '
                             'Timestep-wise sample weighting (use of '
                             'sample_weight_mode="temporal") is restricted to '
                             'outputs that are at least 3D, i.e. that have '
                             'a time dimension.')
        if sample_weight is not None and len(sample_weight.shape) != 2:
            raise ValueError('Found a sample_weight array with shape ' +
                             str(sample_weight.shape) + '. '
                             'In order to use timestep-wise sample weighting, '
                             'you should pass a 2D sample_weight array.')
    else:
        if sample_weight is not None and len(sample_weight.shape) != 1:
            raise ValueError('Found a sample_weight array with shape ' +
                             str(sample_weight.shape) + '. '
                             'In order to use timestep-wise sample weights, '
                             'you should specify '
                             'sample_weight_mode="temporal" '
                             'in compile(). If you just mean to use '
                             'sample-wise weights, make sure your '
                             'sample_weight array is 1D.')

    if sample_weight is not None:
        if len(sample_weight.shape) > len(y.shape):
            raise ValueError('Found a sample_weight with shape' +
                             str(sample_weight.shape) + '.'
                             'Expected sample_weight with rank '
                             'less than or equal to ' + str(len(y.shape)))

        if (not tensor_util.is_tensor(sample_weight)
                and y.shape[:sample_weight.ndim] != sample_weight.shape):
            raise ValueError('Found a sample_weight array with shape ' +
                             str(sample_weight.shape) +
                             ' for an input with shape ' + str(y.shape) + '. '
                             'sample_weight cannot be broadcast.')
        return sample_weight
    elif isinstance(class_weight, dict):
        if len(y.shape) > 2:
            raise ValueError('`class_weight` not supported for '
                             '3+ dimensional targets.')
        if y.shape[1] > 1:
            y_classes = np.argmax(y, axis=1)
        elif y.shape[1] == 1:
            y_classes = np.reshape(y, y.shape[0])
        else:
            y_classes = y

        weights = np.asarray(
            [class_weight[cls] for cls in y_classes if cls in class_weight])

        if len(weights) != len(y_classes):
            # subtract the sets to pick all missing classes
            existing_classes = set(y_classes)
            existing_class_weight = set(class_weight.keys())
            raise ValueError(
                '`class_weight` must contain all classes in the data.'
                ' The classes %s exist in the data but not in '
                '`class_weight`.' % (existing_classes - existing_class_weight))
        return weights
    else:
        return None
示例#58
0
  def map_resources(self):
    """Makes new resource handle ops corresponding to existing resource tensors.

    Creates resource handle ops in the current default graph, whereas
    `accessible_objects` will be from an eager context. Resource mapping adds
    resource handle ops to the main GraphDef of a SavedModel, which allows the
    C++ loader API to interact with variables.

    Returns:
      A tuple of (object_map, resource_map, asset_info):
        object_map: A dictionary mapping from object in `accessible_objects` to
          replacement objects created to hold the new resource tensors.
        resource_map: A dictionary mapping from resource tensors extracted from
          `accessible_objects` to newly created resource tensors.
        asset_info: An _AssetInfo tuple describing external assets referenced
          from accessible_objects.
    """
    # Only makes sense when adding to the export Graph
    assert not context.executing_eagerly()
    # TODO(allenl): Handle MirroredVariables and other types of variables which
    # may need special casing.
    object_map = object_identity.ObjectIdentityDictionary()
    resource_map = {}
    asset_info = _AssetInfo(
        asset_defs=[],
        asset_initializers_by_resource={},
        asset_filename_map={},
        asset_index={})

    for node_id, obj in enumerate(self.nodes):
      if isinstance(obj, tracking.CapturableResource):
        # pylint: disable=protected-access
        with ops.device(obj._resource_device):
          new_resource = obj._create_resource()
        # pylint: enable=protected-access
        resource_map[obj.resource_handle] = new_resource
        self.captured_tensor_node_ids[obj.resource_handle] = node_id
      elif ds_values.is_distributed_variable(obj):
        # Put both the distributed variable and component variable handles in
        # `captured_tensor_node_ids`.
        # Also create a new distributed variable for `object_map` with newly
        # created component variables.
        new_vars = []
        for v in obj.values:
          new_variable = resource_variable_ops.copy_to_graph_uninitialized(v)
          object_map[v] = new_variable
          new_vars.append(new_variable)
          resource_map[v.handle] = new_variable.handle
          self.captured_tensor_node_ids[v.handle] = node_id
        object_map[obj] = obj._clone_with_new_values(new_vars)  # pylint: disable=protected-access
        self.captured_tensor_node_ids[obj] = node_id
      elif resource_variable_ops.is_resource_variable(obj):
        new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj)
        object_map[obj] = new_variable
        resource_map[obj.handle] = new_variable.handle
        self.captured_tensor_node_ids[obj.handle] = node_id
      elif isinstance(obj, tracking.TrackableAsset):
        _process_asset(obj, asset_info, resource_map)
        self.captured_tensor_node_ids[obj.asset_path] = node_id

    for concrete_function in self.concrete_functions:
      for capture in concrete_function.captured_inputs:
        if (tensor_util.is_tensor(capture)
            and capture.dtype not in _UNCOPIABLE_DTYPES
            and capture not in self.captured_tensor_node_ids):
          capture_constant_value = tensor_util.constant_value(capture)
          if capture_constant_value is None:
            raise ValueError(
                ("Attempted to save a function {} which references a symbolic "
                 "Tensor {} that is not a simple constant. This is not "
                 "supported.").format(concrete_function.name, capture))
          copied_tensor = constant_op.constant(capture_constant_value)
          node_id = len(self.nodes)
          node = _CapturedConstant(
              eager_tensor=capture, graph_tensor=copied_tensor)
          self.nodes.append(node)
          self.node_ids[capture] = node_id
          self.node_ids[node] = node_id
          self.captured_tensor_node_ids[capture] = node_id
          resource_map[capture] = copied_tensor

    return object_map, resource_map, asset_info
示例#59
0
  def __init__(self,
               graph_parents=None,
               is_constant_jacobian=False,
               validate_args=False,
               dtype=None,
               forward_min_event_ndims=None,
               inverse_min_event_ndims=None,
               name=None):
    """Constructs Bijector.

    A `Bijector` transforms random variables into new random variables.

    Examples:

    ```python
    # Create the Y = g(X) = X transform.
    identity = Identity()

    # Create the Y = g(X) = exp(X) transform.
    exp = Exp()
    ```

    See `Bijector` subclass docstring for more details and specific examples.

    Args:
      graph_parents: Python list of graph prerequisites of this `Bijector`.
      is_constant_jacobian: Python `bool` indicating that the Jacobian matrix is
        not a function of the input.
      validate_args: Python `bool`, default `False`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are invalid,
        correct behavior is not guaranteed.
      dtype: `tf.dtype` supported by this `Bijector`. `None` means dtype is not
        enforced.
      forward_min_event_ndims: Python `integer` indicating the minimum number of
        dimensions `forward` operates on.
      inverse_min_event_ndims: Python `integer` indicating the minimum number of
        dimensions `inverse` operates on. Will be set to
        `forward_min_event_ndims` by default, if no value is provided.
      name: The name to give Ops created by the initializer.

    Raises:
      ValueError:  If neither `forward_min_event_ndims` and
        `inverse_min_event_ndims` are specified, or if either of them is
        negative.
      ValueError:  If a member of `graph_parents` is not a `Tensor`.
    """
    self._graph_parents = graph_parents or []

    if forward_min_event_ndims is None and inverse_min_event_ndims is None:
      raise ValueError("Must specify at least one of `forward_min_event_ndims` "
                       "and `inverse_min_event_ndims`.")
    elif inverse_min_event_ndims is None:
      inverse_min_event_ndims = forward_min_event_ndims
    elif forward_min_event_ndims is None:
      forward_min_event_ndims = inverse_min_event_ndims

    if not isinstance(forward_min_event_ndims, int):
      raise TypeError("Expected forward_min_event_ndims to be of "
                      "type int, got {}".format(
                          type(forward_min_event_ndims).__name__))

    if not isinstance(inverse_min_event_ndims, int):
      raise TypeError("Expected inverse_min_event_ndims to be of "
                      "type int, got {}".format(
                          type(inverse_min_event_ndims).__name__))

    if forward_min_event_ndims < 0:
      raise ValueError("forward_min_event_ndims must be a non-negative "
                       "integer.")
    if inverse_min_event_ndims < 0:
      raise ValueError("inverse_min_event_ndims must be a non-negative "
                       "integer.")

    self._forward_min_event_ndims = forward_min_event_ndims
    self._inverse_min_event_ndims = inverse_min_event_ndims
    self._is_constant_jacobian = is_constant_jacobian
    self._constant_ildj_map = {}
    self._validate_args = validate_args
    self._dtype = dtype
    self._from_y = {}
    self._from_x = {}
    if name:
      self._name = name
    else:
      # We want the default convention to be snake_case rather than CamelCase
      # since `Chain` uses bijector.name as the kwargs dictionary key.
      def camel_to_snake(name):
        s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
        return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
      self._name = camel_to_snake(type(self).__name__.lstrip("_"))

    for i, t in enumerate(self._graph_parents):
      if t is None or not tensor_util.is_tensor(t):
        raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
示例#60
0
def int_(x=0, base=UNSPECIFIED):
    if tensor_util.is_tensor(x):
        return _tf_int(x, base)
    return _py_int(x, base)