예제 #1
0
def conv(lhs,
         rhs,
         window_strides,
         padding,
         lhs_dilation,
         rhs_dilation,
         dimension_numbers,
         feature_group_count=1,
         precision_config=None,
         name=None):
  """Wraps the XLA ConvGeneralDilated operator.

  ConvGeneralDilated is the most general form of XLA convolution and is
  documented at
  https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution

  Args:
    lhs: the input tensor
    rhs: the kernel tensor
    window_strides: the inter-window strides
    padding: the padding to apply at the start and end of each input dimensions
    lhs_dilation: dilation to apply between input elements
    rhs_dilation: dilation to apply between kernel elements
    dimension_numbers: a `ConvolutionDimensionNumbers` proto.
    feature_group_count: number of feature groups for grouped convolution.
    precision_config: a `xla.PrecisionConfig` proto.
    name: an optional name for the operator

  Returns:
    A tensor representing the output of the convolution.
  """
  precision_config_proto = ""
  if precision_config:
    precision_config_proto = precision_config.SerializeToString()
  return gen_xla_ops.xla_conv(
      lhs,
      rhs,
      window_strides=window_strides,
      padding=padding,
      lhs_dilation=lhs_dilation,
      rhs_dilation=rhs_dilation,
      feature_group_count=feature_group_count,
      dimension_numbers=dimension_numbers.SerializeToString(),
      precision_config=precision_config_proto,
      name=name)
예제 #2
0
파일: xla.py 프로젝트: AnishShah/tensorflow
def conv(lhs,
         rhs,
         window_strides,
         padding,
         lhs_dilation,
         rhs_dilation,
         dimension_numbers,
         feature_group_count=1,
         precision_config=None,
         name=None):
  """Wraps the XLA ConvGeneralDilated operator.

  ConvGeneralDilated is the most general form of XLA convolution and is
  documented at
  https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution

  Args:
    lhs: the input tensor
    rhs: the kernel tensor
    window_strides: the inter-window strides
    padding: the padding to apply at the start and end of each input dimensions
    lhs_dilation: dilation to apply between input elements
    rhs_dilation: dilation to apply between kernel elements
    dimension_numbers: a `ConvolutionDimensionNumbers` proto.
    feature_group_count: number of feature groups for grouped convolution.
    precision_config: a `PrecisionConfigProto` proto.
    name: an optional name for the operator

  Returns:
    A tensor representing the output of the convolution.
  """
  precision_config_proto = ""
  if precision_config:
    precision_config_proto = precision_config.SerializeToString()
  return gen_xla_ops.xla_conv(
      lhs,
      rhs,
      window_strides=window_strides,
      padding=padding,
      lhs_dilation=lhs_dilation,
      rhs_dilation=rhs_dilation,
      feature_group_count=feature_group_count,
      dimension_numbers=dimension_numbers.SerializeToString(),
      precision_config=precision_config_proto,
      name=name)
예제 #3
0
def conv(lhs,
         rhs,
         window_strides,
         padding,
         lhs_dilation,
         rhs_dilation,
         dimension_numbers,
         feature_group_count=1,
         precision_config=None,
         preferred_element_type=None,
         name=None,
         use_v2=False,
         batch_group_count=1):
  """Wraps the XLA ConvGeneralDilated operator.

  ConvGeneralDilated is the most general form of XLA convolution and is
  documented at
  https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution

  Args:
    lhs: the input tensor
    rhs: the kernel tensor
    window_strides: the inter-window strides
    padding: the padding to apply at the start and end of each input dimensions
    lhs_dilation: dilation to apply between input elements
    rhs_dilation: dilation to apply between kernel elements
    dimension_numbers: a `ConvolutionDimensionNumbers` proto.
    feature_group_count: number of feature groups for grouped convolution.
    precision_config: a `xla.PrecisionConfig` proto.
    preferred_element_type: the result `dtype`.
    name: an optional name for the operator.
    use_v2: an optional request to use the XlaConvV2 op even if not necessary.
    batch_group_count: number of batch groups or grouped filters.

  Returns:
    A tensor representing the output of the convolution.
  """
  precision_config_proto = ""
  if precision_config:
    precision_config_proto = precision_config.SerializeToString()
  needs_v2 = (
      preferred_element_type or (lhs.dtype != rhs.dtype) or
      batch_group_count > 1)
  if preferred_element_type is None:
    preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype)
  if needs_v2 or use_v2:
    return gen_xla_ops.xla_conv_v2(
        lhs,
        rhs,
        window_strides=window_strides,
        padding=padding,
        lhs_dilation=lhs_dilation,
        rhs_dilation=rhs_dilation,
        feature_group_count=feature_group_count,
        batch_group_count=batch_group_count,
        dimension_numbers=dimension_numbers.SerializeToString(),
        precision_config=precision_config_proto,
        preferred_element_type=preferred_element_type,
        name=name)
  return gen_xla_ops.xla_conv(
      lhs,
      rhs,
      window_strides=window_strides,
      padding=padding,
      lhs_dilation=lhs_dilation,
      rhs_dilation=rhs_dilation,
      feature_group_count=feature_group_count,
      dimension_numbers=dimension_numbers.SerializeToString(),
      precision_config=precision_config_proto,
      name=name)