예제 #1
0
    def build_model(self):

        # Placeholdersn
        if self.config.trainer.init_type == "normal":  ## test
            self.init_kernel = tf.random_normal_initializer(mean=0.0,
                                                            stddev=0.02)
        elif self.config.trainer.init_type == "xavier":
            self.init_kernel = tf.contrib.layers.xavier_initializer(
                uniform=False, seed=None, dtype=tf.float32)
        self.is_training = tf.placeholder(tf.bool)
        self.image_tensor = tf.placeholder(tf.float32,
                                           shape=[None] +
                                           self.config.trainer.image_dims,
                                           name="x")
        self.noise_tensor = tf.placeholder(
            tf.float32,
            shape=[None, self.config.trainer.noise_dim],
            name="noise")
        self.true_labels = tf.placeholder(dtype=tf.float32,
                                          shape=[None, 1],
                                          name="true_labels")
        self.generated_labels = tf.placeholder(dtype=tf.float32,
                                               shape=[None, 1],
                                               name="gen_labels")
        self.real_noise = tf.placeholder(dtype=tf.float32,
                                         shape=[None] +
                                         self.config.trainer.image_dims,
                                         name="real_noise")
        self.fake_noise = tf.placeholder(dtype=tf.float32,
                                         shape=[None] +
                                         self.config.trainer.image_dims,
                                         name="fake_noise")
        # Building the Graph
        with tf.variable_scope("ALAD"):
            # Generated noise from the encoder
            with tf.variable_scope("Encoder_Model"):
                self.z_gen = self.encoder(
                    self.image_tensor,
                    do_spectral_norm=self.config.trainer.do_spectral_norm)
            # Generated image and reconstructed image from the Generator
            with tf.variable_scope("Generator_Model"):
                self.img_gen = self.generator(
                    self.noise_tensor) + self.fake_noise
                self.rec_img = self.generator(self.z_gen)

            # Reconstructed image of generated image from the encoder
            with tf.variable_scope("Encoder_Model"):
                self.rec_z = self.encoder(
                    self.img_gen, do_spectral_norm=self.config.spectral_norm)

            # Discriminator results of (G(z),z) and (x, E(x))
            with tf.variable_scope("Discriminator_Model_XZ"):
                l_generator, inter_layer_rct_xz = self.discriminator_xz(
                    self.img_gen,
                    self.noise_tensor,
                    do_spectral_norm=self.config.spectral_norm)
                l_encoder, inter_layer_inp_xz = self.discriminator_xz(
                    self.image_tensor + self.real_noise,
                    self.z_gen,
                    do_spectral_norm=self.config.do_spectral_norm,
                )

            # Discrimeinator results of (x, x) and (x, G(E(x))
            with tf.variable_scope("Discriminator_Model_XX"):
                x_logit_real, inter_layer_inp_xx = self.discriminator_xx(
                    self.image_tensor + self.real_noise,
                    self.image_tensor + self.real_noise,
                    do_spectral_norm=self.config.spectral_norm,
                )
                x_logit_fake, inter_layer_rct_xx = self.discriminator_xx(
                    self.image_tensor + self.real_noise,
                    self.rec_img,
                    do_spectral_norm=self.config.spectral_norm,
                )
            # Discriminator results of (z, z) and (z, E(G(z))
            with tf.variable_scope("Discriminator_Model_ZZ"):
                z_logit_real, _ = self.discriminator_zz(
                    self.noise_tensor,
                    self.noise_tensor,
                    do_spectral_norm=self.config.spectral_norm)
                z_logit_fake, _ = self.discriminator_zz(
                    self.noise_tensor,
                    self.rec_z,
                    do_spectral_norm=self.config.spectral_norm)
        ########################################################################
        # LOSS FUNCTIONS
        ########################################################################
        with tf.name_scope("Loss_Functions"):
            # discriminator xz

            # Discriminator should classify encoder pair as real
            loss_dis_enc = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=self.true_labels, logits=l_encoder))
            # Discriminator should classify generator pair as fake
            loss_dis_gen = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=self.generated_labels, logits=l_generator))
            self.dis_loss_xz = loss_dis_gen + loss_dis_enc

            # discriminator xx
            x_real_dis = tf.nn.sigmoid_cross_entropy_with_logits(
                logits=x_logit_real, labels=tf.ones_like(x_logit_real))
            x_fake_dis = tf.nn.sigmoid_cross_entropy_with_logits(
                logits=x_logit_fake, labels=tf.zeros_like(x_logit_fake))
            self.dis_loss_xx = tf.reduce_mean(x_real_dis + x_fake_dis)
            # discriminator zz
            z_real_dis = tf.nn.sigmoid_cross_entropy_with_logits(
                logits=z_logit_real, labels=tf.ones_like(z_logit_real))
            z_fake_dis = tf.nn.sigmoid_cross_entropy_with_logits(
                logits=z_logit_fake, labels=tf.zeros_like(z_logit_fake))
            self.dis_loss_zz = tf.reduce_mean(z_real_dis + z_fake_dis)
            # Compute the whole discriminator loss
            self.loss_discriminator = (self.dis_loss_xz + self.dis_loss_xx +
                                       self.dis_loss_zz
                                       if self.config.trainer.allow_zz else
                                       self.dis_loss_xz + self.dis_loss_xx)
            # generator and encoder
            if self.config.trainer.flip_labels:
                labels_gen = tf.zeros_like(l_generator)
                labels_enc = tf.ones_like(l_encoder)
            else:
                labels_gen = tf.ones_like(l_generator)
                labels_enc = tf.zeros_like(l_encoder)

            gen_loss_xz = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_gen,
                                                        logits=l_generator))
            enc_loss_xz = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_enc,
                                                        logits=l_encoder))

            x_real_gen = tf.nn.sigmoid_cross_entropy_with_logits(
                logits=x_logit_real, labels=tf.zeros_like(x_logit_real))
            x_fake_gen = tf.nn.sigmoid_cross_entropy_with_logits(
                logits=x_logit_fake, labels=tf.ones_like(x_logit_fake))
            z_real_gen = tf.nn.sigmoid_cross_entropy_with_logits(
                logits=z_logit_real, labels=tf.zeros_like(z_logit_real))
            z_fake_gen = tf.nn.sigmoid_cross_entropy_with_logits(
                logits=z_logit_fake, labels=tf.ones_like(z_logit_fake))

            cost_x = tf.reduce_mean(x_real_gen + x_fake_gen)
            cost_z = tf.reduce_mean(z_real_gen + z_fake_gen)

            cycle_consistency_loss = cost_x + cost_z if self.config.trainer.allow_zz else cost_x
            self.loss_generator = gen_loss_xz + cycle_consistency_loss
            self.loss_encoder = enc_loss_xz + cycle_consistency_loss

        ########################################################################
        # OPTIMIZATION
        ########################################################################
        with tf.name_scope("Optimizers"):

            # control op dependencies for batch norm and trainable variables
            all_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            self.dxzvars = [
                v for v in all_variables
                if v.name.startswith("ALAD/Discriminator_Model_XZ")
            ]
            self.dxxvars = [
                v for v in all_variables
                if v.name.startswith("ALAD/Discriminator_Model_XX")
            ]
            self.dzzvars = [
                v for v in all_variables
                if v.name.startswith("ALAD/Discriminator_Model_ZZ")
            ]
            self.gvars = [
                v for v in all_variables
                if v.name.startswith("ALAD/Generator_Model")
            ]
            self.evars = [
                v for v in all_variables
                if v.name.startswith("ALAD/Encoder_Model")
            ]

            self.update_ops_gen = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="ALAD/Generator_Model")
            self.update_ops_enc = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                                    scope="ALAD/Encoder_Model")
            self.update_ops_dis_xz = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="ALAD/Discriminator_Model_XZ")
            self.update_ops_dis_xx = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="ALAD/Discriminator_Model_XX")
            self.update_ops_dis_zz = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="ALAD/Discriminator_Model_ZZ")
            self.disc_optimizer = tf.train.AdamOptimizer(
                learning_rate=self.config.trainer.discriminator_l_rate,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )
            self.gen_optimizer = tf.train.AdamOptimizer(
                learning_rate=self.config.trainer.generator_l_rate,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )
            self.enc_optimizer = tf.train.AdamOptimizer(
                learning_rate=self.config.trainer.generator_l_rate,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )

            with tf.control_dependencies(self.update_ops_gen):
                self.gen_op = self.gen_optimizer.minimize(
                    self.loss_generator,
                    global_step=self.global_step_tensor,
                    var_list=self.gvars)
            with tf.control_dependencies(self.update_ops_enc):
                self.enc_op = self.enc_optimizer.minimize(self.loss_encoder,
                                                          var_list=self.evars)

            with tf.control_dependencies(self.update_ops_dis_xz):
                self.dis_op_xz = self.disc_optimizer.minimize(
                    self.dis_loss_xz, var_list=self.dxzvars)

            with tf.control_dependencies(self.update_ops_dis_xx):
                self.dis_op_xx = self.disc_optimizer.minimize(
                    self.dis_loss_xx, var_list=self.dxxvars)

            with tf.control_dependencies(self.update_ops_dis_zz):
                self.dis_op_zz = self.disc_optimizer.minimize(
                    self.dis_loss_zz, var_list=self.dzzvars)

            # Exponential Moving Average for inference
            def train_op_with_ema_dependency(vars, op):
                ema = tf.train.ExponentialMovingAverage(
                    decay=self.config.trainer.ema_decay)
                maintain_averages_op = ema.apply(vars)
                with tf.control_dependencies([op]):
                    train_op = tf.group(maintain_averages_op)
                return train_op, ema

            self.train_gen_op, self.gen_ema = train_op_with_ema_dependency(
                self.gvars, self.gen_op)
            self.train_enc_op, self.enc_ema = train_op_with_ema_dependency(
                self.evars, self.enc_op)

            self.train_dis_op_xz, self.xz_ema = train_op_with_ema_dependency(
                self.dxzvars, self.dis_op_xz)
            self.train_dis_op_xx, self.xx_ema = train_op_with_ema_dependency(
                self.dxxvars, self.dis_op_xx)
            self.train_dis_op_zz, self.zz_ema = train_op_with_ema_dependency(
                self.dzzvars, self.dis_op_zz)

        with tf.variable_scope("ALAD"):
            with tf.variable_scope("Encoder_Model"):
                self.z_gen_ema = self.encoder(
                    self.image_tensor,
                    getter=sn.get_getter(self.enc_ema),
                    do_spectral_norm=self.config.trainer.spectral_norm,
                )

            with tf.variable_scope("Generator_Model"):
                self.rec_x_ema = self.generator(self.z_gen_ema,
                                                getter=sn.get_getter(
                                                    self.gen_ema))
                self.x_gen_ema = self.generator(self.noise_tensor,
                                                getter=sn.get_getter(
                                                    self.gen_ema))
            with tf.variable_scope("Discriminator_Model_XX"):
                l_encoder_emaxx, inter_layer_inp_emaxx = self.discriminator_xx(
                    self.image_tensor,
                    self.image_tensor,
                    getter=sn.get_getter(self.xx_ema),
                    do_spectral_norm=self.config.trainer.spectral_norm,
                )
                l_generator_emaxx, inter_layer_rct_emaxx = self.discriminator_xx(
                    self.image_tensor,
                    self.rec_x_ema,
                    getter=sn.get_getter(self.xx_ema),
                    do_spectral_norm=self.config.trainer.spectral_norm,
                )

        with tf.name_scope("Testing"):
            with tf.variable_scope("Scores"):
                score_ch = tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.ones_like(l_generator_emaxx),
                    logits=l_generator_emaxx)
                self.score_ch = tf.squeeze(score_ch)

                rec = self.image_tensor - self.rec_x_ema
                rec = tf.layers.Flatten()(rec)
                score_l1 = tf.norm(rec,
                                   ord=1,
                                   axis=1,
                                   keepdims=False,
                                   name="d_loss")
                self.score_l1 = tf.squeeze(score_l1)

                rec = self.image_tensor - self.rec_x_ema
                rec = tf.layers.Flatten()(rec)
                score_l2 = tf.norm(rec,
                                   ord=2,
                                   axis=1,
                                   keepdims=False,
                                   name="d_loss")
                self.score_l2 = tf.squeeze(score_l2)

                inter_layer_inp, inter_layer_rct = (inter_layer_inp_emaxx,
                                                    inter_layer_rct_emaxx)
                fm = inter_layer_inp - inter_layer_rct
                fm = tf.layers.Flatten()(fm)
                score_fm = tf.norm(fm,
                                   ord=self.config.trainer.degree,
                                   axis=1,
                                   keepdims=False,
                                   name="d_loss")
                self.score_fm = tf.squeeze(score_fm)

        if self.config.trainer.enable_early_stop:
            self.rec_error_valid = tf.reduce_mean(score_fm)
        ########################################################################
        # TENSORBOARD
        ########################################################################
        if self.config.log.enable_summary:

            with tf.name_scope("train_summary"):

                with tf.name_scope("dis_summary"):
                    tf.summary.scalar("loss_discriminator",
                                      self.loss_discriminator, ["dis"])
                    tf.summary.scalar("loss_dis_encoder", loss_dis_enc,
                                      ["dis"])
                    tf.summary.scalar("loss_dis_gen", loss_dis_gen, ["dis"])
                    tf.summary.scalar("loss_dis_xz", self.dis_loss_xz, ["dis"])
                    tf.summary.scalar("loss_dis_xx", self.dis_loss_xx, ["dis"])
                    if self.config.trainer.allow_zz:
                        tf.summary.scalar("loss_dis_zz", self.dis_loss_zz,
                                          ["dis"])

                with tf.name_scope("gen_summary"):

                    tf.summary.scalar("loss_generator", self.loss_generator,
                                      ["gen"])
                    tf.summary.scalar("loss_encoder", self.loss_encoder,
                                      ["gen"])
                    tf.summary.scalar("loss_encgen_dxx", cost_x, ["gen"])
                    if self.config.trainer.allow_zz:
                        tf.summary.scalar("loss_encgen_dzz", cost_z, ["gen"])

                with tf.name_scope("img_summary"):
                    heatmap_pl_latent = tf.placeholder(
                        tf.float32,
                        shape=(1, 480, 640, 3),
                        name="heatmap_pl_latent")
                    self.sum_op_latent = tf.summary.image(
                        "heatmap_latent", heatmap_pl_latent)

                with tf.name_scope("image_summary"):
                    tf.summary.image("rec_img", self.rec_img, 3, ["image"])
                    tf.summary.image("input_image", self.image_tensor, 3,
                                     ["image"])
                    tf.summary.image("image_gen", self.img_gen, 3, ["image"])
                    tf.summary.image("input_image", self.image_tensor, 3,
                                     ["image_2"])
                    tf.summary.image("rec_x_ema", self.rec_x_ema, 3,
                                     ["image_2"])

        if self.config.trainer.enable_early_stop:
            with tf.name_scope("validation_summary"):
                tf.summary.scalar("valid", self.rec_error_valid, ["v"])

        self.sum_op_dis = tf.summary.merge_all("dis")
        self.sum_op_gen = tf.summary.merge_all("gen")
        self.sum_op = tf.summary.merge([self.sum_op_dis, self.sum_op_gen])
        self.sum_op_im = tf.summary.merge_all("image")
        self.sum_op_im_test = tf.summary.merge_all("image_2")
        self.sum_op_valid = tf.summary.merge_all("v")
예제 #2
0
    def build_model(self):
        # Initializations
        # Kernel initialization for the convolutions
        if self.config.trainer.init_type == "normal":
            self.init_kernel = tf.random_normal_initializer(mean=0.0,
                                                            stddev=0.02)
        elif self.config.trainer.init_type == "xavier":
            self.init_kernel = tf.contrib.layers.xavier_initializer(
                uniform=False, seed=None, dtype=tf.float32)
        # Placeholders
        self.is_training_gen = tf.placeholder(tf.bool)
        self.is_training_dis = tf.placeholder(tf.bool)
        self.is_training_enc = tf.placeholder(tf.bool)
        self.image_input = tf.placeholder(tf.float32,
                                          shape=[None] +
                                          self.config.trainer.image_dims,
                                          name="x")
        self.noise_tensor = tf.placeholder(
            tf.float32,
            shape=[None, self.config.trainer.noise_dim],
            name="noise")
        # Build Training Graph
        self.logger.info("Building training graph...")
        with tf.variable_scope("EncEBGAN"):
            with tf.variable_scope("Generator_Model"):
                self.image_gen = self.generator(self.noise_tensor)

            with tf.variable_scope("Discriminator_Model"):
                self.embedding_real, self.decoded_real = self.discriminator(
                    self.image_input,
                    do_spectral_norm=self.config.trainer.do_spectral_norm)
                self.embedding_fake, self.decoded_fake = self.discriminator(
                    self.image_gen,
                    do_spectral_norm=self.config.trainer.do_spectral_norm)
            with tf.variable_scope("Encoder_Model"):
                self.image_encoded = self.encoder(self.image_input)

            with tf.variable_scope("Generator_Model"):
                self.image_gen_enc = self.generator(self.image_encoded)

            with tf.variable_scope("Discriminator_Model"):
                self.embedding_enc_fake, self.decoded_enc_fake = self.discriminator(
                    self.image_gen_enc,
                    do_spectral_norm=self.config.trainer.do_spectral_norm)
                self.embedding_enc_real, self.decoded_enc_real = self.discriminator(
                    self.image_input,
                    do_spectral_norm=self.config.trainer.do_spectral_norm)
        # Loss functions
        with tf.name_scope("Loss_Functions"):
            with tf.name_scope("Generator_Discriminator"):
                # Discriminator Loss
                if self.config.trainer.mse_mode == "norm":
                    self.disc_loss_real = tf.reduce_mean(
                        self.mse_loss(self.decoded_real,
                                      self.image_input,
                                      mode="norm"))
                    self.disc_loss_fake = tf.reduce_mean(
                        self.mse_loss(self.decoded_fake,
                                      self.image_gen,
                                      mode="norm"))
                elif self.config.trainer.mse_mode == "mse":
                    self.disc_loss_real = self.mse_loss(self.decoded_real,
                                                        self.image_input,
                                                        mode="mse")
                    self.disc_loss_fake = self.mse_loss(self.decoded_fake,
                                                        self.image_gen,
                                                        mode="mse")
                self.loss_discriminator = (tf.math.maximum(
                    self.config.trainer.disc_margin - self.disc_loss_fake, 0) +
                                           self.disc_loss_real)
                # Generator Loss
                pt_loss = 0
                if self.config.trainer.pullaway:
                    pt_loss = self.pullaway_loss(self.embedding_fake)
                self.loss_generator = self.disc_loss_fake + self.config.trainer.pt_weight * pt_loss

            with tf.name_scope("Encoder"):
                if self.config.trainer.mse_mode == "norm":
                    self.loss_enc_rec = tf.reduce_mean(
                        self.mse_loss(self.image_gen_enc,
                                      self.image_input,
                                      mode="norm"))
                    self.loss_enc_f = tf.reduce_mean(
                        self.mse_loss(self.embedding_enc_real,
                                      self.embedding_enc_fake,
                                      mode="norm"))
                elif self.config.trainer.mse_mode == "mse":
                    self.loss_enc_rec = tf.reduce_mean(
                        self.mse_loss(self.image_gen_enc,
                                      self.image_input,
                                      mode="mse"))
                    self.loss_enc_f = tf.reduce_mean(
                        self.mse_loss(self.embedding_enc_real,
                                      self.embedding_enc_fake,
                                      mode="mse"))
                self.loss_encoder = (
                    self.loss_enc_rec +
                    self.config.trainer.encoder_f_factor * self.loss_enc_f)

        # Optimizers
        with tf.name_scope("Optimizers"):
            self.generator_optimizer = tf.train.AdamOptimizer(
                self.config.trainer.standard_lr_gen,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )
            self.encoder_optimizer = tf.train.AdamOptimizer(
                self.config.trainer.standard_lr_enc,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )
            self.discriminator_optimizer = tf.train.AdamOptimizer(
                self.config.trainer.standard_lr_dis,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )
            # Collect all the variables
            all_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            # Generator Network Variables
            self.generator_vars = [
                v for v in all_variables
                if v.name.startswith("EncEBGAN/Generator_Model")
            ]
            # Discriminator Network Variables
            self.discriminator_vars = [
                v for v in all_variables
                if v.name.startswith("EncEBGAN/Discriminator_Model")
            ]
            # Discriminator Network Variables
            self.encoder_vars = [
                v for v in all_variables
                if v.name.startswith("EncEBGAN/Encoder_Model")
            ]
            # Generator Network Operations
            self.gen_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="EncEBGAN/Generator_Model")
            # Discriminator Network Operations
            self.disc_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="EncEBGAN/Discriminator_Model")
            self.enc_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="EncEBGAN/Encoder_Model")
            with tf.control_dependencies(self.gen_update_ops):
                self.gen_op = self.generator_optimizer.minimize(
                    self.loss_generator,
                    var_list=self.generator_vars,
                    global_step=self.global_step_tensor,
                )
            with tf.control_dependencies(self.disc_update_ops):
                self.disc_op = self.discriminator_optimizer.minimize(
                    self.loss_discriminator, var_list=self.discriminator_vars)
            with tf.control_dependencies(self.enc_update_ops):
                self.enc_op = self.encoder_optimizer.minimize(
                    self.loss_encoder,
                    var_list=self.encoder_vars,
                    global_step=self.global_step_tensor,
                )
            # Exponential Moving Average for Estimation
            self.dis_ema = tf.train.ExponentialMovingAverage(
                decay=self.config.trainer.ema_decay)
            maintain_averages_op_dis = self.dis_ema.apply(
                self.discriminator_vars)

            self.gen_ema = tf.train.ExponentialMovingAverage(
                decay=self.config.trainer.ema_decay)
            maintain_averages_op_gen = self.gen_ema.apply(self.generator_vars)

            self.enc_ema = tf.train.ExponentialMovingAverage(
                decay=self.config.trainer.ema_decay)
            maintain_averages_op_enc = self.enc_ema.apply(self.encoder_vars)

            with tf.control_dependencies([self.disc_op]):
                self.train_dis_op = tf.group(maintain_averages_op_dis)

            with tf.control_dependencies([self.gen_op]):
                self.train_gen_op = tf.group(maintain_averages_op_gen)

            with tf.control_dependencies([self.enc_op]):
                self.train_enc_op = tf.group(maintain_averages_op_enc)

        # Build Test Graph
        self.logger.info("Building Testing Graph...")
        with tf.variable_scope("EncEBGAN"):
            with tf.variable_scope("Discriminator_Model"):
                self.embedding_q_ema, self.decoded_q_ema = self.discriminator(
                    self.image_input,
                    getter=get_getter(self.dis_ema),
                    do_spectral_norm=self.config.trainer.do_spectral_norm,
                )
            with tf.variable_scope("Generator_Model"):
                self.image_gen_ema = self.generator(self.embedding_q_ema,
                                                    getter=get_getter(
                                                        self.gen_ema))
            with tf.variable_scope("Discriminator_Model"):
                self.embedding_rec_ema, self.decoded_rec_ema = self.discriminator(
                    self.image_gen_ema,
                    getter=get_getter(self.dis_ema),
                    do_spectral_norm=self.config.trainer.do_spectral_norm,
                )
            with tf.variable_scope("Encoder_Model"):
                self.image_encoded_ema = self.encoder(self.image_input,
                                                      getter=get_getter(
                                                          self.enc_ema))

            with tf.variable_scope("Generator_Model"):
                self.image_gen_enc_ema = self.generator(self.image_encoded_ema,
                                                        getter=get_getter(
                                                            self.gen_ema))
            with tf.variable_scope("Discriminator_Model"):
                self.embedding_enc_fake_ema, self.decoded_enc_fake_ema = self.discriminator(
                    self.image_gen_enc_ema,
                    getter=get_getter(self.dis_ema),
                    do_spectral_norm=self.config.trainer.do_spectral_norm,
                )
                self.embedding_enc_real_ema, self.decoded_enc_real_ema = self.discriminator(
                    self.image_input,
                    getter=get_getter(self.dis_ema),
                    do_spectral_norm=self.config.trainer.do_spectral_norm,
                )
        with tf.name_scope("Testing"):
            with tf.name_scope("Image_Based"):
                delta = self.image_input - self.image_gen_enc_ema
                self.rec_residual = -delta
                delta_flat = tf.layers.Flatten()(delta)
                img_score_l1 = tf.norm(delta_flat,
                                       ord=2,
                                       axis=1,
                                       keepdims=False,
                                       name="img_loss__1")
                self.img_score_l1 = tf.squeeze(img_score_l1)

                delta = self.decoded_enc_fake_ema - self.decoded_enc_real_ema
                delta_flat = tf.layers.Flatten()(delta)
                img_score_l2 = tf.norm(delta_flat,
                                       ord=2,
                                       axis=1,
                                       keepdims=False,
                                       name="img_loss__2")
                self.img_score_l2 = tf.squeeze(img_score_l2)
                self.score_comb = (
                    (1 - self.config.trainer.feature_match_weight) *
                    self.img_score_l1 +
                    self.config.trainer.feature_match_weight *
                    self.img_score_l2)
            with tf.name_scope("Noise_Based"):
                delta = self.image_encoded_ema - self.embedding_enc_fake_ema
                delta_flat = tf.layers.Flatten()(delta)
                z_score_l1 = tf.norm(delta_flat,
                                     ord=2,
                                     axis=1,
                                     keepdims=False,
                                     name="z_loss_1")
                self.z_score_l1 = tf.squeeze(z_score_l1)

                delta = self.embedding_enc_real_ema - self.embedding_enc_fake_ema
                delta_flat = tf.layers.Flatten()(delta)
                z_score_l2 = tf.norm(delta_flat,
                                     ord=2,
                                     axis=1,
                                     keepdims=False,
                                     name="z_loss_2")
                self.z_score_l2 = tf.squeeze(z_score_l2)

                self.score_comb_2 = (
                    (1 - self.config.trainer.feature_match_weight) *
                    self.z_score_l1 +
                    self.config.trainer.feature_match_weight * self.z_score_l2)

        # Tensorboard
        if self.config.log.enable_summary:
            with tf.name_scope("train_summary"):
                with tf.name_scope("dis_summary"):
                    tf.summary.scalar("loss_disc", self.loss_discriminator,
                                      ["dis"])
                    tf.summary.scalar("loss_disc_real", self.disc_loss_real,
                                      ["dis"])
                    tf.summary.scalar("loss_disc_fake", self.disc_loss_fake,
                                      ["dis"])
                with tf.name_scope("gen_summary"):
                    tf.summary.scalar("loss_generator", self.loss_generator,
                                      ["gen"])
                with tf.name_scope("enc_summary"):
                    tf.summary.scalar("loss_encoder", self.loss_encoder,
                                      ["enc"])
                with tf.name_scope("img_summary"):
                    tf.summary.image("input_image", self.image_input, 1,
                                     ["img_1"])
                    tf.summary.image("reconstructed", self.image_gen, 1,
                                     ["img_1"])
                    tf.summary.image("input_enc", self.image_input, 1,
                                     ["img_2"])
                    tf.summary.image("reconstructed", self.image_gen_enc, 1,
                                     ["img_2"])
                    tf.summary.image("input_image", self.image_input, 1,
                                     ["test"])
                    tf.summary.image("reconstructed", self.image_gen_enc_ema,
                                     1, ["test"])
                    tf.summary.image("residual", self.rec_residual, 1,
                                     ["test"])

        self.sum_op_dis = tf.summary.merge_all("dis")
        self.sum_op_gen = tf.summary.merge_all("gen")
        self.sum_op_enc = tf.summary.merge_all("enc")
        self.sum_op_im_1 = tf.summary.merge_all("img_1")
        self.sum_op_im_2 = tf.summary.merge_all("img_2")
        self.sum_op_im_test = tf.summary.merge_all("test")
        self.sum_op = tf.summary.merge([self.sum_op_dis, self.sum_op_gen])
예제 #3
0
    def build_model(self):
        # Kernel initialization for the convolutions
        if self.config.trainer.init_type == "normal":
            self.init_kernel = tf.random_normal_initializer(mean=0.0, stddev=0.02)
        elif self.config.trainer.init_type == "xavier":
            self.init_kernel = tf.contrib.layers.xavier_initializer(
                uniform=False, seed=None, dtype=tf.float32
            )
        # Placeholders
        self.is_training_gen = tf.placeholder(tf.bool)
        self.is_training_dis = tf.placeholder(tf.bool)
        self.is_training_enc = tf.placeholder(tf.bool)
        self.image_input = tf.placeholder(
            tf.float32, shape=[None] + self.config.trainer.image_dims, name="x"
        )
        self.noise_tensor = tf.placeholder(
            tf.float32, shape=[None, self.config.trainer.noise_dim], name="noise"
        )
        self.true_labels = tf.placeholder(dtype=tf.float32, shape=[None, 1], name="true_labels")
        self.generated_labels = tf.placeholder(dtype=tf.float32, shape=[None, 1], name="gen_labels")
        self.real_noise = tf.placeholder(
            dtype=tf.float32, shape=[None] + self.config.trainer.image_dims, name="real_noise"
        )
        self.fake_noise = tf.placeholder(
            dtype=tf.float32, shape=[None] + self.config.trainer.image_dims, name="fake_noise"
        )
        self.logger.info("Building training graph...")
        with tf.variable_scope("FAnogan"):
            # Generator and Discriminator Training
            with tf.variable_scope("Generator_Model"):
                self.image_gen = self.generator(self.noise_tensor) + self.fake_noise
            with tf.variable_scope("Discriminator_Model"):
                self.disc_real, self.disc_f_real = self.discriminator(
                    self.image_input + self.real_noise
                )
                self.disc_fake, self.disc_f_fake = self.discriminator(self.image_gen)
            # Encoder Training

            with tf.variable_scope("Encoder_Model"):
                # ZIZ Architecture
                self.encoded_gen_noise = self.encoder(self.image_gen)
                # IZI Architecture
                self.encoded_img = self.encoder(self.image_input)
            with tf.variable_scope("Generator_Model"):
                self.gen_enc_img = self.generator(self.encoded_img)
            with tf.variable_scope("Discriminator_Model"):
                # IZI Training
                self.disc_real_izi, self.disc_f_real_izi = self.discriminator(self.image_input)
                self.disc_fake_izi, self.disc_f_fake_izi = self.discriminator(self.gen_enc_img)

        with tf.name_scope("Loss_Funcions"):
            with tf.name_scope("Encoder"):
                if self.config.trainer.encoder_training_mode == "ziz":
                    self.loss_encoder = tf.reduce_mean(
                        self.mse_loss(
                            self.encoded_gen_noise,
                            self.noise_tensor,
                            mode=self.config.trainer.encoder_loss_mode,
                        )
                        * (1.0 / self.config.trainer.noise_dim)
                    )
                elif self.config.trainer.encoder_training_mode == "izi":
                    self.izi_reconstruction = self.mse_loss(
                        self.image_input,
                        self.gen_enc_img,
                        mode=self.config.trainer.encoder_loss_mode,
                    ) * (
                        1.0
                        / (self.config.data_loader.image_size * self.config.data_loader.image_size)
                    )
                    self.loss_encoder = tf.reduce_mean(self.izi_reconstruction)
                elif self.config.trainer.encoder_training_mode == "izi_f":
                    self.izi_reconstruction = self.mse_loss(
                        self.image_input,
                        self.gen_enc_img,
                        mode=self.config.trainer.encoder_loss_mode,
                    ) * (
                        1.0
                        / (self.config.data_loader.image_size * self.config.data_loader.image_size)
                    )
                    self.izi_disc = self.mse_loss(
                        self.disc_f_real_izi,
                        self.disc_f_fake_izi,
                        mode=self.config.trainer.encoder_loss_mode,
                    ) * (
                        1.0
                        * self.config.trainer.kappa_weight_factor
                        / self.config.trainer.feature_layer_dim
                    )
                    self.loss_encoder = tf.reduce_mean(self.izi_reconstruction + self.izi_disc)
            with tf.name_scope("Discriminator_Generator"):
                if self.config.trainer.mode == "standard":
                    self.loss_disc_real = tf.reduce_mean(
                        tf.nn.sigmoid_cross_entropy_with_logits(
                            labels=self.true_labels, logits=self.disc_real
                        )
                    )
                    self.loss_disc_fake = tf.reduce_mean(
                        tf.nn.sigmoid_cross_entropy_with_logits(
                            labels=self.generated_labels, logits=self.disc_fake
                        )
                    )
                    self.loss_discriminator = self.loss_disc_real + self.loss_disc_fake
                    # Flip the weigths for the encoder and generator
                    if self.config.trainer.flip_labels:
                        labels_gen = tf.zeros_like(self.disc_fake)
                    else:
                        labels_gen = tf.ones_like(self.disc_fake)
                    # Generator
                    self.loss_generator_ce = tf.reduce_mean(
                        tf.nn.sigmoid_cross_entropy_with_logits(
                            labels=labels_gen, logits=self.disc_fake
                        )
                    )
                    delta = self.disc_f_fake - self.disc_f_real
                    delta = tf.layers.Flatten()(delta)
                    self.loss_generator_fm = tf.reduce_mean(
                        tf.norm(delta, ord=2, axis=1, keepdims=False)
                    )
                    self.loss_generator = (
                        self.loss_generator_ce
                        + self.config.trainer.feature_match_weight * self.loss_generator_fm
                    )
                elif self.config.trainer.mode == "wgan":
                    self.loss_d_fake = -tf.reduce_mean(self.disc_fake)
                    self.loss_d_real = -tf.reduce_mean(self.disc_real)
                    self.loss_discriminator = -self.loss_d_fake + self.loss_d_real
                    self.loss_generator = -tf.reduce_mean(self.disc_fake)

                # Weight Clipping and Encoder Part
                elif self.config.trainer.mode == "wgan_gp":
                    self.loss_generator = -tf.reduce_mean(self.disc_fake)
                    self.loss_d_fake = -tf.reduce_mean(self.disc_fake)
                    self.loss_d_real = -tf.reduce_mean(self.disc_real)
                    self.loss_discriminator = -self.loss_d_fake - self.loss_d_real
                    alpha_x = tf.random_uniform(
                        shape=[self.config.data_loader.batch_size] + self.config.trainer.image_dims,
                        minval=0.0,
                        maxval=1.0,
                    )
                    differences_x = self.image_gen - self.image_input
                    interpolates_x = self.image_input + (alpha_x * differences_x)
                    gradients = tf.gradients(self.discriminator(interpolates_x), [interpolates_x])[
                        0
                    ]
                    slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
                    gradient_penalty = tf.reduce_mean((slopes - 1.0) ** 2)
                    self.loss_discriminator += self.config.trainer.wgan_gp_lambda * gradient_penalty
        with tf.name_scope("Optimizations"):
            if self.config.trainer.mode == "standard":
                # Build the optimizers
                self.generator_optimizer = tf.train.AdamOptimizer(
                    self.config.trainer.standard_lr,
                    beta1=self.config.trainer.optimizer_adam_beta1,
                    beta2=self.config.trainer.optimizer_adam_beta2,
                )
                self.discriminator_optimizer = tf.train.AdamOptimizer(
                    self.config.trainer.standard_lr_disc,
                    beta1=self.config.trainer.optimizer_adam_beta1,
                    beta2=self.config.trainer.optimizer_adam_beta2,
                )
                self.encoder_optimizer = tf.train.AdamOptimizer(
                    self.config.trainer.standard_lr,
                    beta1=self.config.trainer.optimizer_adam_beta1,
                    beta2=self.config.trainer.optimizer_adam_beta2,
                )
            elif self.config.trainer.mode == "wgan":
                # Build the optimizers
                self.generator_optimizer = tf.train.RMSPropOptimizer(self.config.trainer.wgan_lr)
                self.discriminator_optimizer = tf.train.RMSPropOptimizer(
                    self.config.trainer.standard_lr_disc
                )
                self.encoder_optimizer = tf.train.AdamOptimizer(
                    self.config.trainer.wgan_lr,
                    beta1=self.config.trainer.optimizer_adam_beta1,
                    beta2=self.config.trainer.optimizer_adam_beta2,
                )
            elif self.config.trainer.mode == "wgan_gp":
                # Build the optimizers
                self.generator_optimizer = tf.train.AdamOptimizer(
                    self.config.trainer.wgan_gp_lr, beta1=0.0, beta2=0.9
                )
                self.discriminator_optimizer = tf.train.AdamOptimizer(
                    self.config.trainer.standard_lr_disc, beta1=0.0, beta2=0.9
                )
                self.encoder_optimizer = tf.train.AdamOptimizer(
                    self.config.trainer.wgan_gp_lr, beta1=0.0, beta2=0.9
                )
            # Collect all the variables
            all_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            # Generator Network Variables
            self.generator_vars = [
                v for v in all_variables if v.name.startswith("FAnogan/Generator_Model")
            ]
            # Discriminator Network Variables
            self.discriminator_vars = [
                v for v in all_variables if v.name.startswith("FAnogan/Discriminator_Model")
            ]
            if self.config.trainer.mode == "wgan":
                clip_ops = []
                for var in self.discriminator_vars:
                    clip_bounds = [-0.01, 0.01]
                    clip_ops.append(
                        tf.assign(var, tf.clip_by_value(var, clip_bounds[0], clip_bounds[1]))
                    )
                self.clip_disc_weights = tf.group(*clip_ops)
            # Encoder Network Variables
            self.encoder_vars = [
                v for v in all_variables if v.name.startswith("FAnogan/Encoder_Model")
            ]
            # Create Training Operations
            # Generator Network Operations
            self.gen_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="FAnogan/Generator_Model"
            )
            # Discriminator Network Operations
            self.disc_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="FAnogan/Discriminator_Model"
            )
            # Encoder Network Operations
            self.enc_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="FAnogan/Encoder_Model"
            )
            # Initialization of Optimizers
            with tf.control_dependencies(self.gen_update_ops):
                self.gen_op = self.generator_optimizer.minimize(
                    self.loss_generator,
                    var_list=self.generator_vars,
                    global_step=self.global_step_tensor,
                )
            with tf.control_dependencies(self.disc_update_ops):
                self.disc_op = self.discriminator_optimizer.minimize(
                    self.loss_discriminator, var_list=self.discriminator_vars
                )
            with tf.control_dependencies(self.enc_update_ops):
                self.enc_op = self.encoder_optimizer.minimize(
                    self.loss_encoder, var_list=self.encoder_vars
                )
            # Exponential Moving Average for Estimation
            self.dis_ema = tf.train.ExponentialMovingAverage(decay=self.config.trainer.ema_decay)
            maintain_averages_op_dis = self.dis_ema.apply(self.discriminator_vars)

            self.gen_ema = tf.train.ExponentialMovingAverage(decay=self.config.trainer.ema_decay)
            maintain_averages_op_gen = self.gen_ema.apply(self.generator_vars)

            self.enc_ema = tf.train.ExponentialMovingAverage(decay=self.config.trainer.ema_decay)
            maintain_averages_op_enc = self.enc_ema.apply(self.encoder_vars)

            with tf.control_dependencies([self.disc_op]):
                self.train_dis_op = tf.group(maintain_averages_op_dis)

            with tf.control_dependencies([self.gen_op]):
                self.train_gen_op = tf.group(maintain_averages_op_gen)

            with tf.control_dependencies([self.enc_op]):
                self.train_enc_op = tf.group(maintain_averages_op_enc)

        self.logger.info("Building Testing Graph...")

        with tf.variable_scope("FAnogan"):
            # Generator and Discriminator Training
            with tf.variable_scope("Generator_Model"):
                self.image_gen_ema = self.generator(
                    self.noise_tensor, getter=get_getter(self.gen_ema)
                )
            with tf.variable_scope("Discriminator_Model"):
                self.disc_real_ema, self.disc_f_real_ema = self.discriminator(
                    self.image_input, getter=get_getter(self.dis_ema)
                )
                self.disc_fake_ema, self.disc_f_fake_ema = self.discriminator(
                    self.image_gen_ema, getter=get_getter(self.dis_ema)
                )
            # Encoder Training

            with tf.variable_scope("Encoder_Model"):
                # ZIZ Architecture
                self.encoded_gen_noise_ema = self.encoder(
                    self.image_gen_ema, getter=get_getter(self.enc_ema)
                )
                # IZI Architecture
                self.encoded_img_ema = self.encoder(
                    self.image_input, getter=get_getter(self.enc_ema)
                )
            with tf.variable_scope("Generator_Model"):
                self.gen_enc_img_ema = self.generator(
                    self.encoded_img_ema, getter=get_getter(self.gen_ema)
                )
            with tf.variable_scope("Discriminator_Model"):
                # IZI Training
                self.disc_real_izi_ema, self.disc_f_real_izi_ema = self.discriminator(
                    self.image_input, getter=get_getter(self.dis_ema)
                )
                self.disc_fake_izi_ema, self.disc_f_fake_izi_ema = self.discriminator(
                    self.gen_enc_img_ema, getter=get_getter(self.dis_ema)
                )

        with tf.name_scope("Testing"):
            with tf.name_scope("izi_f_loss"):
                self.score_reconstruction = self.mse_loss(
                    self.image_input, self.gen_enc_img_ema
                ) * (
                    1.0 / (self.config.data_loader.image_size * self.config.data_loader.image_size)
                )
                self.score_disc = self.mse_loss(
                    self.disc_f_real_izi_ema, self.disc_f_fake_izi_ema
                ) * (
                    1.0
                    * self.config.trainer.kappa_weight_factor
                    / self.config.trainer.feature_layer_dim
                )
                self.izi_f_score = self.score_reconstruction + self.score_disc
            with tf.name_scope("ziz_loss"):
                self.score_reconstruction = self.mse_loss(
                    self.image_input, self.gen_enc_img_ema
                ) * (
                    1.0 / (self.config.data_loader.image_size * self.config.data_loader.image_size)
                )
                self.ziz_score = self.score_reconstruction

        if self.config.trainer.enable_early_stop:
            self.rec_error_valid = tf.reduce_mean(self.izi_f_score)

        if self.config.log.enable_summary:
            with tf.name_scope("Summary"):
                with tf.name_scope("Disc_Summary"):
                    tf.summary.scalar("loss_discriminator", self.loss_discriminator, ["dis"])
                    if self.config.trainer.mode == "standard":
                        tf.summary.scalar("loss_dis_real", self.loss_disc_real, ["dis"])
                        tf.summary.scalar("loss_dis_fake", self.loss_disc_fake, ["dis"])
                with tf.name_scope("Gen_Summary"):
                    tf.summary.scalar("loss_generator", self.loss_generator, ["gen"])
                    if self.config.trainer.mode == "standard":
                        tf.summary.scalar("loss_generator_ce", self.loss_generator_ce, ["gen"])
                        tf.summary.scalar("loss_generator_fm", self.loss_generator_fm, ["gen"])
                    tf.summary.scalar("loss_encoder", self.loss_encoder, ["enc"])
                with tf.name_scope("Image_Summary"):
                    tf.summary.image("reconstruct", self.image_gen, 3, ["image_1"])
                    tf.summary.image("input_images", self.image_input, 3, ["image_1"])
                    tf.summary.image("gen_enc_img", self.gen_enc_img, 3, ["image_2"])
                    tf.summary.image("input_image_2", self.image_input, 3, ["image_2"])
        if self.config.trainer.enable_early_stop:
            with tf.name_scope("validation_summary"):
                tf.summary.scalar("valid", self.rec_error_valid, ["v"])

        self.sum_op_dis = tf.summary.merge_all("dis")
        self.sum_op_gen = tf.summary.merge_all("gen")
        self.sum_op_enc = tf.summary.merge_all("enc")
        self.sum_op_im_1 = tf.summary.merge_all("image_1")
        self.sum_op_im_2 = tf.summary.merge_all("image_2")
        self.sum_op_valid = tf.summary.merge_all("v")
예제 #4
0
    def build_model(self):
        # Kernel initialization for the convolutions
        self.init_kernel = tf.random_normal_initializer(mean=0.0, stddev=0.02)
        # Placeholders
        self.is_training = tf.placeholder(tf.bool)
        self.image_input = tf.placeholder(tf.float32,
                                          shape=[None] +
                                          self.config.trainer.image_dims,
                                          name="x")

        self.true_labels = tf.placeholder(dtype=tf.float32,
                                          shape=[None, 1],
                                          name="true_labels")
        self.generated_labels = tf.placeholder(dtype=tf.float32,
                                               shape=[None, 1],
                                               name="gen_labels")

        self.logger.info("Building training graph...")

        with tf.variable_scope("GANomaly"):
            with tf.variable_scope("Generator_Model"):
                self.noise_gen, self.img_rec, self.noise_rec = self.generator(
                    self.image_input)
            with tf.variable_scope("Discriminator_Model"):
                l_real, inter_layer_inp = self.discriminator(self.image_input)
                l_fake, inter_layer_rct = self.discriminator(self.img_rec)

        with tf.name_scope("Loss_Functions"):
            # Discriminator
            self.loss_dis_real = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=self.true_labels, logits=l_real))
            self.loss_dis_fake = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=self.generated_labels, logits=l_fake))
            # Feature matching part
            fm = inter_layer_inp - inter_layer_rct
            fm = tf.layers.Flatten()(fm)
            self.feature_match = tf.reduce_mean(
                tf.norm(fm, ord=2, axis=1, keepdims=False))
            self.loss_discriminator = (self.loss_dis_fake +
                                       self.loss_dis_real + self.feature_match
                                       if self.config.trainer.loss_method
                                       == "fm" else self.loss_dis_fake +
                                       self.loss_dis_real)
            # Generator
            # Adversarial Loss
            if self.config.trainer.flip_labels:
                labels = tf.zeros_like(l_fake)
            else:
                labels = tf.ones_like(l_real)
            self.gen_loss_ce = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,
                                                        logits=l_fake))
            # Contextual Loss
            l1_norm = self.image_input - self.img_rec
            l1_norm = tf.layers.Flatten()(l1_norm)
            self.gen_loss_con = tf.reduce_mean(
                tf.norm(l1_norm, ord=1, axis=1, keepdims=False))
            # Encoder Loss
            l2_norm = self.noise_gen - self.noise_rec
            l2_norm = tf.layers.Flatten()(l2_norm)
            self.gen_loss_enc = tf.reduce_mean(
                tf.norm(l2_norm, ord=2, axis=1, keepdims=False))

            self.gen_loss_total = (
                self.config.trainer.weight_adv * self.gen_loss_ce +
                self.config.trainer.weight_cont * self.gen_loss_con +
                self.config.trainer.weight_enc * self.gen_loss_enc)

        with tf.name_scope("Optimizers"):
            # Build the optimizers
            self.generator_optimizer = tf.train.AdamOptimizer(
                self.config.trainer.generator_l_rate,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )
            self.discriminator_optimizer = tf.train.AdamOptimizer(
                self.config.trainer.discriminator_l_rate,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )
            # Collect all the variables
            all_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            # Generator Network Variables
            self.generator_vars = [
                v for v in all_variables
                if v.name.startswith("GANomaly/Generator_Model")
            ]
            # Discriminator Network Variables
            self.discriminator_vars = [
                v for v in all_variables
                if v.name.startswith("GANomaly/Discriminator_Model")
            ]
            # Create Training Operations
            # Generator Network Operations
            self.gen_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="GANomaly/Generator_Model")
            # Discriminator Network Operations
            self.disc_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="GANomaly/Discriminator_Model")
            # Initialization of Optimizers
            with tf.control_dependencies(self.gen_update_ops):
                self.gen_op = self.generator_optimizer.minimize(
                    self.gen_loss_total,
                    global_step=self.global_step_tensor,
                    var_list=self.generator_vars,
                )
            with tf.control_dependencies(self.disc_update_ops):
                self.disc_op = self.discriminator_optimizer.minimize(
                    self.loss_discriminator, var_list=self.discriminator_vars)

            # Exponential Moving Average for Estimation
            self.dis_ema = tf.train.ExponentialMovingAverage(
                decay=self.config.trainer.ema_decay)
            maintain_averages_op_dis = self.dis_ema.apply(
                self.discriminator_vars)

            self.gen_ema = tf.train.ExponentialMovingAverage(
                decay=self.config.trainer.ema_decay)
            maintain_averages_op_gen = self.gen_ema.apply(self.generator_vars)

            with tf.control_dependencies([self.disc_op]):
                self.train_dis_op = tf.group(maintain_averages_op_dis)

            with tf.control_dependencies([self.gen_op]):
                self.train_gen_op = tf.group(maintain_averages_op_gen)

        self.logger.info("Building Testing Graph...")
        with tf.variable_scope("GANomaly"):
            with tf.variable_scope("Generator_Model"):
                self.noise_gen_ema, self.img_rec_ema, self.noise_rec_ema = self.generator(
                    self.image_input, getter=get_getter(self.gen_ema))
            with tf.variable_scope("Discriminator_model"):
                self.l_real_ema, self.inter_layer_inp_ema = self.discriminator(
                    self.image_input, getter=get_getter(self.dis_ema))
                self.l_fake_ema, self.inter_layer_rct_ema = self.discriminator(
                    self.img_rec_ema, getter=get_getter(self.dis_ema))

        with tf.name_scope("Testing"):
            with tf.variable_scope("Reconstruction_Loss"):
                # | G_E(x) - E(G(x))|1
                # Difference between the noise generated from the input image and reconstructed noise
                delta = self.noise_gen_ema - self.noise_rec_ema
                delta = tf.layers.Flatten()(delta)
                self.score = tf.norm(delta, ord=1, axis=1, keepdims=False)

        if self.config.trainer.enable_early_stop:
            self.rec_error_valid = tf.reduce_mean(self.score)

        if self.config.log.enable_summary:
            with tf.name_scope("summary"):
                with tf.name_scope("disc_summary"):
                    tf.summary.scalar("loss_discriminator_total",
                                      self.loss_discriminator, ["dis"])
                    tf.summary.scalar("loss_dis_real", self.loss_dis_real,
                                      ["dis"])
                    tf.summary.scalar("loss_dis_fake", self.loss_dis_fake,
                                      ["dis"])
                    if self.config.trainer.loss_method:
                        tf.summary.scalar("loss_dis_fm", self.feature_match,
                                          ["dis"])
                with tf.name_scope("gen_summary"):
                    tf.summary.scalar("loss_generator_total",
                                      self.gen_loss_total, ["gen"])
                    tf.summary.scalar("loss_gen_adv", self.gen_loss_ce,
                                      ["gen"])
                    tf.summary.scalar("loss_gen_con", self.gen_loss_con,
                                      ["gen"])
                    tf.summary.scalar("loss_gen_enc", self.gen_loss_enc,
                                      ["gen"])
                with tf.name_scope("image_summary"):
                    tf.summary.image("reconstruct", self.img_rec, 3, ["image"])
                    tf.summary.image("input_images", self.image_input, 3,
                                     ["image"])
        if self.config.trainer.enable_early_stop:
            with tf.name_scope("validation_summary"):
                tf.summary.scalar("valid", self.rec_error_valid, ["v"])

        self.sum_op_dis = tf.summary.merge_all("dis")
        self.sum_op_gen = tf.summary.merge_all("gen")
        self.sum_op_im = tf.summary.merge_all("image")
        self.sum_op_valid = tf.summary.merge_all("v")
예제 #5
0
    def build_model(self):
        # Placeholdersn
        if self.config.trainer.init_type == "normal":
            self.init_kernel = tf.random_normal_initializer(mean=0.0,
                                                            stddev=0.02)
        elif self.config.trainer.init_type == "xavier":
            self.init_kernel = tf.contrib.layers.xavier_initializer(
                uniform=False, seed=None, dtype=tf.float32)
        self.is_training = tf.placeholder(tf.bool)
        self.image_input = tf.placeholder(tf.float32,
                                          shape=[None] +
                                          self.config.trainer.image_dims,
                                          name="x")
        self.noise_tensor = tf.placeholder(
            tf.float32,
            shape=[None, self.config.trainer.noise_dim],
            name="noise")
        # Placeholders for the true and fake labels
        self.true_labels = tf.placeholder(dtype=tf.float32, shape=[None, 1])
        self.generated_labels = tf.placeholder(dtype=tf.float32,
                                               shape=[None, 1])
        self.real_noise = tf.placeholder(dtype=tf.float32,
                                         shape=[None] +
                                         self.config.trainer.image_dims,
                                         name="real_noise")
        self.fake_noise = tf.placeholder(dtype=tf.float32,
                                         shape=[None] +
                                         self.config.trainer.image_dims,
                                         name="fake_noise")
        # Building the Graph
        self.logger.info("Building Graph")
        with tf.variable_scope("ANOGAN"):
            with tf.variable_scope("Generator_Model"):
                self.img_gen = self.generator(
                    self.noise_tensor) + self.fake_noise
            # Discriminator
            with tf.variable_scope("Discriminator_Model"):
                disc_real, inter_layer_real = self.discriminator(
                    self.image_input + self.real_noise)
                disc_fake, inter_layer_fake = self.discriminator(self.img_gen)

        # Losses of the training of Generator and Discriminator

        with tf.variable_scope("Loss_Functions"):
            with tf.name_scope("Discriminator_Loss"):
                self.disc_loss_real = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                        labels=self.true_labels, logits=disc_real))
                self.disc_loss_fake = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                        labels=self.generated_labels, logits=disc_fake))
                self.total_disc_loss = self.disc_loss_real + self.disc_loss_fake
            with tf.name_scope("Generator_Loss"):
                if self.config.trainer.flip_labels:
                    labels = tf.zeros_like(disc_fake)
                else:
                    labels = tf.ones_like(disc_fake)

                self.gen_loss = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,
                                                            logits=disc_fake))
        # Build the Optimizers
        with tf.variable_scope("Optimizers"):
            # Collect all the variables
            all_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            # Generator Network Variables
            self.generator_vars = [
                v for v in all_variables
                if v.name.startswith("ANOGAN/Generator_Model")
            ]
            # Discriminator Network Variables
            self.discriminator_vars = [
                v for v in all_variables
                if v.name.startswith("ANOGAN/Discriminator_Model")
            ]
            # Create Training Operations
            # Generator Network Operations
            self.gen_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="ANOGAN/Generator_Model")
            # Discriminator Network Operations
            self.disc_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="ANOGAN/Discriminator_Model")
            # Initialization of Optimizers
            self.generator_optimizer = tf.train.AdamOptimizer(
                self.config.trainer.generator_l_rate,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )
            self.discriminator_optimizer = tf.train.AdamOptimizer(
                self.config.trainer.discriminator_l_rate,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )
            with tf.control_dependencies(self.gen_update_ops):
                self.train_gen = self.generator_optimizer.minimize(
                    self.gen_loss,
                    global_step=self.global_step_tensor,
                    var_list=self.generator_vars)
            with tf.control_dependencies(self.disc_update_ops):
                self.train_disc = self.discriminator_optimizer.minimize(
                    self.total_disc_loss, var_list=self.discriminator_vars)

            def train_op_with_ema_dependency(vars, op):
                ema = tf.train.ExponentialMovingAverage(
                    decay=self.config.trainer.ema_decay)
                maintain_averages_op = ema.apply(vars)
                with tf.control_dependencies([op]):
                    train_op = tf.group(maintain_averages_op)
                return train_op, ema

            self.train_gen_op, self.gen_ema = train_op_with_ema_dependency(
                self.generator_vars, self.train_gen)
            self.train_dis_op, self.dis_ema = train_op_with_ema_dependency(
                self.discriminator_vars, self.train_disc)
        with tf.variable_scope("Latent_variable"):
            self.z_optim = tf.get_variable(
                name="z_optim",
                shape=[
                    self.config.data_loader.test_batch,
                    self.config.trainer.noise_dim
                ],
                initializer=tf.truncated_normal_initializer(),
            )
            reinit_z = self.z_optim.initializer

        with tf.variable_scope("ANOGAN"):
            with tf.variable_scope("Generator_Model"):
                self.x_gen_ema = self.generator(self.noise_tensor,
                                                getter=sn.get_getter(
                                                    self.gen_ema))
                self.rec_gen_ema = self.generator(self.z_optim,
                                                  getter=sn.get_getter(
                                                      self.gen_ema))
            # Pass real and fake images into discriminator separately
            with tf.variable_scope("Discriminator_Model"):
                real_d_ema, inter_layer_real_ema = self.discriminator(
                    self.image_input, getter=sn.get_getter(self.dis_ema))
                fake_d_ema, inter_layer_fake_ema = self.discriminator(
                    self.rec_gen_ema, getter=sn.get_getter(self.dis_ema))

        with tf.variable_scope("Testing"):
            with tf.variable_scope("Reconstruction_Loss"):
                delta = self.image_input - self.rec_gen_ema
                delta_flat = tf.layers.Flatten()(delta)
                self.reconstruction_score_1 = tf.norm(
                    delta_flat,
                    ord=1,
                    axis=1,
                    keepdims=False,
                    name="epsilon",
                )
                self.reconstruction_score_2 = tf.norm(
                    delta_flat,
                    ord=2,
                    axis=1,
                    keepdims=False,
                    name="epsilon",
                )

            with tf.variable_scope("Discriminator_Scores"):
                if self.config.trainer.loss_method == "c_entropy":
                    dis_score = tf.nn.sigmoid_cross_entropy_with_logits(
                        labels=tf.ones_like(fake_d_ema), logits=fake_d_ema)
                elif self.config.trainer.loss_method == "fm":
                    fm = inter_layer_real_ema - inter_layer_fake_ema
                    fm = tf.layers.Flatten()(fm)
                    dis_score = tf.norm(fm,
                                        ord=self.config.trainer.degree,
                                        axis=1,
                                        keepdims=False,
                                        name="d_loss")

                self.dis_score = tf.squeeze(dis_score)

            with tf.variable_scope("Score"):
                self.loss_invert_1 = (
                    self.config.trainer.weight * self.reconstruction_score_1 +
                    (1 - self.config.trainer.weight) * self.dis_score)
                self.loss_invert_2 = (
                    self.config.trainer.weight * self.reconstruction_score_2 +
                    (1 - self.config.trainer.weight) * self.dis_score)

        self.rec_error_valid = tf.reduce_mean(self.gen_loss)

        with tf.variable_scope("Test_Learning_Rate"):
            step_lr = tf.Variable(0, trainable=False)
            learning_rate_invert = 0.001
            reinit_lr = tf.variables_initializer(
                tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                  scope="Test_Learning_Rate"))

        with tf.name_scope("Test_Optimizer"):
            self.invert_op = tf.train.AdamOptimizer(
                learning_rate_invert).minimize(self.loss_invert_1,
                                               global_step=step_lr,
                                               var_list=[self.z_optim],
                                               name="optimizer")
            reinit_optim = tf.variables_initializer(
                tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                  scope="Test_Optimizer"))

        self.reinit_test_graph_op = [reinit_z, reinit_lr, reinit_optim]

        with tf.name_scope("Scores"):
            self.list_scores = self.loss_invert_1

        if self.config.log.enable_summary:
            with tf.name_scope("Training_Summary"):
                with tf.name_scope("Dis_Summary"):
                    tf.summary.scalar("Real_Discriminator_Loss",
                                      self.disc_loss_real, ["dis"])
                    tf.summary.scalar("Fake_Discriminator_Loss",
                                      self.disc_loss_fake, ["dis"])
                    tf.summary.scalar("Discriminator_Loss",
                                      self.total_disc_loss, ["dis"])

                with tf.name_scope("Gen_Summary"):
                    tf.summary.scalar("Loss_Generator", self.gen_loss, ["gen"])

            with tf.name_scope("Img_Summary"):
                heatmap_pl_latent = tf.placeholder(tf.float32,
                                                   shape=(1, 480, 640, 3),
                                                   name="heatmap_pl_latent")
                sum_op_latent = tf.summary.image("heatmap_latent",
                                                 heatmap_pl_latent)

            with tf.name_scope("Validation_Summary"):
                tf.summary.scalar("valid", self.rec_error_valid, ["v"])

            with tf.name_scope("image_summary"):
                tf.summary.image("reconstruct", self.img_gen, 1, ["image"])
                tf.summary.image("input_images", self.image_input, 1,
                                 ["image"])
                tf.summary.image("reconstruct", self.rec_gen_ema, 1,
                                 ["image_2"])
                tf.summary.image("input_images", self.image_input, 1,
                                 ["image_2"])

            self.sum_op_dis = tf.summary.merge_all("dis")
            self.sum_op_gen = tf.summary.merge_all("gen")
            self.sum_op = tf.summary.merge([self.sum_op_dis, self.sum_op_gen])
            self.sum_op_im = tf.summary.merge_all("image")
            self.sum_op_im_test = tf.summary.merge_all("image_2")
            self.sum_op_valid = tf.summary.merge_all("v")
    def build_model(self):
        ############################################################################################
        # INIT
        ############################################################################################
        # Kernel initialization for the convolutions
        if self.config.trainer.init_type == "normal":
            self.init_kernel = tf.random_normal_initializer(mean=0.0, stddev=0.02)
        elif self.config.trainer.init_type == "xavier":
            self.init_kernel = tf.contrib.layers.xavier_initializer(
                uniform=False, seed=None, dtype=tf.float32
            )
        # Placeholders
        self.is_training_gen = tf.placeholder(tf.bool)
        self.is_training_dis = tf.placeholder(tf.bool)
        self.is_training_enc_g = tf.placeholder(tf.bool)
        self.is_training_enc_r = tf.placeholder(tf.bool)
        self.image_input = tf.placeholder(
            tf.float32, shape=[None] + self.config.trainer.image_dims, name="x"
        )
        self.noise_tensor = tf.placeholder(
            tf.float32, shape=[None, self.config.trainer.noise_dim], name="noise"
        )
        self.denoiser_noise = tf.placeholder(
            tf.float32, shape=[None] + self.config.trainer.image_dims, name="noise"
        )
        ############################################################################################
        # MODEL
        ############################################################################################
        self.logger.info("Building training graph...")
        with tf.variable_scope("SENCEBGAN_Denoiser"):
            # First training part
            # G(z) ==> x'
            with tf.variable_scope("Generator_Model"):
                self.image_gen = self.generator(self.noise_tensor)
            # Discriminator outputs
            with tf.variable_scope("Discriminator_Model"):
                self.embedding_real, self.decoded_real = self.discriminator(
                    self.image_input, do_spectral_norm=self.config.trainer.do_spectral_norm
                )
                self.embedding_fake, self.decoded_fake = self.discriminator(
                    self.image_gen, do_spectral_norm=self.config.trainer.do_spectral_norm
                )
            # Second training part
            # E(x) ==> z'
            with tf.variable_scope("Encoder_G_Model"):
                self.image_encoded = self.encoder_g(self.image_input)
            # G(z') ==> G(E(x)) ==> x''
            with tf.variable_scope("Generator_Model"):
                self.image_gen_enc = self.generator(self.image_encoded)
            # Discriminator outputs
            with tf.variable_scope("Discriminator_Model"):
                self.embedding_enc_fake, self.decoded_enc_fake = self.discriminator(
                    self.image_gen_enc, do_spectral_norm=self.config.trainer.do_spectral_norm
                )
                self.embedding_enc_real, self.decoded_enc_real = self.discriminator(
                    self.image_input, do_spectral_norm=self.config.trainer.do_spectral_norm
                )
            with tf.variable_scope("Discriminator_Model_XX"):
                self.im_logit_real, self.im_f_real = self.discriminator_xx(
                    self.image_input,
                    self.image_input,
                    do_spectral_norm=self.config.trainer.do_spectral_norm,
                )
                self.im_logit_fake, self.im_f_fake = self.discriminator_xx(
                    self.image_input,
                    self.image_gen_enc,
                    do_spectral_norm=self.config.trainer.do_spectral_norm,
                )
            # Third training part
            with tf.variable_scope("Encoder_G_Model"):
                self.image_encoded_r = self.encoder_g(self.image_input)

            with tf.variable_scope("Generator_Model"):
                self.image_gen_enc_r = self.generator(self.image_encoded_r)

            with tf.variable_scope("Denoiser_Model"):
                self.mask, self.output = self.denoiser(self.image_gen_enc_r + self.denoiser_noise)

        ############################################################################################
        # LOSS FUNCTIONS
        ############################################################################################
        with tf.name_scope("Loss_Functions"):
            with tf.name_scope("Generator_Discriminator"):
                # Discriminator Loss
                if self.config.trainer.mse_mode == "norm":
                    self.disc_loss_real = tf.reduce_mean(
                        self.mse_loss(
                            self.decoded_real,
                            self.image_input,
                            mode="norm",
                            order=self.config.trainer.order,
                        )
                    )
                    self.disc_loss_fake = tf.reduce_mean(
                        self.mse_loss(
                            self.decoded_fake,
                            self.image_gen,
                            mode="norm",
                            order=self.config.trainer.order,
                        )
                    )
                elif self.config.trainer.mse_mode == "mse":
                    self.disc_loss_real = self.mse_loss(
                        self.decoded_real,
                        self.image_input,
                        mode="mse",
                        order=self.config.trainer.order,
                    )
                    self.disc_loss_fake = self.mse_loss(
                        self.decoded_fake,
                        self.image_gen,
                        mode="mse",
                        order=self.config.trainer.order,
                    )
                self.loss_discriminator = (
                    tf.math.maximum(self.config.trainer.disc_margin - self.disc_loss_fake, 0)
                    + self.disc_loss_real
                )
                # Generator Loss
                pt_loss = 0
                if self.config.trainer.pullaway:
                    pt_loss = self.pullaway_loss(self.embedding_fake)
                self.loss_generator = self.disc_loss_fake + self.config.trainer.pt_weight * pt_loss
                # New addition to enforce visual similarity
                delta_noise = self.embedding_real - self.embedding_fake
                delta_flat = tf.layers.Flatten()(delta_noise)
                loss_noise_gen = tf.reduce_mean(
                    tf.norm(delta_flat, ord=2, axis=1, keepdims=False)
                    )
                self.loss_generator += (0.1 * loss_noise_gen)

            with tf.name_scope("Encoder_G"):
                if self.config.trainer.mse_mode == "norm":
                    self.loss_enc_rec = tf.reduce_mean(
                        self.mse_loss(
                            self.image_gen_enc,
                            self.image_input,
                            mode="norm",
                            order=self.config.trainer.order,
                        )
                    )
                    self.loss_enc_f = tf.reduce_mean(
                        self.mse_loss(
                            self.embedding_enc_real,
                            self.embedding_enc_fake,
                            mode="norm",
                            order=self.config.trainer.order,
                        )
                    )
                elif self.config.trainer.mse_mode == "mse":
                    self.loss_enc_rec = tf.reduce_mean(
                        self.mse_loss(
                            self.image_gen_enc,
                            self.image_input,
                            mode="mse",
                            order=self.config.trainer.order,
                        )
                    )
                    self.loss_enc_f = tf.reduce_mean(
                        self.mse_loss(
                            self.embedding_enc_real,
                            self.embedding_enc_fake,
                            mode="mse",
                            order=self.config.trainer.order,
                        )
                    )
                self.loss_encoder_g = (
                    self.loss_enc_rec + self.config.trainer.encoder_f_factor * self.loss_enc_f
                )
                if self.config.trainer.enable_disc_xx:
                    self.enc_xx_real = tf.nn.sigmoid_cross_entropy_with_logits(
                        logits=self.im_logit_real, labels=tf.zeros_like(self.im_logit_real)
                    )
                    self.enc_xx_fake = tf.nn.sigmoid_cross_entropy_with_logits(
                        logits=self.im_logit_fake, labels=tf.ones_like(self.im_logit_fake)
                    )
                    self.enc_loss_xx = tf.reduce_mean(self.enc_xx_real + self.enc_xx_fake)
                    self.loss_encoder_g += self.enc_loss_xx

            with tf.name_scope("Denoiser"):
                # This part was normally the rec_image which was input but trying sth new.
                delta_den = self.output - self.image_input
                delta_den = tf.layers.Flatten()(delta_den)
                self.den_loss = tf.reduce_mean(
                    tf.norm(
                        delta_den,
                        ord=2,
                        axis=1,
                        keepdims=False,
                    )
                )

            if self.config.trainer.enable_disc_xx:
                with tf.name_scope("Discriminator_XX"):
                    self.loss_xx_real = tf.nn.sigmoid_cross_entropy_with_logits(
                        logits=self.im_logit_real, labels=tf.ones_like(self.im_logit_real)
                    )
                    self.loss_xx_fake = tf.nn.sigmoid_cross_entropy_with_logits(
                        logits=self.im_logit_fake, labels=tf.zeros_like(self.im_logit_fake)
                    )
                    self.dis_loss_xx = tf.reduce_mean(self.loss_xx_real + self.loss_xx_fake)
            if self.config.trainer.enable_disc_zz:
                with tf.name_scope("Discriminator_ZZ"):
                    self.loss_zz_real = tf.nn.sigmoid_cross_entropy_with_logits(
                        logits=self.z_logit_real, labels=tf.ones_like(self.z_logit_real)
                    )
                    self.loss_zz_fake = tf.nn.sigmoid_cross_entropy_with_logits(
                        logits=self.z_logit_fake, labels=tf.zeros_like(self.z_logit_fake)
                    )
                    self.dis_loss_zz = tf.reduce_mean(self.loss_zz_real + self.loss_zz_fake)

        ############################################################################################
        # OPTIMIZERS
        ############################################################################################
        with tf.name_scope("Optimizers"):
            self.generator_optimizer = tf.train.AdamOptimizer(
                self.config.trainer.standard_lr_gen,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )
            self.encoder_g_optimizer = tf.train.AdamOptimizer(
                self.config.trainer.standard_lr_enc,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )
            self.denoiser_optimizer = tf.train.AdamOptimizer(
                1e-4,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )
            self.discriminator_optimizer = tf.train.AdamOptimizer(
                self.config.trainer.standard_lr_dis,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )
            # Collect all the variables
            all_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            # Generator Network Variables
            self.generator_vars = [
                v for v in all_variables if v.name.startswith("SENCEBGAN_Denoiser/Generator_Model")
            ]
            # Discriminator Network Variables
            self.discriminator_vars = [
                v for v in all_variables if v.name.startswith("SENCEBGAN_Denoiser/Discriminator_Model")
            ]
            # Discriminator Network Variables
            self.encoder_g_vars = [
                v for v in all_variables if v.name.startswith("SENCEBGAN_Denoiser/Encoder_G_Model")
            ]
            self.denoiser_vars = [
                v for v in all_variables if v.name.startswith("SENCEBGAN_Denoiser/Denoiser_Model")
            ]
            self.dxxvars = [
                v for v in all_variables if v.name.startswith("SENCEBGAN_Denoiser/Discriminator_Model_XX")
            ]
            self.dzzvars = [
                v for v in all_variables if v.name.startswith("SENCEBGAN_Denoiser/Discriminator_Model_ZZ")
            ]
            # Generator Network Operations
            self.gen_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="SENCEBGAN_Denoiser/Generator_Model"
            )
            # Discriminator Network Operations
            self.disc_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="SENCEBGAN_Denoiser/Discriminator_Model"
            )
            self.encg_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="SENCEBGAN_Denoiser/Encoder_G_Model"
            )

            self.den_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="SENCEBGAN_Denoiser/Denoiser_Model"
            )
            self.update_ops_dis_xx = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="SENCEBGAN_Denoiser/Discriminator_Model_XX"
            )
            self.update_ops_dis_zz = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="SENCEBGAN_Denoiser/Discriminator_Model_ZZ"
            )
            with tf.control_dependencies(self.gen_update_ops):
                self.gen_op = self.generator_optimizer.minimize(
                    self.loss_generator,
                    var_list=self.generator_vars,
                    global_step=self.global_step_tensor,
                )
            with tf.control_dependencies(self.disc_update_ops):
                self.disc_op = self.discriminator_optimizer.minimize(
                    self.loss_discriminator, var_list=self.discriminator_vars
                )
            with tf.control_dependencies(self.encg_update_ops):
                self.encg_op = self.encoder_g_optimizer.minimize(
                    self.loss_encoder_g,
                    var_list=self.encoder_g_vars,
                    global_step=self.global_step_tensor,
                )
            with tf.control_dependencies(self.den_update_ops):
                self.den_op = self.denoiser_optimizer.minimize(
                    self.den_loss,
                    var_list=self.denoiser_vars,
                    global_step=self.global_step_tensor,
                )
            if self.config.trainer.enable_disc_xx:
                with tf.control_dependencies(self.update_ops_dis_xx):
                    self.disc_op_xx = self.discriminator_optimizer.minimize(
                        self.dis_loss_xx, var_list=self.dxxvars
                    )
            if self.config.trainer.enable_disc_zz:
                with tf.control_dependencies(self.update_ops_dis_zz):
                    self.disc_op_zz = self.discriminator_optimizer.minimize(
                        self.dis_loss_zz, var_list=self.dzzvars
                    )
            # if self.config.trainer.extra_gan_training:
            #     with tf.control_dependencies(self.gen_update_ops):
            #         self.gen_2_op = self.generator_optimizer.minimize(
            #             self.loss_generator_2, var_list=self.generator_vars
            #         )

            # Exponential Moving Average for Estimation
            self.dis_ema = tf.train.ExponentialMovingAverage(decay=self.config.trainer.ema_decay)
            maintain_averages_op_dis = self.dis_ema.apply(self.discriminator_vars)

            self.gen_ema = tf.train.ExponentialMovingAverage(decay=self.config.trainer.ema_decay)
            maintain_averages_op_gen = self.gen_ema.apply(self.generator_vars)
            # if self.config.trainer.extra_gan_training:
            #     self.gen_2_ema = tf.train.ExponentialMovingAverage(decay=self.config.trainer.ema_decay)
            #     maintain_averages_op_gen_2 = self.gen_2_ema.apply(self.generator_vars)

            self.encg_ema = tf.train.ExponentialMovingAverage(decay=self.config.trainer.ema_decay)
            maintain_averages_op_encg = self.encg_ema.apply(self.encoder_g_vars)

            self.den_ema = tf.train.ExponentialMovingAverage(decay=self.config.trainer.ema_decay)
            maintain_averages_op_den = self.den_ema.apply(self.denoiser_vars)

            if self.config.trainer.enable_disc_xx:
                self.dis_xx_ema = tf.train.ExponentialMovingAverage(
                    decay=self.config.trainer.ema_decay
                )
                maintain_averages_op_dis_xx = self.dis_xx_ema.apply(self.dxxvars)

            if self.config.trainer.enable_disc_zz:
                self.dis_zz_ema = tf.train.ExponentialMovingAverage(
                    decay=self.config.trainer.ema_decay
                )
                maintain_averages_op_dis_zz = self.dis_zz_ema.apply(self.dzzvars)

            with tf.control_dependencies([self.disc_op]):
                self.train_dis_op = tf.group(maintain_averages_op_dis)

            with tf.control_dependencies([self.gen_op]):
                self.train_gen_op = tf.group(maintain_averages_op_gen)

            # if self.config.trainer.extra_gan_training:
            #     with tf.control_dependencies([self.gen_2_op]):
            #         self.train_gen_op_2 = tf.group(maintain_averages_op_gen_2)

            with tf.control_dependencies([self.encg_op]):
                self.train_enc_g_op = tf.group(maintain_averages_op_encg)

            with tf.control_dependencies([self.den_op]):
                self.train_den_op = tf.group(maintain_averages_op_den)

            if self.config.trainer.enable_disc_xx:
                with tf.control_dependencies([self.disc_op_xx]):
                    self.train_dis_op_xx = tf.group(maintain_averages_op_dis_xx)

            if self.config.trainer.enable_disc_zz:
                with tf.control_dependencies([self.disc_op_zz]):
                    self.train_dis_op_zz = tf.group(maintain_averages_op_dis_zz)

        ############################################################################################
        # TESTING
        ############################################################################################
        self.logger.info("Building Testing Graph...")
        with tf.variable_scope("SENCEBGAN_Denoiser"):
            with tf.variable_scope("Discriminator_Model"):
                self.embedding_q_ema, self.decoded_q_ema = self.discriminator(
                    self.image_input,
                    getter=get_getter(self.dis_ema),
                    do_spectral_norm=self.config.trainer.do_spectral_norm,
                )
            with tf.variable_scope("Generator_Model"):
                self.image_gen_ema = self.generator(
                    self.embedding_q_ema, getter=get_getter(self.gen_ema)
                )
            with tf.variable_scope("Discriminator_Model"):
                self.embedding_rec_ema, self.decoded_rec_ema = self.discriminator(
                    self.image_gen_ema,
                    getter=get_getter(self.dis_ema),
                    do_spectral_norm=self.config.trainer.do_spectral_norm,
                )
            # Second Training Part
            with tf.variable_scope("Encoder_G_Model"):
                self.image_encoded_ema = self.encoder_g(
                    self.image_input, getter=get_getter(self.encg_ema)
                )

            with tf.variable_scope("Generator_Model"):
                self.image_gen_enc_ema = self.generator(
                    self.image_encoded_ema, getter=get_getter(self.gen_ema)
                )
            with tf.variable_scope("Discriminator_Model"):
                self.embedding_enc_fake_ema, self.decoded_enc_fake_ema = self.discriminator(
                    self.image_gen_enc_ema,
                    getter=get_getter(self.dis_ema),
                    do_spectral_norm=self.config.trainer.do_spectral_norm,
                )
                self.embedding_enc_real_ema, self.decoded_enc_real_ema = self.discriminator(
                    self.image_input,
                    getter=get_getter(self.dis_ema),
                    do_spectral_norm=self.config.trainer.do_spectral_norm,
                )
            if self.config.trainer.enable_disc_xx:
                with tf.variable_scope("Discriminator_Model_XX"):
                    self.im_logit_real_ema, self.im_f_real_ema = self.discriminator_xx(
                        self.image_input,
                        self.image_input,
                        getter=get_getter(self.dis_xx_ema),
                        do_spectral_norm=self.config.trainer.do_spectral_norm,
                    )
                    self.im_logit_fake_ema, self.im_f_fake_ema = self.discriminator_xx(
                        self.image_input,
                        self.image_gen_enc_ema,
                        getter=get_getter(self.dis_xx_ema),
                        do_spectral_norm=self.config.trainer.do_spectral_norm,
                    )
            # Third training part
            with tf.variable_scope("Encoder_G_Model"):
                self.image_encoded_r_ema = self.encoder_g(self.image_input)

            with tf.variable_scope("Generator_Model"):
                self.image_gen_enc_r_ema = self.generator(self.image_encoded_r_ema)

            with tf.variable_scope("Denoiser_Model"):
                self.output_ema, self.mask_ema = self.denoiser(self.image_gen_enc_r_ema, getter=get_getter(self.den_ema))

            

        with tf.name_scope("Testing"):
            with tf.name_scope("Image_Based"):
                delta = self.image_input - self.image_gen_enc_ema
                self.mask = -delta
                delta_flat = tf.layers.Flatten()(delta)
                img_score_l1 = tf.norm(
                    delta_flat, ord=1, axis=1, keepdims=False, name="img_loss__1"
                )
                self.img_score_l1 = tf.squeeze(img_score_l1)

                delta = self.embedding_enc_fake_ema - self.embedding_enc_real_ema
                delta_flat = tf.layers.Flatten()(delta)
                img_score_l2 = tf.norm(
                    delta_flat, ord=1, axis=1, keepdims=False, name="img_loss__2"
                )
                self.img_score_l2 = tf.squeeze(img_score_l2)
                self.score_comb = (
                    (1 - self.config.trainer.feature_match_weight) * self.img_score_l1
                    + self.config.trainer.feature_match_weight * self.img_score_l2
                )
                with tf.variable_scope("Pipeline_Loss_1"):
                    delta_pipe = self.output_ema - self.image_input
                    delta_pipe = tf.layers.Flatten()(delta_pipe)
                    self.pipe_score = tf.norm(delta_pipe, ord=1,axis=1,keepdims=False)
                with tf.variable_scope("Pipeline_Loss_2"):
                    delta_pipe = self.output_ema - self.image_input
                    delta_pipe = tf.layers.Flatten()(delta_pipe)
                    self.pipe_score_2 = tf.norm(delta_pipe, ord=2,axis=1,keepdims=False)
            with tf.name_scope("Noise_Based"):
                with tf.variable_scope("Mask_1"):
                    delta_mask = (self.image_input - self.mask_ema) 
                    delta_mask = tf.layers.Flatten()(delta_mask)
                    self.mask_score_1 = tf.norm(delta_mask, ord=1,axis=1,keepdims=False)
                with tf.variable_scope("Mask_2"):
                    delta_mask_2 = (self.image_input - self.mask_ema) 
                    delta_mask_2 = tf.layers.Flatten()(delta_mask_2)
                    self.mask_score_2 = tf.norm(delta_mask_2, ord=2,axis=1,keepdims=False)

        ############################################################################################
        # TENSORBOARD
        ############################################################################################
        if self.config.log.enable_summary:
            with tf.name_scope("train_summary"):
                with tf.name_scope("dis_summary"):
                    tf.summary.scalar("loss_disc", self.loss_discriminator, ["dis"])
                    tf.summary.scalar("loss_disc_real", self.disc_loss_real, ["dis"])
                    tf.summary.scalar("loss_disc_fake", self.disc_loss_fake, ["dis"])
                    if self.config.trainer.enable_disc_xx:
                        tf.summary.scalar("loss_dis_xx", self.dis_loss_xx, ["enc_g"])
                    if self.config.trainer.enable_disc_zz:
                        tf.summary.scalar("loss_dis_zz", self.dis_loss_zz, ["den"])
                with tf.name_scope("gen_summary"):
                    tf.summary.scalar("loss_generator", self.loss_generator, ["gen"])
                with tf.name_scope("enc_summary"):
                    tf.summary.scalar("loss_encoder_g", self.loss_encoder_g, ["enc_g"])
                    tf.summary.scalar("loss_den", self.den_loss, ["den"])
                with tf.name_scope("img_summary"):
                    tf.summary.image("input_image", self.image_input, 1, ["img_1"])
                    tf.summary.image("reconstructed", self.image_gen, 1, ["img_1"])
                    tf.summary.image("input_enc", self.image_input, 1, ["img_2"])
                    tf.summary.image("reconstructed", self.image_gen_enc, 1, ["img_2"])
                    tf.summary.image("input_image",self.image_input,1,["test"])
                    tf.summary.image("reconstructed", self.image_gen_enc_r_ema,1,["test"])
                    tf.summary.image("mask", self.mask_ema, 1, ["test"])
                    tf.summary.image("output", self.output_ema, 1, ["test"])


            self.sum_op_dis = tf.summary.merge_all("dis")
            self.sum_op_gen = tf.summary.merge_all("gen")
            self.sum_op_enc_g = tf.summary.merge_all("enc_g")
            self.sum_op_den = tf.summary.merge_all("den")
            self.sum_op_im_1 = tf.summary.merge_all("img_1")
            self.sum_op_im_2 = tf.summary.merge_all("img_2")
            self.sum_op_test = tf.summary.merge_all("test")
            self.sum_op = tf.summary.merge([self.sum_op_dis, self.sum_op_gen])
예제 #7
0
    def build_model(self):
        # Place holders
        self.img_size = self.config.data_loader.image_size
        self.is_training = tf.placeholder(tf.bool)
        self.image_input = tf.placeholder(dtype=tf.float32,
                                          shape=[None] +
                                          self.config.trainer.image_dims,
                                          name="x")
        if self.config.trainer.init_type == "normal":
            self.init_kernel = tf.random_normal_initializer(mean=0.0,
                                                            stddev=0.02)
        elif self.config.trainer.init_type == "xavier":
            self.init_kernel = tf.contrib.layers.xavier_initializer(
                uniform=False, seed=None, dtype=tf.float32)
        self.true_labels = tf.placeholder(dtype=tf.float32,
                                          shape=[None, 1],
                                          name="true_labels")
        self.generated_labels = tf.placeholder(dtype=tf.float32,
                                               shape=[None, 1],
                                               name="gen_labels")
        self.real_noise = tf.placeholder(dtype=tf.float32,
                                         shape=[None] +
                                         self.config.trainer.image_dims,
                                         name="real_noise")
        self.fake_noise = tf.placeholder(dtype=tf.float32,
                                         shape=[None] +
                                         self.config.trainer.image_dims,
                                         name="fake_noise")
        #######################################################################
        # GRAPH
        ########################################################################
        self.logger.info("Building Training Graph")

        with tf.variable_scope("Skip_GANomaly"):
            with tf.variable_scope("Generator_Model"):
                self.img_rec = self.generator(self.image_input)
                self.img_rec += self.fake_noise
            with tf.variable_scope("Discriminator_Model"):
                self.disc_real, self.inter_layer_real = self.discriminator(
                    self.image_input + self.real_noise)
                self.disc_fake, self.inter_layer_fake = self.discriminator(
                    self.img_rec)
        ########################################################################
        # METRICS
        ########################################################################
        with tf.variable_scope("Loss_Functions"):
            with tf.variable_scope("Discriminator_Loss"):
                # According to the paper we invert the values for the normal/fake.
                # So normal images should be labeled
                # as zeros
                self.loss_dis_real = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                        labels=self.true_labels, logits=self.disc_real))
                self.loss_dis_fake = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                        labels=self.generated_labels, logits=self.disc_fake))
                # Adversarial Loss Part for Discriminator
                self.loss_discriminator = self.loss_dis_real + self.loss_dis_fake

            with tf.variable_scope("Generator_Loss"):
                # Adversarial Loss
                if self.config.trainer.flip_labels:
                    labels = tf.zeros_like(self.disc_fake)
                else:
                    labels = tf.ones_like(self.disc_fake)
                self.gen_adv_loss = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                        logits=self.disc_fake, labels=labels))
                # Contextual Loss
                context_layers = self.image_input - self.img_rec
                self.contextual_loss = tf.reduce_mean(
                    tf.norm(context_layers,
                            ord=1,
                            axis=1,
                            keepdims=False,
                            name="Contextual_Loss"))
                # Latent Loss
                layer_diff = self.inter_layer_real - self.inter_layer_fake
                self.latent_loss = tf.reduce_mean(
                    tf.norm(layer_diff,
                            ord=2,
                            axis=1,
                            keepdims=False,
                            name="Latent_Loss"))
                self.gen_loss_total = (
                    self.config.trainer.weight_adv * self.gen_adv_loss +
                    self.config.trainer.weight_cont * self.contextual_loss +
                    self.config.trainer.weight_lat * self.latent_loss)

        ########################################################################
        # OPTIMIZATION
        ########################################################################
        # Build the Optimizers
        with tf.name_scope("Optimization"):
            self.generator_optimizer = tf.train.AdamOptimizer(
                self.config.trainer.generator_l_rate,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )
            self.discriminator_optimizer = tf.train.AdamOptimizer(
                self.config.trainer.discriminator_l_rate,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )
            # Collect all the variables
            all_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            # Generator Network Variables
            self.generator_vars = [
                v for v in all_variables
                if v.name.startswith("Skip_GANomaly/Generator_Model")
            ]
            # Discriminator Network Variables
            self.discriminator_vars = [
                v for v in all_variables
                if v.name.startswith("Skip_GANomaly/Discriminator_Model")
            ]
            # Generator Network Operations
            self.gen_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="Skip_GANomaly/Generator_Model")
            # Discriminator Network Operations
            self.disc_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS,
                scope="Skip_GANomaly/Discriminator_Model")
            # Initialization of Optimizers
            with tf.control_dependencies(self.gen_update_ops):
                self.gen_op = self.generator_optimizer.minimize(
                    self.gen_loss_total,
                    global_step=self.global_step_tensor,
                    var_list=self.generator_vars,
                )
            with tf.control_dependencies(self.disc_update_ops):
                self.disc_op = self.discriminator_optimizer.minimize(
                    self.loss_discriminator, var_list=self.discriminator_vars)
            # Exponential Moving Average for Estimation
            self.dis_ema = tf.train.ExponentialMovingAverage(
                decay=self.config.trainer.ema_decay)
            maintain_averages_op_dis = self.dis_ema.apply(
                self.discriminator_vars)

            self.gen_ema = tf.train.ExponentialMovingAverage(
                decay=self.config.trainer.ema_decay)
            maintain_averages_op_gen = self.gen_ema.apply(self.generator_vars)

            with tf.control_dependencies([self.disc_op]):
                self.train_dis_op = tf.group(maintain_averages_op_dis)

            with tf.control_dependencies([self.gen_op]):
                self.train_gen_op = tf.group(maintain_averages_op_gen)

        ########################################################################
        # TESTING
        ########################################################################
        self.logger.info("Building Testing Graph...")

        with tf.variable_scope("Skip_GANomaly"):
            with tf.variable_scope("Generator_Model"):
                self.img_rec_ema = self.generator(self.image_input,
                                                  getter=get_getter(
                                                      self.gen_ema))
            with tf.variable_scope("Discriminator_Model"):
                self.disc_real_ema, self.inter_layer_real_ema = self.discriminator(
                    self.image_input, getter=get_getter(self.dis_ema))
                self.disc_fake_ema, self.inter_layer_fake_ema = self.discriminator(
                    self.img_rec_ema, getter=get_getter(self.dis_ema))

        with tf.name_scope("Testing"):
            with tf.variable_scope("Reconstruction_Loss"):
                # Contextual Loss
                context_layers = self.image_input - self.img_rec_ema
                context_layers = tf.layers.Flatten()(context_layers)
                contextual_loss_ema_1 = tf.norm(context_layers,
                                                ord=1,
                                                axis=1,
                                                keepdims=False,
                                                name="Contextual_Loss")
                self.contextual_loss_ema_1 = tf.squeeze(contextual_loss_ema_1)

                contextual_loss_ema_2 = tf.norm(context_layers,
                                                ord=2,
                                                axis=1,
                                                keepdims=False,
                                                name="Contextual_Loss")
                self.contextual_loss_ema_2 = tf.squeeze(contextual_loss_ema_2)

            with tf.variable_scope("Latent_Loss"):
                # Latent Loss
                layer_diff = self.inter_layer_real_ema - self.inter_layer_fake_ema
                layer_diff = tf.layers.Flatten()(layer_diff)
                latent_loss_ema = tf.norm(layer_diff,
                                          ord=2,
                                          axis=None,
                                          keepdims=False,
                                          name="Latent_Loss")
                self.latent_loss_ema = tf.squeeze(latent_loss_ema)

            self.anomaly_score_1 = (
                self.config.trainer.weight * self.contextual_loss_ema_1 +
                (1 - self.config.trainer.weight) * self.latent_loss_ema)
            self.anomaly_score_2 = (
                self.config.trainer.weight * self.contextual_loss_ema_2 +
                (1 - self.config.trainer.weight) * self.latent_loss_ema)

        if self.config.trainer.enable_early_stop:
            self.rec_error_valid = tf.reduce_mean(self.latent_loss_ema)

        ########################################################################
        # TENSORBOARD
        ########################################################################
        if self.config.log.enable_summary:
            with tf.name_scope("summary"):
                with tf.name_scope("disc_summary"):
                    tf.summary.scalar("loss_discriminator_total",
                                      self.loss_discriminator, ["dis"])
                    tf.summary.scalar("loss_dis_real", self.loss_dis_real,
                                      ["dis"])
                    tf.summary.scalar("loss_dis_fake", self.loss_dis_fake,
                                      ["dis"])
                with tf.name_scope("gen_summary"):
                    tf.summary.scalar("loss_generator_total",
                                      self.gen_loss_total, ["gen"])
                    tf.summary.scalar("loss_gen_adv", self.gen_adv_loss,
                                      ["gen"])
                    tf.summary.scalar("loss_gen_con", self.contextual_loss,
                                      ["gen"])
                    tf.summary.scalar("loss_gen_enc", self.latent_loss,
                                      ["gen"])
                with tf.name_scope("image_summary"):
                    tf.summary.image("reconstruct", self.img_rec, 1, ["image"])
                    tf.summary.image("input_images", self.image_input, 1,
                                     ["image"])
                    tf.summary.image("reconstruct", self.img_rec_ema, 1,
                                     ["image_2"])
                    tf.summary.image("input_images", self.image_input, 1,
                                     ["image_2"])

        if self.config.trainer.enable_early_stop:
            with tf.name_scope("validation_summary"):
                tf.summary.scalar("valid", self.rec_error_valid, ["v"])

        self.sum_op_dis = tf.summary.merge_all("dis")
        self.sum_op_gen = tf.summary.merge_all("gen")
        self.sum_op_im = tf.summary.merge_all("image")
        self.sum_op_im_test = tf.summary.merge_all("image_2")
        self.sum_op_valid = tf.summary.merge_all("v")
예제 #8
0
    def build_model(self):
        # Kernel initialization for the convolutions
        self.init_kernel = tf.random_normal_initializer(mean=0.0, stddev=0.02)
        # Placeholders
        self.is_training = tf.placeholder(tf.bool)
        self.image_input = tf.placeholder(tf.float32,
                                          shape=[None] +
                                          self.config.trainer.image_dims,
                                          name="x")
        self.noise_tensor = tf.placeholder(
            tf.float32,
            shape=[None, self.config.trainer.noise_dim],
            name="noise")
        self.true_labels = tf.placeholder(dtype=tf.float32,
                                          shape=[None, 1],
                                          name="true_labels")
        self.generated_labels = tf.placeholder(dtype=tf.float32,
                                               shape=[None, 1],
                                               name="gen_labels")
        self.real_noise = tf.placeholder(dtype=tf.float32,
                                         shape=[None] +
                                         self.config.trainer.image_dims,
                                         name="real_noise")
        self.fake_noise = tf.placeholder(dtype=tf.float32,
                                         shape=[None] +
                                         self.config.trainer.image_dims,
                                         name="fake_noise")

        self.logger.info("Building training graph...")
        with tf.variable_scope("BIGAN"):
            with tf.variable_scope("Encoder_Model"):
                self.noise_gen = self.encoder(
                    self.image_input,
                    do_spectral_norm=self.config.trainer.do_spectral_norm)

            with tf.variable_scope("Generator_Model"):
                self.image_gen = self.generator(
                    self.noise_tensor) + self.fake_noise
                self.reconstructed = self.generator(self.noise_gen)

            with tf.variable_scope("Discriminator_Model"):
                # E(x) and x --> This being real is the output of discriminator
                l_encoder, inter_layer_inp = self.discriminator(
                    self.noise_gen,
                    self.image_input + self.real_noise,
                    do_spectral_norm=self.config.trainer.do_spectral_norm,
                )
                # z and G(z)
                l_generator, inter_layer_rct = self.discriminator(
                    self.noise_tensor,
                    self.image_gen,
                    do_spectral_norm=self.config.trainer.do_spectral_norm,
                )

        # Loss Function Implementations
        with tf.name_scope("Loss_Functions"):
            # Discriminator
            # Discriminator sees the encoder result as true because it discriminates E(x), x as the real pair
            if self.config.trainer.mode == "standard":
                self.loss_dis_enc = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                        labels=self.true_labels, logits=l_encoder))
                self.loss_dis_gen = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                        labels=self.generated_labels, logits=l_generator))
                self.loss_discriminator = self.loss_dis_enc + self.loss_dis_gen
                # Flip the weigths for the encoder and generator
                if self.config.trainer.flip_labels:
                    labels_gen = tf.zeros_like(l_generator)
                    labels_enc = tf.ones_like(l_encoder)
                else:
                    labels_gen = tf.ones_like(l_generator)
                    labels_enc = tf.zeros_like(l_encoder)
                # Generator
                # Generator is considered as the true ones here because it tries to fool discriminator
                self.loss_generator_ce = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                        labels=labels_gen, logits=l_generator))
                delta = inter_layer_inp - inter_layer_rct
                delta = tf.layers.Flatten()(delta)
                self.loss_generator_fm = tf.reduce_mean(
                    tf.norm(delta, ord=2, axis=1, keepdims=False))
                self.loss_generator = (
                    self.loss_generator_ce +
                    self.config.trainer.feature_match_weight *
                    self.loss_generator_fm)
                # Encoder
                # Encoder is considered as the fake one because it tries to fool the discriminator also
                self.loss_encoder = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_enc,
                                                            logits=l_encoder))

            elif self.config.trainer.mode == "wgan":
                self.loss_generator = -tf.reduce_mean(l_generator)
                self.loss_encoder = -tf.reduce_mean(l_encoder)
                self.loss_discriminator = tf.reduce_mean(
                    l_generator) - tf.reduce_mean(l_encoder)
            elif self.config.trainer.mode == "wgan-gp":
                self.loss_generator = -tf.reduce_mean(l_generator)
                self.loss_encoder = -tf.reduce_mean(l_encoder)
                self.loss_discriminator = tf.reduce_mean(
                    l_generator) - tf.reduce_mean(l_encoder)

                alpha_x = tf.random_uniform(
                    shape=[self.config.data_loader.batch_size] +
                    self.config.trainer.image_dims,
                    minval=0.0,
                    maxval=1.0,
                )
                alpha_z = tf.random_uniform(
                    shape=[
                        self.config.data_loader.batch_size,
                        self.config.trainer.noise_dim
                    ],
                    minval=0.0,
                    maxval=1.0,
                )
                differences_x = self.image_gen - self.image_input
                interpolates_x = self.image_input + (alpha_x * differences_x)
                differences_z = self.noise_gen - self.noise_tensor
                interpolates_z = self.noise_tensor + (alpha_z * differences_z)
                gradients = tf.gradients(
                    self.discriminator(interpolates_z, interpolates_x),
                    [interpolates_z, interpolates_x],
                )[0]
                slopes = tf.sqrt(
                    tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
                gradient_penalty = tf.reduce_mean((slopes - 1.0)**2)
                self.loss_discriminator += self.config.trainer.wgan_gp_lambda * gradient_penalty

        # Optimizer Implementations
        with tf.name_scope("Optimizers"):
            if self.config.trainer.mode == "standard":
                # Build the optimizers
                self.generator_optimizer = tf.train.AdamOptimizer(
                    self.config.trainer.standard_lr,
                    beta1=self.config.trainer.optimizer_adam_beta1,
                    beta2=self.config.trainer.optimizer_adam_beta2,
                )
                self.discriminator_optimizer = tf.train.AdamOptimizer(
                    self.config.trainer.standard_lr,
                    beta1=self.config.trainer.optimizer_adam_beta1,
                    beta2=self.config.trainer.optimizer_adam_beta2,
                )
                self.encoder_optimizer = tf.train.AdamOptimizer(
                    self.config.trainer.standard_lr,
                    beta1=self.config.trainer.optimizer_adam_beta1,
                    beta2=self.config.trainer.optimizer_adam_beta2,
                )
            elif self.config.trainer.mode == "wgan":
                # Build the optimizers
                self.generator_optimizer = tf.train.RMSPropOptimizer(
                    self.config.trainer.wgan_lr)
                self.discriminator_optimizer = tf.train.RMSPropOptimizer(
                    self.config.trainer.wgan_lr)
                self.encoder_optimizer = tf.train.RMSPropOptimizer(
                    self.config.trainer.wgan_lr)
            elif self.config.trainer.mode == "wgan-gp":
                # Build the optimizers
                self.generator_optimizer = tf.train.AdamOptimizer(
                    self.config.trainer.wgan_gp_lr,
                    beta1=self.config.trainer.optimizer_adam_beta1,
                    beta2=self.config.trainer.optimizer_adam_beta2,
                )
                self.discriminator_optimizer = tf.train.AdamOptimizer(
                    self.config.trainer.wgan_gp_lr,
                    beta1=self.config.trainer.optimizer_adam_beta1,
                    beta2=self.config.trainer.optimizer_adam_beta2,
                )
                self.encoder_optimizer = tf.train.AdamOptimizer(
                    self.config.trainer.wgan_gp_lr,
                    beta1=self.config.trainer.optimizer_adam_beta1,
                    beta2=self.config.trainer.optimizer_adam_beta2,
                )

            # Collect all the variables
            all_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            # Generator Network Variables
            self.generator_vars = [
                v for v in all_variables
                if v.name.startswith("BIGAN/Generator_Model")
            ]
            # Discriminator Network Variables
            self.discriminator_vars = [
                v for v in all_variables
                if v.name.startswith("BIGAN/Discriminator_Model")
            ]
            if self.config.trainer.mode == "wgan":
                clip_ops = []
                for var in self.discriminator_vars:
                    clip_bounds = [-0.01, 0.01]
                    clip_ops.append(
                        tf.assign(
                            var,
                            tf.clip_by_value(var, clip_bounds[0],
                                             clip_bounds[1])))
                self.clip_disc_weights = tf.group(*clip_ops)
            # Encoder Network Variables
            self.encoder_vars = [
                v for v in all_variables
                if v.name.startswith("BIGAN/Encoder_Model")
            ]
            # Create Training Operations
            # Generator Network Operations
            self.gen_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="BIGAN/Generator_Model")
            # Discriminator Network Operations
            self.disc_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="BIGAN/Discriminator_Model")
            # Encoder Network Operations
            self.enc_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="BIGAN/Encoder_Model")
            # Initialization of Optimizers
            with tf.control_dependencies(self.gen_update_ops):
                self.gen_op = self.generator_optimizer.minimize(
                    self.loss_generator,
                    var_list=self.generator_vars,
                    global_step=self.global_step_tensor,
                )
            with tf.control_dependencies(self.disc_update_ops):
                self.disc_op = self.discriminator_optimizer.minimize(
                    self.loss_discriminator, var_list=self.discriminator_vars)
            with tf.control_dependencies(self.enc_update_ops):
                self.enc_op = self.encoder_optimizer.minimize(
                    self.loss_encoder, var_list=self.encoder_vars)
            # Exponential Moving Average for Estimation
            self.dis_ema = tf.train.ExponentialMovingAverage(
                decay=self.config.trainer.ema_decay)
            maintain_averages_op_dis = self.dis_ema.apply(
                self.discriminator_vars)

            self.gen_ema = tf.train.ExponentialMovingAverage(
                decay=self.config.trainer.ema_decay)
            maintain_averages_op_gen = self.gen_ema.apply(self.generator_vars)

            self.enc_ema = tf.train.ExponentialMovingAverage(
                decay=self.config.trainer.ema_decay)
            maintain_averages_op_enc = self.enc_ema.apply(self.encoder_vars)

            with tf.control_dependencies([self.disc_op]):
                self.train_dis_op = tf.group(maintain_averages_op_dis)

            with tf.control_dependencies([self.gen_op]):
                self.train_gen_op = tf.group(maintain_averages_op_gen)

            with tf.control_dependencies([self.enc_op]):
                self.train_enc_op = tf.group(maintain_averages_op_enc)

        self.logger.info("Building Testing Graph...")

        with tf.variable_scope("BIGAN"):
            with tf.variable_scope("Encoder_Model"):
                self.noise_gen_ema = self.encoder(
                    self.image_input,
                    getter=get_getter(self.enc_ema),
                    do_spectral_norm=self.config.trainer.do_spectral_norm,
                )
            with tf.variable_scope("Generator_Model"):
                self.reconstruct_ema = self.generator(self.noise_gen_ema,
                                                      getter=get_getter(
                                                          self.gen_ema))
            with tf.variable_scope("Discriminator_Model"):
                self.l_encoder_ema, self.inter_layer_inp_ema = self.discriminator(
                    self.noise_gen_ema,  # E(x)
                    self.image_input,  # x
                    getter=get_getter(self.dis_ema),
                    do_spectral_norm=self.config.trainer.do_spectral_norm,
                )
                self.l_generator_ema, self.inter_layer_rct_ema = self.discriminator(
                    self.noise_gen_ema,  # E(x)
                    self.reconstruct_ema,  # G(E(x))
                    getter=get_getter(self.dis_ema),
                    do_spectral_norm=self.config.trainer.do_spectral_norm,
                )

        with tf.name_scope("Testing"):
            with tf.variable_scope("Reconstruction_Loss"):
                # LG(x) = ||x - G(E(x))||_1
                delta = self.image_input - self.reconstruct_ema
                delta_flat = tf.layers.Flatten()(delta)
                self.gen_score = tf.norm(
                    delta_flat,
                    ord=self.config.trainer.degree,
                    axis=1,
                    keepdims=False,
                    name="epsilon",
                )
            with tf.variable_scope("Discriminator_Loss"):
                if self.config.trainer.loss_method == "cross_e":
                    self.dis_score = tf.nn.sigmoid_cross_entropy_with_logits(
                        labels=tf.ones_like(self.l_encoder_ema),
                        logits=self.l_encoder_ema)
                elif self.config.trainer.loss_method == "fm":
                    fm = self.inter_layer_inp_ema - self.inter_layer_rct_ema
                    fm = tf.layers.Flatten()(fm)
                    self.dis_score = tf.norm(fm,
                                             ord=self.config.trainer.degree,
                                             axis=1,
                                             keepdims=False,
                                             name="d_loss")
                self.dis_score = tf.squeeze(self.dis_score)
            with tf.variable_scope("Score"):
                self.list_scores = (
                    1 - self.config.trainer.weight
                ) * self.dis_score + self.config.trainer.weight * self.gen_score

        if self.config.trainer.enable_early_stop:
            self.rec_error_valid = tf.reduce_mean(self.list_scores)

        if self.config.log.enable_summary:
            with tf.name_scope("Summary"):
                with tf.name_scope("Disc_Summary"):
                    tf.summary.scalar("loss_discriminator",
                                      self.loss_discriminator, ["dis"])
                    if self.config.trainer.mode == "standard":
                        tf.summary.scalar("loss_dis_encoder",
                                          self.loss_dis_enc, ["dis"])
                        tf.summary.scalar("loss_dis_gen", self.loss_dis_gen,
                                          ["dis"])
                with tf.name_scope("Gen_Summary"):
                    tf.summary.scalar("loss_generator", self.loss_generator,
                                      ["gen"])
                    if self.config.trainer.mode == "standard":
                        tf.summary.scalar("loss_generator_ce",
                                          self.loss_generator_ce, ["gen"])
                        tf.summary.scalar("loss_generator_fm",
                                          self.loss_generator_fm, ["gen"])
                    tf.summary.scalar("loss_encoder", self.loss_encoder,
                                      ["gen"])
                with tf.name_scope("Image_Summary"):
                    tf.summary.image("reconstruct", self.reconstructed, 3,
                                     ["image"])
                    tf.summary.image("input_images", self.image_input, 3,
                                     ["image"])
        if self.config.trainer.enable_early_stop:
            with tf.name_scope("validation_summary"):
                tf.summary.scalar("valid", self.rec_error_valid, ["v"])

        self.sum_op_dis = tf.summary.merge_all("dis")
        self.sum_op_gen = tf.summary.merge_all("gen")
        self.sum_op_im = tf.summary.merge_all("image")
        self.sum_op_valid = tf.summary.merge_all("v")
예제 #9
0
    def build_model(self):
        # Initializations
        # Kernel initialization for the convolutions
        self.init_kernel = tf.contrib.layers.xavier_initializer(
            uniform=False, seed=None, dtype=tf.float32)
        # Placeholders
        self.is_training = tf.placeholder(tf.bool)
        self.image_input = tf.placeholder(tf.float32,
                                          shape=[None] +
                                          self.config.trainer.image_dims,
                                          name="x")
        self.noise_tensor = tf.placeholder(
            tf.float32,
            shape=[None, self.config.trainer.noise_dim],
            name="noise")
        # Build Training Graph
        self.logger.info("Building training graph...")
        with tf.variable_scope("FenceGAN"):
            with tf.variable_scope("Generator_Model"):
                self.image_gen = self.generator(self.noise_tensor)

            with tf.variable_scope("Discriminator_Model"):
                self.disc_real, self.f_layer_real = self.discriminator(
                    self.image_input,
                    do_spectral_norm=self.config.trainer.do_spectral_norm)
                self.disc_fake, self.f_layer_fake = self.discriminator(
                    self.image_gen,
                    do_spectral_norm=self.config.trainer.do_spectral_norm)
        # Loss functions
        with tf.name_scope("Loss_Functions"):
            # Discriminator Loss
            # Generator Loss
            pass

        # Optimizers
        with tf.name_scope("Optimizers"):
            self.generator_optimizer = tf.train.AdamOptimizer(
                self.config.trainer.standard_lr,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )
            self.discriminator_optimizer = tf.train.AdamOptimizer(
                self.config.trainer.standard_lr,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )
            # Collect all the variables
            all_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            # Generator Network Variables
            self.generator_vars = [
                v for v in all_variables
                if v.name.startswith("FenceGAN/Generator_Model")
            ]
            # Discriminator Network Variables
            self.discriminator_vars = [
                v for v in all_variables
                if v.name.startswith("FenceGAN/Discriminator_Model")
            ]
            # Generator Network Operations
            self.gen_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="FenceGAN/Generator_Model")
            # Discriminator Network Operations
            self.disc_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="FenceGAN/Discriminator_Model")
            with tf.control_dependencies(self.gen_update_ops):
                self.gen_op = self.generator_optimizer.minimize(
                    self.loss_generator,
                    var_list=self.generator_vars,
                    global_step=self.global_step_tensor,
                )
            with tf.control_dependencies(self.disc_update_ops):
                self.disc_op = self.discriminator_optimizer.minimize(
                    self.loss_discriminator, var_list=self.discriminator_vars)
            # Exponential Moving Average for Estimation
            self.dis_ema = tf.train.ExponentialMovingAverage(
                decay=self.config.trainer.ema_decay)
            maintain_averages_op_dis = self.dis_ema.apply(
                self.discriminator_vars)

            self.gen_ema = tf.train.ExponentialMovingAverage(
                decay=self.config.trainer.ema_decay)
            maintain_averages_op_gen = self.gen_ema.apply(self.generator_vars)

            with tf.control_dependencies([self.disc_op]):
                self.train_dis_op = tf.group(maintain_averages_op_dis)

            with tf.control_dependencies([self.gen_op]):
                self.train_gen_op = tf.group(maintain_averages_op_gen)

        # Build Test Graph
        self.logger.info("Building Testing Graph...")
        with tf.variable_scope("FenceGAN"):
            with tf.variable_scope("Generator_Model"):
                self.image_gen_ema = self.generator(self.noise_tensor,
                                                    getter=get_getter(
                                                        self.gen_ema))

            with tf.variable_scope("Discriminator_Model"):
                self.disc_real_ema, self.f_layer_real_ema = self.discriminator(
                    self.image_input,
                    do_spectral_norm=self.config.trainer.do_spectral_norm,
                    getter=get_getter(self.dis_ema),
                )
                self.disc_fake_ema, self.f_layer_fake_ema = self.discriminator(
                    self.image_gen,
                    do_spectral_norm=self.config.trainer.do_spectral_norm,
                    getter=get_getter(self.dis_ema),
                )
        with tf.name_scope("Testing"):
            with tf.name_scope("Image_Based"):
                delta = self.image_input - self.image_gen_ema
                delta_flat = tf.layers.Flatten()(delta)
                img_score_l1 = tf.norm(delta_flat,
                                       ord=1,
                                       axis=1,
                                       keepdims=False,
                                       name="img_loss__1")
                self.img_score_l1 = tf.squeeze(img_score_l1)

                delta = self.image_input - self.image_gen_ema
                delta_flat = tf.layers.Flatten()(delta)
                img_score_l2 = tf.norm(delta_flat,
                                       ord=2,
                                       axis=1,
                                       keepdims=False,
                                       name="img_loss__2")
                self.img_score_l2 = tf.squeeze(img_score_l2)
            with tf.name_scope("Noise_Based"):
                delta = self.embedding_rec_ema - self.embedding_q_ema
                delta_flat = tf.layers.Flatten()(delta)
                z_score_l1 = tf.norm(delta_flat,
                                     ord=1,
                                     axis=1,
                                     keepdims=False,
                                     name="z_loss_1")
                self.z_score_l1 = tf.squeeze(z_score_l1)

                delta = self.embedding_rec_ema - self.embedding_q_ema
                delta_flat = tf.layers.Flatten()(delta)
                z_score_l2 = tf.norm(delta_flat,
                                     ord=2,
                                     axis=1,
                                     keepdims=False,
                                     name="z_loss_2")
                self.z_score_l2 = tf.squeeze(z_score_l2)

        # Tensorboard
        if self.config.log.enable_summary:
            with tf.name_scope("train_summary"):
                with tf.name_scope("dis_summary"):
                    tf.summary.scalar("loss_disc", self.loss_discriminator,
                                      ["dis"])
                    tf.summary.scalar("loss_disc_real", self.disc_loss_real,
                                      ["dis"])
                    tf.summary.scalar("loss_disc_fake", self.disc_loss_fake,
                                      ["dis"])
                with tf.name_scope("gen_summary"):
                    tf.summary.scalar("loss_generator", self.loss_generator,
                                      ["gen"])
                with tf.name_scope("img_summary"):
                    tf.summary.image("input_image", self.image_input, 3,
                                     ["img"])
                    tf.summary.image("reconstructed", self.image_gen, 3,
                                     ["img"])

        self.sum_op_dis = tf.summary.merge_all("dis")
        self.sum_op_gen = tf.summary.merge_all("gen")
        self.sum_op = tf.summary.merge([self.sum_op_dis, self.sum_op_gen])
        self.sum_op_im = tf.summary.merge_all("img")
예제 #10
0
    def build_model(self):
        # Kernel initialization for the convolutions
        self.init_kernel = tf.random_normal_initializer(mean=0.0, stddev=0.02)
        # Placeholders
        self.is_training = tf.placeholder(tf.bool)
        self.image_input = tf.placeholder(tf.float32,
                                          shape=[None] +
                                          self.config.trainer.image_dims,
                                          name="x")
        self.noise_tensor = tf.placeholder(
            tf.float32,
            shape=[None, self.config.trainer.noise_dim],
            name="noise")
        self.true_labels = tf.placeholder(dtype=tf.float32,
                                          shape=[None, 1],
                                          name="true_labels")
        self.generated_labels = tf.placeholder(dtype=tf.float32,
                                               shape=[None, 1],
                                               name="gen_labels")
        self.real_noise = tf.placeholder(dtype=tf.float32,
                                         shape=[None] +
                                         self.config.trainer.image_dims,
                                         name="real_noise")
        self.fake_noise = tf.placeholder(dtype=tf.float32,
                                         shape=[None] +
                                         self.config.trainer.image_dims,
                                         name="fake_noise")

        self.logger.info("Building training graph...")
        with tf.variable_scope("BIGAN"):
            with tf.variable_scope("Encoder_Model"):
                self.noise_gen = self.encoder(self.image_input)

            with tf.variable_scope("Generator_Model"):
                self.image_gen = self.generator(
                    self.noise_tensor) + self.fake_noise
                self.reconstructed = self.generator(self.noise_gen)

            with tf.variable_scope("Discriminator_Model"):
                # E(x) and x --> This being real is the output of discriminator
                l_encoder, inter_layer_inp = self.discriminator(
                    self.noise_gen, self.image_input + self.real_noise)
                # z and G(z)
                l_generator, inter_layer_rct = self.discriminator(
                    self.noise_tensor, self.image_gen)

        # Loss Function Implementations
        with tf.name_scope("Loss_Functions"):
            # Discriminator
            # Discriminator sees the encoder result as true because it discriminates E(x), x as the real pair
            self.loss_dis_enc = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=self.true_labels, logits=l_encoder))
            self.loss_dis_gen = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=self.generated_labels, logits=l_generator))
            self.loss_discriminator = self.loss_dis_enc + self.loss_dis_gen

            if self.config.trainer.flip_labels:
                labels_gen = tf.zeros_like(l_generator)
                labels_enc = tf.ones_like(l_encoder)
            else:
                labels_gen = tf.ones_like(l_generator)
                labels_enc = tf.zeros_like(l_encoder)
            # Generator
            # Generator is considered as the true ones here because it tries to fool discriminator
            self.loss_generator = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_gen,
                                                        logits=l_generator))
            # Encoder
            # Encoder is considered as the fake one because it tries to fool the discriminator also
            self.loss_encoder = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_enc,
                                                        logits=l_encoder))
        # Optimizer Implementations
        with tf.name_scope("Optimizers"):
            # Build the optimizers
            self.generator_optimizer = tf.train.AdamOptimizer(
                self.config.trainer.generator_l_rate,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )
            self.discriminator_optimizer = tf.train.AdamOptimizer(
                self.config.trainer.discriminator_l_rate,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )
            self.encoder_optimizer = tf.train.AdamOptimizer(
                self.config.trainer.generator_l_rate,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )
            # Collect all the variables
            all_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            # Generator Network Variables
            self.generator_vars = [
                v for v in all_variables
                if v.name.startswith("BIGAN/Generator_Model")
            ]
            # Discriminator Network Variables
            self.discriminator_vars = [
                v for v in all_variables
                if v.name.startswith("BIGAN/Discriminator_Model")
            ]
            # Encoder Network Variables
            self.encoder_vars = [
                v for v in all_variables
                if v.name.startswith("BIGAN/Encoder_Model")
            ]
            # Create Training Operations
            # Generator Network Operations
            self.gen_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="BIGAN/Generator_Model")
            # Discriminator Network Operations
            self.disc_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="BIGAN/Discriminator_Model")
            # Encoder Network Operations
            self.enc_update_ops = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="BIGAN/Encoder_Model")
            # Initialization of Optimizers
            with tf.control_dependencies(self.gen_update_ops):
                self.gen_op = self.generator_optimizer.minimize(
                    self.loss_generator,
                    var_list=self.generator_vars,
                    global_step=self.global_step_tensor,
                )
            with tf.control_dependencies(self.disc_update_ops):
                self.disc_op = self.discriminator_optimizer.minimize(
                    self.loss_discriminator, var_list=self.discriminator_vars)
            with tf.control_dependencies(self.enc_update_ops):
                self.enc_op = self.encoder_optimizer.minimize(
                    self.loss_encoder, var_list=self.encoder_vars)
            # Exponential Moving Average for Estimation
            self.dis_ema = tf.train.ExponentialMovingAverage(
                decay=self.config.trainer.ema_decay)
            maintain_averages_op_dis = self.dis_ema.apply(
                self.discriminator_vars)

            self.gen_ema = tf.train.ExponentialMovingAverage(
                decay=self.config.trainer.ema_decay)
            maintain_averages_op_gen = self.gen_ema.apply(self.generator_vars)

            self.enc_ema = tf.train.ExponentialMovingAverage(
                decay=self.config.trainer.ema_decay)
            maintain_averages_op_enc = self.enc_ema.apply(self.encoder_vars)

            with tf.control_dependencies([self.disc_op]):
                self.train_dis_op = tf.group(maintain_averages_op_dis)

            with tf.control_dependencies([self.gen_op]):
                self.train_gen_op = tf.group(maintain_averages_op_gen)

            with tf.control_dependencies([self.enc_op]):
                self.train_enc_op = tf.group(maintain_averages_op_enc)

        self.logger.info("Building Testing Graph...")

        with tf.variable_scope("BIGAN"):
            with tf.variable_scope("Encoder_Model"):
                self.noise_gen_ema = self.encoder(self.image_input,
                                                  getter=get_getter(
                                                      self.enc_ema))
            with tf.variable_scope("Generator_Model"):
                self.reconstruct_ema = self.generator(self.noise_gen_ema,
                                                      getter=get_getter(
                                                          self.gen_ema))
            with tf.variable_scope("Discriminator_Model"):
                self.l_encoder_ema, self.inter_layer_inp_ema = self.discriminator(
                    self.noise_gen_ema,  # E(x)
                    self.image_input,  # x
                    getter=get_getter(self.dis_ema),
                )
                self.l_generator_ema, self.inter_layer_rct_ema = self.discriminator(
                    self.noise_gen_ema,  # E(x)
                    self.reconstruct_ema,  # G(E(x))
                    getter=get_getter(self.dis_ema),
                )

        with tf.name_scope("Testing"):
            with tf.variable_scope("Reconstruction_Loss"):
                # LG(x) = ||x - G(E(x))||_1
                delta = self.image_input - self.reconstruct_ema
                delta_flat = tf.layers.Flatten()(delta)
                self.gen_score = tf.norm(
                    delta_flat,
                    ord=self.config.trainer.degree,
                    axis=1,
                    keepdims=False,
                    name="epsilon",
                )
            with tf.variable_scope("Discriminator_Loss"):
                if self.config.trainer.loss_method == "cross_e":
                    self.dis_score = tf.nn.sigmoid_cross_entropy_with_logits(
                        labels=tf.ones_like(self.l_encoder_ema),
                        logits=self.l_encoder_ema)
                elif self.config.trainer.loss_method == "fm":
                    fm = self.inter_layer_inp_ema - self.inter_layer_rct_ema
                    fm = tf.layers.Flatten()(fm)
                    self.dis_score = tf.norm(fm,
                                             ord=self.config.trainer.degree,
                                             axis=1,
                                             keepdims=False,
                                             name="d_loss")
                self.dis_score = tf.squeeze(self.dis_score)
            with tf.variable_scope("Score"):
                self.list_scores = (
                    1 - self.config.trainer.weight
                ) * self.dis_score + self.config.trainer.weight * self.gen_score

        if self.config.trainer.enable_early_stop:
            self.rec_error_valid = tf.reduce_mean(self.list_scores)

        if self.config.log.enable_summary:
            with tf.name_scope("Summary"):
                with tf.name_scope("Disc_Summary"):
                    tf.summary.scalar("loss_discriminator",
                                      self.loss_discriminator, ["dis"])
                    tf.summary.scalar("loss_dis_encoder", self.loss_dis_enc,
                                      ["dis"])
                    tf.summary.scalar("loss_dis_gen", self.loss_dis_gen,
                                      ["dis"])
                with tf.name_scope("Gen_Summary"):
                    tf.summary.scalar("loss_generator", self.loss_generator,
                                      ["gen"])
                    tf.summary.scalar("loss_encoder", self.loss_encoder,
                                      ["gen"])
                with tf.name_scope("Image_Summary"):
                    tf.summary.image("reconstruct", self.reconstructed, 3,
                                     ["image"])
                    tf.summary.image("input_images", self.image_input, 3,
                                     ["image"])
        if self.config.trainer.enable_early_stop:
            with tf.name_scope("validation_summary"):
                tf.summary.scalar("valid", self.rec_error_valid, ["v"])

        self.sum_op_dis = tf.summary.merge_all("dis")
        self.sum_op_gen = tf.summary.merge_all("gen")
        self.sum_op_im = tf.summary.merge_all("image")
        self.sum_op_valid = tf.summary.merge_all("v")
예제 #11
0
    def build_model(self):
        # Kernel initialization for the convolutions
        self.init_kernel = tf.random_normal_initializer(mean=0.0, stddev=0.02)
        # Placeholders
        self.is_training = tf.placeholder(tf.bool)
        self.image_input = tf.placeholder(tf.float32,
                                          shape=[None] +
                                          self.config.trainer.image_dims,
                                          name="x")

        self.true_labels = tf.placeholder(dtype=tf.float32,
                                          shape=[None, 1],
                                          name="true_labels")
        self.generated_labels = tf.placeholder(dtype=tf.float32,
                                               shape=[None, 1],
                                               name="gen_labels")

        self.logger.info("Building training graph...")

        with tf.variable_scope("Mark1"):
            with tf.variable_scope("Generator_Model"):
                self.noise_gen, self.img_rec, self.noise_rec = self.generator(
                    self.image_input)
            # Discriminator results of (G(z),z) and (x, E(x))
            with tf.variable_scope("Discriminator_Model_XZ"):
                l_generator, inter_layer_rct_xz = self.discriminator_xz(
                    self.img_rec,
                    self.noise_gen,
                    do_spectral_norm=self.config.spectral_norm)
                l_encoder, inter_layer_inp_xz = self.discriminator_xz(
                    self.image_input,
                    self.noise_gen,
                    do_spectral_norm=self.config.do_spectral_norm)
            # Discrimeinator results of (x, x) and (x, G(E(x))
            with tf.variable_scope("Discriminator_Model_XX"):
                x_logit_real, inter_layer_inp_xx = self.discriminator_xx(
                    self.image_input,
                    self.image_input,
                    do_spectral_norm=self.config.spectral_norm)
                x_logit_fake, inter_layer_rct_xx = self.discriminator_xx(
                    self.image_input,
                    self.img_rec,
                    do_spectral_norm=self.config.spectral_norm)
            # Discriminator results of (z, z) and (z, E(G(z))
            with tf.variable_scope("Discriminator_Model_ZZ"):
                z_logit_real, _ = self.discriminator_zz(
                    self.noise_gen,
                    self.noise_gen,
                    do_spectral_norm=self.config.spectral_norm)
                z_logit_fake, _ = self.discriminator_zz(
                    self.noise_gen,
                    self.noise_rec,
                    do_spectral_norm=self.config.spectral_norm)

        with tf.name_scope("Loss_Functions"):
            # discriminator xz

            # Discriminator should classify encoder pair as real
            loss_dis_enc = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=self.true_labels, logits=l_encoder))
            # Discriminator should classify generator pair as fake
            loss_dis_gen = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=self.generated_labels, logits=l_generator))
            self.dis_loss_xz = loss_dis_gen + loss_dis_enc

            # discriminator xx
            x_real_dis = tf.nn.sigmoid_cross_entropy_with_logits(
                logits=x_logit_real, labels=self.true_labels)
            x_fake_dis = tf.nn.sigmoid_cross_entropy_with_logits(
                logits=x_logit_fake, labels=self.generated_labels)
            self.dis_loss_xx = tf.reduce_mean(x_real_dis + x_fake_dis)
            # discriminator zz
            z_real_dis = tf.nn.sigmoid_cross_entropy_with_logits(
                logits=z_logit_real, labels=self.true_labels)
            z_fake_dis = tf.nn.sigmoid_cross_entropy_with_logits(
                logits=z_logit_fake, labels=self.generated_labels)
            self.dis_loss_zz = tf.reduce_mean(z_real_dis + z_fake_dis)
            # Feature matching part
            fm = inter_layer_inp_xx - inter_layer_rct_xx
            fm = tf.layers.Flatten()(fm)
            self.feature_match = tf.reduce_mean(
                tf.norm(fm, ord=2, axis=1, keepdims=False))
            # Compute the whole discriminator loss
            self.loss_discriminator = (
                self.dis_loss_xz + self.dis_loss_xx + self.dis_loss_zz
                if self.config.trainer.allow_zz else self.dis_loss_xz +
                self.dis_loss_xx + self.feature_match if
                self.config.trainer.loss_method == "fm" else self.dis_loss_xz +
                self.dis_loss_xx)
            # Generator
            # Adversarial Loss
            if self.config.trainer.flip_labels:
                labels_gen = tf.zeros_like(l_generator)
                labels_enc = tf.ones_like(l_encoder)
            else:
                labels_gen = tf.ones_like(l_generator)
                labels_enc = tf.zeros_like(l_encoder)
            self.gen_loss_enc = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_enc,
                                                        logits=l_encoder))
            self.gen_loss_gen = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_gen,
                                                        logits=l_generator))
            # Contextual Loss
            l1_norm = self.image_input - self.img_rec
            l1_norm = tf.layers.Flatten()(l1_norm)
            self.gen_loss_con = tf.reduce_mean(
                tf.norm(l1_norm, ord=1, axis=1, keepdims=False))
            # Encoder Loss
            l2_norm = self.noise_gen - self.noise_rec
            l2_norm = tf.layers.Flatten()(l2_norm)
            self.gen_loss_enc = tf.reduce_mean(
                tf.norm(l2_norm, ord=2, axis=1, keepdims=False))

            self.gen_loss_ce = self.gen_loss_enc + self.gen_loss_gen
            self.gen_loss_total = (
                self.config.trainer.weight_adv * self.gen_loss_ce +
                self.config.trainer.weight_cont * self.gen_loss_con +
                self.config.trainer.weight_enc * self.gen_loss_enc)

        with tf.name_scope("Optimizers"):
            # Collect all the variables
            all_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            # Generator Network Variables
            self.gvars = [
                v for v in all_variables
                if v.name.startswith("Mark1/Generator_Model")
            ]
            self.dxzvars = [
                v for v in all_variables
                if v.name.startswith("Mark1/Discriminator_Model_XZ")
            ]
            self.dxxvars = [
                v for v in all_variables
                if v.name.startswith("Mark1/Discriminator_Model_XX")
            ]
            self.dzzvars = [
                v for v in all_variables
                if v.name.startswith("Mark1/Discriminator_Model_ZZ")
            ]
            # Create Training Operations
            # Generator Network Operations
            self.update_ops_gen = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="Mark1/Generator_Model")
            self.update_ops_enc = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="Mark1/Encoder_Model")
            self.update_ops_dis_xz = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="Mark1/Discriminator_Model_XZ")
            self.update_ops_dis_xx = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="Mark1/Discriminator_Model_XX")
            self.update_ops_dis_zz = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS, scope="Mark1/Discriminator_Model_ZZ")
            self.disc_optimizer = tf.train.AdamOptimizer(
                learning_rate=self.config.trainer.discriminator_l_rate,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )
            self.gen_optimizer = tf.train.AdamOptimizer(
                learning_rate=self.config.trainer.generator_l_rate,
                beta1=self.config.trainer.optimizer_adam_beta1,
                beta2=self.config.trainer.optimizer_adam_beta2,
            )
            # Initialization of Optimizers
            with tf.control_dependencies(self.update_ops_gen):
                self.gen_op = self.gen_optimizer.minimize(
                    self.gen_loss_total,
                    global_step=self.global_step_tensor,
                    var_list=self.gvars)

            with tf.control_dependencies(self.update_ops_dis_xz):
                self.dis_op_xz = self.disc_optimizer.minimize(
                    self.dis_loss_xz, var_list=self.dxzvars)

            with tf.control_dependencies(self.update_ops_dis_xx):
                self.dis_op_xx = self.disc_optimizer.minimize(
                    self.dis_loss_xx, var_list=self.dxxvars)

            with tf.control_dependencies(self.update_ops_dis_zz):
                self.dis_op_zz = self.disc_optimizer.minimize(
                    self.dis_loss_zz, var_list=self.dzzvars)

            # Exponential Moving Average for inference
            def train_op_with_ema_dependency(vars, op):
                ema = tf.train.ExponentialMovingAverage(
                    decay=self.config.trainer.ema_decay)
                maintain_averages_op = ema.apply(vars)
                with tf.control_dependencies([op]):
                    train_op = tf.group(maintain_averages_op)
                return train_op, ema

            self.train_gen_op, self.gen_ema = train_op_with_ema_dependency(
                self.gvars, self.gen_op)

            self.train_dis_op_xz, self.xz_ema = train_op_with_ema_dependency(
                self.dxzvars, self.dis_op_xz)
            self.train_dis_op_xx, self.xx_ema = train_op_with_ema_dependency(
                self.dxxvars, self.dis_op_xx)
            self.train_dis_op_zz, self.zz_ema = train_op_with_ema_dependency(
                self.dzzvars, self.dis_op_zz)

        self.logger.info("Building Testing Graph...")
        with tf.variable_scope("Mark1"):
            with tf.variable_scope("Generator_Model"):
                self.noise_gen_ema, self.img_rec_ema, self.noise_rec_ema = self.generator(
                    self.image_input, getter=get_getter(self.gen_ema))

        with tf.name_scope("Testing"):
            with tf.variable_scope("Reconstruction_Loss"):
                # | G_E(x) - E(G(x))|1
                # Difference between the noise generated from the input image and reconstructed noise
                delta = self.noise_gen_ema - self.noise_rec_ema
                delta = tf.layers.Flatten()(delta)
                self.score = tf.norm(delta, ord=1, axis=1, keepdims=False)

        if self.config.trainer.enable_early_stop:
            self.rec_error_valid = tf.reduce_mean(self.score)

        if self.config.log.enable_summary:
            with tf.name_scope("summary"):
                with tf.name_scope("disc_summary"):
                    tf.summary.scalar("loss_discriminator",
                                      self.loss_discriminator, ["dis"])
                    tf.summary.scalar("loss_dis_encoder", loss_dis_enc,
                                      ["dis"])
                    tf.summary.scalar("loss_dis_gen", loss_dis_gen, ["dis"])
                    tf.summary.scalar("loss_dis_xz", self.dis_loss_xz, ["dis"])
                    tf.summary.scalar("loss_dis_xx", self.dis_loss_xx, ["dis"])
                    if self.config.trainer.allow_zz:
                        tf.summary.scalar("loss_dis_zz", self.dis_loss_zz,
                                          ["dis"])
                    if self.config.trainer.loss_method:
                        tf.summary.scalar("loss_dis_fm", self.feature_match,
                                          ["dis"])
                with tf.name_scope("gen_summary"):
                    tf.summary.scalar("loss_generator_total",
                                      self.gen_loss_total, ["gen"])
                    tf.summary.scalar("loss_gen_adv", self.gen_loss_ce,
                                      ["gen"])
                    tf.summary.scalar("loss_gen_con", self.gen_loss_con,
                                      ["gen"])
                    tf.summary.scalar("loss_gen_enc", self.gen_loss_enc,
                                      ["gen"])
                with tf.name_scope("image_summary"):
                    tf.summary.image("reconstruct", self.img_rec, 3, ["image"])
                    tf.summary.image("input_images", self.image_input, 3,
                                     ["image"])
        if self.config.trainer.enable_early_stop:
            with tf.name_scope("validation_summary"):
                tf.summary.scalar("valid", self.rec_error_valid, ["v"])

        self.sum_op_dis = tf.summary.merge_all("dis")
        self.sum_op_gen = tf.summary.merge_all("gen")
        self.sum_op_im = tf.summary.merge_all("image")
        self.sum_op_valid = tf.summary.merge_all("v")