Exemplo n.º 1
0
    def load_noise_flow_model(self):

        if not hasattr(self.hps, 'x_shape'):
            setattr(self.hps, 'x_shape', self.x_shape)
        self.x = tf.placeholder(tf.float32, self.x_shape, name='noise_image')
        self.y = tf.placeholder(tf.float32, self.x_shape, name='clean_image')
        self.nlf0 = tf.placeholder(tf.float32, [None], name='nlf0')
        self.nlf1 = tf.placeholder(tf.float32, [None], name='nlf1')
        self.iso = tf.placeholder(tf.float32, [None], name='iso')
        self.cam = tf.placeholder(tf.float32, [None], name='cam')
        self.is_training = tf.placeholder(tf.bool, name='is_training')
        # graph1 = tf.Graph()
        # with graph1.as_default():

        self.logger.info('Building Noise Flow')
        self.nf_model = NoiseFlow(self.x_shape[1:], self.is_training, self.hps)

        self.logger.info('Creating sampling operation')
        self.x_sample = self.sample_sidd_tf()

        self.logger.info('Creating saver')
        self.saver = tf.train.Saver()

        self.logger.info('Creating session')
        self.sess = tf.Session()

        self.logger.info('Initializing variables')
        self.sess.run(tf.global_variables_initializer())

        self.logger.info('Restoring best model')
        # last_epoch = restore_last_model(self.ckpt_dir, self.sess, self.saver)
        self.saver.restore(self.sess, self.model_checkpoint_path)
Exemplo n.º 2
0
def main(hps):

    # Download SIDD_Medium_Raw?
    check_download_sidd()

    total_time = time.time()
    host = socket.gethostname()
    tf.set_random_seed(hps.seed)
    np.random.seed(hps.seed)

    # set up a custom logger
    add_logging_level('TRACE', 100)
    logging.getLogger(__name__).setLevel("TRACE")
    logging.basicConfig(level=logging.TRACE)

    hps.n_bins = 2.**hps.n_bits_x

    logging.trace('SIDD path = %s' % hps.sidd_path)

    # prepare data file names
    tr_fns, hps.n_tr_inst = sidd_filenames_que_inst(hps.sidd_path, 'train',
                                                    hps.start_tr_im_idx,
                                                    hps.end_tr_im_idx,
                                                    hps.camera, hps.iso)
    logging.trace('# training scene instances (cam = %s, iso = %s) = %d' %
                  (str(hps.camera), str(hps.iso), hps.n_tr_inst))
    ts_fns, hps.n_ts_inst = sidd_filenames_que_inst(hps.sidd_path, 'test',
                                                    hps.start_ts_im_idx,
                                                    hps.end_ts_im_idx,
                                                    hps.camera, hps.iso)
    logging.trace('# testing scene instances (cam = %s, iso = %s) = %d' %
                  (str(hps.camera), str(hps.iso), hps.n_ts_inst))

    # training/testing data stats
    calc_train_test_stats(hps)

    # output log dir
    logdir = os.path.abspath(
        os.path.join('experiments', hps.problem, hps.logdir)) + '/'
    if not os.path.exists(logdir):
        os.makedirs(logdir, exist_ok=True)
    hps.logdirname = hps.logdir
    hps.logdir = logdir

    train_its, test_its = get_its(hps.n_batch_train, hps.n_batch_test,
                                  hps.n_train, hps.n_test)
    hps.train_its = train_its
    hps.test_its = test_its

    x_shape = [None, hps.patch_height, hps.patch_height, 4]
    hps.x_shape = x_shape
    hps.n_dims = np.prod(x_shape[1:])

    # calculate data stats and baselines
    logging.trace('calculating data stats and baselines...')
    hps.calc_pat_stats_and_baselines_only = True
    pat_stats, nll_gauss, _, nll_sdn, _, tr_batch_sampler, ts_batch_sampler = initialize_data_stats_queues_baselines_histograms(
        hps, logdir)
    hps.nll_gauss = nll_gauss
    hps.nll_sdn = nll_sdn

    # prepare get data queues
    hps.mb_requeue = True  # requeue minibatches for future epochs
    logging.trace('preparing data queues...')
    hps.calc_pat_stats_and_baselines_only = False
    tr_im_que, ts_im_que, tr_pat_que, ts_pat_que, tr_batch_que, ts_batch_que = \
        initialize_data_stats_queues_baselines_histograms(hps, logdir, tr_batch_sampler=tr_batch_sampler, ts_batch_sampler=ts_batch_sampler)
    # hps.save_batches = True

    print_train_test_stats(hps)

    input_shape = x_shape

    # Build noise flow graph
    logging.trace('Building NoiseFlow...')
    is_training = tf.placeholder(tf.bool, name='is_training')
    x = tf.placeholder(tf.float32, x_shape, name='noise_image')
    y = tf.placeholder(tf.float32, x_shape, name='clean_image')
    nlf0 = tf.placeholder(tf.float32, [None], name='nlf0')
    nlf1 = tf.placeholder(tf.float32, [None], name='nlf1')
    iso = tf.placeholder(tf.float32, [None], name='iso')
    cam = tf.placeholder(tf.float32, [None], name='cam')
    lr = tf.placeholder(tf.float32, None, name='learning_rate')

    # initialization of signal, gain, and camera parameters
    if hps.sidd_cond == 'mix':
        init_params(hps)

    # NoiseFlow model
    nf = NoiseFlow(input_shape[1:], is_training, hps)
    loss_val, sd_z = nf.loss(x, y, nlf0=nlf0, nlf1=nlf1, iso=iso, cam=cam)

    # save variable names and number of parameters
    vs = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    vars_files = os.path.join(hps.logdir, 'model_vars.txt')
    with open(vars_files, 'w') as vf:
        vf.write(str(vs))
    hps.num_params = int(
        np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ]))
    logging.trace('number of parameters = %d' % hps.num_params)
    hps_logger(logdir + 'hps.txt', hps, nf.get_layer_names(), hps.num_params)

    # create session
    sess = tf.Session()

    n_processed = 0
    train_time = 0.0
    test_loss_best = np.inf

    # create a saver.
    saver = tf.train.Saver(max_to_keep=0)  # keep all models

    # checkpoint directory
    ckpt_dir = os.path.join(hps.logdir, 'ckpt')
    ckpt_path = os.path.join(ckpt_dir, 'model.ckpt')
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir, exist_ok=True)

    # sampling temperature (default = 1.0)
    if hps.temp is None:
        hps.temp = 1.0

    # setup the output log
    train_logger = test_logger = None
    log_columns = ['epoch', 'NLL']
    # NLL: negative log likelihood
    # NLL_G: for Gaussian baseline
    # NLL_SDN: for camera NLF baseline
    # sdz: standard deviation of the base measure (sanity check)
    log_columns = log_columns + ['NLL_G', 'NLL_SDN', 'sdz']
    if hps.do_sample:
        log_columns.append('sample_time')
    else:
        train_logger = ResultLogger(logdir + 'train.txt',
                                    log_columns + ['train_time'],
                                    hps.continue_training)
        test_logger = ResultLogger(logdir + 'test.txt', log_columns + ['msg'],
                                   hps.continue_training)
    sample_logger = ResultLogger(
        logdir + 'sample.txt',
        log_columns + ['KLD_G', 'KLD_NLF', 'KLD_NF', 'KLD_R'],
        hps.continue_training)

    tcurr = time.time()
    train_results = []
    test_results = []
    sample_results = []

    # continue training?
    start_epoch = 1
    logging.trace('continue_training = ' + str(hps.continue_training))
    if hps.continue_training:
        sess.run(tf.global_variables_initializer())
        last_epoch = restore_last_model(ckpt_dir, sess, saver)
        start_epoch = 1 + last_epoch
        # noinspection PyBroadException
        try:
            train_op = tf.get_collection('train_op')  # [0]
        except:
            logging.trace(
                'could not restore optimizer state, preparing a new optimizer')
            train_op = get_optimizer(hps, lr, loss_val)
    else:
        logging.trace('preparing optimizer')
        train_op = get_optimizer(hps, lr, loss_val)
        logging.trace('initializing variables')
        sess.run(tf.global_variables_initializer())

    _lr = hps.lr
    _nlf0 = None
    _nlf1 = None
    t_train = t_test = t_sample = dsample = is_best = sd_z_tr = sd_z_ts = 0
    kldiv3 = None

    # Epochs
    logging.trace('Starting training/testing/samplings.')
    logging.trace('Logging to ' + logdir)

    for epoch in range(start_epoch, hps.epochs + 1):

        # Testing
        if (not hps.do_sample) and \
                (epoch < 10 or (epoch < 100 and epoch % 10 == 0) or epoch % hps.epochs_full_valid == 0.):
            t = time.time()
            test_epoch_loss = []

            # multi-thread testing (faster)
            test_epoch_loss_que = queue.Queue()
            sd_z_que_ts = queue.Queue()
            sd_z_ts = 0

            test_multithread(sess,
                             ts_batch_que,
                             loss_val,
                             sd_z,
                             x,
                             y,
                             nlf0,
                             nlf1,
                             iso,
                             cam,
                             is_training,
                             test_epoch_loss_que,
                             sd_z_que_ts,
                             test_its,
                             nthr=hps.n_train_threads,
                             requeue=not hps.mb_requeue)

            assert test_epoch_loss_que.qsize() == test_its
            for tt in range(test_its):
                test_epoch_loss.append(test_epoch_loss_que.get())
                sd_z_ts += sd_z_que_ts.get()
            sd_z_ts /= test_its

            mean_test_loss = np.mean(test_epoch_loss)
            test_results.append(mean_test_loss)

            # Save checkpoint
            saver.save(sess, ckpt_path, global_step=epoch)

            # best model?
            if test_results[-1] < test_loss_best:
                test_loss_best = test_results[-1]
                saver.save(sess, ckpt_path + '.best')
                is_best = 1
            else:
                is_best = 0

            # log
            log_dict = {
                'epoch': epoch,
                'NLL': test_results[-1],
                'NLL_G': nll_gauss,
                'NLL_SDN': nll_sdn,
                'sdz': sd_z_ts,
                'msg': is_best
            }
            test_logger.log(log_dict)

            t_test = time.time() - t

        # End testing if & loop

        # Sampling (optional)
        do_sampling = True  # make this true to perform sampling
        if do_sampling and ((
                epoch < 10 or
            (epoch < 100 and epoch % 10 == 0) or  # (is_best == 1) or
                epoch % hps.epochs_full_valid * 2 == 0.)):
            for temp in [1.0]:  # using only default temperature
                t_sample = time.time()
                hps.temp = float(temp)
                sample_epoch_loss = []

                # multi-thread sampling (faster)
                sample_epoch_loss_que = queue.Queue()
                sd_z_que_sam = queue.Queue()
                kldiv_que = queue.Queue()
                sd_z_sam = 0.0
                kldiv1 = np.ndarray([4])
                kldiv1[:] = 0.0
                kldiv3 = np.zeros(4)

                is_cond = hps.sidd_cond != 'uncond'

                # sample (forward)
                x_sample = sample_sidd_tf(sess, nf, is_training, hps.temp, y,
                                          nlf0, nlf1, iso, cam, is_cond)

                sample_multithread(sess,
                                   ts_batch_que,
                                   loss_val,
                                   sd_z,
                                   x,
                                   x_sample,
                                   y,
                                   nlf0,
                                   nlf1,
                                   iso,
                                   cam,
                                   is_training,
                                   sample_epoch_loss_que,
                                   sd_z_que_sam,
                                   kldiv_que,
                                   test_its,
                                   nthr=hps.n_train_threads,
                                   requeue=not hps.mb_requeue,
                                   sc_sd=pat_stats['sc_in_sd'],
                                   epoch=epoch)

                # assert sample_epoch_loss_que.qsize() == test_its
                nqs = sample_epoch_loss_que.qsize()
                for tt in range(nqs):
                    sample_epoch_loss.append(sample_epoch_loss_que.get())
                    sd_z_sam += sd_z_que_sam.get()
                    kldiv3 += kldiv_que.get()
                sd_z_sam /= nqs
                kldiv3 /= np.repeat(nqs, len(kldiv3))

                mean_sample_loss = np.mean(sample_epoch_loss)
                sample_results.append(mean_sample_loss)

                t_sample = time.time() - t_sample

                # log
                log_dict = {
                    'epoch': epoch,
                    'NLL': sample_results[-1],
                    'NLL_G': nll_gauss,
                    'NLL_SDN': nll_sdn,
                    'sdz': sd_z_sam,
                    'sample_time': t_sample,
                    'KLD_G': kldiv3[0],
                    'KLD_NLF': kldiv3[1],
                    'KLD_NF': kldiv3[2],
                    'KLD_R': kldiv3[3]
                }
                sample_logger.log(log_dict)

        # Training loop
        t_curr = 0
        if not hps.do_sample:
            t = time.time()
            train_epoch_loss = []

            # multi-thread training (faster)
            train_epoch_loss_que = queue.Queue()
            sd_z_que_tr = queue.Queue()
            n_processed_que = queue.Queue()
            sd_z_tr = 0

            train_multithread(sess,
                              tr_batch_que,
                              loss_val,
                              sd_z,
                              train_op,
                              x,
                              y,
                              nlf0,
                              nlf1,
                              iso,
                              cam,
                              lr,
                              is_training,
                              _lr,
                              n_processed_que,
                              train_epoch_loss_que,
                              sd_z_que_tr,
                              train_its,
                              nthr=hps.n_train_threads,
                              requeue=not hps.mb_requeue)

            assert train_epoch_loss_que.qsize() == train_its
            for tt in range(train_its):
                train_epoch_loss.append(train_epoch_loss_que.get())
                n_processed += n_processed_que.get()
                sd_z_tr += sd_z_que_tr.get()
            sd_z_tr /= train_its

            t_curr = time.time() - tcurr
            tcurr = time.time()

            mean_train_loss = np.mean(train_epoch_loss)
            train_results.append(mean_train_loss)

            t_train = time.time() - t
            train_time += t_train

            train_logger.log({
                'epoch': epoch,
                'train_time': int(train_time),
                'NLL': train_results[-1],
                'NLL_G': nll_gauss,
                'NLL_SDN': nll_sdn,
                'sdz': sd_z_tr
            })
        # End training

        # print results of train/test/sample
        tr_l = train_results[-1] if len(train_results) > 0 else 0
        ts_l = test_results[-1] if len(test_results) > 0 else 0
        sam_l = sample_results[-1] if len(sample_results) > 0 else 0
        if epoch < 10 or (epoch < 100 and epoch % 10 == 0) or \
                epoch % hps.epochs_full_valid == 0.:
            # E: epoch
            # tr, ts, tsm, tv: time of training, testing, sampling, visualization
            # T: total time
            # tL, sL, smL: loss of training, testing, sampling
            # SDr, SDs: std. dev. of base measure in training and testing
            # B: 1 if best model, 0 otherwise
            print('%s %s %s E=%d tr=%.1f ts=%.1f tsm=%.1f tv=%.1f T=%.1f '
                  'tL=%5.1f sL=%5.1f smL=%5.1f SDr=%.1f SDs=%.1f B=%d' %
                  (str(datetime.now())[11:16], host, hps.logdirname, epoch,
                   t_train, t_test, t_sample, dsample, t_curr, tr_l, ts_l,
                   sam_l, sd_z_tr, sd_z_ts, is_best),
                  end='')
            if kldiv3 is not None:
                print(' ', end='')
                # marginal KL divergence of noise samples from: Gaussian, camera-NLF, and NoiseFlow, respectively
                print(','.join('{0:.3f}'.format(kk) for kk in kldiv3), end='')
            print('', flush=True)

    total_time = time.time() - total_time
    logging.trace('Total time = %f' % total_time)
    with open(path.join(logdir, 'total_time.txt'), 'w') as f:
        f.write('total_time (s) = %f' % total_time)
    logging.trace("Finished!")
Exemplo n.º 3
0
class NoiseFlowWrapper:
    #def __init__(self, path):
    def __init__(self, path, patch_size=[512, 512]):
        logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger(__name__)
        self.nf_path = path
        self.nf_model = None
        self.sess = None
        self.saver = None

        #self.x_shape = None
        self.x_shape = [None, patch_size[0] // 2, patch_size[1] // 2, 4]
        self.x = None
        self.y = None
        self.nlf0 = None
        self.nlf1 = None
        self.iso = None
        self.cam = None
        self.is_training = None
        self.x_sample = None

        self.is_cond = True
        self.temp = 1.0

        self.hps = self.hps_loader(os.path.join(self.nf_path, 'hps.txt'))
        self.ckpt_dir = os.path.join(self.nf_path, 'ckpt')
        self.model_checkpoint_path = os.path.join(self.ckpt_dir,
                                                  'model.ckpt.best')
        self.load_noise_flow_model()

    def load_noise_flow_model(self):

        if not hasattr(self.hps, 'x_shape'):
            setattr(self.hps, 'x_shape', self.x_shape)
        self.x = tf.placeholder(tf.float32, self.x_shape, name='noise_image')
        self.y = tf.placeholder(tf.float32, self.x_shape, name='clean_image')
        self.nlf0 = tf.placeholder(tf.float32, [None], name='nlf0')
        self.nlf1 = tf.placeholder(tf.float32, [None], name='nlf1')
        self.iso = tf.placeholder(tf.float32, [None], name='iso')
        self.cam = tf.placeholder(tf.float32, [None], name='cam')
        self.is_training = tf.placeholder(tf.bool, name='is_training')
        # graph1 = tf.Graph()
        # with graph1.as_default():

        self.logger.info('Building Noise Flow')
        self.nf_model = NoiseFlow(self.x_shape[1:], self.is_training, self.hps)

        self.logger.info('Creating sampling operation')
        self.x_sample = self.sample_sidd_tf()

        self.logger.info('Creating saver')
        self.saver = tf.train.Saver()

        self.logger.info('Creating session')
        self.sess = tf.Session()

        self.logger.info('Initializing variables')
        self.sess.run(tf.global_variables_initializer())

        self.logger.info('Restoring best model')
        # last_epoch = restore_last_model(self.ckpt_dir, self.sess, self.saver)
        self.saver.restore(self.sess, self.model_checkpoint_path)
        # import pdb
        # pdb.set_trace()

    def sample_noise_nf(self, batch_x, b1, b2, iso, cam):
        noise = None
        # sig = np.sqrt(b1 * batch_x + b2)  # in [0, 1]
        # noise = np.random.normal(0.0, sig, batch_x.shape)
        noise = self.sess.run(self.x_sample,
                              feed_dict={
                                  self.y: batch_x,
                                  self.nlf0: [b1],
                                  self.nlf1: [b2],
                                  self.iso: [iso],
                                  self.cam: [cam],
                                  self.is_training: True
                              })
        return noise

    def sample_sidd_tf(self):
        if self.is_cond:
            x_sample = self.nf_model.sample(self.y, self.temp, self.y,
                                            self.nlf0, self.nlf1, self.iso,
                                            self.cam)
        else:
            x_sample = self.nf_model.sample(self.y, self.temp)
        return x_sample

    def hps_loader(self, path):
        import csv

        class Hps:
            pass

        hps = Hps()
        with open(path, 'r') as f:
            reader = csv.reader(f)
            for pair in reader:
                if len(pair) < 2:
                    continue
                val = pair[1]
                try:
                    val = int(val)
                except ValueError:
                    try:
                        val = float(val)
                    except:
                        if val == 'True':
                            val = True
                        elif val == 'False':
                            val = False
                        # elif pair[0] == 'param_inits':
                        # import pdb
                        # pdb.set_trace()
                        # val = val.replace('\n', '')  # .replace('\r', '')
                        # val = ast.literal_eval(val)
                hps.__setattr__(pair[0], val)
        if hps.arch.__contains__('sdn5'):
            npcam = 3
        elif hps.arch.__contains__('sdn6'):
            npcam = 1
        # c_i = 1e-1
        c_i = 1.0
        beta1_i = -5.0 / c_i
        beta2_i = 0.0
        gain_params_i = np.ndarray([5])
        gain_params_i[:] = -5.0 / c_i
        cam_params_i = np.ndarray([npcam, 5])
        cam_params_i[:, :] = 1.0
        hps.param_inits = (c_i, beta1_i, beta2_i, gain_params_i, cam_params_i)
        return hps