def decode_from_dataset(estimator, problem_names, return_beams=False, beam_size=1, max_predictions=-1, decode_to_file=None, save_images=False, identity_output=False): tf.logging.info("Performing local inference from dataset for %s.", str(problem_names)) hparams = estimator.hparams for problem_idx, problem_name in enumerate(problem_names): # Build the inference input function infer_problems_data = data_reader.get_data_filepatterns( problem_name, hparams.data_dir, tf.contrib.learn.ModeKeys.INFER) infer_input_fn = input_fn_builder.build_input_fn( mode=tf.contrib.learn.ModeKeys.INFER, hparams=hparams, data_file_patterns=infer_problems_data, num_datashards=devices.data_parallelism().n, fixed_problem=problem_idx) # Get the predictions as an iterable predictions = estimator.predict(input_fn=infer_input_fn, as_iterable=True) # Prepare output file writers if decode_to_file passed if decode_to_file: output_filepath = decode_to_file + ".outputs." + problem_name target_filepath = decode_to_file + ".targets." + problem_name output_file = tf.gfile.Open(output_filepath, "w") target_file = tf.gfile.Open(target_filepath, "w") problem_hparams = hparams.problems[problem_idx] inputs_vocab = problem_hparams.vocabulary.get("inputs", None) targets_vocab = problem_hparams.vocabulary["targets"] for num_predictions, prediction in enumerate(predictions): inputs = prediction["inputs"] targets = prediction["targets"] outputs = prediction["outputs"] # Log predictions decoded_outputs = [] if return_beams: output_beams = np.split(outputs, beam_size, axis=0) for i, beam in enumerate(output_beams): tf.logging.info("BEAM %d:" % i) decoded = _decode_from_dataset_log_results( inputs, targets, beam, problem_name, num_predictions, inputs_vocab, targets_vocab, save_images, estimator.model_dir, identity_output) decoded_outputs.append(decoded) else: decoded = _decode_from_dataset_log_results( inputs, targets, outputs, problem_name, num_predictions, inputs_vocab, targets_vocab, save_images, estimator.model_dir, identity_output) decoded_outputs.append(decoded) # Write out predictions if decode_to_file passed if decode_to_file: for decoded_output, decoded_target in decoded_outputs: output_file.write(str(decoded_output) + "\n") target_file.write(str(decoded_target) + "\n") if max_predictions >= 0 and num_predictions >= max_predictions: break if decode_to_file: output_file.close() target_file.close() tf.logging.info("Completed inference on %d samples." % num_predictions) # pylint: disable=undefined-loop-variable
def decode_from_dataset(estimator): hparams = estimator.hparams for i, problem in enumerate(FLAGS.problems.split("-")): inputs_vocab = hparams.problems[i].vocabulary.get("inputs", None) targets_vocab = hparams.problems[i].vocabulary["targets"] tf.logging.info("Performing local inference.") infer_problems_data = data_reader.get_data_filepatterns( FLAGS.problems, hparams.data_dir, tf.contrib.learn.ModeKeys.INFER) infer_input_fn = input_fn_builder.build_input_fn( mode=tf.contrib.learn.ModeKeys.INFER, hparams=hparams, data_file_patterns=infer_problems_data, num_datashards=devices.data_parallelism().n, fixed_problem=i) def log_fn(inputs, targets, outputs, problem, j, inputs_vocab=inputs_vocab, targets_vocab=targets_vocab): """Log inference results.""" if "image" in problem and FLAGS.decode_save_images: save_path = os.path.join(estimator.model_dir, "%s_prediction_%d.jpg" % (problem, j)) show_and_save_image(inputs / 255., save_path) elif inputs_vocab: decoded_inputs = inputs_vocab.decode( _save_until_eos(inputs.flatten())) tf.logging.info("Inference results INPUT: %s" % decoded_inputs) if FLAGS.identity_output: decoded_outputs = " ".join(map(str, outputs.flatten())) decoded_targets = " ".join(map(str, targets.flatten())) else: decoded_outputs = targets_vocab.decode( _save_until_eos(outputs.flatten())) decoded_targets = targets_vocab.decode( _save_until_eos(targets.flatten())) tf.logging.info("Inference results OUTPUT: %s" % decoded_outputs) tf.logging.info("Inference results TARGET: %s" % decoded_targets) if FLAGS.decode_to_file: output_filepath = FLAGS.decode_to_file + ".outputs." + problem output_file = tf.gfile.Open(output_filepath, "a") output_file.write(decoded_outputs + "\n") target_filepath = FLAGS.decode_to_file + ".targets." + problem target_file = tf.gfile.Open(target_filepath, "a") target_file.write(decoded_targets + "\n") result_iter = estimator.predict(input_fn=infer_input_fn, as_iterable=True) count = 0 for result in result_iter: # predictions from the test input. We use it to log inputs and decodes. inputs = result["inputs"] targets = result["targets"] outputs = result["outputs"] if FLAGS.decode_return_beams: output_beams = np.split(outputs, FLAGS.decode_beam_size, axis=0) for k, beam in enumerate(output_beams): tf.logging.info("BEAM %d:" % k) log_fn(inputs, targets, beam, problem, count) else: log_fn(inputs, targets, outputs, problem, count) count += 1 if FLAGS.decode_num_samples != -1 and count >= FLAGS.decode_num_samples: break tf.logging.info("Completed inference on %d samples." % count)
def generate_test_data(problem, tmp_dir): problem.generate_data(tmp_dir, tmp_dir) filepatterns = data_reader.get_data_filepatterns( problem.name, tmp_dir, tf.estimator.ModeKeys.TRAIN) assert tf.gfile.Glob(filepatterns[0]) return filepatterns
def get_data_filepatterns(data_dir, mode): return data_reader.get_data_filepatterns(FLAGS.problems, data_dir, mode)
def decode_from_dataset(estimator, problem_names, decode_hp, decode_to_file=None): tf.logging.info("Performing local inference from dataset for %s.", str(problem_names)) hparams = estimator.params for problem_idx, problem_name in enumerate(problem_names): # Build the inference input function infer_problems_data = data_reader.get_data_filepatterns( problem_name, hparams.data_dir, tf.estimator.ModeKeys.PREDICT) infer_input_fn = input_fn_builder.build_input_fn( mode=tf.estimator.ModeKeys.PREDICT, hparams=hparams, data_file_patterns=infer_problems_data, num_datashards=devices.data_parallelism().n, fixed_problem=problem_idx) # Get the predictions as an iterable predictions = estimator.predict(infer_input_fn) # Prepare output file writers if decode_to_file passed if decode_to_file: output_filepath = _decode_filename(decode_to_file, problem_name, decode_hp) parts = output_filepath.split(".") parts[-1] = "targets" target_filepath = ".".join(parts) output_file = tf.gfile.Open(output_filepath, "w") target_file = tf.gfile.Open(target_filepath, "w") problem_hparams = hparams.problems[problem_idx] inputs_vocab = problem_hparams.vocabulary.get("inputs", None) targets_vocab = problem_hparams.vocabulary["targets"] for num_predictions, prediction in enumerate(predictions): inputs = prediction["inputs"] targets = prediction["targets"] outputs = prediction["outputs"] # Log predictions decoded_outputs = [] if decode_hp.return_beams: output_beams = np.split(outputs, decode_hp.beam_size, axis=0) for i, beam in enumerate(output_beams): tf.logging.info("BEAM %d:" % i) decoded = log_decode_results( inputs, beam, problem_name, num_predictions, inputs_vocab, targets_vocab, save_images=decode_hp.save_images, model_dir=estimator.model_dir, identity_output=decode_hp.identity_output, targets=targets) decoded_outputs.append(decoded) else: decoded = log_decode_results( inputs, outputs, problem_name, num_predictions, inputs_vocab, targets_vocab, save_images=decode_hp.save_images, model_dir=estimator.model_dir, identity_output=decode_hp.identity_output, targets=targets) decoded_outputs.append(decoded) # Write out predictions if decode_to_file passed if decode_to_file: for decoded_output, decoded_target in decoded_outputs: output_file.write(str(decoded_output) + "\n") target_file.write(str(decoded_target) + "\n") if (decode_hp.num_samples >= 0 and num_predictions >= decode_hp.num_samples): break if decode_to_file: output_file.close() target_file.close() tf.logging.info("Completed inference on %d samples." % num_predictions) # pylint: disable=undefined-loop-variable