Ejemplo n.º 1
0
def create_train_model(model_creator,
                       hparams,
                       scope=None,
                       num_workers=1,
                       jobid=0,
                       extra_args=None):
    """Create graph, model and iterator for training."""
    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "train"):
        vocab_table = vocab_utils.create_vocab_tables(hparams.vocab_file)
        data_dataset = tf.data.TextLineDataset(hparams.train_data)
        kb_dataset = tf.data.TextLineDataset(hparams.train_kb)
        skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)
        reverse_vocab_table = lookup_ops.index_to_string_table_from_file(
            hparams.vocab_file, default_value=vocab_utils.UNK)
        # this is the actual train_iterator
        train_iterator = iterator_utils.get_iterator(
            data_dataset,
            kb_dataset,
            vocab_table,
            batch_size=hparams.batch_size,
            t1=hparams.t1,
            t2=hparams.t2,
            eod=hparams.eod,
            len_action=hparams.len_action,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            max_dialogue_len=hparams.max_dialogue_len,
            skip_count=skip_count_placeholder,
            num_shards=num_workers,
            shard_index=jobid)

        # this is the placeholder iterator. One can use this placeholder iterator
        # to switch between training and evauation.
        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(
            handle, train_iterator.output_types, train_iterator.output_shapes)
        batched_iterator = iterator_utils.get_batched_iterator(iterator)
        model_device_fn = None
        if extra_args:
            model_device_fn = extra_args.model_device_fn
        with tf.device(model_device_fn):
            model = model_creator(hparams,
                                  iterator=batched_iterator,
                                  handle=handle,
                                  mode=tf.contrib.learn.ModeKeys.TRAIN,
                                  vocab_table=vocab_table,
                                  scope=scope,
                                  extra_args=extra_args,
                                  reverse_vocab_table=reverse_vocab_table)
    return TrainModel(graph=graph,
                      model=model,
                      placeholder_iterator=iterator,
                      train_iterator=train_iterator,
                      placeholder_handle=handle,
                      skip_count_placeholder=skip_count_placeholder)
Ejemplo n.º 2
0
def create_selfplay_model(model_creator,
                          is_mutable,
                          num_workers,
                          jobid,
                          hparams,
                          scope=None,
                          extra_args=None):
    """create slef play models."""
    graph = tf.Graph()
    with graph.as_default(), tf.container(scope or "selfplay"):
        vocab_table = vocab_utils.create_vocab_tables(hparams.vocab_file)
        reverse_vocab_table = lookup_ops.index_to_string_table_from_file(
            hparams.vocab_file, default_value=vocab_utils.UNK)

        if is_mutable:
            mutable_index = 0
        else:
            mutable_index = 1

        # get a list of iterators and placeholders
        iterators, placeholders = self_play_iterator_creator(
            hparams, num_workers, jobid)
        train_iterator, self_play_fulltext_iterator, self_play_structured_iterator = iterators
        data_placeholder, kb_placeholder, batch_size_placeholder, skip_count_placeholder = placeholders

        # get an iterator handler
        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(
            handle, train_iterator.output_types, train_iterator.output_shapes)
        batched_iterator = iterator_utils.get_batched_iterator(iterator)

        model_device_fn = None
        if extra_args:
            model_device_fn = extra_args.model_device_fn
        with tf.device(model_device_fn):
            model = model_creator(hparams,
                                  iterator=batched_iterator,
                                  handle=handle,
                                  mode=[
                                      dialogue_utils.mode_self_play_mutable,
                                      dialogue_utils.mode_self_play_immutable
                                  ][mutable_index],
                                  vocab_table=vocab_table,
                                  reverse_vocab_table=reverse_vocab_table,
                                  scope=scope,
                                  extra_args=extra_args)
    return SelfplayModel(graph=graph,
                         model=model,
                         placeholder_iterator=iterator,
                         placeholder_handle=handle,
                         train_iterator=train_iterator,
                         self_play_ft_iterator=self_play_fulltext_iterator,
                         self_play_st_iterator=self_play_structured_iterator,
                         data_placeholder=data_placeholder,
                         kb_placeholder=kb_placeholder,
                         skip_count_placeholder=skip_count_placeholder,
                         batch_size_placeholder=batch_size_placeholder)
Ejemplo n.º 3
0
def create_infer_model(model_creator, hparams, scope=None, extra_args=None):
    """Create inference model."""
    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "infer"):
        vocab_table = vocab_utils.create_vocab_tables(hparams.vocab_file)
        reverse_vocab_table = lookup_ops.index_to_string_table_from_file(
            hparams.vocab_file, default_value=vocab_utils.UNK)

        data_src_placeholder = tf.placeholder(shape=[None],
                                              dtype=tf.string,
                                              name="src_ph")
        kb_placeholder = tf.placeholder(shape=[None],
                                        dtype=tf.string,
                                        name="kb_ph")
        batch_size_placeholder = tf.placeholder(shape=[],
                                                dtype=tf.int64,
                                                name="bs_ph")

        data_src_dataset = tf.data.Dataset.from_tensor_slices(
            data_src_placeholder)
        kb_dataset = tf.data.Dataset.from_tensor_slices(kb_placeholder)

        # this is the actual infer iterator
        infer_iterator = iterator_utils.get_infer_iterator(
            data_src_dataset,
            kb_dataset,
            vocab_table,
            batch_size=batch_size_placeholder,
            eod=hparams.eod,
            len_action=hparams.len_action)

        # this is the placeholder infer iterator
        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(
            handle, infer_iterator.output_types, infer_iterator.output_shapes)
        batched_iterator = iterator_utils.get_batched_iterator(iterator)

        model = model_creator(hparams,
                              iterator=batched_iterator,
                              handle=handle,
                              mode=tf.contrib.learn.ModeKeys.INFER,
                              vocab_table=vocab_table,
                              reverse_vocab_table=reverse_vocab_table,
                              scope=scope,
                              extra_args=extra_args)

    return InferModel(graph=graph,
                      model=model,
                      placeholder_iterator=iterator,
                      placeholder_handle=handle,
                      infer_iterator=infer_iterator,
                      data_src_placeholder=data_src_placeholder,
                      kb_placeholder=kb_placeholder,
                      batch_size_placeholder=batch_size_placeholder)
Ejemplo n.º 4
0
def create_eval_model(model_creator, hparams, scope=None, extra_args=None):
    """Create train graph, model, src/tgt file holders, and iterator."""
    vocab_file = hparams.vocab_file
    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "eval"):
        vocab_table = vocab_utils.create_vocab_tables(vocab_file)
        data_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        kb_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        data_dataset = tf.data.TextLineDataset(data_file_placeholder)
        kb_dataset = tf.data.TextLineDataset(kb_file_placeholder)
        # this is the eval_actual iterator
        eval_iterator = iterator_utils.get_iterator(
            data_dataset,
            kb_dataset,
            vocab_table,
            batch_size=hparams.batch_size,
            t1=hparams.t1,
            t2=hparams.t2,
            eod=hparams.eod,
            len_action=hparams.len_action,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            max_dialogue_len=hparams.max_dialogue_len)
        # this is the placeholder iterator
        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(
            handle, eval_iterator.output_types, eval_iterator.output_shapes)
        batched_iterator = iterator_utils.get_batched_iterator(iterator)

        model = model_creator(hparams,
                              iterator=batched_iterator,
                              handle=handle,
                              mode=tf.contrib.learn.ModeKeys.EVAL,
                              vocab_table=vocab_table,
                              scope=scope,
                              extra_args=extra_args)

    return EvalModel(graph=graph,
                     model=model,
                     placeholder_iterator=iterator,
                     placeholder_handle=handle,
                     eval_iterator=eval_iterator,
                     data_file_placeholder=data_file_placeholder,
                     kb_file_placeholder=kb_file_placeholder)