예제 #1
0
파일: models.py 프로젝트: xulunk/TomograPy
def mask_object(cube, decimate=False, remove_nan=False, **kwargs):
    obj_rmin = kwargs.get('obj_rmin', None)
    obj_rmax = kwargs.get('obj_rmax', None)
    if obj_rmin is not None or obj_rmax is not None:
        obj_mask = solar.define_map_mask(cube, **kwargs)
        # decimate is mandatory to remove nan because NaN * 0 = NaN
        if decimate or remove_nan:
            Mo = lo.decimate(obj_mask, dtype=cube.dtype)
        else:
            Mo = lo.ndmask(obj_mask, dtype=cube.dtype)
    return Mo, obj_mask
예제 #2
0
파일: models.py 프로젝트: xulunk/TomograPy
def stsrt(data, cube, **kwargs):
    """
    Smooth Temporal Solar Rotational Tomography.
    Assumes data is sorted by observation time 'DATE_OBS'.

    Returns
    -------
    P : The projector with masking
    D : Smoothness priors
    obj_mask : object mask array
    data_mask : data mask array
    """
    # Parse kwargs.
    obj_rmin = kwargs.get('obj_rmin', None)
    obj_rmax = kwargs.get('obj_rmax', None)
    # mask data
    data_mask = solar.define_data_mask(data, **kwargs)
    # define temporal groups
    times = [solar.convert_time(h['DATE_OBS']) for h in data.header]
    ## if no interval is given separate every image
    dt_min = kwargs.get('dt_min', np.max(np.diff(times)) + 1)
    #groups = solar.temporal_groups(data, dt_min)
    ind = solar.temporal_groups_indexes(data, dt_min)
    n = len(ind)
    # define new 4D cube
    cube4 = cube[..., np.newaxis].repeat(n, axis=-1)
    cube4.header = copy.copy(cube.header)
    cube4.header['NAXIS'] = 4
    cube4.header['NAXIS4'] = cube4.shape[3]
    # define 4d model
    # XXX assumes all groups have same number of elements
    ng = data.shape[-1] / n
    P = siddon4d_lo(data.header, cube4.header, ng=ng, mask=data_mask, obstacle="sun")
    # priors
    D = smoothness_prior(cube4, kwargs.get("height_prior", False))
    # mask object
    if obj_rmin is not None or obj_rmax is not None:
        Mo, obj_mask = mask_object(cube, **kwargs)
        obj_mask = obj_mask[..., np.newaxis].repeat(n, axis=-1)
        if kwargs.get("decimate", False) or kwargs.get("remove_nan", False):
            Mo = lo.decimate(obj_mask)
        else:
            Mo = lo.ndmask(obj_mask)
        P = P * Mo.T
        D = [Di * Mo.T for Di in D]
    else:
        obj_mask = None
    return P, D, obj_mask, data_mask