Esempio n. 1
0
def decode_from_dataset(estimator,
                        problem_names,
                        return_beams=False,
                        beam_size=1,
                        max_predictions=-1,
                        decode_to_file=None,
                        save_images=False,
                        identity_output=False):
    tf.logging.info("Performing local inference from dataset for %s.",
                    str(problem_names))
    hparams = estimator.hparams

    for problem_idx, problem_name in enumerate(problem_names):
        # Build the inference input function
        infer_problems_data = data_reader.get_data_filepatterns(
            problem_name, hparams.data_dir, tf.contrib.learn.ModeKeys.INFER)

        infer_input_fn = input_fn_builder.build_input_fn(
            mode=tf.contrib.learn.ModeKeys.INFER,
            hparams=hparams,
            data_file_patterns=infer_problems_data,
            num_datashards=devices.data_parallelism().n,
            fixed_problem=problem_idx)

        # Get the predictions as an iterable
        predictions = estimator.predict(input_fn=infer_input_fn,
                                        as_iterable=True)

        # Prepare output file writers if decode_to_file passed
        if decode_to_file:
            output_filepath = decode_to_file + ".outputs." + problem_name
            target_filepath = decode_to_file + ".targets." + problem_name

            output_file = tf.gfile.Open(output_filepath, "w")
            target_file = tf.gfile.Open(target_filepath, "w")

        problem_hparams = hparams.problems[problem_idx]
        inputs_vocab = problem_hparams.vocabulary.get("inputs", None)
        targets_vocab = problem_hparams.vocabulary["targets"]
        for num_predictions, prediction in enumerate(predictions):
            inputs = prediction["inputs"]
            targets = prediction["targets"]
            outputs = prediction["outputs"]

            # Log predictions
            decoded_outputs = []
            if return_beams:
                output_beams = np.split(outputs, beam_size, axis=0)
                for i, beam in enumerate(output_beams):
                    tf.logging.info("BEAM %d:" % i)
                    decoded = _decode_from_dataset_log_results(
                        inputs, targets, beam, problem_name, num_predictions,
                        inputs_vocab, targets_vocab, save_images,
                        estimator.model_dir, identity_output)
                    decoded_outputs.append(decoded)
            else:
                decoded = _decode_from_dataset_log_results(
                    inputs, targets, outputs, problem_name, num_predictions,
                    inputs_vocab, targets_vocab, save_images,
                    estimator.model_dir, identity_output)
                decoded_outputs.append(decoded)

            # Write out predictions if decode_to_file passed
            if decode_to_file:
                for decoded_output, decoded_target in decoded_outputs:
                    output_file.write(str(decoded_output) + "\n")
                    target_file.write(str(decoded_target) + "\n")

            if max_predictions >= 0 and num_predictions >= max_predictions:
                break

        if decode_to_file:
            output_file.close()
            target_file.close()

        tf.logging.info("Completed inference on %d samples." % num_predictions)  # pylint: disable=undefined-loop-variable
Esempio n. 2
0
def decode_from_dataset(estimator):
    hparams = estimator.hparams
    for i, problem in enumerate(FLAGS.problems.split("-")):
        inputs_vocab = hparams.problems[i].vocabulary.get("inputs", None)
        targets_vocab = hparams.problems[i].vocabulary["targets"]
        tf.logging.info("Performing local inference.")
        infer_problems_data = data_reader.get_data_filepatterns(
            FLAGS.problems, hparams.data_dir, tf.contrib.learn.ModeKeys.INFER)

        infer_input_fn = input_fn_builder.build_input_fn(
            mode=tf.contrib.learn.ModeKeys.INFER,
            hparams=hparams,
            data_file_patterns=infer_problems_data,
            num_datashards=devices.data_parallelism().n,
            fixed_problem=i)

        def log_fn(inputs,
                   targets,
                   outputs,
                   problem,
                   j,
                   inputs_vocab=inputs_vocab,
                   targets_vocab=targets_vocab):
            """Log inference results."""
            if "image" in problem and FLAGS.decode_save_images:
                save_path = os.path.join(estimator.model_dir,
                                         "%s_prediction_%d.jpg" % (problem, j))
                show_and_save_image(inputs / 255., save_path)
            elif inputs_vocab:
                decoded_inputs = inputs_vocab.decode(
                    _save_until_eos(inputs.flatten()))
                tf.logging.info("Inference results INPUT: %s" % decoded_inputs)

            if FLAGS.identity_output:
                decoded_outputs = " ".join(map(str, outputs.flatten()))
                decoded_targets = " ".join(map(str, targets.flatten()))
            else:
                decoded_outputs = targets_vocab.decode(
                    _save_until_eos(outputs.flatten()))
                decoded_targets = targets_vocab.decode(
                    _save_until_eos(targets.flatten()))

            tf.logging.info("Inference results OUTPUT: %s" % decoded_outputs)
            tf.logging.info("Inference results TARGET: %s" % decoded_targets)
            if FLAGS.decode_to_file:
                output_filepath = FLAGS.decode_to_file + ".outputs." + problem
                output_file = tf.gfile.Open(output_filepath, "a")
                output_file.write(decoded_outputs + "\n")
                target_filepath = FLAGS.decode_to_file + ".targets." + problem
                target_file = tf.gfile.Open(target_filepath, "a")
                target_file.write(decoded_targets + "\n")

        result_iter = estimator.predict(input_fn=infer_input_fn,
                                        as_iterable=True)
        count = 0
        for result in result_iter:
            # predictions from the test input. We use it to log inputs and decodes.
            inputs = result["inputs"]
            targets = result["targets"]
            outputs = result["outputs"]
            if FLAGS.decode_return_beams:
                output_beams = np.split(outputs,
                                        FLAGS.decode_beam_size,
                                        axis=0)
                for k, beam in enumerate(output_beams):
                    tf.logging.info("BEAM %d:" % k)
                    log_fn(inputs, targets, beam, problem, count)
            else:
                log_fn(inputs, targets, outputs, problem, count)

            count += 1
            if FLAGS.decode_num_samples != -1 and count >= FLAGS.decode_num_samples:
                break
        tf.logging.info("Completed inference on %d samples." % count)
Esempio n. 3
0
def generate_test_data(problem, tmp_dir):
    problem.generate_data(tmp_dir, tmp_dir)
    filepatterns = data_reader.get_data_filepatterns(
        problem.name, tmp_dir, tf.estimator.ModeKeys.TRAIN)
    assert tf.gfile.Glob(filepatterns[0])
    return filepatterns
Esempio n. 4
0
def get_data_filepatterns(data_dir, mode):
  return data_reader.get_data_filepatterns(FLAGS.problems, data_dir, mode)
Esempio n. 5
0
def decode_from_dataset(estimator,
                        problem_names,
                        decode_hp,
                        decode_to_file=None):
    tf.logging.info("Performing local inference from dataset for %s.",
                    str(problem_names))
    hparams = estimator.params

    for problem_idx, problem_name in enumerate(problem_names):
        # Build the inference input function
        infer_problems_data = data_reader.get_data_filepatterns(
            problem_name, hparams.data_dir, tf.estimator.ModeKeys.PREDICT)

        infer_input_fn = input_fn_builder.build_input_fn(
            mode=tf.estimator.ModeKeys.PREDICT,
            hparams=hparams,
            data_file_patterns=infer_problems_data,
            num_datashards=devices.data_parallelism().n,
            fixed_problem=problem_idx)

        # Get the predictions as an iterable
        predictions = estimator.predict(infer_input_fn)

        # Prepare output file writers if decode_to_file passed
        if decode_to_file:
            output_filepath = _decode_filename(decode_to_file, problem_name,
                                               decode_hp)
            parts = output_filepath.split(".")
            parts[-1] = "targets"
            target_filepath = ".".join(parts)

            output_file = tf.gfile.Open(output_filepath, "w")
            target_file = tf.gfile.Open(target_filepath, "w")

        problem_hparams = hparams.problems[problem_idx]
        inputs_vocab = problem_hparams.vocabulary.get("inputs", None)
        targets_vocab = problem_hparams.vocabulary["targets"]
        for num_predictions, prediction in enumerate(predictions):
            inputs = prediction["inputs"]
            targets = prediction["targets"]
            outputs = prediction["outputs"]

            # Log predictions
            decoded_outputs = []
            if decode_hp.return_beams:
                output_beams = np.split(outputs, decode_hp.beam_size, axis=0)
                for i, beam in enumerate(output_beams):
                    tf.logging.info("BEAM %d:" % i)
                    decoded = log_decode_results(
                        inputs,
                        beam,
                        problem_name,
                        num_predictions,
                        inputs_vocab,
                        targets_vocab,
                        save_images=decode_hp.save_images,
                        model_dir=estimator.model_dir,
                        identity_output=decode_hp.identity_output,
                        targets=targets)
                    decoded_outputs.append(decoded)
            else:
                decoded = log_decode_results(
                    inputs,
                    outputs,
                    problem_name,
                    num_predictions,
                    inputs_vocab,
                    targets_vocab,
                    save_images=decode_hp.save_images,
                    model_dir=estimator.model_dir,
                    identity_output=decode_hp.identity_output,
                    targets=targets)
                decoded_outputs.append(decoded)

            # Write out predictions if decode_to_file passed
            if decode_to_file:
                for decoded_output, decoded_target in decoded_outputs:
                    output_file.write(str(decoded_output) + "\n")
                    target_file.write(str(decoded_target) + "\n")

            if (decode_hp.num_samples >= 0
                    and num_predictions >= decode_hp.num_samples):
                break

        if decode_to_file:
            output_file.close()
            target_file.close()

        tf.logging.info("Completed inference on %d samples." % num_predictions)  # pylint: disable=undefined-loop-variable