Esempio n. 1
0
    def _build_optimizer(self, name, loss, vars, step=None, step_update=None):
        if step == None and hasattr(self, 'global_step'):
            step = self.global_step

        if name == '' or name == 'optimizer':
            (train_op, learning_rate_var, step_var) = get_optimizer_by_config(
                self.config['optimizer'], self.config['optimizer params'],
                loss, vars, step, step_update)
        else:
            (train_op, learning_rate_var, step_var) = get_optimizer_by_config(
                self.config[name + ' optimizer'],
                self.config[name + ' optimizer params'], loss, vars, step,
                step_update)

        return train_op, learning_rate_var, step_var
Esempio n. 2
0
    def build_model(self):

        self.config['classifier params']['name'] = 'classifier'
        self.config['classifier params']["output dims"] = self.nb_classes

        self.classifier = get_classifier(self.config['classifier'],
                                         self.config['classifier params'],
                                         self.is_training)

        # for training
        self.x = tf.placeholder(tf.float32,
                                shape=[
                                    None,
                                ] + self.input_shape,
                                name='x_input')
        self.mask = tf.placeholder(tf.float32,
                                   shape=[
                                       None,
                                   ] + self.mask_shape,
                                   name='mask')

        self.logits, self.end_points = self.classifier.features(self.x)

        self.loss = get_loss('segmentation', self.config['segmentation loss'],
                             {
                                 'logits': self.logits,
                                 'mask': self.mask
                             })

        self.train_miou = get_metric(
            'segmentation', 'miou', {
                'logits': self.logits,
                'mask': self.mask,
                'nb_classes': self.nb_classes
            })

        # for testing
        self.test_x = tf.placeholder(
            tf.float32,
            shape=[None, None, None, self.input_shape[-1]],
            name='test_x_input')
        self.test_logits = self.classifier(self.test_x)
        self.test_y = tf.nn.softmax(self.test_logits)

        (self.train_op, self.learning_rate,
         self.global_step) = get_optimizer_by_config(
             self.config['optimizer'],
             self.config['optimizer params'],
             target=self.loss,
             variables=self.classifier.vars)

        # model saver
        self.saver = tf.train.Saver(self.classifier.store_vars + [
            self.global_step,
        ])
Esempio n. 3
0
	def build_model(self):
		# network config
		self.config['discriminator params']['name'] = 'Discriminator'
		self.config['generator params']['name'] = 'Generator'
		self.discriminator = self._build_discriminator('discriminator')
		self.generator = self._build_generator('generator')

		# build model
		self.x_real = tf.placeholder(tf.float32, shape=[None, ] + list(self.input_shape), name='x_input')
		self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim], name='z')

		self.x_fake = self.generator(self.z)
		self.dis_real = self.discriminator(self.x_real)
		self.dis_fake = self.discriminator(self.x_fake)

		# loss config
		self.d_loss = get_loss('adversarial down', 'cross entropy', {'dis_real' : self.dis_real, 'dis_fake' : self.dis_fake})
		self.g_loss = get_loss('adversarial up', 'cross entropy', {'dis_fake' : self.dis_fake})

		# optimizer config
		self.global_step, self.global_step_update = get_global_step()

		# optimizer of discriminator configured without global step update
		# so we can keep the learning rate of discriminator the same as generator
		(self.d_train_op, 
			self.d_learning_rate, 
				self.d_global_step) = get_optimizer_by_config(self.config['discriminator optimizer'],
																self.config['discriminator optimizer params'],
																self.d_loss, self.discriminator.vars,
																self.global_step)
		(self.g_train_op, 
			self.g_learning_rate, 
				self.g_global_step) = get_optimizer_by_config(self.config['generator optimizer'],
																self.config['generator optimizer params'],
																self.g_loss, self.generator.vars,
																self.global_step, self.global_step_update)

		# model saver
		self.saver = tf.train.Saver(self.discriminator.store_vars 
									+ self.generator.store_vars
									+ [self.global_step])
Esempio n. 4
0
File: vae.py Progetto: xclmj/VAE-GAN
    def build_model(self):

        self.x_real = tf.placeholder(tf.float32,
                                     shape=[
                                         None,
                                     ] + list(self.input_shape),
                                     name='x_input')

        self.config['encoder params']['name'] = 'encoder'
        self.config['decoder params']['name'] = 'decoder'
        self.encoder = self._build_encoder('encoder')
        self.decoder = self._build_decoder('decoder')

        # build encoder
        self.mean_z, self.log_var_z = self.encoder(self.x_real)

        # sample z from mean_z and log_var_z
        sample_z = self.draw_sample(self.mean_z, self.log_var_z)

        # build decoder
        self.x_decode = self.decoder(sample_z)

        # build test decoder
        self.z_test = tf.placeholder(tf.float32,
                                     shape=[None, self.z_dim],
                                     name='z_test')
        self.x_test = self.decoder(self.z_test)

        # loss function
        self.kl_loss = (get_loss('kl', self.config['kl loss'], {
            'mean': self.mean_z,
            'log_var': self.log_var_z
        }) * self.config.get('kl loss prod', 1.0))

        self.recon_loss = (
            get_loss('reconstruction', self.config['reconstruction loss'], {
                'x': self.x_real,
                'y': self.x_decode
            }) * self.config.get('reconstruction loss prod', 1.0))

        self.loss = self.kl_loss + self.recon_loss

        # optimizer configure
        self.train_op, self.learning_rate, self.global_step = get_optimizer_by_config(
            self.config['optimizer'], self.config['optimizer params'],
            self.loss, self.vars)

        # model saver
        self.saver = tf.train.Saver(self.encoder.store_vars +
                                    self.decoder.store_vars + [
                                        self.global_step,
                                    ])
Esempio n. 5
0
	def build_model_m1(self):

		self.xu = tf.placeholder(tf.float32, shape=[None,] + self.input_shape, name='xu_input')

		###########################################################################
		# network define
		# 
		# x_encoder : x -> hx
		self.config['x encoder params']['name'] = 'EncoderHX_X'
		self.config['x encoder params']["output dims"] = self.hx_dim
		self.x_encoder = get_encoder(self.config['x encoder'], 
									self.config['x encoder params'], self.is_training)
		# decoder : hx -> x
		self.config['hx decoder params']['name'] = 'DecoderX_HX'
		# if self.config
		# self.config['hx decoder params']["output dims"] = int(np.product(self.input_shape))
		self.hx_decoder = get_decoder(self.config['hx decoder'], self.config['hx decoder params'], self.is_training)

		###########################################################################
		# for unsupervised training:
		# 
		# xu --> mean_hxu, log_var_hxu ==> kl loss
		#					|
		# 			   sample_hxu --> xu_decode ==> reconstruction loss
		mean_hxu, log_var_hxu = self.x_encoder(self.xu)
		sample_hxu = self.draw_sample(mean_hxu, log_var_hxu)
		xu_decode = self.hx_decoder(sample_hxu)

		self.m1_loss_kl_z = (get_loss('kl', 'gaussian', {'mean' : mean_hxu, 'log_var' : log_var_hxu})
								* self.m1_loss_weights.get('kl z loss weight', 1.0))
		self.m1_loss_recon = (get_loss('reconstruction', 'mse', {'x' : self.xu, 'y' : xu_decode})
								* self.m1_loss_weights.get('reconstruction loss weight', 1.0))
		self.m1_loss = self.m1_loss_kl_z + self.m1_loss_recon


		###########################################################################
		# optimizer configure
		self.m1_global_step, m1_global_step_update = get_global_step('m1_step')

		(self.m1_train_op, 
			self.m1_learning_rate, 
				_) = get_optimizer_by_config(self.config['m1 optimizer'], self.config['m1 optimizer params'], 
													self.m1_loss, self.m1_vars, self.m1_global_step, m1_global_step_update)
Esempio n. 6
0
    def build_model(self):

        self.xu = tf.placeholder(tf.float32,
                                 shape=[
                                     None,
                                 ] + self.input_shape,
                                 name='xu_input')
        self.xl = tf.placeholder(tf.float32,
                                 shape=[
                                     None,
                                 ] + self.input_shape,
                                 name='xl_input')
        self.yl = tf.placeholder(tf.float32,
                                 shape=[None, self.nb_classes],
                                 name='yl_input')

        ###########################################################################
        # network define
        #
        # x_encoder : x -> hx
        self.config['x encoder params']['name'] = 'EncoderHX_X'
        self.config['x encoder params']["output dims"] = self.hx_dim
        self.x_encoder = get_encoder(self.config['x encoder'],
                                     self.config['x encoder params'],
                                     self.is_training)
        # hx_y_encoder : [hx, y] -> hz
        self.config['hx y encoder params']['name'] = 'EncoderHZ_HXY'
        self.config['hx y encoder params']["output dims"] = self.hz_dim
        self.hx_y_encoder = get_encoder(self.config['hx y encoder'],
                                        self.config['hx y encoder params'],
                                        self.is_training)
        # hz_y_decoder : [hz, y] -> x_decode
        self.config['hz y decoder params']['name'] = 'DecoderX_HZY'
        self.config['hz y decoder params']["output dims"] = int(
            np.product(self.input_shape))
        self.hz_y_decoder = get_decoder(self.config['hz y decoder'],
                                        self.config['hz y decoder params'],
                                        self.is_training)
        # x_classifier : hx -> ylogits
        self.config['x classifier params']['name'] = 'ClassifierX'
        self.config['x classifier params']["output dims"] = self.nb_classes
        self.x_classifier = get_classifier(self.config['x classifier'],
                                           self.config['x classifier params'],
                                           self.is_training)

        ###########################################################################
        # for supervised training:
        #
        # xl --> mean_hxl, log_var_hxl
        #		  		  |
        #			 sample_hxl --> yllogits ==> classification loss
        #				  |
        # 			[sample_hxl, yl] --> mean_hzl, log_var_hzl ==> kl loss
        #				          |               |
        # 	  			        [yl,	   	   sample_hzl] --> xl_decode ==> reconstruction loss
        #

        hxl = self.x_encoder(self.xl)
        mean_hzl, log_var_hzl = self.hx_y_encoder(
            tf.concat([hxl, self.yl], axis=1))
        sample_hzl = self.draw_sample(mean_hzl, log_var_hzl)
        decode_xl = self.hz_y_decoder(tf.concat([sample_hzl, self.yl], axis=1))
        # decode_xl = self.hx_decoder(decode_hxl)

        yllogits = self.x_classifier(self.xl)

        self.su_loss_kl_z = (get_loss('kl', 'gaussian', {
            'mean': mean_hzl,
            'log_var': log_var_hzl,
        }) * self.loss_weights.get('kl z loss weight', 1.0))
        self.su_loss_recon = (get_loss('reconstruction', 'mse', {
            'x': self.xl,
            'y': decode_xl
        }) * self.loss_weights.get('reconstruction loss weight', 1.0))
        self.su_loss_cls = (get_loss('classification', 'cross entropy', {
            'logits': yllogits,
            'labels': self.yl
        }) * self.loss_weights.get('classiciation loss weight', 1.0))

        self.su_loss_reg = (
            get_loss('regularization', 'l2',
                     {'var_list': self.x_classifier.vars}) *
            self.loss_weights.get('regularization loss weight', 0.0001))

        self.su_loss = ((self.su_loss_kl_z + self.su_loss_recon +
                         self.su_loss_cls + self.su_loss_reg) *
                        self.loss_weights.get('supervised loss weight', 1.0))

        ###########################################################################
        # for unsupervised training:
        #
        # xu --> mean_hxu, log_var_hxu
        #                |
        #             sample_hxu --> yulogits --> yuprobs
        # 				  |
        #   		 [sample_hxu,    y0] --> mean_hzu0, log_var_hzu0 ==> kl_loss * yuprobs[0]
        # 				  |			  |					|
        #				  |			[y0,           sample_hzu0] --> decode_hxu0 ==> reconstruction loss * yuprobs[0]
        #				  |
        #   	     [sample_hxu,    y1] --> mean_hzu1, log_var_hzu1 ==> kl_loss * yuprobs[1]
        #				  |			  |			        |
        #				  |			[y1,           sample_hzu1] --> decode_hxu1 ==> reconstruction loss * yuprobs[1]
        #		.......
        #
        hxu = self.x_encoder(self.xu)
        yulogits = self.x_classifier(self.xu)
        yuprobs = tf.nn.softmax(yulogits)

        unsu_loss_kl_z_list = []
        unsu_loss_recon_list = []

        for i in range(self.nb_classes):
            yu_fake = tf.ones([
                tf.shape(self.xu)[0],
            ], dtype=tf.int32) * i
            yu_fake = tf.one_hot(yu_fake, depth=self.nb_classes)

            mean_hzu, log_var_hzu = self.hx_y_encoder(
                tf.concat([hxu, yu_fake], axis=1))
            sample_hzu = self.draw_sample(mean_hzu, log_var_hzu)
            decode_xu = self.hz_y_decoder(
                tf.concat([sample_hzu, yu_fake], axis=1))
            # decode_xu = self.hx_decoder(decode_hxu)

            unsu_loss_kl_z_list.append(
                get_loss(
                    'kl', 'gaussian', {
                        'mean': mean_hzu,
                        'log_var': log_var_hzu,
                        'instance_weight': yuprobs[:, i]
                    }))

            unsu_loss_recon_list.append(
                get_loss('reconstruction', 'mse', {
                    'x': self.xu,
                    'y': decode_xu,
                    'instance_weight': yuprobs[:, i]
                }))

        self.unsu_loss_kl_y = (
            get_loss('kl', 'bernoulli', {'probs': yuprobs}) *
            self.loss_weights.get('kl y loss weight', 1.0))
        self.unsu_loss_kl_z = (tf.reduce_sum(unsu_loss_kl_z_list) *
                               self.loss_weights.get('kl z loss weight', 1.0))
        self.unsu_loss_recon = (
            tf.reduce_sum(unsu_loss_recon_list) *
            self.loss_weights.get('reconstruction loss weight', 1.0))

        self.unsu_loss_reg = (
            get_loss('regularization', 'l2',
                     {'var_list': self.x_classifier.vars}) *
            self.loss_weights.get('regularization loss weight', 0.0001))

        self.unsu_loss = (
            (self.unsu_loss_kl_z + self.unsu_loss_recon + self.unsu_loss_kl_y +
             self.unsu_loss_reg) *
            self.loss_weights.get('unsupervised loss weight', 1.0))

        self.xt = tf.placeholder(tf.float32,
                                 shape=[
                                     None,
                                 ] + self.input_shape,
                                 name='xt_input')

        ###########################################################################
        # for test models
        #
        # xt --> mean_hxt, log_var_hxt
        #               |
        #             sample_hxt --> ytlogits --> ytprobs
        # 			   |			    			 |
        #		     [sample_hxt,    			  ytprobs] --> mean_hzt, log_var_hzt
        #
        hxt = self.x_encoder(self.xt)
        ytlogits = self.x_classifier(self.xt)
        self.ytprobs = tf.nn.softmax(ytlogits)
        self.mean_hzt, self.log_var_hzt = self.hx_y_encoder(
            tf.concat([hxt, self.ytprobs], axis=1))

        ###########################################################################
        # optimizer configure

        global_step, global_step_update = get_global_step()

        (self.supervised_train_op, self.supervised_learning_rate,
         _) = get_optimizer_by_config(self.config['optimizer'],
                                      self.config['optimizer params'],
                                      self.su_loss, self.vars, global_step,
                                      global_step_update)
        (self.unsupervised_train_op, self.unsupervised_learning_rate,
         _) = get_optimizer_by_config(self.config['optimizer'],
                                      self.config['optimizer parmas'],
                                      self.unsu_loss, self.vars, global_step,
                                      global_step_update)

        ###########################################################################
        # model saver
        self.saver = tf.train.Saver(self.vars + [
            self.global_step,
        ])
Esempio n. 7
0
    def build_model(self):
        # network config
        self.config['discriminator params']['name'] = 'Discriminator'
        self.config['generator params']['name'] = 'Generator'
        self.discriminator = self._build_discriminator('discriminator')
        self.generator = self._build_generator('generator')

        # build model
        self.x_real = tf.placeholder(tf.float32,
                                     shape=[
                                         None,
                                     ] + list(self.input_shape),
                                     name='x_input')
        self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim], name='z')

        self.x_fake = self.generator(self.z)
        self.dis_real, self.dis_real_end_points = self.discriminator.features(
            self.x_real)
        self.dis_fake, self.dis_fake_end_points = self.discriminator.features(
            self.x_fake)

        # loss config
        self.d_loss_adv = get_loss('adversarial down', 'cross entropy', {
            'dis_real': self.dis_real,
            'dis_fake': self.dis_fake
        })
        if self.use_feature_matching_loss:
            if self.feature_matching_end_points is None:
                self.feature_matching_end_points = [
                    k for k in self.dis_real_end_points.keys() if 'conv' in k
                ]
                print('feature matching end points : ',
                      self.feature_matching_end_points)
            self.d_loss_fm = get_loss(
                'feature matching', 'l2', {
                    'fx': self.dis_real_end_points,
                    'fy': self.dis_fake_end_points,
                    'fnames': self.feature_matching_end_points
                })
            self.d_loss_fm *= self.config.get('feature matching loss weight',
                                              0.01)
            self.d_loss = self.d_loss_adv + self.d_loss_fm
        else:
            self.d_loss = self.d_loss_adv

        self.g_loss = get_loss('adversarial up', 'cross entropy',
                               {'dis_fake': self.dis_fake})

        # optimizer config
        self.global_step, self.global_step_update = get_global_step()

        # optimizer of discriminator configured without global step update
        # so we can keep the learning rate of discriminator the same as generator
        (self.d_train_op, self.d_learning_rate,
         self.d_global_step) = get_optimizer_by_config(
             self.config['discriminator optimizer'],
             self.config['discriminator optimizer params'], self.d_loss,
             self.discriminator.vars, self.global_step)
        (self.g_train_op, self.g_learning_rate,
         self.g_global_step) = get_optimizer_by_config(
             self.config['generator optimizer'],
             self.config['generator optimizer params'], self.g_loss,
             self.generator.vars, self.global_step, self.global_step_update)

        # model saver
        self.saver = tf.train.Saver(self.discriminator.store_vars +
                                    self.generator.store_vars +
                                    [self.global_step])
Esempio n. 8
0
    def build_model(self):
        # network config
        self.config['discriminator params']['name'] = 'Discriminator'
        self.discriminator = get_discriminator(
            self.config['discriminator'], self.config['discriminator params'],
            self.is_training)

        self.config['generator params']['name'] = 'Generator'
        self.generator = get_generator(self.config['generator'],
                                       self.config['generator params'],
                                       self.is_training)

        self.config['classifier params']['name'] = 'Classifier'
        self.classifier = get_classifier(self.config['classifier'],
                                         self.config['classifier params'],
                                         self.is_training)

        # build model
        self.x_real = tf.placeholder(tf.float32,
                                     shape=[
                                         None,
                                     ] + list(self.input_shape),
                                     name='x_input')
        self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim], name='z')
        self.c = tf.placeholder(tf.float32, shape=[None, self.c_dim], name='z')

        self.x_fake = self.generator(self.z)

        dis_real = self.discriminator(self.x_real)
        dis_fake = self.discriminator(self.x_fake)

        # loss config
        eplison = tf.random_uniform(shape=[tf.shape(self.x_real)[0], 1, 1, 1],
                                    minval=0.0,
                                    maxval=1.0)
        x_hat = (eplison * self.x_real) + ((1 - eplison) * self.x_fake)
        dis_hat = self.discriminator(x_hat)

        self.d_loss_adv = (get_loss('adversarial down', 'wassterstein', {
            'dis_real': dis_real,
            'dis_fake': dis_fake
        }) * self.config.get('adversarial loss weight', 1.0))

        self.d_loss_gp = (get_loss('gradient penalty', 'l2', {
            'x': x_hat,
            'y': dis_hat
        }) * self.config.get('gradient penalty loss weight', 10.0))

        self.d_loss = self.d_loss_gp + self.d_loss_adv
        self.g_loss = get_loss('adversarial up', 'wassterstein',
                               {'dis_fake': dis_fake})

        # optimizer config
        self.global_step, self.global_step_update = get_global_step()

        # optimizer of discriminator
        # configured with global step and without global step update
        # so we can keep the learning rate of discriminator the same as generator
        (self.d_train_op, self.d_learning_rate,
         self.d_global_step) = get_optimizer_by_config(
             self.config['discriminator optimizer'],
             self.config['discriminator optimizer params'], self.d_loss,
             self.discriminator.vars, self.global_step)

        (self.g_train_op, self.g_learning_rate,
         self.g_global_step) = get_optimizer_by_config(
             self.config['generator optimizer'],
             self.config['generator optimizer params'], self.g_loss,
             self.generator.vars, self.global_step, self.global_step_update)

        # model saver
        self.saver = tf.train.Saver(self.discriminator.store_vars +
                                    self.generator.store_vars +
                                    [self.global_step])
Esempio n. 9
0
    def build_model(self):
        # network config
        self.config['discriminator params']['name'] = 'Discriminator'
        self.config['discriminator params'][
            "output dims"] = self.nb_classes + 1
        self.config['generator params']['name'] = 'Generator'
        self.discriminator = self._build_discriminator('discriminator')
        self.generator = self._build_generator('generator')

        # build model
        self.x_real = tf.placeholder(tf.float32,
                                     shape=[
                                         None,
                                     ] + list(self.input_shape),
                                     name='x_real')
        self.label_real = tf.placeholder(tf.float32,
                                         shape=[None, self.nb_classes],
                                         name='label_real')
        self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim], name='z')

        self.x_fake = self.generator(self.z)
        self.dis_real, self.dis_real_end_points = self.discriminator.features(
            self.x_real)
        self.dis_fake, self.dis_fake_end_points = self.discriminator.features(
            self.x_fake)

        self.prob_real = tf.nn.softmax(self.dis_real)

        # self.d_loss_feature_matching = get_loss('feature matching', 'l2',
        # 										{'fx': self.dis_real_end_points, 'fy': self.dis_fake_end_points, 'fnames' : self.feature_matching_end_points})
        # self.d_loss_feature_matching *= self.config.get('feature matching loss weight', 0.01)

        # supervised loss config
        self.d_su_loss_adv = get_loss(
            'adversarial down', 'supervised cross entropy', {
                'dis_real': self.dis_real,
                'dis_fake': self.dis_fake,
                'label': self.label_real
            })
        self.d_su_loss_adv *= self.config.get('adversarial loss weight', 1.0)

        # self.d_su_loss = self.d_su_loss_adv + self.d_loss_feature_matching
        self.d_su_loss = self.d_su_loss_adv
        # self.g_su_loss = get_loss('adversarial up', 'supervised cross entropy', {'dis_fake' : self.dis_fake, 'label': self.label_real})

        # unsupervised loss config
        self.d_unsu_loss_adv = get_loss('adversarial down',
                                        'unsupervised cross entropy', {
                                            'dis_real': self.dis_real,
                                            'dis_fake': self.dis_fake
                                        })
        self.d_unsu_loss_adv *= self.config.get('adversarial loss weight', 1.0)
        # self.d_unsu_loss = self.d_unsu_loss_adv + self.d_loss_feature_matching
        self.d_unsu_loss = self.d_unsu_loss_adv

        self.g_unsu_loss = get_loss('adversarial up',
                                    'unsupervised cross entropy',
                                    {'dis_fake': self.dis_fake})

        # optimizer config
        self.global_step, self.global_step_update = get_global_step()

        # optimizer of discriminator configured without global step update
        # so we can keep the learning rate of discriminator the same as generator
        (self.d_su_train_op, self.d_su_learning_rate,
         self.d_su_global_step) = get_optimizer_by_config(
             self.config['discriminator optimizer'],
             self.config['discriminator optimizer params'], self.d_su_loss,
             self.discriminator.vars, self.global_step,
             self.global_step_update)
        # (self.g_su_train_op,
        # 	self.g_su_learning_rate,
        # 		self.g_su_global_step) = get_optimizer_by_config(self.config['generator optimizer'],
        # 														self.config['generator optimizer params'],
        # 														self.g_su_loss, self.generator.vars,
        # 														self.global_step, self.global_step_update)

        (self.d_unsu_train_op, self.d_unsu_learning_rate,
         self.d_unsu_global_step) = get_optimizer_by_config(
             self.config['discriminator optimizer'],
             self.config['discriminator optimizer params'], self.d_unsu_loss,
             self.discriminator.vars, self.global_step)
        (self.g_unsu_train_op, self.g_unsu_learning_rate,
         self.g_unsu_global_step) = get_optimizer_by_config(
             self.config['generator optimizer'],
             self.config['generator optimizer params'], self.g_unsu_loss,
             self.generator.vars, self.global_step, self.global_step_update)

        # model saver
        self.saver = tf.train.Saver(self.discriminator.store_vars +
                                    self.generator.store_vars +
                                    [self.global_step])
Esempio n. 10
0
    def build_model(self):
        # network config
        self.config['discriminator params']['name'] = 'Discriminator'
        self.config['generator params']['name'] = 'Generator'

        self.discriminator = self._build_discriminator('discriminator')
        self.generator = self._build_generator('generator')

        # build model
        self.x_real = tf.placeholder(tf.float32,
                                     shape=[
                                         None,
                                     ] + list(self.input_shape),
                                     name='x_input')
        self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim], name='z')

        self.x_fake = self.generator(self.z)
        self.dis_real, self.dis_real_end_points = self.discriminator.features(
            self.x_real)
        self.dis_fake, self.dis_fake_end_points = self.discriminator.features(
            self.x_fake)

        # loss config
        x_dims = len(self.input_shape)
        if x_dims == 1:
            eplison = tf.random_uniform(shape=[tf.shape(self.x_real)[0], 1],
                                        minval=0.0,
                                        maxval=1.0)
        elif x_dims == 3:
            eplison = tf.random_uniform(
                shape=[tf.shape(self.x_real)[0], 1, 1, 1],
                minval=0.0,
                maxval=1.0)
        else:
            raise NotImplementedError
        x_hat = (eplison * self.x_real) + ((1 - eplison) * self.x_fake)
        dis_hat = self.discriminator(x_hat)

        self.d_loss_list = []
        self.d_loss_adv = (
            get_loss('adversarial down', 'wassterstein', {
                'dis_real': self.dis_real,
                'dis_fake': self.dis_fake
            }) * self.config.get('adversarial loss weight', 1.0))
        self.d_loss_list.append(self.d_loss_adv)

        if self.use_feature_matching_loss:
            if self.feature_matching_end_points is None:
                self.feature_matching_end_points = [
                    k for k in self.dis_real_end_points.keys() if 'conv' in k
                ]
                print('feature matching end points : ',
                      self.feature_matching_end_points)
            self.d_loss_fm = get_loss(
                'feature matching', 'l2', {
                    'fx': self.dis_real_end_points,
                    'fy': self.dis_fake_end_points,
                    'fnames': self.feature_matching_end_points
                })
            self.d_loss_fm *= self.config.get('feature matching loss weight',
                                              0.01)
            self.d_loss_list.append(self.d_loss_fm)

        if self.use_gradient_penalty:
            self.d_loss_gp = (get_loss('gradient penalty', 'l2', {
                'x': x_hat,
                'y': dis_hat
            }) * self.config.get('gradient penalty loss weight', 10.0))
            self.d_loss_list.append(self.d_loss_gp)

        self.d_loss = tf.reduce_sum(self.d_loss_list)
        self.g_loss = get_loss('adversarial up', 'wassterstein',
                               {'dis_fake': self.dis_fake})

        # optimizer config
        self.global_step, self.global_step_update = get_global_step()

        if not self.use_gradient_penalty:
            self.clip_discriminator = [
                tf.assign(
                    tf.clip_by_value(var, self.weight_clip_bound[0],
                                     self.weight_clip_bound[1]))
                for var in self.discriminator.vars
            ]

        # optimizer of discriminator configured without global step update
        # so we can keep the learning rate of discriminator the same as generator
        (self.d_train_op, self.d_learning_rate,
         self.d_global_step) = get_optimizer_by_config(
             self.config['discriminator optimizer'],
             self.config['discriminator optimizer params'], self.d_loss,
             self.discriminator.vars, self.global_step)
        (self.g_train_op, self.g_learning_rate,
         self.g_global_step) = get_optimizer_by_config(
             self.config['generator optimizer'],
             self.config['generator optimizer params'], self.g_loss,
             self.generator.vars, self.global_step, self.global_step_update)

        # model saver
        self.saver = tf.train.Saver(self.discriminator.store_vars +
                                    self.generator.store_vars +
                                    [self.global_step])
Esempio n. 11
0
	def build_model(self):

		self.img_u = tf.placeholder(tf.float32, shape=[None,] + self.input_shape, name='image_unlabelled_input')
		self.img_l = tf.placeholder(tf.float32, shape=[None,] + self.input_shape, name='image_labelled_input')
		self.mask_l = tf.placeholder(tf.float32, shape=[None,] + self.mask_size + [self.nb_classes], name='mask_input')

		###########################################################################
		# network define
		# 
		self.config['classifier params']['name'] = 'Segmentation'
		self.config['classifier params']["output dims"] = self.hx_dim
		self.seg_classifier = get_classifier(self.config['classifier'], 
									self.config['classifier params'], self.is_training)


		self.config['discriminator params']['name'] = 'Segmentation'
		self.config['discriminator params']["output dims"] = 1
		self.config['discriminator params']['output_activation'] = 'none'
		self.discriminator = get_discriminator(self.config['discriminator'], 
									self.config['discriminator params'], self.is_training)

		###########################################################################
		# for supervised training:
		self.mask_generated = self.seg_classifier(self.img_l)

		real_concated = tf.concatenate([self.img_l, self.mask_l], axis=-1)
		fake_concated = tf.concatenate([self.img_l, self.mask_generated], axis=-1)

		dis_real_concated = self.discriminator(real_concated)
		dis_fake_concated = self.discriminator(fake_concated)

		eplison = tf.random_uniform(shape=[tf.shape(self.img_l)[0], 1, 1, 1], minval=0.0, maxval=1.0)
		mask_hat = eplison * self.mask_l + (1 - eplison) * self.mask_generated
		concat_hat = tf.concatenate([self.img_l, mask_hat], axis=-1)

		dis_hat_concated = self.discriminator(concat_hat)


		self.d_su_loss_adv = (get_loss('adversarial down', 'wassterstein', {'dis_real' : dis_real_concated, 'dis_fake' : dis_fake_concated})
								* self.config.get('adversarial loss weight', 1.0))
		self.d_su_loss_gp = (get_loss('gradient penalty', 'l2', {'x' : concat_hat, 'y' : dis_hat_concated})
								* self.config.get('gradient penalty loss weight', 1.0))
		self.d_su_loss = self.d_su_loss_adv + self.d_su_loss_gp

		self.g_su_loss_adv = (get_loss('adversarial up', 'wassterstein', {'dis_fake' : dis_fake_concated})
								* self.config.get('adversarial loss weight', 1.0))

		self.g_su_loss_cls = (get_loss('segmentation', 'l2', {'predict' : self.mask_generated, 'mask':self.mask_l})
								* self.config.get('segmentation loss weight', 1.0))

		self.g_su_loss = self.g_su_loss_adv + self.g_su_loss_cls


		###########################################################################
		# optimizer configure
		(self.d_su_train_op,
			self.d_su_learning_rate,
				self.d_su_global_step) = get_optimizer_by_config(self.config['supervised optimizer'], self.config['supervised optimizer params'],
												self.d_su_loss, self.discriminator.vars, global_step_name='d_global_step_su')

		(self.g_su_train_op,
			self.g_su_learning_rate,
				self.g_su_global_step) = get_optimizer_by_config(self.config['supervised optimizer'], self.config['supervised optimizer params'],
												self.g_su_loss, self.generator.vars, global_step_name='g_global_step_su')

		###########################################################################
		# # for test models
		# # 
		# # xt --> mean_hxt, log_var_hxt
		# #               |
		# #             sample_hxt --> ytlogits --> ytprobs
		# # 			   |			    			 |
		# #		     [sample_hxt,    			  ytprobs] --> mean_hzt, log_var_hzt
		# #
		# mean_hxt, log_var_hxt = self.x_encoder(self.xt)
		# sample_hxt = self.draw_sample(mean_hxt, log_var_hxt)
		# ytlogits = self.hx_classifier(sample_hxt)
		# # test sample class probilities
		# self.ytprobs = tf.nn.softmax(ytlogits)
		# # test sample hidden variable distribution
		# self.mean_hzt, self.log_var_hzt = self.hx_y_encoder(tf.concat([sample_hxt, self.ytprobs], axis=1))


		###########################################################################
		# model saver
		self.saver = tf.train.Saver(self.store_vars + [self.d_su_global_step, self.g_su_global_step])