def _test_scalar_type(self, spark_type, numpy_type, bits): codec = ScalarCodec(spark_type()) field = UnischemaField(name='field_int', numpy_dtype=numpy_type, shape=(), codec=codec, nullable=False) min_val, max_val = -2 ** (bits - 1), 2 ** (bits - 1) - 1 self.assertEqual(codec.decode(field, codec.encode(field, numpy_type(min_val))), min_val) self.assertEqual(codec.decode(field, codec.encode(field, numpy_type(max_val))), max_val) self.assertNotEqual(codec.decode(field, codec.encode(field, numpy_type(min_val))), min_val - 1)
def test_bad_encoded_data_shape(): codec = ScalarCodec(IntegerType()) field = UnischemaField(name='field_int', numpy_dtype=np.int32, shape=(), codec=codec, nullable=False) with pytest.raises(TypeError): codec.decode(field, codec.encode(field, np.asarray([10, 10])))
def test_unicode(): codec = ScalarCodec(StringType()) field = UnischemaField(name='field_string', numpy_dtype=np.unicode_, shape=(), codec=codec, nullable=False) assert codec.decode(field, codec.encode(field, 'abc')) == 'abc' assert codec.decode(field, codec.encode(field, '')) == ''
def test_scalar_codec_unicode(self): codec = ScalarCodec(StringType()) field = UnischemaField(name='field_string', numpy_dtype=np.unicode_, shape=(), codec=codec, nullable=False) self.assertEqual(codec.decode(field, codec.encode(field, 'abc')), 'abc') self.assertEqual(codec.decode(field, codec.encode(field, '')), '')
def test_numeric_types(spark_numpy_types): spark_type, numpy_type = spark_numpy_types codec = ScalarCodec(spark_type()) field = UnischemaField(name='field_int', numpy_dtype=numpy_type, shape=(), codec=codec, nullable=False) min_val, max_val = np.iinfo(numpy_type).min, np.iinfo(numpy_type).max assert codec.decode(field, codec.encode(field, numpy_type(min_val))) == min_val assert codec.decode(field, codec.encode(field, numpy_type(max_val))) == max_val
def test_scalar_codec_decimal(): codec = ScalarCodec(DecimalType(4, 3)) field = UnischemaField(name='field_decimal', numpy_dtype=Decimal, shape=(), codec=codec, nullable=False) value = Decimal('123.4567') assert codec.decode(field, codec.encode(field, value)) == value