예제 #1
0
        def sample_image(decode_hook_args):
            """Converts decoded predictions into summaries."""
            hparams = decode_hook_args.hparams

            if not hasattr(self, 'mean_npz'):
                mean_filename = os.path.join(hparams.data_dir, 'mean.npz')
                stdev_filename = os.path.join(hparams.data_dir, 'stdev.npz')
                with tf.gfile.open(mean_filename, 'r') as f:
                    self.mean_npz = np.load(f)
                with tf.gfile.open(stdev_filename, 'r') as f:
                    self.stdev_npz = np.load(f)

            values = []
            for pred_dict in decode_hook_args.predictions[0]:
                if hparams.just_render:
                    # vae mode, outputs is image, just do image summary and continue
                    values.append(
                        svg_utils.make_image_summary(pred_dict['outputs'],
                                                     'rendered_outputs'))
                    values.append(
                        svg_utils.make_image_summary(pred_dict['targets'],
                                                     'rendered_targets'))
                    continue

                if common_layers.shape_list(pred_dict['targets'])[0] == 1:
                    continue

                # undo normalize (via gaussian)
                denorm_outputs = (pred_dict['outputs'] *
                                  self.stdev_npz) + self.mean_npz
                denorm_targets = (pred_dict['targets'] *
                                  self.stdev_npz) + self.mean_npz

                # simple cmds are 10 dim (4 one-hot, 6 args).
                # Convert to full SVG spec dimensionality so we can convert it to text.
                denorm_outputs = svg_utils.make_simple_cmds_long(
                    denorm_outputs)
                denorm_targets = svg_utils.make_simple_cmds_long(
                    denorm_targets)

                # sampled text summary
                output_svg = to_img([np.reshape(denorm_outputs, [-1, 30])])
                values.append(
                    svg_utils.make_text_summary_value(output_svg,
                                                      'img/sampled'))

                # original text summary
                target_svg = to_img([np.reshape(denorm_targets, [-1, 30])])
                values.append(
                    svg_utils.make_text_summary_value(target_svg, 'img/og'))

            return values
예제 #2
0
def render(tensor, data_dir):
    """Converts SVG decoder output into HTML svg."""
    # undo normalization
    mean_npz, stdev_npz = get_means_stdevs(data_dir)
    tensor = (tensor * stdev_npz) + mean_npz

    # convert to html
    tensor = svg_utils.make_simple_cmds_long(tensor)
    vector = tf.squeeze(tensor, [0, 2])
    html = svg_utils.vector_to_svg(vector.numpy(),
                                   stop_at_eos=True,
                                   categorical=True)

    # some aesthetic postprocessing
    html = postprocess(html)
    html = html.replace('256px', '50px')

    return html