示例#1
0
    def train(self):
        lr = self.opts.lr
        self.sess.run(self.init)
        train_set = Dataset(self.opts)
        train_size = train_set.__len__()
        for epoch in range(1, self.opts.num_epochs):
            batch_num = 0
            for batch_begin, batch_end in zip(range(0, train_size, self.opts.batch_size), \
    range(self.opts.batch_size, train_size, self.opts.batch_size)):
                begin_time = time.time()
                input_ptv, input_oct, gt_img = train_set.load_batch(
                    batch_begin, batch_end)
                feed_dict = {
                    self.true_images: gt_img,
                    self.input_ptv: input_ptv,
                    self.lr: lr,
                    self.input_oct: input_oct
                }
                _, loss, summary = self.sess.run(
                    [self.optimizer, self.loss, self.summaries],
                    feed_dict=feed_dict)

                batch_num += 1
                self.writer.add_summary(
                    summary,
                    epoch * (train_size / self.opts.batch_size) + batch_num)
                if batch_num % self.opts.display == 0:
                    rem_time = (time.time() - begin_time) * (
                        self.opts.num_epochs - epoch) * (train_size /
                                                         self.opts.batch_size)
                    log = '-' * 20
                    log += ' Epoch: {}/{}|'.format(epoch, self.opts.num_epochs)
                    log += ' Batch Number: {}/{}|'.format(
                        batch_num, train_size / self.opts.batch_size)
                    log += ' Batch Time: {}\n'.format(time.time() - begin_time)
                    log += ' Remaining Time: {:0>8}\n'.format(
                        datetime.timedelta(seconds=rem_time))
                    log += ' lr: {} loss: {}\n'.format(lr, loss)
                    print(log)
                # if epoch % self.opts.lr_decay == 0 and batch_num == 1:
                #     lr *= self.opts.lr_decay_factor
                if epoch % self.opts.ckpt_frq == 0 and batch_num == 1:
                    self.saver.save(
                        self.sess,
                        "ckpt/" + "{}_{}_{}".format(epoch, lr, loss))
示例#2
0
class Model(object):
   """Defines the base class for all models
   """

   def __init__(self, opts, is_training):
      """Initialize the model by creating various parts of the graph

      Args:
         opts: All the hyper-parameters of the network
      """
      self.opts = opts
      self.h = opts.h
      self.w = opts.w
      self.c = opts.c
      self.train_mode = is_training
      self.sess = tf.Session()
      self.build_graph()


   def build_graph(self):
      """Generate various parts of the graph
      """
      sys.stdout.write(' - Building various parts of the graph...\n')
      self.non_lin = {'relu' : lambda x: relu(x, name='relu'),
                      'lrelu': lambda x: lrelu(x, name='lrelu'),
                      'tanh' : lambda x: tanh(x, name='tanh')
                      }
      self.allocate_placeholders()

      # Common discriminator
      self.D, self.D_logits = self.discriminator(self.target_images, self.opts.d_kernels,
                                                 self.opts.d_layers, non_lin=self.opts.d_nonlin,
                                                 norm=self.opts.d_norm, use_sigmoid=self.opts.d_sigmoid,
                                                 reuse=False)

      # Generators and Encoders
      if self.opts.model == 'cvae-gan':
         self.E_mean, self.E_std  = self.encoder(self.target_images, self.opts.e_layers,
                                                 self.opts.e_kernels, self.opts.e_nonlin,
                                                 norm=self.opts.e_norm, reuse=False,
                                                 num_blocks=self.opts.e_blocks)
         self.assign_gen_code()
         self.G_cvae  = self.generator(self.input_images, self.gen_input_noise, self.opts.g_layers,
                                  self.opts.g_kernels, self.opts.g_nonlin,
                                  norm=self.opts.g_norm)
      elif self.opts.model == 'clr-gan':
         self.assign_gen_code()
         self.G_clr  = self.generator(self.input_images, self.gen_input_noise, self.opts.g_layers,
                                  self.opts.g_kernels, self.opts.g_nonlin,
                                  norm=self.opts.g_norm)
         self.E_mean, self.E_std  = self.encoder(self.G_clr, self.opts.e_layers, self.opts.e_kernels,
                                                 self.opts.e_nonlin, norm=self.opts.e_norm,
                                                 num_blocks=self.opts.e_blocks, reuse=False)
      elif self.opts.model == 'bicycle':
         # cVAE-GAN graph
         print ' - Generating cVAE-GAN graph...'
         self.E_mean_1, self.E_std_1  = self.encoder(self.target_images, self.opts.e_layers,
                                                     self.opts.e_kernels, self.opts.e_nonlin,
                                                     norm=self.opts.e_norm, reuse=False,
                                                     num_blocks=self.opts.e_blocks)
         with tf.variable_scope('encoded_noise'):
            self.encoded_noise = self.E_mean_1 + self.code * self.E_std_1
         self.G_cvae = self.generator(self.input_images, self.encoded_noise, self.opts.g_layers,
                                      self.opts.g_kernels, self.opts.g_nonlin,
                                      norm=self.opts.g_norm)

         # cLR-GAN graph
         print ' - Generating cLR-GAN graph...'
         self.G_clr = self.generator(self.input_images, self.code, self.opts.g_layers,
                                     self.opts.g_kernels, self.opts.g_nonlin,
                                     norm=self.opts.g_norm, reuse=True)
         self.E_mean_2, self.E_std_2  = self.encoder(self.G_clr, self.opts.e_layers,
                                                     self.opts.e_kernels, self.opts.e_nonlin,
                                                     norm=self.opts.e_norm, reuse=True,
                                                     num_blocks=self.opts.e_blocks)

         # Discriminators
         self.D_cvae, self.D_cvae_logits_ = self.discriminator(self.G_cvae, self.opts.d_kernels,
                                                 self.opts.d_layers, non_lin=self.opts.d_nonlin,
                                                 norm=self.opts.d_norm, use_sigmoid=self.opts.d_sigmoid,
                                                 reuse=True)
         self.D_clr, self.D_clr_logits_   = self.discriminator(self.G_clr, self.opts.d_kernels,
                                                 self.opts.d_layers, non_lin=self.opts.d_nonlin,
                                                 norm=self.opts.d_norm, use_sigmoid=self.opts.d_sigmoid,
                                                 reuse=True)

      if self.opts.model == 'clr-gan':
         self.D_, self.D_logits_ = self.discriminator(self.G_clr, self.opts.d_kernels, self.opts.d_layers,
                                                      non_lin=self.opts.d_nonlin, norm=self.opts.d_norm,
                                                      use_sigmoid=self.opts.d_sigmoid,
                                                      reuse=True)

      if self.opts.model == 'cvae-gan':
         self.D_, self.D_logits_ = self.discriminator(self.G_cvae, self.opts.d_kernels, self.opts.d_layers,
                                                      non_lin=self.opts.d_nonlin, norm=self.opts.d_norm,
                                                      use_sigmoid=self.opts.d_sigmoid,
                                                      reuse=True)

      self.variables = tf.trainable_variables()
      self.d_vars = [var for var in self.variables if 'discriminator' in var.name]
      self.ge_vars = [var for var in self.variables if 'generator' or 'encoder' in var.name]
      self.model_loss()
      self.D_opt = tf.train.RMSPropOptimizer(self.opts.base_lr).minimize(self.d_loss, var_list=self.d_vars)
      clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in self.d_vars]
      self.GE_opt = tf.train.RMSPropOptimizer(self.opts.base_lr).minimize(self.g_loss, var_list=self.ge_vars)
      self.summaries()
      self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2)


   def allocate_placeholders(self):
      """Allocate placeholders of the graph
      """
      sys.stdout.write(' - Allocating placholders...\n')
      self.images_A = tf.placeholder(tf.float32, [None, self.h, self.w, self.c], name="images_A")
      self.images_B = tf.placeholder(tf.float32, [None, self.h, self.w, self.c], name="images_B")
      self.code = tf.placeholder(tf.float32, [None, self.opts.code_len], name="code")
      self.is_training = tf.placeholder(tf.bool, name='is_training')
      self.lr = tf.placeholder(tf.float32, [], name="lr")
      if self.opts.direction == 'a2b':
         self.input_images  = self.images_A
         self.target_images = self.images_B
      elif self.opts.direction == 'b2a':
         self.input_images  = self.images_B
         self.target_images = self.images_A
      else:
         raise ValueError("There is no such image transition type")


   def assign_gen_code(self):
      """Assigns the noise for the generator
         This is dynamic: during Train mode, noise is the encoded vector
      """
      def train_mode():
         """Noise to be set during training mode"""
         input_noise = None
         if self.opts.model == 'cvae-gan' or self.opts.model == 'bicycle':
            input_noise = self.E_mean + self.code * self.E_std
         elif self.opts.model == 'clr-gan':
            input_noise = self.code
         else:
            raise ValueError("No such type of model exists !")
         return input_noise

      def test_mode():
         """Noise to be set during test mode"""
         return self.code

      with tf.variable_scope('Noise'):
         self.gen_input_noise = tf.cond(tf.equal(self.is_training, tf.constant(True)),
                                        true_fn=train_mode,
                                        false_fn=test_mode,
                                        name='Noise')
         assert self.gen_input_noise is not None, "Generator input noise is not fed"


   def summaries(self):
      """Adds all the necessary summaries
      """
      images_A = tf.summary.image('images_A', self.images_A, max_outputs=10)
      images_B = tf.summary.image('images_B', self.images_B, max_outputs=10)
      if self.opts.model == 'bicycle':
         gen_images_cvae = tf.summary.image('Gen_images_cVAE', self.G_cvae, max_outputs=10)
         gen_images_clr  = tf.summary.image('Gen_images_cLR', self.G_clr, max_outputs=10)
         self.gen_images = tf.summary.merge([gen_images_clr, gen_images_cvae])
      elif self.opts.model == 'cvae-gan':
         self.gen_images = tf.summary.image('Gen_images', self.G_cvae, max_outputs=10)
      elif self.opts.model == 'clr-gan':
         self.gen_images = tf.summary.image('Gen_images', self.G_clr, max_outputs=10)

      # Loss
      z_summary = tf.summary.histogram('z', self.code)
      d_loss_fake = tf.summary.scalar('D_loss_fake', self.loss['D_fake_loss'])
      d_loss_real = tf.summary.scalar('D_loss_real', self.loss['D_real_loss'])
      d_loss = tf.summary.scalar('D_loss', self.d_loss)
      g_loss = tf.summary.scalar('G_loss', self.g_loss)
      lr = tf.summary.scalar('learning_rate', self.lr)
      self.d_summaries = tf.summary.merge([d_loss_fake, d_loss_real, z_summary, d_loss])
      if self.opts.model == 'bicycle':
         self.g_summaries = tf.summary.merge([g_loss, lr])
      else:
         self.g_summaries = tf.summary.merge([g_loss, images_A, images_B]+[self.gen_images])

      if not self.opts.full_summaries:
         self.act_sparsity = tf.summary.merge(tf.get_collection('hist_spar'))

      try:
        self.act_sparsity = tf.summary.merge(tf.get_collection('hist_spar'))
      except:
        pass

   def get_learning_factor(self, epoch):
      """Gets the factor to multiply the learning rate with
      
      Args:
        epoch: epoch number 
      """
      return 1.0 - max(0, epoch-self.opts.niter) / float(self.opts.niter_decay+1)


   def encoder(self, image, num_layers=3, kernels=64, non_lin='lrelu', norm=None,
               reuse=False, num_blocks=4):
      """Encoder which generates the latent code

      Args:
         image     : Image which is to be encoded
         num_layers: Non linearity to the intermediate layers of the network
         kernels   : Number of filters for the first layer of the network
         non_lin   : Type of non-linearity activation
         norm      : Should use batch normalization
         reuse     : Should reuse the variables?
         num_blocks: The number of residual blocks

      Returns:
         The encoded latent code
      """
      self.e_layers = {}
      with tf.variable_scope('encoder'):
         if self.opts.e_type == "normal":
            return self.normal_encoder(image, num_layers=num_layers, output_neurons=8,
               kernels=kernels, non_lin=non_lin, norm=norm, reuse=reuse)
         elif self.opts.e_type == "residual":
            return self.resnet_encoder(image, num_layers, output_neurons=8,
               kernels=kernels, non_lin=non_lin, num_blocks=num_blocks, reuse=reuse)
         else:
            raise ValueError("No such type of encoder exists!")

   def normal_encoder(self, image, num_layers=4, output_neurons=1, kernels=64, non_lin='lrelu',
                      norm=None, reuse=False):
      """Few convolutional layers followed by downsampling layers
      """
      k, s = 4, 2
      try:
         self.e_layers['conv0'] = conv2d(image, ksize=k, out_channels=kernels*1, stride=s, name='conv0',
            non_lin=self.non_lin[non_lin], reuse=reuse)
      except KeyError:
         raise KeyError("No such non-linearity is available!")
      for idx in range(1, num_layers):
         input_layer = self.e_layers['conv{}'.format(idx-1)]
         factor = min(2**idx, 4)
         if not norm:
            self.e_layers['conv{}'.format(idx)] = conv2d(input_layer, ksize=k,
               out_channels=kernels*factor, stride=s, name='conv{}'.format(idx),
               non_lin=self.non_lin[non_lin], reuse=reuse)
         else:
            self.e_layers['conv{}'.format(idx)] = conv_bn_lrelu(input_layer, ksize=k,
               out_channels=kernels*factor, is_training=self.is_training, stride=s,
               name='conv{}'.format(idx), reuse=reuse)
         if not self.opts.full_summaries:
            activation_summary(self.e_layers['conv{}'.format(idx)])

      self.e_layers['pool'] = average_pool(self.e_layers['conv{}'.format(num_layers-1)],
         ksize=8, stride=8, name='pool')
      if not self.opts.full_summaries:
         activation_summary(self.e_layers['pool'])

      units = int(np.prod(self.e_layers['pool'].get_shape().as_list()[1:]))
      reshape_layer = tf.reshape(self.e_layers['pool'], [-1, units])
      self.e_layers['full_mean'] = fully_connected(reshape_layer, output_neurons, name='full_mean',
                                                   reuse=reuse)
      # This layers predicts the `log(var)`, to get the std,
      # std = exp(0.5 * log(var))
      self.e_layers['full_logvar'] = fully_connected(reshape_layer, output_neurons, name='full_logvar',
                                                  reuse=reuse)
      if not self.opts.full_summaries:
         activation_summary(self.e_layers['full_mean'])
         activation_summary(self.e_layers['full_logvar'])

      return self.e_layers['full_mean'], tf.exp(0.5 * self.e_layers['full_logvar'])

   def resnet_encoder(self, image, num_layers=4, num_blocks=4, output_neurons=1,
                      kernels=64, non_lin='relu', norm=None, reuse=False):
      """Residual Network with several residual blocks
      """
      self.e_layers['conv0'] = conv2d(image, ksize=4, out_channels=kernels*1, stride=2, name='conv0',
        non_lin=self.non_lin[non_lin], reuse=reuse)

      input_layer = self.e_layers['conv0']
      input_channels = self.e_layers['conv0'].get_shape().as_list()[-1]

      # Add residual blocks
      for idx in xrange(1, num_blocks):
        factor = min(idx+1, 4)
        self.e_layers['block_{}'.format(idx)] = residual_block_v2(input_layer,
             out_channels=[input_channels, kernels*factor], is_training=self.is_training,
             name='block_{}'.format(idx), reuse=reuse)
        input_layer = self.e_layers['block_{}'.format(idx)]
        input_channels = self.e_layers['block_{}'.format(idx)].get_shape().as_list()[-1]

      self.e_layers['pool'] = average_pool(self.e_layers['block_{}'.format(num_blocks-1)],
         ksize=8, stride=8, name='pool')
      if not self.opts.full_summaries:
         activation_summary(self.e_layers['pool'])

      units = int(np.prod(self.e_layers['pool'].get_shape().as_list()[1:]))
      reshape_layer = tf.reshape(self.e_layers['pool'], [-1, units])
      self.e_layers['full_mean'] = fully_connected(reshape_layer, output_neurons, name='full_mean',
                                                   reuse=reuse)
      self.e_layers['full_logvar'] = fully_connected(reshape_layer, output_neurons, name='full_logvar',
                                                  reuse=reuse)
      if not self.opts.full_summaries:
         activation_summary(self.e_layers['full_mean'])
         activation_summary(self.e_layers['full_logvar'])

      return self.e_layers['full_mean'], tf.exp(0.5 * self.e_layers['full_logvar'])


   def generator(self, image, z, layers=3, kernels=64, non_lin='relu', norm=None,
                 reuse=False):
      """Generator graph of GAN

      Args:
         image  : Conditioned image on which the generator generates the image 
         z      : Latent space code (or noise when sampling the images)
         layers : The number of layers either in downsampling / upsampling 
         kernels: Number of kernels to the first layer of the network
         non_lin: Non linearity to be used
         norm   : Whether to use batch normalization layer
         reuse  : Whether to reuse the variables created for generator graph

      Returns:
         Generated image
      """
      self.g_layers = {}
      with tf.variable_scope('generator'):
         if self.opts.where_add == "input":
            return self.generator_input(image, z, layers, kernels, non_lin, norm,
                                        reuse)
         elif self.opts.where_add == "all":
            return self.generator_all(image, z, layers, kernels, non_lin, norm,
                                      reuse)
         else:
            raise ValueError("No such type of generator exists!")


   def generator_input(self, image, z, layers=3, kernels=32, non_lin='lrelu', norm=None,
                       reuse=False):
      """Generator graph where noise is concatenated to the first layer
      """

      with tf.name_scope('replication'):
         tiled_z = tf.tile(z, [1, self.w*self.h], name='tiling')
         reshaped = tf.reshape(tiled_z, [-1, self.h, self.w, self.opts.code_len], name='reshape')
         in_layer = tf.concat([image, reshaped], axis=3, name='concat')
      k, s = 4, 2
      factor = 1

      with tf.variable_scope('down_1'):
        conv1 = conv2d(in_layer, ksize=3, out_channels=32, stride=1, name='conv1', non_lin=self.non_lin[non_lin], reuse=reuse)
        conv1 = conv2d(conv1,    ksize=3, out_channels=32, stride=1, name='conv2', non_lin=self.non_lin[non_lin], reuse=reuse)
        pool1 = max_pool(conv1, kernel=2, stride=2, name='pool1')
      
      with tf.variable_scope('down_2'):
        conv2 = conv2d(pool1, ksize=3, out_channels=64, stride=1, name='conv1', non_lin=self.non_lin[non_lin], reuse=reuse)
        conv2 = conv2d(conv2, ksize=3, out_channels=64, stride=1, name='conv2', non_lin=self.non_lin[non_lin], reuse=reuse)
        pool2 = max_pool(conv2, kernel=2, stride=2, name='pool1')

      with tf.variable_scope('down_3'):
        conv3 = conv2d(pool2, ksize=3, out_channels=128, stride=1, name='conv1', non_lin=self.non_lin[non_lin], reuse=reuse)
        conv3 = conv2d(conv3, ksize=3, out_channels=128, stride=1, name='conv2', non_lin=self.non_lin[non_lin], reuse=reuse)
        pool3 = max_pool(conv3, kernel=2, stride=2, name='pool1')

      with tf.variable_scope('down_4'):
        conv4 = conv2d(pool3, ksize=3, out_channels=256, stride=1, name='conv1', non_lin=self.non_lin[non_lin], reuse=reuse)
        conv4 = conv2d(conv4, ksize=3, out_channels=256, stride=1, name='conv2', non_lin=self.non_lin[non_lin], reuse=reuse)
        pool4 = max_pool(conv4, kernel=2, stride=2, name='pool1')

      with tf.variable_scope('down_5'):
        conv5  = conv2d(pool4, ksize=3, out_channels=512, stride=1, name='conv1', non_lin=self.non_lin[non_lin], reuse=reuse)
        conv5 = conv2d(conv5,  ksize=3, out_channels=512, stride=1, name='conv2', non_lin=self.non_lin[non_lin], reuse=reuse)

      with tf.variable_scope('up_1'):
        dcnv1 = deconv(conv5, ksize=3, out_channels=512, stride=2, name='dconv1', out_shape=32, non_lin=self.non_lin[non_lin],
                       batch_size=self.opts.batch_size, reuse=reuse)
        up1   = concatenate(dcnv1, conv4, axis=3)
        conv6 = conv2d(up1,   ksize=3, out_channels=256, stride=1, name='conv1', non_lin=self.non_lin[non_lin], reuse=reuse)
        conv6 = conv2d(conv6, ksize=3, out_channels=256, stride=1, name='conv2', non_lin=self.non_lin[non_lin], reuse=reuse)

      with tf.variable_scope('up_2'):
        dcnv2 = deconv(conv6, ksize=3, out_channels=256, stride=2, name='dconv1', out_shape=64, non_lin=self.non_lin[non_lin],
                       batch_size=self.opts.batch_size, reuse=reuse)
        up2   = concatenate(dcnv2, conv3, axis=3)
        conv7 = conv2d(up2,   ksize=3, out_channels=128, stride=1, name='conv1', non_lin=self.non_lin[non_lin], reuse=reuse)
        conv7 = conv2d(conv7, ksize=3, out_channels=128, stride=1, name='conv2', non_lin=self.non_lin[non_lin], reuse=reuse)

      with tf.variable_scope('up_3'):
        dcnv3 = deconv(conv7, ksize=3, out_channels=128, stride=2, name='dconv1', out_shape=128, non_lin=self.non_lin[non_lin],
                       batch_size=self.opts.batch_size, reuse=reuse)
        up2   = concatenate(dcnv3, conv2, axis=3)
        conv8 = conv2d(up2,   ksize=3, out_channels=64, stride=1, name='conv1', non_lin=self.non_lin[non_lin], reuse=reuse)
        conv8 = conv2d(conv8, ksize=3, out_channels=64, stride=1, name='conv2', non_lin=self.non_lin[non_lin], reuse=reuse)

      with tf.variable_scope('up_4'):
        dcnv4 = deconv(conv8, ksize=3, out_channels=64, stride=2, name='dconv1', out_shape=256, non_lin=self.non_lin[non_lin],
                       batch_size=self.opts.batch_size, reuse=reuse)
        up3   = concatenate(dcnv4, conv1, axis=3)
        conv9 = conv2d(up3,   ksize=3, out_channels=32, stride=1, name='conv1', non_lin=self.non_lin[non_lin], reuse=reuse)
        conv9 = conv2d(conv9, ksize=3, out_channels=32, stride=1, name='conv2', non_lin=self.non_lin[non_lin], reuse=reuse)

      with tf.variable_scope('up_5'):
        output = conv2d(conv9, ksize=3, out_channels=3, stride=1, name='conv1', non_lin=self.non_lin['tanh'], reuse=reuse)

      return output


   def generator_all(self, image, z, layers=3, kernels=64, non_lin='lrelu', norm=None,
                     reuse=False):
      """Generator graph where noise is to all the layers
      """
      raise NotImplementedError("Not Implemented")


   def discriminator(self, image, kernels=64, num_layers=3, norm_layer=None, non_lin='lrelu', 
                     use_sigmoid=False, reuse=False, norm=None):
      """Discriminator graph of GAN
      The discriminator is a PatchGAN discriminator which consists of two 
         discriminators for two different scales i.e, 70x70 and 140x140
      Authors claim not conditioning the discriminator yields better results
         and hence not conditioning the discriminator with the input image
      Authors also claim that using two discriminators for cVAE-GAN and cLR-GAN
         yields better results, here we share the weights for both of them

      Args:
         image      : Input image to the discriminator
         kernels    : Number of kernels for the first layer of the network
         num_layers : Total number of layers
         norm_layer : Type of normalization layer {batch/instance}
         non_lin    : Type of non-linearity of the network
         use_sigmoid: Use Sigmoid layer before the final layer?
         reuse      : Flag to check whether to reuse the variables created for the
                     discriminator graph
         norm       : Whether to use batch normalization layer

      Returns:
         Whether or not the input image is real or fake
      """
      self.d_layers = {}
      with tf.variable_scope('discriminator'):
         if not self.opts.d_usemulti:
            return self.discriminator_patch(image, kernels, num_layers, norm_layer, non_lin, 
               use_sigmoid, reuse, norm)
         else:
            raise NotImplementedError("Multiple discriminators is not implemented")


   def discriminator_patch(self, image, kernels, num_layers, norm_layer, non_lin,
                           use_sigmoid=False, reuse=False, norm=None):
      """PatchGAN discriminator
      """
      k, s = 4, 2
      self.d_layers['conv0'] = conv2d(image, ksize=k, out_channels=kernels*1, stride=s, name='conv0',
            non_lin=self.non_lin[non_lin], reuse=reuse)
      for idx in range(1, num_layers):
         input_layer = self.d_layers['conv{}'.format(idx-1)]
         factor = min(2**idx, 8)
         if not norm:
            self.d_layers['conv{}'.format(idx)] = conv2d(input_layer, ksize=k,
               out_channels=kernels*factor, stride=s, name='conv{}'.format(idx),
               non_lin=self.non_lin[non_lin], reuse=reuse)
         else:
            self.d_layers['conv{}'.format(idx)] = conv_bn_lrelu(input_layer, ksize=k,
               out_channels=kernels*factor, is_training=self.is_training, stride=s,
               name='conv{}'.format(idx), reuse=reuse)

      input_layer = self.d_layers['conv{}'.format(num_layers-1)]
      factor = min(2**num_layers, 8)
      if not norm:
         self.d_layers['conv{}'.format(num_layers)] = conv2d(input_layer, ksize=k, out_channels=
            kernels*factor, stride=s, name='conv{}'.format(num_layers), non_lin=self.non_lin[non_lin],
            reuse=reuse)
      else:
         self.d_layers['conv{}'.format(num_layers)] = conv_bn_lrelu(input_layer, ksize=k,
            out_channels=kernels*factor, is_training=self.is_training, stride=s,
            name='conv{}'.format(num_layers), reuse=reuse)

      input_layer = self.d_layers['conv{}'.format(num_layers)]
      self.d_layers['conv{}'.format(num_layers+1)] = conv2d(input_layer, ksize=k, out_channels=1, 
         stride=s, name='conv{}'.format(num_layers+1), reuse=reuse)

      logits = self.d_layers['conv{}'.format(num_layers+1)]
      return sigmoid(logits), logits


   def model_loss(self):
      """Implements the loss graph
         All the loss values are stored in the dictionary `self.loss`
      """

      def cVAE_GAN_loss(true_logit, fake_logit, E_mean, E_std, z1, z2):
         """Computes cVAE-GAN loss

         Args:
            true_logit: Output of discriminator for true image
            fake_logit: Output of discriminator for fake image
            E_mean    : Mean predicted by encoder
            E_std     : Std predicted by encoder
            z1        : -
            z2        : -
         """
         with tf.variable_scope('cVAE_GAN_loss'):
            gan_loss(true_logit=true_logit,
                     fake_logit=fake_logit,
                     model='cVAE')
            with tf.variable_scope('KL_loss'):
               self.loss['KL']  = self.opts.lambda_kl * kl_divergence(E_mean, E_std)
            with tf.variable_scope('L1_VAE_loss'):
               self.loss['L1_VAE']  = self.opts.lambda_img * l1_loss(z1, z2)

      def cLR_GAN_loss(true_logit, fake_logit, E, skip=False):
         """Computes cLR-GAN loss

         Args:
            true_logit: Output of discriminator for true image
            fake_logit: Output of discriminator for fake image
            E         : Mean predicted by encoder
            skip      : Whether to skip the G_loss
         """
         with tf.variable_scope('cLR_GAN_loss'):
            gan_loss(true_logit=true_logit,
                     fake_logit=fake_logit,
                     model='cLR',
                     skip_d_real_loss=skip)
            with tf.variable_scope('L1_latent_loss'):
               self.loss['L1_latent']  = self.opts.lambda_latent * l1_loss(E, self.code)

      def gan_loss(true_logit, fake_logit, model='cLR', skip_d_real_loss=False):
         """Implements the GAN loss

         Args:
            true_logit      : Output of discriminator for true image
            fake_logit      : Output of discriminator for fake image
            model           : Name of the model to compute loss for
            skip_d_real_loss: Whether to skip G_loss, should be skipped the second time
                              while training bicycleGAN model
         """
         if len(true_logit.get_shape().as_list()) != 2:
            true_logit = tf.reduce_mean(tf.reshape(true_logit, [self.opts.batch_size, -1]), axis=1)
            fake_logit = tf.reduce_mean(tf.reshape(fake_logit, [self.opts.batch_size, -1]), axis=1)
         with tf.variable_scope('D_fake_loss'):
            self.loss['D_{}_fake_loss'.format(model)] = tf.reduce_mean(fake_logit) - tf.reduce_mean(true_logit)
         with tf.variable_scope('D_real_loss'):
            if not skip_d_real_loss:
               self.loss['D_{}_real_loss'.format(model)] = tf.constant(0.)
            else:
               self.loss['D_{}_real_loss'.format(model)] = 0.
         with tf.variable_scope('G_loss'):
            self.loss['G_{}_loss'.format(model)] = -tf.reduce_mean(fake_logit)

         self.loss['D_{}_loss'.format(model)] = self.loss['D_{}_fake_loss'.format(model)] + \
                                                self.loss['D_{}_real_loss'.format(model)]

      def l1_loss(z1, z2):
         """Implements L1 loss graph
         
         Args:
            z1: Image in case of cVAE-GAN
                Vector in case of cLR-GAN
            z2: Image in case of cVAE-GAN
                Vector in case of cLR-GAN

         Returns:
            L1 loss
         """
         return tf.reduce_mean(tf.abs(z1-z2))

      def kl_divergence(p1_mean, p1_std):
         """Apply KL divergence
            The second distribution is assumed to be unit Gaussian distribution
         
         Args:
            p1_mean: Mean of 1st probability distribution
            p1_std : Std of 1st probability distribution

         Returns:
            KL Divergence between the given distributions
         """
         divergence = 0.5 * tf.reduce_sum(tf.square(p1_mean)+tf.square(p1_std)- \
                                          1.0 * tf.log(tf.square(p1_std))-1, axis=1)
         return tf.reduce_mean(divergence, axis=0)

      with tf.variable_scope('loss'):
         self.loss = {}
         if self.opts.model == 'cvae-gan':
            cVAE_GAN_loss(self.D_logits, self.D_logits_, self.E_mean,
                          self.E_std, self.target_images, self.G_cvae)
            self.d_loss = self.loss['D_cVAE_loss']
            self.g_loss = self.loss['KL'] +\
                          self.loss['L1_VAE'] +\
                          self.loss['G_cVAE_loss']
            self.loss['D_fake_loss'] = self.loss['D_cVAE_fake_loss']
            self.loss['D_real_loss'] = self.loss['D_cVAE_real_loss']
         elif self.opts.model == 'clr-gan':
            cLR_GAN_loss(self.D_logits, self.D_logits_, self.E_mean)
            self.d_loss = self.loss['D_cLR_loss']
            self.g_loss = self.loss['L1_latent'] +\
                          self.loss['G_cLR_loss']
            self.loss['D_fake_loss'] = self.loss['D_cLR_fake_loss']
            self.loss['D_real_loss'] = self.loss['D_cLR_real_loss']
         elif self.opts.model == 'bicycle':
            with tf.variable_scope('Bicycle_GAN_loss'):
               cVAE_GAN_loss(self.D_logits, self.D_cvae_logits_, self.E_mean_1,
                             self.E_std_1, self.target_images, self.G_cvae)
               cLR_GAN_loss(self.D_logits, self.D_clr_logits_, self.E_mean_2, skip=True)
               self.d_loss = self.loss['D_cLR_loss'] +\
                             self.loss['D_cVAE_loss']
               self.g_loss = self.loss['KL'] +\
                             self.loss['L1_VAE'] +\
                             self.loss['L1_latent'] + \
                             self.loss['G_cLR_loss'] +\
                             self.loss['G_cVAE_loss']
               self.loss['D_fake_loss'] = self.loss['D_cLR_fake_loss'] +\
                                          self.loss['D_cVAE_fake_loss']
               self.loss['D_real_loss'] = self.loss['D_cVAE_real_loss']
         else:
            raise ValueError("\"{}\" type of architecture doesn't exist for loss !".format(self.opts.model))

   def train(self):
      """Train the network
      """
      self.test_graph()
      self.data = Dataset(self.opts, load=True)

      if self.opts.resume:
        try:
          self.saver.restore(self.sess, self.opts.ckpt)
          print ' - Successfully restored the checkpoint: {}'.format(self.opts.ckpt)
        except:
          raise ValueError(" - Cannot restore the checkpoint file: {}".format(self.opts.ckpt))
      else:
        self.init = tf.global_variables_initializer()
        self.sess.run(self.init)

      formatter =  "{} Elapsed Time: {}  Epoch: [{:2d}/{:2d}]  Batch: [{:4d}/{:4d}]  LR: {:.5f}  "
      formatter += "D_fake_loss: {:.5f}  D_real_loss: {:.5f}  D_loss: {:.5f}  G_loss: {:.5f}"

      if self.opts.noise_type == "gauss":
        runtime_z = gaussian_noise([self.opts.sample_num, self.opts.code_len])
      elif self.opts.noise_type == "uniform":
        runtime_z = uniform_noise([self.opts.sample_num, self.opts.code_len])
      else:
        raise ValueError("No such type of noise is present !")

      start_time = datetime.now()
      print ' - Training the network...\n'
      for epoch in xrange(0, self.opts.niter+self.opts.niter_decay+1):
         batch_num = 0
         lr_factor = self.get_learning_factor(epoch)
         for batch_begin, batch_end in zip(xrange(0, self.data.train_size(),
            self.opts.batch_size), xrange(self.opts.batch_size, self.data.train_size()+1,
            self.opts.batch_size)):

            iteration = epoch * (self.data.train_size()/self.opts.batch_size) + batch_num
            images_A, images_B = self.data.load_batch(batch_begin, batch_end)

            if self.opts.noise_type == "gauss":
              code = gaussian_noise([self.opts.sample_num, self.opts.code_len])
            elif self.opts.noise_type == "uniform":
              code = uniform_noise([self.opts.sample_num, self.opts.code_len])
            else:
              raise ValueError("No such type of noise is present !")

            # Update Discriminator
            feed_dict = {
              self.images_A: images_A,
              self.images_B: images_B,
              self.code: code,
              self.is_training: True,
              self.lr: self.opts.base_lr*lr_factor
            }
            _, d_loss, d_summaries, d_fake, d_real = self.sess.run(
                    [self.D_opt, self.d_loss, self.d_summaries,
                     self.loss['D_fake_loss'],
                     self.loss['D_real_loss']
                     ],
                    feed_dict=feed_dict)
            self.writer.add_summary(d_summaries, iteration)

            # Update Generator and Encoder
            feed_dict = {
              self.images_A: images_A,
              self.images_B: images_B,
              self.code: code,
              self.is_training: True,
              self.lr: self.opts.base_lr*lr_factor
              }

            for i in xrange(self.opts.g_update):
              _, g_summaries, g_loss = self.sess.run(
                      [self.GE_opt, self.g_summaries,
                       self.g_loss], feed_dict=feed_dict)
              self.writer.add_summary(g_summaries, iteration)

            elapsed_time = datetime.now() - start_time
            curr_time = datetime.fromtimestamp(int(time.time())).strftime('%d-%m-%Y %H:%M:%S')
            print formatter.format(curr_time, elapsed_time, epoch, self.opts.niter+self.opts.niter_decay+1,
                                   batch_num+1, self.data.train_size()/self.opts.batch_size,
                                   self.opts.base_lr, d_fake, d_real, d_fake + d_real,
                                   g_loss)

            # Sample the images
            if np.mod(iteration, self.opts.gen_frq) == 0:
               print ' - [Sampling the images...]'
               feed_dict = {
                  self.images_A: images_A,
                  self.images_B: images_B,
                  self.code: runtime_z,
                  self.is_training: False,
                  self.lr: self.opts.base_lr*lr_factor
               }
               if self.opts.model == 'bicycle':
                  images_cvae = self.G_cvae.eval(session=self.sess, feed_dict=feed_dict)
                  images_clr  = self.G_clr.eval(session=self.sess, feed_dict=feed_dict)

                  utils.imwrite(os.path.join(
                          self.opts.sample_dir, 'iter_{}_cLR'.format(iteration)),
                          images_clr, inv_normalize=True)
                  utils.imwrite(os.path.join(
                          self.opts.sample_dir, 'iter_{}_cVAE'.format(iteration)),
                          images_cvae, inv_normalize=True)
               elif self.opts.model == 'cvae-gan':
                  images = self.G_cvae.eval(session=self.sess, feed_dict=feed_dict)
               elif self.opts.model == 'clr-gan':
                  images = self.G_clr.eval(session=self.sess, feed_dict=feed_dict)
               else:
                  raise ValueError("No such type of model exists")

               if self.opts.model != 'bicycle':
                  utils.imwrite(os.path.join(
                                self.opts.sample_dir, 'iter_{}'.format(iteration)),
                                images, inv_normalize=True)

            # Validate the model
            if np.mod(iteration, self.opts.gen_frq*1) == 0:
              self.validate(iteration)

            batch_num += 1

         if np.mod(epoch, self.opts.ckpt_frq) == 0:
            self.checkpoint(epoch)
      self.sess.close()


   def validate(self, iteration):
      """Validates"""
      print ' - Validating the model at iteration: {}'.format(iteration)

      images_A, images_B = self.data.load_val_batch()
      for i in xrange(3):
        print ' - Validating with latent vector #{}'.format(i)
        if self.opts.noise_type == "gauss":
          sample_z = gaussian_noise([self.opts.sample_num, self.opts.code_len])
        elif self.opts.noise_type == "uniform":
          sample_z = uniform_noise([self.opts.sample_num, self.opts.code_len])
        else:
          raise ValueError("No such type of noise is present !")

        feed_dict = {
          self.images_A: images_A,
          self.images_B: images_B,
          self.code: sample_z,
          self.is_training: False,
        }


        gen_image_summaries = self.gen_images.eval(session=self.sess, feed_dict=feed_dict)
        self.writer.add_summary(gen_image_summaries, iteration)

        
        if self.opts.model == 'bicycle' or self.opts.model == 'clr-gan':
          images_clr  = self.G_clr.eval(session=self.sess, feed_dict=feed_dict)
          utils.imwrite(os.path.join(
                  self.opts.sample_dir, 'VAL_{}_cLR'.format(iteration)),
                  images_clr, inv_normalize=True)
        if self.opts.model == 'bicycle' or self.opts.model == 'cvae-gan':
          images_cvae = self.G_cvae.eval(session=self.sess, feed_dict=feed_dict)
          utils.imwrite(os.path.join(
                  self.opts.sample_dir, 'VAL_{}_cVAE'.format(iteration)),
                  images_cvae, inv_normalize=True)
        utils.imwrite(os.path.join(
                self.opts.sample_dir, 'VAL_{}_ground_truth_A'.format(iteration)),
                images_A, inv_normalize=True)
        utils.imwrite(os.path.join(
                self.opts.sample_dir, 'VAL_{}_ground_truth_B'.format(iteration)),
                images_B, inv_normalize=True)






   def checkpoint(self, epoch):
      """Creates a checkpoint at the given epoch

      Args:
         epoch: epoch number of the training process
      """
      self.saver.save(self.sess, os.path.join(self.opts.summary_dir, "model_{}.ckpt").format(epoch))


   def test_graph(self):
      """Generate the graph and check if the connections are correct
      """
      sys.stdout.write(' - Generating the test graph...\n')
      self.writer = tf.summary.FileWriter(logdir=self.opts.summary_dir,
                                          graph=self.sess.graph)


   def test(self, source):
      """Test the model

      Args:
         source: Input to the model, either single image or directory containing images

      Returns:
         The generated image conditioned on the input image
      """
      split_len = 600 if self.opts.dataset == 'maps' else 256

      img = utils.normalize_images(utils.imread(source))
      img_A = img[:, :split_len, :]
      img_B = img[:, split_len:, :]
      img_A = np.expand_dims(img_A, 0)
      img_B = np.expand_dims(img_B, 0)
      if self.opts.direction == 'b2a':
        input_images = img_B
        target_images = img_A
      else:
        input_images = img_A
        target_images = img_B

      self.saver.restore(self.sess, self.opts.ckpt)
      utils.imwrite(os.path.join(
              self.opts.target_dir, 'target_image'),
              target_images[0], inv_normalize=True)
      utils.imwrite(os.path.join(
              self.opts.target_dir, 'conditional_image'),
              input_images[0], inv_normalize=True)

      print ' - Sampling generator images for different random initial noise'
      for idx in xrange(self.opts.sample_num):
        print 'Sampling #', idx
        if self.opts.noise_type == "gauss":
          code = gaussian_noise([1, self.opts.code_len])
        else:
          code = uniform_noise([1, self.opts.code_len])

        feed_dict = {
          self.is_training: False,
          self.images_A: img_A,
          self.images_B: img_B,
          self.code: code
        }
        if self.opts.model == 'bicycle':
          images = self.G_cvae.eval(session=self.sess, feed_dict=feed_dict)
          utils.imwrite(os.path.join(
                  self.opts.target_dir, 'test_cvae{}'.format(idx)),
                  images, inv_normalize=True)
          images = self.G_clr.eval(session=self.sess, feed_dict=feed_dict)
          utils.imwrite(os.path.join(
                  self.opts.target_dir, 'test_clr{}'.format(idx)),
                  images, inv_normalize=True)
        else:
          raise ValueError("Testing only possible for bicycleGAN")
示例#3
0
    def train(self):
        utils = Dataset(self.opts)
        lr = self.opts.base_lr
        self.sess.run(self.init)
        for iteration in xrange(1, self.opts.MAX_iterations):
            batch_num = 0
            for batch_begin, batch_end in zip(xrange(0, self.opts.train_size, self.opts.batch_size), \
             xrange(self.opts.batch_size, self.opts.train_size, self.opts.batch_size)):
                begin_time = time.time()
                batch_imgs = utils.load_batch(batch_begin, batch_end)
                feed_dict = {self.images: batch_imgs, self.lr: lr}
                _, l1, l2, summary = self.sess.run(
                    [self.optimizer, self.l1, self.l2, self.summaries],
                    feed_dict=feed_dict)

                batch_num += 1
                self.writer.add_summary(
                    summary,
                    iteration * (self.opts.train_size / self.opts.batch_size) +
                    batch_num)
                if batch_num % self.opts.display == 0:
                    rem_time = (time.time() -
                                begin_time) * self.opts.MAX_iterations * (
                                    self.opts.train_size /
                                    self.opts.batch_size)
                    log = '-' * 20
                    log += '\nIteration: {}/{}|'.format(
                        iteration, self.opts.MAX_iterations)
                    log += ' Batch Number: {}/{}|'.format(
                        batch_num, self.opts.train_size / self.opts.batch_size)
                    log += ' Batch Time: {}\n'.format(time.time() - begin_time)
                    log += ' Remaining Time: {:0>8}\n'.format(
                        datetime.timedelta(seconds=rem_time))
                    log += ' Learning Rate: {}\n'.format(lr)
                    log += ' Encoder Loss: {}\n'.format(l1)
                    log += ' Decoder Loss: {}\n'.format(l2)
                    print log
                if iteration % self.opts.lr_decay == 0 and batch_num == 1:
                    lr *= self.opts.lr_decay_factor
                if iteration % self.opts.ckpt_frq == 0 and batch_num == 1:
                    self.saver.save(
                        self.sess,
                        os.path.join(self.opts.root_dir, self.opts.ckpt_dir,
                                     "{}".format(iteration)))
                if iteration % self.opts.generate_frq == 0 and batch_num == 1:
                    generate_imgs = utils.test_images
                    imgs = self.sess.run(self.generated_imgs,
                                         feed_dict={
                                             self.images: generate_imgs,
                                             self.lr: lr
                                         })
                    if self.opts.dataset == "CIFAR":
                        imgs = np.reshape(
                            imgs, (self.opts.test_size, 3, 32, 32)).transpose(
                                0, 2, 3, 1)
                    else:
                        imgs = np.reshape(imgs, (self.opts.test_size, 28, 28))
                    tf.summary.image('Generated image', imgs[0])
                    utils.save_batch_images(
                        imgs, [self.opts.grid_h, self.opts.grid_w],
                        str(iteration) + ".jpg", True)
示例#4
0
    def train(self):
        code = np.random.uniform(
            low=-1.0,
            high=1.0,
            size=[self.opts.batch_size, self.opts.code_len]).astype(np.float32)
        utils = Dataset(self.opts)
        D_lr = self.opts.D_base_lr
        G_lr = self.opts.G_base_lr
        self.sess.run(self.init)
        for iteration in xrange(1, self.opts.MAX_iterations):
            batch_num = 0
            for batch_begin, batch_end in zip(xrange(0, self.opts.train_size, self.opts.batch_size), \
             xrange(self.opts.batch_size, self.opts.train_size, self.opts.batch_size)):
                begin_time = time.time()
                if self.opts.use_labels:
                    batch_imgs, batch_labels = utils.load_batch(
                        batch_begin, batch_end)
                else:
                    batch_imgs = utils.load_batch(batch_begin, batch_end)
                noise = np.random.uniform(
                    low=-1.0,
                    high=1.0,
                    size=[self.opts.batch_size,
                          self.opts.code_len]).astype(np.float32)

                # Real data
                if self.opts.use_labels:
                    feed_dict = {
                        self.images: batch_imgs,
                        self.D_lr: D_lr,
                        self.G_lr: G_lr,
                        self.code: noise,
                        self.labels: batch_labels
                    }
                else:
                    feed_dict = {
                        self.images: batch_imgs,
                        self.D_lr: D_lr,
                        self.G_lr: G_lr,
                        self.code: noise
                    }
                _, D_loss = self.sess.run([self.D_optimizer, self.d_loss],
                                          feed_dict=feed_dict)

                # Fake data
                if self.opts.use_labels:
                    feed_dict = {
                        self.images: batch_imgs,
                        self.D_lr: D_lr,
                        self.G_lr: G_lr,
                        self.code: noise,
                        self.labels: batch_labels
                    }
                else:
                    feed_dict = {
                        self.images: batch_imgs,
                        self.D_lr: D_lr,
                        self.G_lr: G_lr,
                        self.code: noise
                    }
                _, G_loss, summary = self.sess.run(
                    [self.G_optimizer, self.g_loss, self.summaries],
                    feed_dict=feed_dict)

                batch_num += 1
                self.writer.add_summary(
                    summary,
                    iteration * (self.opts.train_size / self.opts.batch_size) +
                    batch_num)
                if batch_num % self.opts.display == 0:
                    rem_time = (time.time() - begin_time) * (
                        self.opts.MAX_iterations - iteration) * (
                            self.opts.train_size / self.opts.batch_size)
                    log = '-' * 20
                    log += '\nIteration: {}/{}|'.format(
                        iteration, self.opts.MAX_iterations)
                    log += ' Batch Number: {}/{}|'.format(
                        batch_num, self.opts.train_size / self.opts.batch_size)
                    log += ' Batch Time: {}\n'.format(time.time() - begin_time)
                    log += ' Remaining Time: {:0>8}\n'.format(
                        datetime.timedelta(seconds=rem_time))
                    log += ' D_lr: {} D_loss: {}\n'.format(D_lr, D_loss)
                    log += ' G_lr: {} G_loss: {}\n'.format(G_lr, G_loss)
                    print log
                if iteration % self.opts.lr_decay == 0 and batch_num == 1:
                    D_lr *= self.opts.lr_decay_factor
                    G_lr *= self.opts.lr_decay_factor
                if iteration % self.opts.ckpt_frq == 0 and batch_num == 1:
                    self.saver.save(
                        self.sess, self.opts.root_dir +
                        self.opts.ckpt_dir + "{}_{}_{}_{}".format(
                            iteration, D_lr, G_lr, D_loss + G_loss))
                if iteration % self.opts.generate_frq == 0 and batch_num == 1:
                    feed_dict = {self.code: code}
                    self.is_training = False
                    imgs = self.sess.run(self.generated_imgs,
                                         feed_dict=feed_dict)
                    if self.opts.dataset == "CIFAR":
                        imgs = np.reshape(
                            imgs, (self.opts.test_size, 3, 32, 32)).transpose(
                                0, 2, 3, 1)
                    else:
                        imgs = np.reshape(imgs, (self.opts.test_size, 28, 28))
                    utils.save_batch_images(
                        imgs, [self.opts.grid_h, self.opts.grid_w],
                        str(iteration) + ".jpg", True)
                    self.is_training = True