def train(args, data, params): train = data['train'] valid = data['valid'] learning_rate = args.learning_rate with tf.Graph().as_default(): input_ph = tf.placeholder(tf.int32, shape=[args.batch_size, params['gram_size']]) targ_ph = tf.placeholder(tf.int32, shape=[args.batch_size]) learning_rate_ph = tf.placeholder(tf.float32, shape=[]) if args.w2v: with h5py.File(args.w2v, 'r') as datafile: embeds = datafile['w2v'][:] scores, normalize_op, vars = ops.model(input_ph, params, embeds) else: scores, normalize_op, vars = ops.model(input_ph, params) loss = ops.loss(scores, targ_ph) train_op, print_op = ops.train(loss, learning_rate_ph, args) #sess = tf.Session(config=tf.ConfigProto(inter_op_parallelism_threads=NUM_THREADS,\ # intra_op_parallelism_threads=NUM_THREADS)) sess = tf.Session() init = tf.initialize_all_variables( ) # initialize variables before they can be used saver = tf.train.Saver() sess.run(init) if args.modelfile: saver.restore(sess, args.modelfile) print "Model restored from %s" % args.modelfile valid_loss = 0. for i in xrange(valid.nbatches): valid_feed_dict = get_feed_dict(valid, i, input_ph, targ_ph, learning_rate_ph) batch_loss = sess.run([loss], feed_dict=valid_feed_dict)[0] valid_loss += batch_loss last_valid = valid_loss print 'Initial valid loss: %.3f' % math.exp( valid_loss / valid.nbatches) for epoch in xrange(args.nepochs): print "Training epoch %d with learning rate %.3f" % (epoch + 1, learning_rate) vals = sess.run(vars) start_time = time.time() train_loss = 0. valid_loss = 0. for i in xrange(train.nbatches): train_feed_dict = get_feed_dict(train, i, input_ph, targ_ph, \ learning_rate_ph, learning_rate) #grads = sess.run(print_op, feed_dict=train_feed_dict) _, batch_loss = sess.run([train_op, loss], feed_dict=train_feed_dict) train_loss += batch_loss for i in xrange(valid.nbatches): valid_feed_dict = get_feed_dict(valid, i, input_ph, targ_ph, learning_rate_ph) batch_loss = sess.run([loss], feed_dict=valid_feed_dict)[0] valid_loss += batch_loss if args.normalize: _ = sess.run(normalize_op) duration = time.time() - start_time print "\tloss = %.3f, valid ppl = %.3f, %.3f s" % \ (math.exp(train_loss/train.nbatches), \ math.exp(valid_loss/valid.nbatches), duration) if last_valid < valid_loss: learning_rate /= 2. elif args.outfile: saver.save(sess, args.outfile) if epoch >= args.decay_after: learning_rate /= 1.2 last_valid = valid_loss return sess.run([normalize_op ])[0] # return final normalized embeddings
def main(): conf = get_config() extension_module = conf.nnabla_context.context ctx = get_extension_context(extension_module, device_id=conf.nnabla_context.device_id) comm = CommunicatorWrapper(ctx) nn.set_default_context(comm.ctx) print("#GPU Count: ", comm.n_procs) data_iterator_train = jsi_iterator(conf.batch_size, conf, train=True) if conf.scaling_factor == 1: d_t = nn.Variable((conf.batch_size, 80, 80, 3), need_grad=True) l_t = nn.Variable((conf.batch_size, 80, 80, 3), need_grad=True) else: d_t = nn.Variable((conf.batch_size, 160 / conf.scaling_factor, 160 / conf.scaling_factor, 3), need_grad=True) l_t = nn.Variable((conf.batch_size, 160, 160, 3), need_grad=True) if comm.n_procs > 1: data_iterator_train = data_iterator_train.slice( rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank) monitor_path = './nnmonitor' + \ str(datetime.datetime.now().strftime("%Y%m%d%H%M%S")) monitor = Monitor(monitor_path) jsi_monitor = setup_monitor(conf, monitor) with nn.parameter_scope("jsinet"): nn.load_parameters(conf.pre_trained_model) net = model(d_t, conf.scaling_factor) net.pred.persistent = True rec_loss = F.mean(F.squared_error(net.pred, l_t)) rec_loss.persistent = True g_final_loss = rec_loss if conf.jsigan: net_gan = gan_model(l_t, net.pred, conf) d_final_fm_loss = net_gan.d_adv_loss d_final_fm_loss.persistent = True d_final_detail_loss = net_gan.d_detail_adv_loss d_final_detail_loss.persistent = True g_final_loss = conf.rec_lambda * rec_loss + conf.adv_lambda * ( net_gan.g_adv_loss + net_gan.g_detail_adv_loss ) + conf.fm_lambda * (net_gan.fm_loss + net_gan.fm_detail_loss) g_final_loss.persistent = True max_iter = data_iterator_train._size // (conf.batch_size) if comm.rank == 0: print("max_iter", data_iterator_train._size, max_iter) iteration = 0 if not conf.jsigan: start_epoch = 0 end_epoch = conf.adv_weight_point lr = conf.learning_rate * comm.n_procs else: start_epoch = conf.adv_weight_point end_epoch = conf.epoch lr = conf.learning_rate * comm.n_procs w_d = conf.weight_decay * comm.n_procs # Set generator parameters with nn.parameter_scope("jsinet"): solver_jsinet = S.Adam(alpha=lr, beta1=0.9, beta2=0.999, eps=1e-08) solver_jsinet.set_parameters(nn.get_parameters()) if conf.jsigan: solver_disc_fm = S.Adam(alpha=lr, beta1=0.9, beta2=0.999, eps=1e-08) solver_disc_detail = S.Adam(alpha=lr, beta1=0.9, beta2=0.999, eps=1e-08) with nn.parameter_scope("Discriminator_FM"): solver_disc_fm.set_parameters(nn.get_parameters()) with nn.parameter_scope("Discriminator_Detail"): solver_disc_detail.set_parameters(nn.get_parameters()) for epoch in range(start_epoch, end_epoch): for index in range(max_iter): d_t.d, l_t.d = data_iterator_train.next() if not conf.jsigan: # JSI-net -> Generator lr_stair_decay_points = [200, 225] lr_net = get_learning_rate(lr, iteration, lr_stair_decay_points, conf.lr_decreasing_factor) g_final_loss.forward(clear_no_need_grad=True) solver_jsinet.zero_grad() if comm.n_procs > 1: all_reduce_callback = comm.get_all_reduce_callback() g_final_loss.backward( clear_buffer=True, communicator_callbacks=all_reduce_callback) else: g_final_loss.backward(clear_buffer=True) solver_jsinet.set_learning_rate(lr_net) solver_jsinet.update() else: # GAN part (discriminator + generator) lr_gan = lr if epoch < conf.gan_lr_linear_decay_point \ else lr * (end_epoch - epoch) / (end_epoch - conf.gan_lr_linear_decay_point) lr_gan = lr_gan * conf.gan_ratio net.pred.need_grad = False # Discriminator_FM solver_disc_fm.zero_grad() d_final_fm_loss.forward(clear_no_need_grad=True) if comm.n_procs > 1: all_reduce_callback = comm.get_all_reduce_callback() d_final_fm_loss.backward( clear_buffer=True, communicator_callbacks=all_reduce_callback) else: d_final_fm_loss.backward(clear_buffer=True) solver_disc_fm.set_learning_rate(lr_gan) solver_disc_fm.weight_decay(w_d) solver_disc_fm.update() # Discriminator_Detail solver_disc_detail.zero_grad() d_final_detail_loss.forward(clear_no_need_grad=True) if comm.n_procs > 1: all_reduce_callback = comm.get_all_reduce_callback() d_final_detail_loss.backward( clear_buffer=True, communicator_callbacks=all_reduce_callback) else: d_final_detail_loss.backward(clear_buffer=True) solver_disc_detail.set_learning_rate(lr_gan) solver_disc_detail.weight_decay(w_d) solver_disc_detail.update() # Generator net.pred.need_grad = True solver_jsinet.zero_grad() g_final_loss.forward(clear_no_need_grad=True) if comm.n_procs > 1: all_reduce_callback = comm.get_all_reduce_callback() g_final_loss.backward( clear_buffer=True, communicator_callbacks=all_reduce_callback) else: g_final_loss.backward(clear_buffer=True) solver_jsinet.set_learning_rate(lr_gan) solver_jsinet.update() iteration += 1 if comm.rank == 0: train_psnr = compute_psnr(net.pred.d, l_t.d, 1.) jsi_monitor['psnr'].add(iteration, train_psnr) jsi_monitor['rec_loss'].add(iteration, rec_loss.d.copy()) jsi_monitor['time'].add(iteration) if comm.rank == 0: if conf.jsigan: jsi_monitor['g_final_loss'].add(iteration, g_final_loss.d.copy()) jsi_monitor['g_adv_loss'].add(iteration, net_gan.g_adv_loss.d.copy()) jsi_monitor['g_detail_adv_loss'].add( iteration, net_gan.g_detail_adv_loss.d.copy()) jsi_monitor['d_final_fm_loss'].add( iteration, d_final_fm_loss.d.copy()) jsi_monitor['d_final_detail_loss'].add( iteration, d_final_detail_loss.d.copy()) jsi_monitor['fm_loss'].add(iteration, net_gan.fm_loss.d.copy()) jsi_monitor['fm_detail_loss'].add( iteration, net_gan.fm_detail_loss.d.copy()) jsi_monitor['lr'].add(iteration, lr_gan) if comm.rank == 0: if not os.path.exists(conf.output_dir): os.makedirs(conf.output_dir) with nn.parameter_scope("jsinet"): nn.save_parameters( os.path.join(conf.output_dir, "model_param_%04d.h5" % epoch))
def inference(): """ Inference function to generate high resolution hdr images """ conf = get_config() ctx = get_extension_context(conf.nnabla_context.context, device_id=conf.nnabla_context.device_id) nn.set_default_context(ctx) data, target = read_mat_file(conf.data.lr_sdr_test, conf.data.hr_hdr_test, conf.data.d_name_test, conf.data.l_name_test, train=False) if not os.path.exists(conf.test_img_dir): os.makedirs(conf.test_img_dir) data_sz = data.shape target_sz = target.shape PATCH_BOUNDARY = 10 # set patch boundary to reduce edge effect around patch edges test_loss_PSNR_list_for_epoch = [] inf_time = [] start_time = time.time() test_pred_full = np.zeros((target_sz[1], target_sz[2], target_sz[3])) print("Loading pre trained model.........", conf.pre_trained_model) nn.load_parameters(conf.pre_trained_model) for index in range(data_sz[0]): ###======== Divide Into Patches ========### for p in range(conf.test_patch**2): pH = p // conf.test_patch pW = p % conf.test_patch sH = data_sz[1] // conf.test_patch sW = data_sz[2] // conf.test_patch H_low_ind, H_high_ind, W_low_ind, W_high_ind = \ get_hw_boundary( PATCH_BOUNDARY, data_sz[1], data_sz[2], pH, sH, pW, sW) data_test_p = nn.Variable.from_numpy_array( data.d[index, H_low_ind:H_high_ind, W_low_ind:W_high_ind, :]) data_test_sz = data_test_p.shape data_test_p = F.reshape( data_test_p, (1, data_test_sz[0], data_test_sz[1], data_test_sz[2])) st = time.time() net = model(data_test_p, conf.scaling_factor) net.pred.forward() test_pred_temp = net.pred.d inf_time.append(time.time() - st) test_pred_t = trim_patch_boundary(test_pred_temp, PATCH_BOUNDARY, data_sz[1], data_sz[2], pH, sH, pW, sW, conf.scaling_factor) #pred_sz = test_pred_t.shape test_pred_t = np.squeeze(test_pred_t) test_pred_full[pH * sH * conf.scaling_factor:(pH + 1) * sH * conf.scaling_factor, pW * sW * conf.scaling_factor:(pW + 1) * sW * conf.scaling_factor, :] = test_pred_t ###======== Compute PSNR & Print Results========### test_GT = np.squeeze(target.d[index, :, :, :]) test_PSNR = compute_psnr(test_pred_full, test_GT, 1.) test_loss_PSNR_list_for_epoch.append(test_PSNR) print( " <Test> [%4d/%4d]-th images, time: %4.4f(minutes), test_PSNR: %.8f[dB] " % (int(index), int(data_sz[0]), (time.time() - start_time) / 60, test_PSNR)) if conf.save_images: # comment for faster testing save_results_yuv(test_pred_full, index, conf.test_img_dir) test_PSNR_per_epoch = np.mean(test_loss_PSNR_list_for_epoch) print("######### Average Test PSNR: %.8f[dB] #########" % (test_PSNR_per_epoch)) print( "######### Estimated Inference Time (per 4K frame): %.8f[s] #########" % (np.mean(inf_time) * conf.test_patch * conf.test_patch))