def train(): param = Parameter() # 1. Build models encoder = Encoder(param).model() decoder = Decoder(param).model() discriminator = Discriminator(param).model() # 2. Set optimizers opt_ae = tf.keras.optimizers.Adam(learning_rate=param.learning_rate_ae, beta_1=0.5, beta_2=0.999, epsilon=0.01) opt_gen = tf.keras.optimizers.Adam(learning_rate=param.learning_rate_gen, beta_1=0.5, beta_2=0.999, epsilon=0.01) opt_dis = tf.keras.optimizers.Adam(learning_rate=param.learning_rate_dis, beta_1=0.5, beta_2=0.999, epsilon=0.01) # 3. Set trainable variables var_ae = encoder.trainable_variables + decoder.trainable_variables var_gen = encoder.trainable_variables var_dis = discriminator.trainable_variables # 4. Load data data_loader = MNISTLoader(one_hot=True) train_set = data_loader.train.batch(batch_size=param.batch_size, drop_remainder=True).shuffle( buffer_size=data_loader.num_train, reshuffle_each_iteration=True) test_set = data_loader.test.batch(batch_size=param.batch_size, drop_remainder=True) # 5. Define loss # 6. Etc. check_point_dir = os.path.join(param.cur_dir, 'training_checkpoints') graph_path = os.path.join(param.cur_dir, 'graph') if not os.path.isdir(graph_path): os.mkdir(graph_path) model_path = os.path.join(param.cur_dir, 'model') if not os.path.isdir(model_path): os.mkdir(model_path) check_point_prefix = os.path.join(check_point_dir, 'aae') enc_name = 'aae_enc' dec_name = 'aae_dec' dis_name = 'aae_dis' graph = 'aae' check_point = tf.train.Checkpoint(opt_gen=opt_gen, opt_dis=opt_dis, opt_ae=opt_ae, encoder=encoder, decoder=decoder, discriminator=discriminator) ckpt_manager = tf.train.CheckpointManager( check_point, check_point_dir, max_to_keep=5, checkpoint_name=check_point_prefix) # 7. Define train / validation step ################################################################################ def training_gen_step(_x, _y): with tf.GradientTape() as _gen_tape: _z_tilde = encoder(_x, training=True) _z_tilde_input = tf.concat([_y, _z_tilde], axis=-1, name='z_tilde_input') _dis_fake = discriminator(_z_tilde_input, training=True) _loss_gen = -tf.reduce_mean(_dis_fake) _grad_gen = _gen_tape.gradient(_loss_gen, var_gen) opt_gen.apply_gradients(zip(_grad_gen, var_gen)) def training_step(_x, _y): with tf.GradientTape() as _ae_tape, tf.GradientTape( ) as _gen_tape, tf.GradientTape() as _dis_tape: _z = tf.random.normal(shape=(param.batch_size, param.latent_dim), mean=0.0, stddev=param.prior_noise_std, dtype=tf.float32) _z_tilde = encoder(_x, training=True) _z_input = tf.concat([_y, _z], axis=-1, name='z_input') _z_tilde_input = tf.concat([_y, _z_tilde], axis=-1, name='z_tilde_input') _x_bar = decoder(_z_tilde_input, training=True) _dis_real = discriminator(_z_input, training=True) _dis_fake = discriminator(_z_tilde_input, training=True) _real_loss, _fake_loss = -tf.reduce_mean( _dis_real), tf.reduce_mean(_dis_fake) _gp = gradient_penalty(partial(discriminator, training=True), _z_input, _z_tilde_input) _loss_gen = -tf.reduce_mean(_dis_fake) _loss_dis = (_real_loss + _fake_loss) + _gp * param.w_gp_lambda _loss_ae = tf.reduce_mean(tf.abs(tf.subtract(_x, _x_bar))) _grad_ae = _ae_tape.gradient(_loss_ae, var_ae) _grad_gen = _gen_tape.gradient(_loss_gen, var_gen) _grad_dis = _dis_tape.gradient(_loss_dis, var_dis) opt_ae.apply_gradients(zip(_grad_ae, var_ae)) opt_gen.apply_gradients(zip(_grad_gen, var_gen)) opt_dis.apply_gradients(zip(_grad_dis, var_dis)) def validation_step(_x, _y): _z = tf.random.normal(shape=(param.batch_size, param.latent_dim), mean=0.0, stddev=param.prior_noise_std, dtype=tf.float32) _z_tilde = encoder(_x, training=False) _z_input = tf.concat([_y, _z], axis=-1, name='z_input') _z_tilde_input = tf.concat([_y, _z_tilde], axis=-1, name='z_tilde_input') _x_bar = decoder(_z_tilde_input, training=False) _x_tilde = decoder(_z_input, training=False) _dis_real = discriminator(_z_input, training=False) _dis_fake = discriminator(_z_tilde_input, training=False) _real_loss, _fake_loss = -tf.reduce_mean(_dis_real), tf.reduce_mean( _dis_fake) _gp = gradient_penalty(partial(discriminator, training=False), _z_input, _z_tilde_input) _loss_gen = -tf.reduce_mean(_dis_fake) _loss_dis = (_real_loss + _fake_loss) + _gp * param.w_gp_lambda _loss_ae = tf.reduce_mean(tf.abs(tf.subtract(_x, _x_bar))) return _x_tilde, _x_bar, _loss_dis.numpy(), _loss_gen.numpy( ), _loss_ae.numpy(), (-_fake_loss.numpy() - _real_loss.numpy()) #################################################################################################################### # 8. Train start_time = time.time() for epoch in range(0, param.max_epoch): # 8-1. Train AAE num_train = 0 for x_train, y_train in train_set: if num_train % 2 == 0: training_step(x_train, tf.cast(y_train, dtype=tf.float32)) else: training_gen_step(x_train, tf.cast(y_train, dtype=tf.float32)) training_gen_step(x_train, tf.cast(y_train, dtype=tf.float32)) num_train += 1 # 8-2. Validation num_valid = 0 val_loss_dis, val_loss_gen, val_loss_ae, val_was_x = [], [], [], [] for x_valid, y_valid in test_set: x_tilde, x_bar, loss_dis, loss_gen, loss_ae, was_x = validation_step( x_valid, tf.cast(y_valid, dtype=tf.float32)) val_loss_dis.append(loss_dis) val_loss_gen.append(loss_gen) val_loss_ae.append(loss_ae) val_was_x.append(was_x) num_valid += 1 if num_valid > param.valid_step: break # 8-3. Report in training elapsed_time = (time.time() - start_time) / 60. _val_loss_ae = np.mean(np.reshape(val_loss_ae, (-1))) _val_loss_dis = np.mean(np.reshape(val_loss_dis, (-1))) _val_loss_gen = np.mean(np.reshape(val_loss_gen, (-1))) _val_was_x = np.mean(np.reshape(val_was_x, (-1))) print( "[Epoch: {:04d}] {:.01f}m.\tdis: {:.6f}\tgen: {:.6f}\tae: {:.6f}\tw_x: {:.6f}" .format(epoch, elapsed_time, _val_loss_dis, _val_loss_gen, _val_loss_ae, _val_was_x)) if epoch % param.save_frequency == 0 and epoch > 1: save_decode_image_array(x_valid.numpy(), path=os.path.join( graph_path, '{}_original-{:04d}.png'.format( graph, epoch))) save_decode_image_array( x_bar.numpy(), path=os.path.join(graph_path, '{}_decode-{:04d}.png'.format(graph, epoch))) save_decode_image_array(x_tilde.numpy(), path=os.path.join( graph_path, '{}_generated-{:04d}.png'.format( graph, epoch))) ckpt_manager.save(checkpoint_number=epoch) save_message = "\tSave model: End of training" encoder.save_weights(os.path.join(model_path, enc_name)) decoder.save_weights(os.path.join(model_path, dec_name)) discriminator.save_weights(os.path.join(model_path, dis_name)) # 6-3. Report print("[Epoch: {:04d}] {:.01f} min.".format(param.max_epoch, elapsed_time)) print(save_message)
def train(w_gp=False): param = Parameter() # 1. Build models generator = Generator(param).model() discriminator = Discriminator(param).model() # 2. Set optimizers opt_gen = tf.keras.optimizers.Adam(learning_rate=param.learning_rate_gen) opt_dis = tf.keras.optimizers.Adam(learning_rate=param.learning_rate_dis) # 3. Set trainable variables var_gen = generator.trainable_variables var_dis = discriminator.trainable_variables # 4. Load data data_loader = MNISTLoader(one_hot=False) train_set = data_loader.train.batch(batch_size=param.batch_size, drop_remainder=True).shuffle( buffer_size=data_loader.num_train, reshuffle_each_iteration=True) test_set = data_loader.test.batch(batch_size=param.batch_size, drop_remainder=True) # 5. Define loss cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True) # 6. Etc. check_point_dir = os.path.join(param.cur_dir, 'training_checkpoints') graph_path = os.path.join(param.cur_dir, 'graph') if not os.path.isdir(graph_path): os.mkdir(graph_path) model_path = os.path.join(param.cur_dir, 'model') if not os.path.isdir(model_path): os.mkdir(model_path) if w_gp is True: check_point_prefix = os.path.join(check_point_dir, 'gan_w_gp') gen_name = 'gan_w_gp' dis_name = 'dis_w_gp' graph = 'gan_w_gp' else: check_point_prefix = os.path.join(check_point_dir, 'gan') gen_name = 'gan' dis_name = 'dis' graph = 'gan' check_point = tf.train.Checkpoint(opt_gen=opt_gen, opt_dis=opt_dis, generator=generator, discriminator=discriminator) ckpt_manager = tf.train.CheckpointManager( check_point, check_point_dir, max_to_keep=5, checkpoint_name=check_point_prefix) # 7. Define train / validation step ################################################################################ def training_step(_x): with tf.GradientTape() as _gen_tape, tf.GradientTape() as _dis_tape: _noise = tf.random.normal(shape=(param.batch_size, param.latent_dim), mean=0.0, stddev=param.prior_noise_std, dtype=tf.float32) _x_tilde = generator(_noise, training=True) _dis_real = discriminator(_x, training=True) _dis_fake = discriminator(_x_tilde, training=True) if w_gp is True: _real_loss, _fake_loss = -tf.reduce_mean( _dis_real), tf.reduce_mean(_dis_fake) _gp = gradient_penalty(partial(discriminator, training=True), _x, _x_tilde) _loss_dis = (_real_loss + _fake_loss) + _gp * param.w_gp_lambda _loss_gen = -tf.reduce_mean(_dis_fake) else: _loss_dis = cross_entropy( tf.ones_like(_dis_real), _dis_real) + cross_entropy( tf.zeros_like(_dis_fake), _dis_fake) _loss_gen = cross_entropy(tf.ones_like(_dis_fake), _dis_fake) _grad_gen = _gen_tape.gradient(_loss_gen, var_gen) _grad_dis = _dis_tape.gradient(_loss_dis, var_dis) opt_gen.apply_gradients(zip(_grad_gen, var_gen)) opt_dis.apply_gradients(zip(_grad_dis, var_dis)) def validation_step(_x): _noise = tf.random.normal(shape=(param.batch_size, param.latent_dim), mean=0.0, stddev=param.prior_noise_std, dtype=tf.float32) _x_tilde = generator(_noise, training=False) _dis_real = discriminator(_x, training=False) _dis_fake = discriminator(_x_tilde, training=False) if w_gp is True: _real_loss, _fake_loss = -tf.reduce_mean( _dis_real), tf.reduce_mean(_dis_fake) _gp = gradient_penalty(partial(discriminator, training=False), _x, _x_tilde) _loss_dis = (_real_loss + _fake_loss) + _gp * param.w_gp_lambda _loss_gen = -tf.reduce_mean(_dis_fake) else: _loss_dis = cross_entropy(tf.ones_like(_dis_real), _dis_real) + cross_entropy( tf.zeros_like(_dis_fake), _dis_fake) _loss_gen = cross_entropy(tf.ones_like(_dis_fake), _dis_fake) return _x_tilde, _loss_dis.numpy(), _loss_gen.numpy() #################################################################################################################### # 8. Train start_time = time.time() for epoch in range(0, param.max_epoch): # 8-1. Train GANs for x_train, _ in train_set: training_step(x_train) # 8-2. Validation num_valid = 0 val_loss_dis, val_loss_gen = [], [] for x_valid, _ in test_set: if num_valid == param.valid_step: break x_tilde, loss_dis, loss_gen = validation_step(x_valid) val_loss_dis.append(loss_dis) val_loss_gen.append(loss_gen) num_valid += 1 if epoch % param.save_frequency == 0 and epoch > 1: save_decode_image_array(x_valid.numpy(), path=os.path.join( graph_path, '{}_original-{:04d}.png'.format( graph, epoch))) save_decode_image_array(x_tilde.numpy(), path=os.path.join( graph_path, '{}_generated-{:04d}.png'.format( graph, epoch))) ckpt_manager.save(checkpoint_number=epoch) # 7-3. Report in training elapsed_time = (time.time() - start_time) / 60. _val_loss_dis = np.mean(np.reshape(val_loss_dis, (-1))) _val_loss_gen = np.mean(np.reshape(val_loss_gen, (-1))) print( "[Epoch: {:04d}] {:.01f} min.\t loss dis: {:.6f}\t loss gen: {:.6f}" .format(epoch, elapsed_time, _val_loss_dis, _val_loss_gen)) save_message = "\tSave model: End of training" generator.save_weights(os.path.join(model_path, gen_name)) discriminator.save_weights(os.path.join(model_path, dis_name)) # 6-3. Report print("[Epoch: {:04d}] {:.01f} min.".format(param.max_epoch, elapsed_time)) print(save_message)
def inference(): param = Parameter() # 1. Build models encoder = Encoder(param).model() decoder = Decoder(param).model() discriminator = Discriminator(param).model() # 2. Load data data_loader = MNISTLoader(one_hot=True) train_set = data_loader.train.batch(batch_size=param.batch_size, drop_remainder=True) test_set = data_loader.test.batch(batch_size=param.batch_size, drop_remainder=True) # 3. Etc. graph_path = os.path.join(param.cur_dir, 'graph') if not os.path.isdir(graph_path): os.mkdir(graph_path) model_path = os.path.join(param.cur_dir, 'model') if not os.path.isdir(model_path): os.mkdir(model_path) # 4. Load model enc_name = 'aae_enc' dec_name = 'aae_dec' dis_name = 'aae_dis' graph = 'aae' encoder.load_weights(os.path.join(model_path, enc_name)) discriminator.load_weights(os.path.join(model_path, dis_name)) decoder.load_weights(os.path.join(model_path, dec_name)) # 5. Define loss # 6. Define testing step ########################################################################################### def testing_step(_x, _y): _z = tf.random.normal(shape=(param.batch_size, param.latent_dim), mean=0.0, stddev=param.prior_noise_std, dtype=tf.float32) _z_tilde = encoder(_x, training=False) _z_input = tf.concat([_y, _z], axis=-1, name='z_input') _z_tilde_input = tf.concat([_y, _z_tilde], axis=-1, name='z_tilde_input') _x_bar = decoder(_z_tilde_input, training=False) _x_tilde = decoder(_z_input, training=False) _dis_real = discriminator(_z_input, training=False) _dis_fake = discriminator(_z_tilde_input, training=False) _real_loss, _fake_loss = -tf.reduce_mean(_dis_real), tf.reduce_mean( _dis_fake) _gp = gradient_penalty(partial(discriminator, training=False), _z_input, _z_tilde_input) _loss_gen = -tf.reduce_mean(_dis_fake) _loss_dis = (_real_loss + _fake_loss) + _gp * param.w_gp_lambda _loss_ae = tf.reduce_mean(tf.abs(tf.subtract(_x, _x_bar))) return _x_tilde, _x_bar, _loss_dis.numpy(), _loss_gen.numpy( ), _loss_ae.numpy(), (-_fake_loss.numpy() - _real_loss.numpy()) #################################################################################################################### # 6. Inference train_loss_dis, train_loss_gen, train_loss_ae, train_was_x = [], [], [], [] for x_train, y_train in train_set: y = tf.cast(y_train, dtype=tf.float32) x_tilde, x_bar, loss_dis, loss_gen, loss_ae, was_x = testing_step( x_train, y) train_loss_dis.append(loss_dis) train_loss_gen.append(loss_gen) train_loss_ae.append(loss_ae) train_was_x.append(was_x) num_test = 0 val_loss_dis, val_loss_gen, val_loss_ae, val_was_x = [], [], [], [] test_loss_dis, test_loss_gen, test_loss_ae, test_was_x = [], [], [], [] for x_test, y_test in test_set: y = tf.cast(y_test, dtype=tf.float32) x_tilde, x_bar, loss_dis, loss_gen, loss_ae, was_x = testing_step( x_test, y) if num_test <= param.valid_step: val_loss_dis.append(loss_dis) val_loss_gen.append(loss_gen) val_loss_ae.append(loss_ae) val_was_x.append(was_x) else: test_loss_dis.append(loss_dis) test_loss_gen.append(loss_gen) test_loss_ae.append(loss_ae) test_was_x.append(was_x) num_test += 1 _z = tf.random.normal(shape=(param.batch_size, param.latent_dim), mean=0.0, stddev=param.prior_noise_std, dtype=tf.float32) for class_idx in range(0, param.num_class): _indices = np.ones(param.batch_size, dtype=np.float) * class_idx _y = tf.one_hot(indices=_indices, depth=param.num_class, dtype=tf.float32) _gen_input = tf.concat([_y, _z], axis=-1, name='gen_input') _x_tilde = decoder(_gen_input, training=False) save_decode_image_array( _x_tilde.numpy(), path=os.path.join(graph_path, '{}_c{}_generated.png'.format(graph, class_idx))) # 7. Report train_loss_dis = np.mean(np.reshape(train_loss_dis, (-1))) train_loss_gen = np.mean(np.reshape(train_loss_gen, (-1))) train_loss_ae = np.mean(np.reshape(train_loss_ae, (-1))) train_was_x = np.mean(np.reshape(train_was_x, (-1))) val_loss_dis = np.mean(np.reshape(val_loss_dis, (-1))) val_loss_gen = np.mean(np.reshape(val_loss_gen, (-1))) val_loss_ae = np.mean(np.reshape(val_loss_ae, (-1))) val_was_x = np.mean(np.reshape(val_was_x, (-1))) test_loss_dis = np.mean(np.reshape(test_loss_dis, (-1))) test_loss_gen = np.mean(np.reshape(test_loss_gen, (-1))) test_loss_ae = np.mean(np.reshape(test_loss_ae, (-1))) test_was_x = np.mean(np.reshape(test_was_x, (-1))) print("[Loss dis] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}". format(train_loss_dis, val_loss_dis, test_loss_dis)) print("[Loss gen] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}". format(train_loss_gen, val_loss_gen, test_loss_gen)) print("[Loss ae] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}". format(train_loss_ae, val_loss_ae, test_loss_ae)) print( "[Was X] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}".format( train_was_x, val_was_x, test_was_x)) # 8. Draw some samples save_decode_image_array(x_test.numpy(), path=os.path.join(graph_path, '{}_original.png'.format(graph))) save_decode_image_array(x_bar.numpy(), path=os.path.join(graph_path, '{}_decoded.png'.format(graph))) save_decode_image_array(x_tilde.numpy(), path=os.path.join( graph_path, '{}_generated.png'.format(graph)))
def inference(denoise=True): param = Parameter() # 1. Build models encoder = Encoder(param).model() decoder = Decoder(param).model() # 2. Load data data_loader = MNISTLoader() train_set = data_loader.train.batch(batch_size=param.batch_size, drop_remainder=True) test_set = data_loader.test.batch(batch_size=param.batch_size, drop_remainder=True) # 3. Etc. graph_path = os.path.join(param.cur_dir, 'graph') if not os.path.isdir(graph_path): os.mkdir(graph_path) model_path = os.path.join(param.cur_dir, 'model') if not os.path.isdir(model_path): os.mkdir(model_path) # 4. Load model if denoise is True: graph = 'denoising_ae' enc_name = 'encoder_denoise' dec_name = 'decoder_denoise' else: graph = 'ae' enc_name = 'encoder' dec_name = 'decoder' encoder.load_weights(os.path.join(model_path, enc_name)) decoder.load_weights(os.path.join(model_path, dec_name)) # 5. Define loss # 6. Define test step ############################################################################################## def testing_step(_x): if denoise is True: _noise = tf.random.normal(shape=(param.batch_size, ) + param.input_dim, mean=0.0, stddev=param.white_noise_std, dtype=tf.float32) _x_test = _x + _noise else: _x_test = _x _z_tilde = encoder(_x_test, training=False) _x_bar = decoder(_z_tilde, training=False) _loss_ae = tf.reduce_mean(tf.abs(tf.subtract( _x, _x_bar))) # pixel-wise loss return _x_test, _x_bar, _loss_ae.numpy() #################################################################################################################### # 6. Inference train_mse = [] for x_train, _ in train_set: x_noise, x_bar, loss_ae = testing_step(x_train) train_mse.append(loss_ae) num_test = 0 valid_mse = [] test_mse = [] for x_test, _ in test_set: x_noise, x_bar, loss_ae = testing_step(x_test) if num_test <= param.valid_step: valid_mse.append(loss_ae) else: test_mse.append(loss_ae) num_test += 1 # 7. Report print("[Loss] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}".format( np.mean(train_mse), np.mean(valid_mse), np.mean(test_mse))) # 8. Draw some samples if denoise is True: save_decode_image_array(x_noise.numpy(), path=os.path.join( graph_path, '{}_noise.png'.format(graph))) save_decode_image_array(x_test.numpy(), path=os.path.join(graph_path, '{}_original.png'.format(graph))) save_decode_image_array(x_bar.numpy(), path=os.path.join(graph_path, '{}_decoded.png'.format(graph)))
def train(): param = Parameter() # 1. Build models generator = Generator(param).model() discriminator = Discriminator(param).model() classifier = Classifier(param).model() # 2. Set optimizers opt_gen = tf.keras.optimizers.Adam(learning_rate=param.learning_rate_gen, beta_1=0.5, beta_2=0.999) opt_dis = tf.keras.optimizers.Adam(learning_rate=param.learning_rate_dis, beta_1=0.5, beta_2=0.999) opt_cla = tf.keras.optimizers.Adam(learning_rate=param.learning_rate_cla, beta_1=0.5, beta_2=0.999) # 3. Set trainable variables var_gen = generator.trainable_variables var_dis = discriminator.trainable_variables var_cla = generator.trainable_variables + classifier.trainable_variables + discriminator.trainable_variables[: -2] # 4. Load data data_loader = MNISTLoader(one_hot=True) train_set = data_loader.train.batch(batch_size=param.batch_size, drop_remainder=True).shuffle( buffer_size=data_loader.num_train, reshuffle_each_iteration=True) test_set = data_loader.test.batch(batch_size=param.batch_size, drop_remainder=True) # 5. Define loss # 6. Etc. check_point_dir = os.path.join(param.cur_dir, 'training_checkpoints') graph_path = os.path.join(param.cur_dir, 'graph') if not os.path.isdir(graph_path): os.mkdir(graph_path) model_path = os.path.join(param.cur_dir, 'model') if not os.path.isdir(model_path): os.mkdir(model_path) check_point_prefix = os.path.join(check_point_dir, 'acgans') gen_name = 'acgans_gen' dis_name = 'acgans_dis' cla_name = 'acgans_cla' graph = 'acgans' check_point = tf.train.Checkpoint(opt_gen=opt_gen, opt_dis=opt_dis, generator=generator, discriminator=discriminator, classifier=classifier) ckpt_manager = tf.train.CheckpointManager( check_point, check_point_dir, max_to_keep=5, checkpoint_name=check_point_prefix) # 7. Define train / validation step ################################################################################ def training_step(_x, _y): with tf.GradientTape() as _gen_tape, tf.GradientTape( ) as _dis_tape, tf.GradientTape() as _cla_tape: _z = tf.random.uniform(shape=(param.batch_size, param.latent_dim), minval=-1.0, maxval=1.0, dtype=tf.float32) _gen_input = tf.concat([_y, _z], axis=-1, name='gen_input') _x_tilde = generator(_gen_input, training=True) _cla_real_logits, _dis_real = discriminator(_x, training=True) _cla_fake_logits, _dis_fake = discriminator(_x_tilde, training=True) _cla_real = classifier(_cla_real_logits, training=True) _cla_fake = classifier(_cla_fake_logits, training=True) _loss_cla_real = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.argmax(_y, axis=1), logits=_cla_real)) _loss_cla_fake = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.argmax(_y, axis=1), logits=_cla_fake)) _real_loss, _fake_loss = -tf.reduce_mean( _dis_real), tf.reduce_mean(_dis_fake) _gp = gradient_penalty(partial(discriminator, training=True), _x, _x_tilde) _loss_cla = _loss_cla_real + _loss_cla_fake _loss_gen = -tf.reduce_mean(_dis_fake) _loss_dis = (_real_loss + _fake_loss) + _gp * param.w_gp_lambda _grad_gen = _gen_tape.gradient(_loss_gen, var_gen) _grad_dis = _dis_tape.gradient(_loss_dis, var_dis) _grad_cla = _cla_tape.gradient(_loss_cla, var_cla) opt_dis.apply_gradients(zip(_grad_dis, var_dis)) opt_gen.apply_gradients(zip(_grad_gen, var_gen)) opt_cla.apply_gradients(zip(_grad_cla, var_cla)) def validation_step(_x, _y): _z = tf.random.uniform(shape=(param.batch_size, param.latent_dim), minval=-1.0, maxval=1.0, dtype=tf.float32) _gen_input = tf.concat([_y, _z], axis=-1, name='gen_input') _x_tilde = generator(_gen_input, training=False) _cla_real_logits, _dis_real = discriminator(_x, training=False) _cla_fake_logits, _dis_fake = discriminator(_x_tilde, training=False) _cla_real = classifier(_cla_real_logits, training=False) _cla_fake = classifier(_cla_fake_logits, training=False) _loss_gen = -tf.reduce_mean(_dis_fake) _real_loss, _fake_loss = -tf.reduce_mean(_dis_real), tf.reduce_mean( _dis_fake) _gp = gradient_penalty(partial(discriminator, training=False), _x, _x_tilde) _loss_dis = (_real_loss + _fake_loss) + _gp * param.w_gp_lambda _loss_cla_real = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax( _y, axis=1), logits=_cla_real)) _loss_cla_fake = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax( _y, axis=1), logits=_cla_fake)) _loss_cla = _loss_cla_real + _loss_cla_fake return _x_tilde, _loss_gen.numpy(), _loss_dis.numpy( ), _loss_cla_real.numpy(), _loss_cla_fake.numpy(), ( -_fake_loss.numpy() - _real_loss.numpy()) #################################################################################################################### # 8. Train start_time = time.time() for epoch in range(0, param.max_epoch): # 8-1. Train ACGANs num_train = 0 for x_train, y_train in train_set: training_step(x_train, tf.cast(y_train, dtype=tf.float32)) num_train += 1 # 8-2. Validation num_valid = 0 val_loss_dis, val_loss_gen, val_loss_cla_real, val_loss_cla_fake, val_was_x = [], [], [], [], [] for x_valid, y_valid in test_set: x_tilde, loss_dis, loss_gen, loss_cla_real, loss_cla_fake, was_x = validation_step( x_valid, tf.cast(y_valid, dtype=tf.float32)) val_loss_dis.append(loss_dis) val_loss_gen.append(loss_gen) val_loss_cla_real.append(loss_cla_real) val_loss_cla_fake.append(loss_cla_fake) val_was_x.append(was_x) num_valid += 1 if num_valid > param.valid_step: break # 8-3. Report in training elapsed_time = (time.time() - start_time) / 60. _val_loss_cla_real = np.mean(np.reshape(val_loss_cla_real, (-1))) _val_loss_cla_fake = np.mean(np.reshape(val_loss_cla_fake, (-1))) _val_loss_dis = np.mean(np.reshape(val_loss_dis, (-1))) _val_loss_gen = np.mean(np.reshape(val_loss_gen, (-1))) _val_was_x = np.mean(np.reshape(val_was_x, (-1))) print( "[{:04d}] {:.01f} m.\tdis: {:.6f}\tgen: {:.6f}\tcla_r: {:.6f}\tcla_f: {:.6f}\t w_x: {:.4f}" .format(epoch, elapsed_time, _val_loss_dis, _val_loss_gen, _val_loss_cla_real, _val_loss_cla_fake, _val_was_x)) if epoch % param.save_frequency == 0 and epoch > 1: save_decode_image_array(x_valid.numpy(), path=os.path.join( graph_path, '{}_original-{:04d}.png'.format( graph, epoch))) save_decode_image_array(x_tilde.numpy(), path=os.path.join( graph_path, '{}_generated-{:04d}.png'.format( graph, epoch))) ckpt_manager.save(checkpoint_number=epoch) save_message = "\tSave model: End of training" generator.save_weights(os.path.join(model_path, gen_name)) discriminator.save_weights(os.path.join(model_path, dis_name)) classifier.save_weights(os.path.join(model_path, cla_name)) print("[Epoch: {:04d}] {:.01f} min.".format(param.max_epoch, elapsed_time)) print(save_message)
def train(denoise=False): param = Parameter() # 1. Build models encoder = Encoder(param).model() decoder = Decoder(param).model() # 2. Set optimizers opt_ae = tf.keras.optimizers.Adam(learning_rate=param.learning_rate_ae) # 3. Set trainable variables var_ae = encoder.trainable_variables + decoder.trainable_variables # 4. Load data data_loader = MNISTLoader() train_set = data_loader.train.batch(batch_size=param.batch_size, drop_remainder=True).shuffle(buffer_size=data_loader.num_train, reshuffle_each_iteration=True) test_set = data_loader.test.batch(batch_size=param.batch_size, drop_remainder=True) # 5. Define loss # 6. Etc. check_point_dir = os.path.join(param.cur_dir, 'training_checkpoints') model_path = os.path.join(param.cur_dir, 'model') if not os.path.isdir(model_path): os.mkdir(model_path) graph_path = os.path.join(param.cur_dir, 'graph') if not os.path.isdir(graph_path): os.mkdir(graph_path) minimum_mse = 100. num_effective_epoch = 0 if denoise is True: check_point_prefix = os.path.join(check_point_dir, 'ae_denoise') graph = 'ae_denoise' enc_name = 'encoder_denoise' dec_name = 'decoder_denoise' else: check_point_prefix = os.path.join(check_point_dir, 'ae') graph = 'ae' enc_name = 'encoder' dec_name = 'decoder' check_point = tf.train.Checkpoint(opt_ae=opt_ae, encoder=encoder, decoder=decoder) ckpt_manager = tf.train.CheckpointManager(check_point, check_point_dir, max_to_keep=5, checkpoint_name=check_point_prefix) # 7. Define train / validation step ################################################################################ def training_step(_x): if denoise is True: _noise = tf.random.normal(shape=(param.batch_size,) + param.input_dim, mean=0.0, stddev=param.white_noise_std, dtype=tf.float32) _x_train = _x + _noise else: _x_train = _x with tf.GradientTape() as _ae_tape: _z_tilde = encoder(_x_train, training=True) _x_bar = decoder(_z_tilde, training=True) _loss_ae = tf.reduce_mean(tf.abs(tf.subtract(_x, _x_bar))) # pixel-wise loss _grad_ae = _ae_tape.gradient(_loss_ae, var_ae) opt_ae.apply_gradients(zip(_grad_ae, var_ae)) def validation_step(_x): if denoise is True: _noise = tf.random.normal(shape=(param.batch_size,) + param.input_dim, mean=0.0, stddev=param.white_noise_std, dtype=tf.float32) _x_valid = _x + _noise else: _x_valid = _x _z_tilde = encoder(_x_valid, training=False) _x_bar = decoder(_z_tilde, training=False) _loss_ae = tf.reduce_mean(tf.abs(tf.subtract(_x, _x_bar))) # pixel-wise loss return _x_valid, _x_bar, _loss_ae.numpy() #################################################################################################################### # 8. Train start_time = time.time() for epoch in range(0, param.max_epoch): # 8-1. Train auto encoder for x_train, _ in train_set: training_step(x_train) # 8-2. Validation num_valid = 0 mse_valid = [] for x_valid, _ in test_set: x_noise, x_bar, loss_ae = validation_step(x_valid) mse_valid.append(loss_ae) if num_valid == param.valid_step: break num_valid += 1 valid_loss = np.mean(mse_valid) save_message = '' if minimum_mse > valid_loss: num_effective_epoch = 0 minimum_mse = valid_loss save_message = "\tSave model: detecting lowest L1: {:.6f} at epoch {:04d}".format(minimum_mse, epoch) encoder.save_weights(os.path.join(model_path, enc_name)) decoder.save_weights(os.path.join(model_path, dec_name)) elapsed_time = (time.time() - start_time) / 60. # 8-3. Report print("[Epoch: {:04d}] {:.01f} min. ae loss: {:.6f} Effective: {}".format(epoch, elapsed_time, valid_loss, (num_effective_epoch == 0))) print("{}".format(save_message)) if epoch % param.save_frequency == 0 and epoch > 1: if denoise is True: save_decode_image_array(x_noise.numpy(), path=os.path.join(graph_path, '{}_noise_{:04d}.png'.format(graph, epoch))) save_decode_image_array(x_valid.numpy(), path=os.path.join(graph_path, '{}_original_{:04d}.png'.format(graph, epoch))) save_decode_image_array(x_bar.numpy(), path=os.path.join(graph_path, '{}_decoded_{:04d}.png'.format(graph, epoch))) ckpt_manager.save(checkpoint_number=epoch) if num_effective_epoch >= param.num_early_stopping: print("\t Early stopping at epoch {:04d}!".format(epoch)) break num_effective_epoch += 1 print("\t Stopping at epoch {:04d}!".format(epoch))
def inference(w_gp=False): param = Parameter() # 1. Build models generator = Generator(param).model() discriminator = Discriminator(param).model() # 2. Load data data_loader = MNISTLoader(one_hot=False) train_set = data_loader.train.batch(batch_size=param.batch_size, drop_remainder=True) test_set = data_loader.test.batch(batch_size=param.batch_size, drop_remainder=True) # 3. Etc. graph_path = os.path.join(param.cur_dir, 'graph') if not os.path.isdir(graph_path): os.mkdir(graph_path) model_path = os.path.join(param.cur_dir, 'model') if not os.path.isdir(model_path): os.mkdir(model_path) # 4. Load model if w_gp is True: gen_name = 'gan_w_gp' dis_name = 'dis_w_gp' graph = 'gan_w_gp' else: gen_name = 'gan' dis_name = 'dis' graph = 'gan' generator.load_weights(os.path.join(model_path, gen_name)) discriminator.load_weights(os.path.join(model_path, dis_name)) # 5. Define loss cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True) # 6. Inference train_dis_loss, train_gen_loss = [], [] for x_train, _ in train_set: # noise = tf.random.uniform(shape=(param.batch_size, param.latent_dim), minval=-1, maxval=1, dtype=tf.float32) noise = tf.random.normal(shape=(param.batch_size, param.latent_dim), mean=0.0, stddev=0.3, dtype=tf.float32) x_tilde = generator(noise, training=False) dis_real = discriminator(x_train, training=False) dis_fake = discriminator(x_tilde, training=False) if w_gp is False: loss_dis = cross_entropy(tf.ones_like(dis_real), dis_real) + cross_entropy( tf.zeros_like(dis_fake), dis_fake) loss_gen = cross_entropy(tf.ones_like(dis_fake), dis_fake) else: real_loss, fake_loss = -tf.reduce_mean(dis_real), tf.reduce_mean( dis_fake) gp = gradient_penalty(partial(discriminator, training=False), x_train, x_tilde) loss_dis = (real_loss + fake_loss) + gp * param.w_gp_lambda loss_gen = -tf.reduce_mean(dis_fake) train_dis_loss.append(loss_dis.numpy()) train_gen_loss.append(loss_gen.numpy()) num_test = 0 valid_dis_loss, valid_gen_loss = [], [] test_dis_loss, test_gen_loss = [], [] for x_test, _ in test_set: # noise = tf.random.uniform(shape=(param.batch_size, param.latent_dim), minval=-1, maxval=1, dtype=tf.float32) noise = tf.random.normal(shape=(param.batch_size, param.latent_dim), mean=0.0, stddev=0.3, dtype=tf.float32) x_tilde = generator(noise, training=False) dis_real = discriminator(x_test, training=False) dis_fake = discriminator(x_tilde, training=False) if w_gp is False: loss_dis = cross_entropy(tf.ones_like(dis_real), dis_real) + cross_entropy( tf.zeros_like(dis_fake), dis_fake) loss_gen = cross_entropy(tf.ones_like(dis_fake), dis_fake) else: real_loss, fake_loss = -tf.reduce_mean(dis_real), tf.reduce_mean( dis_fake) gp = gradient_penalty(partial(discriminator, training=False), x_test, x_tilde) loss_dis = (real_loss + fake_loss) + gp * param.w_gp_lambda loss_gen = -tf.reduce_mean(dis_fake) if num_test <= param.valid_step: valid_dis_loss.append(loss_dis.numpy()) valid_gen_loss.append(loss_gen.numpy()) else: test_dis_loss.append(loss_dis.numpy()) test_gen_loss.append(loss_gen.numpy()) num_test += 1 # 7. Report train_dis_loss = np.mean(np.reshape(train_dis_loss, (-1))) valid_dis_loss = np.mean(np.reshape(valid_dis_loss, (-1))) test_dis_loss = np.mean(np.reshape(test_dis_loss, (-1))) train_gen_loss = np.mean(np.reshape(train_gen_loss, (-1))) valid_gen_loss = np.mean(np.reshape(valid_gen_loss, (-1))) test_gen_loss = np.mean(np.reshape(test_gen_loss, (-1))) print("[Loss dis] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}". format(train_dis_loss, valid_dis_loss, test_dis_loss)) print("[Loss gen] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}". format(train_gen_loss, valid_gen_loss, test_gen_loss)) # 8. Draw some samples save_decode_image_array(x_test.numpy(), path=os.path.join(graph_path, '{}_original.png'.format(graph))) save_decode_image_array(x_tilde.numpy(), path=os.path.join( graph_path, '{}_generated.png'.format(graph)))
def inference(): param = Parameter() # 1. Build models generator = Generator(param).model() discriminator = Discriminator(param).model() classifier = Classifier(param).model() # 2. Load data data_loader = MNISTLoader(one_hot=True) train_set = data_loader.train.batch(batch_size=param.batch_size, drop_remainder=True) test_set = data_loader.test.batch(batch_size=param.batch_size, drop_remainder=True) # 3. Etc. graph_path = os.path.join(param.cur_dir, 'graph') if not os.path.isdir(graph_path): os.mkdir(graph_path) model_path = os.path.join(param.cur_dir, 'model') if not os.path.isdir(model_path): os.mkdir(model_path) # 4. Load model gen_name = 'acgans_gen' dis_name = 'acgans_dis' cla_name = 'acgans_cla' graph = 'acgans' generator.load_weights(os.path.join(model_path, gen_name)) discriminator.load_weights(os.path.join(model_path, dis_name)) classifier.load_weights(os.path.join(model_path, cla_name)) # 5. Define loss # 6. Define testing step ########################################################################################### def testing_step(_x, _y): _z = tf.random.uniform(shape=(param.batch_size, param.latent_dim), minval=-1.0, maxval=1.0, dtype=tf.float32) _gen_input = tf.concat([_y, _z], axis=-1, name='gen_input') _x_tilde = generator(_gen_input, training=False) _cla_real_logits, _dis_real = discriminator(_x, training=False) _cla_fake_logits, _dis_fake = discriminator(_x_tilde, training=False) _cla_real = classifier(_cla_real_logits, training=False) _cla_fake = classifier(_cla_fake_logits, training=False) _loss_gen = -tf.reduce_mean(_dis_fake) _real_loss, _fake_loss = -tf.reduce_mean(_dis_real), tf.reduce_mean( _dis_fake) _gp = gradient_penalty(partial(discriminator, training=False), _x, _x_tilde) _loss_dis = (_real_loss + _fake_loss) + _gp * param.w_gp_lambda _loss_cla_real = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax( _y, axis=1), logits=_cla_real)) _loss_cla_fake = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax( _y, axis=1), logits=_cla_fake)) _loss_cla = _loss_cla_real + _loss_cla_fake return _x_tilde, _cla_real, _loss_gen.numpy(), _loss_dis.numpy( ), _loss_cla_real.numpy(), _loss_cla_fake.numpy(), ( -_fake_loss.numpy() - _real_loss.numpy()) #################################################################################################################### # 6. Inference train_real_label, train_prediction = [], [] train_loss_dis, train_loss_gen, train_loss_cla_real, train_loss_cla_fake, train_was_x = [], [], [], [], [] for x_train, y_train in train_set: y = tf.cast(y_train, dtype=tf.float32) x_tilde, prediction, loss_dis, loss_gen, loss_cla_real, loss_cla_fake, was_x = testing_step( x_train, y) _pred_y = tf.argmax(prediction, axis=1) train_loss_dis.append(loss_dis) train_loss_gen.append(loss_gen) train_loss_cla_real.append(loss_cla_real) train_loss_cla_fake.append(loss_cla_fake) train_was_x.append(was_x) train_real_label.append(tf.argmax(y_train, axis=1).numpy()) train_prediction.append(_pred_y) num_test = 0 valid_real_label, valid_prediction = [], [] test_real_label, test_prediction = [], [] val_loss_dis, val_loss_gen, val_loss_cla_real, val_loss_cla_fake, val_was_x = [], [], [], [], [] test_loss_dis, test_loss_gen, test_loss_cla_real, test_loss_cla_fake, test_was_x = [], [], [], [], [] for x_test, y_test in test_set: y = tf.cast(y_test, dtype=tf.float32) x_tilde, prediction, loss_dis, loss_gen, loss_cla_real, loss_cla_fake, was_x = testing_step( x_test, y) _pred_y = tf.argmax(prediction, axis=1) if num_test <= param.valid_step: val_loss_dis.append(loss_dis) val_loss_gen.append(loss_gen) val_loss_cla_real.append(loss_cla_real) val_loss_cla_fake.append(loss_cla_fake) val_was_x.append(was_x) valid_real_label.append(tf.argmax(y_test, axis=1).numpy()) valid_prediction.append(_pred_y) else: test_loss_dis.append(loss_dis) test_loss_gen.append(loss_gen) test_loss_cla_real.append(loss_cla_real) test_loss_cla_fake.append(loss_cla_fake) test_was_x.append(was_x) test_real_label.append(tf.argmax(y_test, axis=1).numpy()) test_prediction.append(_pred_y) num_test += 1 for class_idx in range(0, param.num_class): _indices = np.ones(param.batch_size, dtype=np.float) * class_idx _y = tf.one_hot(indices=_indices, depth=param.num_class, dtype=tf.float32) _z = tf.random.uniform(shape=(param.batch_size, param.latent_dim), minval=-1.0, maxval=1.0, dtype=tf.float32) _gen_input = tf.concat([_y, _z], axis=-1, name='gen_input') _x_tilde = generator(_gen_input, training=False) save_decode_image_array( _x_tilde.numpy(), path=os.path.join(graph_path, '{}_c{}_generated.png'.format(graph, class_idx))) # 7. Report train_loss_dis = np.mean(np.reshape(train_loss_dis, (-1))) train_loss_gen = np.mean(np.reshape(train_loss_gen, (-1))) train_loss_cla_real = np.mean(np.reshape(train_loss_cla_real, (-1))) train_loss_cla_fake = np.mean(np.reshape(train_loss_cla_fake, (-1))) train_was_x = np.mean(np.reshape(train_was_x, (-1))) val_loss_dis = np.mean(np.reshape(val_loss_dis, (-1))) val_loss_gen = np.mean(np.reshape(val_loss_gen, (-1))) val_loss_cla_real = np.mean(np.reshape(val_loss_cla_real, (-1))) val_loss_cla_fake = np.mean(np.reshape(val_loss_cla_fake, (-1))) val_was_x = np.mean(np.reshape(val_was_x, (-1))) test_loss_dis = np.mean(np.reshape(test_loss_dis, (-1))) test_loss_gen = np.mean(np.reshape(test_loss_gen, (-1))) test_loss_cla_real = np.mean(np.reshape(test_loss_cla_real, (-1))) test_loss_cla_fake = np.mean(np.reshape(test_loss_cla_fake, (-1))) test_was_x = np.mean(np.reshape(test_was_x, (-1))) train_real_label, train_prediction = np.reshape(train_real_label, (-1)), np.reshape( train_prediction, (-1)) valid_real_label, valid_prediction = np.reshape(valid_real_label, (-1)), np.reshape( valid_prediction, (-1)) test_real_label, test_prediction = np.reshape(test_real_label, (-1)), np.reshape( test_prediction, (-1)) train_acc = metrics.accuracy_score(train_real_label, train_prediction) valid_acc = metrics.accuracy_score(valid_real_label, valid_prediction) test_acc = metrics.accuracy_score(test_real_label, test_prediction) print( "[Loss cla fake] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}". format(train_loss_cla_fake, val_loss_cla_fake, test_loss_cla_fake)) print( "[Loss cla real] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}". format(train_loss_cla_real, val_loss_cla_real, test_loss_cla_real)) print("[Loss dis] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}". format(train_loss_dis, val_loss_dis, test_loss_dis)) print("[Loss gen] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}". format(train_loss_gen, val_loss_gen, test_loss_gen)) print( "[Was X] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}".format( train_was_x, val_was_x, test_was_x)) print("[Accuracy] Train: {:.05f}\t Validation: {:.05f}\t Test: {:.05f}". format(train_acc, valid_acc, test_acc)) # 8. Draw some samples save_decode_image_array(x_test.numpy(), path=os.path.join(graph_path, '{}_original.png'.format(graph))) save_decode_image_array(x_tilde.numpy(), path=os.path.join( graph_path, '{}_generated.png'.format(graph)))