示例#1
0
    def call(self, inputs):
        print("xxxx",inputs)
        expanded_tensor = ktf.expand_dims(inputs[0], -1)
        multiples = [1, self.number_of_transforms, 1, 1, 1]
        tiled_tensor = ktf.tile(expanded_tensor, multiples=multiples)
        repeated_tensor = ktf.reshape(tiled_tensor, ktf.shape(inputs[0]) * np.array([self.number_of_transforms, 1, 1, 1]))

        affine_transforms = inputs[1] / self.affine_mul

        affine_transforms = ktf.reshape(affine_transforms, (-1, 8))
        tranformed = tf_affine_transform(repeated_tensor, affine_transforms)
        res = ktf.reshape(tranformed, [-1, self.number_of_transforms] + self.image_size)
        res = ktf.transpose(res, [0, 2, 3, 1, 4])

        #Use masks
        if len(inputs) == 3:
            mask = ktf.transpose(inputs[2], [0, 2, 3, 1])
            mask = ktf.image.resize_images(mask, self.image_size[:2], method=ktf.image.ResizeMethod.NEAREST_NEIGHBOR)
            res = res * ktf.expand_dims(mask, axis=-1)


        if self.aggregation_fn == 'none':
            res = ktf.reshape(res, [-1] + self.image_size[:2] + [self.image_size[2] * self.number_of_transforms])
        elif self.aggregation_fn == 'max':
            res = ktf.reduce_max(res, reduction_indices=[-2])
        elif self.aggregation_fn == 'avg':
            counts = ktf.reduce_sum(mask, reduction_indices=[-1])
            counts = ktf.expand_dims(counts, axis=-1)
            res = ktf.reduce_sum(res, reduction_indices=[-2])
            res /= counts
            res = ktf.where(ktf.is_nan(res), ktf.zeros_like(res), res)
        return res
示例#2
0
def check_angles(x, rotation_guess):
    x = tf.reshape(x, (-1, 1))
    x = angle_mod(x)
    rA = radians(x)
    rA = tf.concat([tf.cos(rA), tf.sin(rA)], axis=-1)
    rI = tf.reshape(rotation_guess, (-1, 1))
    rI = radians(rI)
    rI = tf.concat([tf.cos(rI), tf.sin(rI)], axis=-1)
    guess_test = tf.matmul(rA, rI, transpose_b=True)
    x = tf.where(guess_test < 0, angle_mod(x - 180), x)
    return x
示例#3
0
def _upsampled_registration(target_image, src_image, upsample_factor):

    upsample_factor = tf.constant(upsample_factor, tf.float32)

    target_shape = tf.shape(target_image)
    target_image = tf.reshape(target_image, target_shape[:3])
    src_shape = tf.shape(src_image)
    src_image = tf.reshape(src_image, src_shape[:3])

    src_freq = fft2d(src_image)
    target_freq = fft2d(target_image)

    shape = tf.reshape(tf.shape(src_freq)[1:3], (1, 2))
    shape = tf.cast(shape, tf.float32)
    shape = tf.tile(shape, (tf.shape(target_freq)[0], 1))
    image_product = src_freq * tf.conj(target_freq)
    cross_correlation = tf.spectral.ifft2d(image_product)

    maxima = find_maxima(tf.abs(cross_correlation))
    midpoints = fix(tf.cast(shape, tf.float32) / 2.)

    shifts = maxima
    shifts = tf.where(shifts > midpoints, shifts - shape, shifts)
    shifts = tf.round(shifts * upsample_factor) / upsample_factor

    upsampled_region_size = tf.ceil(upsample_factor * 1.5)
    dftshift = fix(upsampled_region_size / 2.0)
    normalization = tf.cast(tf.size(src_freq[0]), tf.float32)
    normalization *= upsample_factor**2
    sample_region_offset = dftshift - shifts * upsample_factor

    data = tf.conj(image_product)
    upsampled_dft = _upsampled_dft(data, upsampled_region_size,
                                   upsample_factor, sample_region_offset)

    cross_correlation = tf.conj(upsampled_dft)
    cross_correlation /= tf.cast(normalization, tf.complex64)
    cross_correlation = tf.abs(cross_correlation)

    maxima = find_maxima(cross_correlation)
    maxima = maxima - dftshift
    shifts = shifts + maxima / upsample_factor

    return shifts
示例#4
0
def create_sparse(labels):
    indices = tf.where(tf.not_equal(labels, 0))
    values = tf.gather_nd(labels, indices)
    shape = tf.shape(labels, out_type=tf.int64)
    return tf.SparseTensor(indices, values, dense_shape=shape)
示例#5
0
文件: run.py 项目: zzy950117/wc-gan
def get_lr_decay_schedule(args):
    number_of_iters_generator = 1000. * args.number_of_epochs
    number_of_iters_discriminator = 1000. * args.number_of_epochs * args.training_ratio

    if args.lr_decay_schedule is None:
        lr_decay_schedule_generator = lambda iter: 1.
        lr_decay_schedule_discriminator = lambda iter: 1.
    elif args.lr_decay_schedule == 'linear':
        lr_decay_schedule_generator = lambda iter: K.maximum(
            0., 1. - K.cast(iter, 'float32') / number_of_iters_generator)
        lr_decay_schedule_discriminator = lambda iter: K.maximum(
            0., 1. - K.cast(iter, 'float32') / number_of_iters_discriminator)
    elif args.lr_decay_schedule == 'half-linear':
        lr_decay_schedule_generator = lambda iter: ktf.where(
            K.less(iter, K.cast(number_of_iters_generator / 2, 'int64')),
            ktf.maximum(
                0., 1. -
                (K.cast(iter, 'float32') / number_of_iters_generator)), 0.5)
        lr_decay_schedule_discriminator = lambda iter: ktf.where(
            K.less(iter, K.cast(number_of_iters_discriminator / 2, 'int64')),
            ktf.maximum(
                0., 1. - (K.cast(iter, 'float32') /
                          number_of_iters_discriminator)), 0.5)
    elif args.lr_decay_schedule == 'linear-end':
        decay_at = 0.828

        number_of_iters_until_decay_generator = number_of_iters_generator * decay_at
        number_of_iters_until_decay_discriminator = number_of_iters_discriminator * decay_at

        number_of_iters_after_decay_generator = number_of_iters_generator * (
            1 - decay_at)
        number_of_iters_after_decay_discriminator = number_of_iters_discriminator * (
            1 - decay_at)

        lr_decay_schedule_generator = lambda iter: ktf.where(
            K.greater(iter,
                      K.cast(number_of_iters_until_decay_generator, 'int64')),
            ktf.maximum(
                0., 1. - (K.cast(iter, 'float32') -
                          number_of_iters_until_decay_generator) /
                number_of_iters_after_decay_generator), 1)
        lr_decay_schedule_discriminator = lambda iter: ktf.where(
            K.greater(
                iter, K.cast(number_of_iters_until_decay_discriminator, 'int64'
                             )),
            ktf.maximum(
                0., 1. - (K.cast(iter, 'float32') -
                          number_of_iters_until_decay_discriminator) /
                number_of_iters_after_decay_discriminator), 1)
    elif args.lr_decay_schedule.startswith("dropat"):
        drop_at = int(args.lr_decay_schedule.replace('dropat', ''))
        drop_at_generator = drop_at * 1000
        drop_at_discriminator = drop_at * 1000 * args.training_ratio
        print("Drop at generator %s" % drop_at_generator)
        lr_decay_schedule_generator = lambda iter: (ktf.where(
            K.less(iter, drop_at_generator), 1., 0.1) * K.maximum(
                0., 1. - K.cast(iter, 'float32') / number_of_iters_generator))
        lr_decay_schedule_discriminator = lambda iter: (ktf.where(
            K.less(iter, drop_at_discriminator), 1., 0.1) * K.maximum(
                0., 1. - K.cast(iter, 'float32') /
                number_of_iters_discriminator))
    else:
        assert False

    return lr_decay_schedule_generator, lr_decay_schedule_discriminator
示例#6
0
def fix(x):
    x = tf.where(x >= 0, tf.floor(x), tf.ceil(x))
    return x
示例#7
0
def angle_mod(x):
    x_test = fix(x / 360.)
    x = x - 360. * x_test
    x = tf.where(x < 0, x + 360, x)
    return x