Exemplo n.º 1
0
def main(unused_argv):
    # load test images
    test_list = list_image(FLAGS.test_folder)
    # load model
    assert (FLAGS.snapshot_dir != ""
            or FLAGS.model_fname != ""), 'No pretrained model specified'
    model = Autoencoder(cfgs.patch_size * cfgs.patch_size, cfgs, log_dir=None)
    snapshot_fname = FLAGS.model_fname if FLAGS.model_fname != "" \
        else tf.train.latest_checkpoint(FLAGS.snapshot_dir)
    model.restore(snapshot_fname)
    print('Restored from %s' % snapshot_fname)
    sum_psnr = 0.0
    stride = FLAGS.stride
    for img_fname in test_list:
        orig_img = load_image('%s/%s' % (FLAGS.test_folder, img_fname))
        # pre-process image
        gray_img = toGrayscale(orig_img)
        img = gray_img.astype(np.float32)
        img -= cfgs.mean_value
        img *= cfgs.scale
        # make measurement and reconstruct image
        recon_img = overlap_inference(model,
                                      img,
                                      bs=cfgs.batch_size,
                                      stride=stride)
        recon_img /= cfgs.scale
        recon_img += cfgs.mean_value
        # save reconstruction
        cv.imwrite(
            '%s/%sOI_%d_%s' %
            (FLAGS.reconstruction_folder, FLAGS.prefix, stride, img_fname),
            recon_img.astype(np.uint8))
        psnr_ = psnr(gray_img.astype(np.float32), recon_img)
        print('Image %s, psnr: %f' % (img_fname, psnr_))
        sum_psnr += psnr_
    mean_psnr = sum_psnr / len(test_list)

    print('---------------------------')
    print('Mean PSNR: %f' % mean_psnr)
Exemplo n.º 2
0
def main(unused_argv):
    val_losses = []
    assert FLAGS.output_dir, "--output_dir is required"
    # Create training directory.
    output_dir = FLAGS.output_dir
    if not tf.gfile.IsDirectory(output_dir):
        tf.gfile.MakeDirs(output_dir)

    dl = DataLoader(FLAGS.db_fname, mean=cfgs.mean_value, scale=cfgs.scale, n_vals=FLAGS.n_vals)
    dl.prepare()

    x_dim = dl.get_data_dim()
    model = Autoencoder(x_dim, cfgs, log_dir=FLAGS.log_dir)
    model.quantize_weights()

    txt_log_fname = FLAGS.log_dir + 'text_log.txt'
    log_fout = open(txt_log_fname, 'w')

    if FLAGS.pretrained_fname:
        try:
            log_train(log_fout, 'Resume from %s' %(FLAGS.pretrained_fname))
            model.restore(FLAGS.pretrained_fname)
        except:
            log_train(log_fout, 'Cannot restore from %s' %(FLAGS.pretrained_fname))
            pass
    
    lr = cfgs.initial_lr
    epoch_counter = 0
    ite = 0
    while True:
        start = time.time()
        x, flag = dl.next_batch(cfgs.batch_size, 'train')
        load_data_time = time.time() - start
        if flag: 
            epoch_counter += 1
        
        do_log = (ite % FLAGS.log_every_n_steps == 0) or flag
        do_snapshot = flag and epoch_counter > 0 and epoch_counter % FLAGS.save_every_n_epochs == 0
        val_loss = -1

        # train one step
        start = time.time()
        loss, _, summary, ite = model.partial_fit(x, lr, do_log)
        one_iter_time = time.time() - start
        
        # writing outs
        if do_log:
            log_train(log_fout, 'Iteration %d, (lr=%f) training loss  : %f' %(ite, lr, loss))
            if FLAGS.log_time:
                log_train(log_fout, 'Iteration %d, data loading: %f(s) ; one iteration: %f(s)' 
                    %(ite, load_data_time, one_iter_time))
            model.log(summary)
        if flag:
            val_loss = val(model, dl)
            val_losses.append(val_loss)
            log_train(log_fout, '----------------------------------------------------')
            if ite == 0:
                log_train(log_fout, 'Initial validation loss: %f' %(val_loss))
            else:
                log_train(log_fout, 'Epoch %d, validation loss: %f' %(epoch_counter, val_loss))
            log_train(log_fout, '----------------------------------------------------')
            model.log(summary)
        if do_snapshot:
            log_train(log_fout, 'Snapshotting')
            model.save(FLAGS.output_dir)
        
        if flag: 
            if cfgs.lr_update == 'val' and len(val_losses) >= 5 and val_loss >= max(val_losses[-5:-1]):
                    lr = lr * cfgs.lr_decay_factor
                    log_train(log_fout, 'Decay learning rate to %f' %lr)
            elif cfgs.lr_update == 'step' and epoch_counter % cfgs.num_epochs_per_decay == 0:
                    lr = lr * cfgs.lr_decay_factor
                    log_train(log_fout, 'Decay learning rate to %f' %lr)
            if epoch_counter == FLAGS.n_epochs:
                if not do_snapshot:
                    log_train(log_fout, 'Final snapshotting')
                    model.save(FLAGS.output_dir)
                break
    log_fout.close()