コード例 #1
0
 def conv_1d_fn(lhs, rhs):
     dnums = xla_data_pb2.ConvolutionDimensionNumbers()
     num_spatial_dims = 1
     dnums.input_batch_dimension = 0
     dnums.input_feature_dimension = 1
     dnums.output_batch_dimension = 0
     dnums.output_feature_dimension = 1
     dnums.kernel_output_feature_dimension = 0
     dnums.kernel_input_feature_dimension = 1
     dnums.input_spatial_dimensions.extend(
         range(2, 2 + num_spatial_dims))
     dnums.kernel_spatial_dimensions.extend(
         range(2, 2 + num_spatial_dims))
     dnums.output_spatial_dimensions.extend(
         range(2, 2 + num_spatial_dims))
     precision_config = None
     if precision:
         precision_config = xla_data_pb2.PrecisionConfig()
         precision_config.operand_precision.extend(
             [precision, precision])
     return xla.conv(lhs,
                     rhs,
                     window_strides=(1, ),
                     padding=((2, 1), ),
                     lhs_dilation=(1, ),
                     rhs_dilation=(2, ),
                     dimension_numbers=dnums)
コード例 #2
0
 def dot_fn(lhs, rhs):
   dnums = xla_data_pb2.DotDimensionNumbers()
   dnums.lhs_contracting_dimensions.append(2)
   dnums.rhs_contracting_dimensions.append(1)
   dnums.lhs_batch_dimensions.append(0)
   dnums.rhs_batch_dimensions.append(0)
   precision_config = None
   if precision:
     precision_config = xla_data_pb2.PrecisionConfig()
     precision_config.operand_precision.extend([precision, precision])
   return xla.dot_general(
       lhs,
       rhs,
       dimension_numbers=dnums,
       precision_config=precision_config)