def call(self, inputs): print("xxxx",inputs) expanded_tensor = ktf.expand_dims(inputs[0], -1) multiples = [1, self.number_of_transforms, 1, 1, 1] tiled_tensor = ktf.tile(expanded_tensor, multiples=multiples) repeated_tensor = ktf.reshape(tiled_tensor, ktf.shape(inputs[0]) * np.array([self.number_of_transforms, 1, 1, 1])) affine_transforms = inputs[1] / self.affine_mul affine_transforms = ktf.reshape(affine_transforms, (-1, 8)) tranformed = tf_affine_transform(repeated_tensor, affine_transforms) res = ktf.reshape(tranformed, [-1, self.number_of_transforms] + self.image_size) res = ktf.transpose(res, [0, 2, 3, 1, 4]) #Use masks if len(inputs) == 3: mask = ktf.transpose(inputs[2], [0, 2, 3, 1]) mask = ktf.image.resize_images(mask, self.image_size[:2], method=ktf.image.ResizeMethod.NEAREST_NEIGHBOR) res = res * ktf.expand_dims(mask, axis=-1) if self.aggregation_fn == 'none': res = ktf.reshape(res, [-1] + self.image_size[:2] + [self.image_size[2] * self.number_of_transforms]) elif self.aggregation_fn == 'max': res = ktf.reduce_max(res, reduction_indices=[-2]) elif self.aggregation_fn == 'avg': counts = ktf.reduce_sum(mask, reduction_indices=[-1]) counts = ktf.expand_dims(counts, axis=-1) res = ktf.reduce_sum(res, reduction_indices=[-2]) res /= counts res = ktf.where(ktf.is_nan(res), ktf.zeros_like(res), res) return res
def check_angles(x, rotation_guess): x = tf.reshape(x, (-1, 1)) x = angle_mod(x) rA = radians(x) rA = tf.concat([tf.cos(rA), tf.sin(rA)], axis=-1) rI = tf.reshape(rotation_guess, (-1, 1)) rI = radians(rI) rI = tf.concat([tf.cos(rI), tf.sin(rI)], axis=-1) guess_test = tf.matmul(rA, rI, transpose_b=True) x = tf.where(guess_test < 0, angle_mod(x - 180), x) return x
def _upsampled_registration(target_image, src_image, upsample_factor): upsample_factor = tf.constant(upsample_factor, tf.float32) target_shape = tf.shape(target_image) target_image = tf.reshape(target_image, target_shape[:3]) src_shape = tf.shape(src_image) src_image = tf.reshape(src_image, src_shape[:3]) src_freq = fft2d(src_image) target_freq = fft2d(target_image) shape = tf.reshape(tf.shape(src_freq)[1:3], (1, 2)) shape = tf.cast(shape, tf.float32) shape = tf.tile(shape, (tf.shape(target_freq)[0], 1)) image_product = src_freq * tf.conj(target_freq) cross_correlation = tf.spectral.ifft2d(image_product) maxima = find_maxima(tf.abs(cross_correlation)) midpoints = fix(tf.cast(shape, tf.float32) / 2.) shifts = maxima shifts = tf.where(shifts > midpoints, shifts - shape, shifts) shifts = tf.round(shifts * upsample_factor) / upsample_factor upsampled_region_size = tf.ceil(upsample_factor * 1.5) dftshift = fix(upsampled_region_size / 2.0) normalization = tf.cast(tf.size(src_freq[0]), tf.float32) normalization *= upsample_factor**2 sample_region_offset = dftshift - shifts * upsample_factor data = tf.conj(image_product) upsampled_dft = _upsampled_dft(data, upsampled_region_size, upsample_factor, sample_region_offset) cross_correlation = tf.conj(upsampled_dft) cross_correlation /= tf.cast(normalization, tf.complex64) cross_correlation = tf.abs(cross_correlation) maxima = find_maxima(cross_correlation) maxima = maxima - dftshift shifts = shifts + maxima / upsample_factor return shifts
def create_sparse(labels): indices = tf.where(tf.not_equal(labels, 0)) values = tf.gather_nd(labels, indices) shape = tf.shape(labels, out_type=tf.int64) return tf.SparseTensor(indices, values, dense_shape=shape)
def get_lr_decay_schedule(args): number_of_iters_generator = 1000. * args.number_of_epochs number_of_iters_discriminator = 1000. * args.number_of_epochs * args.training_ratio if args.lr_decay_schedule is None: lr_decay_schedule_generator = lambda iter: 1. lr_decay_schedule_discriminator = lambda iter: 1. elif args.lr_decay_schedule == 'linear': lr_decay_schedule_generator = lambda iter: K.maximum( 0., 1. - K.cast(iter, 'float32') / number_of_iters_generator) lr_decay_schedule_discriminator = lambda iter: K.maximum( 0., 1. - K.cast(iter, 'float32') / number_of_iters_discriminator) elif args.lr_decay_schedule == 'half-linear': lr_decay_schedule_generator = lambda iter: ktf.where( K.less(iter, K.cast(number_of_iters_generator / 2, 'int64')), ktf.maximum( 0., 1. - (K.cast(iter, 'float32') / number_of_iters_generator)), 0.5) lr_decay_schedule_discriminator = lambda iter: ktf.where( K.less(iter, K.cast(number_of_iters_discriminator / 2, 'int64')), ktf.maximum( 0., 1. - (K.cast(iter, 'float32') / number_of_iters_discriminator)), 0.5) elif args.lr_decay_schedule == 'linear-end': decay_at = 0.828 number_of_iters_until_decay_generator = number_of_iters_generator * decay_at number_of_iters_until_decay_discriminator = number_of_iters_discriminator * decay_at number_of_iters_after_decay_generator = number_of_iters_generator * ( 1 - decay_at) number_of_iters_after_decay_discriminator = number_of_iters_discriminator * ( 1 - decay_at) lr_decay_schedule_generator = lambda iter: ktf.where( K.greater(iter, K.cast(number_of_iters_until_decay_generator, 'int64')), ktf.maximum( 0., 1. - (K.cast(iter, 'float32') - number_of_iters_until_decay_generator) / number_of_iters_after_decay_generator), 1) lr_decay_schedule_discriminator = lambda iter: ktf.where( K.greater( iter, K.cast(number_of_iters_until_decay_discriminator, 'int64' )), ktf.maximum( 0., 1. - (K.cast(iter, 'float32') - number_of_iters_until_decay_discriminator) / number_of_iters_after_decay_discriminator), 1) elif args.lr_decay_schedule.startswith("dropat"): drop_at = int(args.lr_decay_schedule.replace('dropat', '')) drop_at_generator = drop_at * 1000 drop_at_discriminator = drop_at * 1000 * args.training_ratio print("Drop at generator %s" % drop_at_generator) lr_decay_schedule_generator = lambda iter: (ktf.where( K.less(iter, drop_at_generator), 1., 0.1) * K.maximum( 0., 1. - K.cast(iter, 'float32') / number_of_iters_generator)) lr_decay_schedule_discriminator = lambda iter: (ktf.where( K.less(iter, drop_at_discriminator), 1., 0.1) * K.maximum( 0., 1. - K.cast(iter, 'float32') / number_of_iters_discriminator)) else: assert False return lr_decay_schedule_generator, lr_decay_schedule_discriminator
def fix(x): x = tf.where(x >= 0, tf.floor(x), tf.ceil(x)) return x
def angle_mod(x): x_test = fix(x / 360.) x = x - 360. * x_test x = tf.where(x < 0, x + 360, x) return x