예제 #1
0
    def testConvertPredictionsToVideoSummaries(self):
        # Initialize predictions.
        rng = np.random.RandomState(0)
        inputs = rng.randint(0, 255, (2, 32, 32, 3))
        outputs = rng.randint(0, 255, (5, 32, 32, 3))
        targets = rng.randint(0, 255, (5, 32, 32, 3))

        # batch it up.
        prediction = [{
            "outputs": outputs,
            "inputs": inputs,
            "targets": targets
        }] * 5
        predictions = [prediction]
        decode_hparams = decoding.decode_hparams()

        decode_hooks = decoding.DecodeHookArgs(estimator=None,
                                               problem=None,
                                               output_dirs=None,
                                               hparams=decode_hparams,
                                               decode_hparams=decode_hparams,
                                               predictions=predictions)
        summaries = video_utils.display_video_hooks(decode_hooks)

        for summary in summaries:
            self.assertTrue(isinstance(summary, tf.Summary.Value))
예제 #2
0
 def testDecodeInMemoryTrue(self):
   predictions, problem = self.get_predictions()
   decode_hparams = decoding.decode_hparams()
   decode_hparams.decode_in_memory = True
   decode_hooks = decoding.DecodeHookArgs(
       estimator=None, problem=problem, output_dirs=None,
       hparams=decode_hparams, decode_hparams=decode_hparams,
       predictions=predictions)
   metrics = video_utils.summarize_video_metrics(decode_hooks)
예제 #3
0
  def testConvertPredictionsToImageSummaries(self):
    # Initialize predictions.
    rng = np.random.RandomState(0)
    x = rng.randint(0, 255, (32, 32, 3))
    predictions = [[{"outputs": x, "inputs": x}] * 50]

    decode_hparams = decoding.decode_hparams()
    # should return 20 summaries of images, 10 outputs and 10 inputs if
    # display_decoded_images is set to True.
    for display, summaries_length in zip([True, False], [20, 0]):
      decode_hparams.display_decoded_images = display
      decode_hooks = decoding.DecodeHookArgs(
          estimator=None, problem=None, output_dirs=None,
          hparams=decode_hparams, decode_hparams=decode_hparams,
          predictions=predictions)
      summaries = image_utils.convert_predictions_to_image_summaries(
          decode_hooks)
      self.assertEqual(len(summaries), summaries_length)
      if summaries:
        self.assertTrue(isinstance(summaries[0], tf.Summary.Value))
def decode_from_dataset(estimator,
                        problem_name,
                        hparams,
                        decode_hp,
                        decode_to_file=None,
                        dataset_split=None):
    """Perform decoding from dataset."""
    tf.logging.info("Performing local inference from dataset for %s.",
                    str(problem_name))

    shard = decode_hp.shard_id if decode_hp.shards > 1 else None

    output_dir = os.path.join(estimator.model_dir, "decode")
    tf.gfile.MakeDirs(output_dir)

    if decode_hp.batch_size:
        hparams.batch_size = decode_hp.batch_size
        hparams.use_fixed_batch_size = True

    dataset_kwargs = {
        "shard": shard,
        "dataset_split": dataset_split,
        "max_records": decode_hp.num_samples
    }

    problem = hparams.problem
    infer_input_fn = problem.make_estimator_input_fn(
        tf.estimator.ModeKeys.PREDICT, hparams, dataset_kwargs=dataset_kwargs)

    predictions = estimator.predict(infer_input_fn)

    decode_to_file = decode_to_file or decode_hp.decode_to_file
    if decode_to_file:
        if decode_hp.shards > 1:
            decode_filename = decode_to_file + ("%.2d" % decode_hp.shard_id)
        else:
            decode_filename = decode_to_file

        output_filepath = decoding._decode_filename(decode_filename,
                                                    problem_name, decode_hp)
        parts = output_filepath.split(".")
        parts[-1] = "targets"
        target_filepath = ".".join(parts)
        parts[-1] = "inputs"
        input_filepath = ".".join(parts)
        parts[-1] = "enc_state"
        encoder_state_file_path = ".".join(parts)

        input_file = tf.gfile.Open(input_filepath, "w")

    problem_hparams = hparams.problem_hparams
    has_input = "inputs" in problem_hparams.vocabulary
    inputs_vocab_key = "inputs" if has_input else "targets"
    inputs_vocab = problem_hparams.vocabulary[inputs_vocab_key]

    ##### Modified #####
    # Encoder outputs list created

    encoder_outputs = []
    decoded_inputs = []

    for num_predictions, prediction in enumerate(predictions):
        num_predictions += 1
        inputs = prediction["inputs"]
        encoder_output = prediction["encoder_outputs"]
        decoded_input = inputs_vocab.decode(
            decoding._save_until_eos(inputs, False))

        encoder_outputs.append(encoder_output)
        decoded_inputs.append(decoded_input)

        ##### Modified #####
        # Writing encoder_outputs list to file

        if decode_to_file:
            for i, (e_output, d_input) in \
                    enumerate(zip(encoder_outputs, decoded_inputs)):

                input_file.write("{}:\t{}".format(
                    i,
                    str(d_input) + decode_hp.delimiter))

            np.save(encoder_state_file_path, np.array(encoder_outputs))

        if (0 <= decode_hp.num_samples <= num_predictions):
            break

    if decode_to_file:
        input_file.close()

    decoding.decorun_postdecode_hooks(
        decoding.DecodeHookArgs(estimator=estimator,
                                problem=problem,
                                output_dir=output_dir,
                                hparams=hparams,
                                decode_hparams=decode_hp))

    tf.logging.info("Completed inference on %d samples." % num_predictions)  # pylint: disable=undefined-loop-variable