Ejemplo n.º 1
0
def interp(table, coord, input):
    '''
    Interpolate onto coordinates with given interpolation table

    input - [num_channel, nx, ny]
    '''

    coord_shape = util.get_shape(coord)
    ishape = util.get_shape(input)
    ndim = coord_shape[-1]  # number of image dimensions
    bdim = len(ishape) - ndim  # number of channel dimensions
    kdim = len(coord_shape) - 1  # number of kspace dimensions

    img_shape = ishape[-ndim:]
    center = [i // 2 for i in img_shape]

    with tf.name_scope('get_indices'):
        idx = mod(tf.cast(tf.round(coord), 'int64') + center, img_shape)

    input = tf.transpose(input,
                         perm=list(range(bdim, bdim + ndim)) +
                         list(range(bdim)))

    output = tf.gather_nd(input, idx)

    output = tf.transpose(output,
                          perm=list(range(kdim, kdim + bdim)) +
                          list(range(kdim)))

    with tf.name_scope('get_weights'):
        diff = abs(tf.round(coord) - coord) * 2.0
        weight = tf.reduce_prod(linear_interp(table, diff), axis=-1)
        output *= weight

    return output
Ejemplo n.º 2
0
    def check_codomain(self, output):

        if output.dtype.base_dtype != self.dtype:
            raise ValueError('output dtype mismatch, for {}, got {}'
                             .format(self, output.dtype))

        if util.get_shape(output) != self.oshape:
            raise ValueError('output shape mismatch, for {}, got {}'
                             .format(self, util.get_shape(output)))
Ejemplo n.º 3
0
    def check_domain(self, input):

        if input.dtype.base_dtype != self.dtype:
            raise ValueError('input data type mismatch, for {}, got {}'
                             .format(self, input.dtype))

        if util.get_shape(input) != self.ishape:
            raise ValueError('input shape mismatch, for {}, got {}'
                             .format(self, util.get_shape(input)))
Ejemplo n.º 4
0
    def __init__(self, ishape, table, coord,
                 shift=None, dtype=tf.complex64):
        coord = tf.convert_to_tensor(coord)
        table = tf.convert_to_tensor(table)
        ndim = util.get_shape(coord)[-1]

        ishape = list(ishape)
        self.table = table
        self.coord = coord
        self.shift = shift
        batch = ishape[:-ndim]
        self.img_shape = ishape[-ndim:]
        oshape = batch + util.get_shape(coord)[:-1]

        super().__init__(oshape, ishape, dtype)
Ejemplo n.º 5
0
    def __init__(self, oshape, ishape, mat, dtype=tf.complex64):
        self.mat = tf.convert_to_tensor(mat)

        mshape = util.get_shape(self.mat)
        max_ndim = max(len(oshape), len(ishape), len(mshape))

        oshape_full = [1] * (max_ndim - len(oshape)) + oshape
        ishape_full = [1] * (max_ndim - len(ishape)) + ishape
        mshape_full = [1] * (max_ndim - len(mshape)) + mshape

        # Check dimension valid.
        for i, o, m in zip(ishape_full[:-2], oshape_full[:-2], mshape_full[:-2]):

            if not ((i == m and o == m) or
                    (i == m and o == 1) or
                    (i == 1 and o == m) or
                    (i == 1 and o == 1) or
                    (i == o and m == 1)):
                raise ValueError('Invalid dimensions: {}, {}, {}'.
                                 format(oshape, ishape, mat.shape))

        self.osum_axes = [i for i in range(max_ndim)
                          if oshape_full[i] == 1]
        self.isum_axes = [i for i in range(max_ndim)
                          if ishape_full[i] == 1]
        ndim = len(oshape)
        self.perm = list(range(ndim - 2)) + [ndim - 1, ndim - 2]

        super().__init__(oshape, ishape, dtype)
Ejemplo n.º 6
0
def interpH(oshape, table, coord, input):

    coord_shape = util.get_shape(coord)
    ndim = coord_shape[-1]
    bdim = len(oshape) - ndim
    kdim = len(coord_shape) - 1

    img_shape = oshape[-ndim:]
    center = [i // 2 for i in img_shape]

    idx = mod(tf.cast(tf.round(coord), 'int64') + center, img_shape)

    diff = abs(tf.round(coord) - coord) * 2.0
    weight = tf.reduce_prod(linear_interp(table, diff), axis=-1)
    input *= weight

    input = tf.transpose(input,
                         perm=list(range(bdim, bdim + kdim)) +
                         list(range(bdim)))
    output = tf.scatter_nd(idx, input, oshape[-ndim:] + oshape[:-ndim])

    output = tf.transpose(output,
                          perm=list(range(ndim, ndim + bdim)) +
                          list(range(ndim)))

    return output
Ejemplo n.º 7
0
 def getNumInstances(self, infile, time_context=100, step=25):
     """
     For a single .data file computes the number of examples of size \"time_context\" that can be created
     """
     shape = util.get_shape(os.path.join(infile.replace('.data', '.shape')))
     length_file = float(shape[0])
     return np.maximum(
         1, int(np.ceil((length_file - time_context) / self.step)))
Ejemplo n.º 8
0
    def __init__(self, oshape, ishape, mult, dtype=tf.complex64):

        self.mult = tf.convert_to_tensor(mult)
        if self.mult.dtype.base_dtype != dtype:
            self.mult = tf.cast(self.mult, dtype)
        super().__init__(oshape, ishape, dtype)

        mshape = util.get_shape(self.mult)
        max_ndim = max(max(len(oshape), len(ishape)), len(mshape))
        oshape = [1] * (max_ndim - len(oshape)) + list(oshape)
        ishape = [1] * (max_ndim - len(ishape)) + list(ishape)
        mshape = [1] * (max_ndim - len(mshape)) + list(mshape)

        self.osum_axis = [i for i in range(max_ndim) if (oshape[i] == 1 and
                                                         (ishape[i] > 1 or mshape[i] > 1))]
        self.isum_axis = [i for i in range(max_ndim) if (ishape[i] == 1 and
                                                         (oshape[i] > 1 or mshape[i] > 1))]
Ejemplo n.º 9
0
def right_sweep(mps, data, context):
    # do all but the rightmost site
    for loc in range(len(mps) - 1):
        effective_data, context = prepare_effective_data(mps, data, loc, context)

        local_tensor = fresh_local_tensor(effective_data)
        p, q = factor_local_tensor(local_tensor, shape=get_shape(mps, loc), direction='R')

        mps[loc] = p
        mps[loc + 1] = np.tensordot(q, mps[loc + 1], axes=[-1, 0])

        if context.get('combs_l'):  # not there on the first sweep
            context['combs_l'][loc] = None
        context['combs_r'][loc] = None
        context['step'] += 1

    return mps, context
Ejemplo n.º 10
0
def left_sweep(mps, data, context):
    last_index = len(mps) - 1
    for offset in range(len(mps) - 1):
        loc = last_index - offset

        effective_data, context = prepare_effective_data(mps, data, loc, context)

        local_tensor = fresh_local_tensor(effective_data)
        p, q = factor_local_tensor(local_tensor, shape=get_shape(mps, loc), direction='L')

        mps[loc - 1] = np.matmul(mps[loc - 1], p)
        mps[loc] = q

        context['combs_l'][loc] = None
        context['combs_r'][loc] = None
        context['step'] += 1

    return mps, context
Ejemplo n.º 11
0
def NUFFT(ishape, coord, n=128, shifts=None, dtype=tf.complex64):
    '''
    ishape : [batch, nz, ny, nx] or [batch, ny, nx]
    coord : tensor of shape [..., ndim]
    width = 2.0
    oversamp = 2
    '''
    coord = tf.convert_to_tensor(coord)
    ishape = list(ishape)
    ndim = util.get_shape(coord)[-1]
    beta = np.pi * (1.5 ** 2 - 0.8) ** 0.5
    assert len(ishape) == ndim + 1

    # Apodization
    D = KaiserApodize(ishape, beta, ndim, dtype=dtype)

    # Get interpolation table
    kaiser_table = tf.constant(np.concatenate([kb(np.arange(n) / n, 2.0, beta), [0]]),
                               dtype='float32',
                               name='kaiser_table')

    # FFT
    F = FFT(ishape, ndim=ndim, dtype=dtype)

    As = []
    if shifts is None:
        shifts = list(product([0, 0.5], repeat=ndim))
    for shift in shifts:
        freq_shift = [s / i for s, i in zip(shift, ishape[-ndim:])]

        L = LinearPhase(ishape, freq_shift, dtype=dtype)
        I = Interp(ishape, kaiser_table, coord, shift=shift, dtype=dtype)

        As.append(I * F * L * D)

    A = AddN(As)
    A.add_name_scope('NUFFT')

    return A
Ejemplo n.º 12
0
    def __init__(self, ishape, filt, mode='full', dtype=tf.complex64):
        self.filt = tf.convert_to_tensor(filt)
        self.mode = mode
        self.ndim = len(ishape) - 2
        fshape = util.get_shape(self.filt)

        if mode == 'full':
            oshape = ([ishape[0]] +
                      [i1 + i2 - 1 for i1, i2 in zip(ishape[1:-1], fshape[:-2])]
                      + [fshape[-1]])
            self.mode_adj = 'valid'
        else:
            oshape = ([ishape[0]] +
                      [i1 - i2 + 1 for i1, i2 in zip(ishape[1:-1], fshape[:-2])]
                      + [fshape[-1]])
            self.zshape = ([ishape[0]] +
                           [i1 + i2 - 1 for i1, i2 in zip(ishape[1:-1], fshape[:-2])]
                           + [fshape[-1]])
            self.mode_adj = 'full'
        self.perm = list(range(self.ndim)) + [self.ndim + 1, self.ndim]
        self.filt_adj = tf.conj(tf.transpose(self.filt, self.perm))

        super().__init__(oshape, ishape, dtype)
Ejemplo n.º 13
0
 def getFeatureSize(self, infile):
     """
     For a single .data file return the number of feature, e.g. number of spectrogram bins
     """
     shape = util.get_shape(os.path.join(infile.replace('.data', '.shape')))
     return shape[1]