def conv_1d_fn(lhs, rhs):
     dnums = xla_data_pb2.ConvolutionDimensionNumbers()
     num_spatial_dims = 1
     dnums.input_batch_dimension = 0
     dnums.input_feature_dimension = 1
     dnums.output_batch_dimension = 0
     dnums.output_feature_dimension = 1
     dnums.kernel_output_feature_dimension = 0
     dnums.kernel_input_feature_dimension = 1
     dnums.input_spatial_dimensions.extend(
         range(2, 2 + num_spatial_dims))
     dnums.kernel_spatial_dimensions.extend(
         range(2, 2 + num_spatial_dims))
     dnums.output_spatial_dimensions.extend(
         range(2, 2 + num_spatial_dims))
     precision_config = None
     return xla.conv(lhs,
                     rhs,
                     window_strides=(1, ),
                     padding=((2, 1), ),
                     lhs_dilation=(1, ),
                     rhs_dilation=(2, ),
                     dimension_numbers=dnums,
                     precision_config=precision_config,
                     preferred_element_type=preferred_element_type)
 def conv_1d_fn(lhs, rhs):
     dnums = xla_data_pb2.ConvolutionDimensionNumbers()
     num_spatial_dims = 1
     dnums.input_batch_dimension = 0
     dnums.input_feature_dimension = 1
     dnums.output_batch_dimension = 0
     dnums.output_feature_dimension = 1
     dnums.kernel_output_feature_dimension = 0
     dnums.kernel_input_feature_dimension = 1
     dnums.input_spatial_dimensions.extend(
         range(2, 2 + num_spatial_dims))
     dnums.kernel_spatial_dimensions.extend(
         range(2, 2 + num_spatial_dims))
     dnums.output_spatial_dimensions.extend(
         range(2, 2 + num_spatial_dims))
     precision_config = None
     if precision:
         precision_config = xla_data_pb2.PrecisionConfig()
         precision_config.operand_precision.extend(
             [precision, precision])
     return xla.conv(lhs,
                     rhs,
                     window_strides=(1, ),
                     padding=((2, 1), ),
                     lhs_dilation=(1, ),
                     rhs_dilation=(2, ),
                     dimension_numbers=dnums)
Beispiel #3
0
def _conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation,
                          rhs_dilation, dimension_numbers, feature_group_count,
                          batch_group_count, lhs_shape, rhs_shape, precision):
  """Implementation of lax.conv_general_dilated_p using XlaConv."""
  out_shape = _conv_general_dilated_shape(
      lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
      dimension_numbers, feature_group_count, batch_group_count, lhs_shape,
      rhs_shape, precision)
  # TODO(phawkins): handle precision
  dnums_proto = _conv_general_proto(dimension_numbers)
  assert batch_group_count == 1  # TODO(phawkins): implement batch_group_count
  out = tfxla.conv(
      lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
      dnums_proto, feature_group_count)
  # TODO(tomhennigan): tf2xla should have a shape inference function.
  out.set_shape(out_shape)
  return out
Beispiel #4
0
 def conv_1d_fn(lhs, rhs):
   dnums = xla_data_pb2.ConvolutionDimensionNumbers()
   num_spatial_dims = 1
   dnums.input_batch_dimension = 0
   dnums.input_feature_dimension = 1
   dnums.output_batch_dimension = 0
   dnums.output_feature_dimension = 1
   dnums.kernel_output_feature_dimension = 0
   dnums.kernel_input_feature_dimension = 1
   dnums.input_spatial_dimensions.extend(range(2, 2 + num_spatial_dims))
   dnums.kernel_spatial_dimensions.extend(range(2, 2 + num_spatial_dims))
   dnums.output_spatial_dimensions.extend(range(2, 2 + num_spatial_dims))
   precision_config = None
   if precision:
     precision_config = xla_data_pb2.PrecisionConfig()
     precision_config.operand_precision.extend([precision, precision])
   return xla.conv(
       lhs,
       rhs,
       window_strides=(1,),
       padding=((2, 1),),
       lhs_dilation=(1,),
       rhs_dilation=(2,),
       dimension_numbers=dnums)