def _applycal(self, array, apply_correction): """Calibrate `array` with `apply_correction` and return all factors.""" array_dask = da.from_array(array, chunks=(10, 4, 6)) final_cal_products, correction = calc_correction( array_dask.chunks, self.cache, CORRPRODS, CAL_PRODUCTS, FREQS, {'cal': CAL_FREQS}) corrected = da.core.elemwise(apply_correction, array_dask, correction, dtype=array_dask.dtype) return corrected.compute(), correction.compute()
def test_skip_missing_products(self): dump = 15 channels = np.s_[22:38] shape = (N_DUMPS, N_CHANS, N_CORRPRODS) chunks = da.core.normalize_chunks((10, 5, -1), shape) final_cal_products, corrections = calc_correction( chunks, self.cache, CORRPRODS, [], FREQS, {'cal': CAL_FREQS}) assert_equal(final_cal_products, []) assert_equal(corrections, None) with assert_raises(ValueError): calc_correction(chunks, self.cache, CORRPRODS, ['INVALID'], FREQS, {'cal': CAL_FREQS}) unknown = CAL_STREAM + '.UNKNOWN' final_cal_products, corrections = calc_correction( chunks, self.cache, CORRPRODS, [unknown], FREQS, {'cal': CAL_FREQS}, skip_missing_products=True) assert_equal(final_cal_products, []) assert_equal(corrections, None) cal_products = CAL_PRODUCTS + [unknown] with assert_raises(KeyError): calc_correction(chunks, self.cache, CORRPRODS, cal_products, FREQS, {'cal': CAL_FREQS}, skip_missing_products=False) final_cal_products, corrections = calc_correction( chunks, self.cache, CORRPRODS, cal_products, FREQS, {'cal': CAL_FREQS}, skip_missing_products=True) assert_equal(set(final_cal_products), set(CAL_PRODUCTS)) corrections = corrections[dump:dump+1, channels].compute() expected_corrections = corrections_per_corrprod([dump], channels, final_cal_products) assert_array_equal(corrections, expected_corrections)
def test_calc_correction(self): dump = 15 channels = np.s_[22:38] shape = (N_DUMPS, N_CHANS, N_CORRPRODS) chunks = da.core.normalize_chunks((10, 5, -1), shape) final_cal_products, corrections = calc_correction( chunks, self.cache, CORRPRODS, CAL_PRODUCTS, FREQS, {'cal': CAL_FREQS}) assert_equal(set(final_cal_products), set(CAL_PRODUCTS)) corrections = corrections[dump:dump+1, channels].compute() expected_corrections = corrections_per_corrprod([dump], channels, final_cal_products) assert_array_equal(corrections, expected_corrections)