def conv_1d_fn(lhs, rhs): dnums = xla_data_pb2.ConvolutionDimensionNumbers() num_spatial_dims = 1 dnums.input_batch_dimension = 0 dnums.input_feature_dimension = 1 dnums.output_batch_dimension = 0 dnums.output_feature_dimension = 1 dnums.kernel_output_feature_dimension = 0 dnums.kernel_input_feature_dimension = 1 dnums.input_spatial_dimensions.extend( range(2, 2 + num_spatial_dims)) dnums.kernel_spatial_dimensions.extend( range(2, 2 + num_spatial_dims)) dnums.output_spatial_dimensions.extend( range(2, 2 + num_spatial_dims)) precision_config = None return xla.conv(lhs, rhs, window_strides=(1, ), padding=((2, 1), ), lhs_dilation=(1, ), rhs_dilation=(2, ), dimension_numbers=dnums, precision_config=precision_config, preferred_element_type=preferred_element_type)
def conv_1d_fn(lhs, rhs): dnums = xla_data_pb2.ConvolutionDimensionNumbers() num_spatial_dims = 1 dnums.input_batch_dimension = 0 dnums.input_feature_dimension = 1 dnums.output_batch_dimension = 0 dnums.output_feature_dimension = 1 dnums.kernel_output_feature_dimension = 0 dnums.kernel_input_feature_dimension = 1 dnums.input_spatial_dimensions.extend( range(2, 2 + num_spatial_dims)) dnums.kernel_spatial_dimensions.extend( range(2, 2 + num_spatial_dims)) dnums.output_spatial_dimensions.extend( range(2, 2 + num_spatial_dims)) precision_config = None if precision: precision_config = xla_data_pb2.PrecisionConfig() precision_config.operand_precision.extend( [precision, precision]) return xla.conv(lhs, rhs, window_strides=(1, ), padding=((2, 1), ), lhs_dilation=(1, ), rhs_dilation=(2, ), dimension_numbers=dnums)
def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers): """Enqueues a ConvGeneralDilated operation onto the computation. Args: lhs: LocalOp for the rank N+2 array of inputs. rhs: LocalOp for the rank N+2 array of kernel weights. window_strides: length-N array-like of integer kernel strides. padding: length-N array-like of pairs of integers of (low, high) padding. lhs_dilation: length-N array-like of integer dilation factors. rhs_dilation: length-N array-like of integer dilation factors. dimension_numbers: either an xla_data_pb2.ConvolutionDimensionNumbers or a triple (lhs_spec, rhs_spec, out_spec) where each element is a string of length N+2 identifying by position (1) batch dimensions in lhs, rhs, and the output with the character 'N', (2) feature dimensions in lhs and the output with the character 'C', (3) input and output feature dimensions in rhs with the characters 'I' and 'O' respectively, and (4) spatial dimension correspondences between lhs, rhs, and the output using any distinct characters. For example, to indicate dimension numbers consistent with the Conv operation with two spatial dimensions, one could use ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate dimension numbers consistent with the TensorFlow Conv2D operation, one could use ('NHWC', 'HWIO', 'NHWC'). When using the latter form of convolution dimension specification, window strides are associated with spatial dimension character labels according to the order in which the labels appear in the rhs_spec string, so that window_strides[0] is matched with the dimension corresponding to the first character appearing in rhs_spec that is not 'I' or 'O'. Returns: a LocalOp representing the ConvGenralDilated operation. """ if not isinstance(dimension_numbers, xla_data_pb2.ConvolutionDimensionNumbers): lhs_spec, rhs_spec, out_spec = dimension_numbers dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers() dimension_numbers.input_batch_dimension = lhs_spec.index('N') dimension_numbers.input_feature_dimension = lhs_spec.index('C') dimension_numbers.output_batch_dimension = out_spec.index('N') dimension_numbers.output_feature_dimension = out_spec.index('C') dimension_numbers.kernel_output_feature_dimension = rhs_spec.index( 'O') dimension_numbers.kernel_input_feature_dimension = rhs_spec.index( 'I') dimension_numbers.kernel_spatial_dimensions.extend( i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'}) dimension_numbers.input_spatial_dimensions.extend( sorted( (i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), key=lambda i: rhs_spec.index(lhs_spec[i]))) dimension_numbers.output_spatial_dimensions.extend( sorted( (i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), key=lambda i: rhs_spec.index(out_spec[i]))) return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers)
def _GetConvDimensionNumbers(self, num_spatial_dims): """Create ConvolutionDimensionNumbers proto for convolutions.""" nd = num_spatial_dims dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers() dimension_numbers.input_batch_dimension = 0 dimension_numbers.input_feature_dimension = 1 dimension_numbers.output_batch_dimension = 0 dimension_numbers.output_feature_dimension = 1 dimension_numbers.kernel_output_feature_dimension = 0 dimension_numbers.kernel_input_feature_dimension = 1 dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd)) dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd)) dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) return dimension_numbers
def _conv_general_proto(dimension_numbers): """Converts a ConvDimensionNumbers to an XLA ConvolutionDimensionNumbers.""" assert isinstance(dimension_numbers, lax.ConvDimensionNumbers) lhs_spec, rhs_spec, out_spec = dimension_numbers proto = xla_data_pb2.ConvolutionDimensionNumbers() proto.input_batch_dimension = lhs_spec[0] proto.input_feature_dimension = lhs_spec[1] proto.output_batch_dimension = out_spec[0] proto.output_feature_dimension = out_spec[1] proto.kernel_output_feature_dimension = rhs_spec[0] proto.kernel_input_feature_dimension = rhs_spec[1] proto.input_spatial_dimensions.extend(lhs_spec[2:]) proto.kernel_spatial_dimensions.extend(rhs_spec[2:]) proto.output_spatial_dimensions.extend(out_spec[2:]) return proto