def func_body_augmented(iteration, chosen_ids):
     # find a new facility location to add
     #  based on the clustering score and the NMI score
     candidate_ids = array_ops.setdiff1d(all_ids, chosen_ids)[0]
     new_chosen_idx = _find_loss_augmented_facility_idx(
         pairwise_distances, labels, chosen_ids, candidate_ids,
         margin_multiplier, margin_type)
     chosen_ids = array_ops.concat([chosen_ids, [new_chosen_idx]], 0)
     return iteration + 1, chosen_ids
Пример #2
0
 def func_body_augmented(iteration, chosen_ids):
   # find a new facility location to add
   #  based on the clustering score and the NMI score
   candidate_ids = array_ops.setdiff1d(all_ids, chosen_ids)[0]
   new_chosen_idx = _find_loss_augmented_facility_idx(pairwise_distances,
                                                      labels, chosen_ids,
                                                      candidate_ids,
                                                      margin_multiplier,
                                                      margin_type)
   chosen_ids = array_ops.concat([chosen_ids, [new_chosen_idx]], 0)
   return iteration + 1, chosen_ids
Пример #3
0
def _ProdGrad(op, grad):
    """Gradient for Prod."""
    # The gradient can be expressed by dividing the product by each entry of the
    # input tensor, but this approach can't deal with zeros in the input.
    # Here, we avoid this problem by composing the output as a product of two
    # cumprod operations.

    input_shape = array_ops.shape(op.inputs[0])
    # Reshape reduction indices for the case where the parameter is a scalar
    reduction_indices = array_ops.reshape(op.inputs[1], [-1])

    # Expand grad to full input shape
    output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
    tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
    grad = array_ops.reshape(grad, output_shape_kept_dims)
    grad = array_ops.tile(grad, tile_scaling)

    # Pack all reduced dimensions into a single one, so we can perform the
    # cumprod ops. If the reduction dims list is empty, it defaults to float32,
    # so we need to cast here.  We put all the shape-related ops on CPU to avoid
    # copying back and forth, and since listdiff is CPU only.
    with ops.device("/cpu:0"):
        rank = array_ops.rank(op.inputs[0])
        reduction_indices = (reduction_indices + rank) % rank
        reduced = math_ops.cast(reduction_indices, dtypes.int32)
        idx = math_ops.range(0, rank)
        other, _ = array_ops.setdiff1d(idx, reduced)
        perm = array_ops.concat([reduced, other], 0)
        reduced_num = math_ops.reduce_prod(
            array_ops.gather(input_shape, reduced))
        other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other))
    permuted = array_ops.transpose(op.inputs[0], perm)
    permuted_shape = array_ops.shape(permuted)
    reshaped = array_ops.reshape(permuted, (reduced_num, other_num))

    # Calculate product, leaving out the current entry
    left = math_ops.cumprod(reshaped, axis=0, exclusive=True)
    right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True)
    # For complex inputs, the gradient is in the conjugate direction.
    y = array_ops.reshape(
        math_ops.conj(left) * math_ops.conj(right), permuted_shape)

    # Invert the transpose and reshape operations.
    # Make sure to set the statically known shape information through a reshape.
    out = grad * array_ops.transpose(y, array_ops.invert_permutation(perm))
    return array_ops.reshape(out, input_shape), None
Пример #4
0
def _ProdGrad(op, grad):
  """Gradient for Prod."""
  # The gradient can be expressed by dividing the product by each entry of the
  # input tensor, but this approach can't deal with zeros in the input.
  # Here, we avoid this problem by composing the output as a product of two
  # cumprod operations.

  input_shape = array_ops.shape(op.inputs[0])
  # Reshape reduction indices for the case where the parameter is a scalar
  reduction_indices = array_ops.reshape(op.inputs[1], [-1])

  # Expand grad to full input shape
  output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
  tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
  grad = array_ops.reshape(grad, output_shape_kept_dims)
  grad = array_ops.tile(grad, tile_scaling)

  # Pack all reduced dimensions into a single one, so we can perform the
  # cumprod ops. If the reduction dims list is empty, it defaults to float32,
  # so we need to cast here.  We put all the shape-related ops on CPU to avoid
  # copying back and forth, and since listdiff is CPU only.
  with ops.device("/cpu:0"):
    rank = array_ops.rank(op.inputs[0])
    reduction_indices = (reduction_indices + rank) % rank
    reduced = math_ops.cast(reduction_indices, dtypes.int32)
    idx = math_ops.range(0, rank)
    other, _ = array_ops.setdiff1d(idx, reduced)
    perm = array_ops.concat([reduced, other], 0)
    reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced))
    other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other))
  permuted = array_ops.transpose(op.inputs[0], perm)
  permuted_shape = array_ops.shape(permuted)
  reshaped = array_ops.reshape(permuted, (reduced_num, other_num))

  # Calculate product, leaving out the current entry
  left = math_ops.cumprod(reshaped, axis=0, exclusive=True)
  right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True)
  # For complex inputs, the gradient is in the conjugate direction.
  y = array_ops.reshape(math_ops.conj(left) * math_ops.conj(right),
                        permuted_shape)

  # Invert the transpose and reshape operations.
  # Make sure to set the statically known shape information through a reshape.
  out = grad * array_ops.transpose(y, array_ops.invert_permutation(perm))
  return array_ops.reshape(out, input_shape), None
Пример #5
0
def _embedding_lookup_with_distributed_aggregation(params,
                                                   ids,
                                                   partition_strategy="mod",
                                                   name=None,
                                                   max_norm=None,
                                                   weights=None,
                                                   idx=None,
                                                   segment_ids=None):
  """Lookup helper for embedding_lookup_sparse_with_distributed_aggregation."""
  if params is None or params == []:  # pylint: disable=g-explicit-bool-comparison
    raise ValueError("Need at least one param")
  if isinstance(params, variables.PartitionedVariable):
    params = list(params)  # Iterate to get the underlying Variables.
  if not isinstance(params, list):
    params = [params]

  def maybe_normalize(x):
    if max_norm is not None:
      if x.get_shape().ndims is not None:
        ndims = x.get_shape().ndims
      else:
        ndims = array_ops.size(array_ops.shape(x))
      return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims)))
    return x

  with ops.name_scope(name, "embedding_lookup_with_distributed_aggregation",
                      params + [ids]) as name:
    np = len(params)  # Number of partitions
    # Preserve the resource variable status to avoid accidental dense reads.
    if not any(
        isinstance(p, resource_variable_ops.ResourceVariable) for p in params):
      params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
    if np == 1:
      with ops.colocate_with(params[0]):
        ret = maybe_normalize(_do_gather(params[0], ids))
        ignore_weights = weights is None
        if not ignore_weights:
          if weights.dtype != ret.dtype:
            weights = math_ops.cast(weights, ret.dtype)
          # Reshape to allow broadcast
          ones = array_ops.fill(
              array_ops.expand_dims(array_ops.rank(ret) - 1, 0), 1)
          bcast_weights_shape = array_ops.concat(
              [array_ops.shape(weights), ones], 0)
          orig_weights_shape = weights.get_shape()
          weights = array_ops.reshape(weights, bcast_weights_shape)
          # Set weights shape after reshape
          if ret.get_shape().ndims is not None:
            weights.set_shape(
                orig_weights_shape.concatenate(
                    [1 for _ in range(ret.get_shape().ndims - 1)]))
          ret *= weights
          return math_ops.segment_sum(ret, segment_ids, name=name)
        else:
          return math_ops.sparse_segment_sum(ret, idx, segment_ids, name=name)
    else:
      ids = ops.convert_to_tensor(ids, name="ids")
      flat_ids = array_ops.reshape(ids, [-1])
      original_indices = math_ops.range(array_ops.size(flat_ids))

      # Create p_assignments and set new_ids depending on the strategy.
      if partition_strategy == "mod":
        p_assignments = flat_ids % np
        new_ids = flat_ids // np
      elif partition_strategy == "div":
        # Compute num_total_ids as the sum of dim-0 of params, then assign to
        # partitions based on a constant number of ids per partition. Optimize
        # if we already know the full shape statically.
        dim_0_size = params[0].get_shape()[0]
        for p in xrange(1, np):
          dim_0_size += params[p].get_shape()[0]
        if dim_0_size.value:
          num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
        else:
          dim_0_sizes = []
          for p in xrange(np):
            if params[p].get_shape()[0].value is not None:
              dim_0_sizes.append(params[p].get_shape()[0].value)
            else:
              with ops.colocate_with(params[p]):
                dim_0_sizes.append(array_ops.shape(params[p])[0])
          num_total_ids = math_ops.reduce_sum(
              math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
        ids_per_partition = num_total_ids // np
        extras = num_total_ids % np

        p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1), (
            flat_ids - extras) // ids_per_partition)

        # Emulate a conditional using a boolean indicator tensor
        is_in_first_extras_partitions = math_ops.cast(p_assignments < extras,
                                                      flat_ids.dtype)
        new_ids = (is_in_first_extras_partitions * (flat_ids %
                                                    (ids_per_partition + 1)) +
                   (1 - is_in_first_extras_partitions) * (
                       (flat_ids - extras) % ids_per_partition))
      else:
        raise ValueError("Unrecognized partition strategy: " +
                         partition_strategy)

      # Cast partition assignments to int32 for use in dynamic_partition.
      # There really should not be more than 2^32 partitions.
      p_assignments = math_ops.cast(p_assignments, dtypes.int32)
      # Partition list of ids based on assignments into np separate lists
      gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
      # Similarly, partition the original indices.
      pindices = data_flow_ops.dynamic_partition(original_indices,
                                                 p_assignments, np)
      # Do np separate lookups, finding embeddings for plist[p] in params[p]
      partitioned_result = []
      for p in xrange(np):
        with ops.colocate_with(params[p]):
          partitioned_result.append(_do_gather(params[p], gather_ids[p]))

      ignore_weights = weights is None
      if not ignore_weights:
        # Partition weights according to pindices.
        partitioned_weight = []
        for p in xrange(np):
          partitioned_weight.append(array_ops.gather(weights, pindices[p]))
      # Reshape each partition result.
      element_shape = params[0].get_shape()[1:]
      for p in params[1:]:
        element_shape = element_shape.merge_with(p.get_shape()[1:])
      if element_shape.is_fully_defined():
        for p in xrange(np):
          with ops.colocate_with(params[p]):
            partitioned_result[p] = array_ops.reshape(
                partitioned_result[p],
                array_ops.concat([array_ops.shape(pindices[p]), element_shape],
                                 0))
      else:
        with ops.colocate_with(params[0]):
          params_shape = array_ops.shape(params[0])
        for p in xrange(np):
          with ops.colocate_with(params[p]):
            partitioned_result[p] = array_ops.reshape(
                partitioned_result[p],
                array_ops.concat([
                    array_ops.shape(pindices[p]), array_ops.slice(
                        params_shape, [1], [-1])
                ], 0))
      # Normalize each partition result.
      for p in xrange(np):
        with ops.colocate_with(params[p]):
          partitioned_result[p] = maybe_normalize(partitioned_result[p])
      if not ignore_weights:
        # Multiply each partition result with partition weights.
        for p in xrange(np):
          with ops.colocate_with(params[p]):
            if partitioned_weight[p].dtype != partitioned_result[p].dtype:
              partitioned_weight[p] = math_ops.cast(partitioned_weight[p],
                                                    partitioned_result[p].dtype)
            # Reshape partition weights.
            ones = array_ops.fill(
                array_ops.expand_dims(
                    array_ops.rank(partitioned_result[p]) - 1, 0), 1)
            bcast_weights_shape = array_ops.concat(
                [array_ops.shape(partitioned_weight[p]), ones], 0)
            orig_weights_shape = partitioned_weight[p].get_shape()
            partitioned_weight[p] = array_ops.reshape(partitioned_weight[p],
                                                      bcast_weights_shape)
            if partitioned_result[p].get_shape().ndims is not None:
              partitioned_weight[p].set_shape(
                  orig_weights_shape.concatenate([
                      1
                      for _ in range(partitioned_result[p].get_shape().ndims -
                                     1)
                  ]))
            partitioned_result[p] *= partitioned_weight[p]
      partitioned_segment_ids = []
      for p in xrange(np):
        if not ignore_weights:
          # Partition segment_ids according to pindices.
          p_segment_ids = array_ops.gather(segment_ids, pindices[p])
          # Number the p_segment_ids to meet segment_sum's requirements. Note
          # that unique_p_segment_ids contains unique segment ids of this
          # partition and these ids' order is unchanged.
          unique_p_segment_ids, unique_p_segment_idx = array_ops.unique(
              p_segment_ids)
          partitioned_segment_ids.append(unique_p_segment_ids)
          # segment_sum this partition's result.
          with ops.colocate_with(params[p]):
            partitioned_result[p] = math_ops.segment_sum(
                partitioned_result[p], unique_p_segment_idx)
        else:
          # When ignore weights, we need to get indexs of elements in idx and
          # segment_ids.
          _, exclude_idx = array_ops.setdiff1d(idx, pindices[p])
          all_idx = math_ops.range(array_ops.shape(idx)[0])
          _, include_idx = array_ops.setdiff1d(all_idx, exclude_idx)
          # Gather segment_ids and idx according to indexs.
          p_segment_ids = array_ops.gather(segment_ids, include_idx)
          p_idx = array_ops.gather(idx, include_idx)
          # Number the p_segment_ids, same as ignore_weights case above.
          unique_p_segment_ids, unique_p_segment_idx = array_ops.unique(
              p_segment_ids)
          _, unique_p_idx_idx = array_ops.unique(p_idx)
          partitioned_segment_ids.append(unique_p_segment_ids)
          with ops.colocate_with(params[p]):
            partitioned_result[p] = math_ops.sparse_segment_sum(
                partitioned_result[p], unique_p_idx_idx, unique_p_segment_idx)
      # Concat each partition's segment_ids and result for final segment_sum.
      concat_segment_ids = array_ops.concat(partitioned_segment_ids, 0)
      concat_partitioned_result = array_ops.concat(partitioned_result, 0)
      return math_ops.unsorted_segment_sum(
          concat_partitioned_result,
          concat_segment_ids,
          math_ops.reduce_max(concat_segment_ids) + 1,
          name=name)
Пример #6
0
def norm(tensor,
         ord='euclidean',
         axis=None,
         keepdims=None,
         name=None,
         keep_dims=None):
    r"""Computes the norm of vectors, matrices, and tensors.

  This function can compute several different vector norms (the 1-norm, the
  Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0) and
  matrix norms (Frobenius, 1-norm, 2-norm and inf-norm).

  Args:
    tensor: `Tensor` of types `float32`, `float64`, `complex64`, `complex128`
    ord: Order of the norm. Supported values are 'fro', 'euclidean',
      `1`, `2`, `np.inf` and any positive real number yielding the corresponding
      p-norm. Default is 'euclidean' which is equivalent to Frobenius norm if
      `tensor` is a matrix and equivalent to 2-norm for vectors.
      Some restrictions apply:
        a) The Frobenius norm `fro` is not defined for vectors,
        b) If axis is a 2-tuple (matrix norm), only 'euclidean', 'fro', `1`,
           `2`, `np.inf` are supported.
      See the description of `axis` on how to compute norms for a batch of
      vectors or matrices stored in a tensor.
    axis: If `axis` is `None` (the default), the input is considered a vector
      and a single vector norm is computed over the entire set of values in the
      tensor, i.e. `norm(tensor, ord=ord)` is equivalent to
      `norm(reshape(tensor, [-1]), ord=ord)`.
      If `axis` is a Python integer, the input is considered a batch of vectors,
      and `axis` determines the axis in `tensor` over which to compute vector
      norms.
      If `axis` is a 2-tuple of Python integers it is considered a batch of
      matrices and `axis` determines the axes in `tensor` over which to compute
      a matrix norm.
      Negative indices are supported. Example: If you are passing a tensor that
      can be either a matrix or a batch of matrices at runtime, pass
      `axis=[-2,-1]` instead of `axis=None` to make sure that matrix norms are
      computed.
    keepdims: If True, the axis indicated in `axis` are kept with size 1.
      Otherwise, the dimensions in `axis` are removed from the output shape.
    name: The name of the op.
    keep_dims: Deprecated alias for `keepdims`.

  Returns:
    output: A `Tensor` of the same type as tensor, containing the vector or
      matrix norms. If `keepdims` is True then the rank of output is equal to
      the rank of `tensor`. Otherwise, if `axis` is none the output is a scalar,
      if `axis` is an integer, the rank of `output` is one less than the rank
      of `tensor`, if `axis` is a 2-tuple the rank of `output` is two less
      than the rank of `tensor`.

  Raises:
    ValueError: If `ord` or `axis` is invalid.

  @compatibility(numpy)
  Mostly equivalent to numpy.linalg.norm.
  Not supported: ord <= 0, 2-norm for matrices, nuclear norm.
  Other differences:
    a) If axis is `None`, treats the flattened `tensor` as a vector
     regardless of rank.
    b) Explicitly supports 'euclidean' norm as the default, including for
     higher order tensors.
  @end_compatibility
  """
    keepdims = deprecation.deprecated_argument_lookup('keepdims', keepdims,
                                                      'keep_dims', keep_dims)
    if keepdims is None:
        keepdims = False

    is_matrix_norm = ((isinstance(axis, tuple) or isinstance(axis, list))
                      and len(axis) == 2)
    if is_matrix_norm:
        axis = tuple(axis)
        if (not isinstance(axis[0], int) or not isinstance(axis[1], int)
                or axis[0] == axis[1]):
            raise ValueError(
                "'axis' must be None, an integer, or a tuple of 2 unique integers"
            )
        supported_matrix_norms = ['euclidean', 'fro', 1, 2, np.inf]
        if ord not in supported_matrix_norms:
            raise ValueError(
                "'ord' must be a supported matrix norm in %s, got %s" %
                (supported_matrix_norms, ord))
    else:
        if not (isinstance(axis, int) or axis is None):
            raise ValueError(
                "'axis' must be None, an integer, or a tuple of 2 unique integers"
            )

        supported_vector_norms = ['euclidean', 1, 2, np.inf]
        if (not np.isreal(ord)
                or ord <= 0) and ord not in supported_vector_norms:
            raise ValueError("'ord' must be a supported vector norm, got %s" %
                             ord)
        if axis is not None:
            axis = (axis, )

    with ops.name_scope(name, 'norm', [tensor]):
        tensor = ops.convert_to_tensor(tensor)

        if ord in ['fro', 'euclidean', 2, 2.0]:
            if is_matrix_norm and ord in [2, 2.0]:
                rank = array_ops.rank(tensor)
                positive_axis = map_fn.map_fn(
                    lambda i: control_flow_ops.cond(i >= 0, lambda: i, lambda:
                                                    i + rank),
                    ops.convert_to_tensor(axis))
                axes = math_ops.range(rank)
                perm_before = array_ops.concat([
                    array_ops.setdiff1d(axes, positive_axis)[0], positive_axis
                ],
                                               axis=0)
                perm_after = map_fn.map_fn(
                    lambda i: math_ops.cast(array_ops.squeeze(
                        array_ops.where_v2(math_ops.equal(perm_before, i))),
                                            dtype=dtypes.int32), axes)
                permed = array_ops.transpose(tensor, perm=perm_before)
                matrix_2_norm = array_ops.expand_dims(math_ops.reduce_max(
                    math_ops.abs(
                        gen_linalg_ops.svd(permed, compute_uv=False)[0]),
                    axis=-1,
                    keepdims=True),
                                                      axis=-1)
                result = array_ops.transpose(matrix_2_norm, perm=perm_after)
            else:
                result = math_ops.sqrt(
                    math_ops.reduce_sum(tensor * math_ops.conj(tensor),
                                        axis,
                                        keepdims=True))
                # TODO(rmlarsen): Replace with the following, once gradients are defined
                # result = math_ops.reduce_euclidean_norm(tensor, axis, keepdims=True)
        else:
            result = math_ops.abs(tensor)
            if ord == 1:
                sum_axis = None if axis is None else axis[0]
                result = math_ops.reduce_sum(result, sum_axis, keepdims=True)
                if is_matrix_norm:
                    result = math_ops.reduce_max(result,
                                                 axis[-1],
                                                 keepdims=True)
            elif ord == np.inf:
                if is_matrix_norm:
                    result = math_ops.reduce_sum(result,
                                                 axis[1],
                                                 keepdims=True)
                max_axis = None if axis is None else axis[0]
                result = math_ops.reduce_max(result, max_axis, keepdims=True)
            else:
                # General p-norms (positive p only)
                result = math_ops.pow(
                    math_ops.reduce_sum(math_ops.pow(result, ord),
                                        axis,
                                        keepdims=True), 1.0 / ord)
        if not keepdims:
            result = array_ops.squeeze(result, axis)
        return result
Пример #7
0
def norm(tensor,
         ord='euclidean',
         axis=None,
         keepdims=None,
         name=None,
         keep_dims=None):
  r"""Computes the norm of vectors, matrices, and tensors.

  This function can compute several different vector norms (the 1-norm, the
  Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0) and
  matrix norms (Frobenius, 1-norm, 2-norm and inf-norm).

  Args:
    tensor: `Tensor` of types `float32`, `float64`, `complex64`, `complex128`
    ord: Order of the norm. Supported values are 'fro', 'euclidean',
      `1`, `2`, `np.inf` and any positive real number yielding the corresponding
      p-norm. Default is 'euclidean' which is equivalent to Frobenius norm if
      `tensor` is a matrix and equivalent to 2-norm for vectors.
      Some restrictions apply:
        a) The Frobenius norm `fro` is not defined for vectors,
        b) If axis is a 2-tuple (matrix norm), only 'euclidean', 'fro', `1`,
           `2`, `np.inf` are supported.
      See the description of `axis` on how to compute norms for a batch of
      vectors or matrices stored in a tensor.
    axis: If `axis` is `None` (the default), the input is considered a vector
      and a single vector norm is computed over the entire set of values in the
      tensor, i.e. `norm(tensor, ord=ord)` is equivalent to
      `norm(reshape(tensor, [-1]), ord=ord)`.
      If `axis` is a Python integer, the input is considered a batch of vectors,
      and `axis` determines the axis in `tensor` over which to compute vector
      norms.
      If `axis` is a 2-tuple of Python integers it is considered a batch of
      matrices and `axis` determines the axes in `tensor` over which to compute
      a matrix norm.
      Negative indices are supported. Example: If you are passing a tensor that
      can be either a matrix or a batch of matrices at runtime, pass
      `axis=[-2,-1]` instead of `axis=None` to make sure that matrix norms are
      computed.
    keepdims: If True, the axis indicated in `axis` are kept with size 1.
      Otherwise, the dimensions in `axis` are removed from the output shape.
    name: The name of the op.
    keep_dims: Deprecated alias for `keepdims`.

  Returns:
    output: A `Tensor` of the same type as tensor, containing the vector or
      matrix norms. If `keepdims` is True then the rank of output is equal to
      the rank of `tensor`. Otherwise, if `axis` is none the output is a scalar,
      if `axis` is an integer, the rank of `output` is one less than the rank
      of `tensor`, if `axis` is a 2-tuple the rank of `output` is two less
      than the rank of `tensor`.

  Raises:
    ValueError: If `ord` or `axis` is invalid.

  @compatibility(numpy)
  Mostly equivalent to numpy.linalg.norm.
  Not supported: ord <= 0, 2-norm for matrices, nuclear norm.
  Other differences:
    a) If axis is `None`, treats the flattened `tensor` as a vector
     regardless of rank.
    b) Explicitly supports 'euclidean' norm as the default, including for
     higher order tensors.
  @end_compatibility
  """
  keepdims = deprecation.deprecated_argument_lookup('keepdims', keepdims,
                                                    'keep_dims', keep_dims)
  if keepdims is None:
    keepdims = False

  is_matrix_norm = ((isinstance(axis, tuple) or isinstance(axis, list)) and
                    len(axis) == 2)
  if is_matrix_norm:
    axis = tuple(axis)
    if (not isinstance(axis[0], int) or not isinstance(axis[1], int) or
        axis[0] == axis[1]):
      raise ValueError(
          "'axis' must be None, an integer, or a tuple of 2 unique integers")
    supported_matrix_norms = ['euclidean', 'fro', 1, 2, np.inf]
    if ord not in supported_matrix_norms:
      raise ValueError("'ord' must be a supported matrix norm in %s, got %s" %
                       (supported_matrix_norms, ord))
  else:
    if not (isinstance(axis, int) or axis is None):
      raise ValueError(
          "'axis' must be None, an integer, or a tuple of 2 unique integers")

    supported_vector_norms = ['euclidean', 1, 2, np.inf]
    if (not np.isreal(ord) or ord <= 0) and ord not in supported_vector_norms:
      raise ValueError("'ord' must be a supported vector norm, got %s" % ord)
    if axis is not None:
      axis = (axis,)

  with ops.name_scope(name, 'norm', [tensor]):
    tensor = ops.convert_to_tensor(tensor)

    if ord in ['fro', 'euclidean', 2, 2.0]:
      if is_matrix_norm and ord in [2, 2.0]:
        rank = array_ops.rank(tensor)
        positive_axis = map_fn.map_fn(
            lambda i: control_flow_ops.cond(i >= 0, lambda: i, lambda: i + rank),
            ops.convert_to_tensor(axis))
        axes = math_ops.range(rank)
        perm_before = array_ops.concat(
            [array_ops.setdiff1d(axes, positive_axis)[0], positive_axis],
            axis=0)
        perm_after = map_fn.map_fn(
            lambda i: math_ops.cast(
                array_ops.squeeze(
                    array_ops.where(math_ops.equal(perm_before, i))),
                dtype=dtypes.int32), axes)
        permed = array_ops.transpose(tensor, perm=perm_before)
        matrix_2_norm = array_ops.expand_dims(
            math_ops.reduce_max(
                math_ops.abs(gen_linalg_ops.svd(permed, compute_uv=False)[0]),
                axis=-1,
                keepdims=True),
            axis=-1)
        result = array_ops.transpose(matrix_2_norm, perm=perm_after)
      else:
        result = math_ops.sqrt(
            math_ops.reduce_sum(
                tensor * math_ops.conj(tensor), axis, keepdims=True))
    else:
      result = math_ops.abs(tensor)
      if ord == 1:
        sum_axis = None if axis is None else axis[0]
        result = math_ops.reduce_sum(result, sum_axis, keepdims=True)
        if is_matrix_norm:
          result = math_ops.reduce_max(result, axis[-1], keepdims=True)
      elif ord == np.inf:
        if is_matrix_norm:
          result = math_ops.reduce_sum(result, axis[1], keepdims=True)
        max_axis = None if axis is None else axis[0]
        result = math_ops.reduce_max(result, max_axis, keepdims=True)
      else:
        # General p-norms (positive p only)
        result = math_ops.pow(
            math_ops.reduce_sum(math_ops.pow(result, ord), axis, keepdims=True),
            1.0 / ord)
    if not keepdims:
      result = array_ops.squeeze(result, axis)
    return result