Пример #1
0
def test_intersect_nan_single():
    old_chunks = ((float('nan'),), (10,))
    new_chunks = ((float('nan'),), (5, 5))

    result = list(intersect_chunks(old_chunks, new_chunks))
    expected = [(((0, slice(0, None, None)), (0, slice(0, 5, None))),),
                (((0, slice(0, None, None)), (0, slice(5, 10, None))),)]
    assert result == expected
Пример #2
0
def test_intersect_nan_single():
    old_chunks = ((float('nan'), ), (10, ))
    new_chunks = ((float('nan'), ), (5, 5))

    result = list(intersect_chunks(old_chunks, new_chunks))
    expected = [(((0, slice(0, None, None)), (0, slice(0, 5, None))), ),
                (((0, slice(0, None, None)), (0, slice(5, 10, None))), )]
    assert result == expected
Пример #3
0
def test_intersect_1():
    """ Convert 1 D chunks"""
    old = ((10, 10, 10, 10, 10), )
    new = ((25, 5, 20), )
    answer = ((((0, slice(0, 10, None)), ), ((1, slice(0, 10, None)), ),
               ((2, slice(0, 5, None)), )), (((2, slice(5, 10, None)), ), ),
              (((3, slice(0, 10, None)), ), ((4, slice(0, 10, None)), )))
    cross = intersect_chunks(old_chunks=old, new_chunks=new)
    assert answer == cross
Пример #4
0
def test_intersect_nan():
    old_chunks = ((float('nan'), float('nan')), (8, ))
    new_chunks = ((float('nan'), float('nan')), (4, 4))

    result = list(intersect_chunks(old_chunks, new_chunks))
    expected = [(((0, slice(0, None, None)), (0, slice(0, 4, None))), ),
                (((0, slice(0, None, None)), (0, slice(4, 8, None))), ),
                (((1, slice(0, None, None)), (0, slice(0, 4, None))), ),
                (((1, slice(0, None, None)), (0, slice(4, 8, None))), )]
    assert result == expected
Пример #5
0
def test_intersect_1():
    """ Convert 1 D chunks"""
    old = ((10, 10, 10, 10, 10), )
    new = ((25, 5, 20), )
    answer = ((((0, slice(0, 10, None)), ),
              ((1, slice(0, 10, None)), ),
              ((2, slice(0, 5, None)), )),
              (((2, slice(5, 10, None)), ), ),
              (((3, slice(0, 10, None)), ), ((4, slice(0, 10, None)), )))
    cross = intersect_chunks(old_chunks=old, new_chunks=new)
    assert answer == cross
Пример #6
0
def test_intersect_2():
    """ Convert 1 D chunks"""
    old = ((20, 20, 20, 20, 20), )
    new = ((58, 4, 20, 18), )
    answer = ((((0, slice(0, 20, None)), ), ((1, slice(0, 20, None)), ),
               ((2, slice(0, 18, None)), )), (((2, slice(18, 20, None)), ),
                                              ((3, slice(0, 2, None)), )),
              (((3, slice(2, 20, None)), ),
               ((4, slice(0, 2, None)), )), (((4, slice(2, 20, None)), ), ))
    cross = intersect_chunks(old_chunks=old, new_chunks=new)
    assert answer == cross
Пример #7
0
def test_intersect_2():
    """ Convert 1 D chunks"""
    old = ((20, 20, 20, 20, 20), )
    new = ((58, 4, 20, 18),)
    answer = ((((0, slice(0, 20, None)), ),
              ((1, slice(0, 20, None)), ),
              ((2, slice(0, 18, None)), )),
              (((2, slice(18, 20, None)), ), ((3, slice(0, 2, None)), )),
              (((3, slice(2, 20, None)), ), ((4, slice(0, 2, None)), )),
              (((4, slice(2, 20, None)), ), ))
    cross = intersect_chunks(old_chunks=old, new_chunks=new)
    assert answer == cross
Пример #8
0
def test_intersect_nan():
    old_chunks = ((float('nan'), float('nan')), (8,))
    new_chunks = ((float('nan'), float('nan')), (4, 4))

    result = list(intersect_chunks(old_chunks, new_chunks))
    expected = [
        (((0, slice(0, None, None)), (0, slice(0, 4, None))),),
        (((0, slice(0, None, None)), (0, slice(4, 8, None))),),
        (((1, slice(0, None, None)), (0, slice(0, 4, None))),),
        (((1, slice(0, None, None)), (0, slice(4, 8, None))),)
    ]
    assert result == expected
Пример #9
0
def test_intersect_nan_long():

    old_chunks = (tuple([float('nan')] * 4), (10, ))
    new_chunks = (tuple([float('nan')] * 4), (5, 5))
    result = list(intersect_chunks(old_chunks, new_chunks))
    expected = [(((0, slice(0, None, None)), (0, slice(0, 5, None))), ),
                (((0, slice(0, None, None)), (0, slice(5, 10, None))), ),
                (((1, slice(0, None, None)), (0, slice(0, 5, None))), ),
                (((1, slice(0, None, None)), (0, slice(5, 10, None))), ),
                (((2, slice(0, None, None)), (0, slice(0, 5, None))), ),
                (((2, slice(0, None, None)), (0, slice(5, 10, None))), ),
                (((3, slice(0, None, None)), (0, slice(0, 5, None))), ),
                (((3, slice(0, None, None)), (0, slice(5, 10, None))), )]
    assert result == expected
Пример #10
0
def test_intersect_nan_long():

    old_chunks = (tuple([float('nan')] * 4), (10,))
    new_chunks = (tuple([float('nan')] * 4), (5, 5))
    result = list(intersect_chunks(old_chunks, new_chunks))
    expected = [
        (((0, slice(0, None, None)), (0, slice(0, 5, None))),),
        (((0, slice(0, None, None)), (0, slice(5, 10, None))),),
        (((1, slice(0, None, None)), (0, slice(0, 5, None))),),
        (((1, slice(0, None, None)), (0, slice(5, 10, None))),),
        (((2, slice(0, None, None)), (0, slice(0, 5, None))),),
        (((2, slice(0, None, None)), (0, slice(5, 10, None))),),
        (((3, slice(0, None, None)), (0, slice(0, 5, None))),),
        (((3, slice(0, None, None)), (0, slice(5, 10, None))),)
    ]
    assert result == expected
Пример #11
0
def test_intersect_chunks_with_nonzero():
    from dask.array.rechunk import intersect_chunks

    old = ((4, 4), (2,))
    new = ((8,), (1, 1))
    result = list(intersect_chunks(old, new))
    expected = [
        (
            ((0, slice(0, 4, None)), (0, slice(0, 1, None))),
            ((1, slice(0, 4, None)), (0, slice(0, 1, None))),
        ),
        (
            ((0, slice(0, 4, None)), (0, slice(1, 2, None))),
            ((1, slice(0, 4, None)), (0, slice(1, 2, None))),
        ),
    ]
    assert result == expected
Пример #12
0
 def __init__(self, store, chunk_info):
     self.store = store
     darray = {}
     has_arrays = []
     for array, info in chunk_info.items():
         array_name = store.join(info['prefix'], array)
         chunk_args = (array_name, info['chunks'], info['dtype'])
         darray[array] = store.get_dask_array(*chunk_args)
         # Find all missing chunks in array and convert to 'data_lost' flags
         has_arrays.append((store.has_array(array_name, info['chunks'],
                                            info['dtype']), info['chunks']))
     vis = darray['correlator_data']
     base_name = chunk_info['correlator_data']['prefix']
     flags_raw_name = store.join(chunk_info['flags']['prefix'], 'flags_raw')
     # Combine original flags with data_lost indicating where values were lost from
     # other arrays.
     lost = defaultdict(
         list
     )  # Maps chunk index to list of index expressions to mark as lost
     for has_array, chunks in has_arrays:
         # array may have fewer dimensions than flags
         # (specifically, for weights_channel).
         if has_array.ndim < darray['flags'].ndim:
             chunks += tuple(
                 (x, ) for x in darray['flags'].shape[has_array.ndim:])
         intersections = intersect_chunks(darray['flags'].chunks, chunks)
         for has, pieces in itertools.izip(has_array.flat, intersections):
             if not has:
                 for piece in pieces:
                     chunk_idx, slices = zip(*piece)
                     lost[chunk_idx].append(slices)
     flags = da.map_blocks(_apply_data_lost,
                           darray['flags'],
                           dtype=np.uint8,
                           name=flags_raw_name,
                           lost=lost)
     # Combine low-resolution weights and high-resolution weights_channel
     weights = darray['weights'] * darray['weights_channel'][...,
                                                             np.newaxis]
     VisFlagsWeights.__init__(self, vis, flags, weights, base_name)
Пример #13
0
    def __init__(self, store, chunk_info, corrprods):
        self.store = store
        self.vis_prefix = chunk_info['correlator_data']['prefix']
        darray = {}
        for array, info in chunk_info.items():
            array_name = store.join(info['prefix'], array)
            chunk_args = (array_name, info['chunks'], info['dtype'])
            errors = DATA_LOST if array == 'flags' else 'none'
            darray[array] = store.get_dask_array(*chunk_args, errors=errors)
        flags_orig_name = darray['flags'].name
        flags_raw_name = store.join(chunk_info['flags']['prefix'], 'flags_raw')
        # Combine original flags with data_lost indicating where values were lost from
        # other arrays.
        lost_map = np.empty([len(c) for c in darray['flags'].chunks],
                            dtype="O")
        for index in np.ndindex(lost_map.shape):
            lost_map[index] = []
        for array_name, array in darray.items():
            if array_name == 'flags':
                continue
            # Source keys may appear multiple times in the array, so to save
            # memory we can pre-create the objects for the keys and reuse them
            # (idea borrowed from dask.array.rechunk).
            src_keys = np.empty([len(c) for c in array.chunks], dtype="O")
            for index in np.ndindex(src_keys.shape):
                src_keys[index] = (array.name, ) + index
            # array may have fewer dimensions than flags
            # (specifically, for weights_channel).
            chunks = array.chunks
            if array.ndim < darray['flags'].ndim:
                chunks += tuple(
                    (x, ) for x in darray['flags'].shape[array.ndim:])
            intersections = intersect_chunks(darray['flags'].chunks, chunks)
            for src_key, pieces in zip(src_keys.flat, intersections):
                for piece in pieces:
                    dst_index, slices = zip(*piece)
                    # if src_key is missing, then the parts of dst_index
                    # indicated by slices must be flagged.
                    # TODO: fast path for when slices covers the whole chunk?
                    lost_map[dst_index].extend([src_key, slices])
        dsk = {(flags_raw_name, ) + key:
               (_apply_data_lost, (flags_orig_name, ) + key, value)
               for key, value in np.ndenumerate(lost_map)}
        dsk = HighLevelGraph.from_collections(flags_raw_name,
                                              dsk,
                                              dependencies=list(
                                                  darray.values()))
        flags = da.Array(dsk,
                         flags_raw_name,
                         chunks=darray['flags'].chunks,
                         shape=darray['flags'].shape,
                         dtype=darray['flags'].dtype)
        darray['flags'] = flags

        # Turn missing blocks in the other arrays into zeros to make them
        # valid dask arrays.
        for array_name, array in darray.items():
            if array_name == 'flags':
                continue
            new_name = 'filled-' + array.name
            indices = itertools.product(*(range(len(c)) for c in array.chunks))
            dsk = {(new_name, ) + index:
                   (_default_zero, (array.name, ) + index, shape, array.dtype)
                   for index, shape in zip(indices,
                                           itertools.product(*array.chunks))}
            dsk = HighLevelGraph.from_collections(new_name,
                                                  dsk,
                                                  dependencies=[array])
            darray[array_name] = da.Array(dsk,
                                          new_name,
                                          chunks=array.chunks,
                                          shape=array.shape,
                                          dtype=array.dtype)

        vis = darray['correlator_data']
        # Combine low-resolution weights and high-resolution weights_channel
        weights = darray['weights'] * darray['weights_channel'][...,
                                                                np.newaxis]
        # Scale weights according to power
        if corrprods is not None:
            assert len(corrprods) == vis.shape[2]
            # Ensure that we have only a single chunk on the baseline axis.
            if len(vis.chunks[2]) > 1:
                vis = vis.rechunk({2: vis.shape[2]})
            if len(weights.chunks[2]) > 1:
                weights = weights.rechunk({2: weights.shape[2]})
            auto_indices, index1, index2 = corrprod_to_autocorr(corrprods)
            weights = da.blockwise(weight_power_scale,
                                   'ijk',
                                   vis,
                                   'ijk',
                                   weights,
                                   'ijk',
                                   dtype=np.float32,
                                   auto_indices=auto_indices,
                                   index1=index1,
                                   index2=index2)

        VisFlagsWeights.__init__(self, vis, flags, weights, self.vis_prefix)
Пример #14
0
    def __init__(self,
                 store,
                 chunk_info,
                 corrprods=None,
                 stored_weights_are_scaled=True,
                 van_vleck='off',
                 index=()):
        self.store = store
        self.vis_prefix = chunk_info['correlator_data']['prefix']
        darray = {}
        for array, info in chunk_info.items():
            array_name = store.join(info['prefix'], array)
            chunk_args = (array_name, info['chunks'], info['dtype'])
            errors = DATA_LOST if array == 'flags' else 'placeholder'
            darray[array] = store.get_dask_array(*chunk_args,
                                                 index=index,
                                                 errors=errors)
        flags_orig_name = darray['flags'].name
        flags_raw_name = store.join(chunk_info['flags']['prefix'], 'flags_raw')
        # Combine original flags with data_lost indicating where values were lost from
        # other arrays.
        lost_map = np.empty([len(c) for c in darray['flags'].chunks],
                            dtype="O")
        for index in np.ndindex(lost_map.shape):
            lost_map[index] = []
        for array_name, array in darray.items():
            if array_name == 'flags':
                continue
            # Source keys may appear multiple times in the array, so to save
            # memory we can pre-create the objects for the keys and reuse them
            # (idea borrowed from dask.array.rechunk).
            src_keys = np.empty([len(c) for c in array.chunks], dtype="O")
            for index in np.ndindex(src_keys.shape):
                src_keys[index] = (array.name, ) + index
            # array may have fewer dimensions than flags
            # (specifically, for weights_channel).
            chunks = array.chunks
            if array.ndim < darray['flags'].ndim:
                chunks += tuple(
                    (x, ) for x in darray['flags'].shape[array.ndim:])
            intersections = intersect_chunks(darray['flags'].chunks, chunks)
            for src_key, pieces in zip(src_keys.flat, intersections):
                for piece in pieces:
                    dst_index, slices = zip(*piece)
                    # if src_key is missing, then the parts of dst_index
                    # indicated by slices must be flagged.
                    # TODO: fast path for when slices covers the whole chunk?
                    lost_map[dst_index].extend([src_key, slices])
        dsk = {(flags_raw_name, ) + key:
               (_apply_data_lost, (flags_orig_name, ) + key, value)
               for key, value in np.ndenumerate(lost_map)}
        dsk = HighLevelGraph.from_collections(flags_raw_name,
                                              dsk,
                                              dependencies=list(
                                                  darray.values()))
        flags = da.Array(dsk,
                         flags_raw_name,
                         chunks=darray['flags'].chunks,
                         shape=darray['flags'].shape,
                         dtype=darray['flags'].dtype)
        darray['flags'] = flags

        # Turn missing blocks in the other arrays into zeros to make them
        # valid dask arrays.
        for array_name, array in darray.items():
            if array_name == 'flags':
                continue
            new_name = 'filled-' + array.name
            indices = itertools.product(*(range(len(c)) for c in array.chunks))
            dsk = {(new_name, ) + index:
                   (_default_zero, (array.name, ) + index)
                   for index, shape in zip(indices,
                                           itertools.product(*array.chunks))}
            dsk = HighLevelGraph.from_collections(new_name,
                                                  dsk,
                                                  dependencies=[array])
            darray[array_name] = da.Array(dsk,
                                          new_name,
                                          chunks=array.chunks,
                                          shape=array.shape,
                                          dtype=array.dtype)

        # Optionally correct visibilities for quantisation effects
        vis = darray['correlator_data']
        if van_vleck == 'autocorr':
            vis = correct_autocorr_quantisation(vis, corrprods)
        elif van_vleck != 'off':
            raise ValueError(
                "The van_vleck parameter should be one of ['off', 'autocorr'], "
                f"got '{van_vleck}' instead")

        # Combine low-resolution weights and high-resolution weights_channel
        stored_weights = darray['weights'] * darray['weights_channel'][
            ..., np.newaxis]
        # Scale weights according to power (or remove scaling if already applied)
        if corrprods is not None:
            if stored_weights_are_scaled:
                weights = stored_weights
                unscaled_weights = _scale_weights(vis,
                                                  stored_weights,
                                                  corrprods,
                                                  divide=False)
            else:
                weights = _scale_weights(vis,
                                         stored_weights,
                                         corrprods,
                                         divide=True)
                unscaled_weights = stored_weights
        else:
            if not stored_weights_are_scaled:
                raise ValueError(
                    'Stored weights are unscaled but no corrprods are provided'
                )
            weights = stored_weights
            # Don't bother with unscaled weights (it's optional)
            unscaled_weights = None
        VisFlagsWeights.__init__(self, vis, flags, weights, unscaled_weights,
                                 self.vis_prefix)