def interp(table, coord, input): ''' Interpolate onto coordinates with given interpolation table input - [num_channel, nx, ny] ''' coord_shape = util.get_shape(coord) ishape = util.get_shape(input) ndim = coord_shape[-1] # number of image dimensions bdim = len(ishape) - ndim # number of channel dimensions kdim = len(coord_shape) - 1 # number of kspace dimensions img_shape = ishape[-ndim:] center = [i // 2 for i in img_shape] with tf.name_scope('get_indices'): idx = mod(tf.cast(tf.round(coord), 'int64') + center, img_shape) input = tf.transpose(input, perm=list(range(bdim, bdim + ndim)) + list(range(bdim))) output = tf.gather_nd(input, idx) output = tf.transpose(output, perm=list(range(kdim, kdim + bdim)) + list(range(kdim))) with tf.name_scope('get_weights'): diff = abs(tf.round(coord) - coord) * 2.0 weight = tf.reduce_prod(linear_interp(table, diff), axis=-1) output *= weight return output
def check_codomain(self, output): if output.dtype.base_dtype != self.dtype: raise ValueError('output dtype mismatch, for {}, got {}' .format(self, output.dtype)) if util.get_shape(output) != self.oshape: raise ValueError('output shape mismatch, for {}, got {}' .format(self, util.get_shape(output)))
def check_domain(self, input): if input.dtype.base_dtype != self.dtype: raise ValueError('input data type mismatch, for {}, got {}' .format(self, input.dtype)) if util.get_shape(input) != self.ishape: raise ValueError('input shape mismatch, for {}, got {}' .format(self, util.get_shape(input)))
def __init__(self, ishape, table, coord, shift=None, dtype=tf.complex64): coord = tf.convert_to_tensor(coord) table = tf.convert_to_tensor(table) ndim = util.get_shape(coord)[-1] ishape = list(ishape) self.table = table self.coord = coord self.shift = shift batch = ishape[:-ndim] self.img_shape = ishape[-ndim:] oshape = batch + util.get_shape(coord)[:-1] super().__init__(oshape, ishape, dtype)
def __init__(self, oshape, ishape, mat, dtype=tf.complex64): self.mat = tf.convert_to_tensor(mat) mshape = util.get_shape(self.mat) max_ndim = max(len(oshape), len(ishape), len(mshape)) oshape_full = [1] * (max_ndim - len(oshape)) + oshape ishape_full = [1] * (max_ndim - len(ishape)) + ishape mshape_full = [1] * (max_ndim - len(mshape)) + mshape # Check dimension valid. for i, o, m in zip(ishape_full[:-2], oshape_full[:-2], mshape_full[:-2]): if not ((i == m and o == m) or (i == m and o == 1) or (i == 1 and o == m) or (i == 1 and o == 1) or (i == o and m == 1)): raise ValueError('Invalid dimensions: {}, {}, {}'. format(oshape, ishape, mat.shape)) self.osum_axes = [i for i in range(max_ndim) if oshape_full[i] == 1] self.isum_axes = [i for i in range(max_ndim) if ishape_full[i] == 1] ndim = len(oshape) self.perm = list(range(ndim - 2)) + [ndim - 1, ndim - 2] super().__init__(oshape, ishape, dtype)
def interpH(oshape, table, coord, input): coord_shape = util.get_shape(coord) ndim = coord_shape[-1] bdim = len(oshape) - ndim kdim = len(coord_shape) - 1 img_shape = oshape[-ndim:] center = [i // 2 for i in img_shape] idx = mod(tf.cast(tf.round(coord), 'int64') + center, img_shape) diff = abs(tf.round(coord) - coord) * 2.0 weight = tf.reduce_prod(linear_interp(table, diff), axis=-1) input *= weight input = tf.transpose(input, perm=list(range(bdim, bdim + kdim)) + list(range(bdim))) output = tf.scatter_nd(idx, input, oshape[-ndim:] + oshape[:-ndim]) output = tf.transpose(output, perm=list(range(ndim, ndim + bdim)) + list(range(ndim))) return output
def getNumInstances(self, infile, time_context=100, step=25): """ For a single .data file computes the number of examples of size \"time_context\" that can be created """ shape = util.get_shape(os.path.join(infile.replace('.data', '.shape'))) length_file = float(shape[0]) return np.maximum( 1, int(np.ceil((length_file - time_context) / self.step)))
def __init__(self, oshape, ishape, mult, dtype=tf.complex64): self.mult = tf.convert_to_tensor(mult) if self.mult.dtype.base_dtype != dtype: self.mult = tf.cast(self.mult, dtype) super().__init__(oshape, ishape, dtype) mshape = util.get_shape(self.mult) max_ndim = max(max(len(oshape), len(ishape)), len(mshape)) oshape = [1] * (max_ndim - len(oshape)) + list(oshape) ishape = [1] * (max_ndim - len(ishape)) + list(ishape) mshape = [1] * (max_ndim - len(mshape)) + list(mshape) self.osum_axis = [i for i in range(max_ndim) if (oshape[i] == 1 and (ishape[i] > 1 or mshape[i] > 1))] self.isum_axis = [i for i in range(max_ndim) if (ishape[i] == 1 and (oshape[i] > 1 or mshape[i] > 1))]
def right_sweep(mps, data, context): # do all but the rightmost site for loc in range(len(mps) - 1): effective_data, context = prepare_effective_data(mps, data, loc, context) local_tensor = fresh_local_tensor(effective_data) p, q = factor_local_tensor(local_tensor, shape=get_shape(mps, loc), direction='R') mps[loc] = p mps[loc + 1] = np.tensordot(q, mps[loc + 1], axes=[-1, 0]) if context.get('combs_l'): # not there on the first sweep context['combs_l'][loc] = None context['combs_r'][loc] = None context['step'] += 1 return mps, context
def left_sweep(mps, data, context): last_index = len(mps) - 1 for offset in range(len(mps) - 1): loc = last_index - offset effective_data, context = prepare_effective_data(mps, data, loc, context) local_tensor = fresh_local_tensor(effective_data) p, q = factor_local_tensor(local_tensor, shape=get_shape(mps, loc), direction='L') mps[loc - 1] = np.matmul(mps[loc - 1], p) mps[loc] = q context['combs_l'][loc] = None context['combs_r'][loc] = None context['step'] += 1 return mps, context
def NUFFT(ishape, coord, n=128, shifts=None, dtype=tf.complex64): ''' ishape : [batch, nz, ny, nx] or [batch, ny, nx] coord : tensor of shape [..., ndim] width = 2.0 oversamp = 2 ''' coord = tf.convert_to_tensor(coord) ishape = list(ishape) ndim = util.get_shape(coord)[-1] beta = np.pi * (1.5 ** 2 - 0.8) ** 0.5 assert len(ishape) == ndim + 1 # Apodization D = KaiserApodize(ishape, beta, ndim, dtype=dtype) # Get interpolation table kaiser_table = tf.constant(np.concatenate([kb(np.arange(n) / n, 2.0, beta), [0]]), dtype='float32', name='kaiser_table') # FFT F = FFT(ishape, ndim=ndim, dtype=dtype) As = [] if shifts is None: shifts = list(product([0, 0.5], repeat=ndim)) for shift in shifts: freq_shift = [s / i for s, i in zip(shift, ishape[-ndim:])] L = LinearPhase(ishape, freq_shift, dtype=dtype) I = Interp(ishape, kaiser_table, coord, shift=shift, dtype=dtype) As.append(I * F * L * D) A = AddN(As) A.add_name_scope('NUFFT') return A
def __init__(self, ishape, filt, mode='full', dtype=tf.complex64): self.filt = tf.convert_to_tensor(filt) self.mode = mode self.ndim = len(ishape) - 2 fshape = util.get_shape(self.filt) if mode == 'full': oshape = ([ishape[0]] + [i1 + i2 - 1 for i1, i2 in zip(ishape[1:-1], fshape[:-2])] + [fshape[-1]]) self.mode_adj = 'valid' else: oshape = ([ishape[0]] + [i1 - i2 + 1 for i1, i2 in zip(ishape[1:-1], fshape[:-2])] + [fshape[-1]]) self.zshape = ([ishape[0]] + [i1 + i2 - 1 for i1, i2 in zip(ishape[1:-1], fshape[:-2])] + [fshape[-1]]) self.mode_adj = 'full' self.perm = list(range(self.ndim)) + [self.ndim + 1, self.ndim] self.filt_adj = tf.conj(tf.transpose(self.filt, self.perm)) super().__init__(oshape, ishape, dtype)
def getFeatureSize(self, infile): """ For a single .data file return the number of feature, e.g. number of spectrogram bins """ shape = util.get_shape(os.path.join(infile.replace('.data', '.shape'))) return shape[1]