def test_create_n_file(): """Test create_n_file method.""" local = os.path.dirname(os.path.abspath(__file__)) shape = (200, 200) expected_file_name = os.path.join(local, "n.fits") FitsHelper.create_n_file(shape, local) arr = fits.getdata(expected_file_name) assert arr.shape == shape assert np.issubdtype(np.int16, arr.dtype) os.remove(expected_file_name)
def test_create_float64_file(): """Tests 2d float64 file creation""" local = os.path.dirname(os.path.abspath(__file__)) dummy_name = os.path.join(local, "dummy_float64.fits") dummy_shape = (200, 200) dummy_dtype = np.float64 FitsHelper.create_file(dummy_name, dummy_shape, dummy_dtype) retrieved = fits.getdata(dummy_name) os.remove(dummy_name) assert dummy_shape == retrieved.shape assert np.issubdtype(dummy_dtype, retrieved.dtype)
def test_create_file_3d(): """Tests a 3d file creation""" local = os.path.dirname(os.path.abspath(__file__)) dummy_name = os.path.join(local, "dummy_unit8_3d.fits") dummy_shape = (200, 200, 5) dummy_dtype = np.uint8 FitsHelper.create_file(dummy_name, dummy_shape, dummy_dtype) retrieved = fits.getdata(dummy_name) os.remove(dummy_name) # the channels when stored as fits get moved to the front. dummy_shape = tuple(reversed(dummy_shape)) assert dummy_shape == retrieved.shape assert np.issubdtype(dummy_dtype, retrieved.dtype)
def test_create_rank_vote_files(): """Tests create_rank_vote_files method.""" morphs = [ "spheroid", "disk", "irregular", "point_source", "background" ] local = os.path.dirname(os.path.abspath(__file__)) shape = (200, 200) expected_file_names = [] for m in morphs: expected_file_names.append(os.path.join(local, f"{m}.fits")) FitsHelper.create_rank_vote_files(shape, local) for f in expected_file_names: arr = fits.getdata(f) assert arr.shape == shape assert np.issubdtype(np.float32, arr.dtype) os.remove(f)
def test_get_files(): """Tests the get_files() method""" local = os.path.dirname(os.path.abspath(__file__)) dummy_file1 = os.path.join(local, "dummy1.fits") dummy_file2 = os.path.join(local, "dummy2.fits") arr1 = np.zeros([100, 100], dtype=np.float32) arr2 = np.ones([100, 100], dtype=np.float32) for f, a in zip([dummy_file1, dummy_file2], [arr1, arr2]): fits.PrimaryHDU(data=a).writeto(f) _, arrs = FitsHelper.get_files([dummy_file1, dummy_file2]) os.remove(dummy_file1) os.remove(dummy_file2) assert np.array_equal(arr1, arrs[0]) assert np.array_equal(arr2, arrs[1])
def test_create_file_rasies(): """Tests that a ValueError is raised from improper dtype param""" with pytest.raises(ValueError): FitsHelper.create_file("dummy.fits", [0], str)