def test_intersect_nan_single(): old_chunks = ((float('nan'),), (10,)) new_chunks = ((float('nan'),), (5, 5)) result = list(intersect_chunks(old_chunks, new_chunks)) expected = [(((0, slice(0, None, None)), (0, slice(0, 5, None))),), (((0, slice(0, None, None)), (0, slice(5, 10, None))),)] assert result == expected
def test_intersect_nan_single(): old_chunks = ((float('nan'), ), (10, )) new_chunks = ((float('nan'), ), (5, 5)) result = list(intersect_chunks(old_chunks, new_chunks)) expected = [(((0, slice(0, None, None)), (0, slice(0, 5, None))), ), (((0, slice(0, None, None)), (0, slice(5, 10, None))), )] assert result == expected
def test_intersect_1(): """ Convert 1 D chunks""" old = ((10, 10, 10, 10, 10), ) new = ((25, 5, 20), ) answer = ((((0, slice(0, 10, None)), ), ((1, slice(0, 10, None)), ), ((2, slice(0, 5, None)), )), (((2, slice(5, 10, None)), ), ), (((3, slice(0, 10, None)), ), ((4, slice(0, 10, None)), ))) cross = intersect_chunks(old_chunks=old, new_chunks=new) assert answer == cross
def test_intersect_nan(): old_chunks = ((float('nan'), float('nan')), (8, )) new_chunks = ((float('nan'), float('nan')), (4, 4)) result = list(intersect_chunks(old_chunks, new_chunks)) expected = [(((0, slice(0, None, None)), (0, slice(0, 4, None))), ), (((0, slice(0, None, None)), (0, slice(4, 8, None))), ), (((1, slice(0, None, None)), (0, slice(0, 4, None))), ), (((1, slice(0, None, None)), (0, slice(4, 8, None))), )] assert result == expected
def test_intersect_2(): """ Convert 1 D chunks""" old = ((20, 20, 20, 20, 20), ) new = ((58, 4, 20, 18), ) answer = ((((0, slice(0, 20, None)), ), ((1, slice(0, 20, None)), ), ((2, slice(0, 18, None)), )), (((2, slice(18, 20, None)), ), ((3, slice(0, 2, None)), )), (((3, slice(2, 20, None)), ), ((4, slice(0, 2, None)), )), (((4, slice(2, 20, None)), ), )) cross = intersect_chunks(old_chunks=old, new_chunks=new) assert answer == cross
def test_intersect_2(): """ Convert 1 D chunks""" old = ((20, 20, 20, 20, 20), ) new = ((58, 4, 20, 18),) answer = ((((0, slice(0, 20, None)), ), ((1, slice(0, 20, None)), ), ((2, slice(0, 18, None)), )), (((2, slice(18, 20, None)), ), ((3, slice(0, 2, None)), )), (((3, slice(2, 20, None)), ), ((4, slice(0, 2, None)), )), (((4, slice(2, 20, None)), ), )) cross = intersect_chunks(old_chunks=old, new_chunks=new) assert answer == cross
def test_intersect_nan(): old_chunks = ((float('nan'), float('nan')), (8,)) new_chunks = ((float('nan'), float('nan')), (4, 4)) result = list(intersect_chunks(old_chunks, new_chunks)) expected = [ (((0, slice(0, None, None)), (0, slice(0, 4, None))),), (((0, slice(0, None, None)), (0, slice(4, 8, None))),), (((1, slice(0, None, None)), (0, slice(0, 4, None))),), (((1, slice(0, None, None)), (0, slice(4, 8, None))),) ] assert result == expected
def test_intersect_nan_long(): old_chunks = (tuple([float('nan')] * 4), (10, )) new_chunks = (tuple([float('nan')] * 4), (5, 5)) result = list(intersect_chunks(old_chunks, new_chunks)) expected = [(((0, slice(0, None, None)), (0, slice(0, 5, None))), ), (((0, slice(0, None, None)), (0, slice(5, 10, None))), ), (((1, slice(0, None, None)), (0, slice(0, 5, None))), ), (((1, slice(0, None, None)), (0, slice(5, 10, None))), ), (((2, slice(0, None, None)), (0, slice(0, 5, None))), ), (((2, slice(0, None, None)), (0, slice(5, 10, None))), ), (((3, slice(0, None, None)), (0, slice(0, 5, None))), ), (((3, slice(0, None, None)), (0, slice(5, 10, None))), )] assert result == expected
def test_intersect_nan_long(): old_chunks = (tuple([float('nan')] * 4), (10,)) new_chunks = (tuple([float('nan')] * 4), (5, 5)) result = list(intersect_chunks(old_chunks, new_chunks)) expected = [ (((0, slice(0, None, None)), (0, slice(0, 5, None))),), (((0, slice(0, None, None)), (0, slice(5, 10, None))),), (((1, slice(0, None, None)), (0, slice(0, 5, None))),), (((1, slice(0, None, None)), (0, slice(5, 10, None))),), (((2, slice(0, None, None)), (0, slice(0, 5, None))),), (((2, slice(0, None, None)), (0, slice(5, 10, None))),), (((3, slice(0, None, None)), (0, slice(0, 5, None))),), (((3, slice(0, None, None)), (0, slice(5, 10, None))),) ] assert result == expected
def test_intersect_chunks_with_nonzero(): from dask.array.rechunk import intersect_chunks old = ((4, 4), (2,)) new = ((8,), (1, 1)) result = list(intersect_chunks(old, new)) expected = [ ( ((0, slice(0, 4, None)), (0, slice(0, 1, None))), ((1, slice(0, 4, None)), (0, slice(0, 1, None))), ), ( ((0, slice(0, 4, None)), (0, slice(1, 2, None))), ((1, slice(0, 4, None)), (0, slice(1, 2, None))), ), ] assert result == expected
def __init__(self, store, chunk_info): self.store = store darray = {} has_arrays = [] for array, info in chunk_info.items(): array_name = store.join(info['prefix'], array) chunk_args = (array_name, info['chunks'], info['dtype']) darray[array] = store.get_dask_array(*chunk_args) # Find all missing chunks in array and convert to 'data_lost' flags has_arrays.append((store.has_array(array_name, info['chunks'], info['dtype']), info['chunks'])) vis = darray['correlator_data'] base_name = chunk_info['correlator_data']['prefix'] flags_raw_name = store.join(chunk_info['flags']['prefix'], 'flags_raw') # Combine original flags with data_lost indicating where values were lost from # other arrays. lost = defaultdict( list ) # Maps chunk index to list of index expressions to mark as lost for has_array, chunks in has_arrays: # array may have fewer dimensions than flags # (specifically, for weights_channel). if has_array.ndim < darray['flags'].ndim: chunks += tuple( (x, ) for x in darray['flags'].shape[has_array.ndim:]) intersections = intersect_chunks(darray['flags'].chunks, chunks) for has, pieces in itertools.izip(has_array.flat, intersections): if not has: for piece in pieces: chunk_idx, slices = zip(*piece) lost[chunk_idx].append(slices) flags = da.map_blocks(_apply_data_lost, darray['flags'], dtype=np.uint8, name=flags_raw_name, lost=lost) # Combine low-resolution weights and high-resolution weights_channel weights = darray['weights'] * darray['weights_channel'][..., np.newaxis] VisFlagsWeights.__init__(self, vis, flags, weights, base_name)
def __init__(self, store, chunk_info, corrprods): self.store = store self.vis_prefix = chunk_info['correlator_data']['prefix'] darray = {} for array, info in chunk_info.items(): array_name = store.join(info['prefix'], array) chunk_args = (array_name, info['chunks'], info['dtype']) errors = DATA_LOST if array == 'flags' else 'none' darray[array] = store.get_dask_array(*chunk_args, errors=errors) flags_orig_name = darray['flags'].name flags_raw_name = store.join(chunk_info['flags']['prefix'], 'flags_raw') # Combine original flags with data_lost indicating where values were lost from # other arrays. lost_map = np.empty([len(c) for c in darray['flags'].chunks], dtype="O") for index in np.ndindex(lost_map.shape): lost_map[index] = [] for array_name, array in darray.items(): if array_name == 'flags': continue # Source keys may appear multiple times in the array, so to save # memory we can pre-create the objects for the keys and reuse them # (idea borrowed from dask.array.rechunk). src_keys = np.empty([len(c) for c in array.chunks], dtype="O") for index in np.ndindex(src_keys.shape): src_keys[index] = (array.name, ) + index # array may have fewer dimensions than flags # (specifically, for weights_channel). chunks = array.chunks if array.ndim < darray['flags'].ndim: chunks += tuple( (x, ) for x in darray['flags'].shape[array.ndim:]) intersections = intersect_chunks(darray['flags'].chunks, chunks) for src_key, pieces in zip(src_keys.flat, intersections): for piece in pieces: dst_index, slices = zip(*piece) # if src_key is missing, then the parts of dst_index # indicated by slices must be flagged. # TODO: fast path for when slices covers the whole chunk? lost_map[dst_index].extend([src_key, slices]) dsk = {(flags_raw_name, ) + key: (_apply_data_lost, (flags_orig_name, ) + key, value) for key, value in np.ndenumerate(lost_map)} dsk = HighLevelGraph.from_collections(flags_raw_name, dsk, dependencies=list( darray.values())) flags = da.Array(dsk, flags_raw_name, chunks=darray['flags'].chunks, shape=darray['flags'].shape, dtype=darray['flags'].dtype) darray['flags'] = flags # Turn missing blocks in the other arrays into zeros to make them # valid dask arrays. for array_name, array in darray.items(): if array_name == 'flags': continue new_name = 'filled-' + array.name indices = itertools.product(*(range(len(c)) for c in array.chunks)) dsk = {(new_name, ) + index: (_default_zero, (array.name, ) + index, shape, array.dtype) for index, shape in zip(indices, itertools.product(*array.chunks))} dsk = HighLevelGraph.from_collections(new_name, dsk, dependencies=[array]) darray[array_name] = da.Array(dsk, new_name, chunks=array.chunks, shape=array.shape, dtype=array.dtype) vis = darray['correlator_data'] # Combine low-resolution weights and high-resolution weights_channel weights = darray['weights'] * darray['weights_channel'][..., np.newaxis] # Scale weights according to power if corrprods is not None: assert len(corrprods) == vis.shape[2] # Ensure that we have only a single chunk on the baseline axis. if len(vis.chunks[2]) > 1: vis = vis.rechunk({2: vis.shape[2]}) if len(weights.chunks[2]) > 1: weights = weights.rechunk({2: weights.shape[2]}) auto_indices, index1, index2 = corrprod_to_autocorr(corrprods) weights = da.blockwise(weight_power_scale, 'ijk', vis, 'ijk', weights, 'ijk', dtype=np.float32, auto_indices=auto_indices, index1=index1, index2=index2) VisFlagsWeights.__init__(self, vis, flags, weights, self.vis_prefix)
def __init__(self, store, chunk_info, corrprods=None, stored_weights_are_scaled=True, van_vleck='off', index=()): self.store = store self.vis_prefix = chunk_info['correlator_data']['prefix'] darray = {} for array, info in chunk_info.items(): array_name = store.join(info['prefix'], array) chunk_args = (array_name, info['chunks'], info['dtype']) errors = DATA_LOST if array == 'flags' else 'placeholder' darray[array] = store.get_dask_array(*chunk_args, index=index, errors=errors) flags_orig_name = darray['flags'].name flags_raw_name = store.join(chunk_info['flags']['prefix'], 'flags_raw') # Combine original flags with data_lost indicating where values were lost from # other arrays. lost_map = np.empty([len(c) for c in darray['flags'].chunks], dtype="O") for index in np.ndindex(lost_map.shape): lost_map[index] = [] for array_name, array in darray.items(): if array_name == 'flags': continue # Source keys may appear multiple times in the array, so to save # memory we can pre-create the objects for the keys and reuse them # (idea borrowed from dask.array.rechunk). src_keys = np.empty([len(c) for c in array.chunks], dtype="O") for index in np.ndindex(src_keys.shape): src_keys[index] = (array.name, ) + index # array may have fewer dimensions than flags # (specifically, for weights_channel). chunks = array.chunks if array.ndim < darray['flags'].ndim: chunks += tuple( (x, ) for x in darray['flags'].shape[array.ndim:]) intersections = intersect_chunks(darray['flags'].chunks, chunks) for src_key, pieces in zip(src_keys.flat, intersections): for piece in pieces: dst_index, slices = zip(*piece) # if src_key is missing, then the parts of dst_index # indicated by slices must be flagged. # TODO: fast path for when slices covers the whole chunk? lost_map[dst_index].extend([src_key, slices]) dsk = {(flags_raw_name, ) + key: (_apply_data_lost, (flags_orig_name, ) + key, value) for key, value in np.ndenumerate(lost_map)} dsk = HighLevelGraph.from_collections(flags_raw_name, dsk, dependencies=list( darray.values())) flags = da.Array(dsk, flags_raw_name, chunks=darray['flags'].chunks, shape=darray['flags'].shape, dtype=darray['flags'].dtype) darray['flags'] = flags # Turn missing blocks in the other arrays into zeros to make them # valid dask arrays. for array_name, array in darray.items(): if array_name == 'flags': continue new_name = 'filled-' + array.name indices = itertools.product(*(range(len(c)) for c in array.chunks)) dsk = {(new_name, ) + index: (_default_zero, (array.name, ) + index) for index, shape in zip(indices, itertools.product(*array.chunks))} dsk = HighLevelGraph.from_collections(new_name, dsk, dependencies=[array]) darray[array_name] = da.Array(dsk, new_name, chunks=array.chunks, shape=array.shape, dtype=array.dtype) # Optionally correct visibilities for quantisation effects vis = darray['correlator_data'] if van_vleck == 'autocorr': vis = correct_autocorr_quantisation(vis, corrprods) elif van_vleck != 'off': raise ValueError( "The van_vleck parameter should be one of ['off', 'autocorr'], " f"got '{van_vleck}' instead") # Combine low-resolution weights and high-resolution weights_channel stored_weights = darray['weights'] * darray['weights_channel'][ ..., np.newaxis] # Scale weights according to power (or remove scaling if already applied) if corrprods is not None: if stored_weights_are_scaled: weights = stored_weights unscaled_weights = _scale_weights(vis, stored_weights, corrprods, divide=False) else: weights = _scale_weights(vis, stored_weights, corrprods, divide=True) unscaled_weights = stored_weights else: if not stored_weights_are_scaled: raise ValueError( 'Stored weights are unscaled but no corrprods are provided' ) weights = stored_weights # Don't bother with unscaled weights (it's optional) unscaled_weights = None VisFlagsWeights.__init__(self, vis, flags, weights, unscaled_weights, self.vis_prefix)