def test_match_test(): obs = '1061313128_99bl_1pol_half_time' insfile = os.path.join(DATA_PATH, '%s_SSINS.h5' % obs) ins = INS(insfile) # Mock a simple metric_array and freq_array ins.metric_array = np.ma.ones([10, 20, 1]) ins.weights_array = np.copy(ins.metric_array) ins.freq_array = np.zeros([1, 20]) ins.freq_array = np.arange(20) # Make a shape dictionary for a shape that will be injected later shape = [7.9, 12.1] shape_dict = {'shape': shape} sig_thresh = {'shape': 5, 'narrow': 5, 'streak': 5} mf = MF(ins.freq_array, sig_thresh, shape_dict=shape_dict) # Inject a shape, narrow, and streak event ins.metric_array[3, 5] = 10 ins.metric_array[5] = 10 ins.metric_array[7, 7:13] = 10 ins.metric_ms = ins.mean_subtract() t_max, f_max, R_max, shape_max = mf.match_test(ins) print(shape_max) assert t_max == 5, "Wrong time" assert f_max == slice(0, 20), "Wrong freq" assert shape_max == 'streak', "Wrong shape"
def test_apply_match_test(): obs = '1061313128_99bl_1pol_half_time' insfile = os.path.join(DATA_PATH, '%s_SSINS.h5' % obs) ins = INS(insfile) # Mock a simple metric_array and freq_array ins.metric_array = np.ma.ones([10, 20, 1]) ins.weights_array = np.copy(ins.metric_array) ins.freq_array = np.zeros([1, 20]) ins.freq_array = np.arange(20) # Make a shape dictionary for a shape that will be injected later shape = [7.9, 12.1] shape_dict = {'shape': shape} sig_thresh = {'shape': 5, 'narrow': 5, 'streak': 5} mf = MF(ins.freq_array, sig_thresh, shape_dict=shape_dict) # Inject a shape, narrow, and streak event ins.metric_array[3, 5] = 10 ins.metric_array[5] = 10 ins.metric_array[7, 7:13] = 10 ins.metric_ms = ins.mean_subtract() ins.sig_array = np.ma.copy(ins.metric_ms) mf.apply_match_test(ins, event_record=True) # Check that the right events are flagged test_mask = np.zeros(ins.metric_array.shape, dtype=bool) test_mask[3, 5] = 1 test_mask[5] = 1 test_mask[7, 7:13] = 1 assert np.all(test_mask == ins.metric_array.mask), "Flags are incorrect" test_match_events_slc = [(5, slice(0, 20), 'streak'), (7, slice(7, 13), 'shape'), (3, slice(5, 6), 'narrow')] for i, event in enumerate(test_match_events_slc): assert ins.match_events[i][:-1] == test_match_events_slc[ i], "%ith event is wrong" % i assert not np.any([ins.match_events[i][-1] < 5 for i in range(3) ]), "Some significances were less than 5" # Test a funny if block that is required when the last time in a shape is flagged ins.metric_array[1:, 7:13] = np.ma.masked ins.metric_ms[0, 7:13] = 10 mf.apply_match_test(ins, event_record=True) assert np.all(ins.metric_ms.mask[:, 7:13] ), "All the times were not flagged for the shape"
def test_samp_thresh(): obs = '1061313128_99bl_1pol_half_time' insfile = os.path.join(DATA_PATH, '%s_SSINS.h5' % obs) out_prefix = os.path.join(DATA_PATH, '%s_test' % obs) match_outfile = '%s_SSINS_match_events.yml' % out_prefix ins = INS(insfile) # Mock a simple metric_array and freq_array ins.metric_array = np.ma.ones([10, 20, 1]) ins.weights_array = np.copy(ins.metric_array) ins.metric_ms = ins.mean_subtract() ins.sig_array = np.ma.copy(ins.metric_ms) ins.freq_array = np.zeros([1, 20]) ins.freq_array = np.arange(20) # Arbitrarily flag enough data in channel 10 sig_thresh = {'narrow': 5} mf = MF(ins.freq_array, sig_thresh, streak=False, N_samp_thresh=5) ins.metric_array[3:, 10] = np.ma.masked ins.metric_array[3:, 9] = np.ma.masked # Put in an outlier so it gets to samp_thresh_test ins.metric_array[0, 11] = 10 ins.metric_ms = ins.mean_subtract() bool_ind = np.zeros(ins.metric_array.shape, dtype=bool) bool_ind[:, 10] = 1 bool_ind[:, 9] = 1 bool_ind[0, 11] = 1 mf.apply_match_test(ins, event_record=True, apply_samp_thresh=True) test_match_events = [(0, slice(11, 12), 'narrow')] test_match_events += [(ind, slice(9, 10), 'samp_thresh') for ind in range(3)] test_match_events += [(ind, slice(10, 11), 'samp_thresh') for ind in range(3)] # Test stuff assert np.all( ins.metric_array.mask == bool_ind), "The right flags were not applied" for i, event in enumerate(test_match_events): assert ins.match_events[ i][:-1] == event, "The events weren't appended correctly" # Test that writing with samp_thresh flags is OK ins.write(out_prefix, output_type='match_events') test_match_events_read = ins.match_events_read(match_outfile) os.remove(match_outfile) assert ins.match_events == test_match_events_read # Test that exception is raised when N_samp_thresh is too high with pytest.raises(ValueError): mf = MF(ins.freq_array, {'narrow': 5, 'streak': 5}, N_samp_thresh=100) mf.apply_samp_thresh_test(ins)
def test_combine_ins_errors(): obs = "1061313128_99bl_1pol_half_time" testfile = os.path.join(DATA_PATH, f"{obs}.uvfits") autofile = os.path.join(DATA_PATH, "1061312640_autos.uvfits") mixfile = os.path.join(DATA_PATH, "1061312640_mix.uvfits") ss = SS() ss.read(testfile, diff=True) all_bls = ss.get_antpairs() first_50 = all_bls[:50] remaining = all_bls[50:] ss_first_50 = ss.select(bls=first_50, inplace=False) ss_remaining = ss.select(bls=remaining, inplace=False) ins_first_50 = INS(ss_first_50, use_integration_weights=True) ins_remaining = INS(ss_remaining, use_integration_weights=True) ins_sig_arr = np.ma.copy(ins_first_50.sig_array) ins_first_50.sig_array = ins_first_50.sig_array + 1 with pytest.warns(UserWarning, match="sig_array attribute"): new_ins = util.combine_ins(ins_first_50, ins_remaining) ss_autos = SS() ss_autos.read(autofile, diff=True) ss_cross = SS() ss_cross.read(mixfile, diff=True) auto_ins = INS(ss_autos, spectrum_type="auto") cross_ins = INS(ss_cross, spectrum_type="cross") with pytest.raises(ValueError, match="ins1 is of type"): new_ins = util.combine_ins(auto_ins, cross_ins) ins_remaining.polarization_array = np.array([-6]) with pytest.raises(ValueError, match="The spectra do not have the same pols"): new_ins = util.combine_ins(ins_first_50, ins_remaining) ins_remaining.freq_array = ins_remaining.freq_array + 1 with pytest.raises(ValueError, match="The spectra do not have the same frequencies"): new_ins = util.combine_ins(ins_first_50, ins_remaining) ins_remaining.time_array = ins_remaining.time_array + 1 with pytest.raises(ValueError, match="The spectra do not have matching time"): new_ins = util.combine_ins(ins_first_50, ins_remaining)