def main(args):
    # Read model parameters
    checkpoint_path = tf.train.latest_checkpoint(args.checkpoint_dir)
    if checkpoint_path is None:
        log.error('Could not find a checkpoint in {}'.format(
            args.checkpoint_dir))
        return
    metapath = ".".join([checkpoint_path, "meta"])
    log.info("Loading {}".format(metapath))
    tf.train.import_meta_graph(metapath)
    with tf.Session() as sess:
        model_params = utils.get_model_params(sess)

    if not hasattr(models, model_params['model_name']):
        log.error("Model {} does not exist".format(model_params['model_name']))
        return
    mdl = getattr(models, model_params['model_name'])

    # Instantiate new evaluation graph
    tf.reset_default_graph()
    sz = model_params['net_input_size']

    log.info("Model {}".format(model_params['model_name']))

    #
    # identify the input and output tensors to export
    # the part of graph you'd like to freeze
    #
    fullres_input = tf.placeholder(tf.float32, (1, None, None, 3),
                                   name='fullres_input')
    input_tensor = tf.placeholder(tf.float32, (1, sz, sz, 3),
                                  name='lowres_input')
    with tf.variable_scope('inference'):
        prediction = mdl.inference(input_tensor,
                                   fullres_input,
                                   model_params,
                                   is_training=False)
    if model_params["model_name"] == "HDRNetGaussianPyrNN":

        # export seperate graphs for deploying models on android
        output_tensor = tf.get_collection('guide')[0]
        output_tensor = tf.reshape(output_tensor, [-1], name='guide')
        # output_tensor = tf.get_collection('packed_coefficients')[0]
        # gs = output_tensor.get_shape().as_list()
        # output_tensor = tf.reshape(tf.reshape(output_tensor, tf.stack([gs[0], gs[1], gs[2], gs[3], gs[4] * gs[5]])),
        #                            [-1], name="bilateral_coefficients")
        # output_tensor = tf.transpose(tf.squeeze(output_tensor), [3, 2, 0, 1, 4], name="bilateral_coefficients")

        # export the whole graph when deploying on cloud
        # output_tensor = tf.cast(255.0*tf.squeeze(tf.clip_by_value(output_tensor, 0, 1)), tf.uint8, name='output_img')
        log.info("Output shape".format(output_tensor.get_shape()))
    else:
        # export seperate graphs for deploying models on android
        output_tensor = tf.get_collection('guide')[0]
        output_tensor = tf.reshape(output_tensor, [-1], name='guide')
        # output_tensor = tf.get_collection('packed_coefficients')[0]
        # gs = output_tensor.get_shape().as_list()
        # output_tensor = tf.reshape(tf.reshape(output_tensor, tf.stack([gs[0], gs[1], gs[2], gs[3], gs[4]*gs[5]])),
        #                            [-1], name="bilateral_coefficients")
        # output_tensor = tf.transpose(tf.squeeze(output_tensor), [3, 2, 0, 1, 4], name="bilateral_coefficients")

        # export the whole graph when deploying on cloud
        # output_tensor = tf.cast(255.0*tf.squeeze(tf.clip_by_value(output_tensor, 0, 1)), tf.uint8, name='output_img')
        log.info("Output shape {}".format(output_tensor.get_shape()))
    saver = tf.train.Saver()

    gdef = tf.get_default_graph().as_graph_def()

    log.info("Restoring weights from {}".format(checkpoint_path))
    test_graph_name = "test_graph.pbtxt"
    with tf.Session() as sess:
        saver.restore(sess, checkpoint_path)
        tf.train.write_graph(sess.graph, args.checkpoint_dir, test_graph_name)

        input_graph_path = os.path.join(args.checkpoint_dir, test_graph_name)
        output_graph_path = os.path.join(args.checkpoint_dir,
                                         "frozen_graph.pb")
        input_saver_def_path = ""
        input_binary = False
        output_binary = True
        input_node_names = input_tensor.name.split(":")[0]
        output_node_names = output_tensor.name.split(":")[0]
        restore_op_name = "save/restore_all"
        filename_tensor_name = "save/Const:0"
        clear_devices = False

        log.info("Freezing to {}".format(output_graph_path))
        freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
                                  input_binary, checkpoint_path,
                                  output_node_names, restore_op_name,
                                  filename_tensor_name, output_graph_path,
                                  clear_devices, "")
        log.info('input tensor: {} {}'.format(input_tensor.name,
                                              input_tensor.shape))
        log.info('output tensor: {} {}'.format(output_tensor.name,
                                               output_tensor.shape))

        # Dump guide parameters
        if model_params['model_name'] == 'HDRNetCurves':
            g = tf.get_default_graph()
            ccm = g.get_tensor_by_name('inference/guide/ccm:0')
            ccm_bias = g.get_tensor_by_name('inference/guide/ccm_bias:0')
            shifts = g.get_tensor_by_name('inference/guide/shifts:0')
            slopes = g.get_tensor_by_name('inference/guide/slopes:0')
            mixing_weights = g.get_tensor_by_name(
                'inference/guide/channel_mixing/weights:0')
            mixing_bias = g.get_tensor_by_name(
                'inference/guide/channel_mixing/biases:0')

            ccm_, ccm_bias_, shifts_, slopes_, mixing_weights_, mixing_bias_ = sess.run(
                [ccm, ccm_bias, shifts, slopes, mixing_weights, mixing_bias])
            shifts_ = np.squeeze(shifts_).astype(np.float32)
            slopes_ = np.squeeze(slopes_).astype(np.float32)
            mix_matrix_dump = np.append(np.squeeze(mixing_weights_),
                                        mixing_bias_[0]).astype(np.float32)
            ccm34_ = np.vstack((ccm_, ccm_bias_[np.newaxis, :]))

            save(ccm34_.T,
                 os.path.join(args.checkpoint_dir, 'guide_ccm_f32_3x4.bin'))
            save(
                shifts_.T,
                os.path.join(args.checkpoint_dir, 'guide_shifts_f32_16x3.bin'))
            save(
                slopes_.T,
                os.path.join(args.checkpoint_dir, 'guide_slopes_f32_16x3.bin'))
            save(
                mix_matrix_dump,
                os.path.join(args.checkpoint_dir,
                             'guide_mix_matrix_f32_1x4.bin'))

        elif model_params['model_name'] == "HDRNetGaussianPyrNN":
            g = tf.get_default_graph()
            for lvl in range(3):
                conv1_w = g.get_tensor_by_name(
                    'inference/guide/level_{}/conv1/weights:0'.format(lvl))
                conv1_b = g.get_tensor_by_name(
                    'inference/guide/level_{}/conv1/BatchNorm/beta:0'.format(
                        lvl))
                conv1_mu = g.get_tensor_by_name(
                    'inference/guide/level_{}/conv1/BatchNorm/moving_mean:0'.
                    format(lvl))
                conv1_sigma = g.get_tensor_by_name(
                    'inference/guide/level_{}/conv1/BatchNorm/moving_variance:0'
                    .format(lvl))
                conv1_eps = g.get_tensor_by_name(
                    'inference/guide/level_{}/conv1/BatchNorm/batchnorm/add/y:0'
                    .format(lvl))
                conv2_w = g.get_tensor_by_name(
                    'inference/guide/level_{}/conv2/weights:0'.format(lvl))
                conv2_b = g.get_tensor_by_name(
                    'inference/guide/level_{}/conv2/biases:0'.format(lvl))

                conv1w_, conv1b_, conv1mu_, conv1sigma_, conv1eps_, conv2w_, conv2b_ = sess.run(
                    [
                        conv1_w, conv1_b, conv1_mu, conv1_sigma, conv1_eps,
                        conv2_w, conv2_b
                    ])

                conv1b_ -= conv1mu_ / np.sqrt((conv1sigma_ + conv1eps_))
                conv1w_ = conv1w_ / np.sqrt((conv1sigma_ + conv1eps_))

                conv1w_ = np.squeeze(conv1w_.astype(np.float32))
                conv1b_ = np.squeeze(conv1b_.astype(np.float32))
                conv1b_ = conv1b_[np.newaxis, :]

                conv2w_ = np.squeeze(conv2w_.astype(np.float32))
                conv2b_ = np.squeeze(conv2b_.astype(np.float32))

                conv2 = np.append(conv2w_, conv2b_)
                conv1 = np.vstack([conv1w_, conv1b_])

                save(
                    conv1.T,
                    os.path.join(args.checkpoint_dir,
                                 'guide_level{}_conv1.bin'.format(lvl)))
                save(
                    conv2,
                    os.path.join(args.checkpoint_dir,
                                 'guide_level{}_conv2.bin'.format(lvl)))

        elif model_params['model_name'] in "HDRNetPointwiseNNGuide":
            g = tf.get_default_graph()
            conv1_w = g.get_tensor_by_name('inference/guide/conv1/weights:0')
            conv1_b = g.get_tensor_by_name(
                'inference/guide/conv1/BatchNorm/beta:0')
            conv1_mu = g.get_tensor_by_name(
                'inference/guide/conv1/BatchNorm/moving_mean:0')
            conv1_sigma = g.get_tensor_by_name(
                'inference/guide/conv1/BatchNorm/moving_variance:0')
            conv1_eps = g.get_tensor_by_name(
                'inference/guide/conv1/BatchNorm/batchnorm/add/y:0')
            conv2_w = g.get_tensor_by_name('inference/guide/conv2/weights:0')
            conv2_b = g.get_tensor_by_name('inference/guide/conv2/biases:0')

            conv1w_, conv1b_, conv1mu_, conv1sigma_, conv1eps_, conv2w_, conv2b_ = sess.run(
                [
                    conv1_w, conv1_b, conv1_mu, conv1_sigma, conv1_eps,
                    conv2_w, conv2_b
                ])

            conv1b_ -= conv1mu_ / np.sqrt((conv1sigma_ + conv1eps_))
            conv1w_ = conv1w_ / np.sqrt((conv1sigma_ + conv1eps_))

            conv1w_ = np.squeeze(conv1w_.astype(np.float32))
            conv1b_ = np.squeeze(conv1b_.astype(np.float32))
            conv1b_ = conv1b_[np.newaxis, :]

            conv2w_ = np.squeeze(conv2w_.astype(np.float32))
            conv2b_ = np.squeeze(conv2b_.astype(np.float32))

            conv2 = np.append(conv2w_, conv2b_)
            conv1 = np.vstack([conv1w_, conv1b_])

            save(conv1.T, os.path.join(args.checkpoint_dir, 'guide_conv1.bin'))
            save(conv2, os.path.join(args.checkpoint_dir, 'guide_conv2.bin'))
示例#2
0
def main(args):
    setproctitle.setproctitle('hdrnet_run')

    inputs = get_input_list(args.input)

    # -------- Load params ----------------------------------------------------
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        checkpoint_path = tf.train.latest_checkpoint(args.checkpoint_dir)
        if checkpoint_path is None:
            log.error('Could not find a checkpoint in {}'.format(
                args.checkpoint_dir))
            return

        metapath = ".".join([checkpoint_path, "meta"])
        log.info('Loading graph from {}'.format(metapath))
        tf.train.import_meta_graph(metapath)

        model_params = utils.get_model_params(sess)

    # -------- Setup graph ----------------------------------------------------
    if not hasattr(models, model_params['model_name']):
        log.error("Model {} does not exist".format(params.model_name))
        return
    mdl = getattr(models, model_params['model_name'])

    tf.reset_default_graph()
    net_shape = model_params['net_input_size']
    t_fullres_input = tf.placeholder(tf.float32, (1, None, None, 3))
    t_lowres_input = tf.placeholder(tf.float32, (1, net_shape, net_shape, 3))

    with tf.variable_scope('inference'):
        prediction = mdl.inference(t_lowres_input,
                                   t_fullres_input,
                                   model_params,
                                   is_training=False)
    output = tf.cast(255.0 * tf.squeeze(tf.clip_by_value(prediction, 0, 1)),
                     tf.uint8)
    saver = tf.train.Saver()

    if args.debug:
        coeffs = tf.get_collection('bilateral_coefficients')[0]
        if len(coeffs.get_shape().as_list()) == 6:
            bs, gh, gw, gd, no, ni = coeffs.get_shape().as_list()
            coeffs = tf.transpose(coeffs, [0, 3, 1, 4, 5, 2])
            coeffs = tf.reshape(coeffs, [bs, gh * gd, gw * ni * no, 1])
            coeffs = tf.squeeze(coeffs)
            m = tf.reduce_max(tf.abs(coeffs))
            coeffs = tf.clip_by_value((coeffs + m) / (2 * m), 0, 1)

        ms = tf.get_collection('multiscale')
        if len(ms) > 0:
            for i, m in enumerate(ms):
                maxi = tf.reduce_max(tf.abs(m))
                m = tf.clip_by_value((m + maxi) / (2 * maxi), 0, 1)
                sz = tf.shape(m)
                m = tf.transpose(m, [0, 1, 3, 2])
                m = tf.reshape(m, [sz[0], sz[1], sz[2] * sz[3]])
                ms[i] = tf.squeeze(m)

        fr = tf.get_collection('fullres_features')
        if len(fr) > 0:
            for i, m in enumerate(fr):
                maxi = tf.reduce_max(tf.abs(m))
                m = tf.clip_by_value((m + maxi) / (2 * maxi), 0, 1)
                sz = tf.shape(m)
                m = tf.transpose(m, [0, 1, 3, 2])
                m = tf.reshape(m, [sz[0], sz[1], sz[2] * sz[3]])
                fr[i] = tf.squeeze(m)

        guide = tf.get_collection('guide')
        if len(guide) > 0:
            for i, g in enumerate(guide):
                maxi = tf.reduce_max(tf.abs(g))
                g = tf.clip_by_value((g + maxi) / (2 * maxi), 0, 1)
                guide[i] = tf.squeeze(g)

    with tf.Session(config=config) as sess:
        log.info('Restoring weights from {}'.format(checkpoint_path))
        saver.restore(sess, checkpoint_path)

        for idx, input_path in enumerate(inputs):
            if args.limit is not None and idx >= args.limit:
                log.info("Stopping at limit {}".format(args.limit))
                break

            log.info("Processing {}".format(input_path))
            im_input = cv2.imread(input_path,
                                  -1)  # -1 means read as is, no conversions.
            if im_input.shape[2] == 4:
                log.info("Input {} has 4 channels, dropping alpha".format(
                    input_path))
                im_input = im_input[:, :, :3]

            im_input = np.flip(im_input,
                               2)  # OpenCV reads BGR, convert back to RGB.

            log.info("Max level: {}".format(np.amax(im_input[:, :, 0])))
            log.info("Max level: {}".format(np.amax(im_input[:, :, 1])))
            log.info("Max level: {}".format(np.amax(im_input[:, :, 2])))

            # HACK for HDR+.
            if im_input.dtype == np.uint16 and args.hdrp:
                log.info(
                    "Using HDR+ hack for uint16 input. Assuming input white level is 32767."
                )
                # im_input = im_input / 32767.0
                # im_input = im_input / 32767.0 /2
                # im_input = im_input / (1.0*2**16)
                im_input = skimage.img_as_float(im_input)
            else:
                im_input = skimage.img_as_float(im_input)

            # Make or Load lowres image
            if args.lowres_input is None:
                lowres_input = skimage.transform.resize(im_input,
                                                        [net_shape, net_shape],
                                                        order=0)
            else:
                raise NotImplemented

            fname = os.path.splitext(os.path.basename(input_path))[0]
            output_path = os.path.join(args.output, fname + ".png")
            basedir = os.path.dirname(output_path)

            im_input = im_input[np.newaxis, :, :, :]
            lowres_input = lowres_input[np.newaxis, :, :, :]

            feed_dict = {
                t_fullres_input: im_input,
                t_lowres_input: lowres_input
            }

            out_ = sess.run(output, feed_dict=feed_dict)

            if not os.path.exists(basedir):
                os.makedirs(basedir)

            skimage.io.imsave(output_path, out_)

            if args.debug:
                output_path = os.path.join(args.output, fname + "_input.png")
                skimage.io.imsave(output_path, np.squeeze(im_input))

                coeffs_ = sess.run(coeffs, feed_dict=feed_dict)
                output_path = os.path.join(args.output, fname + "_coeffs.png")
                skimage.io.imsave(output_path, coeffs_)
                if len(ms) > 0:
                    ms_ = sess.run(ms, feed_dict=feed_dict)
                    for i, m in enumerate(ms_):
                        output_path = os.path.join(
                            args.output, fname + "_ms_{}.png".format(i))
                        skimage.io.imsave(output_path, m)

                if len(fr) > 0:
                    fr_ = sess.run(fr, feed_dict=feed_dict)
                    for i, m in enumerate(fr_):
                        output_path = os.path.join(
                            args.output, fname + "_fr_{}.png".format(i))
                        skimage.io.imsave(output_path, m)

                if len(guide) > 0:
                    guide_ = sess.run(guide, feed_dict=feed_dict)
                    for i, g in enumerate(guide_):
                        output_path = os.path.join(
                            args.output, fname + "_guide_{}.png".format(i))
                        skimage.io.imsave(output_path, g)
示例#3
0
文件: run.py 项目: hexfaker/hdrnet
def main(args):
    setproctitle.setproctitle('hdrnet_run')

    param_order, inputs = get_input_list_csv(args.input)

    # -------- Load params ----------------------------------------------------
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        checkpoint_path = args.checkpoint_path
        if os.path.isdir(checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)

        if checkpoint_path is None:
            log.error('Could not find a checkpoint in {}'.format(args.checkpoint_path))
            return

        metapath = ".".join([checkpoint_path, "meta"])
        log.info('Loading graph from {}'.format(metapath))
        tf.train.import_meta_graph(metapath)

        model_params = utils.get_model_params(sess)

    model_params['model_name'] = dec(model_params['model_name'])

    # -------- Setup graph ----------------------------------------------------
    print(model_params['model_name'])
    if not hasattr(models, model_params['model_name']):
        log.error("Model {} does not exist".format(model_params.model_name))
        return
    mdl = getattr(models, model_params['model_name'])

    tf.reset_default_graph()
    net_shape = model_params['net_input_size']
    t_fullres_input = tf.placeholder(tf.float32, (1, None, None, 3))
    t_params_input = tf.placeholder(tf.float32, (1, model_params['lr_params']))

    with tf.variable_scope('inference'):
        prediction = mdl.inference(
            tf_resize(t_fullres_input, [net_shape, net_shape]),
            t_fullres_input, t_params_input, model_params, is_training=False)
    output = tf.squeeze(tf.clip_by_value(prediction, 0, 1))
    saver = tf.train.Saver()

    image_stats, param_stats = load_stats(args.stats)
    print(param_stats)
    print(image_stats)

    if args.debug:
        coeffs = tf.get_collection('bilateral_coefficients')[0]
        if len(coeffs.get_shape().as_list()) == 6:
            bs, gh, gw, gd, no, ni = coeffs.get_shape().as_list()
            coeffs = tf.transpose(coeffs, [0, 3, 1, 4, 5, 2])
            coeffs = tf.reshape(coeffs, [bs, gh * gd, gw * ni * no, 1])
            coeffs = tf.squeeze(coeffs)
            m = tf.reduce_max(tf.abs(coeffs))
            coeffs = tf.clip_by_value((coeffs + m) / (2 * m), 0, 1)

        ms = tf.get_collection('multiscale')
        if len(ms) > 0:
            for i, m in enumerate(ms):
                maxi = tf.reduce_max(tf.abs(m))
                m = tf.clip_by_value((m + maxi) / (2 * maxi), 0, 1)
                sz = tf.shape(m)
                m = tf.transpose(m, [0, 1, 3, 2])
                m = tf.reshape(m, [sz[0], sz[1], sz[2] * sz[3]])
                ms[i] = tf.squeeze(m)

        fr = tf.get_collection('fullres_features')
        if len(fr) > 0:
            for i, m in enumerate(fr):
                maxi = tf.reduce_max(tf.abs(m))
                m = tf.clip_by_value((m + maxi) / (2 * maxi), 0, 1)
                sz = tf.shape(m)
                m = tf.transpose(m, [0, 1, 3, 2])
                m = tf.reshape(m, [sz[0], sz[1], sz[2] * sz[3]])
                fr[i] = tf.squeeze(m)

        guide = tf.get_collection('guide')
        if len(guide) > 0:
            for i, g in enumerate(guide):
                maxi = tf.reduce_max(tf.abs(g))
                g = tf.clip_by_value((g + maxi) / (2 * maxi), 0, 1)
                guide[i] = tf.squeeze(g)

    with tf.Session(config=config) as sess:
        log.info('Restoring weights from {}'.format(checkpoint_path))
        saver.restore(sess, checkpoint_path)

        loss = []
        for idx, (input_path, gt_path, params) in enumerate(tqdm.tqdm(inputs)):
            if args.limit is not None and idx >= args.limit:
                log.info("Stopping at limit {}".format(args.limit))
                break

            input_image = load_image(input_path, args.eval_resolution)
            gt_image = load_image(gt_path, args.eval_resolution)

            params = normalize_params(np.array(params), param_stats, param_order)
            norm_input_image = normalize_image(input_image, image_stats)

            basedir = args.output
            prefix = os.path.splitext(os.path.basename(input_path))[0]
            output_path = os.path.join(basedir, prefix + "_out.jpg")
            gt_copy_path = os.path.join(basedir, prefix + "_gt.jpg")
            input_copy_path = os.path.join(basedir, prefix + "_1n.jpg")  # Not typo. ordering

            norm_input_image = norm_input_image[np.newaxis, :, :, :]
            params = np.array(params)[np.newaxis, :]

            feed_dict = {
                t_fullres_input: norm_input_image,
                t_params_input: params
            }

            out_image = sess.run(output, feed_dict=feed_dict)

            if not os.path.exists(basedir):
                os.makedirs(basedir)

            loss.append(np.mean(np.abs(gt_image - out_image)))

            skimage.io.imsave(output_path, save_img(out_image))
            skimage.io.imsave(input_copy_path, save_img(input_image))
            skimage.io.imsave(gt_copy_path, save_img(gt_image))

            if args.debug:
                output_path = os.path.join(args.output, prefix + "_input.png")
                skimage.io.imsave(output_path, np.squeeze(norm_input_image))

                coeffs_ = sess.run(coeffs, feed_dict=feed_dict)
                output_path = os.path.join(args.output, prefix + "_coeffs.png")
                skimage.io.imsave(output_path, coeffs_)
                if len(ms) > 0:
                    ms_ = sess.run(ms, feed_dict=feed_dict)
                    for i, m in enumerate(ms_):
                        output_path = os.path.join(args.output, prefix + "_ms_{}.png".format(i))
                        skimage.io.imsave(output_path, m)

                if len(fr) > 0:
                    fr_ = sess.run(fr, feed_dict=feed_dict)
                    for i, m in enumerate(fr_):
                        output_path = os.path.join(args.output, prefix + "_fr_{}.png".format(i))
                        skimage.io.imsave(output_path, m)

                if len(guide) > 0:
                    guide_ = sess.run(guide, feed_dict=feed_dict)
                    for i, g in enumerate(guide_):
                        output_path = os.path.join(args.output, prefix + "_guide_{}.png".format(i))
                        skimage.io.imsave(output_path, g)

    print("Loss: " + str(np.mean(loss)))
示例#4
0
def main(_):
    with tf.Graph().as_default():
        # Inject placeholder into the graph
        serialized_tf_example = tf.placeholder(tf.string, name='input_image')
        serialized_low_example = tf.placeholder(tf.string, name='low_image')
        #serialized_shape = tf.placeholder(tf.string, name='shape_image')
        feature_configs = {
            'image/encoded': tf.FixedLenFeature(shape=[], dtype=tf.string)
        }
        tf_example = tf.parse_example(serialized_tf_example, feature_configs)
        tf_low_example = tf.parse_example(serialized_low_example,
                                          feature_configs)
        #tf_low_shape = tf.parse_example(serialized_shape, feature_configs)

        jpegs = tf_example['image/encoded']
        low_jpegs = tf_low_example['image/encoded']
        #shape_jpegs = tf_low_shape['image/encoded']

        full_images = tf.map_fn(preprocess_image, jpegs, dtype=tf.float32)
        low_images = tf.map_fn(preprocess_low_image,
                               low_jpegs,
                               dtype=tf.float32)
        #full_images = tf.squeeze(full_images, [0])
        #low_images = tf.squeeze(low_images, [0])

        # now the image shape is (1, ?, ?, 3)

        # Create model
        checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)

        metapath = ".".join([checkpoint_path, "meta"])
        tf.train.import_meta_graph(metapath)
        with tf.Session() as sess:
            model_params = utils.get_model_params(sess)
        mdl = getattr(models, model_params['model_name'])

        with tf.variable_scope('inference'):
            prediction = mdl.inference(low_images,
                                       full_images,
                                       model_params,
                                       is_training=False)
        output = tf.cast(
            255.0 * tf.squeeze(tf.clip_by_value(prediction, 0, 1)), tf.uint8)
        #output_img = tf.image.encode_png(tf.image.convert_image_dtype(output[0], dtype=tf.uint8))

        # Create saver to restore from checkpoints
        saver = tf.train.Saver()

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            # Restore the model from last checkpoints
            ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
            saver.restore(sess, ckpt.model_checkpoint_path)

            # (re-)create export directory
            export_path = os.path.join(
                tf.compat.as_bytes(FLAGS.output_dir),
                tf.compat.as_bytes(str(FLAGS.model_version)))
            if os.path.exists(export_path):
                shutil.rmtree(export_path)

            # create model builder
            builder = tf.saved_model.builder.SavedModelBuilder(export_path)

            # create tensors info
            predict_tensor_inputs_info = tf.saved_model.utils.build_tensor_info(
                jpegs)
            predict_tensor_low_info = tf.saved_model.utils.build_tensor_info(
                low_jpegs)
            #predict_tensor_shape_info = tf.saved_model.utils.build_tensor_info(shape_jpegs)
            predict_tensor_scores_info = tf.saved_model.utils.build_tensor_info(
                output)

            # build prediction signature
            prediction_signature = (
                tf.saved_model.signature_def_utils.build_signature_def(
                    inputs={
                        'images': predict_tensor_inputs_info,
                        'low': predict_tensor_low_info
                    },
                    #'shape': predict_tensor_shape_info},
                    outputs={'result': predict_tensor_scores_info},
                    method_name=tf.saved_model.signature_constants.
                    PREDICT_METHOD_NAME))

            # save the model
            #legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
            builder.add_meta_graph_and_variables(
                sess, [tf.saved_model.tag_constants.SERVING],
                signature_def_map={'predict_images': prediction_signature})
            #legacy_init_op=legacy_init_op)

            builder.save()

    print("Successfully exported hdr model version '{}' into '{}'".format(
        FLAGS.model_version, FLAGS.output_dir))
示例#5
0
def main(args):
  # Read model parameters
  checkpoint_path = tf.train.latest_checkpoint(args.checkpoint_dir)
  if checkpoint_path is None:
    log.error('Could not find a checkpoint in {}'.format(args.checkpoint_dir))
    return
  metapath = ".".join([checkpoint_path, "meta"])
  log.info("Loading {}".format(metapath))
  tf.train.import_meta_graph(metapath)
  with tf.Session() as sess:
    model_params = utils.get_model_params(sess)

  if not hasattr(models, model_params['model_name']):
    log.error("Model {} does not exist".format(model_params['model_name']))
    return
  mdl = getattr(models, model_params['model_name'])

  # Instantiate new evaluation graph
  tf.reset_default_graph()
  sz = model_params['net_input_size']

  log.info("Model {}".format(model_params['model_name']))

  input_tensor = tf.placeholder(tf.float32, [1, sz, sz, 3], name='lowres_input')
  with tf.variable_scope('inference'):
    prediction = mdl.inference(input_tensor, input_tensor, model_params, is_training=False)
  if model_params["model_name" ] == "HDRNetGaussianPyrNN":
    output_tensor = tf.get_collection('packed_coefficients')[0]
    output_tensor = tf.transpose(tf.squeeze(output_tensor), [3, 2, 0, 1, 4], name="output_coefficients")
    log.info("Output shape".format(output_tensor.get_shape()))
  else:
    output_tensor = tf.get_collection('packed_coefficients')[0]
    output_tensor = tf.transpose(tf.squeeze(output_tensor), [3, 2, 0, 1, 4], name="output_coefficients")
    log.info("Output shape {}".format(output_tensor.get_shape()))
  saver = tf.train.Saver()

  gdef = tf.get_default_graph().as_graph_def()

  log.info("Restoring weights from {}".format(checkpoint_path))
  test_graph_name = "test_graph.pbtxt"
  with tf.Session() as sess:
    saver.restore(sess, checkpoint_path)
    tf.train.write_graph(sess.graph, args.checkpoint_dir, test_graph_name)

    input_graph_path = os.path.join(args.checkpoint_dir, test_graph_name)
    output_graph_path = os.path.join(args.checkpoint_dir, "frozen_graph.pb")
    input_saver_def_path = ""
    input_binary = False
    output_binary = True
    input_node_names = input_tensor.name.split(":")[0]
    output_node_names = output_tensor.name.split(":")[0]
    restore_op_name = "save/restore_all"
    filename_tensor_name = "save/Const:0"
    clear_devices = False

    log.info("Freezing to {}".format(output_graph_path))
    freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
                              input_binary, checkpoint_path, output_node_names,
                              restore_op_name, filename_tensor_name,
                              output_graph_path, clear_devices, "")
    log.info('input tensor: {} {}'.format(input_tensor.name, input_tensor.shape))
    log.info('output tensor: {} {}'.format(output_tensor.name, output_tensor.shape))

    # Dump guide parameters
    if model_params['model_name'] == 'HDRNetCurves':
      g = tf.get_default_graph()
      ccm = g.get_tensor_by_name('inference/guide/ccm:0')
      ccm_bias = g.get_tensor_by_name('inference/guide/ccm_bias:0')
      shifts = g.get_tensor_by_name('inference/guide/shifts:0')
      slopes = g.get_tensor_by_name('inference/guide/slopes:0')
      mixing_weights = g.get_tensor_by_name('inference/guide/channel_mixing/weights:0')
      mixing_bias = g.get_tensor_by_name('inference/guide/channel_mixing/biases:0')

      ccm_, ccm_bias_, shifts_, slopes_, mixing_weights_, mixing_bias_ = sess.run(
              [ccm, ccm_bias, shifts, slopes, mixing_weights, mixing_bias])
      shifts_ = np.squeeze(shifts_).astype(np.float32)
      slopes_ = np.squeeze(slopes_).astype(np.float32)
      mix_matrix_dump = np.append(np.squeeze(mixing_weights_), mixing_bias_[0]).astype(np.float32)
      ccm34_ = np.vstack((ccm_, ccm_bias_[np.newaxis, :]))

      save(ccm34_.T, os.path.join(args.checkpoint_dir, 'guide_ccm_f32_3x4.bin'))
      save(shifts_.T, os.path.join(args.checkpoint_dir, 'guide_shifts_f32_16x3.bin'))
      save(slopes_.T, os.path.join(args.checkpoint_dir, 'guide_slopes_f32_16x3.bin'))
      save(mix_matrix_dump, os.path.join(args.checkpoint_dir, 'guide_mix_matrix_f32_1x4.bin'))

    elif model_params['model_name'] == "HDRNetGaussianPyrNN":
      g = tf.get_default_graph()
      for lvl in range(3):
        conv1_w = g.get_tensor_by_name('inference/guide/level_{}/conv1/weights:0'.format(lvl))
        conv1_b = g.get_tensor_by_name('inference/guide/level_{}/conv1/BatchNorm/beta:0'.format(lvl))
        conv1_mu = g.get_tensor_by_name('inference/guide/level_{}/conv1/BatchNorm/moving_mean:0'.format(lvl))
        conv1_sigma = g.get_tensor_by_name('inference/guide/level_{}/conv1/BatchNorm/moving_variance:0'.format(lvl))
        conv1_eps = g.get_tensor_by_name('inference/guide/level_{}/conv1/BatchNorm/batchnorm/add/y:0'.format(lvl))
        conv2_w = g.get_tensor_by_name('inference/guide/level_{}/conv2/weights:0'.format(lvl))
        conv2_b = g.get_tensor_by_name('inference/guide/level_{}/conv2/biases:0'.format(lvl))

        conv1w_, conv1b_, conv1mu_, conv1sigma_, conv1eps_, conv2w_, conv2b_ = sess.run(
            [conv1_w, conv1_b, conv1_mu, conv1_sigma, conv1_eps, conv2_w, conv2_b])

        conv1b_ -= conv1mu_/np.sqrt((conv1sigma_+conv1eps_))
        conv1w_ = conv1w_/np.sqrt((conv1sigma_+conv1eps_))

        conv1w_ = np.squeeze(conv1w_.astype(np.float32))
        conv1b_ = np.squeeze(conv1b_.astype(np.float32))
        conv1b_ = conv1b_[np.newaxis, :]

        conv2w_ = np.squeeze(conv2w_.astype(np.float32))
        conv2b_ = np.squeeze(conv2b_.astype(np.float32))

        conv2 = np.append(conv2w_, conv2b_)
        conv1 = np.vstack([conv1w_, conv1b_])

        save(conv1.T, os.path.join(args.checkpoint_dir, 'guide_level{}_conv1.bin'.format(lvl)))
        save(conv2, os.path.join(args.checkpoint_dir, 'guide_level{}_conv2.bin'.format(lvl)))

    elif model_params['model_name'] in "HDRNetPointwiseNNGuide":
      g = tf.get_default_graph()
      conv1_w = g.get_tensor_by_name('inference/guide/conv1/weights:0')
      conv1_b = g.get_tensor_by_name('inference/guide/conv1/BatchNorm/beta:0')
      conv1_mu = g.get_tensor_by_name('inference/guide/conv1/BatchNorm/moving_mean:0')
      conv1_sigma = g.get_tensor_by_name('inference/guide/conv1/BatchNorm/moving_variance:0')
      conv1_eps = g.get_tensor_by_name('inference/guide/conv1/BatchNorm/batchnorm/add/y:0')
      conv2_w = g.get_tensor_by_name('inference/guide/conv2/weights:0')
      conv2_b = g.get_tensor_by_name('inference/guide/conv2/biases:0')

      conv1w_, conv1b_, conv1mu_, conv1sigma_, conv1eps_, conv2w_, conv2b_ = sess.run(
          [conv1_w, conv1_b, conv1_mu, conv1_sigma, conv1_eps, conv2_w, conv2_b])

      conv1b_ -= conv1mu_/np.sqrt((conv1sigma_+conv1eps_))
      conv1w_ = conv1w_/np.sqrt((conv1sigma_+conv1eps_))

      conv1w_ = np.squeeze(conv1w_.astype(np.float32))
      conv1b_ = np.squeeze(conv1b_.astype(np.float32))
      conv1b_ = conv1b_[np.newaxis, :]

      conv2w_ = np.squeeze(conv2w_.astype(np.float32))
      conv2b_ = np.squeeze(conv2b_.astype(np.float32))

      conv2 = np.append(conv2w_, conv2b_)
      conv1 = np.vstack([conv1w_, conv1b_])

      save(conv1.T, os.path.join(args.checkpoint_dir, 'guide_conv1.bin'))
      save(conv2, os.path.join(args.checkpoint_dir, 'guide_conv2.bin'))
示例#6
0
文件: run.py 项目: KeyKy/hdrnet
def main(args):
  setproctitle.setproctitle('hdrnet_run')

  inputs = get_input_list(args.input)

  # -------- Load params ----------------------------------------------------
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  with tf.Session(config=config) as sess:
    checkpoint_path = tf.train.latest_checkpoint(args.checkpoint_dir)
    if checkpoint_path is None:
      log.error('Could not find a checkpoint in {}'.format(args.checkpoint_dir))
      return

    metapath = ".".join([checkpoint_path, "meta"])
    log.info('Loading graph from {}'.format(metapath))
    tf.train.import_meta_graph(metapath)

    model_params = utils.get_model_params(sess)

  # -------- Setup graph ----------------------------------------------------
  if not hasattr(models, model_params['model_name']):
    log.error("Model {} does not exist".format(params.model_name))
    return
  mdl = getattr(models, model_params['model_name'])

  tf.reset_default_graph()
  net_shape = model_params['net_input_size']
  t_fullres_input = tf.placeholder(tf.float32, (1, None, None, 3))
  t_lowres_input = tf.placeholder(tf.float32, (1, net_shape, net_shape, 3))

  with tf.variable_scope('inference'):
    prediction = mdl.inference(
        t_lowres_input, t_fullres_input, model_params, is_training=False)
  output = tf.cast(255.0*tf.squeeze(tf.clip_by_value(prediction, 0, 1)), tf.uint8)
  saver = tf.train.Saver()

  if args.debug:
    coeffs = tf.get_collection('bilateral_coefficients')[0]
    if len(coeffs.get_shape().as_list()) == 6:
      bs, gh, gw, gd, no, ni = coeffs.get_shape().as_list()
      coeffs = tf.transpose(coeffs, [0, 3, 1, 4, 5, 2])
      coeffs = tf.reshape(coeffs, [bs, gh*gd, gw*ni*no, 1])
      coeffs = tf.squeeze(coeffs)
      m = tf.reduce_max(tf.abs(coeffs))
      coeffs = tf.clip_by_value((coeffs+m)/(2*m), 0, 1)

    ms = tf.get_collection('multiscale')
    if len(ms) > 0:
      for i, m in enumerate(ms):
        maxi = tf.reduce_max(tf.abs(m))
        m = tf.clip_by_value((m+maxi)/(2*maxi), 0, 1)
        sz = tf.shape(m)
        m = tf.transpose(m, [0, 1, 3, 2])
        m = tf.reshape(m, [sz[0], sz[1], sz[2]*sz[3]])
        ms[i] = tf.squeeze(m)

    fr = tf.get_collection('fullres_features')
    if len(fr) > 0:
      for i, m in enumerate(fr):
        maxi = tf.reduce_max(tf.abs(m))
        m = tf.clip_by_value((m+maxi)/(2*maxi), 0, 1)
        sz = tf.shape(m)
        m = tf.transpose(m, [0, 1, 3, 2])
        m = tf.reshape(m, [sz[0], sz[1], sz[2]*sz[3]])
        fr[i] = tf.squeeze(m)

    guide = tf.get_collection('guide')
    if len(guide) > 0:
      for i, g in enumerate(guide):
        maxi = tf.reduce_max(tf.abs(g))
        g = tf.clip_by_value((g+maxi)/(2*maxi), 0, 1)
        guide[i] = tf.squeeze(g)

  with tf.Session(config=config) as sess:
    log.info('Restoring weights from {}'.format(checkpoint_path))
    saver.restore(sess, checkpoint_path)

    for idx, input_path in enumerate(inputs):
      if args.limit is not None and idx >= args.limit:
        log.info("Stopping at limit {}".format(args.limit))
        break

      log.info("Processing {}".format(input_path))
      im_input = cv2.imread(input_path, -1)  # -1 means read as is, no conversions.
      if im_input.shape[2] == 4:
        log.info("Input {} has 4 channels, dropping alpha".format(input_path))
        im_input = im_input[:, :, :3]

      im_input = np.flip(im_input, 2)  # OpenCV reads BGR, convert back to RGB.

      log.info("Max level: {}".format(np.amax(im_input[:, :, 0])))
      log.info("Max level: {}".format(np.amax(im_input[:, :, 1])))
      log.info("Max level: {}".format(np.amax(im_input[:, :, 2])))

      # HACK for HDR+.
      if im_input.dtype == np.uint16 and args.hdrp:
        log.info("Using HDR+ hack for uint16 input. Assuming input white level is 32767.")
        # im_input = im_input / 32767.0
        # im_input = im_input / 32767.0 /2
        # im_input = im_input / (1.0*2**16)
        im_input = skimage.img_as_float(im_input)
      else:
        im_input = skimage.img_as_float(im_input)

      # Make or Load lowres image
      if args.lowres_input is None:
        lowres_input = skimage.transform.resize(
            im_input, [net_shape, net_shape], order = 0)
      else:
        raise NotImplemented

      fname = os.path.splitext(os.path.basename(input_path))[0]
      output_path = os.path.join(args.output, fname+".png")
      basedir = os.path.dirname(output_path)

      im_input = im_input[np.newaxis, :, :, :]
      lowres_input = lowres_input[np.newaxis, :, :, :]

      feed_dict = {
          t_fullres_input: im_input,
          t_lowres_input: lowres_input
      }

      out_ =  sess.run(output, feed_dict=feed_dict)

      if not os.path.exists(basedir):
        os.makedirs(basedir)

      skimage.io.imsave(output_path, out_)

      if args.debug:
        output_path = os.path.join(args.output, fname+"_input.png")
        skimage.io.imsave(output_path, np.squeeze(im_input))

        coeffs_ = sess.run(coeffs, feed_dict=feed_dict)
        output_path = os.path.join(args.output, fname+"_coeffs.png")
        skimage.io.imsave(output_path, coeffs_)
        if len(ms) > 0:
          ms_ = sess.run(ms, feed_dict=feed_dict)
          for i, m in enumerate(ms_):
            output_path = os.path.join(args.output, fname+"_ms_{}.png".format(i))
            skimage.io.imsave(output_path, m)

        if len(fr) > 0:
          fr_ = sess.run(fr, feed_dict=feed_dict)
          for i, m in enumerate(fr_):
            output_path = os.path.join(args.output, fname+"_fr_{}.png".format(i))
            skimage.io.imsave(output_path, m)

        if len(guide) > 0:
          guide_ = sess.run(guide, feed_dict=feed_dict)
          for i, g in enumerate(guide_):
            output_path = os.path.join(args.output, fname+"_guide_{}.png".format(i))
            skimage.io.imsave(output_path, g)