def train_eval_and_decode(self): """Does eval and decode after training every eval_freq_in_steps.""" eval_steps = self._hparams.eval_freq_in_steps packed_dataset = "_packed" in self._hparams.problem.name for i in range(0, self._train_spec.max_steps, eval_steps): if packed_dataset and i > 0: problem = registry.problem(self._hparams.problem.name + "_packed") p_hparams = problem.get_hparams(self._hparams) self._hparams.problem = problem self._hparams.problem_hparams = p_hparams self._estimator.train(self._train_spec.input_fn, steps=eval_steps, hooks=self._train_spec.hooks) self._set_eval_dir_name("eval") self._estimator.evaluate(self._eval_spec.input_fn, steps=self._eval_spec.steps, hooks=self._eval_spec.hooks, name="eval") if packed_dataset: problem = registry.problem( self._hparams.problem.name.replace("_packed", "")) p_hparams = problem.get_hparams(self._hparams) self._hparams.problem = problem self._hparams.problem_hparams = p_hparams self.decode(dataset_split=tf.estimator.ModeKeys.EVAL) d_hparams = self._decode_hparams d_hparams = self._decode_hparams
def generate_data(): # Generate data if requested. data_dir = os.path.expanduser(FLAGS.data_dir) tmp_dir = os.path.expanduser(FLAGS.tmp_dir) tf.gfile.MakeDirs(data_dir) tf.gfile.MakeDirs(tmp_dir) problem_name = FLAGS.problem tf.logging.info("Generating data for %s" % problem_name) registry.problem(problem_name).generate_data(data_dir, tmp_dir)
def generate_data_for_registered_problem(problem_name): """Generate data for a registered problem.""" tf.logging.info("Generating data for %s.", problem_name) if FLAGS.num_shards: raise ValueError( "--num_shards should not be set for registered Problem.") problem = registry.problem(problem_name) task_id = None if FLAGS.task_id < 0 else FLAGS.task_id data_dir = os.path.expanduser(FLAGS.data_dir) tmp_dir = os.path.expanduser(FLAGS.tmp_dir) if task_id is None and problem.multiprocess_generate: if FLAGS.task_id_start != -1: assert FLAGS.task_id_end != -1 task_id_start = FLAGS.task_id_start task_id_end = FLAGS.task_id_end else: task_id_start = 0 task_id_end = problem.num_generate_tasks pool = multiprocessing.Pool(processes=FLAGS.num_concurrent_processes) problem.prepare_to_generate(data_dir, tmp_dir) args = [(problem_name, data_dir, tmp_dir, task_id) for task_id in range(task_id_start, task_id_end)] pool.map(generate_data_in_process, args) else: problem.generate_data(data_dir, tmp_dir, task_id)
def score_file(filename): """Score each line in a file and return the scores.""" # Prepare model. hparams = create_hparams() encoders = registry.problem(FLAGS.problem).feature_encoders(FLAGS.data_dir) has_inputs = "inputs" in encoders # Prepare features for feeding into the model. if has_inputs: inputs_ph = tf.placeholder(dtype=tf.int32) # Just length dimension. batch_inputs = tf.reshape(inputs_ph, [1, -1, 1, 1]) # Make it 4D. targets_ph = tf.placeholder(dtype=tf.int32) # Just length dimension. batch_targets = tf.reshape(targets_ph, [1, -1, 1, 1]) # Make it 4D. if has_inputs: features = {"inputs": batch_inputs, "targets": batch_targets} else: features = {"targets": batch_targets} # Prepare the model and the graph when model runs on features. model = registry.model(FLAGS.model)(hparams, tf.estimator.ModeKeys.EVAL) _, losses = model(features) saver = tf.train.Saver() with tf.Session() as sess: # Load weights from checkpoint. if FLAGS.checkpoint_path is None: ckpts = tf.train.get_checkpoint_state(FLAGS.output_dir) ckpt = ckpts.model_checkpoint_path else: ckpt = FLAGS.checkpoint_path saver.restore(sess, ckpt) # Run on each line. with tf.gfile.Open(filename) as f: lines = f.readlines() results = [] for line in lines: tab_split = line.split("\t") if len(tab_split) > 2: raise ValueError("Each line must have at most one tab separator.") if len(tab_split) == 1: targets = tab_split[0].strip() else: targets = tab_split[1].strip() inputs = tab_split[0].strip() # Run encoders and append EOS symbol. targets_numpy = encoders["targets"].encode( targets) + [text_encoder.EOS_ID] if has_inputs: inputs_numpy = encoders["inputs"].encode(inputs) + [text_encoder.EOS_ID] # Prepare the feed. if has_inputs: feed = {inputs_ph: inputs_numpy, targets_ph: targets_numpy} else: feed = {targets_ph: targets_numpy} # Get the score. np_loss = sess.run(losses["training"], feed) results.append(np_loss) return results
def add_problem_hparams(hparams, problem_name_or_instance): """Add problem hparams for the problems.""" if isinstance(problem_name_or_instance, problem_lib.Problem): problem = problem_name_or_instance else: problem = registry.problem(problem_name_or_instance) p_hparams = problem.get_hparams(hparams) hparams.problem = problem hparams.problem_hparams = p_hparams
def main(_): problem = registry.problem(FLAGS.problem) # We make the assumption that the problem is a subclass of Text2TextProblem. assert isinstance(problem, text_problems.Text2TextProblem) data_dir = os.path.expanduser(FLAGS.data_dir) tmp_dir = os.path.expanduser(FLAGS.tmp_dir) tf.gfile.MakeDirs(data_dir) tf.gfile.MakeDirs(tmp_dir) tf.logging.info("Saving vocabulary to data_dir: %s" % data_dir) problem.get_or_create_vocab(data_dir, tmp_dir) tf.logging.info("Saved vocabulary file: " + os.path.join(data_dir, problem.vocab_filename))
def generate_data_in_process(arg): problem_name, data_dir, tmp_dir, task_id = arg problem = registry.problem(problem_name) problem.generate_data(data_dir, tmp_dir, task_id)
def problem(name): return registry.problem(name)