def _build_optimizer(self, name, loss, vars, step=None, step_update=None): if step == None and hasattr(self, 'global_step'): step = self.global_step if name == '' or name == 'optimizer': (train_op, learning_rate_var, step_var) = get_optimizer_by_config( self.config['optimizer'], self.config['optimizer params'], loss, vars, step, step_update) else: (train_op, learning_rate_var, step_var) = get_optimizer_by_config( self.config[name + ' optimizer'], self.config[name + ' optimizer params'], loss, vars, step, step_update) return train_op, learning_rate_var, step_var
def build_model(self): self.config['classifier params']['name'] = 'classifier' self.config['classifier params']["output dims"] = self.nb_classes self.classifier = get_classifier(self.config['classifier'], self.config['classifier params'], self.is_training) # for training self.x = tf.placeholder(tf.float32, shape=[ None, ] + self.input_shape, name='x_input') self.mask = tf.placeholder(tf.float32, shape=[ None, ] + self.mask_shape, name='mask') self.logits, self.end_points = self.classifier.features(self.x) self.loss = get_loss('segmentation', self.config['segmentation loss'], { 'logits': self.logits, 'mask': self.mask }) self.train_miou = get_metric( 'segmentation', 'miou', { 'logits': self.logits, 'mask': self.mask, 'nb_classes': self.nb_classes }) # for testing self.test_x = tf.placeholder( tf.float32, shape=[None, None, None, self.input_shape[-1]], name='test_x_input') self.test_logits = self.classifier(self.test_x) self.test_y = tf.nn.softmax(self.test_logits) (self.train_op, self.learning_rate, self.global_step) = get_optimizer_by_config( self.config['optimizer'], self.config['optimizer params'], target=self.loss, variables=self.classifier.vars) # model saver self.saver = tf.train.Saver(self.classifier.store_vars + [ self.global_step, ])
def build_model(self): # network config self.config['discriminator params']['name'] = 'Discriminator' self.config['generator params']['name'] = 'Generator' self.discriminator = self._build_discriminator('discriminator') self.generator = self._build_generator('generator') # build model self.x_real = tf.placeholder(tf.float32, shape=[None, ] + list(self.input_shape), name='x_input') self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim], name='z') self.x_fake = self.generator(self.z) self.dis_real = self.discriminator(self.x_real) self.dis_fake = self.discriminator(self.x_fake) # loss config self.d_loss = get_loss('adversarial down', 'cross entropy', {'dis_real' : self.dis_real, 'dis_fake' : self.dis_fake}) self.g_loss = get_loss('adversarial up', 'cross entropy', {'dis_fake' : self.dis_fake}) # optimizer config self.global_step, self.global_step_update = get_global_step() # optimizer of discriminator configured without global step update # so we can keep the learning rate of discriminator the same as generator (self.d_train_op, self.d_learning_rate, self.d_global_step) = get_optimizer_by_config(self.config['discriminator optimizer'], self.config['discriminator optimizer params'], self.d_loss, self.discriminator.vars, self.global_step) (self.g_train_op, self.g_learning_rate, self.g_global_step) = get_optimizer_by_config(self.config['generator optimizer'], self.config['generator optimizer params'], self.g_loss, self.generator.vars, self.global_step, self.global_step_update) # model saver self.saver = tf.train.Saver(self.discriminator.store_vars + self.generator.store_vars + [self.global_step])
def build_model(self): self.x_real = tf.placeholder(tf.float32, shape=[ None, ] + list(self.input_shape), name='x_input') self.config['encoder params']['name'] = 'encoder' self.config['decoder params']['name'] = 'decoder' self.encoder = self._build_encoder('encoder') self.decoder = self._build_decoder('decoder') # build encoder self.mean_z, self.log_var_z = self.encoder(self.x_real) # sample z from mean_z and log_var_z sample_z = self.draw_sample(self.mean_z, self.log_var_z) # build decoder self.x_decode = self.decoder(sample_z) # build test decoder self.z_test = tf.placeholder(tf.float32, shape=[None, self.z_dim], name='z_test') self.x_test = self.decoder(self.z_test) # loss function self.kl_loss = (get_loss('kl', self.config['kl loss'], { 'mean': self.mean_z, 'log_var': self.log_var_z }) * self.config.get('kl loss prod', 1.0)) self.recon_loss = ( get_loss('reconstruction', self.config['reconstruction loss'], { 'x': self.x_real, 'y': self.x_decode }) * self.config.get('reconstruction loss prod', 1.0)) self.loss = self.kl_loss + self.recon_loss # optimizer configure self.train_op, self.learning_rate, self.global_step = get_optimizer_by_config( self.config['optimizer'], self.config['optimizer params'], self.loss, self.vars) # model saver self.saver = tf.train.Saver(self.encoder.store_vars + self.decoder.store_vars + [ self.global_step, ])
def build_model_m1(self): self.xu = tf.placeholder(tf.float32, shape=[None,] + self.input_shape, name='xu_input') ########################################################################### # network define # # x_encoder : x -> hx self.config['x encoder params']['name'] = 'EncoderHX_X' self.config['x encoder params']["output dims"] = self.hx_dim self.x_encoder = get_encoder(self.config['x encoder'], self.config['x encoder params'], self.is_training) # decoder : hx -> x self.config['hx decoder params']['name'] = 'DecoderX_HX' # if self.config # self.config['hx decoder params']["output dims"] = int(np.product(self.input_shape)) self.hx_decoder = get_decoder(self.config['hx decoder'], self.config['hx decoder params'], self.is_training) ########################################################################### # for unsupervised training: # # xu --> mean_hxu, log_var_hxu ==> kl loss # | # sample_hxu --> xu_decode ==> reconstruction loss mean_hxu, log_var_hxu = self.x_encoder(self.xu) sample_hxu = self.draw_sample(mean_hxu, log_var_hxu) xu_decode = self.hx_decoder(sample_hxu) self.m1_loss_kl_z = (get_loss('kl', 'gaussian', {'mean' : mean_hxu, 'log_var' : log_var_hxu}) * self.m1_loss_weights.get('kl z loss weight', 1.0)) self.m1_loss_recon = (get_loss('reconstruction', 'mse', {'x' : self.xu, 'y' : xu_decode}) * self.m1_loss_weights.get('reconstruction loss weight', 1.0)) self.m1_loss = self.m1_loss_kl_z + self.m1_loss_recon ########################################################################### # optimizer configure self.m1_global_step, m1_global_step_update = get_global_step('m1_step') (self.m1_train_op, self.m1_learning_rate, _) = get_optimizer_by_config(self.config['m1 optimizer'], self.config['m1 optimizer params'], self.m1_loss, self.m1_vars, self.m1_global_step, m1_global_step_update)
def build_model(self): self.xu = tf.placeholder(tf.float32, shape=[ None, ] + self.input_shape, name='xu_input') self.xl = tf.placeholder(tf.float32, shape=[ None, ] + self.input_shape, name='xl_input') self.yl = tf.placeholder(tf.float32, shape=[None, self.nb_classes], name='yl_input') ########################################################################### # network define # # x_encoder : x -> hx self.config['x encoder params']['name'] = 'EncoderHX_X' self.config['x encoder params']["output dims"] = self.hx_dim self.x_encoder = get_encoder(self.config['x encoder'], self.config['x encoder params'], self.is_training) # hx_y_encoder : [hx, y] -> hz self.config['hx y encoder params']['name'] = 'EncoderHZ_HXY' self.config['hx y encoder params']["output dims"] = self.hz_dim self.hx_y_encoder = get_encoder(self.config['hx y encoder'], self.config['hx y encoder params'], self.is_training) # hz_y_decoder : [hz, y] -> x_decode self.config['hz y decoder params']['name'] = 'DecoderX_HZY' self.config['hz y decoder params']["output dims"] = int( np.product(self.input_shape)) self.hz_y_decoder = get_decoder(self.config['hz y decoder'], self.config['hz y decoder params'], self.is_training) # x_classifier : hx -> ylogits self.config['x classifier params']['name'] = 'ClassifierX' self.config['x classifier params']["output dims"] = self.nb_classes self.x_classifier = get_classifier(self.config['x classifier'], self.config['x classifier params'], self.is_training) ########################################################################### # for supervised training: # # xl --> mean_hxl, log_var_hxl # | # sample_hxl --> yllogits ==> classification loss # | # [sample_hxl, yl] --> mean_hzl, log_var_hzl ==> kl loss # | | # [yl, sample_hzl] --> xl_decode ==> reconstruction loss # hxl = self.x_encoder(self.xl) mean_hzl, log_var_hzl = self.hx_y_encoder( tf.concat([hxl, self.yl], axis=1)) sample_hzl = self.draw_sample(mean_hzl, log_var_hzl) decode_xl = self.hz_y_decoder(tf.concat([sample_hzl, self.yl], axis=1)) # decode_xl = self.hx_decoder(decode_hxl) yllogits = self.x_classifier(self.xl) self.su_loss_kl_z = (get_loss('kl', 'gaussian', { 'mean': mean_hzl, 'log_var': log_var_hzl, }) * self.loss_weights.get('kl z loss weight', 1.0)) self.su_loss_recon = (get_loss('reconstruction', 'mse', { 'x': self.xl, 'y': decode_xl }) * self.loss_weights.get('reconstruction loss weight', 1.0)) self.su_loss_cls = (get_loss('classification', 'cross entropy', { 'logits': yllogits, 'labels': self.yl }) * self.loss_weights.get('classiciation loss weight', 1.0)) self.su_loss_reg = ( get_loss('regularization', 'l2', {'var_list': self.x_classifier.vars}) * self.loss_weights.get('regularization loss weight', 0.0001)) self.su_loss = ((self.su_loss_kl_z + self.su_loss_recon + self.su_loss_cls + self.su_loss_reg) * self.loss_weights.get('supervised loss weight', 1.0)) ########################################################################### # for unsupervised training: # # xu --> mean_hxu, log_var_hxu # | # sample_hxu --> yulogits --> yuprobs # | # [sample_hxu, y0] --> mean_hzu0, log_var_hzu0 ==> kl_loss * yuprobs[0] # | | | # | [y0, sample_hzu0] --> decode_hxu0 ==> reconstruction loss * yuprobs[0] # | # [sample_hxu, y1] --> mean_hzu1, log_var_hzu1 ==> kl_loss * yuprobs[1] # | | | # | [y1, sample_hzu1] --> decode_hxu1 ==> reconstruction loss * yuprobs[1] # ....... # hxu = self.x_encoder(self.xu) yulogits = self.x_classifier(self.xu) yuprobs = tf.nn.softmax(yulogits) unsu_loss_kl_z_list = [] unsu_loss_recon_list = [] for i in range(self.nb_classes): yu_fake = tf.ones([ tf.shape(self.xu)[0], ], dtype=tf.int32) * i yu_fake = tf.one_hot(yu_fake, depth=self.nb_classes) mean_hzu, log_var_hzu = self.hx_y_encoder( tf.concat([hxu, yu_fake], axis=1)) sample_hzu = self.draw_sample(mean_hzu, log_var_hzu) decode_xu = self.hz_y_decoder( tf.concat([sample_hzu, yu_fake], axis=1)) # decode_xu = self.hx_decoder(decode_hxu) unsu_loss_kl_z_list.append( get_loss( 'kl', 'gaussian', { 'mean': mean_hzu, 'log_var': log_var_hzu, 'instance_weight': yuprobs[:, i] })) unsu_loss_recon_list.append( get_loss('reconstruction', 'mse', { 'x': self.xu, 'y': decode_xu, 'instance_weight': yuprobs[:, i] })) self.unsu_loss_kl_y = ( get_loss('kl', 'bernoulli', {'probs': yuprobs}) * self.loss_weights.get('kl y loss weight', 1.0)) self.unsu_loss_kl_z = (tf.reduce_sum(unsu_loss_kl_z_list) * self.loss_weights.get('kl z loss weight', 1.0)) self.unsu_loss_recon = ( tf.reduce_sum(unsu_loss_recon_list) * self.loss_weights.get('reconstruction loss weight', 1.0)) self.unsu_loss_reg = ( get_loss('regularization', 'l2', {'var_list': self.x_classifier.vars}) * self.loss_weights.get('regularization loss weight', 0.0001)) self.unsu_loss = ( (self.unsu_loss_kl_z + self.unsu_loss_recon + self.unsu_loss_kl_y + self.unsu_loss_reg) * self.loss_weights.get('unsupervised loss weight', 1.0)) self.xt = tf.placeholder(tf.float32, shape=[ None, ] + self.input_shape, name='xt_input') ########################################################################### # for test models # # xt --> mean_hxt, log_var_hxt # | # sample_hxt --> ytlogits --> ytprobs # | | # [sample_hxt, ytprobs] --> mean_hzt, log_var_hzt # hxt = self.x_encoder(self.xt) ytlogits = self.x_classifier(self.xt) self.ytprobs = tf.nn.softmax(ytlogits) self.mean_hzt, self.log_var_hzt = self.hx_y_encoder( tf.concat([hxt, self.ytprobs], axis=1)) ########################################################################### # optimizer configure global_step, global_step_update = get_global_step() (self.supervised_train_op, self.supervised_learning_rate, _) = get_optimizer_by_config(self.config['optimizer'], self.config['optimizer params'], self.su_loss, self.vars, global_step, global_step_update) (self.unsupervised_train_op, self.unsupervised_learning_rate, _) = get_optimizer_by_config(self.config['optimizer'], self.config['optimizer parmas'], self.unsu_loss, self.vars, global_step, global_step_update) ########################################################################### # model saver self.saver = tf.train.Saver(self.vars + [ self.global_step, ])
def build_model(self): # network config self.config['discriminator params']['name'] = 'Discriminator' self.config['generator params']['name'] = 'Generator' self.discriminator = self._build_discriminator('discriminator') self.generator = self._build_generator('generator') # build model self.x_real = tf.placeholder(tf.float32, shape=[ None, ] + list(self.input_shape), name='x_input') self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim], name='z') self.x_fake = self.generator(self.z) self.dis_real, self.dis_real_end_points = self.discriminator.features( self.x_real) self.dis_fake, self.dis_fake_end_points = self.discriminator.features( self.x_fake) # loss config self.d_loss_adv = get_loss('adversarial down', 'cross entropy', { 'dis_real': self.dis_real, 'dis_fake': self.dis_fake }) if self.use_feature_matching_loss: if self.feature_matching_end_points is None: self.feature_matching_end_points = [ k for k in self.dis_real_end_points.keys() if 'conv' in k ] print('feature matching end points : ', self.feature_matching_end_points) self.d_loss_fm = get_loss( 'feature matching', 'l2', { 'fx': self.dis_real_end_points, 'fy': self.dis_fake_end_points, 'fnames': self.feature_matching_end_points }) self.d_loss_fm *= self.config.get('feature matching loss weight', 0.01) self.d_loss = self.d_loss_adv + self.d_loss_fm else: self.d_loss = self.d_loss_adv self.g_loss = get_loss('adversarial up', 'cross entropy', {'dis_fake': self.dis_fake}) # optimizer config self.global_step, self.global_step_update = get_global_step() # optimizer of discriminator configured without global step update # so we can keep the learning rate of discriminator the same as generator (self.d_train_op, self.d_learning_rate, self.d_global_step) = get_optimizer_by_config( self.config['discriminator optimizer'], self.config['discriminator optimizer params'], self.d_loss, self.discriminator.vars, self.global_step) (self.g_train_op, self.g_learning_rate, self.g_global_step) = get_optimizer_by_config( self.config['generator optimizer'], self.config['generator optimizer params'], self.g_loss, self.generator.vars, self.global_step, self.global_step_update) # model saver self.saver = tf.train.Saver(self.discriminator.store_vars + self.generator.store_vars + [self.global_step])
def build_model(self): # network config self.config['discriminator params']['name'] = 'Discriminator' self.discriminator = get_discriminator( self.config['discriminator'], self.config['discriminator params'], self.is_training) self.config['generator params']['name'] = 'Generator' self.generator = get_generator(self.config['generator'], self.config['generator params'], self.is_training) self.config['classifier params']['name'] = 'Classifier' self.classifier = get_classifier(self.config['classifier'], self.config['classifier params'], self.is_training) # build model self.x_real = tf.placeholder(tf.float32, shape=[ None, ] + list(self.input_shape), name='x_input') self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim], name='z') self.c = tf.placeholder(tf.float32, shape=[None, self.c_dim], name='z') self.x_fake = self.generator(self.z) dis_real = self.discriminator(self.x_real) dis_fake = self.discriminator(self.x_fake) # loss config eplison = tf.random_uniform(shape=[tf.shape(self.x_real)[0], 1, 1, 1], minval=0.0, maxval=1.0) x_hat = (eplison * self.x_real) + ((1 - eplison) * self.x_fake) dis_hat = self.discriminator(x_hat) self.d_loss_adv = (get_loss('adversarial down', 'wassterstein', { 'dis_real': dis_real, 'dis_fake': dis_fake }) * self.config.get('adversarial loss weight', 1.0)) self.d_loss_gp = (get_loss('gradient penalty', 'l2', { 'x': x_hat, 'y': dis_hat }) * self.config.get('gradient penalty loss weight', 10.0)) self.d_loss = self.d_loss_gp + self.d_loss_adv self.g_loss = get_loss('adversarial up', 'wassterstein', {'dis_fake': dis_fake}) # optimizer config self.global_step, self.global_step_update = get_global_step() # optimizer of discriminator # configured with global step and without global step update # so we can keep the learning rate of discriminator the same as generator (self.d_train_op, self.d_learning_rate, self.d_global_step) = get_optimizer_by_config( self.config['discriminator optimizer'], self.config['discriminator optimizer params'], self.d_loss, self.discriminator.vars, self.global_step) (self.g_train_op, self.g_learning_rate, self.g_global_step) = get_optimizer_by_config( self.config['generator optimizer'], self.config['generator optimizer params'], self.g_loss, self.generator.vars, self.global_step, self.global_step_update) # model saver self.saver = tf.train.Saver(self.discriminator.store_vars + self.generator.store_vars + [self.global_step])
def build_model(self): # network config self.config['discriminator params']['name'] = 'Discriminator' self.config['discriminator params'][ "output dims"] = self.nb_classes + 1 self.config['generator params']['name'] = 'Generator' self.discriminator = self._build_discriminator('discriminator') self.generator = self._build_generator('generator') # build model self.x_real = tf.placeholder(tf.float32, shape=[ None, ] + list(self.input_shape), name='x_real') self.label_real = tf.placeholder(tf.float32, shape=[None, self.nb_classes], name='label_real') self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim], name='z') self.x_fake = self.generator(self.z) self.dis_real, self.dis_real_end_points = self.discriminator.features( self.x_real) self.dis_fake, self.dis_fake_end_points = self.discriminator.features( self.x_fake) self.prob_real = tf.nn.softmax(self.dis_real) # self.d_loss_feature_matching = get_loss('feature matching', 'l2', # {'fx': self.dis_real_end_points, 'fy': self.dis_fake_end_points, 'fnames' : self.feature_matching_end_points}) # self.d_loss_feature_matching *= self.config.get('feature matching loss weight', 0.01) # supervised loss config self.d_su_loss_adv = get_loss( 'adversarial down', 'supervised cross entropy', { 'dis_real': self.dis_real, 'dis_fake': self.dis_fake, 'label': self.label_real }) self.d_su_loss_adv *= self.config.get('adversarial loss weight', 1.0) # self.d_su_loss = self.d_su_loss_adv + self.d_loss_feature_matching self.d_su_loss = self.d_su_loss_adv # self.g_su_loss = get_loss('adversarial up', 'supervised cross entropy', {'dis_fake' : self.dis_fake, 'label': self.label_real}) # unsupervised loss config self.d_unsu_loss_adv = get_loss('adversarial down', 'unsupervised cross entropy', { 'dis_real': self.dis_real, 'dis_fake': self.dis_fake }) self.d_unsu_loss_adv *= self.config.get('adversarial loss weight', 1.0) # self.d_unsu_loss = self.d_unsu_loss_adv + self.d_loss_feature_matching self.d_unsu_loss = self.d_unsu_loss_adv self.g_unsu_loss = get_loss('adversarial up', 'unsupervised cross entropy', {'dis_fake': self.dis_fake}) # optimizer config self.global_step, self.global_step_update = get_global_step() # optimizer of discriminator configured without global step update # so we can keep the learning rate of discriminator the same as generator (self.d_su_train_op, self.d_su_learning_rate, self.d_su_global_step) = get_optimizer_by_config( self.config['discriminator optimizer'], self.config['discriminator optimizer params'], self.d_su_loss, self.discriminator.vars, self.global_step, self.global_step_update) # (self.g_su_train_op, # self.g_su_learning_rate, # self.g_su_global_step) = get_optimizer_by_config(self.config['generator optimizer'], # self.config['generator optimizer params'], # self.g_su_loss, self.generator.vars, # self.global_step, self.global_step_update) (self.d_unsu_train_op, self.d_unsu_learning_rate, self.d_unsu_global_step) = get_optimizer_by_config( self.config['discriminator optimizer'], self.config['discriminator optimizer params'], self.d_unsu_loss, self.discriminator.vars, self.global_step) (self.g_unsu_train_op, self.g_unsu_learning_rate, self.g_unsu_global_step) = get_optimizer_by_config( self.config['generator optimizer'], self.config['generator optimizer params'], self.g_unsu_loss, self.generator.vars, self.global_step, self.global_step_update) # model saver self.saver = tf.train.Saver(self.discriminator.store_vars + self.generator.store_vars + [self.global_step])
def build_model(self): # network config self.config['discriminator params']['name'] = 'Discriminator' self.config['generator params']['name'] = 'Generator' self.discriminator = self._build_discriminator('discriminator') self.generator = self._build_generator('generator') # build model self.x_real = tf.placeholder(tf.float32, shape=[ None, ] + list(self.input_shape), name='x_input') self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim], name='z') self.x_fake = self.generator(self.z) self.dis_real, self.dis_real_end_points = self.discriminator.features( self.x_real) self.dis_fake, self.dis_fake_end_points = self.discriminator.features( self.x_fake) # loss config x_dims = len(self.input_shape) if x_dims == 1: eplison = tf.random_uniform(shape=[tf.shape(self.x_real)[0], 1], minval=0.0, maxval=1.0) elif x_dims == 3: eplison = tf.random_uniform( shape=[tf.shape(self.x_real)[0], 1, 1, 1], minval=0.0, maxval=1.0) else: raise NotImplementedError x_hat = (eplison * self.x_real) + ((1 - eplison) * self.x_fake) dis_hat = self.discriminator(x_hat) self.d_loss_list = [] self.d_loss_adv = ( get_loss('adversarial down', 'wassterstein', { 'dis_real': self.dis_real, 'dis_fake': self.dis_fake }) * self.config.get('adversarial loss weight', 1.0)) self.d_loss_list.append(self.d_loss_adv) if self.use_feature_matching_loss: if self.feature_matching_end_points is None: self.feature_matching_end_points = [ k for k in self.dis_real_end_points.keys() if 'conv' in k ] print('feature matching end points : ', self.feature_matching_end_points) self.d_loss_fm = get_loss( 'feature matching', 'l2', { 'fx': self.dis_real_end_points, 'fy': self.dis_fake_end_points, 'fnames': self.feature_matching_end_points }) self.d_loss_fm *= self.config.get('feature matching loss weight', 0.01) self.d_loss_list.append(self.d_loss_fm) if self.use_gradient_penalty: self.d_loss_gp = (get_loss('gradient penalty', 'l2', { 'x': x_hat, 'y': dis_hat }) * self.config.get('gradient penalty loss weight', 10.0)) self.d_loss_list.append(self.d_loss_gp) self.d_loss = tf.reduce_sum(self.d_loss_list) self.g_loss = get_loss('adversarial up', 'wassterstein', {'dis_fake': self.dis_fake}) # optimizer config self.global_step, self.global_step_update = get_global_step() if not self.use_gradient_penalty: self.clip_discriminator = [ tf.assign( tf.clip_by_value(var, self.weight_clip_bound[0], self.weight_clip_bound[1])) for var in self.discriminator.vars ] # optimizer of discriminator configured without global step update # so we can keep the learning rate of discriminator the same as generator (self.d_train_op, self.d_learning_rate, self.d_global_step) = get_optimizer_by_config( self.config['discriminator optimizer'], self.config['discriminator optimizer params'], self.d_loss, self.discriminator.vars, self.global_step) (self.g_train_op, self.g_learning_rate, self.g_global_step) = get_optimizer_by_config( self.config['generator optimizer'], self.config['generator optimizer params'], self.g_loss, self.generator.vars, self.global_step, self.global_step_update) # model saver self.saver = tf.train.Saver(self.discriminator.store_vars + self.generator.store_vars + [self.global_step])
def build_model(self): self.img_u = tf.placeholder(tf.float32, shape=[None,] + self.input_shape, name='image_unlabelled_input') self.img_l = tf.placeholder(tf.float32, shape=[None,] + self.input_shape, name='image_labelled_input') self.mask_l = tf.placeholder(tf.float32, shape=[None,] + self.mask_size + [self.nb_classes], name='mask_input') ########################################################################### # network define # self.config['classifier params']['name'] = 'Segmentation' self.config['classifier params']["output dims"] = self.hx_dim self.seg_classifier = get_classifier(self.config['classifier'], self.config['classifier params'], self.is_training) self.config['discriminator params']['name'] = 'Segmentation' self.config['discriminator params']["output dims"] = 1 self.config['discriminator params']['output_activation'] = 'none' self.discriminator = get_discriminator(self.config['discriminator'], self.config['discriminator params'], self.is_training) ########################################################################### # for supervised training: self.mask_generated = self.seg_classifier(self.img_l) real_concated = tf.concatenate([self.img_l, self.mask_l], axis=-1) fake_concated = tf.concatenate([self.img_l, self.mask_generated], axis=-1) dis_real_concated = self.discriminator(real_concated) dis_fake_concated = self.discriminator(fake_concated) eplison = tf.random_uniform(shape=[tf.shape(self.img_l)[0], 1, 1, 1], minval=0.0, maxval=1.0) mask_hat = eplison * self.mask_l + (1 - eplison) * self.mask_generated concat_hat = tf.concatenate([self.img_l, mask_hat], axis=-1) dis_hat_concated = self.discriminator(concat_hat) self.d_su_loss_adv = (get_loss('adversarial down', 'wassterstein', {'dis_real' : dis_real_concated, 'dis_fake' : dis_fake_concated}) * self.config.get('adversarial loss weight', 1.0)) self.d_su_loss_gp = (get_loss('gradient penalty', 'l2', {'x' : concat_hat, 'y' : dis_hat_concated}) * self.config.get('gradient penalty loss weight', 1.0)) self.d_su_loss = self.d_su_loss_adv + self.d_su_loss_gp self.g_su_loss_adv = (get_loss('adversarial up', 'wassterstein', {'dis_fake' : dis_fake_concated}) * self.config.get('adversarial loss weight', 1.0)) self.g_su_loss_cls = (get_loss('segmentation', 'l2', {'predict' : self.mask_generated, 'mask':self.mask_l}) * self.config.get('segmentation loss weight', 1.0)) self.g_su_loss = self.g_su_loss_adv + self.g_su_loss_cls ########################################################################### # optimizer configure (self.d_su_train_op, self.d_su_learning_rate, self.d_su_global_step) = get_optimizer_by_config(self.config['supervised optimizer'], self.config['supervised optimizer params'], self.d_su_loss, self.discriminator.vars, global_step_name='d_global_step_su') (self.g_su_train_op, self.g_su_learning_rate, self.g_su_global_step) = get_optimizer_by_config(self.config['supervised optimizer'], self.config['supervised optimizer params'], self.g_su_loss, self.generator.vars, global_step_name='g_global_step_su') ########################################################################### # # for test models # # # # xt --> mean_hxt, log_var_hxt # # | # # sample_hxt --> ytlogits --> ytprobs # # | | # # [sample_hxt, ytprobs] --> mean_hzt, log_var_hzt # # # mean_hxt, log_var_hxt = self.x_encoder(self.xt) # sample_hxt = self.draw_sample(mean_hxt, log_var_hxt) # ytlogits = self.hx_classifier(sample_hxt) # # test sample class probilities # self.ytprobs = tf.nn.softmax(ytlogits) # # test sample hidden variable distribution # self.mean_hzt, self.log_var_hzt = self.hx_y_encoder(tf.concat([sample_hxt, self.ytprobs], axis=1)) ########################################################################### # model saver self.saver = tf.train.Saver(self.store_vars + [self.d_su_global_step, self.g_su_global_step])