def _test_missing_chunks(self, shape, chunk_overrides=None): # Put fake dataset into chunk store store = NpyFileChunkStore(self.tempdir) prefix = 'cb2' data, chunk_info = put_fake_dataset(store, prefix, shape, chunk_overrides) # Delete some random chunks in each array of the dataset missing_chunks = {} rs = random.Random(4) for array, info in chunk_info.items(): array_name = store.join(prefix, array) slices = da.core.slices_from_chunks(info['chunks']) culled_slices = rs.sample(slices, len(slices) // 10 + 1) missing_chunks[array] = culled_slices for culled_slice in culled_slices: chunk_name, shape = store.chunk_metadata(array_name, culled_slice) os.remove(os.path.join(store.path, chunk_name) + '.npy') vfw = ChunkStoreVisFlagsWeights(store, chunk_info, None) assert_equal(vfw.store, store) assert_equal(vfw.vis_prefix, prefix) # Check that (only) missing chunks have been replaced by zeros vis = data['correlator_data'] for culled_slice in missing_chunks['correlator_data']: vis[culled_slice] = 0. assert_array_equal(vfw.vis, vis) weights = data['weights'] * data['weights_channel'][..., np.newaxis] for culled_slice in missing_chunks['weights'] + missing_chunks['weights_channel']: weights[culled_slice] = 0. assert_array_equal(vfw.weights, weights) # Check that (only) missing chunks have been flagged as 'data lost' flags = data['flags'] for culled_slice in missing_chunks['flags']: flags[culled_slice] = 0 for culled_slice in itertools.chain(*missing_chunks.values()): flags[culled_slice] |= DATA_LOST assert_array_equal(vfw.flags, flags)
def test_missing_chunks(self): # Put fake dataset into chunk store store = NpyFileChunkStore(self.tempdir) base_name = 'cb2' shape = (10, 64, 30) data, chunk_info = put_fake_dataset(store, base_name, shape) # Delete a random chunk in each array of the dataset missing_chunks = {} rs = random.Random(4) for array, info in chunk_info.items(): array_name = store.join(base_name, array) slices = da.core.slices_from_chunks(info['chunks']) culled_slice = rs.choice(slices) missing_chunks[array] = culled_slice chunk_name, shape = store.chunk_metadata(array_name, culled_slice) os.remove(os.path.join(store.path, chunk_name) + '.npy') vfw = ChunkStoreVisFlagsWeights(store, base_name, chunk_info) # Check that (only) missing chunks have been replaced by zeros vis = data['correlator_data'] vis[missing_chunks['correlator_data']] = 0. assert_array_equal(vfw.vis, vis) weights = data['weights'] * data['weights_channel'][..., np.newaxis] weights[missing_chunks['weights']] = 0. weights[missing_chunks['weights_channel']] = 0. assert_array_equal(vfw.weights, weights) # Check that (only) missing chunks have been flagged as 'data lost' flags = data['flags'] flags[missing_chunks['flags']] = 0. flags[missing_chunks['correlator_data']] |= 8 flags[missing_chunks['weights']] |= 8 flags[missing_chunks['weights_channel']] |= 8 flags[missing_chunks['flags']] |= 8 assert_array_equal(vfw.flags, flags)
def main(): args = parse_args() dask.config.set(num_workers=args.workers) # Lightweight open with no data - just to create telstate and identify the CBID ds = TelstateDataSource.from_url(args.source, upgrade_flags=False, chunk_store=None) # View the CBID, but not any specific stream cbid = ds.capture_block_id telstate = ds.telstate.root().view(cbid) streams = get_streams(telstate, args.streams) # Find all arrays in the selected streams, and also ensure we're not # trying to write things back on top of an existing dataset. arrays = {} for stream_name in streams: sts = view_capture_stream(telstate, cbid, stream_name) try: chunk_info = sts['chunk_info'] except KeyError as exc: raise RuntimeError('Could not get chunk info for {!r}: {}'.format( stream_name, exc)) for array_name, array_info in chunk_info.items(): if args.new_prefix is not None: array_info[ 'prefix'] = args.new_prefix + '-' + stream_name.replace( '_', '-') prefix = array_info['prefix'] path = os.path.join(args.dest, prefix) if os.path.exists(path): raise RuntimeError( 'Directory {!r} already exists'.format(path)) store = get_chunk_store(args.source, sts, array_name) # Older files have dtype as an object that can't be encoded in msgpack dtype = np.dtype(array_info['dtype']) array_info['dtype'] = np.lib.format.dtype_to_descr(dtype) arrays[(stream_name, array_name)] = Array(stream_name, array_name, store, array_info) # Apply DATA_LOST bits to the flags arrays. This is a less efficient approach than # datasources.py, but much simpler. for stream_name in streams: flags_array = arrays.get((stream_name, 'flags')) if not flags_array: continue sources = [stream_name] sts = view_capture_stream(telstate, cbid, stream_name) sources += sts['src_streams'] for src_stream in sources: if src_stream not in streams: continue src_ts = view_capture_stream(telstate, cbid, src_stream) for array_name in src_ts['chunk_info']: if array_name == 'flags' and src_stream != stream_name: # Upgraded flags completely replace the source stream's # flags, rather than augmenting them. Thus, data lost in # the source stream has no effect. continue lost_flags = arrays[(src_stream, array_name)].lost_flags lost_flags = lost_flags.rechunk( flags_array.data.chunks[:lost_flags.ndim]) # weights_channel doesn't have a baseline axis while lost_flags.ndim < flags_array.data.ndim: lost_flags = lost_flags[..., np.newaxis] lost_flags = da.broadcast_to(lost_flags, flags_array.data.shape, chunks=flags_array.data.chunks) flags_array.data |= lost_flags # Apply the rechunking specs for spec in args.spec: key = (spec.stream, spec.array) if key not in arrays: raise RuntimeError('{}/{} is not a known array'.format( spec.stream, spec.array)) arrays[key].data = arrays[key].data.rechunk({ 0: spec.time, 1: spec.freq }) # Write out the new data dest_store = NpyFileChunkStore(args.dest) stores = [] for array in arrays.values(): full_name = dest_store.join(array.chunk_info['prefix'], array.array_name) dest_store.create_array(full_name) stores.append(dest_store.put_dask_array(full_name, array.data)) array.chunk_info['chunks'] = array.data.chunks stores = da.compute(*stores) # put_dask_array returns an array with an exception object per chunk for result_set in stores: for result in result_set.flat: if result is not None: raise result # Fix up chunk_info for new chunking for stream_name in streams: sts = view_capture_stream(telstate, cbid, stream_name) chunk_info = sts['chunk_info'] for array_name in chunk_info.keys(): chunk_info[array_name] = arrays[(stream_name, array_name)].chunk_info sts.wrapped.delete('chunk_info') sts.wrapped['chunk_info'] = chunk_info # s3_endpoint_url is for the old version of the data sts.wrapped.delete('s3_endpoint_url') if args.s3_endpoint_url is not None: sts.wrapped['s3_endpoint_url'] = args.s3_endpoint_url # Write updated RDB file url_parts = urllib.parse.urlparse(args.source, scheme='file') dest_file = os.path.join(args.dest, args.new_prefix or cbid, os.path.basename(url_parts.path)) os.makedirs(os.path.dirname(dest_file), exist_ok=True) with RDBWriter(dest_file) as writer: writer.save(telstate.backend)
class TestFlagWriterServer(BaseTestWriterServer): async def setup_server(self, **arg_overrides) -> FlagWriterServer: args = dict(host='127.0.0.1', port=0, loop=self.loop, endpoints=self.endpoints, flag_interface='lo', flags_ibv=False, chunk_store=self.chunk_store, chunk_params=self.chunk_params, telstate=self.telstate.root(), input_name='sdp_l1_flags', output_name='sdp_l1_flags', rename_src={}, s3_endpoint_url=None, max_workers=4, buffer_dumps=2) args.update(arg_overrides) server = FlagWriterServer(**args) await server.start() self.addCleanup(server.stop) return server def setup_ig(self) -> spead2.send.ItemGroup: self.cbid = '1234567890' n_chans_per_substream = self.telstate['n_chans_per_substream'] n_bls = self.telstate['n_bls'] flags = np.random.randint(0, 256, (n_chans_per_substream, n_bls), np.uint8) ig = spead2.send.ItemGroup() # This is copied and adapted from katsdpcal ig.add_item(id=None, name='flags', description="Flags for visibilities", shape=(self.telstate['n_chans_per_substream'], self.telstate['n_bls']), dtype=None, format=[('u', 8)], value=flags) ig.add_item(id=None, name='timestamp', description="Seconds since sync time", shape=(), dtype=None, format=[('f', 64)], value=100.0) ig.add_item(id=None, name='dump_index', description='Index in time', shape=(), dtype=None, format=[('u', 64)], value=0) ig.add_item(id=0x4103, name='frequency', description="Channel index of first channel in the heap", shape=(), dtype=np.uint32, value=0) ig.add_item(id=None, name='capture_block_id', description='SDP capture block ID', shape=(None, ), dtype=None, format=[('c', 8)], value=self.cbid) return ig async def stop_server(self) -> None: for queue in self.inproc_queues.values(): queue.stop() await self.server.stop() async def setUp(self) -> None: self.npy_path = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, self.npy_path) self.chunk_store = NpyFileChunkStore(self.npy_path) self.telstate = self.setup_telstate('sdp_l1_flags') self.telstate['src_streams'] = ['sdp_l0'] self.chunk_channels = 128 self.chunk_params = ChunkParams( self.telstate['n_bls'] * self.chunk_channels, self.chunk_channels) self.setup_sleep() self.setup_spead() self.server = await self.setup_server() self.client = await self.setup_client(self.server) self.ig = self.setup_ig() def _check_chunk_info(self, output_name: str = 'sdp_l1_flags') -> Dict[str, Any]: n_chans = self.telstate['n_chans'] n_bls = self.telstate['n_bls'] capture_stream = '{}_{}'.format(self.cbid, output_name) view = self.telstate.root().view(capture_stream) chunk_info = view['chunk_info'] n_chunks = n_chans // self.chunk_channels assert_equal( chunk_info, { 'flags': { 'prefix': capture_stream.replace('_', '-'), 'shape': (1, n_chans, n_bls), 'chunks': ((1, ), (self.chunk_channels, ) * n_chunks, (n_bls, )), 'dtype': np.dtype(np.uint8) } }) return chunk_info['flags'] async def test_capture(self, output_name: str = 'sdp_l1_flags') -> None: n_chans_per_substream = self.telstate['n_chans_per_substream'] self.assert_sensor_equals('status', Status.WAIT_DATA) self.assert_sensor_equals('capture-block-state', '{}') await self.client.request('capture-init', self.cbid) self.assert_sensor_equals('capture-block-state', '{"%s": "CAPTURING"}' % self.cbid) await self.send_heap(self.tx[0], self.ig.get_heap()) self.assert_sensor_equals('status', Status.CAPTURING) await self.client.request('capture-done', self.cbid) self.assert_sensor_equals( 'status', Status.CAPTURING) # Should still be capturing self.assert_sensor_equals('capture-block-state', '{}') await self.stop_server() capture_stream = '{}_{}'.format(self.cbid, output_name) prefix = capture_stream.replace('_', '-') assert_true(self.chunk_store.is_complete(prefix)) # Validate the data written chunk_info = self._check_chunk_info(output_name) data = self.chunk_store.get_dask_array( self.chunk_store.join(chunk_info['prefix'], 'flags'), chunk_info['chunks'], chunk_info['dtype']).compute() n_chans_per_substream = self.telstate['n_chans_per_substream'] np.testing.assert_array_equal(self.ig['flags'].value[np.newaxis], data[:, :n_chans_per_substream, :]) np.testing.assert_equal(0, data[:, n_chans_per_substream:, :]) async def test_new_name(self) -> None: # Replace client and server with different args output_name = 'sdp_l1_flags_new' rename_src = {'sdp_l0': 'sdp_l0_new'} s3_endpoint_url = 'http://new.invalid/' await self.server.stop() self.server = await self.setup_server(output_name=output_name, rename_src=rename_src, s3_endpoint_url=s3_endpoint_url) self.client = await self.setup_client(self.server) await self.test_capture(output_name) telstate_output = self.telstate.root().view(output_name) assert_equal(telstate_output['inherit'], 'sdp_l1_flags') assert_equal(telstate_output['s3_endpoint_url'], s3_endpoint_url) assert_equal(telstate_output['src_streams'], ['sdp_l0_new']) async def test_failed_write(self) -> None: with mock.patch.object(NpyFileChunkStore, 'put_chunk', side_effect=katdal.chunkstore.StoreUnavailable): await self.client.request('capture-init', self.cbid) await self.send_heap(self.tx[0], self.ig.get_heap()) await self.client.request('capture-done', self.cbid) self._check_chunk_info() self.assert_sensor_equals('device-status', DeviceStatus.FAIL, {Sensor.Status.ERROR}) async def test_double_init(self) -> None: await self.client.request('capture-init', self.cbid) with assert_raises_regex(aiokatcp.FailReply, 'already active'): await self.client.request('capture-init', self.cbid) self.assert_sensor_equals('capture-block-state', '{"%s": "CAPTURING"}' % self.cbid) async def test_done_without_init(self) -> None: with assert_raises_regex(aiokatcp.FailReply, 'unknown'): await self.client.request('capture-done', self.cbid) async def test_no_data(self) -> None: self.assert_sensor_equals('capture-block-state', '{}') await self.client.request('capture-init', self.cbid) self.assert_sensor_equals('capture-block-state', '{"%s": "CAPTURING"}' % self.cbid) with assert_logs('katsdpdatawriter.flag_writer', 'WARNING'): await self.client.request('capture-done', self.cbid) self.assert_sensor_equals('capture-block-state', '{}') async def test_data_after_done(self) -> None: await self.client.request('capture-init', self.cbid) await self.client.request('capture-done', self.cbid) with assert_logs('katsdpdatawriter.flag_writer', 'WARNING') as cm: await self.send_heap(self.tx[0], self.ig.get_heap()) assert_regex(cm.output[0], 'outside of init/done')