def build_model(self): """ Graph Input """ # images if self.custom_dataset : Image_Data_Class = ImageData(self.img_size, self.c_dim) inputs = tf.data.Dataset.from_tensor_slices(self.data) gpu_device = '/gpu:0' inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size)) inputs_iterator = inputs.make_one_shot_iterator() self.inputs = inputs_iterator.get_next() else : self.inputs = tf.placeholder(tf.float32, [self.batch_size, self.img_size, self.img_size, self.c_dim], name='real_images') # noises self.z = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name='z') """ Loss Function """ # output of D for real images real_logits = self.discriminator(self.inputs) # output of D for fake images fake_images = self.generator(self.z) fake_logits = self.discriminator(fake_images, reuse=True) if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan' : GP = self.gradient_penalty(real=self.inputs, fake=fake_images) else : GP = 0 # get loss for discriminator self.d_loss = discriminator_loss(self.gan_type, real=real_logits, fake=fake_logits) + GP # get loss for generator self.g_loss = generator_loss(self.gan_type, fake=fake_logits) """ Training """ # divide trainable variables into a group for D and a group for G t_vars = tf.trainable_variables() d_vars = [var for var in t_vars if 'discriminator' in var.name] g_vars = [var for var in t_vars if 'generator' in var.name] # optimizers self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=d_vars) self.g_optim = tf.train.AdamOptimizer(self.g_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.g_loss, var_list=g_vars) """" Testing """ # for test self.fake_images = self.generator(self.z, is_training=False, reuse=True) """ Summary """ self.d_sum = tf.summary.scalar("d_loss", self.d_loss) self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
def _input_fn(channel): """Returns a Dataset for reading from a SageMaker PipeMode channel.""" features = { 'data': tf.FixedLenFeature([], tf.string), 'labels': tf.FixedLenFeature([], tf.int64), } def parse(record): parsed = tf.parse_single_example(record, features) return ({ 'data': tf.decode_raw(parsed['data'], tf.float64) }, parsed['labels']) ds = PipeModeDataset(channel) if EPOCHS > 1: ds = ds.repeat(EPOCHS) ds = ds.prefetch(PREFETCH_SIZE) ds = ds.apply(map_and_batch(parse, batch_size=BATCH_SIZE, num_parallel_batches=NUM_PARALLEL_BATCHES)) return ds
def input_fn(params): """The actual input function.""" if use_tpu: batch_size = params["batch_size"] else: batch_size = bsz # For training, we want a lot of parallel reading and shuffling. # For eval, we want no shuffling and parallel reading doesn't matter. d = tf.data.TFRecordDataset(input_file) if is_training: d = d.repeat() d = d.shuffle(buffer_size=100) d = d.apply( contrib_data.map_and_batch( lambda record: _decode_record(record, name_to_features), batch_size=batch_size, drop_remainder=drop_remainder)) return d
def input_fn(): """The actual input function.""" micro_batch_size = opts["micro_batch_size"] # For training, we want a lot of parallel reading and shuffling. # For eval, we want no shuffling and parallel reading doesn't matter. d = tf.data.TFRecordDataset(input_file) """ Not sure the repeat is in the right place. Only for training. """ if is_training: buffer_size = opts['buffer_size'] d = d.shuffle(buffer_size=buffer_size) d = d.apply( map_and_batch( lambda record: _decode_record(record, name_to_features), batch_size=micro_batch_size, drop_remainder=True)) d = d.repeat() return d
def _input_fn(channel): """Returns a Dataset for reading from a SageMaker PipeMode channel.""" features = { "data": tf.FixedLenFeature([], tf.string), "labels": tf.FixedLenFeature([], tf.int64), } def parse(record): parsed = tf.parse_single_example(record, features) return ({ "data": tf.decode_raw(parsed["data"], tf.float64) }, parsed["labels"]) ds = PipeModeDataset(channel) if EPOCHS > 1: ds = ds.repeat(EPOCHS) ds = ds.prefetch(PREFETCH_SIZE) ds = ds.apply( map_and_batch(parse, batch_size=BATCH_SIZE, num_parallel_batches=NUM_PARALLEL_BATCHES)) return ds
def _input_fn(): features = { 'data': tf.FixedLenFeature([], tf.string), 'labels': tf.FixedLenFeature([], tf.int64), } def parse(record): parsed = tf.parse_single_example(record, features) return ({ 'data': tf.decode_raw(parsed['data'], tf.float64) }, parsed['labels']) ds = PipeModeDataset(config.channel, benchmark=True) if config.epochs > 1: ds = ds.repeat(config.epochs) if config.prefetch_size > 0: ds = ds.prefetch(config.prefetch_size) ds = ds.apply( map_and_batch(parse, batch_size=config.batch_size, num_parallel_batches=config.parallel_transform_calls)) return ds
def input_fn(params): """The actual input function.""" batch_size = params["batch_size"] # For training, we want a lot of parallel reading and shuffling. # For eval, we want no shuffling and parallel reading doesn't matter. d = tf.data.Dataset.list_files(input_file, shuffle=False) d = d.apply( contrib_data.parallel_interleave(functools.partial( tf.data.TFRecordDataset, compression_type=FLAGS.compression_type), cycle_length=32, sloppy=is_training)) if is_training: d = d.repeat() d = d.shuffle(buffer_size=100) d = d.apply( contrib_data.map_and_batch( lambda record: _decode_record(record, name_to_features), batch_size=batch_size, drop_remainder=drop_remainder)) return d
def build_model(self): if self.phase == 'train': self.lr = tf.placeholder(tf.float32, name='learning_rate') self.manual_d_loss = tf.placeholder(tf.float32, name='manual_d_loss') """ Input Image""" Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag) # All optimized variables # Discriminator #self.D_optim = [] self.Discriminator_loss_total = [] #self.D_loss = [] # Generator #self.real_A = [] #self.real_B = [] #self.fake_A = [] #self.fake_B = [] #self.G_optim = [] self.Generator_loss_total = [] #self.G_loss = [] # TODO(jhhuang): we should change how we load data here trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset) trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset) # TODO(jhhuang): change GPU devices gpu_device = '/gpu:0' trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True) ) #.apply(prefetch_to_device(gpu_device)) trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True) ) #.apply(prefetch_to_device(gpu_device)) trainA_iterator = trainA.make_one_shot_iterator() trainB_iterator = trainB.make_one_shot_iterator() # TODO(jhhuang): revisit the for-loop reuse_vars = False available_gpus = self.get_available_gpus() num_gpus = len(available_gpus) print(str(available_gpus)) for i in range(num_gpus): print("Current GPU: " + str(available_gpus[i])) with tf.device( self.assign_to_device(available_gpus[i], ps_device='/cpu:0')): #gpu_device = '/gpu:0' #trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device)).as_numpy_iterator() #trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device)).as_numpy_iterator() domain_A = trainA_iterator.get_next() domain_B = trainB_iterator.get_next() """ Define Generator, Discriminator """ x_ab, cam_ab = self.generate_a2b( domain_A, reuse=reuse_vars) # real a x_ba, cam_ba = self.generate_b2a( domain_B, reuse=reuse_vars) # real b x_aba, _ = self.generate_b2a(x_ab, reuse=True) # real b x_bab, _ = self.generate_a2b(x_ba, reuse=True) # real a x_aa, cam_aa = self.generate_b2a(domain_A, reuse=True) # fake b x_bb, cam_bb = self.generate_a2b(domain_B, reuse=True) # fake a real_A_logit, real_A_cam_logit, real_B_logit, real_B_cam_logit = self.discriminate_real( domain_A, domain_B, reuse=reuse_vars) fake_A_logit, fake_A_cam_logit, fake_B_logit, fake_B_cam_logit = self.discriminate_fake( x_ba, x_ab) tf.print("real_A_logit: ", real_A_logit) tf.print("real_A_cam_logit: ", real_A_cam_logit) tf.print("real_B_logit: ", real_B_logit) tf.print("real_B_cam_logit: ", real_B_cam_logit) tf.print("fake_A_logit: ", fake_A_logit) tf.print("fake_A_cam_logit: ", fake_A_cam_logit) tf.print("fake_B_logit: ", fake_B_logit) tf.print("fake_B_cam_logit: ", fake_B_cam_logit) """ Define Loss """ if self.gan_type.__contains__( 'wgan') or self.gan_type == 'dragan': GP_A, GP_CAM_A = self.gradient_panalty( real=domain_A, fake=x_ba, scope="discriminator_A") GP_B, GP_CAM_B = self.gradient_panalty( real=domain_B, fake=x_ab, scope="discriminator_B") else: GP_A, GP_CAM_A = 0, 0 GP_B, GP_CAM_B = 0, 0 G_ad_loss_A = ( generator_loss(self.gan_type, fake_A_logit) + generator_loss(self.gan_type, fake_A_cam_logit)) G_ad_loss_B = ( generator_loss(self.gan_type, fake_B_logit) + generator_loss(self.gan_type, fake_B_cam_logit)) self.fool_discriminator_counter = tf.placeholder( tf.int32, name="fool_discriminator_counter") fool_dis = tf.cast( (self.fool_discriminator_counter % 100 == 0), tf.bool) D_ad_loss_A = tf.cond( fool_dis, lambda: (discriminator_loss( self.gan_type, fake_A_logit, real_A_logit ) + discriminator_loss( self.gan_type, fake_A_cam_logit, real_A_cam_logit ) + GP_A + GP_CAM_A), lambda: (discriminator_loss( self.gan_type, real_A_logit, fake_A_logit) + discriminator_loss( self.gan_type, real_A_cam_logit, fake_A_cam_logit) + GP_A + GP_CAM_A)) D_ad_loss_B = tf.cond( fool_dis, lambda: (discriminator_loss( self.gan_type, fake_B_logit, real_B_logit ) + discriminator_loss( self.gan_type, fake_B_cam_logit, real_B_cam_logit ) + GP_B + GP_CAM_B), lambda: (discriminator_loss( self.gan_type, real_B_logit, fake_B_logit) + discriminator_loss( self.gan_type, real_B_cam_logit, fake_B_cam_logit) + GP_B + GP_CAM_B)) #else : # D_ad_loss_A = (discriminator_loss(self.gan_type, real_A_logit, fake_A_logit) + discriminator_loss(self.gan_type, real_A_cam_logit, fake_A_cam_logit) + GP_A + GP_CAM_A) # D_ad_loss_B = (discriminator_loss(self.gan_type, real_B_logit, fake_B_logit) + discriminator_loss(self.gan_type, real_B_cam_logit, fake_B_cam_logit) + GP_B + GP_CAM_B) #self.fool_discriminator_counter += 1 reconstruction_A = L1_loss(x_aba, domain_A) # reconstruction reconstruction_B = L1_loss(x_bab, domain_B) # reconstruction identity_A = L1_loss(x_aa, domain_A) identity_B = L1_loss(x_bb, domain_B) cam_A = cam_loss(source=cam_ba, non_source=cam_aa) cam_B = cam_loss(source=cam_ab, non_source=cam_bb) Generator_A_gan = self.adv_weight * G_ad_loss_A Generator_A_cycle = self.cycle_weight * reconstruction_B Generator_A_identity = self.identity_weight * identity_A Generator_A_cam = self.cam_weight * cam_A Generator_B_gan = self.adv_weight * G_ad_loss_B Generator_B_cycle = self.cycle_weight * reconstruction_A Generator_B_identity = self.identity_weight * identity_B Generator_B_cam = self.cam_weight * cam_B Generator_A_loss = Generator_A_gan + Generator_A_cycle + Generator_A_identity + Generator_A_cam Generator_B_loss = Generator_B_gan + Generator_B_cycle + Generator_B_identity + Generator_B_cam Discriminator_A_loss = self.adv_weight * D_ad_loss_A Discriminator_B_loss = self.adv_weight * D_ad_loss_B #self.Generator_loss = Generator_A_loss + Generator_B_loss + regularization_loss('generator') self.Generator_loss_total.append( Generator_A_loss + Generator_B_loss + regularization_loss('generator')) #self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss + regularization_loss('discriminator') + self.manual_d_loss self.Discriminator_loss_total.append( Discriminator_A_loss + Discriminator_B_loss + regularization_loss('discriminator') + self.manual_d_loss) reuse_vars = True """ Result Image """ #self.fake_A = x_ba #self.fake_A.append(x_ba) #self.fake_B = x_ab #self.fake_B.append(x_ab) #self.real_A = self.domain_A #self.real_A.append(self.domain_A) #self.real_B = self.domain_B #self.real_B.append(self.domain_B) """ Training """ t_vars = tf.trainable_variables() G_vars = [var for var in t_vars if 'generator' in var.name] D_vars = [var for var in t_vars if 'discriminator' in var.name] """ Compute Average Loss """ self.Generator_loss = tf.reduce_mean(self.Generator_loss_total) self.Discriminator_loss = tf.reduce_mean( self.Discriminator_loss_total) print("G_vars : ") for elemnt in G_vars: print("element from G_vars : " + str(element)) self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize( self.Generator_loss, var_list=G_vars) #self.G_optim.append(tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars)) self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize( self.Discriminator_loss, var_list=D_vars) #self.D_optim.append(tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars)) """" Summary """ self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss) self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss) self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss) self.G_A_gan = tf.summary.scalar("G_A_gan", Generator_A_gan) self.G_A_cycle = tf.summary.scalar("G_A_cycle", Generator_A_cycle) self.G_A_identity = tf.summary.scalar("G_A_identity", Generator_A_identity) self.G_A_cam = tf.summary.scalar("G_A_cam", Generator_A_cam) self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss) self.G_B_gan = tf.summary.scalar("G_B_gan", Generator_B_gan) self.G_B_cycle = tf.summary.scalar("G_B_cycle", Generator_B_cycle) self.G_B_identity = tf.summary.scalar("G_B_identity", Generator_B_identity) self.G_B_cam = tf.summary.scalar("G_B_cam", Generator_B_cam) self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss) self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss) self.rho_var = [] for var in tf.trainable_variables(): if 'rho' in var.name: self.rho_var.append(tf.summary.histogram(var.name, var)) self.rho_var.append( tf.summary.scalar(var.name + "_min", tf.reduce_min(var))) self.rho_var.append( tf.summary.scalar(var.name + "_max", tf.reduce_max(var))) self.rho_var.append( tf.summary.scalar(var.name + "_mean", tf.reduce_mean(var))) g_summary_list = [ self.G_A_loss, self.G_A_gan, self.G_A_cycle, self.G_A_identity, self.G_A_cam, self.G_B_loss, self.G_B_gan, self.G_B_cycle, self.G_B_identity, self.G_B_cam, self.all_G_loss ] g_summary_list.extend(self.rho_var) d_summary_list = [self.D_A_loss, self.D_B_loss, self.all_D_loss] self.G_loss = tf.summary.merge(g_summary_list) #self.G_loss.append(tf.summary.merge(g_summary_list)) self.D_loss = tf.summary.merge(d_summary_list) #self.D_loss.append(tf.summary.merge(d_summary_list)) else: """ Test """ self.test_domain_A = tf.placeholder( tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_A') self.test_domain_B = tf.placeholder( tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_B') self.test_fake_B, _ = self.generate_a2b(self.test_domain_A) self.test_fake_A, _ = self.generate_b2a(self.test_domain_B)
def main(): parser = argparse.ArgumentParser() parser.add_argument( '-s', '--sample', type=str, default='', help='Samples based on input song. Empty string means training.') parser.add_argument('-c', '--concat', dest='concat', action='store_true', help='Enable concatenation.') parser.add_argument('--no-concat', dest='concat', action='store_false', help='Disable concatenation.') parser.set_defaults(concat=False) parser.add_argument('-r', '--record', dest='record', action='store_true', help='Enable recording.') parser.add_argument('--no-record', dest='record', action='store_false', help='Disable recording.') parser.set_defaults( record=False) # Warning: Windows kills python if enabled. args = parser.parse_args() sampling = args.sample != '' if not os.path.exists('Checkpoints_v1'): os.makedirs('Checkpoints_v1') if not os.path.exists('Logs_v1'): os.makedirs('Logs_v1') if not os.path.exists('Samples_v1'): os.makedirs('Samples_v1') if not os.path.exists('Timelines_v1'): os.makedirs('Timelines_v1') filename = 'Dataset/cond_dataset.tfrecord' dataset = tf.data.TFRecordDataset(filename, num_parallel_reads=8) def _parse(example_proto): feature = { 'label': tf.FixedLenFeature([], tf.int64), 'data': tf.FixedLenFeature([], tf.string) } parsed = tf.parse_single_example(example_proto, feature) data = tf.decode_raw(parsed['data'], tf.uint8) label = tf.cast(parsed['label'], tf.uint8) data = tf.py_func(func=np.unpackbits, inp=[data], Tout=tf.uint8) label = tf.py_func(func=np.unpackbits, inp=[label], Tout=tf.uint8) data = tf.cast(data, tf.float32) label = tf.cast(label, tf.float32) data = tf.reshape(data, [CHANNEL_NUM, CLASS_NUM, INPUT_LENGTH]) label.set_shape([8]) label = label[:CHANNEL_NUM] label = tf.expand_dims(tf.expand_dims(label, axis=-1), axis=-1) data = data * 2 - 1 return {'data': data, 'label': label} dataset = dataset.apply(data.shuffle_and_repeat(buffer_size=16384)) dataset = dataset.apply( data.map_and_batch(_parse, batch_size=BATCH_SIZE, num_parallel_batches=16, drop_remainder=True)) dataset = dataset.prefetch(32) dataset = dataset.apply(data.prefetch_to_device('/gpu:0')) iterator = dataset.make_one_shot_iterator() next_data = iterator.get_next() real_input_3 = next_data['data'] label = next_data['label'] input_noise = tf.placeholder(dtype=tf.float32, shape=[None, NOISE_LENGTH], name='input_noise') real_input_2 = tf.layers.average_pooling2d(inputs=real_input_3, pool_size=2, strides=2, padding='same', data_format='channels_first', name='real_input_2') real_input_2 -= tf.reduce_min(real_input_2) real_input_2 /= (tf.reduce_max(real_input_2) + EPSILON) real_input_2 = 2 * real_input_2 - 1 real_input_1 = tf.layers.average_pooling2d(inputs=real_input_2, pool_size=2, strides=2, padding='same', data_format='channels_first', name='real_input_1') real_input_1 -= tf.reduce_min(real_input_1) real_input_1 /= (tf.reduce_max(real_input_1) + EPSILON) real_input_1 = 2 * real_input_1 - 1 train = tf.placeholder(dtype=tf.bool, name='train') # shape: [None, 6, 64, 256] encode = encoder(inputs=real_input_3, update_collection=SPECTRAL_UPDATE_OPS, train=train) # shape: [None, 64] tf.summary.histogram('encode', encode) print('Encoder set') real_input_3_image = tf.expand_dims(real_input_3[:1], axis=-1, name='real_input_3_image') # shape: [1, 6, 128, 512, 1] real_input_2_image = tf.expand_dims(real_input_2[:1], axis=-1, name='real_input_2_image') # shape: [1, 6, 64, 256, 1] real_input_1_image = tf.expand_dims(real_input_1[:1], axis=-1, name='real_input_1_image') # shape: [1, 6, 32, 128, 1] for i in range(CHANNEL_NUM): tf.summary.image('real_input_1_%d' % i, real_input_1_image[:, i]) tf.summary.image('real_input_2_%d' % i, real_input_2_image[:, i]) tf.summary.image('real_input_3_%d' % i, real_input_3_image[:, i]) shared_output = shared_gen(noise=input_noise, label=label, update_collection=SPECTRAL_UPDATE_OPS, train=train) # shape: [None, 64, 16, 64] gen1 = generator1(inputs=shared_output, label=label, update_collection=SPECTRAL_UPDATE_OPS, train=train) output_gen1 = process(gen1, 1, train, SPECTRAL_UPDATE_OPS) * label # shape: [None, 6, 32, 128] gen2 = generator2(inputs=gen1, label=label, update_collection=SPECTRAL_UPDATE_OPS, train=train) output_gen2 = process(gen2, 2, train, SPECTRAL_UPDATE_OPS) * label # shape: [None, 6, 64, 256] gen3 = generator3(inputs=gen2, label=label, update_collection=SPECTRAL_UPDATE_OPS, train=train) output_gen3 = process(gen3, 3, train, SPECTRAL_UPDATE_OPS) * label # shape: [None, 6, 128, 512] print('Generators set') dis1_real = discriminator1(inputs=real_input_1, encode=encode, update_collection=SPECTRAL_UPDATE_OPS) dis1_gen = discriminator1(inputs=output_gen1, encode=encode, update_collection=NO_OPS) dis2_real = discriminator2(inputs=real_input_2, encode=encode, update_collection=SPECTRAL_UPDATE_OPS) dis2_gen = discriminator2(inputs=output_gen2, encode=encode, update_collection=NO_OPS) dis3_real = discriminator3(inputs=real_input_3, encode=encode, update_collection=SPECTRAL_UPDATE_OPS) dis3_gen = discriminator3(inputs=output_gen3, encode=encode, update_collection=NO_OPS) print('Discriminators set') loss_dis1 = tf.reduce_mean(dis1_gen - dis1_real) + lipschitz_penalty( real=real_input_1, gen=output_gen1, encode=encode, discriminator=discriminator1) + DRIFT * tf.reduce_mean( tf.square(dis1_real)) loss_gen1 = -tf.reduce_mean(dis1_gen) loss_dis2 = tf.reduce_mean(dis2_gen - dis2_real) + lipschitz_penalty( real=real_input_2, gen=output_gen2, encode=encode, discriminator=discriminator2) + DRIFT * tf.reduce_mean( tf.square(dis2_real)) loss_gen2 = -tf.reduce_mean(dis2_gen) loss_dis3 = tf.reduce_mean(dis3_gen - dis3_real) + lipschitz_penalty( real=real_input_3, gen=output_gen3, encode=encode, discriminator=discriminator3) + DRIFT * tf.reduce_mean( tf.square(dis3_real)) loss_gen3 = -tf.reduce_mean(dis3_gen) loss_gen = tf.add_n([loss_gen1, loss_gen2, loss_gen3]) / 3 tf.summary.scalar('loss_dis1', loss_dis1) tf.summary.scalar('loss_gen1', loss_gen1) tf.summary.scalar('dis1_real', tf.reduce_mean(dis1_real)) tf.summary.scalar('loss_dis2', loss_dis2) tf.summary.scalar('loss_gen2', loss_gen2) tf.summary.scalar('dis2_real', tf.reduce_mean(dis2_real)) tf.summary.scalar('loss_dis3', loss_dis3) tf.summary.scalar('loss_gen3', loss_gen3) tf.summary.scalar('dis3_real', tf.reduce_mean(dis3_real)) tf.summary.scalar('loss_gen', loss_gen) print('Losses set') gen_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Shared_generator') #gen_var += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Encoder') for i in range(CHANNEL_NUM): gen_var += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator1_%d' % i) gen_var += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator2_%d' % i) gen_var += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator3_%d' % i) dis1_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Discriminator1') dis1_var += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Encoder') dis2_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Discriminator2') dis2_var += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Encoder') dis3_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Discriminator3') dis3_var += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Encoder') gen_extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='Shared_generator') #gen_extra_update_ops += tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='Encoder') for i in range(CHANNEL_NUM): gen_extra_update_ops += tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='Generator1_%d' % i) gen_extra_update_ops += tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='Generator2_%d' % i) gen_extra_update_ops += tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='Generator3_%d' % i) dis1_extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='Discriminator1') dis1_extra_update_ops += tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='Encoder') dis2_extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='Discriminator2') dis2_extra_update_ops += tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='Encoder') dis3_extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='Discriminator3') dis3_extra_update_ops += tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='Encoder') spectral_norm_update_ops = tf.get_collection(SPECTRAL_UPDATE_OPS) with tf.name_scope('optimizers'): with tf.control_dependencies(dis1_extra_update_ops): dis1_train = tf.train.AdamOptimizer(learning_rate=0.0004, beta1=0.5, beta2=0.9).minimize( loss=loss_dis1, var_list=dis1_var, name='dis1_train') print('dis1_train setup') with tf.control_dependencies(dis2_extra_update_ops): dis2_train = tf.train.AdamOptimizer(learning_rate=0.0004, beta1=0.5, beta2=0.9).minimize( loss=loss_dis2, var_list=dis2_var, name='dis2_train') print('dis2_train setup') with tf.control_dependencies(dis3_extra_update_ops): dis3_train = tf.train.AdamOptimizer(learning_rate=0.0004, beta1=0.5, beta2=0.9).minimize( loss=loss_dis3, var_list=dis3_var, name='dis3_train') print('dis3_train setup') with tf.control_dependencies(gen_extra_update_ops): gen_train = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.5, beta2=0.9).minimize( loss=loss_gen, var_list=gen_var, name='gen_train') print('gen_train setup') print('Optimizers set') gpu_options = tf.GPUOptions(allow_growth=True, allocator_type='BFC') config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options) saver = tf.train.Saver() with tf.Session(config=config) as sess: merged = tf.summary.merge_all() sess.run(tf.global_variables_initializer()) if tf.train.latest_checkpoint('Checkpoints_v1') is not None: print('Restoring...') saver.restore(sess, tf.train.latest_checkpoint('Checkpoints_v1')) feed_dict = {input_noise: None, train: True} print('preparing complete') if sampling: feed_dict = {input_noise: None, real_input_3: None, train: False} path = args.sample try: feed_dict[real_input_3] = roll(path)[:BATCH_SIZE] except: print('Error while opening file.') exit() feed_dict[input_noise] = get_noise([BATCH_SIZE, NOISE_LENGTH]) samples = sess.run(output_gen3, feed_dict=feed_dict) path = path.split('/')[-1] if not os.path.exists('Samples_v1/sample_%s'): os.mkdir('Samples_v1/sample_%s' % path) np.save(file='Samples_v1/sample_%s' % path + '/%s' % path, arr=samples) unpack_sample(name='Samples_v1/sample_%s' % path + '/%s.npy' % path, concat=args.concat) exit() writer = tf.summary.FileWriter('Logs_v1', sess.graph) run_options = tf.RunOptions(report_tensor_allocations_upon_oom=True) epoch_num = 100001 for train_count in tqdm(range(epoch_num), smoothing=0.7): for i in range(TRAIN_RATIO_DIS): feed_dict[input_noise] = get_noise([BATCH_SIZE, NOISE_LENGTH]) feed_dict[train] = True *_, loss_val_dis1, loss_val_dis2, loss_val_dis3 = sess.run( [ dis1_train, dis2_train, dis3_train, loss_dis1, loss_dis2, loss_dis3 ], feed_dict=feed_dict, options=run_options) for i in range(TRAIN_RATIO_GEN): feed_dict[input_noise] = get_noise([BATCH_SIZE, NOISE_LENGTH]) feed_dict[train] = True summary, _, loss_val_gen = sess.run( [merged, gen_train, loss_gen], feed_dict=feed_dict, options=run_options) sess.run(spectral_norm_update_ops) writer.add_summary(summary, train_count) tqdm.write('%06d' % train_count, end=' ') tqdm.write('Discriminator1 loss : %.7f' % loss_val_dis1, end=' ') tqdm.write('Discriminator2 loss : %.7f' % loss_val_dis2, end=' ') tqdm.write('Discriminator3 loss : %.7f' % loss_val_dis3, end=' ') tqdm.write('Generator loss : %.7f' % loss_val_gen) if train_count % 1000 == 0: feed_dict[input_noise] = get_noise([BATCH_SIZE, NOISE_LENGTH]) feed_dict[train] = False samples = sess.run(output_gen3, feed_dict=feed_dict) np.save(file='Samples_v1/song_%06d' % train_count, arr=samples) unpack_sample('Samples_v1/song_%06d' % train_count) save_path = saver.save( sess, 'Checkpoints_v1/song_%06d' % train_count + '.ckpt') tqdm.write('Model Saved: %s' % save_path) if args.record: trace_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) # pylint: disable=E1101 run_metadata = tf.RunMetadata() sess.run([dis1_train, dis2_train, dis3_train, gen_train], feed_dict=feed_dict, options=trace_options, run_metadata=run_metadata) writer.add_run_metadata(run_metadata, 'run_%d' % train_count) tl = timeline.Timeline(run_metadata.step_stats) # pylint: disable=E1101 ctf = tl.generate_chrome_trace_format() with open('Timelines_v1/timeline_%d.json' % train_count, 'w') as f: f.write(ctf) writer.close()
def build_model(self): label_fix_onehot_list = [] """ Input Image""" if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA': img_class = ImageData_celebA(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.label_list, self.augment_flag) img_class.preprocess(self.phase) else: img_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.label_list, self.augment_flag) img_class.preprocess() label_fix_onehot_list = img_class.label_onehot_list label_fix_onehot_list = tf.tile( tf.expand_dims(label_fix_onehot_list, axis=1), [1, self.batch_size, 1]) dataset_num = len(img_class.image) print("Dataset number : ", dataset_num) if self.phase == 'train': self.lr = tf.placeholder(tf.float32, name='learning_rate') if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA': img_and_label = tf.data.Dataset.from_tensor_slices( (img_class.image, img_class.label, img_class.train_label_onehot_list)) else: img_and_label = tf.data.Dataset.from_tensor_slices( (img_class.image, img_class.label)) gpu_device = '/gpu:0' img_and_label = img_and_label.apply( shuffle_and_repeat(dataset_num)).apply( map_and_batch(img_class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, None)) img_and_label_iterator = img_and_label.make_one_shot_iterator() if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA': self.x_real, label_org, label_fix_onehot_list = img_and_label_iterator.get_next( ) label_trg = tf.random_shuffle( label_org) # Target domain labels label_fix_onehot_list = tf.transpose(label_fix_onehot_list, perm=[1, 0, 2]) else: self.x_real, label_org = img_and_label_iterator.get_next() label_trg = tf.random_shuffle( label_org) # Target domain labels """ Define Generator, Discriminator """ fake_style_code = tf.random_normal( shape=[self.batch_size, self.style_dim]) x_fake = self.generator(self.x_real, label_trg, fake_style_code) # real a recon_style_code = tf.random_normal( shape=[self.batch_size, self.style_dim]) x_recon = self.generator(x_fake, label_org, recon_style_code, reuse=True) # real b real_logit, real_cls, _ = self.discriminator(self.x_real) fake_logit, fake_cls, fake_noise = self.discriminator(x_fake, reuse=True) """ Define Loss """ if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan': GP = self.gradient_panalty(real=self.x_real, fake=x_fake) else: GP = 0 g_adv_loss = self.adv_weight * generator_loss( self.gan_type, fake_logit) g_cls_loss = self.cls_weight * classification_loss(logit=fake_cls, label=label_trg) g_rec_loss = self.rec_weight * L1_loss(self.x_real, x_recon) g_noise_loss = self.noise_weight * L1_loss(fake_style_code, fake_noise) d_adv_loss = self.adv_weight * discriminator_loss( self.gan_type, real_logit, fake_logit) + GP d_cls_loss = self.cls_weight * classification_loss(logit=real_cls, label=label_org) d_noise_loss = self.noise_weight * L1_loss(fake_style_code, fake_noise) self.d_loss = d_adv_loss + d_cls_loss + d_noise_loss self.g_loss = g_adv_loss + g_cls_loss + g_rec_loss + g_noise_loss """ Result Image """ if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA': self.x_fake_list = [] for _ in range(self.num_style): random_style_code = tf.random_normal( shape=[self.batch_size, self.style_dim]) self.x_fake_list.append( tf.map_fn(lambda c: self.generator( self.x_real, c, random_style_code, reuse=True), label_fix_onehot_list, dtype=tf.float32)) else: self.x_fake_list = [] for _ in range(self.num_style): random_style_code = tf.random_normal( shape=[self.batch_size, self.style_dim]) self.x_fake_list.append( tf.map_fn(lambda c: self.generator( self.x_real, c, random_style_code, reuse=True), label_fix_onehot_list, dtype=tf.float32)) """ Training """ t_vars = tf.trainable_variables() G_vars = [var for var in t_vars if 'generator' in var.name] D_vars = [var for var in t_vars if 'discriminator' in var.name] self.g_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize( self.g_loss, var_list=G_vars) self.d_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize( self.d_loss, var_list=D_vars) """" Summary """ self.Generator_loss = tf.summary.scalar("g_loss", self.g_loss) self.Discriminator_loss = tf.summary.scalar("d_loss", self.d_loss) self.g_adv_loss = tf.summary.scalar("g_adv_loss", g_adv_loss) self.g_cls_loss = tf.summary.scalar("g_cls_loss", g_cls_loss) self.g_rec_loss = tf.summary.scalar("g_rec_loss", g_rec_loss) self.g_noise_loss = tf.summary.scalar("g_noise_loss", g_noise_loss) self.d_adv_loss = tf.summary.scalar("d_adv_loss", d_adv_loss) self.d_cls_loss = tf.summary.scalar("d_cls_loss", d_cls_loss) self.d_noise_loss = tf.summary.scalar("d_noise_loss", d_noise_loss) self.g_summary_loss = tf.summary.merge([ self.Generator_loss, self.g_adv_loss, self.g_cls_loss, self.g_rec_loss, self.g_noise_loss ]) self.d_summary_loss = tf.summary.merge([ self.Discriminator_loss, self.d_adv_loss, self.d_cls_loss, self.d_noise_loss ]) else: """ Test """ if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA': img_and_label = tf.data.Dataset.from_tensor_slices( (img_class.test_image, img_class.test_label, img_class.test_label_onehot_list)) dataset_num = len(img_class.test_image) gpu_device = '/gpu:0' img_and_label = img_and_label.apply( shuffle_and_repeat(dataset_num)).apply( map_and_batch(img_class.image_processing, batch_size=self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, None)) img_and_label_iterator = img_and_label.make_one_shot_iterator() self.x_test, _, self.test_label_fix_onehot_list = img_and_label_iterator.get_next( ) self.test_img_placeholder = tf.placeholder( tf.float32, [1, self.img_height, self.img_width, self.img_ch]) self.test_label_fix_placeholder = tf.placeholder( tf.float32, [self.c_dim, 1, self.c_dim]) self.custom_image = tf.placeholder( tf.float32, [1, self.img_height, self.img_width, self.img_ch], name='custom_image') # Custom Image custom_label_fix_onehot_list = tf.transpose( np.expand_dims(label2onehot(self.label_list), axis=0), perm=[1, 0, 2]) # [c_dim, bs, c_dim] """ Test Image """ test_random_style_code = tf.random_normal( shape=[1, self.style_dim]) self.x_test_fake_list = tf.map_fn( lambda c: self.generator(self.test_img_placeholder, c, test_random_style_code), self.test_label_fix_placeholder, dtype=tf.float32) self.custom_fake_image = tf.map_fn(lambda c: self.generator( self.custom_image, c, test_random_style_code, reuse=True), custom_label_fix_onehot_list, dtype=tf.float32) else: self.custom_image = tf.placeholder( tf.float32, [1, self.img_height, self.img_width, self.img_ch], name='custom_image') # Custom Image custom_label_fix_onehot_list = tf.transpose( np.expand_dims(label2onehot(self.label_list), axis=0), perm=[1, 0, 2]) # [c_dim, bs, c_dim] test_random_style_code = tf.random_normal( shape=[1, self.style_dim]) self.custom_fake_image = tf.map_fn( lambda c: self.generator(self.custom_image, c, test_random_style_code), custom_label_fix_onehot_list, dtype=tf.float32)
def build_model(self): self.lr = tf.placeholder(tf.float32, name='learning_rate') """ Input Image""" Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag) trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset) trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset) gpu_device = '/gpu:0' trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, self.batch_size)) trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, self.batch_size)) trainA_iterator = trainA.make_one_shot_iterator() trainB_iterator = trainB.make_one_shot_iterator() self.identity_A = trainA_iterator.get_next() self.shape_A = trainA_iterator.get_next() self.other_A = trainA_iterator.get_next() self.shape_B = trainB_iterator.get_next() self.test_identity_A = tf.placeholder( tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_identity_A') self.test_shape_B = tf.placeholder( tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_shape_B') """ Define Generator, Discriminator """ self.fake_same = self.generator(x_identity=self.identity_A, x_shape=self.shape_A) self.fake_diff = self.generator(x_identity=self.identity_A, x_shape=self.shape_B, reuse=True) fake_diff_shape = self.generator(x_identity=self.shape_B, x_shape=self.fake_diff, reuse=True) fake_diff_identity = self.generator(x_identity=self.fake_diff, x_shape=self.shape_B, reuse=True) real_logit = self.discriminator(x_identity=self.identity_A, x=self.other_A) fake_logit = self.discriminator(x_identity=self.identity_A, x=self.fake_diff, reuse=True) """ Define Loss """ g_identity_loss = self.adv_weight * generator_loss( self.gan_type, fake_logit) * 64 g_shape_loss_same = self.L1_weight * L1_loss(self.fake_same, self.shape_A) g_shape_loss_diff_shape = self.L1_weight * L1_loss( fake_diff_shape, self.shape_B) + self.L1_weight * L1_loss( self.fake_diff, self.shape_B) g_shape_loss_diff_identity = self.L1_weight * L1_loss( fake_diff_identity, self.fake_diff) self.Generator_loss = g_identity_loss + g_shape_loss_same + g_shape_loss_diff_shape + g_shape_loss_diff_identity self.Discriminator_loss = self.adv_weight * discriminator_loss( self.gan_type, real=real_logit, fake=fake_logit) """ Result Image """ self.test_fake = self.generator(x_identity=self.test_identity_A, x_shape=self.test_shape_B, reuse=True) """ Training """ t_vars = tf.trainable_variables() G_vars = [var for var in t_vars if 'generator' in var.name] D_vars = [var for var in t_vars if 'discriminator' in var.name] self.G_optim = tf.train.AdamOptimizer( self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars) self.D_optim = tf.train.AdamOptimizer( self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars) """" Summary """ self.G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss) self.D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss) self.G_identity = tf.summary.scalar("G_identity", g_identity_loss) self.G_shape_loss_same = tf.summary.scalar("G_shape_loss_same", g_shape_loss_same) self.G_shape_loss_diff_shape = tf.summary.scalar( "G_shape_loss_diff_shape", g_shape_loss_diff_shape) self.G_shape_loss_diff_identity = tf.summary.scalar( "G_shape_loss_diff_identity", g_shape_loss_diff_identity) self.G_loss_merge = tf.summary.merge([ self.G_loss, self.G_identity, self.G_shape_loss_same, self.G_shape_loss_diff_shape, self.G_shape_loss_diff_identity ]) self.D_loss_merge = tf.summary.merge([self.D_loss])
def get_squad_dataset(opts, is_training): seq_length = opts['seq_length'] name_to_features = { "unique_ids": tf.FixedLenFeature([], tf.int64), "input_ids": tf.FixedLenFeature([seq_length], tf.int64), "input_mask": tf.FixedLenFeature([seq_length], tf.int64), "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), } if is_training: name_to_features["start_positions"] = tf.FixedLenFeature([], tf.int64) name_to_features["end_positions"] = tf.FixedLenFeature([], tf.int64) micro_batch_size = opts['micro_batch_size'] tfrecord_dir = opts['tfrecord_dir'] opts['global_batch_size'] = opts['replicas'] * opts[ 'micro_batch_size'] * opts['gradient_accumulation_count'] if opts["version_2_with_negative"]: base_name_train = f"{opts['seq_length']}_{opts['doc_stride']}_{opts['max_query_length']}_SQuAD20" base_name_eval = f"{opts['seq_length']}_{opts['doc_stride']}_{opts['max_query_length']}_{opts['global_batch_size']}_SQuAD20" else: base_name_train = f"{opts['seq_length']}_{opts['doc_stride']}_{opts['max_query_length']}_SQuAD11" base_name_eval = f"{opts['seq_length']}_{opts['doc_stride']}_{opts['max_query_length']}_{opts['global_batch_size']}_SQuAD11" if is_training: filename = os.path.join(tfrecord_dir, base_name_train + "_train.tfrecord") input_file = opts['train_file'] else: filename = os.path.join(tfrecord_dir, base_name_eval + "_eval.tfrecord") input_file = opts['predict_file'] if not os.path.exists(filename): tf.logging.info(f'Preprocessing SQuAD input file: {filename}') cache_path = os.path.dirname(filename) if not os.path.exists(cache_path): tf.logging.info(f'Creating SQuAD cache in {cache_path}') os.makedirs(cache_path) examples = read_squad_examples(input_file=input_file, opts=opts, is_training=is_training) writer = FeatureWriter(filename=filename, is_training=is_training) features = [] tokenizer = tokenization.FullTokenizer( vocab_file=opts['vocab_file'], do_lower_case=opts['do_lower_case']) def append_feature(feature): features.append(feature) writer.process_feature(feature) if is_training: # Don't need to pad for repeated dataset padding_to = 1 else: # Padding for pipeline depth padding_to = opts['replicas'] * opts['micro_batch_size'] * opts[ 'gradient_accumulation_count'] num_of_features = convert_examples_to_features( examples=examples, tokenizer=tokenizer, max_seq_length=opts['seq_length'], doc_stride=opts['doc_stride'], max_query_length=opts['max_query_length'], is_training=is_training, output_fn=append_feature, padding_to=padding_to) if is_training: metadatafile = os.path.join( tfrecord_dir, "train_" + base_name_train + ".metadata") if not os.path.exists(metadatafile): tf.logging.info( f'Logging converted no. of SQuAD train features in {metadatafile}' ) with open(metadatafile, 'w') as f: f.write(str(num_of_features) + '\n') else: metadatafile = os.path.join(tfrecord_dir, "eval_" + base_name_eval + ".metadata") if not os.path.exists(metadatafile): tf.logging.info( f'Logging converted no. of SQuAD eval features in {metadatafile}' ) with open(metadatafile, 'w') as f: f.write(str(num_of_features) + '\n') writer.close() d = tf.data.TFRecordDataset(filename) if is_training: if opts['distributed_worker_count'] > 1: d = d.shard(num_shards=opts['distributed_worker_count'], index=opts['distributed_worker_index']) d = d.shuffle(buffer_size=100000, reshuffle_each_iteration=False) d = d.repeat() if opts['generated_data']: d = d.repeat() d = d.apply( map_and_batch(lambda record: _decode_record(record, name_to_features), batch_size=micro_batch_size, drop_remainder=True)) return d
def build_model(self, A): self.lr = tf.placeholder(tf.float32, name='learning_rate') """ Input Image""" Image_data_class = ImageData(load_size=self.img_size, channels=self.img_ch, data_path=self.dataset_path, selected_attrs=self.selected_attrs, augment_flag=self.augment_flag) Image_data_class.preprocess() train_dataset_num = len(Image_data_class.train_dataset) test_dataset_num = len(Image_data_class.test_dataset) train_dataset = tf.data.Dataset.from_tensor_slices((Image_data_class.train_dataset, Image_data_class.train_dataset_label, Image_data_class.train_dataset_fix_label)) test_dataset = tf.data.Dataset.from_tensor_slices((Image_data_class.test_dataset, Image_data_class.test_dataset_label, Image_data_class.test_dataset_fix_label)) gpu_device = '/gpu:0' train_dataset = train_dataset.\ apply(shuffle_and_repeat(train_dataset_num)).\ apply(map_and_batch(Image_data_class.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).\ apply(prefetch_to_device(gpu_device, self.batch_size)) test_dataset = test_dataset.\ apply(shuffle_and_repeat(test_dataset_num)).\ apply(map_and_batch(Image_data_class.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).\ apply(prefetch_to_device(gpu_device, self.batch_size)) train_dataset_iterator = train_dataset.make_one_shot_iterator() test_dataset_iterator = test_dataset.make_one_shot_iterator() self.x_real, label_org, label_fix_list = train_dataset_iterator.get_next() # Input image / Original domain labels label_trg = tf.random_shuffle(label_org) # Target domain labels label_fix_list = tf.transpose(label_fix_list, perm=[1, 0, 2]) self.x_test, test_label_org, test_label_fix_list = test_dataset_iterator.get_next() # Input image / Original domain labels test_label_fix_list = tf.transpose(test_label_fix_list, perm=[1, 0, 2]) self.custom_image = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='custom_image') # Custom Image custom_label_fix_list = tf.transpose(create_labels(self.custom_label, self.selected_attrs), perm=[1, 0, 2]) """ Define Generator, Discriminator """ # binary label transformation dist = tf_contrib.distributions.Categorical(probs=[0.25, 0.5, 0.25]) hat_c = tf.cast(dist.sample([self.batch_size, self.c_dim]) - 1, dtype ='float32') x_fake, w_fake = self.generator(self.x_real, hat_c) # real a x_recon, w_recon = self.generator(x_fake, -hat_c, reuse=True) # real b real_logit, real_cls = self.discriminator(self.x_real) fake_logit, fake_cls = self.discriminator(x_fake, reuse=True) # warp cycle A_fake = tf_contrib.image.dense_image_warp(A, w_fake) A_cycle = tf_contrib.image.dense_image_warp(A_fake, w_recon) """ Define Loss """ if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan' : GP = self.gradient_panalty(real=self.x_real, fake=x_fake) else : GP = 0 g_adv_loss = generator_loss(loss_func=self.gan_type, fake=fake_logit) g_cls_loss = binary_label_loss(hat_c, fake_cls, self.c_dim) warp_cycle_loss = tf.reduce_mean((A_cycle - A)** 2, axis=[1,2,3]) g_rec_loss = tf.reduce_mean(warp_cycle_loss) smooth_loss = tf.reduce_mean(total_variation(w_fake)) d_adv_loss = discriminator_loss(loss_func=self.gan_type, real=real_logit, fake=fake_logit) d_cls_loss = classification_loss(logit=real_cls, label=label_org) self.d_loss = self.adv_weight * d_adv_loss + self.gp_weight * GP + self.cls_weight * d_cls_loss self.g_loss = self.adv_weight * g_adv_loss + self.cls_weight * g_cls_loss + self.rec_weight * g_rec_loss + self.smooth_weight * smooth_loss """ Result Image """ self.x_fake_list = tf.map_fn(lambda x : self.generator(self.x_real, x, reuse=True)[0], label_fix_list, dtype=tf.float32) """ Test Image """ self.x_test_fake_list = tf.map_fn(lambda x : self.generator(self.x_test, x, reuse=True)[0], test_label_fix_list, dtype=tf.float32) self.custom_fake_image = tf.map_fn(lambda x : self.generator(self.custom_image, x, reuse=True)[0], custom_label_fix_list, dtype=tf.float32) """ Training """ t_vars = tf.trainable_variables() G_vars = [var for var in t_vars if 'generator' in var.name] D_vars = [var for var in t_vars if 'discriminator' in var.name] self.g_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.g_loss, var_list=G_vars) self.d_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.d_loss, var_list=D_vars) """" Summary """ self.Generator_loss = tf.summary.scalar("Generator_loss", self.g_loss) self.Discriminator_loss = tf.summary.scalar("Discriminator_loss", self.d_loss) self.g_adv_loss = tf.summary.scalar("g_adv_loss", g_adv_loss) self.g_cls_loss = tf.summary.scalar("g_cls_loss", g_cls_loss) self.g_rec_loss = tf.summary.scalar("g_rec_loss", g_rec_loss) self.d_adv_loss = tf.summary.scalar("d_adv_loss", d_adv_loss) self.d_cls_loss = tf.summary.scalar("d_cls_loss", d_cls_loss) self.g_summary_loss = tf.summary.merge([self.Generator_loss, self.g_adv_loss, self.g_cls_loss, self.g_rec_loss]) self.d_summary_loss = tf.summary.merge([self.Discriminator_loss, self.d_adv_loss, self.d_cls_loss])
def build_model(self): self.lr = tf.placeholder(tf.float32, name='learning_rate') """ Input Image""" Image_data_class = ImageData(load_size=self.img_size, channels=self.img_ch, data_path=self.dataset_path, dataset_name=self.dataset_name, selected_attrs=self.selected_attrs, augment_flag=self.augment_flag) Image_data_class.preprocess() train_dataset_num = len(Image_data_class.train_dataset) test_dataset_num = len(Image_data_class.test_dataset) train_dataset = tf.data.Dataset.from_tensor_slices( (Image_data_class.train_dataset, Image_data_class.train_dataset_label, Image_data_class.train_dataset_fix_label)) test_dataset = tf.data.Dataset.from_tensor_slices( (Image_data_class.test_dataset, Image_data_class.test_dataset_label, Image_data_class.test_dataset_fix_label)) gpu_device = '/gpu:0' train_dataset = train_dataset.\ apply(shuffle_and_repeat(train_dataset_num)).\ apply(map_and_batch(Image_data_class.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).\ apply(prefetch_to_device(gpu_device, self.batch_size)) test_dataset = test_dataset.\ apply(shuffle_and_repeat(test_dataset_num)).\ apply(map_and_batch(Image_data_class.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).\ apply(prefetch_to_device(gpu_device, self.batch_size)) train_dataset_iterator = train_dataset.make_one_shot_iterator() test_dataset_iterator = test_dataset.make_one_shot_iterator() self.x_real, label_org, label_fix_list = train_dataset_iterator.get_next( ) # Input image / Original domain labels label_trg = tf.random_shuffle(label_org) # Target domain labels label_fix_list = tf.transpose(label_fix_list, perm=[1, 0, 2]) self.label_org = label_org self.label_trg = label_trg self.label_fix_list = label_fix_list self.x_test, test_label_org, test_label_fix_list = test_dataset_iterator.get_next( ) # Input image / Original domain labels test_label_fix_list = tf.transpose(test_label_fix_list, perm=[1, 0, 2]) self.custom_image = tf.placeholder( tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='custom_image') # Custom Image custom_label_fix_list = tf.transpose(create_labels( self.custom_label, self.selected_attrs, self.dataset_name), perm=[1, 0, 2]) """ Define Generator, Discriminator """ x_fake = self.generator(self.x_real, label_trg) # real a x_recon = self.generator(x_fake, label_org, reuse=True) # real b real_logit, real_cls = self.discriminator(self.x_real) fake_logit, fake_cls = self.discriminator(x_fake, reuse=True) """ Define Loss """ if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan': GP = self.gradient_panalty(real=self.x_real, fake=x_fake) else: GP = 0 g_adv_loss = generator_loss(loss_func=self.gan_type, fake=fake_logit) g_cls_loss = classification_loss(logit=fake_cls, label=label_trg) g_rec_loss = L1_loss(self.x_real, x_recon) d_adv_loss = discriminator_loss( loss_func=self.gan_type, real=real_logit, fake=fake_logit) + GP d_cls_loss = classification_loss(logit=real_cls, label=label_org) self.d_loss = self.adv_weight * d_adv_loss + self.cls_weight * d_cls_loss self.g_loss = self.adv_weight * g_adv_loss + self.cls_weight * g_cls_loss + self.rec_weight * g_rec_loss """ Result Image """ self.x_fake_list = tf.map_fn( lambda x: self.generator(self.x_real, x, reuse=True), label_fix_list, dtype=tf.float32) """ Test Image """ self.x_test_fake_list = tf.map_fn( lambda x: self.generator(self.x_test, x, reuse=True), test_label_fix_list, dtype=tf.float32) self.custom_fake_image = tf.map_fn( lambda x: self.generator(self.custom_image, x, reuse=True), custom_label_fix_list, dtype=tf.float32) """ Training """ t_vars = tf.trainable_variables() G_vars = [var for var in t_vars if 'generator' in var.name] D_vars = [var for var in t_vars if 'discriminator' in var.name] self.g_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize( self.g_loss, var_list=G_vars) self.d_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize( self.d_loss, var_list=D_vars) """" Summary """ self.Generator_loss = tf.summary.scalar("Generator_loss", self.g_loss) self.Discriminator_loss = tf.summary.scalar("Discriminator_loss", self.d_loss) self.g_adv_loss = tf.summary.scalar("g_adv_loss", g_adv_loss) self.g_cls_loss = tf.summary.scalar("g_cls_loss", g_cls_loss) self.g_rec_loss = tf.summary.scalar("g_rec_loss", g_rec_loss) self.d_adv_loss = tf.summary.scalar("d_adv_loss", d_adv_loss) self.d_cls_loss = tf.summary.scalar("d_cls_loss", d_cls_loss) self.g_summary_loss = tf.summary.merge([ self.Generator_loss, self.g_adv_loss, self.g_cls_loss, self.g_rec_loss ]) self.d_summary_loss = tf.summary.merge( [self.Discriminator_loss, self.d_adv_loss, self.d_cls_loss])
def build_model(self): if self.phase == 'train': """ Input Image""" img_data_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.augment_flag) img_data_class.preprocess() self.dataset_num = len(img_data_class.image_list) img_and_class = tf.data.Dataset.from_tensor_slices( (img_data_class.image_list, img_data_class.class_list)) gpu_device = '/gpu:0' img_and_class = img_and_class.apply( shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(img_data_class.image_processing, batch_size=self.batch_size * self.gpu_num, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, None)) img_and_class_iterator = img_and_class.make_one_shot_iterator() self.content_img, self.content_class = img_and_class_iterator.get_next( ) self.style_img, self.style_class = img_and_class_iterator.get_next( ) self.content_img = tf.split(self.content_img, num_or_size_splits=self.gpu_num) self.content_class = tf.split(self.content_class, num_or_size_splits=self.gpu_num) self.style_img = tf.split(self.style_img, num_or_size_splits=self.gpu_num) self.style_class = tf.split(self.style_class, num_or_size_splits=self.gpu_num) self.fake_img = [] d_adv_losses = [] g_adv_losses = [] g_recon_losses = [] g_feature_losses = [] for gpu_id in range(self.gpu_num): with tf.device( tf.DeviceSpec(device_type="GPU", device_index=gpu_id)): with tf.variable_scope(tf.get_variable_scope(), reuse=(gpu_id > 0)): """ Define Generator, Discriminator """ content_code = self.content_encoder( self.content_img[gpu_id]) style_class_code = self.class_encoder( self.style_img[gpu_id]) content_class_code = self.class_encoder( self.content_img[gpu_id]) fake_img = self.generator(content_code, style_class_code) recon_img = self.generator(content_code, content_class_code) real_logit, style_feature_map = self.discriminator( self.style_img[gpu_id], self.style_class[gpu_id]) fake_logit, fake_feature_map = self.discriminator( fake_img, self.style_class[gpu_id]) recon_logit, recon_feature_map = self.discriminator( recon_img, self.content_class[gpu_id]) _, content_feature_map = self.discriminator( self.content_img[gpu_id], self.content_class[gpu_id]) """ Define Loss """ d_adv_loss = self.adv_weight * discriminator_loss( self.gan_type, real_logit, fake_logit, self.style_img[gpu_id]) g_adv_loss = 0.5 * self.adv_weight * ( generator_loss(self.gan_type, fake_logit) + generator_loss(self.gan_type, recon_logit)) g_recon_loss = self.recon_weight * L1_loss( self.content_img[gpu_id], recon_img) content_feature_map = tf.reduce_mean(tf.reduce_mean( content_feature_map, axis=2), axis=1) recon_feature_map = tf.reduce_mean(tf.reduce_mean( recon_feature_map, axis=2), axis=1) fake_feature_map = tf.reduce_mean(tf.reduce_mean( fake_feature_map, axis=2), axis=1) style_feature_map = tf.reduce_mean(tf.reduce_mean( style_feature_map, axis=2), axis=1) g_feature_loss = self.feature_weight * ( L1_loss(recon_feature_map, content_feature_map) + L1_loss(fake_feature_map, style_feature_map)) d_adv_losses.append(d_adv_loss) g_adv_losses.append(g_adv_loss) g_recon_losses.append(g_recon_loss) g_feature_losses.append(g_feature_loss) self.fake_img.append(fake_img) self.g_loss = tf.reduce_mean(g_adv_losses) + \ tf.reduce_mean(g_recon_losses) + \ tf.reduce_mean(g_feature_losses) + regularization_loss('encoder') + regularization_loss('generator') self.d_loss = tf.reduce_mean(d_adv_losses) + regularization_loss( 'discriminator') """ Training """ t_vars = tf.trainable_variables() G_vars = [ var for var in t_vars if 'encoder' in var.name or 'generator' in var.name ] D_vars = [var for var in t_vars if 'discriminator' in var.name] if self.gpu_num == 1: prev_G_optim = tf.train.RMSPropOptimizer( self.lr, decay=0.99, epsilon=1e-8).minimize(self.g_loss, var_list=G_vars) self.D_optim = tf.train.RMSPropOptimizer( self.lr, decay=0.99, epsilon=1e-8).minimize(self.d_loss, var_list=D_vars) # Pytorch : decay=0.99, epsilon=1e-8 else: prev_G_optim = tf.train.RMSPropOptimizer( self.lr, decay=0.99, epsilon=1e-8).minimize(self.g_loss, var_list=G_vars, colocate_gradients_with_ops=True) self.D_optim = tf.train.RMSPropOptimizer( self.lr, decay=0.99, epsilon=1e-8).minimize(self.d_loss, var_list=D_vars, colocate_gradients_with_ops=True) # Pytorch : decay=0.99, epsilon=1e-8 self.ema = tf.train.ExponentialMovingAverage(decay=self.ema_decay) with tf.control_dependencies([prev_G_optim]): self.G_optim = self.ema.apply(G_vars) """" Summary """ self.summary_g_loss = tf.summary.scalar("g_loss", self.g_loss) self.summary_d_loss = tf.summary.scalar("d_loss", self.d_loss) self.summary_g_adv_loss = tf.summary.scalar( "g_adv_loss", tf.reduce_mean(g_adv_losses)) self.summary_g_recon_loss = tf.summary.scalar( "g_recon_loss", tf.reduce_mean(g_recon_losses)) self.summary_g_feature_loss = tf.summary.scalar( "g_feature_loss", tf.reduce_mean(g_feature_losses)) g_summary_list = [ self.summary_g_loss, self.summary_g_adv_loss, self.summary_g_recon_loss, self.summary_g_feature_loss ] d_summary_list = [self.summary_d_loss] self.summary_merge_g_loss = tf.summary.merge(g_summary_list) self.summary_merge_d_loss = tf.summary.merge(d_summary_list) else: """ Test """ self.ema = tf.train.ExponentialMovingAverage(decay=self.ema_decay) self.test_content_img = tf.placeholder( tf.float32, [1, self.img_height, self.img_width, self.img_ch]) self.test_class_img = tf.placeholder( tf.float32, [self.K, self.img_height, self.img_width, self.img_ch]) test_content_code = self.content_encoder(self.test_content_img) test_style_class_code = tf.reduce_mean(self.class_encoder( self.test_class_img), axis=0, keepdims=True) self.test_fake_img = self.generator(test_content_code, test_style_class_code)
def get_pretraining_dataset(opts, data_type, is_training=True, num_cpu_threads=4, use_static_mask=False): if is_training: input_file = opts['train_file'] else: input_file = opts['test_file'] micro_batch_size = opts['micro_batch_size'] max_seq_length = opts['seq_length'] max_predictions_per_seq = opts['max_predictions_per_seq'] input_files = [] for input_pattern in input_file.split(","): input_files.extend(tf.gfile.Glob(input_pattern)) tf.logging.info("*** Input Files ***") for input_file in input_files: tf.logging.info(" %s" % input_file) if use_static_mask: # The masked tokens have been re-arranaged to always be at the first # 'max_predictions_per_seq' positions. name_to_features = { "input_ids": tf.FixedLenFeature([max_seq_length], tf.int64), "input_position": tf.FixedLenFeature([max_seq_length], tf.int64), "segment_ids": tf.FixedLenFeature([max_seq_length], tf.int64), "mask_padding_index": tf.FixedLenFeature([1], tf.int64), "seq_padding_index": tf.FixedLenFeature([1], tf.int64), "masked_labels": tf.FixedLenFeature([max_predictions_per_seq], tf.int64), "masked_lm_weights": tf.FixedLenFeature([max_predictions_per_seq], tf.float32), "next_sentence_labels": tf.FixedLenFeature([1], tf.int64), } else: # Default, the tokens have not been re-arranged. name_to_features = { "input_ids": tf.FixedLenFeature([max_seq_length], tf.int64), "input_mask": tf.FixedLenFeature([max_seq_length], tf.int64), "segment_ids": tf.FixedLenFeature([max_seq_length], tf.int64), "masked_lm_positions": tf.FixedLenFeature([max_predictions_per_seq], tf.int64), "masked_lm_ids": tf.FixedLenFeature([max_predictions_per_seq], tf.int64), "masked_lm_weights": tf.FixedLenFeature([max_predictions_per_seq], tf.float32), "next_sentence_labels": tf.FixedLenFeature([1], tf.int64), } # For training, we want a lot of parallel reading and shuffling. # For eval, we want no shuffling and parallel reading doesn't matter. if is_training: d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) d = d.repeat() # `cycle_length` is the number of parallel files that get read. cycle_length = min(num_cpu_threads, len(input_files)) # `sloppy` mode means that the interleaving is not exact. This adds # even more randomness to the training pipeline. d = d.apply(parallel_interleave(tf.data.TFRecordDataset, sloppy=is_training, cycle_length=cycle_length)) # `buffer_size` should be set big enough to keep data shuffle sufficiently. if opts['distributed_worker_count'] > 1: d = d.shard(num_shards=opts['distributed_worker_count'], index=opts['distributed_worker_index']) d = d.shuffle(buffer_size=1000, seed=opts['seed']) else: d = d.shuffle(buffer_size=1000) else: d = tf.data.TFRecordDataset(input_files) d = d.repeat() d = d.apply(map_and_batch( lambda record: _decode_record(record, name_to_features, data_type), batch_size=micro_batch_size, num_parallel_batches=num_cpu_threads, drop_remainder=True)) return d
def build_model(self): if self.custom_dataset: Image_Data_Class = ImageData(self.img_size, self.c_dim) inputs = tf.data.Dataset.from_tensor_slices(self.data) gpu_device = '/gpu:0' inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, self.batch_size)) inputs_iterator = inputs.make_one_shot_iterator() self.inputs = inputs_iterator.get_next() else: self.inputs = tf.placeholder( tf.float32, [self.batch_size, self.img_size, self.img_size, self.c_dim], name='real_images') self.z = tf.placeholder(tf.float32, [self.batch_size, 1, 1, self.z_dim], name='z') real_logits = self.discriminator(self.inputs) fake_images = self.generator(self.z) fake_logits = self.discriminator(fake_images, reuse=True) if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan': GP = self.gradient_penalty(real=self.inputs, fake=fake_images) else: GP = 0 self.d_loss = discriminator_loss( self.gan_type, real=real_logits, fake=fake_logits) + GP self.g_loss = generator_loss(self.gan_type, fake=fake_logits) t_vars = tf.trainable_variables() d_vars = [var for var in t_vars if 'discriminator' in var.name] g_vars = [var for var in t_vars if 'generator' in var.name] self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize( self.d_loss, var_list=d_vars) self.g_optim = tf.train.AdamOptimizer(self.g_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize( self.g_loss, var_list=g_vars) self.fake_images = self.generator(self.z, is_training=False, reuse=True) self.d_sum = tf.summary.scalar("d_loss", self.d_loss) self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
if not export_file: fail_for_missing_file() with open(export_file) as f: export_json = load(f) legend = export_json['legend'] tfrecord_paths = export_json['tfrecord_paths'] image_dim = 512 test_set_size = math.ceil(0.20 * len(tfrecord_paths)) training_dataset = (tf.data.TFRecordDataset( tfrecord_paths).skip(test_set_size).map(_parse_tfrecord).apply( shuffle_and_repeat(50)).apply(map_and_batch(_resize(image_dim), 8))) test_dataset = (tf.data.TFRecordDataset(tfrecord_paths).take( test_set_size).map(_parse_tfrecord).apply( map_and_batch(_resize(image_dim), test_set_size))) training_iterator = training_dataset.make_one_shot_iterator() test_iterator = test_dataset.make_initializable_iterator() handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, training_dataset.output_types, training_dataset.output_shapes) images, labels = iterator.get_next() number_of_classes = len(legend) + 1 # +1 to include background class weight_decay = 0.0005 dropout_keep_prob = tf.placeholder(tf.float32, shape=[])
def build_model(self): # some parameters bs = self.batch_size """ Graph Input """ # images if self.custom_dataset: Image_Data_Class = ImageData(self.output_height, self.c_dim) inputs = tf.data.Dataset.from_tensor_slices(self.data) gpu_device = '/gpu:0' inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).apply( prefetch_to_device(gpu_device, self.batch_size)) inputs_iterator = inputs.make_one_shot_iterator() self.inputs = inputs_iterator.get_next() else: self.inputs = tf.placeholder(tf.float32, [ self.batch_size, self.output_height, self.output_height, self.c_dim ], name='real_images') # noises self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z') """ Loss Function """ x_fake = self.generator(self.z, is_training=True, reuse=False) x_real_encoder = self.encoder(self.inputs, is_training=True, reuse=False, sn=True) x_fake_encoder = self.encoder(x_fake, is_training=True, reuse=True, sn=True) x_real_fake = tf.subtract(x_real_encoder, x_fake_encoder) x_fake_real = tf.subtract(x_fake_encoder, x_real_encoder) x_real_fake_score = self.discriminator(x_real_fake, reuse=False, sn=True) x_fake_real_score = self.discriminator(x_fake_real, reuse=True, sn=True) # get loss for discriminator self.d_loss = discriminator_loss(self.loss_type, real=x_real_fake_score, fake=x_fake_real_score) # get loss for generator self.g_loss = generator_loss(self.loss_type, real=x_real_fake_score, fake=x_fake_real_score) """ Training """ # divide trainable variables into a group for D and a group for G t_vars = tf.trainable_variables() d_vars = [ var for var in t_vars if 'discriminator' in var.name or 'encoder' in var.name ] g_vars = [var for var in t_vars if 'generator' in var.name] # optimizers with tf.control_dependencies(tf.get_collection( tf.GraphKeys.UPDATE_OPS)): self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \ .minimize(self.d_loss, var_list=d_vars) self.g_optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \ .minimize(self.g_loss, var_list=g_vars) """" Testing """ # for test self.fake_images = self.generator(self.z, is_training=False, reuse=True) """ Summary """ self.d_sum = tf.summary.scalar("d_loss", self.d_loss) self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
def build_model(self): self.lr = tf.placeholder(tf.float32, name='lr') """ Input Image""" Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag) trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset) trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset) gpu_device = '/gpu:0' trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, self.batch_size)) trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, self.batch_size)) trainA_iterator = trainA.make_one_shot_iterator() trainB_iterator = trainB.make_one_shot_iterator() self.domain_A = trainA_iterator.get_next() self.domain_B = trainB_iterator.get_next() """ Define Encoder, Generator, Discriminator """ random_z = tf.random_normal(shape=[self.batch_size, self.n_z], mean=0.0, stddev=1.0, dtype=tf.float32) # encode content_a, attribute_a, mean_a, logvar_a = self.Encoder_A( self.domain_A) content_b, attribute_b, mean_b, logvar_b = self.Encoder_B( self.domain_B) # decode (fake, identity, random) fake_a = self.Decoder_A(content_B=content_b, attribute_A=attribute_a) fake_b = self.Decoder_B(content_A=content_a, attribute_B=attribute_b) recon_a = self.Decoder_A(content_B=content_a, attribute_A=attribute_a, reuse=True) recon_b = self.Decoder_B(content_A=content_b, attribute_B=attribute_b, reuse=True) random_fake_a = self.Decoder_A(content_B=content_b, attribute_A=random_z, reuse=True) random_fake_b = self.Decoder_B(content_A=content_a, attribute_B=random_z, reuse=True) # encode & decode again for cycle-consistency content_fake_a, attribute_fake_a, _, _ = self.Encoder_A(fake_a, reuse=True) content_fake_b, attribute_fake_b, _, _ = self.Encoder_B(fake_b, reuse=True) cycle_a = self.Decoder_A(content_B=content_fake_b, attribute_A=attribute_fake_a, reuse=True) cycle_b = self.Decoder_B(content_A=content_fake_a, attribute_B=attribute_fake_b, reuse=True) # for latent regression _, attribute_fake_random_a, _, _ = self.Encoder_A(random_fake_a, random_fake=True, reuse=True) _, attribute_fake_random_b, _, _ = self.Encoder_B(random_fake_b, random_fake=True, reuse=True) # discriminate real_A_logit, real_B_logit = self.discriminate_real( self.domain_A, self.domain_B) fake_A_logit, fake_B_logit = self.discriminate_fake(fake_a, fake_b) random_fake_A_logit, random_fake_B_logit = self.discriminate_fake( random_fake_a, random_fake_b) content_A_logit, content_B_logit = self.discriminate_content( content_a, content_b) """ Define Loss """ g_adv_loss_a = generator_loss(self.gan_type, fake_A_logit) + generator_loss( self.gan_type, random_fake_A_logit) g_adv_loss_b = generator_loss(self.gan_type, fake_B_logit) + generator_loss( self.gan_type, random_fake_B_logit) g_con_loss_a = generator_loss(self.gan_type, content_A_logit, content=True) g_con_loss_b = generator_loss(self.gan_type, content_B_logit, content=True) g_cyc_loss_a = L1_loss(cycle_a, self.domain_A) g_cyc_loss_b = L1_loss(cycle_b, self.domain_B) g_rec_loss_a = L1_loss(recon_a, self.domain_A) g_rec_loss_b = L1_loss(recon_b, self.domain_B) g_latent_loss_a = L1_loss(attribute_fake_random_a, random_z) g_latent_loss_b = L1_loss(attribute_fake_random_b, random_z) if self.concat: g_kl_loss_a = kl_loss(mean_a, logvar_a) + l2_regularize(content_a) g_kl_loss_b = kl_loss(mean_b, logvar_b) + l2_regularize(content_b) else: g_kl_loss_a = l2_regularize(attribute_a) + l2_regularize(content_a) g_kl_loss_b = l2_regularize(attribute_b) + l2_regularize(content_b) d_adv_loss_a = discriminator_loss(self.gan_type, real_A_logit, fake_A_logit) d_adv_loss_b = discriminator_loss(self.gan_type, real_B_logit, fake_B_logit) d_con_loss = discriminator_loss(self.gan_type, content_A_logit, content_B_logit) Generator_A_domain_loss = self.domain_adv_w * g_adv_loss_a Generator_A_content_loss = self.content_adv_w * g_con_loss_a Generator_A_cycle_loss = self.cycle_w * g_cyc_loss_b Generator_A_recon_loss = self.recon_w * g_rec_loss_a Generator_A_latent_loss = self.latent_w * g_latent_loss_a Generator_A_kl_loss = self.kl_w * g_kl_loss_a Generator_A_loss = Generator_A_domain_loss + \ Generator_A_content_loss + \ Generator_A_cycle_loss + \ Generator_A_recon_loss + \ Generator_A_latent_loss + \ Generator_A_kl_loss Generator_B_domain_loss = self.domain_adv_w * g_adv_loss_b Generator_B_content_loss = self.content_adv_w * g_con_loss_b Generator_B_cycle_loss = self.cycle_w * g_cyc_loss_a Generator_B_recon_loss = self.recon_w * g_rec_loss_b Generator_B_latent_loss = self.latent_w * g_latent_loss_b Generator_B_kl_loss = self.kl_w * g_kl_loss_b Generator_B_loss = Generator_B_domain_loss + \ Generator_B_content_loss + \ Generator_B_cycle_loss + \ Generator_B_recon_loss + \ Generator_B_latent_loss + \ Generator_B_kl_loss Discriminator_A_loss = self.domain_adv_w * d_adv_loss_a Discriminator_B_loss = self.domain_adv_w * d_adv_loss_b Discriminator_content_loss = self.content_adv_w * d_con_loss self.Generator_loss = Generator_A_loss + Generator_B_loss self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss self.Discriminator_content_loss = Discriminator_content_loss """ Training """ t_vars = tf.trainable_variables() G_vars = [ var for var in t_vars if 'endoer' in var.name or 'generator' in var.name ] D_vars = [ var for var in t_vars if 'discriminator' in var.name and 'content' not in var.name ] D_content_vars = [ var for var in t_vars if 'content_discriminator' in var.name ] grads, _ = tf.clip_by_global_norm(tf.gradients( self.Discriminator_content_loss, D_content_vars), clip_norm=5) self.G_optim = tf.train.AdamOptimizer( self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars) self.D_optim = tf.train.AdamOptimizer( self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars) self.D_content_optim = tf.train.AdamOptimizer( self.lr, beta1=0.5, beta2=0.999).apply_gradients(zip(grads, D_content_vars)) """" Summary """ self.lr_write = tf.summary.scalar("learning_rate", self.lr) self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss) self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss) self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss) self.G_A_domain_loss = tf.summary.scalar("G_A_domain_loss", Generator_A_domain_loss) self.G_A_content_loss = tf.summary.scalar("G_A_content_loss", Generator_A_content_loss) self.G_A_cycle_loss = tf.summary.scalar("G_A_cycle_loss", Generator_A_cycle_loss) self.G_A_recon_loss = tf.summary.scalar("G_A_recon_loss", Generator_A_recon_loss) self.G_A_latent_loss = tf.summary.scalar("G_A_latent_loss", Generator_A_latent_loss) self.G_A_kl_loss = tf.summary.scalar("G_A_kl_loss", Generator_A_kl_loss) self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss) self.G_B_domain_loss = tf.summary.scalar("G_B_domain_loss", Generator_B_domain_loss) self.G_B_content_loss = tf.summary.scalar("G_B_content_loss", Generator_B_content_loss) self.G_B_cycle_loss = tf.summary.scalar("G_B_cycle_loss", Generator_B_cycle_loss) self.G_B_recon_loss = tf.summary.scalar("G_B_recon_loss", Generator_B_recon_loss) self.G_B_latent_loss = tf.summary.scalar("G_B_latent_loss", Generator_B_latent_loss) self.G_B_kl_loss = tf.summary.scalar("G_B_kl_loss", Generator_B_kl_loss) self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss) self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss) self.G_loss = tf.summary.merge([ self.G_A_loss, self.G_A_domain_loss, self.G_A_content_loss, self.G_A_cycle_loss, self.G_A_recon_loss, self.G_A_latent_loss, self.G_A_kl_loss, self.G_B_loss, self.G_B_domain_loss, self.G_B_content_loss, self.G_B_cycle_loss, self.G_B_recon_loss, self.G_B_latent_loss, self.G_B_kl_loss, self.all_G_loss ]) self.D_loss = tf.summary.merge( [self.D_A_loss, self.D_B_loss, self.all_D_loss]) self.D_content_loss = tf.summary.scalar( "Discriminator_content_loss", self.Discriminator_content_loss) """ Image """ self.fake_A = fake_a self.fake_B = fake_b self.real_A = self.domain_A self.real_B = self.domain_B """ Test """ self.test_image = tf.placeholder( tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_image') self.test_random_z = tf.random_normal(shape=[1, self.n_z], mean=0.0, stddev=1.0, dtype=tf.float32) test_content_a, test_attribute_a, _, _ = self.Encoder_A( self.test_image, is_training=False, reuse=True) test_content_b, test_attribute_b, _, _ = self.Encoder_B( self.test_image, is_training=False, reuse=True) self.test_fake_A = self.Decoder_A(content_B=test_content_b, attribute_A=self.test_random_z, reuse=True) self.test_fake_B = self.Decoder_B(content_A=test_content_a, attribute_B=self.test_random_z, reuse=True) """ Guided Image Translation """ self.content_image = tf.placeholder( tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='content_image') self.attribute_image = tf.placeholder( tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='guide_attribute_image') guide_content_A, guide_attribute_A, _, _ = self.Encoder_A( self.content_image, is_training=False, reuse=True) guide_content_B, guide_attribute_B, _, _ = self.Encoder_B( self.attribute_image, is_training=False, reuse=True) self.guide_fake_A = self.Decoder_A(content_B=guide_content_B, attribute_A=guide_attribute_A, reuse=True) self.guide_fake_B = self.Decoder_B(content_A=guide_content_A, attribute_B=guide_attribute_B, reuse=True)
def _initialize_tf_dataset(file_names, augment, over_sample, shuffle, target_size, normalize, nb_workers=8, batch_size=64, shuffle_buffer_size=3000, input_type="img", nr_epochs=1): import tensorflow as tf from tensorflow.contrib.data import map_and_batch from tensorflow.contrib.data import shuffle_and_repeat if not type(target_size) is list: target_size = list(target_size) with tf.name_scope('input_pipeline'): dataset = tf.data.TFRecordDataset(file_names) if shuffle: dataset = dataset.apply( shuffle_and_repeat(shuffle_buffer_size, nr_epochs)) def _decode_and_augment_image(example_proto): keys_to_features = { 'label': tf.FixedLenFeature([], tf.int64), 'shape': tf.FixedLenFeature([], tf.string), 'image': tf.FixedLenFeature([], tf.string), } tfrecord_features = tf.parse_single_example( example_proto, keys_to_features) image = tf.decode_raw(tfrecord_features['image'], tf.uint8) shape = tf.decode_raw(tfrecord_features['shape'], tf.int64) if input_type == ".jpeg": image = tf.reshape(image, target_size + [3]) else: image = tf.reshape(image, target_size) label = tfrecord_features['label'] if augment: image = tf.image.random_flip_left_right(image) image = tf.image.random_flip_up_down(image) degrees = tf.random_uniform((), minval=-180, maxval=180) image = tf.contrib.image.rotate(image, degrees) width_shift = tf.random_uniform((), minval=0, maxval=0.05) height_shift = tf.random_uniform((), minval=0, maxval=0.05) horizontal_pad = tf.cast(tf.ceil(width_shift * target_size[0]), tf.int32) vertical_pad = tf.cast(tf.ceil(height_shift * target_size[1]), tf.int32) padding = tf.stack([ horizontal_pad, horizontal_pad, vertical_pad, vertical_pad, tf.constant(0), tf.constant(0) ]) padding = tf.reshape(padding, (3, 2)) image = tf.pad(image, padding) image = tf.random_crop(image, target_size + [3]) zoom = tf.random_uniform((), minval=-0.1, maxval=0.1) new_dim = tf.cast(tf.ceil((1 - zoom) * target_size[0]), dtype=tf.int32) image = tf.image.resize_image_with_crop_or_pad( image, new_dim, new_dim) image = tf.image.resize_images( image, target_size, method=tf.image.ResizeMethod.BILINEAR) if normalize: std = tf.constant(np.array( [70.53946096, 51.71475228, 43.03428563]), dtype=tf.float32) std = tf.expand_dims(tf.expand_dims(std, axis=0), axis=0) mean = tf.constant(np.array( [108.64628601, 75.86886597, 54.34005736]), dtype=tf.float32) mean = tf.expand_dims(tf.expand_dims(mean, axis=0), axis=0) image = (tf.cast(image, dtype=tf.float32) - mean) / std label = tf.reshape(label, [1]) if input_type == ".jpeg": image = tf.reshape(image, target_size + [3]) else: image = tf.reshape(image, target_size) return {'shape': shape, 'image': image}, label dataset = dataset \ .apply(map_and_batch(_decode_and_augment_image, batch_size=batch_size, num_parallel_batches=nb_workers, drop_remainder=True)) \ .prefetch(nb_workers) # def _augment_images(example) return dataset
def build_model(self): """ Input Image""" img_data_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.augment_flag) train_captions, train_images, test_captions, test_images, idx_to_word, word_to_idx = img_data_class.preprocess() """ train_captions: (8855, 10, 18), test_captions: (2933, 10, 18) train_images: (8855,), test_images: (2933,) idx_to_word : 5450 5450 """ if self.phase == 'train' : self.lr = tf.placeholder(tf.float32, name='learning_rate') self.dataset_num = len(train_images) img_and_caption = tf.data.Dataset.from_tensor_slices((train_images, train_captions)) gpu_device = '/gpu:0' img_and_caption = img_and_caption.apply(shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(img_data_class.image_processing, batch_size=self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, None)) img_and_caption_iterator = img_and_caption.make_one_shot_iterator() real_img_256, caption = img_and_caption_iterator.get_next() target_sentence_index = tf.random_uniform(shape=[], minval=0, maxval=10, dtype=tf.int32) caption = tf.gather(caption, target_sentence_index, axis=1) word_emb, sent_emb, mask = self.rnn_encoder(caption, n_words=len(idx_to_word), embed_dim=self.embed_dim, drop_rate=0.5, n_hidden=128, n_layers=1, bidirectional=True, rnn_type='lstm') noise = tf.random_normal(shape=[self.batch_size, self.z_dim], mean=0.0, stddev=1.0) fake_imgs, _, mu, logvar = self.generator(noise, sent_emb, word_emb, mask) real_img_64, real_img_128 = resize(real_img_256, target_size=[64, 64]), resize(real_img_256, target_size=[128, 128]) fake_img_64, fake_img_128, fake_img_256 = fake_imgs[0], fake_imgs[1], fake_imgs[2] uncond_real_logits, cond_real_logits = self.discriminator([real_img_64, real_img_128, real_img_256], sent_emb) _, cond_wrong_logits = self.discriminator([real_img_64[:(self.batch_size - 1)], real_img_128[:(self.batch_size - 1)], real_img_256[:(self.batch_size - 1)]], sent_emb[1:self.batch_size]) uncond_fake_logits, cond_fake_logits = self.discriminator([fake_img_64, fake_img_128, fake_img_256], sent_emb) self.g_adv_loss, self.d_adv_loss = 0, 0 for i in range(3): self.g_adv_loss += self.adv_weight * (generator_loss(self.gan_type, uncond_fake_logits[i]) + generator_loss(self.gan_type, cond_fake_logits[i])) uncond_real_loss, uncond_fake_loss = discriminator_loss(self.gan_type, uncond_real_logits[i], uncond_fake_logits[i]) cond_real_loss, cond_fake_loss = discriminator_loss(self.gan_type, cond_real_logits[i], cond_fake_logits[i]) _, cond_wrong_loss = discriminator_loss(self.gan_type, None, cond_wrong_logits[i]) self.d_adv_loss += self.adv_weight * (((uncond_real_loss + cond_real_loss) / 2) + (uncond_fake_loss + cond_fake_loss + cond_wrong_loss) / 3) self.g_kl_loss = self.kl_weight * kl_loss(mu, logvar) self.g_loss = self.g_adv_loss + self.g_kl_loss self.d_loss = self.d_adv_loss self.real_img = real_img_256 self.fake_img = fake_img_256 """ Training """ t_vars = tf.trainable_variables() G_vars = [var for var in t_vars if 'generator' in var.name] D_vars = [var for var in t_vars if 'discriminator' in var.name] self.g_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.g_loss, var_list=G_vars) self.d_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.d_loss,var_list=D_vars) """" Summary """ self.summary_g_loss = tf.summary.scalar("g_loss", self.g_loss) self.summary_d_loss = tf.summary.scalar("d_loss", self.d_loss) self.summary_g_adv_loss = tf.summary.scalar("g_adv_loss", self.g_adv_loss) self.summary_g_kl_loss = tf.summary.scalar("g_kl_loss", self.g_kl_loss) self.summary_d_adv_loss = tf.summary.scalar("d_adv_loss", self.d_adv_loss) g_summary_list = [self.summary_g_loss, self.summary_g_adv_loss, self.summary_g_kl_loss] d_summary_list = [self.summary_d_loss, self.summary_d_adv_loss] self.summary_merge_g_loss = tf.summary.merge(g_summary_list) self.summary_merge_d_loss = tf.summary.merge(d_summary_list) else : """ Test """ self.dataset_num = len(test_captions) gpu_device = '/gpu:0' img_and_caption = tf.data.Dataset.from_tensor_slices((test_images, test_captions)) img_and_caption = img_and_caption.apply( shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(img_data_class.image_processing, batch_size=self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, None)) img_and_caption_iterator = img_and_caption.make_one_shot_iterator() real_img_256, caption = img_and_caption_iterator.get_next() target_sentence_index = tf.random_uniform(shape=[], minval=0, maxval=10, dtype=tf.int32) caption = tf.gather(caption, target_sentence_index, axis=1) word_emb, sent_emb, mask = self.rnn_encoder(caption, n_words=len(idx_to_word), embed_dim=self.embed_dim, drop_rate=0.5, n_hidden=128, n_layers=1, bidirectional=True, rnn_type='lstm', is_training=False) noise = tf.random_normal(shape=[self.batch_size, self.z_dim], mean=0.0, stddev=1.0) fake_imgs, _, _, _ = self.generator(noise, sent_emb, word_emb, mask, is_training=False) self.test_real_img = real_img_256 self.test_fake_img = fake_imgs[2]
def get_batch(datasets, preprocess_name, is_training, batch_size, num_gpu=1, seed=None): with tf.device('/cpu:0'): num_class = datasets.num_class file_name = datasets.source feature = datasets.feature decoder = datasets.decoder name = datasets.description['name'] image_preprocessing_fn = get_preprocess_fn(preprocess_name) dataset = tf.data.Dataset.from_tensor_slices(file_name) if is_training: # Shuffle the input files dataset = dataset.shuffle(len(file_name), seed=seed, reshuffle_each_iteration=True) ''' Convert to individual records. cycle_length = 8 means 8 files will be read and deserialized in parallel. This number is low enough to not cause too much contention on small systems but high enough to provide the benefits of parallelization. You may want to increase this number if you have a large number of CPU cores. ''' cycle_length = min(10, len(file_name)) dataset = dataset.apply( data.parallel_interleave(tf.data.TFRecordDataset, cycle_length=cycle_length)) # We prefetch a batch at a time, This can help smooth out the time taken to # load input files as we go through shuffling and processing. dataset = dataset.prefetch(buffer_size=batch_size) if is_training: dataset = dataset.apply( data.shuffle_and_repeat(buffer_size=10000, seed=seed)) else: dataset = dataset.repeat() def map_func(record): parsed = tf.parse_single_example(record, feature) image = decoder(parsed['image/encoded']) # Perform additional preprocessing on the parsed data. image = image_preprocessing_fn(image, datasets, is_training=is_training) label = parsed['image/class/label'] label = tf.one_hot(label, num_class) return image, label ''' Parse the raw records into images and labels. Testing has shown that setting num_parallel_batches > 1 produces no improvement in throughput, since batch_size is almost always much greater than the number of CPU cores. ''' dataset = dataset.apply( data.map_and_batch(map_func=map_func, batch_size=batch_size, num_parallel_batches=1)) ''' Operations between the final prefetch and the get_next call to the iterator will happen synchronously during run time. We prefetch here again to background all of the above processing work and keep it out of the critical training path. ''' dataset = dataset.prefetch(buffer_size=32) iterator = dataset.make_one_shot_iterator() return iterator
merged_summary_op = tf.summary.merge_all() saver = tf.train.Saver(max_to_keep=3) dataset = tf.data.TFRecordDataset('Dataset/dataset_1.tfrecord') def _parse(example_proto): feature = {'roll' : tf.FixedLenFeature([], tf.string)} parsed = tf.parse_single_example(example_proto, feature) data = tf.decode_raw(parsed['roll'], tf.uint8) data = tf.py_func(func=np.unpackbits, inp=[data], Tout=tf.uint8) data = tf.cast(data, tf.float32) data = tf.reshape(data, [CLASS_NUM, 600]) data = data * 2 - 1 return data dataset = dataset.apply(data.shuffle_and_repeat(buffer_size=60000, count=3)) dataset = dataset.apply(data.map_and_batch(_parse, batch_size=batch_size, num_parallel_batches=2)) iterator = dataset.prefetch(batch_size).make_one_shot_iterator() real_input_next_element = iterator.get_next() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) if os.path.exists(model_path): load_path = saver.restore(sess, model_path) print("Model restored from file: %s" % load_path) for i in range(total_batch): tfdata = sess.run(real_input_next_element) reshape_tfdata = tfdata.reshape([-1, CLASS_NUM, 600]) # cuted_tfdata = reshape_tfdata[:, :, :INPUT_LENGTH] reshape_cuted_tfdata = cuted_tfdata.reshape([-1, CLASS_NUM*INPUT_LENGTH])
def build_model(self): """ Graph """ if self.phase == 'train': self.d_loss_per_res = {} self.g_loss_per_res = {} self.generator_optim = {} self.discriminator_optim = {} self.alpha_summary_per_res = {} self.d_summary_per_res = {} self.g_summary_per_res = {} self.train_fake_images = {} for res in self.resolutions[self.resolutions.index(self.start_res ):]: g_loss_per_gpu = [] d_loss_per_gpu = [] train_fake_images_per_gpu = [] batch_size = self.batch_sizes.get(res, self.batch_size_base) global_step = tf.get_variable( 'global_step_{}'.format(res), shape=[], dtype=tf.float32, initializer=tf.initializers.zeros(), trainable=False, aggregation=tf.VariableAggregation.ONLY_FIRST_TOWER) alpha_const, zero_constant = get_alpha_const( self.iteration // 2, batch_size * self.gpu_num, global_step) # smooth transition variable do_train_trans = self.train_with_trans[res] alpha = tf.get_variable( 'alpha_{}'.format(res), shape=[], dtype=tf.float32, initializer=tf.initializers.ones() if do_train_trans else tf.initializers.zeros(), trainable=False, aggregation=tf.VariableAggregation.ONLY_FIRST_TOWER) if do_train_trans: alpha_assign_op = tf.assign(alpha, alpha_const) else: alpha_assign_op = tf.assign(alpha, zero_constant) with tf.control_dependencies([alpha_assign_op]): for gpu_id in range(self.gpu_num): with tf.device( tf.DeviceSpec(device_type="GPU", device_index=gpu_id)): with tf.variable_scope(tf.get_variable_scope(), reuse=(gpu_id > 0)): # images gpu_device = '/gpu:{}'.format(gpu_id) image_class = ImageData(res) inputs = tf.data.Dataset.from_tensor_slices( self.dataset) inputs = inputs. \ apply(shuffle_and_repeat(self.dataset_num)). \ apply(map_and_batch(image_class.image_processing, batch_size, num_parallel_batches=16, drop_remainder=True)). \ apply(prefetch_to_device(gpu_device, None)) # When using dataset.prefetch, use buffer_size=None to let it detect optimal buffer size inputs_iterator = inputs.make_one_shot_iterator( ) real_img = inputs_iterator.get_next() z = tf.random_normal( shape=[batch_size, self.z_dim]) fake_img = self.generator(z, alpha, res) real_img = smooth_crossfade(real_img, alpha) real_logit = self.discriminator( real_img, alpha, res) fake_logit = self.discriminator( fake_img, alpha, res) # compute loss d_loss, g_loss = compute_loss( real_img, real_logit, fake_logit) d_loss_per_gpu.append(d_loss) g_loss_per_gpu.append(g_loss) train_fake_images_per_gpu.append(fake_img) print("Create graph for {} resolution".format(res)) # prepare appropriate training vars d_vars, g_vars = filter_trainable_variables(res) d_loss = tf.reduce_mean(d_loss_per_gpu) g_loss = tf.reduce_mean(g_loss_per_gpu) d_lr = self.d_learning_rates.get(res, self.learning_rate_base) g_lr = self.g_learning_rates.get(res, self.learning_rate_base) if self.gpu_num == 1: colocate_grad = False else: colocate_grad = True d_optim = tf.train.AdamOptimizer( d_lr, beta1=0.0, beta2=0.99, epsilon=1e-8).minimize( d_loss, var_list=d_vars, colocate_gradients_with_ops=colocate_grad) g_optim = tf.train.AdamOptimizer( g_lr, beta1=0.0, beta2=0.99, epsilon=1e-8).minimize( g_loss, var_list=g_vars, global_step=global_step, colocate_gradients_with_ops=colocate_grad) self.discriminator_optim[res] = d_optim self.generator_optim[res] = g_optim self.d_loss_per_res[res] = d_loss self.g_loss_per_res[res] = g_loss self.train_fake_images[res] = tf.concat( train_fake_images_per_gpu, axis=0) """ Summary """ self.alpha_summary_per_res[res] = tf.summary.scalar( "alpha_{}".format(res), alpha) self.d_summary_per_res[res] = tf.summary.scalar( "d_loss_{}".format(res), self.d_loss_per_res[res]) self.g_summary_per_res[res] = tf.summary.scalar( "g_loss_{}".format(res), self.g_loss_per_res[res]) else: """" Testing """ test_z = tf.random_normal(shape=[self.batch_size, self.z_dim]) alpha = tf.constant(0.0, dtype=tf.float32, shape=[]) self.fake_images = self.generator(test_z, alpha=alpha, target_img_size=self.img_size, is_training=False)
def build_model(self): self.lr = tf.placeholder(tf.float32, name='learning_rate') """ Input Image""" img_class = Image_data(self.img_height, self.img_width, self.img_ch, self.segmap_ch, self.dataset_path, self.augment_flag) img_class.preprocess() self.dataset_num = len(img_class.image) self.test_dataset_num = len(img_class.segmap_test) img_and_segmap = tf.data.Dataset.from_tensor_slices( (img_class.image, img_class.segmap)) segmap_test = tf.data.Dataset.from_tensor_slices(img_class.segmap_test) gpu_device = '/gpu:0' img_and_segmap = img_and_segmap.apply( shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(img_class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, self.batch_size)) segmap_test = segmap_test.apply(shuffle_and_repeat( self.dataset_num)).apply( map_and_batch(img_class.test_image_processing, batch_size=self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, self.batch_size)) img_and_segmap_iterator = img_and_segmap.make_one_shot_iterator() segmap_test_iterator = segmap_test.make_one_shot_iterator() self.real_x, self.real_x_segmap, self.real_x_segmap_onehot = img_and_segmap_iterator.get_next( ) self.real_x_segmap_test, self.real_x_segmap_test_onehot = segmap_test_iterator.get_next( ) """ Define Generator, Discriminator """ fake_x, x_mean, x_var = self.image_translate( segmap_img=self.real_x_segmap_onehot, x_img=self.real_x) real_logit, fake_logit = self.image_discriminate( segmap_img=self.real_x_segmap_onehot, real_img=self.real_x, fake_img=fake_x) if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan': GP = self.gradient_penalty(real=self.real_x, segmap=self.real_x_segmap_onehot, fake=fake_x) else: GP = 0 """ Define Loss """ g_adv_loss = self.adv_weight * generator_loss(self.gan_type, fake_logit) g_kl_loss = self.kl_weight * kl_loss(x_mean, x_var) g_vgg_loss = self.vgg_weight * VGGLoss()(self.real_x, fake_x) g_feature_loss = self.feature_weight * feature_loss( real_logit, fake_logit) g_reg_loss = regularization_loss('generator') + regularization_loss( 'encoder') d_adv_loss = self.adv_weight * ( discriminator_loss(self.gan_type, real_logit, fake_logit) + GP) d_reg_loss = regularization_loss('discriminator') self.g_loss = g_adv_loss + g_kl_loss + g_vgg_loss + g_feature_loss + g_reg_loss self.d_loss = d_adv_loss + d_reg_loss """ Result Image """ self.fake_x = fake_x self.random_fake_x, _, _ = self.image_translate( segmap_img=self.real_x_segmap_onehot, random_style=True, reuse=True) """ Test """ self.test_segmap_image = tf.placeholder(tf.float32, [ 1, self.img_height, self.img_width, len(img_class.color_value_dict) ]) self.random_test_fake_x, _, _ = self.image_translate( segmap_img=self.test_segmap_image, random_style=True, reuse=True) self.test_guide_image = tf.placeholder( tf.float32, [1, self.img_height, self.img_width, self.img_ch]) self.guide_test_fake_x, _, _ = self.image_translate( segmap_img=self.test_segmap_image, x_img=self.test_guide_image, reuse=True) """ Training """ t_vars = tf.trainable_variables() G_vars = [ var for var in t_vars if 'encoder' in var.name or 'generator' in var.name ] D_vars = [var for var in t_vars if 'discriminator' in var.name] if self.TTUR: beta1 = 0.0 beta2 = 0.9 g_lr = self.lr / 2 d_lr = self.lr * 2 else: beta1 = self.beta1 beta2 = self.beta2 g_lr = self.lr d_lr = self.lr self.G_optim = tf.train.AdamOptimizer( g_lr, beta1=beta1, beta2=beta2).minimize(self.g_loss, var_list=G_vars) self.D_optim = tf.train.AdamOptimizer( d_lr, beta1=beta1, beta2=beta2).minimize(self.d_loss, var_list=D_vars) """" Summary """ self.summary_g_loss = tf.summary.scalar("g_loss", self.g_loss) self.summary_d_loss = tf.summary.scalar("d_loss", self.d_loss) self.summary_g_adv_loss = tf.summary.scalar("g_adv_loss", g_adv_loss) self.summary_g_kl_loss = tf.summary.scalar("g_kl_loss", g_kl_loss) self.summary_g_vgg_loss = tf.summary.scalar("g_vgg_loss", g_vgg_loss) self.summary_g_feature_loss = tf.summary.scalar( "g_feature_loss", g_feature_loss) g_summary_list = [ self.summary_g_loss, self.summary_g_adv_loss, self.summary_g_kl_loss, self.summary_g_vgg_loss, self.summary_g_feature_loss ] d_summary_list = [self.summary_d_loss] self.G_loss = tf.summary.merge(g_summary_list) self.D_loss = tf.summary.merge(d_summary_list)
def build_model(self): if self.phase == 'train': #初始化步长 self.lr = tf.placeholder(tf.float32, name='learning_rate') """ Input Image""" Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag) trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset) trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset) #将图片地址切片 #print('ok trainA:',trainA) gpu_device = '/gpu:0' trainA = trainA.apply(shuffle_and_repeat(100)).apply( map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, None)) trainB = trainB.apply(shuffle_and_repeat(100)).apply( map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, None)) #打乱原图片排列 #map_and_batch 1:将tensor的嵌套结构映射到另一个tensor嵌套结构的函数 # 2:要在此数据集合并的单个batch中的连续元素数(一个batch 32 个元素,即输出是32维的元素) # 3:要并行创建的batch数。一方面,较高的值可以帮助减轻落后者的影响。另一方面,如果CPU空闲,较高的值可能会增加竞争。 # 4:表示是否应丢弃最后一个batch,以防其大小小于所需值 #返回 32*64*64*3 的随机增强的数组 #应用gpu加速 #print('ok trainA:',trainA.shape) trainA_iterator = trainA.make_one_shot_iterator() trainB_iterator = trainB.make_one_shot_iterator() self.domain_A = trainA_iterator.get_next() self.domain_B = trainB_iterator.get_next() """ Define Generator, Discriminator """ x_ab, cam_ab = self.generate_a2b(self.domain_A) # real a #self.domain_A 是 卡通图片 #x_ab是由self.domain_A经过下采样 再 上采样得到的图片 x_ba, cam_ba = self.generate_b2a(self.domain_B) # real b #generate_a2b和generate_b2a是两套不同的参数 x_aba, _ = self.generate_b2a(x_ab, reuse=True) # real b x_bab, _ = self.generate_a2b(x_ba, reuse=True) # real a #固定参数不变,再将generate_a2b生成的图片,用generate_b2a生成一遍 #generate_b2a 尝试 将真人图生成卡通图 #generate_a2b 尝试 将卡通图生成真人图 #可以看做将generate_a2b与generate_b2a作为逆变换 x_aa, cam_aa = self.generate_b2a(self.domain_A, reuse=True) # fake b x_bb, cam_bb = self.generate_a2b(self.domain_B, reuse=True) # fake a #固定参数不变 #***将卡通图生成卡通图 #确保在generate_b2a和generate_a2b过程中,颜色区域是不变的 real_A_logit, real_A_cam_logit, real_B_logit, real_B_cam_logit = self.discriminate_real( self.domain_A, self.domain_B) #鉴别 真卡通图 与 真真人图 fake_A_logit, fake_A_cam_logit, fake_B_logit, fake_B_cam_logit = self.discriminate_fake( x_ba, x_ab) #鉴别 假卡通图 与 假真人图 #输入的是生成器生成的图片 #输出的是图片经过卷积的32*8*8*1的张量和池化结果的连接(32, 2) """ Define Loss """ if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan': GP_A, GP_CAM_A = self.gradient_panalty(real=self.domain_A, fake=x_ba, scope="discriminator_A") GP_B, GP_CAM_B = self.gradient_panalty(real=self.domain_B, fake=x_ab, scope="discriminator_B") else: GP_A, GP_CAM_A = 0, 0 GP_B, GP_CAM_B = 0, 0 #接下来是对于假图片判别器discriminate_fake 真图片判别器discriminate_real的损失计算 #对判别器和生成器的损失计算 ''' 一、对抗损失(T) 用的最小二乘损失 ''' #fake_A_logit[0],fake_A_logit[1]与1的平方差的和+fake_A_cam_logit[0],fake_A_cam_logit[1]与1的平方差的和 #从生成器的角度看,真图片会被discriminate鉴别器判为0,生成的图片经过判别器(1-fake_*),也要尽量被判别为0 G_ad_loss_A = (generator_loss(self.gan_type, fake_A_logit) + generator_loss(self.gan_type, fake_A_cam_logit)) #生成假卡通图的损失 G_ad_loss_B = (generator_loss(self.gan_type, fake_B_logit) + generator_loss(self.gan_type, fake_B_cam_logit)) #生成假真人图的损失 #生成器的损失 D_ad_loss_A = ( discriminator_loss(self.gan_type, real_A_logit, fake_A_logit) + discriminator_loss(self.gan_type, real_A_cam_logit, fake_A_cam_logit) + GP_A + GP_CAM_A) #对卡通图的鉴别损失 #discriminator_loss力求将真图片判断为1,假图片判断为0 D_ad_loss_B = ( discriminator_loss(self.gan_type, real_B_logit, fake_B_logit) + discriminator_loss(self.gan_type, real_B_cam_logit, fake_B_cam_logit) + GP_B + GP_CAM_B) #鉴别器的损失 ''' 二、Cycle 损失(T) the image should be successfully translated back to the original domain ''' reconstruction_A = L1_loss(x_aba, self.domain_A) # reconstruction #由卡通图生成真人图再生成卡通图的损失 reconstruction_B = L1_loss(x_bab, self.domain_B) # reconstruction #由真人图生成卡通图再生成真人图的损失 ''' 三、Identity 损失 确保在A B相互变化时,身份信息不丢失 ''' identity_A = L1_loss(x_aa, self.domain_A) identity_B = L1_loss(x_bb, self.domain_B) #将卡通图生成卡通图的损失 ''' 四、CAM 损失 利用辅助分类器,使得G和D知道在哪里进行强化变换 在A变B时,热力图应该有明显显示 在A变A时,热力图应该没有显示 ''' cam_A = cam_loss(source=cam_ba, non_source=cam_aa) #cam_ba是从真人图到卡通图的全连接(两次,用不同方法池化) #cam_aa是从卡通图到卡通图的全连接 cam_B = cam_loss(source=cam_ab, non_source=cam_bb) #开始时的比重是如何决定的??? # Generator_A_gan = self.adv_weight * G_ad_loss_A #1 #网络由真人图生成卡通图的损失*相对的比重 Generator_A_cycle = self.cycle_weight * reconstruction_B #10 #self.generate_a2b(self.generate_b2a(self.domain_B), reuse=True) #由真人图生成卡通图再生成真人图的损失*相对的比重 Generator_A_identity = self.identity_weight * identity_A #10 #由卡通图生成卡通图的损失*相对的比重 Generator_A_cam = self.cam_weight * cam_A #1000 #从真人图到卡通图的全连接*相对的比重 Generator_B_gan = self.adv_weight * G_ad_loss_B Generator_B_cycle = self.cycle_weight * reconstruction_A Generator_B_identity = self.identity_weight * identity_B Generator_B_cam = self.cam_weight * cam_B print('ok 5') Generator_A_loss = Generator_A_gan + Generator_A_cycle + Generator_A_identity + Generator_A_cam #所有生成卡通图的损失 Generator_B_loss = Generator_B_gan + Generator_B_cycle + Generator_B_identity + Generator_B_cam Discriminator_A_loss = self.adv_weight * D_ad_loss_A #对生成的卡通图 和 真的卡通图的鉴别损失+生成的卡通图的全连接 和 真的卡通图的全连接的鉴别损失*权重 Discriminator_B_loss = self.adv_weight * D_ad_loss_B self.Generator_loss = Generator_A_loss + Generator_B_loss + regularization_loss( 'generator') #生成器的总损失 self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss + regularization_loss( 'discriminator') #鉴别器的总损失 print('55') """ Result Image """ #生成的假图片(用于储存) self.fake_A = x_ba self.fake_B = x_ab #输入的真图片 self.real_A = self.domain_A self.real_B = self.domain_B self.imgba = imgba self.imgab = imgab """ Training """ t_vars = tf.trainable_variables() G_vars = [var for var in t_vars if 'generator' in var.name] D_vars = [var for var in t_vars if 'discriminator' in var.name] self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize( self.Generator_loss, var_list=G_vars) self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize( self.Discriminator_loss, var_list=D_vars) #var_list:在优化时每次要迭代更新的参数集合 """" Summary """ # self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss) # self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss) # self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss) # self.G_A_gan = tf.summary.scalar("G_A_gan", Generator_A_gan) # self.G_A_cycle = tf.summary.scalar("G_A_cycle", Generator_A_cycle) # self.G_A_identity = tf.summary.scalar("G_A_identity", Generator_A_identity) # self.G_A_cam = tf.summary.scalar("G_A_cam", Generator_A_cam) # self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss) # self.G_B_gan = tf.summary.scalar("G_B_gan", Generator_B_gan) # self.G_B_cycle = tf.summary.scalar("G_B_cycle", Generator_B_cycle) # self.G_B_identity = tf.summary.scalar("G_B_identity", Generator_B_identity) # self.G_B_cam = tf.summary.scalar("G_B_cam", Generator_B_cam) # self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss) # self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss) ''' 画图 ''' # self.rho_var = [] # for var in tf.trainable_variables(): # if 'rho' in var.name: # self.rho_var.append(tf.summary.histogram(var.name, var)) # self.rho_var.append(tf.summary.scalar(var.name + "_min", tf.reduce_min(var))) # self.rho_var.append(tf.summary.scalar(var.name + "_max", tf.reduce_max(var))) # self.rho_var.append(tf.summary.scalar(var.name + "_mean", tf.reduce_mean(var))) # print('ok 7') # g_summary_list = [self.G_A_loss, self.G_A_gan, self.G_A_cycle, self.G_A_identity, self.G_A_cam, # self.G_B_loss, self.G_B_gan, self.G_B_cycle, self.G_B_identity, self.G_B_cam, # self.all_G_loss] # g_summary_list.extend(self.rho_var) # d_summary_list = [self.D_A_loss, self.D_B_loss, self.all_D_loss] # self.G_loss = tf.summary.merge(g_summary_list) # self.D_loss = tf.summary.merge(d_summary_list) else: """ Test """ self.test_domain_A = tf.placeholder( tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_A') self.test_domain_B = tf.placeholder( tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_B') self.test_fake_B, _ = self.generate_a2b(self.test_domain_A) self.test_fake_A, _ = self.generate_b2a(self.test_domain_B)
def build_model(self): if self.phase == 'train': self.lr = tf.placeholder(tf.float32, name='learning_rate') """ Input Image""" Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag) trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset) trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset) gpu_device = '/gpu:0' trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, None)) trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, None)) trainA_iterator = trainA.make_one_shot_iterator() trainB_iterator = trainB.make_one_shot_iterator() self.domain_A = trainA_iterator.get_next() self.domain_B = trainB_iterator.get_next() """ Define Generator, Discriminator """ x_ab, cam_ab = self.generate_a2b(self.domain_A) # real a x_ba, cam_ba = self.generate_b2a(self.domain_B) # real b x_aba, _ = self.generate_b2a(x_ab, reuse=True) # real b x_bab, _ = self.generate_a2b(x_ba, reuse=True) # real a x_aa, cam_aa = self.generate_b2a(self.domain_A, reuse=True) # fake b x_bb, cam_bb = self.generate_a2b(self.domain_B, reuse=True) # fake a real_A_logit, real_A_cam_logit, real_B_logit, real_B_cam_logit = self.discriminate_real( self.domain_A, self.domain_B) fake_A_logit, fake_A_cam_logit, fake_B_logit, fake_B_cam_logit = self.discriminate_fake( x_ba, x_ab) """ Define Loss """ if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan': GP_A, GP_CAM_A = self.gradient_panalty(real=self.domain_A, fake=x_ba, scope="discriminator_A") GP_B, GP_CAM_B = self.gradient_panalty(real=self.domain_B, fake=x_ab, scope="discriminator_B") else: GP_A, GP_CAM_A = 0, 0 GP_B, GP_CAM_B = 0, 0 G_ad_loss_A = (generator_loss(self.gan_type, fake_A_logit) + generator_loss(self.gan_type, fake_A_cam_logit)) G_ad_loss_B = (generator_loss(self.gan_type, fake_B_logit) + generator_loss(self.gan_type, fake_B_cam_logit)) D_ad_loss_A = ( discriminator_loss(self.gan_type, real_A_logit, fake_A_logit) + discriminator_loss(self.gan_type, real_A_cam_logit, fake_A_cam_logit) + GP_A + GP_CAM_A) D_ad_loss_B = ( discriminator_loss(self.gan_type, real_B_logit, fake_B_logit) + discriminator_loss(self.gan_type, real_B_cam_logit, fake_B_cam_logit) + GP_B + GP_CAM_B) reconstruction_A = L1_loss(x_aba, self.domain_A) # reconstruction reconstruction_B = L1_loss(x_bab, self.domain_B) # reconstruction identity_A = L1_loss(x_aa, self.domain_A) identity_B = L1_loss(x_bb, self.domain_B) cam_A = cam_loss(source=cam_ba, non_source=cam_aa) cam_B = cam_loss(source=cam_ab, non_source=cam_bb) Generator_A_gan = self.adv_weight * G_ad_loss_A Generator_A_cycle = self.cycle_weight * reconstruction_B Generator_A_identity = self.identity_weight * identity_A Generator_A_cam = self.cam_weight * cam_A Generator_B_gan = self.adv_weight * G_ad_loss_B Generator_B_cycle = self.cycle_weight * reconstruction_A Generator_B_identity = self.identity_weight * identity_B Generator_B_cam = self.cam_weight * cam_B Generator_A_loss = Generator_A_gan + Generator_A_cycle + Generator_A_identity + Generator_A_cam Generator_B_loss = Generator_B_gan + Generator_B_cycle + Generator_B_identity + Generator_B_cam Discriminator_A_loss = self.adv_weight * D_ad_loss_A Discriminator_B_loss = self.adv_weight * D_ad_loss_B self.Generator_loss = Generator_A_loss + Generator_B_loss + regularization_loss( 'generator') self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss + regularization_loss( 'discriminator') """ Result Image """ self.fake_A = x_ba self.fake_B = x_ab self.real_A = self.domain_A self.real_B = self.domain_B """ Training """ t_vars = tf.trainable_variables() G_vars = [var for var in t_vars if 'generator' in var.name] D_vars = [var for var in t_vars if 'discriminator' in var.name] self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize( self.Generator_loss, var_list=G_vars) self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize( self.Discriminator_loss, var_list=D_vars) """" Summary """ self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss) self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss) self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss) self.G_A_gan = tf.summary.scalar("G_A_gan", Generator_A_gan) self.G_A_cycle = tf.summary.scalar("G_A_cycle", Generator_A_cycle) self.G_A_identity = tf.summary.scalar("G_A_identity", Generator_A_identity) self.G_A_cam = tf.summary.scalar("G_A_cam", Generator_A_cam) self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss) self.G_B_gan = tf.summary.scalar("G_B_gan", Generator_B_gan) self.G_B_cycle = tf.summary.scalar("G_B_cycle", Generator_B_cycle) self.G_B_identity = tf.summary.scalar("G_B_identity", Generator_B_identity) self.G_B_cam = tf.summary.scalar("G_B_cam", Generator_B_cam) self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss) self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss) self.rho_var = [] for var in tf.trainable_variables(): if 'rho' in var.name: self.rho_var.append(tf.summary.histogram(var.name, var)) self.rho_var.append( tf.summary.scalar(var.name + "_min", tf.reduce_min(var))) self.rho_var.append( tf.summary.scalar(var.name + "_max", tf.reduce_max(var))) self.rho_var.append( tf.summary.scalar(var.name + "_mean", tf.reduce_mean(var))) g_summary_list = [ self.G_A_loss, self.G_A_gan, self.G_A_cycle, self.G_A_identity, self.G_A_cam, self.G_B_loss, self.G_B_gan, self.G_B_cycle, self.G_B_identity, self.G_B_cam, self.all_G_loss ] g_summary_list.extend(self.rho_var) d_summary_list = [self.D_A_loss, self.D_B_loss, self.all_D_loss] self.G_loss = tf.summary.merge(g_summary_list) self.D_loss = tf.summary.merge(d_summary_list) else: """ Test """ self.test_domain_A = tf.placeholder( tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_A') self.test_domain_B = tf.placeholder( tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_B') self.test_fake_B, _ = self.generate_a2b(self.test_domain_A) self.test_fake_A, _ = self.generate_b2a(self.test_domain_B)
def _initialize_tf_dataset(file_names, augment, over_sample, shuffle, target_size, normalize, nb_workers=8, batch_size=64, shuffle_buffer_size=3000, input_type="img", nr_epochs=1): import tensorflow as tf from tensorflow.contrib.data import map_and_batch from tensorflow.contrib.data import shuffle_and_repeat if not type(target_size) is list: target_size = list(target_size) with tf.name_scope('input_pipeline'): dataset = tf.data.TFRecordDataset(file_names) if shuffle: dataset = dataset.apply( shuffle_and_repeat(shuffle_buffer_size, nr_epochs)) def _decode_and_augment_image(example_proto): keys_to_features = { 'label': tf.FixedLenFeature([], tf.int64), 'shape': tf.FixedLenFeature([], tf.string), 'image': tf.FixedLenFeature([], tf.string), } tfrecord_features = tf.parse_single_example(example_proto, keys_to_features) image = tf.decode_raw(tfrecord_features['image'], tf.uint8) shape = tf.decode_raw(tfrecord_features['shape'], tf.int64) if input_type == ".jpeg": image = tf.reshape(image, target_size + [3]) else: image = tf.reshape(image, target_size) label = tfrecord_features['label'] if augment: image = tf.image.random_flip_left_right(image) image = tf.image.random_flip_up_down(image) degrees = tf.random_uniform((), minval=-180, maxval=180) image = tf.contrib.image.rotate(image, degrees) width_shift = tf.random_uniform((), minval=0, maxval=0.05) height_shift = tf.random_uniform((), minval=0, maxval=0.05) horizontal_pad = tf.cast( tf.ceil(width_shift * target_size[0]), tf.int32) vertical_pad = tf.cast(tf.ceil(height_shift * target_size[1]), tf.int32) padding = tf.stack([ horizontal_pad, horizontal_pad, vertical_pad, vertical_pad, tf.constant(0), tf.constant(0) ]) padding = tf.reshape(padding, (3, 2)) image = tf.pad(image, padding) image = tf.random_crop(image, target_size + [3]) zoom = tf.random_uniform((), minval=-0.1, maxval=0.1) new_dim = tf.cast(tf.ceil((1 - zoom) * target_size[0]), dtype=tf.int32) image = tf.image.resize_image_with_crop_or_pad(image, new_dim, new_dim) image = tf.image.resize_images( image, target_size, method=tf.image.ResizeMethod.BILINEAR) if normalize: std = tf.constant( np.array([70.53946096, 51.71475228, 43.03428563]), dtype=tf.float32) std = tf.expand_dims(tf.expand_dims(std, axis=0), axis=0) mean = tf.constant( np.array([108.64628601, 75.86886597, 54.34005736]), dtype=tf.float32) mean = tf.expand_dims(tf.expand_dims(mean, axis=0), axis=0) image = (tf.cast(image, dtype=tf.float32) - mean) / std label = tf.reshape(label, [1]) if input_type == ".jpeg": image = tf.reshape(image, target_size + [3]) else: image = tf.reshape(image, target_size) return {'shape': shape, 'image': image}, label dataset = dataset \ .apply(map_and_batch(_decode_and_augment_image, batch_size=batch_size, num_parallel_batches=nb_workers, drop_remainder=True)) \ .prefetch(nb_workers) # def _augment_images(example) return dataset
def build_model(self): self.ema = tf.train.ExponentialMovingAverage(decay=self.ema_decay) if self.phase == 'train' : """ Input Image""" img_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.label_list, self.augment_flag) img_class.preprocess() dataset_num = len(img_class.image) print("Dataset number : ", dataset_num) self.lr = tf.placeholder(tf.float32, name='learning_rate') self.ds_weight_placeholder = tf.placeholder(tf.float32, name='ds_weight') img_and_label = tf.data.Dataset.from_tensor_slices((img_class.image, img_class.label)) gpu_device = '/gpu:0' img_and_label = img_and_label.apply(shuffle_and_repeat(dataset_num)).apply( map_and_batch(img_class.image_processing, self.batch_size * self.gpu_num, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, None)) img_and_label_iterator = img_and_label.make_one_shot_iterator() self.x_real, label_org = img_and_label_iterator.get_next() # [bs, 256, 256, 3], [bs, 1] # label_trg = tf.random_shuffle(label_org) # Target domain labels label_trg = tf.random_uniform(shape=tf.shape(label_org), minval=0, maxval=self.c_dim, dtype=tf.int32) # Target domain labels """ split """ x_real_gpu_split = tf.split(self.x_real, num_or_size_splits=self.gpu_num, axis=0) label_org_gpu_split = tf.split(label_org, num_or_size_splits=self.gpu_num, axis=0) label_trg_gpu_split = tf.split(label_trg, num_or_size_splits=self.gpu_num, axis=0) g_adv_loss_per_gpu = [] g_sty_recon_loss_per_gpu = [] g_sty_diverse_loss_per_gpu = [] g_cyc_loss_per_gpu = [] g_loss_per_gpu = [] d_adv_loss_per_gpu = [] d_loss_per_gpu = [] for gpu_id in range(self.gpu_num): with tf.device(tf.DeviceSpec(device_type="GPU", device_index=gpu_id)): with tf.variable_scope(tf.get_variable_scope(), reuse=(gpu_id > 0)): x_real_split = tf.split(x_real_gpu_split[gpu_id], num_or_size_splits=self.batch_size, axis=0) label_org_split = tf.split(label_org_gpu_split[gpu_id], num_or_size_splits=self.batch_size, axis=0) label_trg_split = tf.split(label_trg_gpu_split[gpu_id], num_or_size_splits=self.batch_size, axis=0) g_adv_loss = None g_sty_recon_loss = None g_sty_diverse_loss = None g_cyc_loss = None d_adv_loss = None d_simple_gp = None d_gp = None for each_bs in range(self.batch_size) : """ Define Generator, Discriminator """ x_real_each = x_real_split[each_bs] # [1, 256, 256, 3] label_org_each = tf.squeeze(label_org_split[each_bs], axis=[0, 1]) # [1, 1] -> [] label_trg_each = tf.squeeze(label_trg_split[each_bs], axis=[0, 1]) random_style_code = tf.random_normal(shape=[1, self.style_dim]) random_style_code_1 = tf.random_normal(shape=[1, self.style_dim]) random_style_code_2 = tf.random_normal(shape=[1, self.style_dim]) random_style = tf.gather(self.mapping_network(random_style_code), label_trg_each) random_style_1 = tf.gather(self.mapping_network(random_style_code_1), label_trg_each) random_style_2 = tf.gather(self.mapping_network(random_style_code_2), label_trg_each) x_fake = self.generator(x_real_each, random_style) # for adversarial objective x_fake_1 = self.generator(x_real_each, random_style_1) # for style diversification x_fake_2 = self.generator(x_real_each, random_style_2) # for style diversification x_real_each_style = tf.gather(self.style_encoder(x_real_each), label_org_each) # for cycle consistency x_fake_style = tf.gather(self.style_encoder(x_fake), label_trg_each) # for style reconstruction x_cycle = self.generator(x_fake, x_real_each_style) # for cycle consistency real_logit = tf.gather(self.discriminator(x_real_each), label_org_each) fake_logit = tf.gather(self.discriminator(x_fake), label_trg_each) """ Define loss """ if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan': GP = self.gradient_panalty(real=x_real_each, fake=x_fake, real_label=label_org_each) else: GP = tf.constant([0], tf.float32) if each_bs == 0 : g_adv_loss = self.adv_weight * generator_loss(self.gan_type, fake_logit) g_sty_recon_loss = self.sty_weight * L1_loss(random_style, x_fake_style) g_sty_diverse_loss = self.ds_weight_placeholder * L1_loss(x_fake_1, x_fake_2) g_cyc_loss = self.cyc_weight * L1_loss(x_real_each, x_cycle) d_adv_loss = self.adv_weight * discriminator_loss(self.gan_type, real_logit, fake_logit) d_simple_gp = self.adv_weight * simple_gp(real_logit, fake_logit, x_real_each, x_fake, r1_gamma=self.r1_weight, r2_gamma=0.0) d_gp = self.adv_weight * GP else : g_adv_loss = tf.concat([g_adv_loss, self.adv_weight * generator_loss(self.gan_type, fake_logit)], axis=0) g_sty_recon_loss = tf.concat([g_sty_recon_loss, self.sty_weight * L1_loss(random_style, x_fake_style)], axis=0) g_sty_diverse_loss = tf.concat([g_sty_diverse_loss, self.ds_weight_placeholder * L1_loss(x_fake_1, x_fake_2)], axis=0) g_cyc_loss = tf.concat([g_cyc_loss, self.cyc_weight * L1_loss(x_real_each, x_cycle)], axis=0) d_adv_loss = tf.concat([d_adv_loss, self.adv_weight * discriminator_loss(self.gan_type, real_logit, fake_logit)], axis=0) d_simple_gp = tf.concat([d_simple_gp, self.adv_weight * simple_gp(real_logit, fake_logit, x_real_each, x_fake, r1_gamma=self.r1_weight, r2_gamma=0.0)], axis=0) d_gp = tf.concat([d_gp, self.adv_weight * GP], axis=0) g_adv_loss = tf.reduce_mean(g_adv_loss) g_sty_recon_loss = tf.reduce_mean(g_sty_recon_loss) g_sty_diverse_loss = tf.reduce_mean(g_sty_diverse_loss) g_cyc_loss = tf.reduce_mean(g_cyc_loss) d_adv_loss = tf.reduce_mean(d_adv_loss) d_simple_gp = tf.reduce_mean(tf.reduce_sum(d_simple_gp, axis=[1, 2, 3])) d_gp = tf.reduce_mean(d_gp) g_loss = g_adv_loss + g_sty_recon_loss - g_sty_diverse_loss + g_cyc_loss d_loss = d_adv_loss + d_simple_gp + d_gp g_adv_loss_per_gpu.append(g_adv_loss) g_sty_recon_loss_per_gpu.append(g_sty_recon_loss) g_sty_diverse_loss_per_gpu.append(g_sty_diverse_loss) g_cyc_loss_per_gpu.append(g_cyc_loss) d_adv_loss_per_gpu.append(d_adv_loss) g_loss_per_gpu.append(g_loss) d_loss_per_gpu.append(d_loss) g_adv_loss = tf.reduce_mean(g_adv_loss_per_gpu) g_sty_recon_loss = tf.reduce_mean(g_sty_recon_loss_per_gpu) g_sty_diverse_loss = tf.reduce_mean(g_sty_diverse_loss_per_gpu) g_cyc_loss = tf.reduce_mean(g_cyc_loss_per_gpu) self.g_loss = tf.reduce_mean(g_loss_per_gpu) d_adv_loss = tf.reduce_mean(d_adv_loss_per_gpu) self.d_loss = tf.reduce_mean(d_loss_per_gpu) """ Training """ t_vars = tf.trainable_variables() G_vars = [var for var in t_vars if 'generator' in var.name] E_vars = [var for var in t_vars if 'encoder' in var.name] F_vars = [var for var in t_vars if 'mapping' in var.name] D_vars = [var for var in t_vars if 'discriminator' in var.name] if self.gpu_num == 1 : prev_g_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0, beta2=0.99).minimize(self.g_loss, var_list=G_vars) prev_e_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0, beta2=0.99).minimize(self.g_loss, var_list=E_vars) prev_f_optimizer = tf.train.AdamOptimizer(self.lr * 0.01, beta1=0, beta2=0.99).minimize(self.g_loss, var_list=F_vars) self.d_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0, beta2=0.99).minimize(self.d_loss, var_list=D_vars) else : prev_g_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0, beta2=0.99).minimize(self.g_loss, var_list=G_vars, colocate_gradients_with_ops=True) prev_e_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0, beta2=0.99).minimize(self.g_loss, var_list=E_vars, colocate_gradients_with_ops=True) prev_f_optimizer = tf.train.AdamOptimizer(self.lr * 0.01, beta1=0, beta2=0.99).minimize(self.g_loss, var_list=F_vars, colocate_gradients_with_ops=True) self.d_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0, beta2=0.99).minimize(self.d_loss, var_list=D_vars, colocate_gradients_with_ops=True) with tf.control_dependencies([prev_g_optimizer, prev_e_optimizer, prev_f_optimizer]): self.g_optimizer = self.ema.apply(G_vars) self.e_optimizer = self.ema.apply(E_vars) self.f_optimizer = self.ema.apply(F_vars) """" Summary """ self.Generator_loss = tf.summary.scalar("g_loss", self.g_loss) self.Discriminator_loss = tf.summary.scalar("d_loss", self.d_loss) self.g_adv_loss = tf.summary.scalar("g_adv_loss", g_adv_loss) self.g_sty_recon_loss = tf.summary.scalar("g_sty_recon_loss", g_sty_recon_loss) self.g_sty_diverse_loss = tf.summary.scalar("g_sty_diverse_loss", g_sty_diverse_loss) self.g_cyc_loss = tf.summary.scalar("g_cyc_loss", g_cyc_loss) self.d_adv_loss = tf.summary.scalar("d_adv_loss", d_adv_loss) g_summary_list = [self.Generator_loss, self.g_adv_loss, self.g_sty_recon_loss, self.g_sty_diverse_loss, self.g_cyc_loss] d_summary_list = [self.Discriminator_loss, self.d_adv_loss] self.g_summary_loss = tf.summary.merge(g_summary_list) self.d_summary_loss = tf.summary.merge(d_summary_list) """ Result Image """ def return_g_images(generator, image, code): x = generator(image, code) return x self.x_fake_list = [] first_x_real = tf.expand_dims(self.x_real[0], axis=0) label_fix_list = tf.constant([idx for idx in range(self.c_dim)]) for _ in range(self.num_style): random_style_code = tf.truncated_normal(shape=[1, self.style_dim]) self.x_fake_list.append(tf.map_fn( lambda c: return_g_images(self.generator, first_x_real, tf.gather(self.mapping_network(random_style_code), c)), label_fix_list, dtype=tf.float32)) elif self.phase == 'refer_test': """ Test """ def return_g_images(generator, image, code): x = generator(image, code) return x self.custom_image = tf.placeholder(tf.float32, [1, self.img_height, self.img_width, self.img_ch], name='custom_image') self.refer_image = tf.placeholder(tf.float32, [1, self.img_height, self.img_width, self.img_ch], name='refer_image') label_fix_list = tf.constant([idx for idx in range(self.c_dim)]) self.refer_fake_image = tf.map_fn( lambda c : return_g_images(self.generator, self.custom_image, tf.gather(self.style_encoder(self.refer_image), c)), label_fix_list, dtype=tf.float32) else : """ Test """ def return_g_images(generator, image, code): x = generator(image, code) return x self.custom_image = tf.placeholder(tf.float32, [1, self.img_height, self.img_width, self.img_ch], name='custom_image') label_fix_list = tf.constant([idx for idx in range(self.c_dim)]) random_style_code = tf.truncated_normal(shape=[1, self.style_dim]) self.custom_fake_image = tf.map_fn( lambda c : return_g_images(self.generator, self.custom_image, tf.gather(self.mapping_network(random_style_code), c)), label_fix_list, dtype=tf.float32)
def build_model(self): self.lr = tf.placeholder(tf.float32, name='learning_rate') """ Input Image""" Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag) trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset) trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset) trainB_smooth = tf.data.Dataset.from_tensor_slices(self.trainB_smooth_dataset) gpu_device = '/gpu:0' trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size)) trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size)) trainB_smooth = trainB_smooth.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size)) trainA_iterator = trainA.make_one_shot_iterator() trainB_iterator = trainB.make_one_shot_iterator() trainB_smooth_iterator = trainB_smooth.make_one_shot_iterator() self.real_A = trainA_iterator.get_next() self.real_B = trainB_iterator.get_next() self.real_B_smooth = trainB_smooth_iterator.get_next() self.test_real_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_real_A') """ Define Generator, Discriminator """ self.fake_B = self.generator(self.real_A) real_B_logit = self.discriminator(self.real_B) fake_B_logit = self.discriminator(self.fake_B, reuse=True) real_B_smooth_logit = self.discriminator(self.real_B_smooth, reuse=True) """ Define Loss """ if self.gan_type.__contains__('gp') or self.gan_type.__contains__('lp') or self.gan_type.__contains__('dragan') : GP = self.gradient_panalty(real=self.real_B, fake=self.fake_B) + self.gradient_panalty(self.real_B, fake=self.real_B_smooth) else : GP = 0.0 v_loss = self.vgg_weight * vgg_loss(self.real_A, self.fake_B) g_loss = self.adv_weight * generator_loss(self.gan_type, fake_B_logit) d_loss = self.adv_weight * discriminator_loss(self.gan_type, real_B_logit, fake_B_logit, real_B_smooth_logit) + GP self.Vgg_loss = v_loss self.Generator_loss = g_loss + v_loss self.Discriminator_loss = d_loss """ Result Image """ self.test_fake_B = self.generator(self.test_real_A, reuse=True) """ Training """ t_vars = tf.trainable_variables() G_vars = [var for var in t_vars if 'generator' in var.name] D_vars = [var for var in t_vars if 'discriminator' in var.name] self.init_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Vgg_loss, var_list=G_vars) self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars) self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars) """" Summary """ self.G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss) self.D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss) self.G_gan = tf.summary.scalar("G_gan", g_loss) self.G_vgg = tf.summary.scalar("G_vgg", v_loss) self.V_loss_merge = tf.summary.merge([self.G_vgg]) self.G_loss_merge = tf.summary.merge([self.G_loss, self.G_gan, self.G_vgg]) self.D_loss_merge = tf.summary.merge([self.D_loss])
def build_model(self): """ Graph Input """ # images if self.custom_dataset: Image_Data_Class = ImageData(self.img_size, self.c_dim) inputs = tf.data.Dataset.from_tensor_slices(self.data) gpu_device = '/gpu:0' inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, self.batch_size)) inputs_iterator = inputs.make_one_shot_iterator() self.inputs = inputs_iterator.get_next() else: self.inputs = tf.placeholder( tf.float32, [self.batch_size, self.img_size, self.img_size, self.c_dim], name='real_images') # noises self.z = tf.placeholder(tf.float32, [self.batch_size, 1, 1, self.z_dim], name='z') """ Loss Function """ # output of D for real images real_logits = self.discriminator(self.inputs) # output of D for fake images fake_images = self.generator(self.z) fake_logits = self.discriminator(fake_images, reuse=True) if self.gan_type.__contains__('gp') or self.gan_type.__contains__( 'lp') or self.gan_type.__contains__('dragan'): GP = self.gradient_penalty(real=self.inputs, fake=fake_images) else: GP = 0 # get loss for discriminator self.d_loss = discriminator_loss( self.Ra, self.gan_type, real=real_logits, fake=fake_logits) + GP # get loss for generator self.g_loss = generator_loss(self.Ra, self.gan_type, real=real_logits, fake=fake_logits) """ Training """ # divide trainable variables into a group for D and a group for G t_vars = tf.trainable_variables() d_vars = [var for var in t_vars if 'discriminator' in var.name] g_vars = [var for var in t_vars if 'generator' in var.name] # optimizers self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize( self.d_loss, var_list=d_vars) self.g_optim = tf.train.AdamOptimizer(self.g_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize( self.g_loss, var_list=g_vars) """" Testing """ # for test self.fake_images = self.generator(self.z, is_training=False, reuse=True) """ Summary """ self.d_sum = tf.summary.scalar("d_loss", self.d_loss) self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
def build_model(self): if self.phase == 'train': self.lr = tf.placeholder(tf.float32, name='learning_rate') """ Input Image""" img_data_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.augment_flag) img_data_class.preprocess() self.dataset_num = len(img_data_class.image_list) img_and_embedding = tf.data.Dataset.from_tensor_slices( (img_data_class.image_list, img_data_class.embedding)) gpu_device = '/gpu:0' img_and_embedding = img_and_embedding.apply( shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(img_data_class.image_processing, batch_size=self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, None)) img_and_embedding_iterator = img_and_embedding.make_one_shot_iterator( ) self.real_img_256, self.embedding = img_and_embedding_iterator.get_next( ) sentence_index = tf.random.uniform(shape=[], minval=0, maxval=10, dtype=tf.int32) self.embedding = tf.gather(self.embedding, indices=sentence_index, axis=1) #[bs, 1024] noise = tf.random_normal(shape=[self.batch_size, self.z_dim]) self.fake_img_64, mu_64, logvar_64 = self.generator_1( self.embedding, noise) self.fake_img_256, mu_256, logvar_256 = self.generator_2( self.fake_img_64, self.embedding) self.real_img_64 = tf.image.resize_bilinear(self.real_img_256, size=[64, 64]) self.real_img = [self.real_img_64, self.real_img_256] self.fake_img = [self.fake_img_64, self.fake_img_256] real_logit_64 = self.discriminator_1(self.real_img_64, mu_64) fake_logit_64 = self.discriminator_1(self.fake_img_64, mu_64) real_logit_256 = self.discriminator_2(self.real_img_256, mu_256) fake_logit_256 = self.discriminator_2(self.fake_img_256, mu_256) g_adv_loss_64 = generator_loss(self.gan_type, fake_logit_64) * self.adv_weight g_kl_loss_64 = kl_loss(mu_64, logvar_64) * self.kl_weight d_adv_loss_64 = discriminator_loss(self.gan_type, real_logit_64, fake_logit_64) * self.adv_weight g_loss_64 = g_adv_loss_64 + g_kl_loss_64 d_loss_64 = d_adv_loss_64 g_adv_loss_256 = generator_loss(self.gan_type, fake_logit_256) * self.adv_weight g_kl_loss_256 = kl_loss(mu_256, logvar_256) * self.kl_weight d_adv_loss_256 = discriminator_loss( self.gan_type, real_logit_256, fake_logit_256) * self.adv_weight g_loss_256 = g_adv_loss_256 + g_kl_loss_256 d_loss_256 = d_adv_loss_256 self.g_loss = [g_loss_64, g_loss_256] self.d_loss = [d_loss_64, d_loss_256] """ Training """ t_vars = tf.trainable_variables() G1_vars = [var for var in t_vars if 'generator_1' in var.name] G2_vars = [var for var in t_vars if 'generator_2' in var.name] D1_vars = [var for var in t_vars if 'discriminator_1' in var.name] D2_vars = [var for var in t_vars if 'discriminator_2' in var.name] g1_optim = tf.train.AdamOptimizer( self.lr, beta1=0.5, beta2=0.999).minimize(g_loss_64, var_list=G1_vars) g2_optim = tf.train.AdamOptimizer( self.lr, beta1=0.5, beta2=0.999).minimize(g_loss_256, var_list=G2_vars) d1_optim = tf.train.AdamOptimizer( self.lr, beta1=0.5, beta2=0.999).minimize(d_loss_64, var_list=D1_vars) d2_optim = tf.train.AdamOptimizer( self.lr, beta1=0.5, beta2=0.999).minimize(d_loss_256, var_list=D2_vars) self.g_optim = [g1_optim, g2_optim] self.d_optim = [d1_optim, d2_optim] """" Summary """ self.summary_g_loss_64 = tf.summary.scalar("g_loss_64", g_loss_64) self.summary_g_loss_256 = tf.summary.scalar( "g_loss_256", g_loss_256) self.summary_d_loss_64 = tf.summary.scalar("d_loss_64", d_loss_64) self.summary_d_loss_256 = tf.summary.scalar( "d_loss_256", d_loss_256) self.summary_g_adv_loss_64 = tf.summary.scalar( "g_adv_loss_64", g_adv_loss_64) self.summary_g_adv_loss_256 = tf.summary.scalar( "g_adv_loss_256", g_adv_loss_256) self.summary_g_kl_loss_64 = tf.summary.scalar( "g_kl_loss_64", g_kl_loss_64) self.summary_g_kl_loss_256 = tf.summary.scalar( "g_kl_loss_256", g_kl_loss_256) self.summary_d_adv_loss_64 = tf.summary.scalar( "d_adv_loss_64", d_adv_loss_64) self.summary_d_adv_loss_256 = tf.summary.scalar( "d_adv_loss_256", d_adv_loss_256) g_summary_list = [ self.summary_g_loss_64, self.summary_g_loss_256, self.summary_g_adv_loss_64, self.summary_g_adv_loss_256, self.summary_g_kl_loss_64, self.summary_g_kl_loss_256 ] d_summary_list = [ self.summary_d_loss_64, self.summary_d_loss_256, self.summary_d_adv_loss_64, self.summary_d_adv_loss_256 ] self.summary_merge_g_loss = tf.summary.merge(g_summary_list) self.summary_merge_d_loss = tf.summary.merge(d_summary_list) else: """ Test """ """ Input Image""" img_data_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, augment_flag=False) img_data_class.preprocess() self.dataset_num = len(img_data_class.image_list) img_and_embedding = tf.data.Dataset.from_tensor_slices( (img_data_class.image_list, img_data_class.embedding)) gpu_device = '/gpu:0' img_and_embedding = img_and_embedding.apply( shuffle_and_repeat(self.dataset_num)).apply( map_and_batch(img_data_class.image_processing, batch_size=32, num_parallel_batches=16, drop_remainder=True)).apply( prefetch_to_device(gpu_device, None)) img_and_embedding_iterator = img_and_embedding.make_one_shot_iterator( ) self.real_img_256, self.embedding = img_and_embedding_iterator.get_next( ) sentence_index = tf.random.uniform(shape=[], minval=0, maxval=10, dtype=tf.int32) self.embedding = tf.gather(self.embedding, indices=sentence_index, axis=1) # [bs, 1024] noise = tf.random_normal(shape=[self.batch_size, self.z_dim]) self.fake_img_64, mu_64, logvar_64 = self.generator_1( self.embedding, noise, is_training=False) self.fake_img_256, mu_256, logvar_256 = self.generator_2( self.fake_img_64, self.embedding, is_training=False) self.test_fake_img = self.fake_img_256 self.test_real_img = self.real_img_256 self.test_fake_img_64 = self.fake_img_64