Exemplo n.º 1
0
    def build_tensorflow_model(self, batch_size):
        """ Break it out into functions?
        """

        # Set input/output shapes for reference during inference.
        self.model_input_shape = tuple([batch_size] + list(self.input_shape))
        self.model_output_shape = tuple([batch_size] + list(self.input_shape))

        self.latent = tf.placeholder(tf.float32, [None, self.latent_size])
        self.reference_images = tf.placeholder(
            tf.float32, [None] + list(self.model_input_shape)[1:])
        self.synthetic_images = generator(self,
                                          self.latent,
                                          depth=self.depth,
                                          name='generator')

        self.discriminator_real, self.discriminator_real_logits = discriminator(
            self,
            self.reference_images,
            depth=self.depth + 1,
            name='discriminator')
        self.discriminator_fake, self.discriminator_fake_logits = discriminator(
            self,
            self.synthetic_images,
            depth=self.depth + 1,
            name='discriminator',
            reuse=True)

        t_vars = tf.trainable_variables()
        self.d_vars = [var for var in t_vars if 'discriminator' in var.name]
        self.g_vars = [var for var in t_vars if 'generator' in var.name]
        self.saver = tf.train.Saver(self.g_vars + self.d_vars)

        self.calculate_losses()

        if self.hyperverbose:
            self.model_summary()
Exemplo n.º 2
0
    def build_tensorflow_model(self, batch_size):
        """ Break it out into functions?
        """

        # Set input/output shapes for reference during inference.
        self.model_input_shape = tuple([batch_size] + list(self.input_shape))
        self.model_output_shape = tuple([batch_size] + list(self.input_shape))

        self.alpha_transition = tf.Variable(initial_value=0.0,
                                            trainable=False,
                                            name='alpha_transition')
        self.step_pl = tf.placeholder(tf.float32, shape=None)
        self.alpha_transition_assign = self.alpha_transition.assign(
            self.step_pl / (self.num_epochs * self.training_steps_per_epoch))

        self.latent = tf.placeholder(tf.float32, [None, self.latent_size])
        self.reference_images = tf.placeholder(
            tf.float32, [None] + list(self.model_input_shape)[1:])
        self.synthetic_images = generator(
            self,
            self.latent,
            depth=self.progressive_depth,
            transition=self.transition,
            alpha_transition=self.alpha_transition,
            name='generator')

        # Derived Parameters
        self.output_size = pow(2, self.progressive_depth + 2)
        self.zoom_level = self.progressive_depth + 1
        self.reference_images = tf.placeholder(
            tf.float32,
            [None] + [self.output_size] * self.dim + [self.channels])

        max_downscale = np.floor(math.log(self.model_input_shape[1], 2))
        downscale_factor = 2**max_downscale / (2**(self.progressive_depth + 2))
        self.raw_volumes = tf.placeholder(tf.float32, self.model_input_shape)
        self.input_volumes = downscale2d(self.raw_volumes, downscale_factor)

        # Data Loading Tools
        self.low_images = upscale2d(downscale2d(self.reference_images, 2), 2)
        self.real_images = self.alpha_transition * self.reference_images + (
            1 - self.alpha_transition) * self.low_images

        self.discriminator_real, self.discriminator_real_logits = discriminator(
            self,
            self.reference_images,
            depth=self.progressive_depth,
            name='discriminator',
            transition=self.transition,
            alpha_transition=self.alpha_transition)
        self.discriminator_fake, self.discriminator_fake_logits = discriminator(
            self,
            self.synthetic_images,
            depth=self.progressive_depth,
            name='discriminator',
            transition=self.transition,
            alpha_transition=self.alpha_transition,
            reuse=True)

        # Hmmm.. better way to do this? Or at least move to function.
        t_vars = tf.trainable_variables()
        self.d_vars = [var for var in t_vars if 'discriminator' in var.name]
        self.g_vars = [var for var in t_vars if 'generator' in var.name]

        # save the variables , which remain unchanged
        self.d_vars_n = [
            var for var in self.d_vars if 'discriminator_n' in var.name
        ]
        self.g_vars_n = [
            var for var in self.g_vars if 'generator_n' in var.name
        ]

        # remove the new variables for the new model
        self.d_vars_n_read = [
            var for var in self.d_vars_n
            if '{}'.format(self.output_size) not in var.name
        ]
        self.g_vars_n_read = [
            var for var in self.g_vars_n
            if '{}'.format(self.output_size) not in var.name
        ]

        # save the rgb variables, which remain unchanged
        self.d_vars_n_2 = [
            var for var in self.d_vars
            if 'discriminator_y_rgb_conv' in var.name
        ]
        self.g_vars_n_2 = [
            var for var in self.g_vars if 'generator_y_rgb_conv' in var.name
        ]

        self.d_vars_n_2_rgb = [
            var for var in self.d_vars_n_2
            if '{}'.format(self.output_size) not in var.name
        ]
        self.g_vars_n_2_rgb = [
            var for var in self.g_vars_n_2
            if '{}'.format(self.output_size) not in var.name
        ]

        self.saver = tf.train.Saver(self.d_vars + self.g_vars)
        self.r_saver = tf.train.Saver(self.d_vars_n_read + self.g_vars_n_read)
        if len(self.d_vars_n_2_rgb + self.g_vars_n_2_rgb):
            self.rgb_saver = tf.train.Saver(self.d_vars_n_2_rgb +
                                            self.g_vars_n_2_rgb)

        self.calculate_losses()

        if self.hyperverbose:
            self.model_summary()
Exemplo n.º 3
0
    def build_tensorflow_model(self, batch_size):

        """ Break it out into functions?
        """

        # Set input/output shapes for reference during inference.
        self.model_input_shape = tuple([batch_size] + list(self.input_shape))
        self.model_output_shape = tuple([batch_size] + list(self.input_shape))

        self.latent = tf.placeholder(tf.float32, [None, self.latent_size])
        self.reference_images = tf.placeholder(tf.float32, [None] + list(self.model_input_shape)[1:])

        self.synthetic_images = generator(self, self.latent, depth=self.depth, name='generator')

        _, _, _, self.discriminator_real_logits = discriminator(self, self.reference_images, depth=self.depth + 1, name='discriminator')
        _, _, _, self.discriminator_fake_logits = discriminator(self, self.synthetic_images, depth=self.depth + 1, name='discriminator', reuse=True)

        self.basic_loss = tf.reduce_mean(tf.square(self.reference_images - self.synthetic_images))

        # Loss functions
        self.D_loss = tf.reduce_mean(self.discriminator_fake_logits) - tf.reduce_mean(self.discriminator_real_logits)
        self.G_loss = -tf.reduce_mean(self.discriminator_fake_logits)

        # Gradient Penalty from Wasserstein GAN GP, I believe? Check on it --andrew
        # Also investigate more what's happening here --andrew
        self.differences = self.synthetic_images - self.reference_images
        self.alpha = tf.random_uniform(shape=[tf.shape(self.differences)[0], 1, 1, 1], minval=0., maxval=1.)
        interpolates = self.reference_images + (self.alpha * self.differences)
        _, _, _, discri_logits = discriminator(self, interpolates, reuse=True, depth=self.depth + 1, name='discriminator')
        gradients = tf.gradients(discri_logits, [interpolates])[0]

        # Some sort of norm from papers, check up on it. --andrew
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2, 3]))
        self.gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
        tf.summary.scalar("gp_loss", self.gradient_penalty)

        # Update Loss functions..
        self.D_origin_loss = self.D_loss
        self.D_loss += 10 * self.gradient_penalty
        self.D_loss += 0.001 * tf.reduce_mean(tf.square(self.discriminator_real_logits - 0.0))

        # vgg_model = tf.keras.applications.VGG19(include_top=False,
        #                                     weights='imagenet',
        #                                     input_tensor=self.synthetic_images,
        #                                     input_shape=(64, 64, 3),
        #                                     pooling=None,
        #                                     classes=1000)
        # print(vgg_model)

        # self.load_reference_model()

        input_tensor = keras.layers.Input(tensor=self.synthetic_images, shape=self.input_shape)

        model_parameters = {'input_shape': self.input_shape,
                    'downsize_filters_factor': 1,
                    'pool_size': (2, 2), 
                    'kernel_size': (3, 3), 
                    'dropout': 0, 
                    'batch_norm': True, 
                    'initial_learning_rate': 0.00001, 
                    'output_type': 'binary_label',
                    'num_outputs': 1, 
                    'activation': 'relu',
                    'padding': 'same', 
                    'implementation': 'keras',
                    'depth': 3,
                    'max_filter': 128,
                    'stride_size': (1, 1),
                    'input_tensor': input_tensor}

        unet_output = UNet(**model_parameters)
        unet_model = keras.models.Model(input_tensor, unet_output.output_layer)
        unet_model.load_weights('retinal_seg_weights.h5')

        if self.hyperverbose:
            self.model_summary()

        # self.find_layers(['sampling'])

        self.activated_tensor = self.grab_tensor(self.activated_tensor_name)
        print self.activated_tensor
        self.activated_tensor = tf.stack([self.activated_tensor[..., self.filter_num]], axis=-1)
        print self.activated_tensor
        # self.input_tensor = self.grab_tensor(self.input_tensor_name)

        self.activation_loss = -1 * tf.reduce_mean(self.activated_tensor)
        self.activaton_graidents = tf.gradients(self.activation_loss, self.synthetic_images)
        print self.activaton_graidents

        # Hmmm.. better way to do this? Or at least move to function.
        t_vars = tf.trainable_variables()
        self.d_vars = [var for var in t_vars if 'discriminator' in var.name]
        self.g_vars = [var for var in t_vars if 'generator' in var.name]

        # Create save/load operation
        self.saver = tf.train.Saver(self.g_vars + self.d_vars)

        self.G_activation_loss = self.G_loss + .000 * self.activation_loss

        # Create Optimizers
        self.opti_D = tf.train.AdamOptimizer(learning_rate=self.initial_learning_rate, beta1=0.0, beta2=0.99).minimize(
            self.D_loss, var_list=self.d_vars)
        self.opti_G = self.tensorflow_optimizer_dict[self.optimizer](learning_rate=self.initial_learning_rate, beta1=0.0, beta2=0.99).minimize(self.G_activation_loss, var_list=self.g_vars)

        self.combined_loss = 1 * self.activation_loss + 1 * self.basic_loss

        self.combined_optimizer = self.tensorflow_optimizer_dict[self.optimizer](learning_rate=self.initial_learning_rate, beta1=0.0, beta2=0.99).minimize(self.combined_loss, var_list=self.g_vars)

        self.basic_optimizer = self.tensorflow_optimizer_dict[self.optimizer](learning_rate=self.initial_learning_rate, beta1=0.0, beta2=0.99).minimize(self.basic_loss, var_list=self.g_vars)

        self.activation_optimizer = self.tensorflow_optimizer_dict[self.optimizer](learning_rate=self.initial_learning_rate, beta1=0.0, beta2=0.99).minimize(self.activation_loss, var_list=self.g_vars)