Exemple #1
0
def get_image_ques(hps, requeue, n_thr_im, im_qsz=16):
    tr_fns, hps.n_tr_inst = sidd_filenames_que_inst(hps.sidd_path, 'train', hps.start_tr_im_idx, hps.end_tr_im_idx,
                                                    hps.camera, hps.iso)
    ts_fns, hps.n_ts_inst = sidd_filenames_que_inst(hps.sidd_path, 'test', hps.start_ts_im_idx, hps.end_ts_im_idx,
                                                    hps.camera, hps.iso)
    # image loaders
    tr_image_loader = ImageLoader(tr_fns, max_queue_size=im_qsz, n_threads=n_thr_im, requeue=requeue)
    ts_image_loader = ImageLoader(ts_fns, max_queue_size=im_qsz, n_threads=n_thr_im, requeue=requeue)
    tr_im_que = tr_image_loader.get_queue()
    ts_im_que = ts_image_loader.get_queue()

    return tr_im_que, ts_im_que
def main(hps):

    # Download SIDD_Medium_Raw?
    check_download_sidd()

    total_time = time.time()
    host = socket.gethostname()
    tf.set_random_seed(hps.seed)
    np.random.seed(hps.seed)

    # set up a custom logger
    add_logging_level('TRACE', 100)
    logging.getLogger(__name__).setLevel("TRACE")
    logging.basicConfig(level=logging.TRACE)

    hps.n_bins = 2.**hps.n_bits_x

    logging.trace('SIDD path = %s' % hps.sidd_path)

    # prepare data file names
    tr_fns, hps.n_tr_inst = sidd_filenames_que_inst(hps.sidd_path, 'train',
                                                    hps.start_tr_im_idx,
                                                    hps.end_tr_im_idx,
                                                    hps.camera, hps.iso)
    logging.trace('# training scene instances (cam = %s, iso = %s) = %d' %
                  (str(hps.camera), str(hps.iso), hps.n_tr_inst))
    ts_fns, hps.n_ts_inst = sidd_filenames_que_inst(hps.sidd_path, 'test',
                                                    hps.start_ts_im_idx,
                                                    hps.end_ts_im_idx,
                                                    hps.camera, hps.iso)
    logging.trace('# testing scene instances (cam = %s, iso = %s) = %d' %
                  (str(hps.camera), str(hps.iso), hps.n_ts_inst))

    # training/testing data stats
    calc_train_test_stats(hps)

    # output log dir
    logdir = os.path.abspath(
        os.path.join('experiments', hps.problem, hps.logdir)) + '/'
    if not os.path.exists(logdir):
        os.makedirs(logdir, exist_ok=True)
    hps.logdirname = hps.logdir
    hps.logdir = logdir

    train_its, test_its = get_its(hps.n_batch_train, hps.n_batch_test,
                                  hps.n_train, hps.n_test)
    hps.train_its = train_its
    hps.test_its = test_its

    x_shape = [None, hps.patch_height, hps.patch_height, 4]
    hps.x_shape = x_shape
    hps.n_dims = np.prod(x_shape[1:])

    # calculate data stats and baselines
    logging.trace('calculating data stats and baselines...')
    hps.calc_pat_stats_and_baselines_only = True
    pat_stats, nll_gauss, _, nll_sdn, _, tr_batch_sampler, ts_batch_sampler = initialize_data_stats_queues_baselines_histograms(
        hps, logdir)
    hps.nll_gauss = nll_gauss
    hps.nll_sdn = nll_sdn

    # prepare get data queues
    hps.mb_requeue = True  # requeue minibatches for future epochs
    logging.trace('preparing data queues...')
    hps.calc_pat_stats_and_baselines_only = False
    tr_im_que, ts_im_que, tr_pat_que, ts_pat_que, tr_batch_que, ts_batch_que = \
        initialize_data_stats_queues_baselines_histograms(hps, logdir, tr_batch_sampler=tr_batch_sampler, ts_batch_sampler=ts_batch_sampler)
    # hps.save_batches = True

    print_train_test_stats(hps)

    input_shape = x_shape

    # Build noise flow graph
    logging.trace('Building NoiseFlow...')
    is_training = tf.placeholder(tf.bool, name='is_training')
    x = tf.placeholder(tf.float32, x_shape, name='noise_image')
    y = tf.placeholder(tf.float32, x_shape, name='clean_image')
    nlf0 = tf.placeholder(tf.float32, [None], name='nlf0')
    nlf1 = tf.placeholder(tf.float32, [None], name='nlf1')
    iso = tf.placeholder(tf.float32, [None], name='iso')
    cam = tf.placeholder(tf.float32, [None], name='cam')
    lr = tf.placeholder(tf.float32, None, name='learning_rate')

    # initialization of signal, gain, and camera parameters
    if hps.sidd_cond == 'mix':
        init_params(hps)

    # NoiseFlow model
    nf = NoiseFlow(input_shape[1:], is_training, hps)
    loss_val, sd_z = nf.loss(x, y, nlf0=nlf0, nlf1=nlf1, iso=iso, cam=cam)

    # save variable names and number of parameters
    vs = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    vars_files = os.path.join(hps.logdir, 'model_vars.txt')
    with open(vars_files, 'w') as vf:
        vf.write(str(vs))
    hps.num_params = int(
        np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ]))
    logging.trace('number of parameters = %d' % hps.num_params)
    hps_logger(logdir + 'hps.txt', hps, nf.get_layer_names(), hps.num_params)

    # create session
    sess = tf.Session()

    n_processed = 0
    train_time = 0.0
    test_loss_best = np.inf

    # create a saver.
    saver = tf.train.Saver(max_to_keep=0)  # keep all models

    # checkpoint directory
    ckpt_dir = os.path.join(hps.logdir, 'ckpt')
    ckpt_path = os.path.join(ckpt_dir, 'model.ckpt')
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir, exist_ok=True)

    # sampling temperature (default = 1.0)
    if hps.temp is None:
        hps.temp = 1.0

    # setup the output log
    train_logger = test_logger = None
    log_columns = ['epoch', 'NLL']
    # NLL: negative log likelihood
    # NLL_G: for Gaussian baseline
    # NLL_SDN: for camera NLF baseline
    # sdz: standard deviation of the base measure (sanity check)
    log_columns = log_columns + ['NLL_G', 'NLL_SDN', 'sdz']
    if hps.do_sample:
        log_columns.append('sample_time')
    else:
        train_logger = ResultLogger(logdir + 'train.txt',
                                    log_columns + ['train_time'],
                                    hps.continue_training)
        test_logger = ResultLogger(logdir + 'test.txt', log_columns + ['msg'],
                                   hps.continue_training)
    sample_logger = ResultLogger(
        logdir + 'sample.txt',
        log_columns + ['KLD_G', 'KLD_NLF', 'KLD_NF', 'KLD_R'],
        hps.continue_training)

    tcurr = time.time()
    train_results = []
    test_results = []
    sample_results = []

    # continue training?
    start_epoch = 1
    logging.trace('continue_training = ' + str(hps.continue_training))
    if hps.continue_training:
        sess.run(tf.global_variables_initializer())
        last_epoch = restore_last_model(ckpt_dir, sess, saver)
        start_epoch = 1 + last_epoch
        # noinspection PyBroadException
        try:
            train_op = tf.get_collection('train_op')  # [0]
        except:
            logging.trace(
                'could not restore optimizer state, preparing a new optimizer')
            train_op = get_optimizer(hps, lr, loss_val)
    else:
        logging.trace('preparing optimizer')
        train_op = get_optimizer(hps, lr, loss_val)
        logging.trace('initializing variables')
        sess.run(tf.global_variables_initializer())

    _lr = hps.lr
    _nlf0 = None
    _nlf1 = None
    t_train = t_test = t_sample = dsample = is_best = sd_z_tr = sd_z_ts = 0
    kldiv3 = None

    # Epochs
    logging.trace('Starting training/testing/samplings.')
    logging.trace('Logging to ' + logdir)

    for epoch in range(start_epoch, hps.epochs + 1):

        # Testing
        if (not hps.do_sample) and \
                (epoch < 10 or (epoch < 100 and epoch % 10 == 0) or epoch % hps.epochs_full_valid == 0.):
            t = time.time()
            test_epoch_loss = []

            # multi-thread testing (faster)
            test_epoch_loss_que = queue.Queue()
            sd_z_que_ts = queue.Queue()
            sd_z_ts = 0

            test_multithread(sess,
                             ts_batch_que,
                             loss_val,
                             sd_z,
                             x,
                             y,
                             nlf0,
                             nlf1,
                             iso,
                             cam,
                             is_training,
                             test_epoch_loss_que,
                             sd_z_que_ts,
                             test_its,
                             nthr=hps.n_train_threads,
                             requeue=not hps.mb_requeue)

            assert test_epoch_loss_que.qsize() == test_its
            for tt in range(test_its):
                test_epoch_loss.append(test_epoch_loss_que.get())
                sd_z_ts += sd_z_que_ts.get()
            sd_z_ts /= test_its

            mean_test_loss = np.mean(test_epoch_loss)
            test_results.append(mean_test_loss)

            # Save checkpoint
            saver.save(sess, ckpt_path, global_step=epoch)

            # best model?
            if test_results[-1] < test_loss_best:
                test_loss_best = test_results[-1]
                saver.save(sess, ckpt_path + '.best')
                is_best = 1
            else:
                is_best = 0

            # log
            log_dict = {
                'epoch': epoch,
                'NLL': test_results[-1],
                'NLL_G': nll_gauss,
                'NLL_SDN': nll_sdn,
                'sdz': sd_z_ts,
                'msg': is_best
            }
            test_logger.log(log_dict)

            t_test = time.time() - t

        # End testing if & loop

        # Sampling (optional)
        do_sampling = True  # make this true to perform sampling
        if do_sampling and ((
                epoch < 10 or
            (epoch < 100 and epoch % 10 == 0) or  # (is_best == 1) or
                epoch % hps.epochs_full_valid * 2 == 0.)):
            for temp in [1.0]:  # using only default temperature
                t_sample = time.time()
                hps.temp = float(temp)
                sample_epoch_loss = []

                # multi-thread sampling (faster)
                sample_epoch_loss_que = queue.Queue()
                sd_z_que_sam = queue.Queue()
                kldiv_que = queue.Queue()
                sd_z_sam = 0.0
                kldiv1 = np.ndarray([4])
                kldiv1[:] = 0.0
                kldiv3 = np.zeros(4)

                is_cond = hps.sidd_cond != 'uncond'

                # sample (forward)
                x_sample = sample_sidd_tf(sess, nf, is_training, hps.temp, y,
                                          nlf0, nlf1, iso, cam, is_cond)

                sample_multithread(sess,
                                   ts_batch_que,
                                   loss_val,
                                   sd_z,
                                   x,
                                   x_sample,
                                   y,
                                   nlf0,
                                   nlf1,
                                   iso,
                                   cam,
                                   is_training,
                                   sample_epoch_loss_que,
                                   sd_z_que_sam,
                                   kldiv_que,
                                   test_its,
                                   nthr=hps.n_train_threads,
                                   requeue=not hps.mb_requeue,
                                   sc_sd=pat_stats['sc_in_sd'],
                                   epoch=epoch)

                # assert sample_epoch_loss_que.qsize() == test_its
                nqs = sample_epoch_loss_que.qsize()
                for tt in range(nqs):
                    sample_epoch_loss.append(sample_epoch_loss_que.get())
                    sd_z_sam += sd_z_que_sam.get()
                    kldiv3 += kldiv_que.get()
                sd_z_sam /= nqs
                kldiv3 /= np.repeat(nqs, len(kldiv3))

                mean_sample_loss = np.mean(sample_epoch_loss)
                sample_results.append(mean_sample_loss)

                t_sample = time.time() - t_sample

                # log
                log_dict = {
                    'epoch': epoch,
                    'NLL': sample_results[-1],
                    'NLL_G': nll_gauss,
                    'NLL_SDN': nll_sdn,
                    'sdz': sd_z_sam,
                    'sample_time': t_sample,
                    'KLD_G': kldiv3[0],
                    'KLD_NLF': kldiv3[1],
                    'KLD_NF': kldiv3[2],
                    'KLD_R': kldiv3[3]
                }
                sample_logger.log(log_dict)

        # Training loop
        t_curr = 0
        if not hps.do_sample:
            t = time.time()
            train_epoch_loss = []

            # multi-thread training (faster)
            train_epoch_loss_que = queue.Queue()
            sd_z_que_tr = queue.Queue()
            n_processed_que = queue.Queue()
            sd_z_tr = 0

            train_multithread(sess,
                              tr_batch_que,
                              loss_val,
                              sd_z,
                              train_op,
                              x,
                              y,
                              nlf0,
                              nlf1,
                              iso,
                              cam,
                              lr,
                              is_training,
                              _lr,
                              n_processed_que,
                              train_epoch_loss_que,
                              sd_z_que_tr,
                              train_its,
                              nthr=hps.n_train_threads,
                              requeue=not hps.mb_requeue)

            assert train_epoch_loss_que.qsize() == train_its
            for tt in range(train_its):
                train_epoch_loss.append(train_epoch_loss_que.get())
                n_processed += n_processed_que.get()
                sd_z_tr += sd_z_que_tr.get()
            sd_z_tr /= train_its

            t_curr = time.time() - tcurr
            tcurr = time.time()

            mean_train_loss = np.mean(train_epoch_loss)
            train_results.append(mean_train_loss)

            t_train = time.time() - t
            train_time += t_train

            train_logger.log({
                'epoch': epoch,
                'train_time': int(train_time),
                'NLL': train_results[-1],
                'NLL_G': nll_gauss,
                'NLL_SDN': nll_sdn,
                'sdz': sd_z_tr
            })
        # End training

        # print results of train/test/sample
        tr_l = train_results[-1] if len(train_results) > 0 else 0
        ts_l = test_results[-1] if len(test_results) > 0 else 0
        sam_l = sample_results[-1] if len(sample_results) > 0 else 0
        if epoch < 10 or (epoch < 100 and epoch % 10 == 0) or \
                epoch % hps.epochs_full_valid == 0.:
            # E: epoch
            # tr, ts, tsm, tv: time of training, testing, sampling, visualization
            # T: total time
            # tL, sL, smL: loss of training, testing, sampling
            # SDr, SDs: std. dev. of base measure in training and testing
            # B: 1 if best model, 0 otherwise
            print('%s %s %s E=%d tr=%.1f ts=%.1f tsm=%.1f tv=%.1f T=%.1f '
                  'tL=%5.1f sL=%5.1f smL=%5.1f SDr=%.1f SDs=%.1f B=%d' %
                  (str(datetime.now())[11:16], host, hps.logdirname, epoch,
                   t_train, t_test, t_sample, dsample, t_curr, tr_l, ts_l,
                   sam_l, sd_z_tr, sd_z_ts, is_best),
                  end='')
            if kldiv3 is not None:
                print(' ', end='')
                # marginal KL divergence of noise samples from: Gaussian, camera-NLF, and NoiseFlow, respectively
                print(','.join('{0:.3f}'.format(kk) for kk in kldiv3), end='')
            print('', flush=True)

    total_time = time.time() - total_time
    logging.trace('Total time = %f' % total_time)
    with open(path.join(logdir, 'total_time.txt'), 'w') as f:
        f.write('total_time (s) = %f' % total_time)
    logging.trace("Finished!")
def initialize_data_stats_queues_baselines_histograms(hps,
                                                      logdir,
                                                      tr_batch_sampler=None,
                                                      ts_batch_sampler=None):
    # use 4 or 8 thread for faster loading
    n_thr_im = 8
    n_thr_pt = 1  # use 1 to prevent shuffling
    n_thr_mb = 1  # use 1 to prevent shuffling
    n_thr_psc = 4

    im_qsz = 4
    pat_qsz = 300
    mb_qsz_tr = hps.train_its + 1
    mb_qsz_ts = hps.test_its + 1

    requeue = True  # True: keep re-adding the same data to the queues for future epochs

    tr_fns, hps.n_tr_inst = sidd_filenames_que_inst(hps.sidd_path, 'train',
                                                    hps.start_tr_im_idx,
                                                    hps.end_tr_im_idx,
                                                    hps.camera, hps.iso)
    ts_fns, hps.n_ts_inst = sidd_filenames_que_inst(hps.sidd_path, 'test',
                                                    hps.start_ts_im_idx,
                                                    hps.end_ts_im_idx,
                                                    hps.camera, hps.iso)

    # image loaders
    tr_im_que, ts_im_que = get_image_ques(hps,
                                          requeue=requeue,
                                          n_thr_im=n_thr_im,
                                          im_qsz=im_qsz)

    # patch samplers
    tr_patch_sampler = PatchSampler(tr_im_que,
                                    patch_height=hps.patch_height,
                                    sampling=hps.patch_sampling,
                                    max_queue_size=pat_qsz,
                                    n_threads=n_thr_pt,
                                    n_reuse_image=hps.n_reuse_image,
                                    n_pat_per_im=hps.n_patches_per_image,
                                    shuffle=hps.shuffle_patches)
    ts_patch_sampler = PatchSampler(ts_im_que,
                                    patch_height=hps.patch_height,
                                    sampling='uniform',
                                    max_queue_size=pat_qsz,
                                    n_threads=n_thr_pt,
                                    n_reuse_image=hps.n_reuse_image,
                                    n_pat_per_im=hps.n_patches_per_image,
                                    shuffle=hps.shuffle_patches)
    tr_pat_que = tr_patch_sampler.get_queue()
    ts_pat_que = ts_patch_sampler.get_queue()

    # patch stats and baselines
    if hps.calc_pat_stats_and_baselines_only:
        pat_stats = None
    else:
        pat_stats_calculator = PatchStatsCalculator(None,
                                                    hps.patch_height,
                                                    n_channels=4,
                                                    save_dir=logdir,
                                                    file_postfix='',
                                                    n_threads=n_thr_psc,
                                                    hps=hps)
        pat_stats = pat_stats_calculator.load_pat_stats()
        nll_gauss, bpd_gauss = pat_stats_calculator.load_gauss_baseline()
        nll_sdn, bpd_sdn = pat_stats_calculator.load_sdn_baseline()

    if tr_batch_sampler is None:
        tr_batch_sampler = MiniBatchSampler(tr_pat_que,
                                            minibatch_size=hps.n_batch_train,
                                            max_queue_size=mb_qsz_tr,
                                            n_threads=n_thr_mb,
                                            pat_stats=pat_stats)
        ts_batch_sampler = MiniBatchSampler(ts_pat_que,
                                            minibatch_size=hps.n_batch_test,
                                            max_queue_size=mb_qsz_ts,
                                            n_threads=n_thr_mb,
                                            pat_stats=pat_stats)
    tr_batch_que = tr_batch_sampler.get_queue()
    ts_batch_que = ts_batch_sampler.get_queue()

    # patch stats and baselines
    if hps.calc_pat_stats_and_baselines_only:
        pat_stats_calculator = PatchStatsCalculator(tr_batch_que,
                                                    hps.patch_height,
                                                    n_channels=4,
                                                    save_dir=logdir,
                                                    file_postfix='',
                                                    n_threads=n_thr_psc,
                                                    hps=hps)
        pat_stats = pat_stats_calculator.calc_stats()
        nll_gauss, bpd_gauss, nll_sdn, bpd_sdn = pat_stats_calculator.calc_baselines(
            ts_batch_que)
        logging.trace('initialize_data_queues_and_baselines: done')

    if hps.calc_pat_stats_and_baselines_only:
        return pat_stats, nll_gauss, bpd_gauss, nll_sdn, bpd_sdn, tr_batch_sampler, ts_batch_sampler
    else:
        return tr_im_que, ts_im_que, tr_pat_que, ts_pat_que, tr_batch_que, ts_batch_que