def conv_1d_fn(lhs, rhs): dnums = xla_data_pb2.ConvolutionDimensionNumbers() num_spatial_dims = 1 dnums.input_batch_dimension = 0 dnums.input_feature_dimension = 1 dnums.output_batch_dimension = 0 dnums.output_feature_dimension = 1 dnums.kernel_output_feature_dimension = 0 dnums.kernel_input_feature_dimension = 1 dnums.input_spatial_dimensions.extend( range(2, 2 + num_spatial_dims)) dnums.kernel_spatial_dimensions.extend( range(2, 2 + num_spatial_dims)) dnums.output_spatial_dimensions.extend( range(2, 2 + num_spatial_dims)) precision_config = None return xla.conv(lhs, rhs, window_strides=(1, ), padding=((2, 1), ), lhs_dilation=(1, ), rhs_dilation=(2, ), dimension_numbers=dnums, precision_config=precision_config, preferred_element_type=preferred_element_type)
def conv_1d_fn(lhs, rhs): dnums = xla_data_pb2.ConvolutionDimensionNumbers() num_spatial_dims = 1 dnums.input_batch_dimension = 0 dnums.input_feature_dimension = 1 dnums.output_batch_dimension = 0 dnums.output_feature_dimension = 1 dnums.kernel_output_feature_dimension = 0 dnums.kernel_input_feature_dimension = 1 dnums.input_spatial_dimensions.extend( range(2, 2 + num_spatial_dims)) dnums.kernel_spatial_dimensions.extend( range(2, 2 + num_spatial_dims)) dnums.output_spatial_dimensions.extend( range(2, 2 + num_spatial_dims)) precision_config = None if precision: precision_config = xla_data_pb2.PrecisionConfig() precision_config.operand_precision.extend( [precision, precision]) return xla.conv(lhs, rhs, window_strides=(1, ), padding=((2, 1), ), lhs_dilation=(1, ), rhs_dilation=(2, ), dimension_numbers=dnums)
def _conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, lhs_shape, rhs_shape, precision): """Implementation of lax.conv_general_dilated_p using XlaConv.""" out_shape = _conv_general_dilated_shape( lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, lhs_shape, rhs_shape, precision) # TODO(phawkins): handle precision dnums_proto = _conv_general_proto(dimension_numbers) assert batch_group_count == 1 # TODO(phawkins): implement batch_group_count out = tfxla.conv( lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dnums_proto, feature_group_count) # TODO(tomhennigan): tf2xla should have a shape inference function. out.set_shape(out_shape) return out
def conv_1d_fn(lhs, rhs): dnums = xla_data_pb2.ConvolutionDimensionNumbers() num_spatial_dims = 1 dnums.input_batch_dimension = 0 dnums.input_feature_dimension = 1 dnums.output_batch_dimension = 0 dnums.output_feature_dimension = 1 dnums.kernel_output_feature_dimension = 0 dnums.kernel_input_feature_dimension = 1 dnums.input_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) dnums.kernel_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) dnums.output_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) precision_config = None if precision: precision_config = xla_data_pb2.PrecisionConfig() precision_config.operand_precision.extend([precision, precision]) return xla.conv( lhs, rhs, window_strides=(1,), padding=((2, 1),), lhs_dilation=(1,), rhs_dilation=(2,), dimension_numbers=dnums)