コード例 #1
0
ファイル: video_utils.py プロジェクト: xiching/tensor2tensor
def display_video_hooks(hook_args):
    """Hooks to display videos at decode time."""
    predictions = hook_args.predictions

    all_summaries = []
    for decode_ind, decode in enumerate(predictions):

        target_videos = video_metrics.stack_data_given_key(decode, "targets")
        output_videos = video_metrics.stack_data_given_key(decode, "outputs")
        input_videos = video_metrics.stack_data_given_key(decode, "inputs")
        target_videos = np.asarray(target_videos, dtype=np.uint8)
        output_videos = np.asarray(output_videos, dtype=np.uint8)
        input_videos = np.asarray(input_videos, dtype=np.uint8)

        input_videos = np.concatenate((input_videos, target_videos), axis=1)
        output_videos = np.concatenate((input_videos, output_videos), axis=1)
        input_summ_vals, _ = common_video.py_gif_summary(
            "decode_%d/input" % decode_ind,
            input_videos,
            max_outputs=10,
            fps=10,
            return_summary_value=True)
        output_summ_vals, _ = common_video.py_gif_summary(
            "decode_%d/output" % decode_ind,
            output_videos,
            max_outputs=10,
            fps=10,
            return_summary_value=True)
        all_summaries.extend(input_summ_vals)
        all_summaries.extend(output_summ_vals)
    return all_summaries
コード例 #2
0
def display_video_hooks(hook_args):
    """Hooks to display videos at decode time."""
    predictions = hook_args.predictions
    fps = hook_args.decode_hparams.frames_per_second
    border_percent = hook_args.decode_hparams.border_percent

    all_summaries = []
    for decode_ind, decode in enumerate(predictions):

        target_videos = video_metrics.stack_data_given_key(decode, "targets")
        output_videos = video_metrics.stack_data_given_key(decode, "outputs")
        input_videos = video_metrics.stack_data_given_key(decode, "inputs")
        target_videos = np.asarray(target_videos, dtype=np.uint8)
        output_videos = np.asarray(output_videos, dtype=np.uint8)
        input_videos = np.asarray(input_videos, dtype=np.uint8)

        input_videos = create_border(input_videos,
                                     color="blue",
                                     border_percent=border_percent)
        target_videos = create_border(target_videos,
                                      color="red",
                                      border_percent=border_percent)
        output_videos = create_border(output_videos,
                                      color="red",
                                      border_percent=border_percent)

        # Video gif.
        all_input = np.concatenate((input_videos, target_videos), axis=1)
        all_output = np.concatenate((input_videos, output_videos), axis=1)

        input_summ_vals, _ = common_video.py_gif_summary(
            "decode_%d/input" % decode_ind,
            all_input,
            max_outputs=10,
            fps=fps,
            return_summary_value=True)
        output_summ_vals, _ = common_video.py_gif_summary(
            "decode_%d/output" % decode_ind,
            all_output,
            max_outputs=10,
            fps=fps,
            return_summary_value=True)
        all_summaries.extend(input_summ_vals)
        all_summaries.extend(output_summ_vals)

        # Frame-by-frame summaries
        iterable = zip(all_input[:10], all_output[:10])
        for ind, (input_video, output_video) in enumerate(iterable):
            t, h, w, c = input_video.shape
            # Tile vertically
            input_frames = np.reshape(input_video, (t * h, w, c))
            output_frames = np.reshape(output_video, (t * h, w, c))

            # Concat across width.
            all_frames = np.concatenate((input_frames, output_frames), axis=1)
            tag = "input/output/decode_%d_sample_%d" % (decode_ind, ind)
            frame_by_frame_summ = image_utils.image_to_tf_summary_value(
                all_frames, tag=tag)
            all_summaries.append(frame_by_frame_summ)
    return all_summaries
コード例 #3
0
def display_video_hooks(hook_args):
    """Hooks to display videos at decode time."""
    predictions = hook_args.predictions
    max_outputs = hook_args.decode_hparams.max_display_outputs
    max_decodes = hook_args.decode_hparams.max_display_decodes

    with tf.Graph().as_default():
        _, best_decodes = video_metrics.compute_video_metrics_from_predictions(
            predictions, decode_hparams=hook_args.decode_hparams)

    all_summaries = []
    # Displays decodes corresponding to the best/worst metric,
    for metric, metric_decode_inds in best_decodes.items():
        curr_metric_inds = metric_decode_inds[:max_outputs]
        best_inputs, best_outputs, best_targets = [], [], []
        for sample_ind, decode_ind in enumerate(curr_metric_inds):
            curr_decode = predictions[decode_ind][sample_ind]
            best_inputs.append(curr_decode["inputs"])
            best_outputs.append(curr_decode["outputs"])
            best_targets.append(curr_decode["targets"])
        best_inputs = np.array(best_inputs, dtype=np.uint8)
        best_outputs = np.array(best_outputs, dtype=np.uint8)
        best_targets = np.array(best_targets, dtype=np.uint8)
        summaries = convert_videos_to_summaries(
            best_inputs,
            best_outputs,
            best_targets,
            tag=metric,
            decode_hparams=hook_args.decode_hparams)
        all_summaries.extend(summaries)

    # Display random decodes for ten conditioning frames.
    for decode_ind, decode in enumerate(predictions[:max_decodes]):
        target_videos = video_metrics.stack_data_given_key(decode, "targets")
        output_videos = video_metrics.stack_data_given_key(decode, "outputs")
        input_videos = video_metrics.stack_data_given_key(decode, "inputs")
        target_videos = np.asarray(target_videos, dtype=np.uint8)
        output_videos = np.asarray(output_videos, dtype=np.uint8)
        input_videos = np.asarray(input_videos, dtype=np.uint8)
        summaries = convert_videos_to_summaries(
            input_videos,
            output_videos,
            target_videos,
            tag="decode_%d" % decode_ind,
            decode_hparams=hook_args.decode_hparams,
            display_ground_truth=decode_ind == 0)
        all_summaries.extend(summaries)
    return all_summaries