Ejemplo n.º 1
0
def create_train_model(model_creator,
                       hparams,
                       scope=None,
                       num_workers=1,
                       jobid=0,
                       graph=None,
                       extra_args=None,
                       trie=None,
                       use_placeholders=False):
  """Create train graph, model, and iterator."""
  src_vocab_file = hparams.src_vocab_file
  tgt_vocab_file = hparams.tgt_vocab_file

  if not graph:
    graph = tf.Graph()

  with graph.as_default(), tf.container(scope or "train"):
    src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
        src_vocab_file, tgt_vocab_file, hparams.share_vocab)
    reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
        tgt_vocab_file, default_value=vocab_utils.UNK)

    annot_placeholder = None
    src_placeholder = None
    tgt_placeholder = None
    annot_dataset = None
    ctx_dataset = None
    if use_placeholders:
      src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
      src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder)

      tgt_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
      tgt_dataset = tf.data.Dataset.from_tensor_slices(tgt_placeholder)

      if hparams.use_rl:
        annot_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        annot_dataset = tf.data.Dataset.from_tensor_slices(annot_placeholder)
    else:
      src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
      tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)
      ctx_file = None
      if hparams.ctx is not None:
        ctx_file = "%s.%s" % (hparams.train_prefix, hparams.ctx)

      src_dataset = tf.data.TextLineDataset(src_file)
      tgt_dataset = tf.data.TextLineDataset(tgt_file)

      if hparams.train_annotations is not None:
        annot_dataset = tf.data.TextLineDataset(hparams.train_annotations)

      if ctx_file is not None:
        ctx_dataset = tf.data.TextLineDataset(ctx_file)

    skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)

    iterator = iterator_utils.get_iterator(
        hparams=hparams,
        src_dataset=src_dataset,
        tgt_dataset=tgt_dataset,
        src_vocab_table=src_vocab_table,
        tgt_vocab_table=tgt_vocab_table,
        ctx_dataset=ctx_dataset,
        annot_dataset=annot_dataset,
        batch_size=hparams.batch_size,
        sos=hparams.sos,
        eos=hparams.eos,
        random_seed=hparams.random_seed,
        num_buckets=hparams.num_buckets,
        src_max_len=hparams.src_max_len,
        tgt_max_len=hparams.tgt_max_len,
        skip_count=skip_count_placeholder,
        num_shards=num_workers,
        shard_index=jobid)

    # Note: One can set model_device_fn to
    # `tf.train.replica_device_setter(ps_tasks)` for distributed training.
    model_device_fn = None
    if extra_args:
      model_device_fn = extra_args.model_device_fn
    with tf.device(model_device_fn):
      model = model_creator(
          hparams=hparams,
          iterator=iterator,
          mode=tf.contrib.learn.ModeKeys.TRAIN,
          source_vocab_table=src_vocab_table,
          target_vocab_table=tgt_vocab_table,
          reverse_target_vocab_table=reverse_tgt_vocab_table,
          scope=scope,
          extra_args=extra_args,
          trie=trie)

  return TrainModel(
      graph=graph,
      model=model,
      iterator=iterator,
      src_placeholder=src_placeholder,
      tgt_placeholder=tgt_placeholder,
      annot_placeholder=annot_placeholder,
      skip_count_placeholder=skip_count_placeholder)
Ejemplo n.º 2
0
def create_eval_model(model_creator,
                      hparams,
                      scope=None,
                      graph=None,
                      extra_args=None,
                      trie=None):
  """Create train graph, model, src/tgt file holders, and iterator."""
  src_vocab_file = hparams.src_vocab_file
  tgt_vocab_file = hparams.tgt_vocab_file

  if not graph:
    graph = tf.Graph()

  with graph.as_default(), tf.container(scope or "eval"):
    src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
        src_vocab_file, tgt_vocab_file, hparams.share_vocab)
    reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
        tgt_vocab_file, default_value=vocab_utils.UNK)
    src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
    tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)

    ctx_file_placeholder = None
    if hparams.ctx is not None:
      ctx_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)

    annot_file_placeholder = None
    if hparams.dev_annotations is not None:
      annot_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)

    src_dataset = tf.data.TextLineDataset(src_file_placeholder)
    tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder)

    ctx_dataset = None
    if ctx_file_placeholder is not None:
      ctx_dataset = tf.data.TextLineDataset(ctx_file_placeholder)

    annot_dataset = None
    if annot_file_placeholder is not None:
      annot_dataset = tf.data.TextLineDataset(annot_file_placeholder)

    iterator = iterator_utils.get_iterator(
        hparams=hparams,
        src_dataset=src_dataset,
        tgt_dataset=tgt_dataset,
        src_vocab_table=src_vocab_table,
        tgt_vocab_table=tgt_vocab_table,
        batch_size=hparams.batch_size,
        sos=hparams.sos,
        eos=hparams.eos,
        ctx_dataset=ctx_dataset,
        annot_dataset=annot_dataset,
        random_seed=hparams.random_seed,
        num_buckets=hparams.num_buckets,
        src_max_len=hparams.src_max_len_infer,
        tgt_max_len=hparams.tgt_max_len_infer)
    model = model_creator(
        hparams=hparams,
        iterator=iterator,
        mode=tf.contrib.learn.ModeKeys.EVAL,
        source_vocab_table=src_vocab_table,
        target_vocab_table=tgt_vocab_table,
        reverse_target_vocab_table=reverse_tgt_vocab_table,
        scope=scope,
        extra_args=extra_args,
        trie=trie)

  return EvalModel(
      graph=graph,
      model=model,
      src_file_placeholder=src_file_placeholder,
      tgt_file_placeholder=tgt_file_placeholder,
      ctx_file_placeholder=ctx_file_placeholder,
      annot_file_placeholder=annot_file_placeholder,
      iterator=iterator)
Ejemplo n.º 3
0
def create_infer_model(model_creator,
                       hparams,
                       scope=None,
                       graph=None,
                       extra_args=None,
                       trie=None):
  """Create inference model."""
  src_vocab_file = hparams.src_vocab_file
  tgt_vocab_file = hparams.tgt_vocab_file

  if not graph:
    graph = tf.Graph()

  with graph.as_default(), tf.container(scope or "infer"):
    src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
        src_vocab_file, tgt_vocab_file, hparams.share_vocab)
    reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
        tgt_vocab_file, default_value=vocab_utils.UNK)

    src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)

    ctx_placeholder = None
    if hparams.ctx is not None:
      ctx_placeholder = tf.placeholder(shape=[None], dtype=tf.string)

    annot_placeholder = None
    if hparams.use_rl:
      annot_placeholder = tf.placeholder(shape=[None], dtype=tf.string)

    batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)

    src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder)

    ctx_dataset = None
    if ctx_placeholder is not None:
      ctx_dataset = tf.data.Dataset.from_tensor_slices(ctx_placeholder)

    annot_dataset = None
    if annot_placeholder is not None:
      annot_dataset = tf.data.Dataset.from_tensor_slices(annot_placeholder)

    trie_exclude_placeholder = None
    trie_exclude_dataset = None
    if hparams.infer_mode.startswith("trie_"):
      trie_exclude_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
      trie_exclude_dataset = tf.data.Dataset.from_tensor_slices(
          trie_exclude_placeholder)

    iterator = iterator_utils.get_infer_iterator(
        hparams=hparams,
        src_dataset=src_dataset,
        src_vocab_table=src_vocab_table,
        tgt_vocab_table=tgt_vocab_table,
        ctx_dataset=ctx_dataset,
        annot_dataset=annot_dataset,
        trie_exclude_dataset=trie_exclude_dataset,
        batch_size=batch_size_placeholder,
        eos=hparams.eos,
        src_max_len=hparams.src_max_len_infer)
    model = model_creator(
        hparams=hparams,
        iterator=iterator,
        mode=tf.contrib.learn.ModeKeys.INFER,
        source_vocab_table=src_vocab_table,
        target_vocab_table=tgt_vocab_table,
        reverse_target_vocab_table=reverse_tgt_vocab_table,
        scope=scope,
        extra_args=extra_args,
        trie=trie)
  return InferModel(
      graph=graph,
      model=model,
      src_placeholder=src_placeholder,
      annot_placeholder=annot_placeholder,
      trie_exclude_placeholder=trie_exclude_placeholder,
      ctx_placeholder=ctx_placeholder,
      batch_size_placeholder=batch_size_placeholder,
      iterator=iterator)
Ejemplo n.º 4
0
def create_train_model_for_server(model_creator,
                                  hparams,
                                  scope=None,
                                  num_workers=1,
                                  jobid=0,
                                  graph=None,
                                  extra_args=None,
                                  trie=None):
  """Create graph, model, and iterator when running the NMT in server mode.

  This is different from the standard training model, because the input arrives
  via RPC and thus has to be fed using placeholders."""
  assert hparams.num_buckets == 1, "No bucketing when in server mode."

  src_vocab_file = hparams.src_vocab_file
  tgt_vocab_file = hparams.tgt_vocab_file

  if not graph:
    graph = tf.Graph()

  with graph.as_default(), tf.container(scope or "train"):
    src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
        src_vocab_file, tgt_vocab_file, hparams.share_vocab)
    reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
        tgt_vocab_file, default_value=vocab_utils.UNK)
    skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)

    src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
    src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder)

    tgt_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
    tgt_dataset = tf.data.Dataset.from_tensor_slices(tgt_placeholder)

    wgt_placeholder = tf.placeholder(shape=[None], dtype=tf.float32)
    wgt_dataset = tf.data.Dataset.from_tensor_slices(wgt_placeholder)

    ctx_placeholder = None
    if hparams.ctx is not None:
      ctx_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
    ctx_dataset = None
    if ctx_placeholder is not None:
      ctx_dataset = tf.data.Dataset.from_tensor_slices(ctx_placeholder)

    batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)

    iterator = iterator_utils.get_iterator(
        hparams=hparams,
        src_dataset=src_dataset,
        tgt_dataset=tgt_dataset,
        src_vocab_table=src_vocab_table,
        tgt_vocab_table=tgt_vocab_table,
        wgt_dataset=wgt_dataset,
        ctx_dataset=ctx_dataset,
        annot_dataset=None,
        batch_size=batch_size_placeholder,
        sos=hparams.sos,
        eos=hparams.eos,
        random_seed=hparams.random_seed,
        num_buckets=hparams.num_buckets,
        src_max_len=hparams.src_max_len,
        tgt_max_len=hparams.tgt_max_len,
        skip_count=skip_count_placeholder,
        num_shards=num_workers,
        shard_index=jobid)

    # Note: One can set model_device_fn to
    # `tf.train.replica_device_setter(ps_tasks)` for distributed training.
    model_device_fn = None
    if extra_args:
      model_device_fn = extra_args.model_device_fn
    with tf.device(model_device_fn):
      model = model_creator(
          hparams=hparams,
          iterator=iterator,
          mode=tf.contrib.learn.ModeKeys.TRAIN,
          source_vocab_table=src_vocab_table,
          target_vocab_table=tgt_vocab_table,
          reverse_target_vocab_table=reverse_tgt_vocab_table,
          scope=scope,
          extra_args=extra_args,
          trie=trie)

  return TrainModelForServer(
      graph=graph,
      model=model,
      iterator=iterator,
      src_placeholder=src_placeholder,
      tgt_placeholder=tgt_placeholder,
      wgt_placeholder=wgt_placeholder,
      batch_size_placeholder=batch_size_placeholder,
      skip_count_placeholder=skip_count_placeholder)