Exemplo n.º 1
0
 def _test_missing_chunks(self, shape, chunk_overrides=None):
     # Put fake dataset into chunk store
     store = NpyFileChunkStore(self.tempdir)
     prefix = 'cb2'
     data, chunk_info = put_fake_dataset(store, prefix, shape, chunk_overrides)
     # Delete some random chunks in each array of the dataset
     missing_chunks = {}
     rs = random.Random(4)
     for array, info in chunk_info.items():
         array_name = store.join(prefix, array)
         slices = da.core.slices_from_chunks(info['chunks'])
         culled_slices = rs.sample(slices, len(slices) // 10 + 1)
         missing_chunks[array] = culled_slices
         for culled_slice in culled_slices:
             chunk_name, shape = store.chunk_metadata(array_name, culled_slice)
             os.remove(os.path.join(store.path, chunk_name) + '.npy')
     vfw = ChunkStoreVisFlagsWeights(store, chunk_info, None)
     assert_equal(vfw.store, store)
     assert_equal(vfw.vis_prefix, prefix)
     # Check that (only) missing chunks have been replaced by zeros
     vis = data['correlator_data']
     for culled_slice in missing_chunks['correlator_data']:
         vis[culled_slice] = 0.
     assert_array_equal(vfw.vis, vis)
     weights = data['weights'] * data['weights_channel'][..., np.newaxis]
     for culled_slice in missing_chunks['weights'] + missing_chunks['weights_channel']:
         weights[culled_slice] = 0.
     assert_array_equal(vfw.weights, weights)
     # Check that (only) missing chunks have been flagged as 'data lost'
     flags = data['flags']
     for culled_slice in missing_chunks['flags']:
         flags[culled_slice] = 0
     for culled_slice in itertools.chain(*missing_chunks.values()):
         flags[culled_slice] |= DATA_LOST
     assert_array_equal(vfw.flags, flags)
Exemplo n.º 2
0
 def test_missing_chunks(self):
     # Put fake dataset into chunk store
     store = NpyFileChunkStore(self.tempdir)
     base_name = 'cb2'
     shape = (10, 64, 30)
     data, chunk_info = put_fake_dataset(store, base_name, shape)
     # Delete a random chunk in each array of the dataset
     missing_chunks = {}
     rs = random.Random(4)
     for array, info in chunk_info.items():
         array_name = store.join(base_name, array)
         slices = da.core.slices_from_chunks(info['chunks'])
         culled_slice = rs.choice(slices)
         missing_chunks[array] = culled_slice
         chunk_name, shape = store.chunk_metadata(array_name, culled_slice)
         os.remove(os.path.join(store.path, chunk_name) + '.npy')
     vfw = ChunkStoreVisFlagsWeights(store, base_name, chunk_info)
     # Check that (only) missing chunks have been replaced by zeros
     vis = data['correlator_data']
     vis[missing_chunks['correlator_data']] = 0.
     assert_array_equal(vfw.vis, vis)
     weights = data['weights'] * data['weights_channel'][..., np.newaxis]
     weights[missing_chunks['weights']] = 0.
     weights[missing_chunks['weights_channel']] = 0.
     assert_array_equal(vfw.weights, weights)
     # Check that (only) missing chunks have been flagged as 'data lost'
     flags = data['flags']
     flags[missing_chunks['flags']] = 0.
     flags[missing_chunks['correlator_data']] |= 8
     flags[missing_chunks['weights']] |= 8
     flags[missing_chunks['weights_channel']] |= 8
     flags[missing_chunks['flags']] |= 8
     assert_array_equal(vfw.flags, flags)
Exemplo n.º 3
0
def main():
    args = parse_args()
    dask.config.set(num_workers=args.workers)

    # Lightweight open with no data - just to create telstate and identify the CBID
    ds = TelstateDataSource.from_url(args.source,
                                     upgrade_flags=False,
                                     chunk_store=None)
    # View the CBID, but not any specific stream
    cbid = ds.capture_block_id
    telstate = ds.telstate.root().view(cbid)
    streams = get_streams(telstate, args.streams)

    # Find all arrays in the selected streams, and also ensure we're not
    # trying to write things back on top of an existing dataset.
    arrays = {}
    for stream_name in streams:
        sts = view_capture_stream(telstate, cbid, stream_name)
        try:
            chunk_info = sts['chunk_info']
        except KeyError as exc:
            raise RuntimeError('Could not get chunk info for {!r}: {}'.format(
                stream_name, exc))
        for array_name, array_info in chunk_info.items():
            if args.new_prefix is not None:
                array_info[
                    'prefix'] = args.new_prefix + '-' + stream_name.replace(
                        '_', '-')
            prefix = array_info['prefix']
            path = os.path.join(args.dest, prefix)
            if os.path.exists(path):
                raise RuntimeError(
                    'Directory {!r} already exists'.format(path))
            store = get_chunk_store(args.source, sts, array_name)
            # Older files have dtype as an object that can't be encoded in msgpack
            dtype = np.dtype(array_info['dtype'])
            array_info['dtype'] = np.lib.format.dtype_to_descr(dtype)
            arrays[(stream_name, array_name)] = Array(stream_name, array_name,
                                                      store, array_info)

    # Apply DATA_LOST bits to the flags arrays. This is a less efficient approach than
    # datasources.py, but much simpler.
    for stream_name in streams:
        flags_array = arrays.get((stream_name, 'flags'))
        if not flags_array:
            continue
        sources = [stream_name]
        sts = view_capture_stream(telstate, cbid, stream_name)
        sources += sts['src_streams']
        for src_stream in sources:
            if src_stream not in streams:
                continue
            src_ts = view_capture_stream(telstate, cbid, src_stream)
            for array_name in src_ts['chunk_info']:
                if array_name == 'flags' and src_stream != stream_name:
                    # Upgraded flags completely replace the source stream's
                    # flags, rather than augmenting them. Thus, data lost in
                    # the source stream has no effect.
                    continue
                lost_flags = arrays[(src_stream, array_name)].lost_flags
                lost_flags = lost_flags.rechunk(
                    flags_array.data.chunks[:lost_flags.ndim])
                # weights_channel doesn't have a baseline axis
                while lost_flags.ndim < flags_array.data.ndim:
                    lost_flags = lost_flags[..., np.newaxis]
                lost_flags = da.broadcast_to(lost_flags,
                                             flags_array.data.shape,
                                             chunks=flags_array.data.chunks)
                flags_array.data |= lost_flags

    # Apply the rechunking specs
    for spec in args.spec:
        key = (spec.stream, spec.array)
        if key not in arrays:
            raise RuntimeError('{}/{} is not a known array'.format(
                spec.stream, spec.array))
        arrays[key].data = arrays[key].data.rechunk({
            0: spec.time,
            1: spec.freq
        })

    # Write out the new data
    dest_store = NpyFileChunkStore(args.dest)
    stores = []
    for array in arrays.values():
        full_name = dest_store.join(array.chunk_info['prefix'],
                                    array.array_name)
        dest_store.create_array(full_name)
        stores.append(dest_store.put_dask_array(full_name, array.data))
        array.chunk_info['chunks'] = array.data.chunks
    stores = da.compute(*stores)
    # put_dask_array returns an array with an exception object per chunk
    for result_set in stores:
        for result in result_set.flat:
            if result is not None:
                raise result

    # Fix up chunk_info for new chunking
    for stream_name in streams:
        sts = view_capture_stream(telstate, cbid, stream_name)
        chunk_info = sts['chunk_info']
        for array_name in chunk_info.keys():
            chunk_info[array_name] = arrays[(stream_name,
                                             array_name)].chunk_info
        sts.wrapped.delete('chunk_info')
        sts.wrapped['chunk_info'] = chunk_info
        # s3_endpoint_url is for the old version of the data
        sts.wrapped.delete('s3_endpoint_url')
        if args.s3_endpoint_url is not None:
            sts.wrapped['s3_endpoint_url'] = args.s3_endpoint_url

    # Write updated RDB file
    url_parts = urllib.parse.urlparse(args.source, scheme='file')
    dest_file = os.path.join(args.dest, args.new_prefix or cbid,
                             os.path.basename(url_parts.path))
    os.makedirs(os.path.dirname(dest_file), exist_ok=True)
    with RDBWriter(dest_file) as writer:
        writer.save(telstate.backend)
Exemplo n.º 4
0
class TestFlagWriterServer(BaseTestWriterServer):
    async def setup_server(self, **arg_overrides) -> FlagWriterServer:
        args = dict(host='127.0.0.1',
                    port=0,
                    loop=self.loop,
                    endpoints=self.endpoints,
                    flag_interface='lo',
                    flags_ibv=False,
                    chunk_store=self.chunk_store,
                    chunk_params=self.chunk_params,
                    telstate=self.telstate.root(),
                    input_name='sdp_l1_flags',
                    output_name='sdp_l1_flags',
                    rename_src={},
                    s3_endpoint_url=None,
                    max_workers=4,
                    buffer_dumps=2)
        args.update(arg_overrides)
        server = FlagWriterServer(**args)
        await server.start()
        self.addCleanup(server.stop)
        return server

    def setup_ig(self) -> spead2.send.ItemGroup:
        self.cbid = '1234567890'
        n_chans_per_substream = self.telstate['n_chans_per_substream']
        n_bls = self.telstate['n_bls']
        flags = np.random.randint(0, 256, (n_chans_per_substream, n_bls),
                                  np.uint8)

        ig = spead2.send.ItemGroup()
        # This is copied and adapted from katsdpcal
        ig.add_item(id=None,
                    name='flags',
                    description="Flags for visibilities",
                    shape=(self.telstate['n_chans_per_substream'],
                           self.telstate['n_bls']),
                    dtype=None,
                    format=[('u', 8)],
                    value=flags)
        ig.add_item(id=None,
                    name='timestamp',
                    description="Seconds since sync time",
                    shape=(),
                    dtype=None,
                    format=[('f', 64)],
                    value=100.0)
        ig.add_item(id=None,
                    name='dump_index',
                    description='Index in time',
                    shape=(),
                    dtype=None,
                    format=[('u', 64)],
                    value=0)
        ig.add_item(id=0x4103,
                    name='frequency',
                    description="Channel index of first channel in the heap",
                    shape=(),
                    dtype=np.uint32,
                    value=0)
        ig.add_item(id=None,
                    name='capture_block_id',
                    description='SDP capture block ID',
                    shape=(None, ),
                    dtype=None,
                    format=[('c', 8)],
                    value=self.cbid)
        return ig

    async def stop_server(self) -> None:
        for queue in self.inproc_queues.values():
            queue.stop()
        await self.server.stop()

    async def setUp(self) -> None:
        self.npy_path = tempfile.mkdtemp()
        self.addCleanup(shutil.rmtree, self.npy_path)
        self.chunk_store = NpyFileChunkStore(self.npy_path)
        self.telstate = self.setup_telstate('sdp_l1_flags')
        self.telstate['src_streams'] = ['sdp_l0']
        self.chunk_channels = 128
        self.chunk_params = ChunkParams(
            self.telstate['n_bls'] * self.chunk_channels, self.chunk_channels)
        self.setup_sleep()
        self.setup_spead()
        self.server = await self.setup_server()
        self.client = await self.setup_client(self.server)
        self.ig = self.setup_ig()

    def _check_chunk_info(self,
                          output_name: str = 'sdp_l1_flags') -> Dict[str, Any]:
        n_chans = self.telstate['n_chans']
        n_bls = self.telstate['n_bls']
        capture_stream = '{}_{}'.format(self.cbid, output_name)

        view = self.telstate.root().view(capture_stream)
        chunk_info = view['chunk_info']
        n_chunks = n_chans // self.chunk_channels
        assert_equal(
            chunk_info, {
                'flags': {
                    'prefix':
                    capture_stream.replace('_', '-'),
                    'shape': (1, n_chans, n_bls),
                    'chunks':
                    ((1, ), (self.chunk_channels, ) * n_chunks, (n_bls, )),
                    'dtype':
                    np.dtype(np.uint8)
                }
            })
        return chunk_info['flags']

    async def test_capture(self, output_name: str = 'sdp_l1_flags') -> None:
        n_chans_per_substream = self.telstate['n_chans_per_substream']
        self.assert_sensor_equals('status', Status.WAIT_DATA)
        self.assert_sensor_equals('capture-block-state', '{}')

        await self.client.request('capture-init', self.cbid)
        self.assert_sensor_equals('capture-block-state',
                                  '{"%s": "CAPTURING"}' % self.cbid)

        await self.send_heap(self.tx[0], self.ig.get_heap())
        self.assert_sensor_equals('status', Status.CAPTURING)

        await self.client.request('capture-done', self.cbid)
        self.assert_sensor_equals(
            'status', Status.CAPTURING)  # Should still be capturing
        self.assert_sensor_equals('capture-block-state', '{}')
        await self.stop_server()
        capture_stream = '{}_{}'.format(self.cbid, output_name)
        prefix = capture_stream.replace('_', '-')
        assert_true(self.chunk_store.is_complete(prefix))

        # Validate the data written
        chunk_info = self._check_chunk_info(output_name)
        data = self.chunk_store.get_dask_array(
            self.chunk_store.join(chunk_info['prefix'], 'flags'),
            chunk_info['chunks'], chunk_info['dtype']).compute()
        n_chans_per_substream = self.telstate['n_chans_per_substream']
        np.testing.assert_array_equal(self.ig['flags'].value[np.newaxis],
                                      data[:, :n_chans_per_substream, :])
        np.testing.assert_equal(0, data[:, n_chans_per_substream:, :])

    async def test_new_name(self) -> None:
        # Replace client and server with different args
        output_name = 'sdp_l1_flags_new'
        rename_src = {'sdp_l0': 'sdp_l0_new'}
        s3_endpoint_url = 'http://new.invalid/'
        await self.server.stop()
        self.server = await self.setup_server(output_name=output_name,
                                              rename_src=rename_src,
                                              s3_endpoint_url=s3_endpoint_url)
        self.client = await self.setup_client(self.server)
        await self.test_capture(output_name)
        telstate_output = self.telstate.root().view(output_name)
        assert_equal(telstate_output['inherit'], 'sdp_l1_flags')
        assert_equal(telstate_output['s3_endpoint_url'], s3_endpoint_url)
        assert_equal(telstate_output['src_streams'], ['sdp_l0_new'])

    async def test_failed_write(self) -> None:
        with mock.patch.object(NpyFileChunkStore,
                               'put_chunk',
                               side_effect=katdal.chunkstore.StoreUnavailable):
            await self.client.request('capture-init', self.cbid)
            await self.send_heap(self.tx[0], self.ig.get_heap())
            await self.client.request('capture-done', self.cbid)
        self._check_chunk_info()
        self.assert_sensor_equals('device-status', DeviceStatus.FAIL,
                                  {Sensor.Status.ERROR})

    async def test_double_init(self) -> None:
        await self.client.request('capture-init', self.cbid)
        with assert_raises_regex(aiokatcp.FailReply, 'already active'):
            await self.client.request('capture-init', self.cbid)
        self.assert_sensor_equals('capture-block-state',
                                  '{"%s": "CAPTURING"}' % self.cbid)

    async def test_done_without_init(self) -> None:
        with assert_raises_regex(aiokatcp.FailReply, 'unknown'):
            await self.client.request('capture-done', self.cbid)

    async def test_no_data(self) -> None:
        self.assert_sensor_equals('capture-block-state', '{}')
        await self.client.request('capture-init', self.cbid)
        self.assert_sensor_equals('capture-block-state',
                                  '{"%s": "CAPTURING"}' % self.cbid)
        with assert_logs('katsdpdatawriter.flag_writer', 'WARNING'):
            await self.client.request('capture-done', self.cbid)
        self.assert_sensor_equals('capture-block-state', '{}')

    async def test_data_after_done(self) -> None:
        await self.client.request('capture-init', self.cbid)
        await self.client.request('capture-done', self.cbid)
        with assert_logs('katsdpdatawriter.flag_writer', 'WARNING') as cm:
            await self.send_heap(self.tx[0], self.ig.get_heap())
        assert_regex(cm.output[0], 'outside of init/done')