def get_image_ques(hps, requeue, n_thr_im, im_qsz=16): tr_fns, hps.n_tr_inst = sidd_filenames_que_inst(hps.sidd_path, 'train', hps.start_tr_im_idx, hps.end_tr_im_idx, hps.camera, hps.iso) ts_fns, hps.n_ts_inst = sidd_filenames_que_inst(hps.sidd_path, 'test', hps.start_ts_im_idx, hps.end_ts_im_idx, hps.camera, hps.iso) # image loaders tr_image_loader = ImageLoader(tr_fns, max_queue_size=im_qsz, n_threads=n_thr_im, requeue=requeue) ts_image_loader = ImageLoader(ts_fns, max_queue_size=im_qsz, n_threads=n_thr_im, requeue=requeue) tr_im_que = tr_image_loader.get_queue() ts_im_que = ts_image_loader.get_queue() return tr_im_que, ts_im_que
def main(hps): # Download SIDD_Medium_Raw? check_download_sidd() total_time = time.time() host = socket.gethostname() tf.set_random_seed(hps.seed) np.random.seed(hps.seed) # set up a custom logger add_logging_level('TRACE', 100) logging.getLogger(__name__).setLevel("TRACE") logging.basicConfig(level=logging.TRACE) hps.n_bins = 2.**hps.n_bits_x logging.trace('SIDD path = %s' % hps.sidd_path) # prepare data file names tr_fns, hps.n_tr_inst = sidd_filenames_que_inst(hps.sidd_path, 'train', hps.start_tr_im_idx, hps.end_tr_im_idx, hps.camera, hps.iso) logging.trace('# training scene instances (cam = %s, iso = %s) = %d' % (str(hps.camera), str(hps.iso), hps.n_tr_inst)) ts_fns, hps.n_ts_inst = sidd_filenames_que_inst(hps.sidd_path, 'test', hps.start_ts_im_idx, hps.end_ts_im_idx, hps.camera, hps.iso) logging.trace('# testing scene instances (cam = %s, iso = %s) = %d' % (str(hps.camera), str(hps.iso), hps.n_ts_inst)) # training/testing data stats calc_train_test_stats(hps) # output log dir logdir = os.path.abspath( os.path.join('experiments', hps.problem, hps.logdir)) + '/' if not os.path.exists(logdir): os.makedirs(logdir, exist_ok=True) hps.logdirname = hps.logdir hps.logdir = logdir train_its, test_its = get_its(hps.n_batch_train, hps.n_batch_test, hps.n_train, hps.n_test) hps.train_its = train_its hps.test_its = test_its x_shape = [None, hps.patch_height, hps.patch_height, 4] hps.x_shape = x_shape hps.n_dims = np.prod(x_shape[1:]) # calculate data stats and baselines logging.trace('calculating data stats and baselines...') hps.calc_pat_stats_and_baselines_only = True pat_stats, nll_gauss, _, nll_sdn, _, tr_batch_sampler, ts_batch_sampler = initialize_data_stats_queues_baselines_histograms( hps, logdir) hps.nll_gauss = nll_gauss hps.nll_sdn = nll_sdn # prepare get data queues hps.mb_requeue = True # requeue minibatches for future epochs logging.trace('preparing data queues...') hps.calc_pat_stats_and_baselines_only = False tr_im_que, ts_im_que, tr_pat_que, ts_pat_que, tr_batch_que, ts_batch_que = \ initialize_data_stats_queues_baselines_histograms(hps, logdir, tr_batch_sampler=tr_batch_sampler, ts_batch_sampler=ts_batch_sampler) # hps.save_batches = True print_train_test_stats(hps) input_shape = x_shape # Build noise flow graph logging.trace('Building NoiseFlow...') is_training = tf.placeholder(tf.bool, name='is_training') x = tf.placeholder(tf.float32, x_shape, name='noise_image') y = tf.placeholder(tf.float32, x_shape, name='clean_image') nlf0 = tf.placeholder(tf.float32, [None], name='nlf0') nlf1 = tf.placeholder(tf.float32, [None], name='nlf1') iso = tf.placeholder(tf.float32, [None], name='iso') cam = tf.placeholder(tf.float32, [None], name='cam') lr = tf.placeholder(tf.float32, None, name='learning_rate') # initialization of signal, gain, and camera parameters if hps.sidd_cond == 'mix': init_params(hps) # NoiseFlow model nf = NoiseFlow(input_shape[1:], is_training, hps) loss_val, sd_z = nf.loss(x, y, nlf0=nlf0, nlf1=nlf1, iso=iso, cam=cam) # save variable names and number of parameters vs = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) vars_files = os.path.join(hps.logdir, 'model_vars.txt') with open(vars_files, 'w') as vf: vf.write(str(vs)) hps.num_params = int( np.sum([ np.prod(v.get_shape().as_list()) for v in tf.trainable_variables() ])) logging.trace('number of parameters = %d' % hps.num_params) hps_logger(logdir + 'hps.txt', hps, nf.get_layer_names(), hps.num_params) # create session sess = tf.Session() n_processed = 0 train_time = 0.0 test_loss_best = np.inf # create a saver. saver = tf.train.Saver(max_to_keep=0) # keep all models # checkpoint directory ckpt_dir = os.path.join(hps.logdir, 'ckpt') ckpt_path = os.path.join(ckpt_dir, 'model.ckpt') if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir, exist_ok=True) # sampling temperature (default = 1.0) if hps.temp is None: hps.temp = 1.0 # setup the output log train_logger = test_logger = None log_columns = ['epoch', 'NLL'] # NLL: negative log likelihood # NLL_G: for Gaussian baseline # NLL_SDN: for camera NLF baseline # sdz: standard deviation of the base measure (sanity check) log_columns = log_columns + ['NLL_G', 'NLL_SDN', 'sdz'] if hps.do_sample: log_columns.append('sample_time') else: train_logger = ResultLogger(logdir + 'train.txt', log_columns + ['train_time'], hps.continue_training) test_logger = ResultLogger(logdir + 'test.txt', log_columns + ['msg'], hps.continue_training) sample_logger = ResultLogger( logdir + 'sample.txt', log_columns + ['KLD_G', 'KLD_NLF', 'KLD_NF', 'KLD_R'], hps.continue_training) tcurr = time.time() train_results = [] test_results = [] sample_results = [] # continue training? start_epoch = 1 logging.trace('continue_training = ' + str(hps.continue_training)) if hps.continue_training: sess.run(tf.global_variables_initializer()) last_epoch = restore_last_model(ckpt_dir, sess, saver) start_epoch = 1 + last_epoch # noinspection PyBroadException try: train_op = tf.get_collection('train_op') # [0] except: logging.trace( 'could not restore optimizer state, preparing a new optimizer') train_op = get_optimizer(hps, lr, loss_val) else: logging.trace('preparing optimizer') train_op = get_optimizer(hps, lr, loss_val) logging.trace('initializing variables') sess.run(tf.global_variables_initializer()) _lr = hps.lr _nlf0 = None _nlf1 = None t_train = t_test = t_sample = dsample = is_best = sd_z_tr = sd_z_ts = 0 kldiv3 = None # Epochs logging.trace('Starting training/testing/samplings.') logging.trace('Logging to ' + logdir) for epoch in range(start_epoch, hps.epochs + 1): # Testing if (not hps.do_sample) and \ (epoch < 10 or (epoch < 100 and epoch % 10 == 0) or epoch % hps.epochs_full_valid == 0.): t = time.time() test_epoch_loss = [] # multi-thread testing (faster) test_epoch_loss_que = queue.Queue() sd_z_que_ts = queue.Queue() sd_z_ts = 0 test_multithread(sess, ts_batch_que, loss_val, sd_z, x, y, nlf0, nlf1, iso, cam, is_training, test_epoch_loss_que, sd_z_que_ts, test_its, nthr=hps.n_train_threads, requeue=not hps.mb_requeue) assert test_epoch_loss_que.qsize() == test_its for tt in range(test_its): test_epoch_loss.append(test_epoch_loss_que.get()) sd_z_ts += sd_z_que_ts.get() sd_z_ts /= test_its mean_test_loss = np.mean(test_epoch_loss) test_results.append(mean_test_loss) # Save checkpoint saver.save(sess, ckpt_path, global_step=epoch) # best model? if test_results[-1] < test_loss_best: test_loss_best = test_results[-1] saver.save(sess, ckpt_path + '.best') is_best = 1 else: is_best = 0 # log log_dict = { 'epoch': epoch, 'NLL': test_results[-1], 'NLL_G': nll_gauss, 'NLL_SDN': nll_sdn, 'sdz': sd_z_ts, 'msg': is_best } test_logger.log(log_dict) t_test = time.time() - t # End testing if & loop # Sampling (optional) do_sampling = True # make this true to perform sampling if do_sampling and (( epoch < 10 or (epoch < 100 and epoch % 10 == 0) or # (is_best == 1) or epoch % hps.epochs_full_valid * 2 == 0.)): for temp in [1.0]: # using only default temperature t_sample = time.time() hps.temp = float(temp) sample_epoch_loss = [] # multi-thread sampling (faster) sample_epoch_loss_que = queue.Queue() sd_z_que_sam = queue.Queue() kldiv_que = queue.Queue() sd_z_sam = 0.0 kldiv1 = np.ndarray([4]) kldiv1[:] = 0.0 kldiv3 = np.zeros(4) is_cond = hps.sidd_cond != 'uncond' # sample (forward) x_sample = sample_sidd_tf(sess, nf, is_training, hps.temp, y, nlf0, nlf1, iso, cam, is_cond) sample_multithread(sess, ts_batch_que, loss_val, sd_z, x, x_sample, y, nlf0, nlf1, iso, cam, is_training, sample_epoch_loss_que, sd_z_que_sam, kldiv_que, test_its, nthr=hps.n_train_threads, requeue=not hps.mb_requeue, sc_sd=pat_stats['sc_in_sd'], epoch=epoch) # assert sample_epoch_loss_que.qsize() == test_its nqs = sample_epoch_loss_que.qsize() for tt in range(nqs): sample_epoch_loss.append(sample_epoch_loss_que.get()) sd_z_sam += sd_z_que_sam.get() kldiv3 += kldiv_que.get() sd_z_sam /= nqs kldiv3 /= np.repeat(nqs, len(kldiv3)) mean_sample_loss = np.mean(sample_epoch_loss) sample_results.append(mean_sample_loss) t_sample = time.time() - t_sample # log log_dict = { 'epoch': epoch, 'NLL': sample_results[-1], 'NLL_G': nll_gauss, 'NLL_SDN': nll_sdn, 'sdz': sd_z_sam, 'sample_time': t_sample, 'KLD_G': kldiv3[0], 'KLD_NLF': kldiv3[1], 'KLD_NF': kldiv3[2], 'KLD_R': kldiv3[3] } sample_logger.log(log_dict) # Training loop t_curr = 0 if not hps.do_sample: t = time.time() train_epoch_loss = [] # multi-thread training (faster) train_epoch_loss_que = queue.Queue() sd_z_que_tr = queue.Queue() n_processed_que = queue.Queue() sd_z_tr = 0 train_multithread(sess, tr_batch_que, loss_val, sd_z, train_op, x, y, nlf0, nlf1, iso, cam, lr, is_training, _lr, n_processed_que, train_epoch_loss_que, sd_z_que_tr, train_its, nthr=hps.n_train_threads, requeue=not hps.mb_requeue) assert train_epoch_loss_que.qsize() == train_its for tt in range(train_its): train_epoch_loss.append(train_epoch_loss_que.get()) n_processed += n_processed_que.get() sd_z_tr += sd_z_que_tr.get() sd_z_tr /= train_its t_curr = time.time() - tcurr tcurr = time.time() mean_train_loss = np.mean(train_epoch_loss) train_results.append(mean_train_loss) t_train = time.time() - t train_time += t_train train_logger.log({ 'epoch': epoch, 'train_time': int(train_time), 'NLL': train_results[-1], 'NLL_G': nll_gauss, 'NLL_SDN': nll_sdn, 'sdz': sd_z_tr }) # End training # print results of train/test/sample tr_l = train_results[-1] if len(train_results) > 0 else 0 ts_l = test_results[-1] if len(test_results) > 0 else 0 sam_l = sample_results[-1] if len(sample_results) > 0 else 0 if epoch < 10 or (epoch < 100 and epoch % 10 == 0) or \ epoch % hps.epochs_full_valid == 0.: # E: epoch # tr, ts, tsm, tv: time of training, testing, sampling, visualization # T: total time # tL, sL, smL: loss of training, testing, sampling # SDr, SDs: std. dev. of base measure in training and testing # B: 1 if best model, 0 otherwise print('%s %s %s E=%d tr=%.1f ts=%.1f tsm=%.1f tv=%.1f T=%.1f ' 'tL=%5.1f sL=%5.1f smL=%5.1f SDr=%.1f SDs=%.1f B=%d' % (str(datetime.now())[11:16], host, hps.logdirname, epoch, t_train, t_test, t_sample, dsample, t_curr, tr_l, ts_l, sam_l, sd_z_tr, sd_z_ts, is_best), end='') if kldiv3 is not None: print(' ', end='') # marginal KL divergence of noise samples from: Gaussian, camera-NLF, and NoiseFlow, respectively print(','.join('{0:.3f}'.format(kk) for kk in kldiv3), end='') print('', flush=True) total_time = time.time() - total_time logging.trace('Total time = %f' % total_time) with open(path.join(logdir, 'total_time.txt'), 'w') as f: f.write('total_time (s) = %f' % total_time) logging.trace("Finished!")
def initialize_data_stats_queues_baselines_histograms(hps, logdir, tr_batch_sampler=None, ts_batch_sampler=None): # use 4 or 8 thread for faster loading n_thr_im = 8 n_thr_pt = 1 # use 1 to prevent shuffling n_thr_mb = 1 # use 1 to prevent shuffling n_thr_psc = 4 im_qsz = 4 pat_qsz = 300 mb_qsz_tr = hps.train_its + 1 mb_qsz_ts = hps.test_its + 1 requeue = True # True: keep re-adding the same data to the queues for future epochs tr_fns, hps.n_tr_inst = sidd_filenames_que_inst(hps.sidd_path, 'train', hps.start_tr_im_idx, hps.end_tr_im_idx, hps.camera, hps.iso) ts_fns, hps.n_ts_inst = sidd_filenames_que_inst(hps.sidd_path, 'test', hps.start_ts_im_idx, hps.end_ts_im_idx, hps.camera, hps.iso) # image loaders tr_im_que, ts_im_que = get_image_ques(hps, requeue=requeue, n_thr_im=n_thr_im, im_qsz=im_qsz) # patch samplers tr_patch_sampler = PatchSampler(tr_im_que, patch_height=hps.patch_height, sampling=hps.patch_sampling, max_queue_size=pat_qsz, n_threads=n_thr_pt, n_reuse_image=hps.n_reuse_image, n_pat_per_im=hps.n_patches_per_image, shuffle=hps.shuffle_patches) ts_patch_sampler = PatchSampler(ts_im_que, patch_height=hps.patch_height, sampling='uniform', max_queue_size=pat_qsz, n_threads=n_thr_pt, n_reuse_image=hps.n_reuse_image, n_pat_per_im=hps.n_patches_per_image, shuffle=hps.shuffle_patches) tr_pat_que = tr_patch_sampler.get_queue() ts_pat_que = ts_patch_sampler.get_queue() # patch stats and baselines if hps.calc_pat_stats_and_baselines_only: pat_stats = None else: pat_stats_calculator = PatchStatsCalculator(None, hps.patch_height, n_channels=4, save_dir=logdir, file_postfix='', n_threads=n_thr_psc, hps=hps) pat_stats = pat_stats_calculator.load_pat_stats() nll_gauss, bpd_gauss = pat_stats_calculator.load_gauss_baseline() nll_sdn, bpd_sdn = pat_stats_calculator.load_sdn_baseline() if tr_batch_sampler is None: tr_batch_sampler = MiniBatchSampler(tr_pat_que, minibatch_size=hps.n_batch_train, max_queue_size=mb_qsz_tr, n_threads=n_thr_mb, pat_stats=pat_stats) ts_batch_sampler = MiniBatchSampler(ts_pat_que, minibatch_size=hps.n_batch_test, max_queue_size=mb_qsz_ts, n_threads=n_thr_mb, pat_stats=pat_stats) tr_batch_que = tr_batch_sampler.get_queue() ts_batch_que = ts_batch_sampler.get_queue() # patch stats and baselines if hps.calc_pat_stats_and_baselines_only: pat_stats_calculator = PatchStatsCalculator(tr_batch_que, hps.patch_height, n_channels=4, save_dir=logdir, file_postfix='', n_threads=n_thr_psc, hps=hps) pat_stats = pat_stats_calculator.calc_stats() nll_gauss, bpd_gauss, nll_sdn, bpd_sdn = pat_stats_calculator.calc_baselines( ts_batch_que) logging.trace('initialize_data_queues_and_baselines: done') if hps.calc_pat_stats_and_baselines_only: return pat_stats, nll_gauss, bpd_gauss, nll_sdn, bpd_sdn, tr_batch_sampler, ts_batch_sampler else: return tr_im_que, ts_im_que, tr_pat_que, ts_pat_que, tr_batch_que, ts_batch_que