コード例 #1
0
ファイル: aae.py プロジェクト: yanzhicong/VAE-GAN
	def build_model(self):
		# network config
		self.discriminator = self._build_discriminator('discriminator', params={
			'name':'Discriminator',
			"output dims":1,
			'output_activation':'none'})
		self.encoder = self._build_encoder('encoder', params={
			'name':'Encoder',
			"output dims":self.z_dim})
		self.decoder = self._build_decoder('decoder', params={'name':'Decoder'})

		# build model
		self.img = tf.placeholder(tf.float32, shape=[None, ] + list(self.input_shape), name='img')
		if self.has_label:
			self.label = tf.placeholder(tf.float32, shape=[None, self.nb_classes], name='label')
		self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim], name='z')
		if self.has_label:
			self.z_label = tf.placeholder(tf.float32, shape=[None, self.nb_classes], name='z_label')

		self.z_sample = self.encoder(self.img)

		self.img_recon = self.decoder(self.z_sample)

		if self.has_label:
			self.dis_real = self.discriminator(tf.concat([self.z, self.z_label], axis=1))
			self.dis_fake = self.discriminator(tf.concat([self.z_sample, self.label], axis=1))
		else:
			self.dis_real = self.discriminator(self.z)
			self.dis_fake = self.discriminator(self.z_sample)

		# generate image from z:
		self.img_generate = self.decoder(self.z)

		# loss config
		self.loss_adv_down = get_loss('adversarial down', 'cross entropy', {'dis_real': self.dis_real, 'dis_fake': self.dis_fake})
		self.loss_adv_up = get_loss('adversarial up', 'cross entropy', {'dis_fake': self.dis_fake})
		self.loss_recon = get_loss('reconstruction', 'l2', {'x': self.img, 'y': self.img_recon})

		# optimizer config
		self.global_step, self.global_step_update = self._build_step_var('global_step')

		self.train_auto_encoder, _ = self._build_train_function('auto-encoder', 
												self.loss_recon, self.encoder.vars + self.decoder.vars, 
												step=self.global_step, build_summary=self.has_summary)
		 
		self.train_discriminator, _ = self._build_train_function('discriminator', 
												self.loss_adv_down, self.discriminator.vars,
												step=self.global_step, build_summary=self.has_summary)

		self.train_encoder, _ = self._build_train_function('encoder', 
												self.loss_adv_up, self.encoder.vars,
												step=self.global_step, build_summary=self.has_summary)

		# model saver
		self.saver = tf.train.Saver(self.discriminator.store_vars
									+ self.encoder.store_vars
									+ self.decoder.store_vars
									+ [self.global_step])
コード例 #2
0
ファイル: vae.py プロジェクト: xclmj/VAE-GAN
    def build_model(self):

        self.x_real = tf.placeholder(tf.float32,
                                     shape=[
                                         None,
                                     ] + list(self.input_shape),
                                     name='x_input')

        self.config['encoder params']['name'] = 'encoder'
        self.config['decoder params']['name'] = 'decoder'
        self.encoder = self._build_encoder('encoder')
        self.decoder = self._build_decoder('decoder')

        # build encoder
        self.mean_z, self.log_var_z = self.encoder(self.x_real)

        # sample z from mean_z and log_var_z
        sample_z = self.draw_sample(self.mean_z, self.log_var_z)

        # build decoder
        self.x_decode = self.decoder(sample_z)

        # build test decoder
        self.z_test = tf.placeholder(tf.float32,
                                     shape=[None, self.z_dim],
                                     name='z_test')
        self.x_test = self.decoder(self.z_test)

        # loss function
        self.kl_loss = (get_loss('kl', self.config['kl loss'], {
            'mean': self.mean_z,
            'log_var': self.log_var_z
        }) * self.config.get('kl loss prod', 1.0))

        self.recon_loss = (
            get_loss('reconstruction', self.config['reconstruction loss'], {
                'x': self.x_real,
                'y': self.x_decode
            }) * self.config.get('reconstruction loss prod', 1.0))

        self.loss = self.kl_loss + self.recon_loss

        # optimizer configure
        self.train_op, self.learning_rate, self.global_step = get_optimizer_by_config(
            self.config['optimizer'], self.config['optimizer params'],
            self.loss, self.vars)

        # model saver
        self.saver = tf.train.Saver(self.encoder.store_vars +
                                    self.decoder.store_vars + [
                                        self.global_step,
                                    ])
コード例 #3
0
	def build_model(self):

		self.classifier = self._build_classifier('classifier', params={
			'name' : 'classifier',
			'output dims' : self.nb_classes
		})

		# for training
		self.x = tf.placeholder(tf.float32, shape=[None,]  + self.input_shape, name='x_input')
		self.label = tf.placeholder(tf.float32, shape=[None, self.nb_classes], name='label')

		self.logits, self.end_points = self.classifier.features(self.x)

		self.loss = get_loss('classification', self.config['classification loss'], 
						{'logits' : self.logits, 'labels' : self.label})
		self.train_acc = tf.summary.scalar('train acc', get_metric('accuracy', 'top1', 
						{'logits': self.logits, 'labels':self.label}))

		# for testing
		self.probs = tf.nn.softmax(self.logits)

		self.global_step, self.global_step_update = self._build_step_var('global_step')
	
		if self.finetune_steps > 0:
			self.finetune_classifier, _ = self._build_train_function('finetune', self.loss, self.classifier.top_vars, 
						step=self.global_step, step_update=self.global_step_update, 
						build_summary=self.has_summary, sum_list=[self.train_acc,])
	
		self.train_classifier, _ = self._build_train_function('optimizer', self.loss, self.classifier.vars, 
						step=self.global_step, step_update=self.global_step_update, 
						build_summary=self.has_summary, sum_list=[self.train_acc,])
		# model saver
		self.saver = tf.train.Saver(self.classifier.store_vars + [self.global_step,])
コード例 #4
0
ファイル: semi_dgm.py プロジェクト: yanzhicong/VAE-GAN
	def build_model_m1(self):

		self.xu = tf.placeholder(tf.float32, shape=[None,] + self.input_shape, name='xu_input')

		###########################################################################
		# network define
		# 
		# x_encoder : x -> hx
		self.config['x encoder params']['name'] = 'EncoderHX_X'
		self.config['x encoder params']["output dims"] = self.hx_dim
		self.x_encoder = get_encoder(self.config['x encoder'], 
									self.config['x encoder params'], self.is_training)
		# decoder : hx -> x
		self.config['hx decoder params']['name'] = 'DecoderX_HX'
		# if self.config
		# self.config['hx decoder params']["output dims"] = int(np.product(self.input_shape))
		self.hx_decoder = get_decoder(self.config['hx decoder'], self.config['hx decoder params'], self.is_training)

		###########################################################################
		# for unsupervised training:
		# 
		# xu --> mean_hxu, log_var_hxu ==> kl loss
		#					|
		# 			   sample_hxu --> xu_decode ==> reconstruction loss
		mean_hxu, log_var_hxu = self.x_encoder(self.xu)
		sample_hxu = self.draw_sample(mean_hxu, log_var_hxu)
		xu_decode = self.hx_decoder(sample_hxu)

		self.m1_loss_kl_z = (get_loss('kl', 'gaussian', {'mean' : mean_hxu, 'log_var' : log_var_hxu})
								* self.m1_loss_weights.get('kl z loss weight', 1.0))
		self.m1_loss_recon = (get_loss('reconstruction', 'mse', {'x' : self.xu, 'y' : xu_decode})
								* self.m1_loss_weights.get('reconstruction loss weight', 1.0))
		self.m1_loss = self.m1_loss_kl_z + self.m1_loss_recon


		###########################################################################
		# optimizer configure
		self.m1_global_step, m1_global_step_update = get_global_step('m1_step')

		(self.m1_train_op, 
			self.m1_learning_rate, 
				_) = get_optimizer_by_config(self.config['m1 optimizer'], self.config['m1 optimizer params'], 
													self.m1_loss, self.m1_vars, self.m1_global_step, m1_global_step_update)
コード例 #5
0
ファイル: segmentation.py プロジェクト: yanzhicong/VAE-GAN
    def build_model(self):

        self.config['classifier params']['name'] = 'classifier'
        self.config['classifier params']["output dims"] = self.nb_classes

        self.classifier = get_classifier(self.config['classifier'],
                                         self.config['classifier params'],
                                         self.is_training)

        # for training
        self.x = tf.placeholder(tf.float32,
                                shape=[
                                    None,
                                ] + self.input_shape,
                                name='x_input')
        self.mask = tf.placeholder(tf.float32,
                                   shape=[
                                       None,
                                   ] + self.mask_shape,
                                   name='mask')

        self.logits, self.end_points = self.classifier.features(self.x)

        self.loss = get_loss('segmentation', self.config['segmentation loss'],
                             {
                                 'logits': self.logits,
                                 'mask': self.mask
                             })

        self.train_miou = get_metric(
            'segmentation', 'miou', {
                'logits': self.logits,
                'mask': self.mask,
                'nb_classes': self.nb_classes
            })

        # for testing
        self.test_x = tf.placeholder(
            tf.float32,
            shape=[None, None, None, self.input_shape[-1]],
            name='test_x_input')
        self.test_logits = self.classifier(self.test_x)
        self.test_y = tf.nn.softmax(self.test_logits)

        (self.train_op, self.learning_rate,
         self.global_step) = get_optimizer_by_config(
             self.config['optimizer'],
             self.config['optimizer params'],
             target=self.loss,
             variables=self.classifier.vars)

        # model saver
        self.saver = tf.train.Saver(self.classifier.store_vars + [
            self.global_step,
        ])
コード例 #6
0
ファイル: began.py プロジェクト: yanzhicong/VAE-GAN
	def build_model(self):
		# network config
		self.config['discriminator params']['name'] = 'Discriminator'
		self.config['generator params']['name'] = 'Generator'
		self.discriminator = self._build_discriminator('discriminator')
		self.generator = self._build_generator('generator')

		# build model
		self.x_real = tf.placeholder(tf.float32, shape=[None, ] + list(self.input_shape), name='x_input')
		self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim], name='z')

		self.x_fake = self.generator(self.z)
		self.dis_real = self.discriminator(self.x_real)
		self.dis_fake = self.discriminator(self.x_fake)

		# loss config
		self.d_loss = get_loss('adversarial down', 'cross entropy', {'dis_real' : self.dis_real, 'dis_fake' : self.dis_fake})
		self.g_loss = get_loss('adversarial up', 'cross entropy', {'dis_fake' : self.dis_fake})

		# optimizer config
		self.global_step, self.global_step_update = get_global_step()

		# optimizer of discriminator configured without global step update
		# so we can keep the learning rate of discriminator the same as generator
		(self.d_train_op, 
			self.d_learning_rate, 
				self.d_global_step) = get_optimizer_by_config(self.config['discriminator optimizer'],
																self.config['discriminator optimizer params'],
																self.d_loss, self.discriminator.vars,
																self.global_step)
		(self.g_train_op, 
			self.g_learning_rate, 
				self.g_global_step) = get_optimizer_by_config(self.config['generator optimizer'],
																self.config['generator optimizer params'],
																self.g_loss, self.generator.vars,
																self.global_step, self.global_step_update)

		# model saver
		self.saver = tf.train.Saver(self.discriminator.store_vars 
									+ self.generator.store_vars
									+ [self.global_step])
コード例 #7
0
    def build_model(self):

        self.xu = tf.placeholder(tf.float32,
                                 shape=[
                                     None,
                                 ] + self.input_shape,
                                 name='xu_input')
        self.xl = tf.placeholder(tf.float32,
                                 shape=[
                                     None,
                                 ] + self.input_shape,
                                 name='xl_input')
        self.yl = tf.placeholder(tf.float32,
                                 shape=[None, self.nb_classes],
                                 name='yl_input')

        ###########################################################################
        # network define
        #
        # x_encoder : x -> hx
        self.config['x encoder params']['name'] = 'EncoderHX_X'
        self.config['x encoder params']["output dims"] = self.hx_dim
        self.x_encoder = get_encoder(self.config['x encoder'],
                                     self.config['x encoder params'],
                                     self.is_training)
        # hx_y_encoder : [hx, y] -> hz
        self.config['hx y encoder params']['name'] = 'EncoderHZ_HXY'
        self.config['hx y encoder params']["output dims"] = self.hz_dim
        self.hx_y_encoder = get_encoder(self.config['hx y encoder'],
                                        self.config['hx y encoder params'],
                                        self.is_training)
        # hz_y_decoder : [hz, y] -> x_decode
        self.config['hz y decoder params']['name'] = 'DecoderX_HZY'
        self.config['hz y decoder params']["output dims"] = int(
            np.product(self.input_shape))
        self.hz_y_decoder = get_decoder(self.config['hz y decoder'],
                                        self.config['hz y decoder params'],
                                        self.is_training)
        # x_classifier : hx -> ylogits
        self.config['x classifier params']['name'] = 'ClassifierX'
        self.config['x classifier params']["output dims"] = self.nb_classes
        self.x_classifier = get_classifier(self.config['x classifier'],
                                           self.config['x classifier params'],
                                           self.is_training)

        ###########################################################################
        # for supervised training:
        #
        # xl --> mean_hxl, log_var_hxl
        #		  		  |
        #			 sample_hxl --> yllogits ==> classification loss
        #				  |
        # 			[sample_hxl, yl] --> mean_hzl, log_var_hzl ==> kl loss
        #				          |               |
        # 	  			        [yl,	   	   sample_hzl] --> xl_decode ==> reconstruction loss
        #

        hxl = self.x_encoder(self.xl)
        mean_hzl, log_var_hzl = self.hx_y_encoder(
            tf.concat([hxl, self.yl], axis=1))
        sample_hzl = self.draw_sample(mean_hzl, log_var_hzl)
        decode_xl = self.hz_y_decoder(tf.concat([sample_hzl, self.yl], axis=1))
        # decode_xl = self.hx_decoder(decode_hxl)

        yllogits = self.x_classifier(self.xl)

        self.su_loss_kl_z = (get_loss('kl', 'gaussian', {
            'mean': mean_hzl,
            'log_var': log_var_hzl,
        }) * self.loss_weights.get('kl z loss weight', 1.0))
        self.su_loss_recon = (get_loss('reconstruction', 'mse', {
            'x': self.xl,
            'y': decode_xl
        }) * self.loss_weights.get('reconstruction loss weight', 1.0))
        self.su_loss_cls = (get_loss('classification', 'cross entropy', {
            'logits': yllogits,
            'labels': self.yl
        }) * self.loss_weights.get('classiciation loss weight', 1.0))

        self.su_loss_reg = (
            get_loss('regularization', 'l2',
                     {'var_list': self.x_classifier.vars}) *
            self.loss_weights.get('regularization loss weight', 0.0001))

        self.su_loss = ((self.su_loss_kl_z + self.su_loss_recon +
                         self.su_loss_cls + self.su_loss_reg) *
                        self.loss_weights.get('supervised loss weight', 1.0))

        ###########################################################################
        # for unsupervised training:
        #
        # xu --> mean_hxu, log_var_hxu
        #                |
        #             sample_hxu --> yulogits --> yuprobs
        # 				  |
        #   		 [sample_hxu,    y0] --> mean_hzu0, log_var_hzu0 ==> kl_loss * yuprobs[0]
        # 				  |			  |					|
        #				  |			[y0,           sample_hzu0] --> decode_hxu0 ==> reconstruction loss * yuprobs[0]
        #				  |
        #   	     [sample_hxu,    y1] --> mean_hzu1, log_var_hzu1 ==> kl_loss * yuprobs[1]
        #				  |			  |			        |
        #				  |			[y1,           sample_hzu1] --> decode_hxu1 ==> reconstruction loss * yuprobs[1]
        #		.......
        #
        hxu = self.x_encoder(self.xu)
        yulogits = self.x_classifier(self.xu)
        yuprobs = tf.nn.softmax(yulogits)

        unsu_loss_kl_z_list = []
        unsu_loss_recon_list = []

        for i in range(self.nb_classes):
            yu_fake = tf.ones([
                tf.shape(self.xu)[0],
            ], dtype=tf.int32) * i
            yu_fake = tf.one_hot(yu_fake, depth=self.nb_classes)

            mean_hzu, log_var_hzu = self.hx_y_encoder(
                tf.concat([hxu, yu_fake], axis=1))
            sample_hzu = self.draw_sample(mean_hzu, log_var_hzu)
            decode_xu = self.hz_y_decoder(
                tf.concat([sample_hzu, yu_fake], axis=1))
            # decode_xu = self.hx_decoder(decode_hxu)

            unsu_loss_kl_z_list.append(
                get_loss(
                    'kl', 'gaussian', {
                        'mean': mean_hzu,
                        'log_var': log_var_hzu,
                        'instance_weight': yuprobs[:, i]
                    }))

            unsu_loss_recon_list.append(
                get_loss('reconstruction', 'mse', {
                    'x': self.xu,
                    'y': decode_xu,
                    'instance_weight': yuprobs[:, i]
                }))

        self.unsu_loss_kl_y = (
            get_loss('kl', 'bernoulli', {'probs': yuprobs}) *
            self.loss_weights.get('kl y loss weight', 1.0))
        self.unsu_loss_kl_z = (tf.reduce_sum(unsu_loss_kl_z_list) *
                               self.loss_weights.get('kl z loss weight', 1.0))
        self.unsu_loss_recon = (
            tf.reduce_sum(unsu_loss_recon_list) *
            self.loss_weights.get('reconstruction loss weight', 1.0))

        self.unsu_loss_reg = (
            get_loss('regularization', 'l2',
                     {'var_list': self.x_classifier.vars}) *
            self.loss_weights.get('regularization loss weight', 0.0001))

        self.unsu_loss = (
            (self.unsu_loss_kl_z + self.unsu_loss_recon + self.unsu_loss_kl_y +
             self.unsu_loss_reg) *
            self.loss_weights.get('unsupervised loss weight', 1.0))

        self.xt = tf.placeholder(tf.float32,
                                 shape=[
                                     None,
                                 ] + self.input_shape,
                                 name='xt_input')

        ###########################################################################
        # for test models
        #
        # xt --> mean_hxt, log_var_hxt
        #               |
        #             sample_hxt --> ytlogits --> ytprobs
        # 			   |			    			 |
        #		     [sample_hxt,    			  ytprobs] --> mean_hzt, log_var_hzt
        #
        hxt = self.x_encoder(self.xt)
        ytlogits = self.x_classifier(self.xt)
        self.ytprobs = tf.nn.softmax(ytlogits)
        self.mean_hzt, self.log_var_hzt = self.hx_y_encoder(
            tf.concat([hxt, self.ytprobs], axis=1))

        ###########################################################################
        # optimizer configure

        global_step, global_step_update = get_global_step()

        (self.supervised_train_op, self.supervised_learning_rate,
         _) = get_optimizer_by_config(self.config['optimizer'],
                                      self.config['optimizer params'],
                                      self.su_loss, self.vars, global_step,
                                      global_step_update)
        (self.unsupervised_train_op, self.unsupervised_learning_rate,
         _) = get_optimizer_by_config(self.config['optimizer'],
                                      self.config['optimizer parmas'],
                                      self.unsu_loss, self.vars, global_step,
                                      global_step_update)

        ###########################################################################
        # model saver
        self.saver = tf.train.Saver(self.vars + [
            self.global_step,
        ])
コード例 #8
0
    def build_model(self):
        # network config
        self.config['discriminator params']['name'] = 'Discriminator'
        self.config['generator params']['name'] = 'Generator'
        self.discriminator = self._build_discriminator('discriminator')
        self.generator = self._build_generator('generator')

        # build model
        self.x_real = tf.placeholder(tf.float32,
                                     shape=[
                                         None,
                                     ] + list(self.input_shape),
                                     name='x_input')
        self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim], name='z')

        self.x_fake = self.generator(self.z)
        self.dis_real, self.dis_real_end_points = self.discriminator.features(
            self.x_real)
        self.dis_fake, self.dis_fake_end_points = self.discriminator.features(
            self.x_fake)

        # loss config
        self.d_loss_adv = get_loss('adversarial down', 'cross entropy', {
            'dis_real': self.dis_real,
            'dis_fake': self.dis_fake
        })
        if self.use_feature_matching_loss:
            if self.feature_matching_end_points is None:
                self.feature_matching_end_points = [
                    k for k in self.dis_real_end_points.keys() if 'conv' in k
                ]
                print('feature matching end points : ',
                      self.feature_matching_end_points)
            self.d_loss_fm = get_loss(
                'feature matching', 'l2', {
                    'fx': self.dis_real_end_points,
                    'fy': self.dis_fake_end_points,
                    'fnames': self.feature_matching_end_points
                })
            self.d_loss_fm *= self.config.get('feature matching loss weight',
                                              0.01)
            self.d_loss = self.d_loss_adv + self.d_loss_fm
        else:
            self.d_loss = self.d_loss_adv

        self.g_loss = get_loss('adversarial up', 'cross entropy',
                               {'dis_fake': self.dis_fake})

        # optimizer config
        self.global_step, self.global_step_update = get_global_step()

        # optimizer of discriminator configured without global step update
        # so we can keep the learning rate of discriminator the same as generator
        (self.d_train_op, self.d_learning_rate,
         self.d_global_step) = get_optimizer_by_config(
             self.config['discriminator optimizer'],
             self.config['discriminator optimizer params'], self.d_loss,
             self.discriminator.vars, self.global_step)
        (self.g_train_op, self.g_learning_rate,
         self.g_global_step) = get_optimizer_by_config(
             self.config['generator optimizer'],
             self.config['generator optimizer params'], self.g_loss,
             self.generator.vars, self.global_step, self.global_step_update)

        # model saver
        self.saver = tf.train.Saver(self.discriminator.store_vars +
                                    self.generator.store_vars +
                                    [self.global_step])
コード例 #9
0
    def build_model(self):

        self.x_real = tf.placeholder(
            tf.float32,
            shape=[None, np.product(self.input_shape)],
            name='x_input')
        self.y_real = tf.placeholder(tf.float32,
                                     shape=[None, self.nb_classes],
                                     name='y_input')

        # self.encoder_input_shape = int(np.product(self.input_shape))

        self.config['encoder parmas']['name'] = 'EncoderX'
        self.config['encoder params']["output dims"] = self.z_dim
        self.encoder = get_encoder(self.config['x encoder'],
                                   self.config['encoder params'],
                                   self.is_training)

        self.config['decoder params']['name'] = 'Decoder'
        self.config['decoder params']["output dims"] = self.encoder_input_shape

        # self.y_encoder = get_encoder(self.config['y encoder'], self.config['y encoder params'], self.is_training)
        self.decoder = get_decoder(self.config['decoder'],
                                   self.config['decoder params'],
                                   self.is_training)

        # build encoder
        self.z_mean, self.z_log_var = self.x_encoder(
            tf.concatenate([self.x_real, self.y_real]))
        self.z_mean_y = self.y_encoder(self.y_real)

        # sample z from z_mean and z_log_var
        self.z_sample = self.draw_sample(self.z_mean, self.z_log_var)

        # build decoder
        self.x_decode = self.decoder(self.z_sample)

        # build test decoder
        self.z_test = tf.placeholder(tf.float32,
                                     shape=[None, self.z_dim],
                                     name='z_test')
        self.x_test = self.decoder(self.z_test, reuse=True)

        # loss function
        self.kl_loss = (get_loss(
            'kl', self.config['kl loss'], {
                'z_mean': (self.z_mean - self.z_mean_y),
                'z_log_var': self.z_log_var
            }) * self.config.get('kl loss prod', 1.0))
        self.xent_loss = (
            get_loss('reconstruction', self.config['reconstruction loss'], {
                'x': self.x_real,
                'y': self.x_decode
            }) * self.config.get('reconstruction loss prod', 1.0))
        self.loss = self.kl_loss + self.xent_loss

        # optimizer configure
        self.global_step, self.global_step_update = get_global_step()
        if 'lr' in self.config:
            self.learning_rate = get_learning_rate(self.config['lr_scheme'],
                                                   float(self.config['lr']),
                                                   self.global_step,
                                                   self.config['lr_params'])
            self.optimizer = get_optimizer(
                self.config['optimizer'],
                {'learning_rate': self.learning_rate}, self.loss,
                self.decoder.vars + self.x_encoder.vars + self.y_encoder.vars)
        else:
            self.optimizer = get_optimizer(
                self.config['optimizer'], {}, self.loss,
                self.decoder.vars + self.x_encoder.vars + self.y_encoder.vars)

        self.train_update = tf.group([self.optimizer, self.global_step_update])

        # model saver
        self.saver = tf.train.Saver(self.x_encoder.vars + self.y_encoder.vars,
                                    self.decoder.vars + [
                                        self.global_step,
                                    ])
コード例 #10
0
    def build_model(self):
        # network config
        self.config['discriminator params']['name'] = 'Discriminator'
        self.discriminator = get_discriminator(
            self.config['discriminator'], self.config['discriminator params'],
            self.is_training)

        self.config['generator params']['name'] = 'Generator'
        self.generator = get_generator(self.config['generator'],
                                       self.config['generator params'],
                                       self.is_training)

        self.config['classifier params']['name'] = 'Classifier'
        self.classifier = get_classifier(self.config['classifier'],
                                         self.config['classifier params'],
                                         self.is_training)

        # build model
        self.x_real = tf.placeholder(tf.float32,
                                     shape=[
                                         None,
                                     ] + list(self.input_shape),
                                     name='x_input')
        self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim], name='z')
        self.c = tf.placeholder(tf.float32, shape=[None, self.c_dim], name='z')

        self.x_fake = self.generator(self.z)

        dis_real = self.discriminator(self.x_real)
        dis_fake = self.discriminator(self.x_fake)

        # loss config
        eplison = tf.random_uniform(shape=[tf.shape(self.x_real)[0], 1, 1, 1],
                                    minval=0.0,
                                    maxval=1.0)
        x_hat = (eplison * self.x_real) + ((1 - eplison) * self.x_fake)
        dis_hat = self.discriminator(x_hat)

        self.d_loss_adv = (get_loss('adversarial down', 'wassterstein', {
            'dis_real': dis_real,
            'dis_fake': dis_fake
        }) * self.config.get('adversarial loss weight', 1.0))

        self.d_loss_gp = (get_loss('gradient penalty', 'l2', {
            'x': x_hat,
            'y': dis_hat
        }) * self.config.get('gradient penalty loss weight', 10.0))

        self.d_loss = self.d_loss_gp + self.d_loss_adv
        self.g_loss = get_loss('adversarial up', 'wassterstein',
                               {'dis_fake': dis_fake})

        # optimizer config
        self.global_step, self.global_step_update = get_global_step()

        # optimizer of discriminator
        # configured with global step and without global step update
        # so we can keep the learning rate of discriminator the same as generator
        (self.d_train_op, self.d_learning_rate,
         self.d_global_step) = get_optimizer_by_config(
             self.config['discriminator optimizer'],
             self.config['discriminator optimizer params'], self.d_loss,
             self.discriminator.vars, self.global_step)

        (self.g_train_op, self.g_learning_rate,
         self.g_global_step) = get_optimizer_by_config(
             self.config['generator optimizer'],
             self.config['generator optimizer params'], self.g_loss,
             self.generator.vars, self.global_step, self.global_step_update)

        # model saver
        self.saver = tf.train.Saver(self.discriminator.store_vars +
                                    self.generator.store_vars +
                                    [self.global_step])
コード例 #11
0
ファイル: improved_gan.py プロジェクト: yanzhicong/VAE-GAN
    def build_model(self):
        # network config
        self.config['discriminator params']['name'] = 'Discriminator'
        self.config['discriminator params'][
            "output dims"] = self.nb_classes + 1
        self.config['generator params']['name'] = 'Generator'
        self.discriminator = self._build_discriminator('discriminator')
        self.generator = self._build_generator('generator')

        # build model
        self.x_real = tf.placeholder(tf.float32,
                                     shape=[
                                         None,
                                     ] + list(self.input_shape),
                                     name='x_real')
        self.label_real = tf.placeholder(tf.float32,
                                         shape=[None, self.nb_classes],
                                         name='label_real')
        self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim], name='z')

        self.x_fake = self.generator(self.z)
        self.dis_real, self.dis_real_end_points = self.discriminator.features(
            self.x_real)
        self.dis_fake, self.dis_fake_end_points = self.discriminator.features(
            self.x_fake)

        self.prob_real = tf.nn.softmax(self.dis_real)

        # self.d_loss_feature_matching = get_loss('feature matching', 'l2',
        # 										{'fx': self.dis_real_end_points, 'fy': self.dis_fake_end_points, 'fnames' : self.feature_matching_end_points})
        # self.d_loss_feature_matching *= self.config.get('feature matching loss weight', 0.01)

        # supervised loss config
        self.d_su_loss_adv = get_loss(
            'adversarial down', 'supervised cross entropy', {
                'dis_real': self.dis_real,
                'dis_fake': self.dis_fake,
                'label': self.label_real
            })
        self.d_su_loss_adv *= self.config.get('adversarial loss weight', 1.0)

        # self.d_su_loss = self.d_su_loss_adv + self.d_loss_feature_matching
        self.d_su_loss = self.d_su_loss_adv
        # self.g_su_loss = get_loss('adversarial up', 'supervised cross entropy', {'dis_fake' : self.dis_fake, 'label': self.label_real})

        # unsupervised loss config
        self.d_unsu_loss_adv = get_loss('adversarial down',
                                        'unsupervised cross entropy', {
                                            'dis_real': self.dis_real,
                                            'dis_fake': self.dis_fake
                                        })
        self.d_unsu_loss_adv *= self.config.get('adversarial loss weight', 1.0)
        # self.d_unsu_loss = self.d_unsu_loss_adv + self.d_loss_feature_matching
        self.d_unsu_loss = self.d_unsu_loss_adv

        self.g_unsu_loss = get_loss('adversarial up',
                                    'unsupervised cross entropy',
                                    {'dis_fake': self.dis_fake})

        # optimizer config
        self.global_step, self.global_step_update = get_global_step()

        # optimizer of discriminator configured without global step update
        # so we can keep the learning rate of discriminator the same as generator
        (self.d_su_train_op, self.d_su_learning_rate,
         self.d_su_global_step) = get_optimizer_by_config(
             self.config['discriminator optimizer'],
             self.config['discriminator optimizer params'], self.d_su_loss,
             self.discriminator.vars, self.global_step,
             self.global_step_update)
        # (self.g_su_train_op,
        # 	self.g_su_learning_rate,
        # 		self.g_su_global_step) = get_optimizer_by_config(self.config['generator optimizer'],
        # 														self.config['generator optimizer params'],
        # 														self.g_su_loss, self.generator.vars,
        # 														self.global_step, self.global_step_update)

        (self.d_unsu_train_op, self.d_unsu_learning_rate,
         self.d_unsu_global_step) = get_optimizer_by_config(
             self.config['discriminator optimizer'],
             self.config['discriminator optimizer params'], self.d_unsu_loss,
             self.discriminator.vars, self.global_step)
        (self.g_unsu_train_op, self.g_unsu_learning_rate,
         self.g_unsu_global_step) = get_optimizer_by_config(
             self.config['generator optimizer'],
             self.config['generator optimizer params'], self.g_unsu_loss,
             self.generator.vars, self.global_step, self.global_step_update)

        # model saver
        self.saver = tf.train.Saver(self.discriminator.store_vars +
                                    self.generator.store_vars +
                                    [self.global_step])
コード例 #12
0
ファイル: wgan_gp.py プロジェクト: yanzhicong/VAE-GAN
    def build_model(self):
        # network config
        self.config['discriminator params']['name'] = 'Discriminator'
        self.config['generator params']['name'] = 'Generator'

        self.discriminator = self._build_discriminator('discriminator')
        self.generator = self._build_generator('generator')

        # build model
        self.x_real = tf.placeholder(tf.float32,
                                     shape=[
                                         None,
                                     ] + list(self.input_shape),
                                     name='x_input')
        self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim], name='z')

        self.x_fake = self.generator(self.z)
        self.dis_real, self.dis_real_end_points = self.discriminator.features(
            self.x_real)
        self.dis_fake, self.dis_fake_end_points = self.discriminator.features(
            self.x_fake)

        # loss config
        x_dims = len(self.input_shape)
        if x_dims == 1:
            eplison = tf.random_uniform(shape=[tf.shape(self.x_real)[0], 1],
                                        minval=0.0,
                                        maxval=1.0)
        elif x_dims == 3:
            eplison = tf.random_uniform(
                shape=[tf.shape(self.x_real)[0], 1, 1, 1],
                minval=0.0,
                maxval=1.0)
        else:
            raise NotImplementedError
        x_hat = (eplison * self.x_real) + ((1 - eplison) * self.x_fake)
        dis_hat = self.discriminator(x_hat)

        self.d_loss_list = []
        self.d_loss_adv = (
            get_loss('adversarial down', 'wassterstein', {
                'dis_real': self.dis_real,
                'dis_fake': self.dis_fake
            }) * self.config.get('adversarial loss weight', 1.0))
        self.d_loss_list.append(self.d_loss_adv)

        if self.use_feature_matching_loss:
            if self.feature_matching_end_points is None:
                self.feature_matching_end_points = [
                    k for k in self.dis_real_end_points.keys() if 'conv' in k
                ]
                print('feature matching end points : ',
                      self.feature_matching_end_points)
            self.d_loss_fm = get_loss(
                'feature matching', 'l2', {
                    'fx': self.dis_real_end_points,
                    'fy': self.dis_fake_end_points,
                    'fnames': self.feature_matching_end_points
                })
            self.d_loss_fm *= self.config.get('feature matching loss weight',
                                              0.01)
            self.d_loss_list.append(self.d_loss_fm)

        if self.use_gradient_penalty:
            self.d_loss_gp = (get_loss('gradient penalty', 'l2', {
                'x': x_hat,
                'y': dis_hat
            }) * self.config.get('gradient penalty loss weight', 10.0))
            self.d_loss_list.append(self.d_loss_gp)

        self.d_loss = tf.reduce_sum(self.d_loss_list)
        self.g_loss = get_loss('adversarial up', 'wassterstein',
                               {'dis_fake': self.dis_fake})

        # optimizer config
        self.global_step, self.global_step_update = get_global_step()

        if not self.use_gradient_penalty:
            self.clip_discriminator = [
                tf.assign(
                    tf.clip_by_value(var, self.weight_clip_bound[0],
                                     self.weight_clip_bound[1]))
                for var in self.discriminator.vars
            ]

        # optimizer of discriminator configured without global step update
        # so we can keep the learning rate of discriminator the same as generator
        (self.d_train_op, self.d_learning_rate,
         self.d_global_step) = get_optimizer_by_config(
             self.config['discriminator optimizer'],
             self.config['discriminator optimizer params'], self.d_loss,
             self.discriminator.vars, self.global_step)
        (self.g_train_op, self.g_learning_rate,
         self.g_global_step) = get_optimizer_by_config(
             self.config['generator optimizer'],
             self.config['generator optimizer params'], self.g_loss,
             self.generator.vars, self.global_step, self.global_step_update)

        # model saver
        self.saver = tf.train.Saver(self.discriminator.store_vars +
                                    self.generator.store_vars +
                                    [self.global_step])
コード例 #13
0
ファイル: semiseg.py プロジェクト: xclmj/VAE-GAN
	def build_model(self):

		self.img_u = tf.placeholder(tf.float32, shape=[None,] + self.input_shape, name='image_unlabelled_input')
		self.img_l = tf.placeholder(tf.float32, shape=[None,] + self.input_shape, name='image_labelled_input')
		self.mask_l = tf.placeholder(tf.float32, shape=[None,] + self.mask_size + [self.nb_classes], name='mask_input')

		###########################################################################
		# network define
		# 
		self.config['classifier params']['name'] = 'Segmentation'
		self.config['classifier params']["output dims"] = self.hx_dim
		self.seg_classifier = get_classifier(self.config['classifier'], 
									self.config['classifier params'], self.is_training)


		self.config['discriminator params']['name'] = 'Segmentation'
		self.config['discriminator params']["output dims"] = 1
		self.config['discriminator params']['output_activation'] = 'none'
		self.discriminator = get_discriminator(self.config['discriminator'], 
									self.config['discriminator params'], self.is_training)

		###########################################################################
		# for supervised training:
		self.mask_generated = self.seg_classifier(self.img_l)

		real_concated = tf.concatenate([self.img_l, self.mask_l], axis=-1)
		fake_concated = tf.concatenate([self.img_l, self.mask_generated], axis=-1)

		dis_real_concated = self.discriminator(real_concated)
		dis_fake_concated = self.discriminator(fake_concated)

		eplison = tf.random_uniform(shape=[tf.shape(self.img_l)[0], 1, 1, 1], minval=0.0, maxval=1.0)
		mask_hat = eplison * self.mask_l + (1 - eplison) * self.mask_generated
		concat_hat = tf.concatenate([self.img_l, mask_hat], axis=-1)

		dis_hat_concated = self.discriminator(concat_hat)


		self.d_su_loss_adv = (get_loss('adversarial down', 'wassterstein', {'dis_real' : dis_real_concated, 'dis_fake' : dis_fake_concated})
								* self.config.get('adversarial loss weight', 1.0))
		self.d_su_loss_gp = (get_loss('gradient penalty', 'l2', {'x' : concat_hat, 'y' : dis_hat_concated})
								* self.config.get('gradient penalty loss weight', 1.0))
		self.d_su_loss = self.d_su_loss_adv + self.d_su_loss_gp

		self.g_su_loss_adv = (get_loss('adversarial up', 'wassterstein', {'dis_fake' : dis_fake_concated})
								* self.config.get('adversarial loss weight', 1.0))

		self.g_su_loss_cls = (get_loss('segmentation', 'l2', {'predict' : self.mask_generated, 'mask':self.mask_l})
								* self.config.get('segmentation loss weight', 1.0))

		self.g_su_loss = self.g_su_loss_adv + self.g_su_loss_cls


		###########################################################################
		# optimizer configure
		(self.d_su_train_op,
			self.d_su_learning_rate,
				self.d_su_global_step) = get_optimizer_by_config(self.config['supervised optimizer'], self.config['supervised optimizer params'],
												self.d_su_loss, self.discriminator.vars, global_step_name='d_global_step_su')

		(self.g_su_train_op,
			self.g_su_learning_rate,
				self.g_su_global_step) = get_optimizer_by_config(self.config['supervised optimizer'], self.config['supervised optimizer params'],
												self.g_su_loss, self.generator.vars, global_step_name='g_global_step_su')

		###########################################################################
		# # for test models
		# # 
		# # xt --> mean_hxt, log_var_hxt
		# #               |
		# #             sample_hxt --> ytlogits --> ytprobs
		# # 			   |			    			 |
		# #		     [sample_hxt,    			  ytprobs] --> mean_hzt, log_var_hzt
		# #
		# mean_hxt, log_var_hxt = self.x_encoder(self.xt)
		# sample_hxt = self.draw_sample(mean_hxt, log_var_hxt)
		# ytlogits = self.hx_classifier(sample_hxt)
		# # test sample class probilities
		# self.ytprobs = tf.nn.softmax(ytlogits)
		# # test sample hidden variable distribution
		# self.mean_hzt, self.log_var_hzt = self.hx_y_encoder(tf.concat([sample_hxt, self.ytprobs], axis=1))


		###########################################################################
		# model saver
		self.saver = tf.train.Saver(self.store_vars + [self.d_su_global_step, self.g_su_global_step])
コード例 #14
0
ファイル: aae_ssl.py プロジェクト: yanzhicong/VAE-GAN
    def build_model(self):
        # network config
        self.z_discriminator = self._build_discriminator(
            'z discriminator', params={'name': 'Z_Discriminator'})
        self.y_discriminator = self._build_discriminator(
            'y discriminator', params={'name': 'Y_Discriminator'})
        self.encoder = self._build_encoder('encoder',
                                           params={
                                               'name':
                                               'Encoder',
                                               "output dims":
                                               self.z_dim + self.nb_classes
                                           })
        self.decoder = self._build_decoder('decoder',
                                           params={'name': 'Decoder'})

        # build model
        self.img = tf.placeholder(tf.float32,
                                  shape=[
                                      None,
                                  ] + list(self.input_shape),
                                  name='img')
        self.label = tf.placeholder(tf.float32,
                                    shape=[None, self.nb_classes],
                                    name='label')

        self.real_z = tf.placeholder(tf.float32,
                                     shape=[None, self.z_dim],
                                     name='real_z')
        self.real_y = tf.placeholder(tf.float32,
                                     shape=[None, self.nb_classes],
                                     name='real_y')

        self.img_encode = self.encoder(self.img)

        self.img_z = self.img_encode[:, :self.z_dim]
        self.img_logits = self.img_encode[:, self.z_dim:]
        self.img_y = tf.nn.softmax(self.img_logits)

        self.img_recon = self.decoder(
            tf.concat([self.img_z, self.img_y], axis=1))

        self.dis_z_real = self.z_discriminator(self.real_z)
        self.dis_z_fake = self.z_discriminator(self.img_z)

        if self.gan_type == 'wgan':
            eplison = tf.random_uniform(shape=[tf.shape(self.real_z)[0], 1],
                                        minval=0.0,
                                        maxval=1.0)
            self.hat_z = (eplison * self.real_z) + ((1 - eplison) * self.img_z)
            self.dis_z_hat = self.z_discriminator(self.hat_z)

        self.dis_y_real = self.y_discriminator(self.real_y)
        self.dis_y_fake = self.y_discriminator(self.img_y)

        if self.gan_type == 'wgan':
            eplison2 = tf.random_uniform(shape=[tf.shape(self.real_y)[0], 1],
                                         minval=0.0,
                                         maxval=1.0)
            self.hat_y = (eplison2 * self.real_y) + (
                (1 - eplison2) * self.img_y)
            self.dis_y_hat = self.y_discriminator(self.hat_y)

        # generate image from z
        self.img_generate = self.decoder(
            tf.concat([self.real_z, self.real_y], axis=1))

        # loss config
        # reconstruction phase
        self.loss_recon = get_loss('reconstruction', 'l2', {
            'x': self.img,
            'y': self.img_recon
        })

        # regulation phase

        if self.gan_type == 'wgan':
            self.loss_z_adv_down = get_loss('adversarial down', 'wassterstein',
                                            {
                                                'dis_real': self.dis_z_real,
                                                'dis_fake': self.dis_z_fake
                                            })
            self.loss_z_gp = get_loss('gradient penalty', 'l2', {
                'x': self.hat_z,
                'y': self.dis_z_hat
            })
            self.loss_z_adv_up = get_loss('adversarial up', 'wassterstein',
                                          {'dis_fake': self.dis_z_fake})

            self.loss_y_adv_down = get_loss('adversarial down', 'wassterstein',
                                            {
                                                'dis_real': self.dis_y_real,
                                                'dis_fake': self.dis_y_fake
                                            })
            self.loss_y_gp = get_loss('gradient penalty', 'l2', {
                'x': self.hat_y,
                'y': self.dis_y_hat
            })
            self.loss_y_adv_up = get_loss('adversarial up', 'wassterstein',
                                          {'dis_fake': self.dis_y_fake})

        elif self.gan_type == 'dcgan':
            self.loss_z_adv_down = get_loss('adversarial down',
                                            'cross entropy', {
                                                'dis_real': self.dis_z_real,
                                                'dis_fake': self.dis_z_fake
                                            })
            self.loss_z_adv_up = get_loss('adversarial up', 'cross entropy',
                                          {'dis_fake': self.dis_z_fake})

            self.loss_y_adv_down = get_loss('adversarial down',
                                            'cross entropy', {
                                                'dis_real': self.dis_y_real,
                                                'dis_fake': self.dis_y_fake
                                            })
            self.loss_y_adv_up = get_loss('adversarial up', 'cross entropy',
                                          {'dis_fake': self.dis_y_fake})

        # semi-supervised classification phase
        self.loss_cla = get_loss('classification', 'cross entropy', {
            'logits': self.img_logits,
            'labels': self.label
        })

        self.ae_loss = self.loss_recon
        if self.gan_type == 'wgan':
            self.dz_loss = self.loss_z_adv_down + self.loss_z_gp
            self.dy_loss = self.loss_y_adv_down + self.loss_y_gp
        elif self.gan_type == 'dcgan':
            self.dz_loss = self.loss_z_adv_down
            self.dy_loss = self.loss_y_adv_down
        self.ez_loss = self.loss_z_adv_up
        self.ey_loss = self.loss_y_adv_up
        self.e_loss = self.loss_cla

        # optimizer config
        self.global_step, self.global_step_update = get_global_step()

        # reconstruction phase
        (self.ae_train_op, self.ae_learning_rate,
         self.ae_step) = self._build_optimizer(
             'auto-encoder', self.loss_recon,
             self.encoder.vars + self.decoder.vars)

        # regulation phase
        (self.dz_train_op, self.dz_learning_rate,
         self.dz_step) = self._build_optimizer('discriminator', self.dz_loss,
                                               self.z_discriminator.vars)

        (self.dy_train_op, self.dy_learning_rate,
         self.dy_step) = self._build_optimizer('discriminator', self.dy_loss,
                                               self.y_discriminator.vars)

        (self.ez_train_op, self.ez_learning_rate,
         self.ez_step) = self._build_optimizer('encoder', self.ez_loss,
                                               self.encoder.vars)

        (self.ey_train_op, self.ey_learning_rate,
         self.ey_step) = self._build_optimizer('encoder', self.ey_loss,
                                               self.encoder.vars)

        # classification phase
        (self.e_train_op, self.e_learning_rate,
         self.e_step) = self._build_optimizer('classifier', self.e_loss,
                                              self.encoder.vars)

        # model saver
        self.saver = tf.train.Saver(self.z_discriminator.store_vars +
                                    self.y_discriminator.store_vars +
                                    self.encoder.store_vars +
                                    self.decoder.store_vars +
                                    [self.global_step])
コード例 #15
0
ファイル: attention_mil.py プロジェクト: yanzhicong/VAE-GAN
    def build_model(self):

        self.feature_ext_net = self._build_classifier('feature_ext',
                                                      params={
                                                          'name':
                                                          'feature_ext',
                                                          "output dims":
                                                          self.z_dims
                                                      })

        if self.mil_pooling == 'attention':
            self.attention_net_params = self.config[
                'attention_net params'].copy()
            self.attention_net_params.update({
                'name': 'attention_net',
                'output dims': 1
            })
            self.attention_net = AttentionNet(self.attention_net_params,
                                              self.is_training)

        self.classifier = self._build_classifier('classifier',
                                                 params={
                                                     'name': 'classifier',
                                                     "output dims":
                                                     self.nb_classes
                                                 })

        #
        # Build model
        #
        # 1. inputs
        self.x_bag = tf.placeholder(tf.float32,
                                    shape=[
                                        None,
                                    ] + self.input_shape,
                                    name='x_bag')
        self.label = tf.placeholder(tf.float32,
                                    shape=[self.nb_classes],
                                    name='label')

        # 2.  feature extraction
        self.features, self.fea_ext_net_endpoints = self.feature_ext_net.features(
            self.x_bag)

        # 3. mil pooling
        if self.mil_pooling == 'maxpooling':
            self.bag_feature = tf.reduce_max(self.features, axis=0)
            self.bag_feature = tf.reshape(self.bag_feature, [1, -1])

        elif self.mil_pooling == 'avgpooling':
            self.bag_feature = tf.reduce_mean(self.features, axis=0)
            self.bag_feature = tf.reshape(self.bag_feature, [1, -1])

        elif self.mil_pooling == 'attention':
            self.instance_weight, self.attention_net_endpoints = self.attention_net.features(
                self.features)
            self.bag_feature = tf.reduce_sum(self.features *
                                             self.instance_weight,
                                             axis=0)
            self.bag_feature = tf.reshape(self.bag_feature, [1, -1])

        # 4. classify
        self.logits, self.classifier_endpoints = self.classifier.features(
            self.bag_feature)
        # self.probs = tf.nn.softmax(self.logits)
        self.probs = tf.nn.sigmoid(self.logits)
        self.bag_label = tf.reshape(self.label, [1, -1])

        # 5. loss and metric
        self.entropy_loss = get_loss('classification', 'binary entropy', {
            'logits': self.logits,
            'labels': self.bag_label
        })

        # self.regulation_loss = get_loss('regularization', 'l2', {'var_list' : self.classifier.trainable_vars}) * 0.005
        # self.regulation_loss += get_loss('regularization', 'l2', {'var_list' : self.feature_ext_net.trainable_vars}) * 0.005

        # if self.mil_pooling == 'attention':
        # 	self.regulation_loss += get_loss('regularization', 'l2', {'var_list' : self.attention_net.trainable_vars}) * 0.005

        self.loss = self.entropy_loss
        #  + self.regulation_loss

        self.train_acc = get_metric('accuracy', 'multi-class acc2', {
            'probs': self.probs,
            'labels': self.bag_label
        })

        # build optimizer
        self.global_step, self.global_step_update = self._build_step_var(
            'global_step')

        if self.has_summary:
            sum_list = []
            # sum_list.append(tf.summary.scalar('train entropy loss', self.entropy_loss))
            # sum_list.append(tf.summary.scalar('train regulation loss', self.regulation_loss))
            sum_list.append(tf.summary.scalar('train acc', self.train_acc))
        else:
            sum_list = []

        train_function_args = {
            'step': self.global_step,
            'step_update': self.global_step_update,
            'build_summary': True,
            'sum_list': sum_list
        }

        if self.finetune_steps > 0:
            self.finetune_classifier, _ = self._build_train_function(
                'finetune', self.loss, self.finetune_vars,
                **train_function_args)

        self.train_classifier, _, = self._build_train_function(
            'optimizer', self.loss, self.vars, **train_function_args)

        self.saver = tf.train.Saver(self.store_vars + [
            self.global_step,
        ])