Exemplo n.º 1
0
 def conv_1d_fn(lhs, rhs):
     dnums = xla_data_pb2.ConvolutionDimensionNumbers()
     num_spatial_dims = 1
     dnums.input_batch_dimension = 0
     dnums.input_feature_dimension = 1
     dnums.output_batch_dimension = 0
     dnums.output_feature_dimension = 1
     dnums.kernel_output_feature_dimension = 0
     dnums.kernel_input_feature_dimension = 1
     dnums.input_spatial_dimensions.extend(
         range(2, 2 + num_spatial_dims))
     dnums.kernel_spatial_dimensions.extend(
         range(2, 2 + num_spatial_dims))
     dnums.output_spatial_dimensions.extend(
         range(2, 2 + num_spatial_dims))
     precision_config = None
     return xla.conv(lhs,
                     rhs,
                     window_strides=(1, ),
                     padding=((2, 1), ),
                     lhs_dilation=(1, ),
                     rhs_dilation=(2, ),
                     dimension_numbers=dnums,
                     precision_config=precision_config,
                     preferred_element_type=preferred_element_type)
 def conv_1d_fn(lhs, rhs):
     dnums = xla_data_pb2.ConvolutionDimensionNumbers()
     num_spatial_dims = 1
     dnums.input_batch_dimension = 0
     dnums.input_feature_dimension = 1
     dnums.output_batch_dimension = 0
     dnums.output_feature_dimension = 1
     dnums.kernel_output_feature_dimension = 0
     dnums.kernel_input_feature_dimension = 1
     dnums.input_spatial_dimensions.extend(
         range(2, 2 + num_spatial_dims))
     dnums.kernel_spatial_dimensions.extend(
         range(2, 2 + num_spatial_dims))
     dnums.output_spatial_dimensions.extend(
         range(2, 2 + num_spatial_dims))
     precision_config = None
     if precision:
         precision_config = xla_data_pb2.PrecisionConfig()
         precision_config.operand_precision.extend(
             [precision, precision])
     return xla.conv(lhs,
                     rhs,
                     window_strides=(1, ),
                     padding=((2, 1), ),
                     lhs_dilation=(1, ),
                     rhs_dilation=(2, ),
                     dimension_numbers=dnums)
Exemplo n.º 3
0
    def ConvGeneralDilated(self, lhs, rhs, window_strides, padding,
                           lhs_dilation, rhs_dilation, dimension_numbers):
        """Enqueues a ConvGeneralDilated operation onto the computation.

    Args:
      lhs: LocalOp for the rank N+2 array of inputs.
      rhs: LocalOp for the rank N+2 array of kernel weights.
      window_strides: length-N array-like of integer kernel strides.
      padding: length-N array-like of pairs of integers of (low, high) padding.
      lhs_dilation: length-N array-like of integer dilation factors.
      rhs_dilation: length-N array-like of integer dilation factors.
      dimension_numbers: either an xla_data_pb2.ConvolutionDimensionNumbers or a
        triple (lhs_spec, rhs_spec, out_spec) where each element is a string of
        length N+2 identifying by position (1) batch dimensions in lhs, rhs, and
        the output with the character 'N', (2) feature dimensions in lhs and the
        output with the character 'C', (3) input and output feature dimensions
        in rhs with the characters 'I' and 'O' respectively, and (4) spatial
        dimension correspondences between lhs, rhs, and the output using any
        distinct characters. For example, to indicate dimension numbers
        consistent with the Conv operation with two spatial dimensions, one
        could use ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate
        dimension numbers consistent with the TensorFlow Conv2D operation, one
        could use ('NHWC', 'HWIO', 'NHWC'). When using the latter form of
        convolution dimension specification, window strides are associated with
        spatial dimension character labels according to the order in which the
        labels appear in the rhs_spec string, so that window_strides[0] is
        matched with the dimension corresponding to the first character
        appearing in rhs_spec that is not 'I' or 'O'.

    Returns: a LocalOp representing the ConvGenralDilated operation.
    """
        if not isinstance(dimension_numbers,
                          xla_data_pb2.ConvolutionDimensionNumbers):
            lhs_spec, rhs_spec, out_spec = dimension_numbers
            dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers()

            dimension_numbers.input_batch_dimension = lhs_spec.index('N')
            dimension_numbers.input_feature_dimension = lhs_spec.index('C')
            dimension_numbers.output_batch_dimension = out_spec.index('N')
            dimension_numbers.output_feature_dimension = out_spec.index('C')
            dimension_numbers.kernel_output_feature_dimension = rhs_spec.index(
                'O')
            dimension_numbers.kernel_input_feature_dimension = rhs_spec.index(
                'I')

            dimension_numbers.kernel_spatial_dimensions.extend(
                i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'})
            dimension_numbers.input_spatial_dimensions.extend(
                sorted(
                    (i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}),
                    key=lambda i: rhs_spec.index(lhs_spec[i])))
            dimension_numbers.output_spatial_dimensions.extend(
                sorted(
                    (i for i, c in enumerate(out_spec) if c not in {'N', 'C'}),
                    key=lambda i: rhs_spec.index(out_spec[i])))
        return self._client.ConvGeneralDilated(lhs, rhs, window_strides,
                                               padding, lhs_dilation,
                                               rhs_dilation, dimension_numbers)
Exemplo n.º 4
0
 def _GetConvDimensionNumbers(self, num_spatial_dims):
   """Create ConvolutionDimensionNumbers proto for convolutions."""
   nd = num_spatial_dims
   dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers()
   dimension_numbers.input_batch_dimension = 0
   dimension_numbers.input_feature_dimension = 1
   dimension_numbers.output_batch_dimension = 0
   dimension_numbers.output_feature_dimension = 1
   dimension_numbers.kernel_output_feature_dimension = 0
   dimension_numbers.kernel_input_feature_dimension = 1
   dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd))
   dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd))
   dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd))
   return dimension_numbers
Exemplo n.º 5
0
def _conv_general_proto(dimension_numbers):
    """Converts a ConvDimensionNumbers to an XLA ConvolutionDimensionNumbers."""
    assert isinstance(dimension_numbers, lax.ConvDimensionNumbers)
    lhs_spec, rhs_spec, out_spec = dimension_numbers
    proto = xla_data_pb2.ConvolutionDimensionNumbers()
    proto.input_batch_dimension = lhs_spec[0]
    proto.input_feature_dimension = lhs_spec[1]
    proto.output_batch_dimension = out_spec[0]
    proto.output_feature_dimension = out_spec[1]
    proto.kernel_output_feature_dimension = rhs_spec[0]
    proto.kernel_input_feature_dimension = rhs_spec[1]
    proto.input_spatial_dimensions.extend(lhs_spec[2:])
    proto.kernel_spatial_dimensions.extend(rhs_spec[2:])
    proto.output_spatial_dimensions.extend(out_spec[2:])
    return proto