def main(): # set flag FLAGS = set_flags() print('obs path :', FLAGS.data_url) mox.file.copy_parallel(FLAGS.data_url, FLAGS.native_data) print("mox copy files finished") # set logger logflag = set_logger(FLAGS) log(logflag, 'Training script start', 'info') # loda data # log(logflag, 'Data process : Data loading start', 'info') # HR_train, LR_train = load_npz_data(FLAGS) # log(logflag, # 'Data process : Loading existing data is completed. {} images are loaded'.format(len(HR_train)), # 'info') # pre train if FLAGS.pretrain_generator: train_pretrain_generator(FLAGS, logflag) tf.reset_default_graph() gc.collect() return else: log( logflag, 'Pre-train : Pre-train skips and an existing trained model will be used', 'info') LR_data = tf.placeholder( tf.float32, shape=[None, FLAGS.LR_image_size, FLAGS.LR_image_size, FLAGS.channel], name='LR_input') HR_data = tf.placeholder( tf.float32, shape=[None, FLAGS.HR_image_size, FLAGS.HR_image_size, FLAGS.channel], name='HR_input') # build Generator and Discriminator network = Network(FLAGS, LR_data, HR_data) gen_out = network.generator() dis_out_real, dis_out_fake = network.discriminator(gen_out) # build loss function loss = Loss() gen_loss, dis_loss = loss.gan_loss(FLAGS, HR_data, gen_out, dis_out_real, dis_out_fake) # define optimizers global_iter = tf.Variable(0, trainable=False) dis_var, dis_optimizer, gen_var, gen_optimizer = Optimizer().gan_optimizer( FLAGS, global_iter, dis_loss, gen_loss) # build summary writer tr_summary = tf.summary.merge(loss.add_summary_writer()) num_train_data = len(LR_train) num_batch_in_train = int(math.floor(num_train_data / FLAGS.batch_size)) num_epoch = int(math.ceil(FLAGS.num_iter / num_batch_in_train)) HR_train, LR_train = normalize_images(HR_train, LR_train) fetches = { 'dis_optimizer': dis_optimizer, 'gen_optimizer': gen_optimizer, 'dis_loss': dis_loss, 'gen_loss': gen_loss, 'gen_HR': gen_out, 'summary': tr_summary } gc.collect() config = tf.ConfigProto() custom_op = config.graph_options.rewrite_options.custom_optimizers.add() custom_op.name = "NpuOptimizer" custom_op.parameter_map["use_off_line"].b = True config.graph_options.rewrite_options.remapping = RewriterConfig.OFF # Start Session with tf.Session(config=config) as sess: log(logflag, 'Training ESRGAN starts', 'info') sess.run(tf.global_variables_initializer()) sess.run(global_iter.initializer) writer = tf.summary.FileWriter(FLAGS.logdir, graph=sess.graph) pre_saver = tf.train.Saver(var_list=tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')) pre_saver.restore(sess, tf.train.latest_checkpoint(FLAGS.host_pre_train)) if FLAGS.perceptual_loss == 'VGG19': sess.run(load_vgg19_weight(FLAGS)) log(logflag, 'load pretrain vgg19 model', 'info') saver = tf.train.Saver(max_to_keep=10) for epoch in range(num_epoch): log(logflag, 'ESRGAN Epoch: {0}'.format(epoch), 'info') HR_train, LR_train = shuffle(HR_train, LR_train, random_state=222) for iteration in range(num_batch_in_train): current_iter = tf.train.global_step(sess, global_iter) if current_iter > FLAGS.num_iter: break feed_dict = { HR_data: HR_train[iteration * FLAGS.batch_size:iteration * FLAGS.batch_size + FLAGS.batch_size], LR_data: LR_train[iteration * FLAGS.batch_size:iteration * FLAGS.batch_size + FLAGS.batch_size] } # update weights of G/D result = sess.run(fetches=fetches, feed_dict=feed_dict) # save summary every n iter if current_iter % FLAGS.train_summary_save_freq == 0: writer.add_summary(result['summary'], global_step=current_iter) # save samples every n iter if current_iter % FLAGS.train_sample_save_freq == 0: log( logflag, 'ESRGAN iteration : {0}, gen_loss : {1}, dis_loss : {2}' .format(current_iter, result['gen_loss'], result['dis_loss']), 'info') save_image(FLAGS, result['gen_HR'], 'train', current_iter, save_max_num=5) # # save samples every n iter # if current_iter % FLAGS.train_sample_save_freq == 0: # log(logflag, # 'ESRGAN iteration : {0}, gen_loss : {1}, dis_loss : {2}'.format(current_iter, # result['gen_loss'], # result['dis_loss']), # 'info') # # save_image(FLAGS, result['gen_HR'], 'train', current_iter, save_max_num=5) if current_iter % FLAGS.train_ckpt_save_freq == 0: saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'gen'), global_step=current_iter) writer.close() mox.file.copy_parallel(FLAGS.checkpoint_dir, "s3://esrgan-div2k/data/checkpoint") log(logflag, 'Training ESRGAN end', 'info') log(logflag, 'Training script end', 'info')
def main(): # set flag FLAGS = set_flags() # make dirs target_dirs = [ FLAGS.HR_data_dir, FLAGS.LR_data_dir, FLAGS.npz_data_dir, FLAGS.train_result_dir, FLAGS.pre_train_checkpoint_dir, FLAGS.checkpoint_dir, FLAGS.logdir ] create_dirs(target_dirs) # set logger logflag = set_logger(FLAGS) log(logflag, 'Training script start', 'info') # load data if FLAGS.save_data: log(logflag, 'Data process : Data processing start', 'info') HR_train, LR_train = load_and_save_data(FLAGS, logflag) log(logflag, 'Data process : Data loading and data processing are completed', 'info') else: log(logflag, 'Data process : Data loading start', 'info') HR_train, LR_train = load_npz_data(FLAGS) log( logflag, 'Data process : Loading existing data is completed. {} images are loaded' .format(len(HR_train)), 'info') # pre-train generator with pixel-wise loss and save the trained model if FLAGS.pretrain_generator: train_pretrain_generator(FLAGS, LR_train, HR_train, logflag) tf.reset_default_graph() gc.collect() else: log( logflag, 'Pre-train : Pre-train skips and an existing trained model will be used', 'info') LR_data = tf.placeholder( tf.float32, shape=[None, FLAGS.LR_image_size, FLAGS.LR_image_size, FLAGS.channel], name='LR_input') HR_data = tf.placeholder( tf.float32, shape=[None, FLAGS.HR_image_size, FLAGS.HR_image_size, FLAGS.channel], name='HR_input') # build Generator and Discriminator network = Network(FLAGS, LR_data, HR_data) gen_out = network.generator() dis_out_real, dis_out_fake = network.discriminator(gen_out) # build loss function loss = Loss() gen_loss, dis_loss = loss.gan_loss(FLAGS, HR_data, gen_out, dis_out_real, dis_out_fake) # define optimizers global_iter = tf.Variable(0, trainable=False) dis_var, dis_optimizer, gen_var, gen_optimizer = Optimizer().gan_optimizer( FLAGS, global_iter, dis_loss, gen_loss) # build summary writer tr_summary = tf.summary.merge(loss.add_summary_writer()) num_train_data = len(HR_train) num_batch_in_train = int(math.floor(num_train_data / FLAGS.batch_size)) num_epoch = int(math.ceil(FLAGS.num_iter / num_batch_in_train)) HR_train, LR_train = normalize_images(HR_train, LR_train) fetches = { 'dis_optimizer': dis_optimizer, 'gen_optimizer': gen_optimizer, 'dis_loss': dis_loss, 'gen_loss': gen_loss, 'gen_HR': gen_out, 'summary': tr_summary } gc.collect() config = tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True, visible_device_list=FLAGS.gpu_dev_num)) # Start Session with tf.Session(config=config) as sess: log(logflag, 'Training ESRGAN starts', 'info') sess.run(tf.global_variables_initializer()) sess.run(global_iter.initializer) writer = tf.summary.FileWriter(FLAGS.logdir, graph=sess.graph) pre_saver = tf.train.Saver(var_list=tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')) pre_saver.restore( sess, tf.train.latest_checkpoint(FLAGS.pre_train_checkpoint_dir)) if FLAGS.perceptual_loss == 'VGG19': sess.run(load_vgg19_weight(FLAGS)) saver = tf.train.Saver(max_to_keep=10) for epoch in range(num_epoch): log(logflag, 'ESRGAN Epoch: {0}'.format(epoch), 'info') HR_train, LR_train = shuffle(HR_train, LR_train, random_state=222) for iteration in range(num_batch_in_train): current_iter = tf.train.global_step(sess, global_iter) if current_iter > FLAGS.num_iter: break feed_dict = { HR_data: HR_train[iteration * FLAGS.batch_size:iteration * FLAGS.batch_size + FLAGS.batch_size], LR_data: LR_train[iteration * FLAGS.batch_size:iteration * FLAGS.batch_size + FLAGS.batch_size] } # update weights of G/D result = sess.run(fetches=fetches, feed_dict=feed_dict) # save summary every n iter if current_iter % FLAGS.train_summary_save_freq == 0: writer.add_summary(result['summary'], global_step=current_iter) # save samples every n iter if current_iter % FLAGS.train_sample_save_freq == 0: log( logflag, 'ESRGAN iteration : {0}, gen_loss : {1}, dis_loss : {2}' .format(current_iter, result['gen_loss'], result['dis_loss']), 'info') save_image(FLAGS, result['gen_HR'], 'train', current_iter, save_max_num=5) if current_iter % FLAGS.train_ckpt_save_freq == 0: saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'gen'), global_step=current_iter) writer.close() log(logflag, 'Training ESRGAN end', 'info') log(logflag, 'Training script end', 'info')
def train_pretrain_generator(FLAGS, LR_train, HR_train, logflag): """pre-train deep network as initialization weights of ESRGAN Generator""" log(logflag, 'Pre-train : Process start', 'info') LR_data = tf.placeholder( tf.float32, shape=[None, FLAGS.LR_image_size, FLAGS.LR_image_size, FLAGS.channel], name='LR_input') HR_data = tf.placeholder( tf.float32, shape=[None, FLAGS.HR_image_size, FLAGS.HR_image_size, FLAGS.channel], name='HR_input') # build Generator network = Network(FLAGS, LR_data) pre_gen_out = network.generator() # build loss function loss = Loss() pre_gen_loss = loss.pretrain_loss(pre_gen_out, HR_data) # build optimizer global_iter = tf.Variable(0, trainable=False) pre_gen_var, pre_gen_optimizer = Optimizer().pretrain_optimizer( FLAGS, global_iter, pre_gen_loss) # build summary writer pre_summary = tf.summary.merge(loss.add_summary_writer()) num_train_data = len(HR_train) num_batch_in_train = int(math.floor(num_train_data / FLAGS.batch_size)) num_epoch = int(math.ceil(FLAGS.num_iter / num_batch_in_train)) HR_train, LR_train = normalize_images(HR_train, LR_train) fetches = { 'pre_gen_loss': pre_gen_loss, 'pre_gen_optimizer': pre_gen_optimizer, 'gen_HR': pre_gen_out, 'summary': pre_summary } gc.collect() config = tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True, visible_device_list=FLAGS.gpu_dev_num)) saver = tf.train.Saver(max_to_keep=10) # Start session with tf.Session(config=config) as sess: log(logflag, 'Pre-train : Training starts', 'info') sess.run(tf.global_variables_initializer()) sess.run(global_iter.initializer) sess.run(scale_initialization(pre_gen_var, FLAGS)) writer = tf.summary.FileWriter(FLAGS.logdir, graph=sess.graph, filename_suffix='pre-train') for epoch in range(num_epoch): log(logflag, 'Pre-train Epoch: {0}'.format(epoch), 'info') HR_train, LR_train = shuffle(HR_train, LR_train, random_state=222) for iteration in range(num_batch_in_train): current_iter = tf.train.global_step(sess, global_iter) if current_iter > FLAGS.num_iter: break feed_dict = { HR_data: HR_train[iteration * FLAGS.batch_size:iteration * FLAGS.batch_size + FLAGS.batch_size], LR_data: LR_train[iteration * FLAGS.batch_size:iteration * FLAGS.batch_size + FLAGS.batch_size] } # update weights result = sess.run(fetches=fetches, feed_dict=feed_dict) # save summary every n iter if current_iter % FLAGS.train_summary_save_freq == 0: writer.add_summary(result['summary'], global_step=current_iter) # save samples every n iter if current_iter % FLAGS.train_sample_save_freq == 0: log( logflag, 'Pre-train iteration : {0}, pixel-wise_loss : {1}'. format(current_iter, result['pre_gen_loss']), 'info') save_image(FLAGS, result['gen_HR'], 'pre-train', current_iter, save_max_num=5) # save checkpoint if current_iter % FLAGS.train_ckpt_save_freq == 0: saver.save(sess, os.path.join(FLAGS.pre_train_checkpoint_dir, 'pre_gen'), global_step=current_iter) writer.close() log(logflag, 'Pre-train : Process end', 'info')
def main(): # Prepare args args = parse_args() num_labeled_train = args.num_labeled_train num_test = args.num_test ramp_up_period = args.ramp_up_period ramp_down_period = args.ramp_down_period num_class = args.num_class num_epoch = args.num_epoch batch_size = args.batch_size weight_max = args.weight_max learning_rate = args.learning_rate alpha = args.alpha weight_norm_flag = args.weight_norm_flag augmentation_flag = args.augmentation_flag whitening_flag = args.whitening_flag trans_range = args.trans_range # Data Preparation train_x, train_y, test_x, test_y = load_data(args.data_path) ret_dic = split_supervised_train(train_x, train_y, num_labeled_train) ret_dic['test_x'] = test_x ret_dic['test_y'] = test_y ret_dic = make_train_test_dataset(ret_dic, num_class) unsupervised_target = ret_dic['unsupervised_target'] supervised_label = ret_dic['supervised_label'] supervised_flag = ret_dic['train_sup_flag'] unsupervised_weight = ret_dic['unsupervised_weight'] test_y = ret_dic['test_y'] train_x, test_x = normalize_images(ret_dic['train_x'], ret_dic['test_x']) # pre-process if whitening_flag: train_x, test_x = whiten_zca(train_x, test_x) if augmentation_flag: train_x = np.pad(train_x, ((0, 0), (trans_range, trans_range), (trans_range, trans_range), (0, 0)), 'reflect') # make the whole data and labels for training # x = [train_x, supervised_label, supervised_flag, unsupervised_weight] y = np.concatenate((unsupervised_target, supervised_label, supervised_flag, unsupervised_weight), axis=1) num_train_data = train_x.shape[0] # Build Model if weight_norm_flag: from lib.model_WN import build_model from lib.weight_norm import AdamWithWeightnorm optimizer = AdamWithWeightnorm(lr=learning_rate, beta_1=0.9, beta_2=0.999) else: from lib.model_BN import build_model optimizer = Adam(lr=learning_rate, beta_1=0.9, beta_2=0.999) model = build_model(num_class=num_class) model.compile(optimizer=optimizer, loss=semi_supervised_loss(num_class)) model.metrics_tensors += model.outputs model.summary() # prepare weights and arrays for updates gen_weight = ramp_up_weight( ramp_up_period, weight_max * (num_labeled_train / num_train_data)) gen_lr_weight = ramp_down_weight(ramp_down_period) idx_list = [v for v in range(num_train_data)] ensemble_prediction = np.zeros((num_train_data, num_class)) cur_pred = np.zeros((num_train_data, num_class)) # Training for epoch in range(num_epoch): print('epoch: ', epoch) idx_list = shuffle(idx_list) if epoch > num_epoch - ramp_down_period: weight_down = next(gen_lr_weight) K.set_value(model.optimizer.lr, weight_down * learning_rate) K.set_value(model.optimizer.beta_1, 0.4 * weight_down + 0.5) ave_loss = 0 for i in range(0, num_train_data, batch_size): target_idx = idx_list[i:i + batch_size] if augmentation_flag: x1 = data_augmentation_tempen(train_x[target_idx], trans_range) else: x1 = train_x[target_idx] x2 = supervised_label[target_idx] x3 = supervised_flag[target_idx] x4 = unsupervised_weight[target_idx] y_t = y[target_idx] x_t = [x1, x2, x3, x4] tr_loss, output = model.train_on_batch(x=x_t, y=y_t) cur_pred[idx_list[i:i + batch_size]] = output[:, 0:num_class] ave_loss += tr_loss print('Training Loss: ', (ave_loss * batch_size) / num_train_data, flush=True) # Update phase next_weight = next(gen_weight) y, unsupervised_weight = update_weight(y, unsupervised_weight, next_weight) ensemble_prediction, y = update_unsupervised_target( ensemble_prediction, y, num_class, alpha, cur_pred, epoch) # Evaluation if epoch % 5 == 0: print('Evaluate epoch : ', epoch, flush=True) evaluate(model, num_class, num_test, test_x, test_y)