예제 #1
0
class ContextEncoder_adv(object):
    def __init__(self,
                 batch_size=params.BATCH_SIZE,
                 nb_epochs=params.NB_EPOCHS,
                 mask=None,
                 experiment_path=params.EXPERIMENT_PATH,
                 use_adversarial_loss=params.USE_ADVERSARIAL_LOSS,
                 lambda_decay=params.LAMBDA_DECAY,
                 lambda_adversarial=params.LAMBDA_ADVERSARIAL,
                 patience=params.PATIENCE,
                 discr_whole_image=params.DISCR_WHOLE_IMAGE,
                 discr_loss_limit=params.DISCR_LOSS_LIMIT,
                 use_dropout=params.USE_DROPOUT):
        self.batch_size = batch_size
        self.nb_epochs = nb_epochs
        self.experiment_path = experiment_path
        self.save_path = os.path.join(self.experiment_path, "model/")
        self.save_best_path = os.path.join(self.experiment_path, "best_model/")
        self.logs_path = os.path.join(self.experiment_path, "logs")
        create_dir(self.save_path)
        create_dir(self.logs_path)

        self.global_step = tf.Variable(0, trainable=False, name='global_step')
        self.phase = tf.placeholder(tf.bool, name='phase')
        self.patience = patience

        # parameters for adversarial loss
        self.use_adversarial_loss = use_adversarial_loss
        self.lambda_adversarial = lambda_adversarial
        if lambda_decay:
            self.lambda_adversarial = 1 - tf.train.exponential_decay(
                .1, self.global_step, 10000, .5, staircase=True)
        self.discr_whole_image = discr_whole_image
        self.discr_loss_limit = discr_loss_limit
        self.num_discr_trained = tf.Variable(tf.constant(0, dtype=tf.int32),
                                             trainable=False)

        self.use_dropout = use_dropout

        self.batch_loader = BatchLoader(self.batch_size)

        if mask is not None:
            self.np_mask = mask
        else:
            self.np_mask = np.zeros((1, 64, 64, 1))
            self.np_mask[:, 16:48, 16:48, :] = 1

        self._sess = tf.Session()
        tf.summary.scalar("lambda_adversarial", self.lambda_adversarial)
        tf.summary.scalar("num discr trained", self.num_discr_trained)

    def build_model(self):
        # x : input
        self.x = tf.placeholder(tf.float32, shape=[self.batch_size, 64, 64, 3])
        # self.x_float = 2 * tf.image.convert_image_dtype(self.x, dtype=tf.float32) - 1
        self.x_float = self.x / 255 * 2 - 1

        self.mask = tf.placeholder(tf.float32, shape=[1, 64, 64, 1])
        self.x_masked = self.x_float * (1 - self.mask)

        self._encode()
        self._channel_wise()
        self._decode()
        self._generate_image()
        self._reconstruction_loss()

        # adversarial loss
        self._init_discriminator_variables()
        self._adversarial_loss()

        self._optimize()
        self.merged_summary = tf.summary.merge_all()

    def _encode(self):
        with tf.name_scope("encode"):
            with tf.name_scope('weights'):
                self._W_conv1 = weight_variable([5, 5, 3, 128])
                self._W_conv2 = weight_variable([5, 5, 128, 256])
                self._W_conv3 = weight_variable([5, 5, 256, 512])
                self._W_conv4 = weight_variable([5, 5, 512, 512])
                self._W_conv5 = weight_variable([3, 3, 512, 512])
                variable_summaries(self._W_conv1)
                variable_summaries(self._W_conv2)
                variable_summaries(self._W_conv3)
                variable_summaries(self._W_conv4)
                variable_summaries(self._W_conv5)
            with tf.name_scope('biases'):
                self._b_conv1 = bias_variable([128])
                self._b_conv2 = bias_variable([256])
                self._b_conv3 = bias_variable([512])
                self._b_conv4 = bias_variable([512])
                self._b_conv5 = bias_variable([512])
                variable_summaries(self._b_conv1)
                variable_summaries(self._b_conv2)
                variable_summaries(self._b_conv3)
                variable_summaries(self._b_conv4)
                variable_summaries(self._b_conv5)

            # 64 64 3
            self.h_conv1 = activation_function(
                conv2d(self.x_masked,
                       self._W_conv1,
                       stride=1,
                       is_training=self.phase) + self._b_conv1)
            self.h_pool1 = avg_pool_2x2(self.h_conv1)

            # 32 32 128
            self.h_conv2 = activation_function(
                conv2d(self.h_pool1,
                       self._W_conv2,
                       stride=1,
                       is_training=self.phase) + self._b_conv2)
            self.h_pool2 = avg_pool_2x2(self.h_conv2)

            # 16 16 256
            self.h_conv3 = activation_function(
                conv2d(self.h_pool2,
                       self._W_conv3,
                       stride=1,
                       is_training=self.phase) + self._b_conv3)
            self.h_pool3 = avg_pool_2x2(self.h_conv3)

            # 8 8 512
            self.h_conv4 = activation_function(
                conv2d(self.h_pool3,
                       self._W_conv4,
                       stride=1,
                       is_training=self.phase) + self._b_conv4)
            self.h_pool4 = avg_pool_2x2(self.h_conv4)

            # 4 4 512
            self.h_conv5 = activation_function(
                conv2d(self.h_pool4,
                       self._W_conv5,
                       stride=1,
                       is_training=self.phase) + self._b_conv5)

            # 4 4 512
            if self.use_dropout:
                keep_prob = tf.cond(self.phase, lambda: tf.constant(.5),
                                    lambda: tf.constant(1.))
                self.h_conv5_drop = tf.nn.dropout(self.h_conv5, keep_prob)
            else:
                self.h_conv5_drop = self.h_conv5

    def _channel_wise(self):
        with tf.name_scope('channel_wise'):
            with tf.name_scope('weights'):
                self._W_fc1 = weight_variable([512, 4 * 4, 4 * 4])
                variable_summaries(self._W_fc1)
            with tf.name_scope('biases'):
                self._b_fc1 = bias_variable([512])
                variable_summaries(self._b_fc1)
            self.h_conv5_flat_img = tf.reshape(self.h_conv5_drop,
                                               [512, self.batch_size, 4 * 4])
            self.h_fc1 = activation_function(
                tf.reshape(tf.matmul(self.h_conv5_flat_img, self._W_fc1),
                           [self.batch_size, 16, 512]) + self._b_fc1)

            self.h_fc1_img = tf.reshape(self.h_fc1,
                                        [self.batch_size, 4, 4, 512])

    def _decode(self):
        with tf.name_scope('decode'):
            with tf.name_scope('weights'):
                self._W_uconv1 = weight_variable([5, 5, 512, 512])
                self._W_uconv2 = weight_variable([5, 5, 256, 512])
                self._W_uconv3 = weight_variable([5, 5, 128, 256])
            with tf.name_scope('biases'):
                self._b_uconv1 = bias_variable([512])
                self._b_uconv2 = bias_variable([256])
                self._b_uconv3 = bias_variable([128])
            self.h_uconv1 = activation_function(
                uconv2d(self.h_fc1_img,
                        self._W_uconv1,
                        output_shape=[self.batch_size, 8, 8, 512],
                        stride=2,
                        is_training=self.phase) + self._b_uconv1)

            # 8 8 512
            self.h_uconv2 = activation_function(
                uconv2d(self.h_uconv1,
                        self._W_uconv2,
                        output_shape=[self.batch_size, 16, 16, 256],
                        stride=2,
                        is_training=self.phase) + self._b_uconv2)

            # 16 16 256
            self.h_uconv3 = activation_function(
                uconv2d(self.h_uconv2,
                        self._W_uconv3,
                        output_shape=[self.batch_size, 32, 32, 128],
                        stride=2,
                        is_training=self.phase) + self._b_uconv3)

            # 32 32 128
            if self.use_dropout:
                keep_prob = tf.cond(self.phase, lambda: tf.constant(.5),
                                    lambda: tf.constant(1.))
                self.h_uconv3_drop = tf.nn.dropout(self.h_uconv3, keep_prob)
            else:
                self.h_uconv3_drop = self.h_uconv3

    def _generate_image(self):
        with tf.name_scope('generated_image'):
            self._W_uconv4 = weight_variable([5, 5, 3, 128])
            self._b_uconv4 = bias_variable([3])
            self.y = tf.nn.tanh(
                uconv2d(self.h_uconv3_drop,
                        self._W_uconv4,
                        output_shape=[self.batch_size, 32, 32, 3],
                        stride=1,
                        is_training=self.phase) + self._b_uconv4)
            # 32 32 3
            self.y_padded = tf.pad(self.y,
                                   [[0, 0], [16, 16], [16, 16], [0, 0]])
            self.generated_image = self.y_padded + self.x_masked
            # 64 64 3
            tf.summary.image("original_image", self.x_float, max_outputs=12)
            tf.summary.image("generated_image",
                             self.generated_image,
                             max_outputs=12)

    def _reconstruction_loss(self):
        with tf.name_scope('reconstruction_loss'):
            self._reconstruction_loss = tf.nn.l2_loss(
                self.mask * (self.x_float - self.y_padded)) / self.batch_size
            tf.summary.scalar('reconstruction_loss', self._reconstruction_loss)

    def _init_discriminator_variables(self):
        with tf.name_scope('discriminator'):
            with tf.name_scope('weights'):
                self._W_discr1 = weight_variable([5, 5, 3, 128])
                self._W_discr2 = weight_variable([5, 5, 128, 256])
                self._W_discr3 = weight_variable([5, 5, 256, 512])
                if self.discr_whole_image:
                    self._W_discr4 = weight_variable([5, 5, 512, 512])
                self._W_dfc = weight_variable([4 * 4 * 512, 1])

            with tf.name_scope('biases'):
                self._b_discr1 = bias_variable([128])
                self._b_discr2 = bias_variable([256])
                self._b_discr3 = bias_variable([512])
                if self.discr_whole_image:
                    self._b_discr4 = bias_variable([512])
                self._b_dfc = bias_variable([1])

    def _discriminator_encoder(self, image):
        with tf.name_scope('discriminator_encoder'):
            # image is 32 32 3 OR 64 64 3 (if whole image)
            h_d1 = activation_function(
                conv2d(image, self._W_discr1, stride=1, is_training=self.phase)
                + self._b_discr1)
            h_dpool1 = avg_pool_2x2(h_d1)

            # 16 16 128 OR 32 32 128 (if whole image)
            h_d2 = activation_function(
                conv2d(
                    h_dpool1, self._W_discr2, stride=1, is_training=self.phase)
                + self._b_discr2)
            h_dpool2 = avg_pool_2x2(h_d2)

            # 8 8 256 OR 16 16 256 (if whole image)
            h_d3 = activation_function(
                conv2d(
                    h_dpool2, self._W_discr3, stride=1, is_training=self.phase)
                + self._b_discr3)
            h_dpool3 = avg_pool_2x2(h_d3)

            if self.discr_whole_image:
                # 8 8 512 (if whole image)
                h_d4 = activation_function(
                    conv2d(h_dpool3,
                           self._W_discr4,
                           stride=1,
                           is_training=self.phase) + self._b_discr4)
                h_dfinal = avg_pool_2x2(h_d4)
            else:
                h_dfinal = h_dpool3

            # 4 4 512
            h_dfinal_flat = tf.reshape(h_dfinal,
                                       [self.batch_size, 4 * 4 * 512])
            if self.use_dropout:
                keep_prob = tf.cond(self.phase, lambda: tf.constant(.5),
                                    lambda: tf.constant(1.))
                h_dfinal_drop = tf.nn.dropout(h_dfinal_flat, keep_prob)
            else:
                h_dfinal_drop = h_dfinal_flat
            discr = tf.matmul(h_dfinal_drop, self._W_dfc) + self._b_dfc
            return discr

    def _adversarial_loss(self):
        with tf.name_scope('adversarial_loss'):
            self._discr_variables = [
                v for v in tf.trainable_variables()
                if v.name.startswith('discriminator')
            ]
            self._gen_variables = [
                v for v in tf.trainable_variables()
                if not v.name.startswith('discriminator')
            ]
            print(len(self._discr_variables), "DISCR VARIABLES ",
                  [v.name for v in self._discr_variables])
            print(len(self._gen_variables), "GEN VARIABLES",
                  [v.name for v in self._gen_variables])

            if self.discr_whole_image:
                # D(real img)
                real_discr = self._discriminator_encoder(self.x_float)
                # D(G(img))
                fake_discr = self._discriminator_encoder(self.generated_image)

            else:
                # discriminate the center of the image only
                self.real_img = tf.slice(self.x_float, [0, 16, 16, 0],
                                         [self.batch_size, 32, 32, 3])
                # D(real img)
                real_discr = self._discriminator_encoder(self.real_img)
                # D(G(img))
                fake_discr = self._discriminator_encoder(self.y)

            real_discr_loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=real_discr, labels=.9 * tf.ones_like(real_discr)))
            fake_discr_loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=fake_discr, labels=tf.zeros_like(fake_discr)))

            self._discr_adversarial_loss = (real_discr_loss +
                                            fake_discr_loss) / 2
            self._gen_adversarial_loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=fake_discr, labels=tf.ones_like(fake_discr)))

            # Disriminator loss
            self._discr_loss = self._discr_adversarial_loss
            # Generator loss (combination of reconstruction and adversarial loss)
            self._gen_loss = self.lambda_adversarial * self._gen_adversarial_loss + \
                             (1 - self.lambda_adversarial) * self._reconstruction_loss

            tf.summary.scalar("discr loss", self._discr_loss)
            tf.summary.scalar("gen full loss (adversarial and reconstruction)",
                              self._gen_loss)
            tf.summary.scalar("gen adversarial loss",
                              self._gen_adversarial_loss)

    def _optimize(self):
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.variable_scope("optimizer"):
            optimizer = tf.train.AdamOptimizer()

            with tf.control_dependencies(update_ops):
                # Context-encoder
                grads = optimizer.compute_gradients(self._reconstruction_loss)
                self.train_fn = optimizer.apply_gradients(
                    grads, global_step=self.global_step)

                # Context-encoder with adversarial loss
                grads_discr = optimizer.compute_gradients(
                    loss=self._discr_loss, var_list=self._discr_variables)
                grads_gen = optimizer.compute_gradients(
                    loss=self._gen_loss, var_list=self._gen_variables)
                self.train_discr = optimizer.apply_gradients(
                    grads_discr, global_step=self.global_step)
                self.train_gen = optimizer.apply_gradients(
                    grads_gen, global_step=self.global_step)

    def _compute_val_loss(self):
        n_val_batches = self.batch_loader.n_valid_batches // 2
        val_loss = 0
        for _ in tqdm(range(n_val_batches)):
            batch = self.batch_loader.load_batch(train=False)
            loss, summary_str = self._sess.run(
                [self._reconstruction_loss, self.merged_summary],
                feed_dict={
                    self.x: batch,
                    self.mask: self.np_mask,
                    self.phase: 0
                })
            val_loss += loss
        val_loss /= n_val_batches
        return val_loss, summary_str

    def _restore(self):
        """
        Retrieve last model saved if possible
        Create a main Saver object
        Create a SummaryWriter object
        Init variables
        :param save_name: string (default : model)
            Name of the model
        :return:
        """
        saver = tf.train.Saver(max_to_keep=2)
        # Try to restore an old model
        last_saved_model = tf.train.latest_checkpoint(self.save_path)

        self._sess.run(tf.global_variables_initializer())

        train_writer = tf.summary.FileWriter(os.path.join(
            self.logs_path, "train"),
                                             graph=self._sess.graph,
                                             flush_secs=20)
        val_writer = tf.summary.FileWriter(os.path.join(self.logs_path, "val"),
                                           graph=self._sess.graph,
                                           flush_secs=20)

        if last_saved_model is not None:
            saver.restore(self._sess, last_saved_model)
            print("[*] Restoring model  {}".format(last_saved_model))
        else:
            tf.train.global_step(self._sess, self.global_step)
            print("[*] New model created")
        return saver, train_writer, val_writer

    def train(self):
        """
        Train the model
        :return:
        """
        # Retrieve a model or create a new
        saver, train_writer, val_writer = self._restore()

        epoch = 0
        n_train_batches = self.batch_loader.n_train_batches

        # Retrieve current global step
        last_step = self._sess.run(self.global_step)
        epoch += last_step // n_train_batches
        last_iter = last_step - n_train_batches * epoch
        print("last iter {}".format(last_iter))
        print("last step {}".format(last_step))
        print("epoch {}".format(epoch))
        # Iterate over epochs

        is_not_restart = False
        patience_count = 0
        best_val_loss = 1e10

        while epoch < self.nb_epochs:

            for i in tqdm(range(n_train_batches)):
                if i < last_iter and not is_not_restart:
                    continue
                is_not_restart = True
                batch = self.batch_loader.load_batch(train=True)

                if self.use_adversarial_loss:
                    # no discr_loss_limit
                    if self.discr_loss_limit >= 1:
                        _ = self._sess.run(self.train_discr,
                                           feed_dict={
                                               self.x: batch,
                                               self.mask: self.np_mask,
                                               self.phase: 1
                                           })
                    # there is a discr_loss_limit
                    # train the discriminator only if its loss is higher than discr_loss_limit
                    else:
                        discr_loss = self._sess.run(self._discr_loss,
                                                    feed_dict={
                                                        self.x: batch,
                                                        self.mask:
                                                        self.np_mask,
                                                        self.phase: 1
                                                    })
                        if discr_loss >= self.discr_loss_limit:
                            self.num_discr_trained += 1
                            _ = self._sess.run(self.train_discr,
                                               feed_dict={
                                                   self.x: batch,
                                                   self.mask: self.np_mask,
                                                   self.phase: 1
                                               })

                ops = [self.train_gen, self.global_step
                       ] if self.use_adversarial_loss else [
                           self.train_fn, self.global_step
                       ]
                if i % 200 == 0:
                    ops.append(self.merged_summary)
                output = self._sess.run(ops,
                                        feed_dict={
                                            self.x: batch,
                                            self.mask: self.np_mask,
                                            self.phase: 1
                                        })

                if i % 200 == 0:
                    # print("nb of black and white images so far : {}".format(self.nb_bw_img))
                    train_writer.add_summary(output[-1], global_step=output[1])

            saver.save(self._sess,
                       global_step=output[1],
                       save_path=self.save_path)
            val_loss, summary_str = self._compute_val_loss()
            val_writer.add_summary(summary_str, global_step=output[1])
            val_writer.add_summary(tf.Summary(value=[
                tf.Summary.Value(tag="val_loss", simple_value=val_loss),
            ]),
                                   global_step=output[1])
            cprint("Epoch {}".format(epoch), color="yellow")

            # early stopping
            if val_loss < best_val_loss:
                patience_count = 0
            else:
                patience_count += 1
                if patience_count >= self.patience:
                    break

            epoch += 1

        cprint("Training done.", "green", attrs=["bold"])

        train_writer.flush()
        val_writer.flush()
        train_writer.close()
        val_writer.close()