def test_construction_from_url(self): view, cbid, sn, _, _ = make_fake_data_source(self.telstate, self.store, (20, 16, 40)) source_direct = TelstateDataSource(view, cbid, sn, self.store) # Save RDB file to e.g. 'tempdir/cb/cb_sdp_l0.rdb', as if 'tempdir' is a real S3 bucket rdb_dir = os.path.join(self.tempdir, cbid) os.mkdir(rdb_dir) rdb_filename = os.path.join(rdb_dir, f'{cbid}_{sn}.rdb') # Insert CBID and stream name at the top level, just like metawriter does self.telstate['capture_block_id'] = cbid self.telstate['stream_name'] = sn with RDBWriter(rdb_filename) as rdbw: rdbw.save(self.telstate) # Check that we can open RDB file and automatically infer the chunk store source_from_file = open_data_source(rdb_filename) assert_telstate_data_source_equal(source_from_file, source_direct) # Check that we can override the capture_block_id and stream name via query parameters query = urllib.parse.urlencode({ 'capture_block_id': cbid, 'stream_name': sn }) url = urllib.parse.urlunparse( ('file', '', rdb_filename, '', query, '')) source_from_url = TelstateDataSource.from_url(url, chunk_store=self.store) assert_telstate_data_source_equal(source_from_url, source_direct) # Check invalid URLs with assert_raises(DataSourceNotFound): open_data_source('ftp://unsupported') with assert_raises(DataSourceNotFound): open_data_source(rdb_filename[:-4])
def test_van_vleck(self): shape = (20, 16, 40) view, cbid, sn, l0_data, _ = make_fake_data_source( self.telstate, self.store, shape) # Uncorrected visibilities data_source = TelstateDataSource(view, cbid, sn, self.store, van_vleck='off') raw_vis = data_source.data.vis np.testing.assert_array_equal(raw_vis.compute(), l0_data['correlator_data']) # Corrected visibilities data_source2 = TelstateDataSource(view, cbid, sn, self.store, van_vleck='autocorr') corrected_vis = data_source2.data.vis expected_vis = correct_autocorr_quantisation(raw_vis, view['bls_ordering']) np.testing.assert_array_equal(corrected_vis.compute(), expected_vis.compute()) # Check parameter validation with assert_raises(ValueError): TelstateDataSource(view, cbid, sn, self.store, van_vleck='blah')
def test_upgrade_flags(self): shape = (20, 16, 40) view, cbid, sn, l0_data, l1_flags_data = \ make_fake_datasource(self.telstate, self.store, self.cbid, shape) data_source = TelstateDataSource(view, cbid, sn, self.store) np.testing.assert_array_equal(data_source.data.vis.compute(), l0_data['correlator_data']) np.testing.assert_array_equal(data_source.data.flags.compute(), l1_flags_data['flags']) # Again, now explicitly disabling the upgrade data_source = TelstateDataSource(view, cbid, sn, self.store, upgrade_flags=False) np.testing.assert_array_equal(data_source.data.vis.compute(), l0_data['correlator_data']) np.testing.assert_array_equal(data_source.data.flags.compute(), l0_data['flags'])
def test_timestamps(self): view, cbid, sn, l0_data, l1_flags_data = \ make_fake_datasource(self.telstate, self.store, self.cbid, (20, 64, 40)) data_source = TelstateDataSource(view, cbid, sn, self.store) np.testing.assert_array_equal( data_source.timestamps, np.arange(20, dtype=np.float32) * 2 + 123456912)
def test_upgrade_flags_extend_l0(self, l0_chunk_overrides=None, l1_flags_chunk_overrides=None): """L1 flags has more dumps than L0""" l0_shape = (18, 16, 40) l1_flags_shape = (20, 16, 40) view, cbid, sn, l0_data, l1_flags_data = make_fake_data_source( self.telstate, self.store, l0_shape, l1_flags_shape=l1_flags_shape, l0_chunk_overrides=l0_chunk_overrides, l1_flags_chunk_overrides=l1_flags_chunk_overrides) data_source = TelstateDataSource(view, cbid, sn, self.store) expected_timestamps = np.arange(l1_flags_shape[0], dtype=np.float32) * 2 + 1600000123 np.testing.assert_array_equal(data_source.timestamps, expected_timestamps) expected_vis = np.zeros(l1_flags_shape, l0_data['correlator_data'].dtype) expected_vis[:18] = l0_data['correlator_data'] expected_flags = l1_flags_data['flags'].copy() # The visibilities for this extension are lost, so the flags will mark it as such expected_flags[18:20] |= DATA_LOST np.testing.assert_array_equal(data_source.data.vis.compute(), expected_vis) np.testing.assert_array_equal(data_source.data.flags.compute(), expected_flags)
def test_upgrade_flags_extend_l1(self, l0_chunk_overrides=None, l1_flags_chunk_overrides=None): """L1 flags has fewer dumps than L0""" l0_shape = (20, 16, 40) l1_flags_shape = (18, 16, 40) view, cbid, sn, l0_data, l1_flags_data = make_fake_data_source( self.telstate, self.store, l0_shape, l1_flags_shape=l1_flags_shape, l0_chunk_overrides=l0_chunk_overrides, l1_flags_chunk_overrides=l1_flags_chunk_overrides) data_source = TelstateDataSource(view, cbid, sn, self.store) expected_timestamps = np.arange(l0_shape[0], dtype=np.float32) * 2 + 1600000123 np.testing.assert_array_equal(data_source.timestamps, expected_timestamps) np.testing.assert_array_equal(data_source.data.vis.compute(), l0_data['correlator_data']) expected_flags = np.zeros(l0_shape, np.uint8) expected_flags[:l1_flags_shape[0]] = l1_flags_data['flags'] expected_flags[l1_flags_shape[0]:] = DATA_LOST np.testing.assert_array_equal(data_source.data.flags.compute(), expected_flags)
def test_upgrade_flags_shape_mismatch(self): """L1 flags shape is incompatible with L0""" l0_shape = (18, 16, 40) l1_flags_shape = (20, 8, 40) view, cbid, sn, l0_data, l1_flags_data = \ make_fake_datasource(self.telstate, self.store, self.cbid, l0_shape, l1_flags_shape) with assert_raises(ValueError): TelstateDataSource(view, cbid, sn, self.store)
def test_timestamps_preselect(self): view, cbid, sn, l0_data, l1_flags_data = \ make_fake_data_source(self.telstate, self.store, (20, 64, 40)) data_source = TelstateDataSource(view, cbid, sn, self.store, preselect=dict(dumps=np.s_[2:10])) np.testing.assert_array_equal( data_source.timestamps, np.arange(2, 10, dtype=np.float32) * 2 + 1600000123)
def test_rdb_support(self): telstate = katsdptelstate.TelescopeState() view, cbid, sn, _, _ = make_fake_data_source(telstate, self.store, (5, 16, 40), PREFIX) telstate['capture_block_id'] = cbid telstate['stream_name'] = sn # Save telstate to temp RDB file since RDBWriter needs a filename and not a handle rdb_filename = f'{cbid}_{sn}.rdb' temp_filename = os.path.join(self.tempdir, rdb_filename) with RDBWriter(temp_filename) as rdbw: rdbw.save(telstate) # Read the file back in and upload it to S3 with open(temp_filename, mode='rb') as rdb_file: rdb_data = rdb_file.read() rdb_url = urllib.parse.urljoin(self.store_url, self.store.join(cbid, rdb_filename)) self.store.create_array(cbid) self.store.complete_request('PUT', rdb_url, data=rdb_data) # Check that data source can be constructed from URL (with auto chunk store) source_from_url = TelstateDataSource.from_url(rdb_url, **self.store_kwargs) source_direct = TelstateDataSource(view, cbid, sn, self.store) assert_telstate_data_source_equal(source_from_url, source_direct)
def test_preselect(self): view, cbid, sn, l0_data, l1_flags_data = \ make_fake_data_source(self.telstate, self.store, (20, 64, 40)) preselect = dict(dumps=np.s_[2:10], channels=np.s_[-20:]) index = np.s_[2:10, -20:] data_source = TelstateDataSource(view, cbid, sn, self.store, upgrade_flags=False, preselect=preselect) np.testing.assert_array_equal(data_source.data.vis.compute(), l0_data['correlator_data'][index]) np.testing.assert_array_equal(data_source.data.flags.compute(), l0_data['flags'][index])
def test_bad_preselect(self): view, cbid, sn, l0_data, l1_flags_data = \ make_fake_data_source(self.telstate, self.store, (20, 64, 40)) with assert_raises(IndexError): data_source = TelstateDataSource( view, cbid, sn, self.store, preselect=dict(dumps=np.s_[[1, 2]])) with assert_raises(IndexError): data_source = TelstateDataSource( view, cbid, sn, self.store, preselect=dict(dumps=np.s_[5:0:-1])) with assert_raises(IndexError): data_source = TelstateDataSource( view, cbid, sn, self.store, preselect=dict(frequencies=np.s_[:]))
def test_basic_timestamps(self): # Add a sensor to telstate to exercise the relevant code paths in TelstateDataSource self.telstate.add('obs_script_log', 'Digitisers synced', ts=1600000000.) view, cbid, sn, _, _ = make_fake_data_source(self.telstate, self.store, (20, 64, 40)) data_source = TelstateDataSource(view, cbid, sn, chunk_store=None, source_name='hello') assert 'hello' in data_source.name assert data_source.data is None expected_timestamps = np.arange(20, dtype=np.float32) * 2 + 1600000123 np.testing.assert_array_equal(data_source.timestamps, expected_timestamps)
def main(): args = parse_args() dask.config.set(num_workers=args.workers) # Lightweight open with no data - just to create telstate and identify the CBID ds = TelstateDataSource.from_url(args.source, upgrade_flags=False, chunk_store=None) # View the CBID, but not any specific stream cbid = ds.capture_block_id telstate = ds.telstate.root().view(cbid) streams = get_streams(telstate, args.streams) # Find all arrays in the selected streams, and also ensure we're not # trying to write things back on top of an existing dataset. arrays = {} for stream_name in streams: sts = view_capture_stream(telstate, cbid, stream_name) try: chunk_info = sts['chunk_info'] except KeyError as exc: raise RuntimeError('Could not get chunk info for {!r}: {}'.format( stream_name, exc)) for array_name, array_info in chunk_info.items(): if args.new_prefix is not None: array_info[ 'prefix'] = args.new_prefix + '-' + stream_name.replace( '_', '-') prefix = array_info['prefix'] path = os.path.join(args.dest, prefix) if os.path.exists(path): raise RuntimeError( 'Directory {!r} already exists'.format(path)) store = get_chunk_store(args.source, sts, array_name) # Older files have dtype as an object that can't be encoded in msgpack dtype = np.dtype(array_info['dtype']) array_info['dtype'] = np.lib.format.dtype_to_descr(dtype) arrays[(stream_name, array_name)] = Array(stream_name, array_name, store, array_info) # Apply DATA_LOST bits to the flags arrays. This is a less efficient approach than # datasources.py, but much simpler. for stream_name in streams: flags_array = arrays.get((stream_name, 'flags')) if not flags_array: continue sources = [stream_name] sts = view_capture_stream(telstate, cbid, stream_name) sources += sts['src_streams'] for src_stream in sources: if src_stream not in streams: continue src_ts = view_capture_stream(telstate, cbid, src_stream) for array_name in src_ts['chunk_info']: if array_name == 'flags' and src_stream != stream_name: # Upgraded flags completely replace the source stream's # flags, rather than augmenting them. Thus, data lost in # the source stream has no effect. continue lost_flags = arrays[(src_stream, array_name)].lost_flags lost_flags = lost_flags.rechunk( flags_array.data.chunks[:lost_flags.ndim]) # weights_channel doesn't have a baseline axis while lost_flags.ndim < flags_array.data.ndim: lost_flags = lost_flags[..., np.newaxis] lost_flags = da.broadcast_to(lost_flags, flags_array.data.shape, chunks=flags_array.data.chunks) flags_array.data |= lost_flags # Apply the rechunking specs for spec in args.spec: key = (spec.stream, spec.array) if key not in arrays: raise RuntimeError('{}/{} is not a known array'.format( spec.stream, spec.array)) arrays[key].data = arrays[key].data.rechunk({ 0: spec.time, 1: spec.freq }) # Write out the new data dest_store = NpyFileChunkStore(args.dest) stores = [] for array in arrays.values(): full_name = dest_store.join(array.chunk_info['prefix'], array.array_name) dest_store.create_array(full_name) stores.append(dest_store.put_dask_array(full_name, array.data)) array.chunk_info['chunks'] = array.data.chunks stores = da.compute(*stores) # put_dask_array returns an array with an exception object per chunk for result_set in stores: for result in result_set.flat: if result is not None: raise result # Fix up chunk_info for new chunking for stream_name in streams: sts = view_capture_stream(telstate, cbid, stream_name) chunk_info = sts['chunk_info'] for array_name in chunk_info.keys(): chunk_info[array_name] = arrays[(stream_name, array_name)].chunk_info sts.wrapped.delete('chunk_info') sts.wrapped['chunk_info'] = chunk_info # s3_endpoint_url is for the old version of the data sts.wrapped.delete('s3_endpoint_url') if args.s3_endpoint_url is not None: sts.wrapped['s3_endpoint_url'] = args.s3_endpoint_url # Write updated RDB file url_parts = urllib.parse.urlparse(args.source, scheme='file') dest_file = os.path.join(args.dest, args.new_prefix or cbid, os.path.basename(url_parts.path)) os.makedirs(os.path.dirname(dest_file), exist_ok=True) with RDBWriter(dest_file) as writer: writer.save(telstate.backend)
def _katdal_open(self, filename, **kwargs): """Mock implementation of katdal.open.""" data_source = TelstateDataSource( self.view, self.cbid, self.stream_name, chunk_store=self.store, **kwargs) return VisibilityDataV4(data_source, **kwargs)