예제 #1
0
    def _input_fn(params):
        """Input function."""
        del params

        if mode == tf.contrib.learn.ModeKeys.TRAIN:
            src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
            tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)
        else:
            if hparams.mode == "translate":
                src_file = hparams.translate_file + ".tok"
                tgt_file = hparams.translate_file + ".tok"
            else:
                src_file = "%s.%s" % (hparams.test_prefix, hparams.src)
                tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt)
        src_vocab_file = hparams.src_vocab_file
        tgt_vocab_file = hparams.tgt_vocab_file
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)
        src_dataset = tf.data.TextLineDataset(src_file)
        tgt_dataset = tf.data.TextLineDataset(tgt_file)

        if mode == tf.contrib.learn.ModeKeys.TRAIN:
            # Run one epoch and stop if running train_and_eval.
            if hparams.mode == "train_and_eval":
                # In this mode input pipeline is restarted every epoch, so choose a
                # different random_seed.
                num_repeat = 1
                random_seed = hparams.random_seed + int(time.time()) % 100
            else:
                num_repeat = 8
                random_seed = hparams.random_seed
            return iterator_utils.get_iterator(
                src_dataset,
                tgt_dataset,
                src_vocab_table,
                tgt_vocab_table,
                batch_size=hparams.batch_size,
                sos=hparams.sos,
                eos=hparams.eos,
                random_seed=random_seed,
                num_buckets=hparams.num_buckets,
                src_max_len=hparams.src_max_len,
                tgt_max_len=hparams.tgt_max_len,
                output_buffer_size=None,
                skip_count=None,
                num_shards=1,  # flags.num_workers
                shard_index=0,  # flags.jobid
                reshuffle_each_iteration=True,
                use_char_encode=hparams.use_char_encode,
                num_repeat=num_repeat,
                filter_oversized_sequences=True
            )  # need to update get_effective_train_epoch_size() if this flag flips.
        else:
            return iterator_utils.get_infer_iterator(
                src_dataset,
                src_vocab_table,
                batch_size=hparams.infer_batch_size,
                eos=hparams.eos,
                src_max_len=hparams.src_max_len,
                use_char_encode=hparams.use_char_encode)
예제 #2
0
def _createModel(mode, hparam, modelFunc=None):
    '''
    根据hparam:
        train_src,train_tgt,创建数据集,然会返回TrainModel
        TrainModels是和神经网络输入相关的 封装
    :param hparam: 
    :return: 
    '''
    def _get_config_proto():
        conf = tf.ConfigProto(allow_soft_placement=True,
                              log_device_placement=False)
        return conf

    graph = tf.Graph()
    with graph.as_default():
        skipCount = None
        if mode == 'train':
            src_dataset = tf.data.TextLineDataset(hparam.train_src)
            tgt_dataset = tf.data.TextLineDataset(hparam.train_tgt)
            skipCount = tf.placeholder(tf.int64, shape=())
        elif mode == 'eval':
            src_dataset = tf.data.TextLineDataset(hparam.dev_src)
            tgt_dataset = tf.data.TextLineDataset(hparam.dev_tgt)
        else:
            raise ValueError('_createTrainModel.mode must be train or eval')

        batch_input = get_iterator(src_dataset, tgt_dataset, skipCount, hparam)
        sess = tf.Session(config=_get_config_proto())
        model = modelFunc(batch_input, mode, hparam)

        if mode == 'train':
            return TrainModel(model, sess, graph, batch_input)
        elif mode == 'eval':
            return EvalModel(model, sess, graph, batch_input)
예제 #3
0
 def testGraphVariablesDevMode(self):
     
     with self.graph.as_default():
         
         iterator, total_num = get_iterator('DEV',
                                             filesobj = files,
                                             buffer_size = hparams.buffer_size,
                                             num_epochs = hparams.num_epochs,
                                             batch_size = hparams.batch_size, 
                                             debug_mode = True)
         
         _ = BuildEvalModel(hparams,
                            iterator,
                            tf.get_default_graph())
         
         var_names = [var.name for var in tf.trainable_variables()]
         
         
     self.assertAllEqual(sorted(var_names), sorted(list(expected_variables.keys())),
                         'variables are not compatible dev mode')
         
     with self.graph.as_default():
         
         for var in tf.trainable_variables():
             
             self.assertAllEqual(tuple(var.get_shape().as_list()),
                                 expected_variables[var.name],
                                 'missed shapes at {} dev mode'.format(var.name))
예제 #4
0
def create_train_model(model_creator,
                       hparams,
                       scope=None,
                       num_workers=1,
                       jobid=0,
                       extra_args=None):
    """Create train graph, model, and iterator."""
    src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
    tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file

    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "train"):
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)

        src_dataset = tf.data.TextLineDataset(src_file)
        tgt_dataset = tf.data.TextLineDataset(tgt_file)

        skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)

        iterator = iterator_utils.get_iterator(
            src_dataset,
            tgt_dataset,
            src_vocab_table,
            tgt_vocab_table,
            batch_size=hparams.batch_size,
            sos=hparams.sos,
            eos=hparams.eos,
            source_reverse=hparams.source_reverse,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            src_max_len=hparams.src_max_len,
            tgt_max_len=hparams.tgt_max_len,
            skip_count=skip_count_placeholder,
            num_shards=num_workers,
            shard_index=jobid)

        # Note: One can set model_device_fn to
        # `tf.train.replica_device_setter(ps_tasks)` for distributed training.
        model_device_fn = None
        if extra_args:
            model_device_fn = extra_args.model_device_fn
        with tf.device(model_device_fn):
            model = model_creator(hparams,
                                  iterator=iterator,
                                  mode=tf.contrib.learn.ModeKeys.TRAIN,
                                  source_vocab_table=src_vocab_table,
                                  target_vocab_table=tgt_vocab_table,
                                  scope=scope,
                                  extra_args=extra_args)

    return TrainModel(graph=graph,
                      model=model,
                      iterator=iterator,
                      skip_count_placeholder=skip_count_placeholder)
예제 #5
0
    def _input_fn(params):
        """Input function."""
        if mode == tf.contrib.learn.ModeKeys.TRAIN:
            src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
            tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)
        else:
            src_file = "%s.%s" % (hparams.test_prefix, hparams.src)
            tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt)
        src_vocab_file = hparams.src_vocab_file
        tgt_vocab_file = hparams.tgt_vocab_file
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)
        src_dataset = tf.data.TextLineDataset(src_file)
        tgt_dataset = tf.data.TextLineDataset(tgt_file)

        if mode == tf.contrib.learn.ModeKeys.TRAIN:
            if "context" in params:
                batch_size = params["batch_size"]
                num_hosts = params["context"].num_hosts
                # TODO(dehao): update to use current_host once available in API.
                current_host = params["context"].current_input_fn_deployment(
                )[1]
            else:
                num_hosts = 1
                current_host = 0
                batch_size = hparams.batch_size
            mlperf_log.gnmt_print(key=mlperf_log.INPUT_BATCH_SIZE,
                                  value=batch_size)
            mlperf_log.gnmt_print(key=mlperf_log.TRAIN_HP_MAX_SEQ_LEN,
                                  value=hparams.src_max_len)
            return iterator_utils.get_iterator(
                src_dataset,
                tgt_dataset,
                src_vocab_table,
                tgt_vocab_table,
                batch_size=batch_size,
                sos=hparams.sos,
                eos=hparams.eos,
                random_seed=hparams.random_seed,
                num_buckets=hparams.num_buckets,
                src_max_len=hparams.src_max_len,
                tgt_max_len=hparams.tgt_max_len,
                output_buffer_size=None,
                skip_count=None,
                num_shards=num_hosts,
                shard_index=current_host,
                reshuffle_each_iteration=True,
                use_char_encode=hparams.use_char_encode,
                filter_oversized_sequences=True)
        else:
            return iterator_utils.get_infer_iterator(
                src_dataset,
                src_vocab_table,
                batch_size=hparams.infer_batch_size,
                eos=hparams.eos,
                src_max_len=hparams.src_max_len,
                use_char_encode=hparams.use_char_encode)
예제 #6
0
def create_train_model(model_creator,
                       hparams,
                       scope=None,
                       num_workers=1,
                       jobid=0,
                       extra_args=None):
    """Create graph, model and iterator for training."""
    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "train"):
        vocab_table = vocab_utils.create_vocab_tables(hparams.vocab_file)
        data_dataset = tf.data.TextLineDataset(hparams.train_data)
        kb_dataset = tf.data.TextLineDataset(hparams.train_kb)
        skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)
        reverse_vocab_table = lookup_ops.index_to_string_table_from_file(
            hparams.vocab_file, default_value=vocab_utils.UNK)
        # this is the actual train_iterator
        train_iterator = iterator_utils.get_iterator(
            data_dataset,
            kb_dataset,
            vocab_table,
            batch_size=hparams.batch_size,
            t1=hparams.t1,
            t2=hparams.t2,
            eod=hparams.eod,
            len_action=hparams.len_action,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            max_dialogue_len=hparams.max_dialogue_len,
            skip_count=skip_count_placeholder,
            num_shards=num_workers,
            shard_index=jobid)

        # this is the placeholder iterator. One can use this placeholder iterator
        # to switch between training and evauation.
        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(
            handle, train_iterator.output_types, train_iterator.output_shapes)
        batched_iterator = iterator_utils.get_batched_iterator(iterator)
        model_device_fn = None
        if extra_args:
            model_device_fn = extra_args.model_device_fn
        with tf.device(model_device_fn):
            model = model_creator(hparams,
                                  iterator=batched_iterator,
                                  handle=handle,
                                  mode=tf.contrib.learn.ModeKeys.TRAIN,
                                  vocab_table=vocab_table,
                                  scope=scope,
                                  extra_args=extra_args,
                                  reverse_vocab_table=reverse_vocab_table)
    return TrainModel(graph=graph,
                      model=model,
                      placeholder_iterator=iterator,
                      train_iterator=train_iterator,
                      placeholder_handle=handle,
                      skip_count_placeholder=skip_count_placeholder)
    def testGetIteratorWithShard(self):
        tf.set_random_seed(1)
        tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor(
            tf.constant(["a", "b", "c", "eos", "sos"]))
        src_dataset = tf.data.Dataset.from_tensor_slices(
            tf.constant(["c c a", "f e a g", "d", "c a"]))
        tgt_dataset = tf.data.Dataset.from_tensor_slices(
            tf.constant(["a b", "c c", "", "b c"]))
        hparams = tf.contrib.training.HParams(random_seed=3,
                                              num_buckets=5,
                                              eos="eos",
                                              sos="sos")
        batch_size = 2
        src_max_len = 3
        dataset = iterator_utils.get_iterator(src_dataset=src_dataset,
                                              tgt_dataset=tgt_dataset,
                                              src_vocab_table=src_vocab_table,
                                              tgt_vocab_table=tgt_vocab_table,
                                              batch_size=batch_size,
                                              sos=hparams.sos,
                                              eos=hparams.eos,
                                              random_seed=hparams.random_seed,
                                              num_buckets=hparams.num_buckets,
                                              src_max_len=src_max_len,
                                              num_shards=2,
                                              shard_index=1,
                                              reshuffle_each_iteration=False)
        table_initializer = tf.tables_initializer()
        iterator = dataset.make_initializable_iterator()
        get_next = iterator.get_next()
        with self.test_session() as sess:
            sess.run(table_initializer)
            sess.run(iterator.initializer)
            features = sess.run(get_next)
            self.assertAllEqual(
                [
                    [-1, -1, 0],  # "f" == unknown, "e" == unknown, a
                    [2, 0, 3]
                ],  # c a eos -- eos is padding
                features["source"])
            self.assertAllEqual([3, 2], features["source_sequence_length"])
            self.assertAllEqual(
                [
                    [4, 2, 2],  # sos c c
                    [4, 1, 2]
                ],  # sos b c
                features["target_input"])
            self.assertAllEqual(
                [
                    [2, 2, 3],  # c c eos
                    [1, 2, 3]
                ],  # b c eos
                features["target_output"])
            self.assertAllEqual([3, 3], features["target_sequence_length"])

            with self.assertRaisesOpError("End of sequence"):
                sess.run(get_next)
    def testTrainInput(self):

        try1 = BatchedInput(source=np.array([[3, 4, 5, 10, 5, 5, 5, 5, 5, 5],
                                             [6, 7, 8, 6, 9, 11, 2, 2, 2, 2]]),
                            target_in=np.array([[1, 3, 4, 5, 10, 2, 2],
                                                [1, 6, 11, 7, 8, 9, 12]]),
                            target_out=np.array([[3, 4, 5, 10, 2, 2, 2],
                                                 [6, 11, 7, 8, 9, 12, 2]]),
                            source_size=np.array([10, 6]),
                            target_size=np.array([5, 7]),
                            initializer=None,
                            sos_token_id=1,
                            eos_token_id=2)

        try2 = BatchedInput(source=np.array([[11, 10, 4, 8]]),
                            target_in=np.array([[1, 10, 11, 12, 7, 8]]),
                            target_out=np.array([[10, 11, 12, 7, 8, 2]]),
                            source_size=np.array([4]),
                            target_size=np.array([6]),
                            initializer=None,
                            sos_token_id=1,
                            eos_token_id=2)

        with self.graph.as_default():

            iterator, total_num = iterator_utils.get_iterator(
                'TRAIN',
                filesobj=self.files,
                buffer_size=None,
                num_epochs=1,
                batch_size=2,
                debug_mode=True)

            table_init_op = tf.tables_initializer()

        self.sess.run(table_init_op)

        self.sess.run(iterator.initializer)

        res1 = self.sess.run([
            iterator.source, iterator.target_in, iterator.target_out,
            iterator.source_size, iterator.target_size, iterator.sos_token_id,
            iterator.eos_token_id
        ])

        self._check_all_equal(res1, try1, 1)

        res2 = self.sess.run([
            iterator.source, iterator.target_in, iterator.target_out,
            iterator.source_size, iterator.target_size, iterator.sos_token_id,
            iterator.eos_token_id
        ])

        self._check_all_equal(res2, try2, 2)

        self.assertAllEqual(total_num, 3, 'mismatch in total num somehow')
예제 #9
0
def create_eval_model(model_creator, hparams, scope=None, extra_args=None):
    """Create train graph, model, src/tgt file holders, and iterator."""
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file
    speaker_file = hparams.speaker_file

    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "eval"):
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)
        src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        src_dataset = tf.data.TextLineDataset(src_file_placeholder)
        tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder)

        spkr_table, _ = vocab_utils.create_vocab_tables(speaker_file, "", True)

        src_spk_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        tgt_spk_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        src_spkr_dataset = tf.data.TextLineDataset(src_spk_file_placeholder)
        tgt_spkr_dataset = tf.data.TextLineDataset(tgt_spk_file_placeholder)

        iterator = iterator_utils.get_iterator(
            src_dataset,
            tgt_dataset,
            src_vocab_table,
            tgt_vocab_table,
            src_spkr_dataset,
            tgt_spkr_dataset,
            spkr_table,
            hparams.batch_size,
            sos=hparams.sos,
            eos=hparams.eos,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            src_max_len=hparams.src_max_len_infer,
            tgt_max_len=hparams.tgt_max_len_infer)
        model = model_creator(hparams,
                              iterator=iterator,
                              mode=tf.contrib.learn.ModeKeys.EVAL,
                              source_vocab_table=src_vocab_table,
                              target_vocab_table=tgt_vocab_table,
                              speaker_table=spkr_table,
                              scope=scope,
                              extra_args=extra_args)
    return EvalModel(graph=graph,
                     model=model,
                     src_file_placeholder=src_file_placeholder,
                     tgt_file_placeholder=tgt_file_placeholder,
                     src_spk_file_placeholder=src_spk_file_placeholder,
                     tgt_spk_file_placeholder=tgt_spk_file_placeholder,
                     iterator=iterator)
 def testGetIterator(self):
     tf.set_random_seed(1)
     tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor(
         tf.constant(["a", "b", "c", "eos", "sos"]))
     src_dataset = tf.data.Dataset.from_tensor_slices(
         tf.constant(["f e a g", "c c a", "d", "c a"]))
     tgt_dataset = tf.data.Dataset.from_tensor_slices(
         tf.constant(["c c", "a b", "", "b c"]))
     hparams = tf.contrib.training.HParams(random_seed=3,
                                           num_buckets=1,
                                           eos="eos",
                                           sos="sos")
     batch_size = 2
     src_max_len = 5
     dataset = iterator_utils.get_iterator(src_dataset=src_dataset,
                                           tgt_dataset=tgt_dataset,
                                           src_vocab_table=src_vocab_table,
                                           tgt_vocab_table=tgt_vocab_table,
                                           batch_size=batch_size,
                                           global_batch_size=batch_size,
                                           sos=hparams.sos,
                                           eos=hparams.eos,
                                           random_seed=hparams.random_seed,
                                           num_buckets=hparams.num_buckets,
                                           src_max_len=src_max_len,
                                           reshuffle_each_iteration=False)
     table_initializer = tf.tables_initializer()
     iterator = dataset.make_initializable_iterator()
     get_next = iterator.get_next()
     with self.test_session() as sess:
         sess.run(table_initializer)
         sess.run(iterator.initializer)
         features = sess.run(get_next)
         self.assertAllEqual(
             [
                 [4, 2, 0, 3, 3],  # c a eos -- eos is padding
                 [4, 2, 2, 0, 3]
             ],  # c c a
             features["source"])
         self.assertAllEqual([4, 5], features["source_sequence_length"])
         self.assertAllEqual(
             [
                 [4, 1, 2],  # sos b c
                 [4, 0, 1]
             ],  # sos a b
             features["target_input"])
         self.assertAllEqual(
             [
                 [1, 2, 3],  # b c eos
                 [0, 1, 3]
             ],  # a b eos
             features["target_output"])
         self.assertAllEqual([3, 3], features["target_sequence_length"])
예제 #11
0
def create_train_model(model_creator,
                       hparams,
                       scope=None,
                       single_cell_fn=None):
    """Create train graph, model, and iterator."""
    src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
    tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)

    tgt_vocab_file = hparams.tgt_vocab_file

    graph = tf.Graph()

    with graph.as_default():

        tgt_vocab_table = vocab_utils.create_tgt_vocab_table(tgt_vocab_file)

        src_dataset = tf.contrib.data.TextLineDataset(src_file)
        tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file)
        skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)

        iterator = iterator_utils.get_iterator(
            src_dataset,
            tgt_dataset,
            tgt_vocab_table,
            sos=hparams.sos,
            eos=hparams.eos,
            source_reverse=hparams.source_reverse,
            random_seed=hparams.random_seed,
            src_max_len=hparams.src_max_len,
            tgt_max_len=hparams.tgt_max_len,
            skip_count=skip_count_placeholder)

        # Note: One can set model_device_fn to `tf.train.replica_device_setter(ps_tasks)` for distributed training.
        with tf.device(model_helper.get_device_str(hparams.base_gpu)):
            # model_creator: 模型
            model = model_creator(hparams,
                                  iterator=iterator,
                                  mode=tf.contrib.learn.ModeKeys.TRAIN,
                                  target_vocab_table=tgt_vocab_table,
                                  scope=scope,
                                  single_cell_fn=single_cell_fn)

    return TrainModel(graph=graph,
                      model=model,
                      iterator=iterator,
                      skip_count_placeholder=skip_count_placeholder)
예제 #12
0
def create_eval_model(model_creator, hparams, scope=None, extra_args=None):
    """Create train graph, model, src/tgt file holders, and iterator."""
    vocab_file = hparams.vocab_file
    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "eval"):
        vocab_table = vocab_utils.create_vocab_tables(vocab_file)
        data_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        kb_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        data_dataset = tf.data.TextLineDataset(data_file_placeholder)
        kb_dataset = tf.data.TextLineDataset(kb_file_placeholder)
        # this is the eval_actual iterator
        eval_iterator = iterator_utils.get_iterator(
            data_dataset,
            kb_dataset,
            vocab_table,
            batch_size=hparams.batch_size,
            t1=hparams.t1,
            t2=hparams.t2,
            eod=hparams.eod,
            len_action=hparams.len_action,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            max_dialogue_len=hparams.max_dialogue_len)
        # this is the placeholder iterator
        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(
            handle, eval_iterator.output_types, eval_iterator.output_shapes)
        batched_iterator = iterator_utils.get_batched_iterator(iterator)

        model = model_creator(hparams,
                              iterator=batched_iterator,
                              handle=handle,
                              mode=tf.contrib.learn.ModeKeys.EVAL,
                              vocab_table=vocab_table,
                              scope=scope,
                              extra_args=extra_args)

    return EvalModel(graph=graph,
                     model=model,
                     placeholder_iterator=iterator,
                     placeholder_handle=handle,
                     eval_iterator=eval_iterator,
                     data_file_placeholder=data_file_placeholder,
                     kb_file_placeholder=kb_file_placeholder)
예제 #13
0
def create_train_model(
    model_creator, hparams):

    """Create train graph, model, and iterator."""

    src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
    tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file

    graph = tf.Graph()

    with graph.as_default(), tf.container("train"):
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)

        src_dataset = tf.data.TextLineDataset(src_file)
        tgt_dataset = tf.data.TextLineDataset(tgt_file)

        iterator = iterator_utils.get_iterator(
            src_dataset,
            tgt_dataset,
            src_vocab_table,
            tgt_vocab_table,
            batch_size=hparams.batch_size,
            sos=hparams.sos,
            eos=hparams.eos,
            source_reverse=hparams.source_reverse,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            src_max_len=hparams.src_max_len,
            tgt_max_len=hparams.tgt_max_len)

        model = model_creator(
            hparams,
            iterator=iterator,
            mode=tf.contrib.learn.ModeKeys.TRAIN,
            source_vocab_table=src_vocab_table,
            target_vocab_table=tgt_vocab_table)

    return TrainModel(
                graph=graph,
                model=model,
                iterator=iterator)
예제 #14
0
def create_eval_model(model_creator, hparams, scope=None, extra_args=None):
  """Create train graph, model, src/tgt file holders, and iterator."""
  src_vocab_file = hparams.src_vocab_file
  tgt_vocab_file = hparams.tgt_vocab_file
  graph = tf.Graph()

  with graph.as_default(), tf.container(scope or "eval"):
    src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
        src_vocab_file, tgt_vocab_file, hparams.share_vocab)
    reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
        tgt_vocab_file, default_value=vocab_utils.UNK)

    src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
    tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
    src_dataset = tf.data.TextLineDataset(src_file_placeholder)
    tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder)
    iterator = iterator_utils.get_iterator(
        src_dataset,
        tgt_dataset,
        src_vocab_table,
        tgt_vocab_table,
        hparams.batch_size,
        sos=hparams.sos,
        eos=hparams.eos,
        random_seed=hparams.random_seed,
        num_buckets=hparams.num_buckets,
        src_max_len=hparams.src_max_len_infer,
        tgt_max_len=hparams.tgt_max_len_infer,
        use_char_encode=hparams.use_char_encode)
    model = model_creator(
        hparams,
        iterator=iterator,
        mode=tf.contrib.learn.ModeKeys.EVAL,
        source_vocab_table=src_vocab_table,
        target_vocab_table=tgt_vocab_table,
        reverse_target_vocab_table=reverse_tgt_vocab_table,
        scope=scope,
        extra_args=extra_args)
  return EvalModel(
      graph=graph,
      model=model,
      src_file_placeholder=src_file_placeholder,
      tgt_file_placeholder=tgt_file_placeholder,
      iterator=iterator)
예제 #15
0
def create_train_model(model_creator, hparams):
    """Create train graph, model, and iterator."""
    src_file = hparams.src_file
    tgt_file = hparams.tgt_file
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file

    graph = tf.Graph()

    with graph.as_default(), tf.container("train"):
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)

        src_dataset = tf.data.TextLineDataset(src_file)
        tgt_dataset = tf.data.TextLineDataset(tgt_file)
        skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)

        iterator = iterator_utils.get_iterator(
            src_dataset,
            tgt_dataset,
            src_vocab_table,
            tgt_vocab_table,
            batch_size=hparams.batch_size,
            sos=hparams.sos,
            eos=hparams.eos,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            src_max_len=hparams.src_max_len,
            tgt_max_len=hparams.tgt_max_len,
            skip_count=skip_count_placeholder)

        with tf.device("/cpu:0"):
            model = model_creator(hparams,
                                  iterator=iterator,
                                  mode=tf.estimator.ModeKeys.TRAIN,
                                  source_vocab_table=src_vocab_table,
                                  target_vocab_table=tgt_vocab_table)

    return TrainModel(graph=graph,
                      model=model,
                      iterator=iterator,
                      skip_count_placeholder=skip_count_placeholder)
예제 #16
0
def prepare_dataset(flags):
    """Generate the preprocessed dataset."""
    src_file = "%s.%s" % (flags.data_dir + flags.train_prefix, flags.src)
    tgt_file = "%s.%s" % (flags.data_dir + flags.train_prefix, flags.tgt)
    vocab_file = flags.data_dir + flags.vocab_prefix
    _, vocab_file = vocab_utils.check_vocab(vocab_file, flags.out_dir)
    out_file = flags.out_dir + "preprocessed_dataset"
    src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
        vocab_file)
    src_dataset = tf.data.TextLineDataset(src_file)
    tgt_dataset = tf.data.TextLineDataset(tgt_file)
    iterator = iterator_utils.get_iterator(
        src_dataset,
        tgt_dataset,
        src_vocab_table,
        tgt_vocab_table,
        batch_size=1,
        global_batch_size=1,
        sos=vocab_utils.SOS,
        eos=vocab_utils.EOS,
        random_seed=1,
        num_buckets=flags.num_buckets,
        src_max_len=flags.src_max_len,
        tgt_max_len=flags.tgt_max_len,
        filter_oversized_sequences=True,
        return_raw=True).make_initializable_iterator()

    with tf.Session() as sess:
        sess.run(tf.tables_initializer())
        sess.run(iterator.initializer)
        try:
            i = 0
            while True:
                with open(out_file + "_%d" % i, "wb") as f:
                    i += 1
                    for _ in range(100):
                        for j in sess.run(iterator.get_next()):
                            tf.logging.info(j)
                            f.write(bytearray(j))
        except tf.errors.OutOfRangeError:
            pass
    def test_get_iterator(self):
        tf.set_random_seed(1)
        src_vocab_table = tgt_vocab_table = lookup_ops.index_table_from_tensor(
            tf.constant(["a", "b", "c", "eos", "sos"]))
        src_dataset = tf.data.Dataset.from_tensor_slices(
            tf.constant(["f e a g", "c c a", "d", "c a"]))
        tgt_dataset = tf.data.Dataset.from_tensor_slices(
            tf.constant(["c c", "a b", "", "b c"]))
        hparams = tf.contrib.training.HParams(random_seed=3,
                                              num_buckets=1,
                                              eos="eos",
                                              sos="sos")
        batch_size = 2
        src_max_len = 3
        iterator = iterator_utils.get_iterator(src_dataset=src_dataset,
                                               tgt_dataset=tgt_dataset,
                                               src_vocab_table=src_vocab_table,
                                               tgt_vocab_table=tgt_vocab_table,
                                               batch_size=batch_size,
                                               sos=hparams.sos,
                                               eos=hparams.eos,
                                               random_seed=hparams.random_seed,
                                               num_buckets=hparams.num_buckets,
                                               src_max_len=src_max_len,
                                               reshuffle_each_iteration=False,
                                               delimiter=" ")
        table_initializer = tf.tables_initializer()
        source = iterator.source
        src_seq_len = iterator.source_sequence_length
        self.assertEqual([None, None], source.shape.as_list())
        with self.test_session() as sess:
            sess.run(table_initializer)
            sess.run(iterator.initializer)
            (source_v, src_seq_len_v) = sess.run((source, src_seq_len))
            print(source_v)
            print(src_seq_len_v)

            (source_v, src_seq_len_v) = sess.run((source, src_seq_len))
            print(source_v)
            print(src_seq_len_v)
예제 #18
0
def create_test_iterator(hparams, mode):
    """Create test iterator."""
    src_vocab_table = lookup_ops.index_table_from_tensor(
        tf.constant([hparams.eos, "a", "b", "c", "d"]))
    tgt_vocab_mapping = tf.constant([hparams.sos, hparams.eos, "a", "b", "c"])
    tgt_vocab_table = lookup_ops.index_table_from_tensor(tgt_vocab_mapping)
    if mode == tf.contrib.learn.ModeKeys.INFER:
        reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_tensor(
            tgt_vocab_mapping)

    src_dataset = tf.data.Dataset.from_tensor_slices(
        tf.constant(["a a b b c", "a b b"]))

    if mode != tf.contrib.learn.ModeKeys.INFER:
        tgt_dataset = tf.data.Dataset.from_tensor_slices(
            tf.constant(["a b c b c", "a b c b"]))
        return (
            # TODO changeto accomodate new inputs
            iterator_utils.get_iterator(src_dataset=src_dataset,
                                        tgt_dataset=tgt_dataset,
                                        src_vocab_table=src_vocab_table,
                                        tgt_vocab_table=tgt_vocab_table,
                                        batch_size=hparams.batch_size,
                                        sos=hparams.sos,
                                        eos=hparams.eos,
                                        random_seed=hparams.random_seed,
                                        num_buckets=hparams.num_buckets),
            src_vocab_table,
            tgt_vocab_table)
    else:
        return (
            # TODO changeto accomodate new inputs
            iterator_utils.get_infer_iterator(src_dataset=src_dataset,
                                              src_vocab_table=src_vocab_table,
                                              eos=hparams.eos,
                                              batch_size=hparams.batch_size),
            src_vocab_table,
            tgt_vocab_table,
            reverse_tgt_vocab_table)
예제 #19
0
def mytest_iterator():
    src_dataset = tf.data.TextLineDataset(hparam.train_src)
    tgt_dataset = tf.data.TextLineDataset(hparam.train_tgt)
    src, tgt_in, tgt_out, src_seq_len, tgt_seq_len, initializer, _ = get_iterator(
        src_dataset, tgt_dataset, hparam)

    with tf.Session() as sess:
        sess.run(tf.tables_initializer())
        sess.run(initializer)

        for i in range(1):
            try:
                _src, _tgt_in, _tgt_out, _src_seq_len, _tgt_seq_len = sess.run(
                    [src, tgt_in, tgt_out, src_seq_len, tgt_seq_len])
                print('src', _src)
                print('tgt_in', _tgt_in)
                print('tgt_out', _tgt_out)
                print('src_seq_len', _src_seq_len)
                print('tgt_seq_len', _tgt_seq_len)
            except tf.errors.OutOfRangeError:
                print('xxxxxxxxxxxxxxx')
                sess.run(initializer)
    def setUp(self):

        super(ModelTest, self).setUp()

        self.graph = tf.Graph()

        self.session = tf.Session(graph=self.graph)

        with self.graph.as_default():

            self.iterator, _ = iterator_utils.get_iterator(
                'TRAIN',
                filesobj=TRAIN_FILES,
                buffer_size=TRAIN_HPARAMS.buffer_size,
                num_epochs=TRAIN_HPARAMS.num_epochs,
                batch_size=TRAIN_HPARAMS.batch_size,
                debug_mode=True)

            self.model = AttentionModel(TRAIN_HPARAMS, self.iterator, 'TRAIN')

            self.table_init_op = tf.tables_initializer()

            self.vars_init_op = tf.global_variables_initializer()
예제 #21
0
def create_eval_model(model_creator, hparams, scope=None, single_cell_fn=None):
    """Create train graph, model, src/tgt file holders, and iterator."""
    tgt_vocab_file = hparams.tgt_vocab_file
    graph = tf.Graph()

    with graph.as_default():

        tgt_vocab_table = vocab_utils.create_tgt_vocab_table(tgt_vocab_file)

        src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        src_dataset = tf.contrib.data.TextLineDataset(src_file_placeholder)
        tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file_placeholder)
        iterator = iterator_utils.get_iterator(
            src_dataset,
            tgt_dataset,
            tgt_vocab_table,
            sos=hparams.sos,
            eos=hparams.eos,
            source_reverse=hparams.source_reverse,
            random_seed=hparams.random_seed,
            src_max_len=hparams.src_max_len,
            tgt_max_len=hparams.tgt_max_len)

        model = model_creator(hparams,
                              iterator=iterator,
                              mode=tf.contrib.learn.ModeKeys.EVAL,
                              target_vocab_table=tgt_vocab_table,
                              scope=scope,
                              single_cell_fn=single_cell_fn)

    return EvalModel(graph=graph,
                     model=model,
                     src_file_placeholder=src_file_placeholder,
                     tgt_file_placeholder=tgt_file_placeholder,
                     iterator=iterator)
예제 #22
0
if __name__ == "__main__":
    src_dataset_name = "tst2012.en"
    tgt_dataset_name = "tst2012.vi"
    src_vocab_name = "vocab.en"
    tgt_vocab_name = "vocab.vi"

    src_dataset_path = os.path.join(BASE, src_dataset_name)
    tgt_dataset_path = os.path.join(BASE, tgt_dataset_name)
    src_vocab_path = os.path.join(BASE, src_vocab_name)
    tgt_vocab_path = os.path.join(BASE, tgt_vocab_name)

    batch_size = 128
    num_buckets = 5

    iterator = get_iterator(src_dataset_path=src_dataset_path, tgt_dataset_path=tgt_dataset_path,
                            src_vocab_path=src_vocab_path, tgt_vocab_path=tgt_vocab_path,
                            batch_size=batch_size, num_buckets=num_buckets, is_shuffle=False,
                            src_max_len=None, tgt_max_len=None)

    model = EvalModel(iterator=iterator)

    session = tf.Session()
    session.run(tf.tables_initializer())
    session.run(tf.global_variables_initializer())
    session.run(iterator.initializer)

    for i in range(1000):
        print(i)
        eval_loss, predict_count, batch_size = model.eval(eval_sess=session)
    # print(eval_loss)
    # print(predict_count)
    # print(batch_size)
예제 #23
0
    def testGetIteratorWithSkipCount(self):
        vocab_table = lookup_ops.index_table_from_tensor(
            tf.constant(["a", "b", "c", "eos", "sos"]))
        src_dataset = tf.contrib.data.Dataset.from_tensor_slices(
            tf.constant(["c c a", "c a", "d", "f e a g"]))
        tgt_dataset = tf.contrib.data.Dataset.from_tensor_slices(
            tf.constant(["a b", "b c", "", "c c"]))
        hparams = tf.contrib.training.HParams(
            random_seed=3,
            num_buckets=5,
            source_reverse=False,
            eos="eos",
            sos="sos")
        batch_size = 2
        src_max_len = 3
        skip_count = tf.placeholder(shape=(), dtype=tf.int64)
        iterator = iterator_utils.get_iterator(
            src_dataset=src_dataset,
            tgt_dataset=tgt_dataset,
            vocab_table=vocab_table,
            batch_size=batch_size,
            sos=hparams.sos,
            eos=hparams.eos,
            src_reverse=hparams.source_reverse,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            src_max_len=src_max_len,
            skip_count=skip_count)
        table_initializer = tf.tables_initializer()
        source = iterator.source
        target_input = iterator.target_input
        target_output = iterator.target_output
        src_seq_len = iterator.source_sequence_length
        tgt_seq_len = iterator.target_sequence_length
        self.assertEqual([None, None], source.shape.as_list())
        self.assertEqual([None, None], target_input.shape.as_list())
        self.assertEqual([None, None], target_output.shape.as_list())
        self.assertEqual([None], src_seq_len.shape.as_list())
        self.assertEqual([None], tgt_seq_len.shape.as_list())
        with self.test_session() as sess:
            sess.run(table_initializer)
            sess.run(iterator.initializer, feed_dict={skip_count: 3})

            (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = (
                sess.run((source, src_seq_len, target_input, target_output,
                          tgt_seq_len)))
            self.assertAllEqual(
                [[-1, -1, 0]],  # "f" == unknown, "e" == unknown, a
                source_v)
            self.assertAllEqual([3], src_len_v)
            self.assertAllEqual(
                [[4, 2, 2]],  # sos c c
                target_input_v)
            self.assertAllEqual(
                [[2, 2, 3]],  # c c eos
                target_output_v)
            self.assertAllEqual([3], tgt_len_v)

            with self.assertRaisesOpError("End of sequence"):
                sess.run(source)

            # Re-init iterator with skip_count=0.
            sess.run(iterator.initializer, feed_dict={skip_count: 0})

            (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = (
                sess.run((source, src_seq_len, target_input, target_output,
                          tgt_seq_len)))
            self.assertAllEqual(
                [[-1, -1, 0],  # "f" == unknown, "e" == unknown, a
                 [2, 0, 3]],  # c a eos -- eos is padding
                source_v)
            self.assertAllEqual([3, 2], src_len_v)
            self.assertAllEqual(
                [[4, 2, 2],  # sos c c
                 [4, 1, 2]],  # sos b c
                target_input_v)
            self.assertAllEqual(
                [[2, 2, 3],  # c c eos
                 [1, 2, 3]],  # b c eos
                target_output_v)
            self.assertAllEqual([3, 3], tgt_len_v)

            (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = (
                sess.run((source, src_seq_len, target_input, target_output,
                          tgt_seq_len)))
            self.assertAllEqual(
                [[2, 2, 0]],  # c c a
                source_v)
            self.assertAllEqual([3], src_len_v)
            self.assertAllEqual(
                [[4, 0, 1]],  # sos a b
                target_input_v)
            self.assertAllEqual(
                [[0, 1, 3]],  # a b eos
                target_output_v)
            self.assertAllEqual([3], tgt_len_v)

            with self.assertRaisesOpError("End of sequence"):
                sess.run(source)
예제 #24
0
    def testTrainInput(self):

        try1 = BatchedInput(
            src_chars=np.array([[[41, 53, 52, 42, 59, 41, 47, 43, 52, 42, 53],
                                 [52, 53, 57, 53, 58, 56, 53, 57, 2, 2, 2],
                                 [51, 47, 57, 51, 53, 57, 2, 2, 2, 2, 2],
                                 [15, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]],
                                [[52, 53, 57, 58, 56, 53, 57, 2, 2, 2, 2],
                                 [57, 53, 51, 53, 57, 2, 2, 2, 2, 2, 2],
                                 [59, 52, 53, 2, 2, 2, 2, 2, 2, 2, 2],
                                 [15, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]]]),
            trg_chars_in=np.array([[[1, 40, 54, 45, 58, 45, 50, 43, 2, 2],
                                    [1, 51, 57, 54, 55, 41, 48, 58, 41, 55],
                                    [1, 15, 2, 2, 2, 2, 2, 2, 2, 2],
                                    [2, 2, 2, 2, 2, 2, 2, 2, 2, 2]],
                                   [[1, 59, 41, 2, 2, 2, 2, 2, 2, 2],
                                    [1, 37, 54, 41, 2, 2, 2, 2, 2, 2],
                                    [1, 51, 50, 41, 2, 2, 2, 2, 2, 2],
                                    [1, 15, 2, 2, 2, 2, 2, 2, 2, 2]]]),
            trg_chars_out=np.array([[[40, 54, 45, 58, 45, 50, 43, 2, 2, 2],
                                     [51, 57, 54, 55, 41, 48, 58, 41, 55, 2],
                                     [15, 2, 2, 2, 2, 2, 2, 2, 2, 2],
                                     [2, 2, 2, 2, 2, 2, 2, 2, 2, 2]],
                                    [[59, 41, 2, 2, 2, 2, 2, 2, 2, 2],
                                     [37, 54, 41, 2, 2, 2, 2, 2, 2, 2],
                                     [51, 50, 41, 2, 2, 2, 2, 2, 2, 2],
                                     [15, 2, 2, 2, 2, 2, 2, 2, 2, 2]]]),
            trg_chars_lens=np.array([[8, 10, 2, 0], [3, 4, 4, 2]]),
            trg_words_in=np.array([[1, 100, 110, 9, 2], [1, 98, 348, 81, 9]]),
            trg_words_out=np.array([[100, 110, 9, 2, 2], [98, 348, 81, 9, 2]]),
            src_size=np.array([4, 4]),
            trg_size=np.array([4, 5]),
            initializer=None,
            sos_token_id=1,
            eos_token_id=2)

        try2 = BatchedInput(src_chars=np.array([[[46, 59, 40, 53, 2, 2],
                                                 [51, 59, 41, 46, 53, 2],
                                                 [42, 47, 50, 53, 45, 53],
                                                 [15, 2, 2, 2, 2, 2]]]),
                            trg_chars_in=np.array(
                                [[[1, 56, 44, 41, 54, 41, 2],
                                  [1, 59, 37, 55, 2, 2, 2],
                                  [1, 37, 2, 2, 2, 2, 2],
                                  [1, 59, 44, 51, 48, 41, 2],
                                  [1, 48, 51, 56, 2, 2, 2],
                                  [1, 51, 42, 2, 2, 2, 2],
                                  [1, 40, 45, 37, 48, 51, 43],
                                  [1, 15, 2, 2, 2, 2, 2]]]),
                            trg_chars_out=np.array(
                                [[[56, 44, 41, 54, 41, 2, 2],
                                  [59, 37, 55, 2, 2, 2, 2],
                                  [37, 2, 2, 2, 2, 2, 2],
                                  [59, 44, 51, 48, 41, 2, 2],
                                  [48, 51, 56, 2, 2, 2, 2],
                                  [51, 42, 2, 2, 2, 2, 2],
                                  [40, 45, 37, 48, 51, 43, 2],
                                  [15, 2, 2, 2, 2, 2, 2]]]),
                            trg_chars_lens=np.array([[6, 4, 2, 6, 4, 3, 7,
                                                      2]]),
                            trg_words_in=np.array(
                                [[1, 121, 122, 14, 836, 416, 37, 1471, 9]]),
                            trg_words_out=np.array(
                                [[121, 122, 14, 836, 416, 37, 1471, 9, 2]]),
                            src_size=np.array([4]),
                            trg_size=np.array([9]),
                            initializer=None,
                            sos_token_id=1,
                            eos_token_id=2)

        with self.graph.as_default():

            iterator, total_num = iterator_utils.get_iterator(
                'TRAIN',
                filesobj=self.files,
                buffer_size=None,
                num_epochs=1,
                batch_size=2,
                debug_mode=True)

            table_init_op = tf.tables_initializer()

        self.sess.run(table_init_op)

        self.sess.run(iterator.initializer)

        res1 = self.sess.run([
            iterator.src_chars, iterator.trg_chars_in, iterator.trg_chars_out,
            iterator.trg_chars_lens, iterator.trg_words_in,
            iterator.trg_words_out, iterator.src_size, iterator.trg_size,
            iterator.sos_token_id, iterator.eos_token_id
        ])

        self._check_all_equal(res1, try1, 1)

        res2 = self.sess.run([
            iterator.src_chars, iterator.trg_chars_in, iterator.trg_chars_out,
            iterator.trg_chars_lens, iterator.trg_words_in,
            iterator.trg_words_out, iterator.src_size, iterator.trg_size,
            iterator.sos_token_id, iterator.eos_token_id
        ])

        self._check_all_equal(res2, try2, 2)

        self.assertAllEqual(total_num, 3, 'mismatch in total num somehow')
예제 #25
0
#graph = tf.Graph()
scope="train"
#with graph.as_default(), tf.container(scope):
src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(src_vocab_file, tgt_vocab_file, args.share_vocab)
src_dataset = tf.data.TextLineDataset(src_file)
tgt_dataset = tf.data.TextLineDataset(tgt_file)
skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)

iterator = iterator_utils.get_iterator(
				src_dataset,
				tgt_dataset,
				src_vocab_table,
				tgt_vocab_table,
				batch_size=args.batch_size,
				sos=args.sos,
				eos=args.eos,
				random_seed=args.random_seed,
				num_buckets=args.num_buckets,
				src_max_len=args.src_max_len,
				tgt_max_len=args.tgt_max_len,
				skip_count=skip_count_placeholder,
				num_shards=args.num_workers,
				shard_index=args.jobid)

# get source
source = iterator.source
if args.time_major:
	source = tf.transpose(source)

embedding_encoder, embedding_decoder = create_emb_for_encoder_and_decoder(share_vocab=args.share_vocab,
																																					src_vocab_size=src_vocab_size,
예제 #26
0
def run():

    train_hparams = misc_utils.get_train_hparams()

    dev_hparams = misc_utils.get_dev_hparams()

    test_hparams = misc_utils.get_test_hparams()

    train_graph = tf.Graph()

    dev_graph = tf.Graph()

    test_graph = tf.Graph()

    with train_graph.as_default():

        train_iterator, train_total_num = iterator_utils.get_iterator(
            regime=train_hparams.regime,
            filesobj=train_hparams.filesobj,
            buffer_size=train_hparams.buffer_size,
            num_epochs=train_hparams.num_epochs,
            batch_size=train_hparams.batch_size)

        train_model = BuildTrainModel(hparams=train_hparams,
                                      iterator=train_iterator,
                                      graph=train_graph)

        train_vars_init_op = tf.global_variables_initializer()

        train_tables_init_op = tf.tables_initializer()

        tf.get_default_graph().finalize()

    with dev_graph.as_default():

        dev_iterator, dev_total_num = iterator_utils.get_iterator(
            dev_hparams.regime,
            filesobj=dev_hparams.filesobj,
            buffer_size=dev_hparams.buffer_size,
            num_epochs=dev_hparams.num_epochs,
            batch_size=dev_hparams.batch_size)

        dev_model = BuildEvalModel(hparams=dev_hparams,
                                   iterator=dev_iterator,
                                   graph=dev_graph)

        dev_tables_init_op = tf.tables_initializer()

        tf.get_default_graph().finalize()

    with test_graph.as_default():

        test_iterator, test_total_num = iterator_utils.get_iterator(
            test_hparams.regime,
            filesobj=test_hparams.filesobj,
            buffer_size=test_hparams.buffer_size,
            num_epochs=test_hparams.num_epochs,
            batch_size=test_hparams.batch_size)

        test_model = BuildInferModel(
            hparams=test_hparams,
            iterator=test_iterator,
            graph=test_graph,
            infer_file_path=test_hparams.translation_file_path,
            infer_chars_file_path=test_hparams.char_translation_file_path)

        test_tables_init_op = tf.tables_initializer()

        tf.get_default_graph().finalize()

    train_session = tf.Session(graph=train_graph)

    dev_session = tf.Session(graph=dev_graph)

    test_session = tf.Session(graph=test_graph)

    train_steps = misc_utils.count_num_steps(train_hparams.num_epochs,
                                             train_total_num,
                                             train_hparams.batch_size)

    eval_steps = misc_utils.count_num_steps(1, dev_total_num,
                                            dev_hparams.batch_size)

    num_test_steps = misc_utils.count_num_steps(1, test_total_num,
                                                test_hparams.batch_size)

    eval_count = dev_hparams.num_steps_to_eval

    train_session.run(train_tables_init_op)

    train_session.run(train_iterator.initializer)

    train_session.run(train_vars_init_op)

    with tqdm(total=train_steps) as prog:

        for step in range(train_steps):

            train_model.train(train_session)

            if step % eval_count == 0:

                dev_loss = misc_utils.eval_once(train_model, dev_model,
                                                train_session, dev_session,
                                                step, train_hparams.chkpts_dir,
                                                dev_iterator, eval_steps,
                                                dev_tables_init_op)

                print('dev loss at step {} = {}'.format(step, dev_loss))

            prog.update(1)

    misc_utils.write_translations(train_model, train_session,
                                  train_hparams.chkpts_dir, step, test_model,
                                  test_session, test_tables_init_op,
                                  test_iterator, num_test_steps)
예제 #27
0
def train():
	"""Train a translation model."""
	create_new_model = params['create_new_model']
	out_dir = params['out_dir']
	model_creator = nmt_model.Model # Create model graph
	summary_name = "train_log"

	# Setting up session and initilize input data iterators
	src_file = params['src_data_file']
	tgt_file = params['tgt_data_file']
	dev_src_file = params['dev_src_file']
	dev_tgt_file = params['dev_tgt_file']
	test_src_file = params['test_src_file']
	test_tgt_file = params['test_tgt_file']


	char_vocab_file = params['enc_char_map_path']
	src_vocab_file = params['src_vocab_file']
	tgt_vocab_file = params['tgt_vocab_file']
	if(src_vocab_file == '' or src_vocab_file == ''):
		raise ValueError("vocab_file '%s' not given in params.") 

	graph = tf.Graph()

  	# Log and output files
	log_file = os.path.join(out_dir, "log_%d" % time.time())
	log_f = tf.gfile.GFile(log_file, mode="a")
	utils.print_out("# log_file=%s" % log_file, log_f)


	# Model run params
	num_epochs = params['num_epochs']
	batch_size = params['batch_size']
	steps_per_stats = params['steps_per_stats']

	utils.print_out("# Epochs=%s, Batch Size=%s, Steps_per_Stats=%s" % (num_epochs, batch_size, steps_per_stats), None)

	with graph.as_default():
		src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(src_vocab_file, tgt_vocab_file, params['share_vocab'])
		char_vocab_table = vocab_utils.get_char_table(char_vocab_file)
		reverse_target_table = lookup_ops.index_to_string_table_from_file(tgt_vocab_file, default_value=params['unk'])

		src_dataset = tf.data.TextLineDataset(src_file)
		tgt_dataset = tf.data.TextLineDataset(tgt_file)

		batched_iter = iterator_utils.get_iterator(src_dataset,
											   tgt_dataset,
											   char_vocab_table,
											   src_vocab_table,
											   tgt_vocab_table,
											   batch_size=batch_size,
											   sos=params['sos'],
											   eos=params['eos'],
											   char_pad = params['char_pad'],
											   num_buckets=params['num_buckets'],
											   num_epochs = params['num_epochs'],
											   src_max_len=params['src_max_len'],
											   tgt_max_len=params['tgt_max_len'],
											   src_char_max_len = params['char_max_len']
											   )

		# Summary writer
		summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),graph)


		# Preload validation data for decoding.
		dev_src_dataset = tf.data.TextLineDataset(dev_src_file)
		dev_tgt_dataset = tf.data.TextLineDataset(dev_tgt_file)
		dev_batched_iterator = iterator_utils.get_iterator(dev_src_dataset,
														   dev_tgt_dataset,
														   char_vocab_table,
														   src_vocab_table,
														   tgt_vocab_table,
														   batch_size=batch_size,
														   sos=params['sos'],
														   eos=params['eos'],
														   char_pad = params['char_pad'],
														   num_buckets=params['num_buckets'],
														   num_epochs = params['num_epochs'],
														   src_max_len=params['src_max_len'],
														   tgt_max_len=params['tgt_max_len'],
														   src_char_max_len = params['char_max_len']
														   )

		# Preload test data for decoding.
		test_src_dataset = tf.data.TextLineDataset(test_src_file)
		test_tgt_dataset = tf.data.TextLineDataset(test_tgt_file)
		test_batched_iterator = iterator_utils.get_iterator(test_src_dataset,
														   test_tgt_dataset,
														   char_vocab_table,
														   src_vocab_table,
														   tgt_vocab_table,
														   batch_size=batch_size,
														   sos=params['sos'],
														   eos=params['eos'],
														   char_pad = params['char_pad'],
														   num_buckets=params['num_buckets'],
														   num_epochs = params['num_epochs'],
														   src_max_len=params['src_max_len'],
														   tgt_max_len=params['tgt_max_len'],
														   src_char_max_len = params['char_max_len']
														   )

		config_proto = utils.get_config_proto(log_device_placement=params['log_device_placement'])
		sess = tf.Session(config=config_proto)


		with sess.as_default():
			

			train_model = model_creator(mode = params['mode'],
										train_iterator = batched_iter,
										val_iterator = dev_batched_iterator,
										char_vocab_table = char_vocab_table,
										source_vocab_table=src_vocab_table,
										target_vocab_table=tgt_vocab_table,
										reverse_target_table = reverse_target_table)

			loaded_train_model, global_step = create_or_load_model(train_model, params['out_dir'],session=sess,name="train",
																	log_f = log_f, create=create_new_model)
			
			sess.run([batched_iter.initializer,dev_batched_iterator.initializer, test_batched_iterator.initializer])


			start_train_time = time.time()
			utils.print_out("# Start step %d, lr %g, %s" %(global_step, loaded_train_model.learning_rate.eval(session=sess), time.ctime()), log_f)
			
			# Reset statistics
			stats = init_stats()

			steps_per_epoch = int(np.ceil(utils.get_file_row_size(src_file) / batch_size))
			utils.print_out("Total steps per epoch: %d" % steps_per_epoch)

			def train_step(model, sess):	
				return model.train(sess)
			def dev_step(model, sess):
				total_steps = int(np.ceil(utils.get_file_row_size(dev_src_file) / batch_size ))
				total_dev_loss = 0.0
				total_accuracy = 0.0
				for _ in range(total_steps):
					dev_result_step = model.dev(sess)
					dev_softmax_scores, dev_loss, tgt_output_ids,_,_,_,_ = dev_result_step
					total_dev_loss += dev_loss * params['batch_size']
					total_accuracy += evaluation_utils._accuracy(dev_softmax_scores, tgt_output_ids,  None, None)
				return (total_dev_loss/total_steps, total_accuracy/total_steps)


			for epoch_step in range(num_epochs): 
				for curr_step in range(int(np.ceil(steps_per_epoch))):
					start_time = time.time()
					step_result = train_step(loaded_train_model, sess)
					global_step = update_stats(stats, summary_writer, start_time, step_result)

    				# Logging Step
					if(curr_step % params['steps_per_stats'] == 0):
						check_stats(stats, global_step, steps_per_stats, log_f)



					# Evaluation
					if(curr_step % params['steps_per_devRun'] == 0):
						dev_step_loss, dev_step_acc = dev_step(loaded_train_model, sess)
						utils.print_out("Dev Step total loss, Accuracy: %f, %f" % (dev_step_loss, dev_step_acc), log_f)

				utils.print_out("# Finished an epoch, epoch completed %d" % epoch_step)
				loaded_train_model.saver.save(sess,  os.path.join(out_dir, "translate.ckpt"), global_step=global_step)
				dev_step_loss = dev_step(loaded_train_model, sess)


			utils.print_time("# Done training!", start_train_time)
			summary_writer.close()
예제 #28
0
def self_play_iterator_creator(hparams, num_workers, jobid):
    """create a self play iterator. There are iterators that will be created here.
  A supervised training iterator used for supervised learning. A full text
  iterator and structured iterator used for reinforcement learning self play.
  Full text iterators feeds data from text files while structured iterators
  are initialized directly from objects. The former one is used for traiing.
  The later one is used for self play dialogue generation to eliminate the
  need of serializing them into actual text
  files.
  """
    vocab_table = vocab_utils.create_vocab_tables(hparams.vocab_file)
    data_dataset = tf.data.TextLineDataset(hparams.train_data)
    kb_dataset = tf.data.TextLineDataset(hparams.train_kb)
    skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)
    # this is the actual iterator for supervised training
    train_iterator = iterator_utils.get_iterator(
        data_dataset,
        kb_dataset,
        vocab_table,
        batch_size=hparams.batch_size,
        t1=hparams.t1,
        t2=hparams.t2,
        eod=hparams.eod,
        len_action=hparams.len_action,
        random_seed=hparams.random_seed,
        num_buckets=hparams.num_buckets,
        max_dialogue_len=hparams.max_dialogue_len,
        skip_count=skip_count_placeholder,
        num_shards=num_workers,
        shard_index=jobid)

    # this is the actual iterator for self_play_fulltext_iterator
    data_placeholder = tf.placeholder(shape=[None],
                                      dtype=tf.string,
                                      name="src_ph")
    kb_placeholder = tf.placeholder(shape=[None],
                                    dtype=tf.string,
                                    name="kb_ph")
    batch_size_placeholder = tf.placeholder(shape=[],
                                            dtype=tf.int64,
                                            name="bs_ph")

    dataset_data = tf.data.Dataset.from_tensor_slices(data_placeholder)
    kb_dataset = tf.data.Dataset.from_tensor_slices(kb_placeholder)

    self_play_fulltext_iterator = iterator_utils.get_infer_iterator(
        dataset_data,
        kb_dataset,
        vocab_table,
        batch_size=batch_size_placeholder,
        eod=hparams.eod,
        len_action=hparams.len_action,
        self_play=True)

    # this is the actual iterator for self_play_structured_iterator
    self_play_structured_iterator = tf.data.Iterator.from_structure(
        self_play_fulltext_iterator.output_types,
        self_play_fulltext_iterator.output_shapes)
    iterators = [
        train_iterator, self_play_fulltext_iterator,
        self_play_structured_iterator
    ]

    # this is the list of placeholders
    placeholders = [
        data_placeholder, kb_placeholder, batch_size_placeholder,
        skip_count_placeholder
    ]
    return iterators, placeholders
  def testGetIteratorWithSkipCount(self):
    tf.set_random_seed(1)
    tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor(
        tf.constant(["a", "b", "c", "eos", "sos"]))
    src_dataset = tf.data.Dataset.from_tensor_slices(
        tf.constant(["c a", "c c a", "d", "f e a g"]))
    tgt_dataset = tf.data.Dataset.from_tensor_slices(
        tf.constant(["b c", "a b", "", "c c"]))
    hparams = tf.contrib.training.HParams(
        random_seed=3,
        num_buckets=5,
        eos="eos",
        sos="sos")
    batch_size = 2
    src_max_len = 3
    skip_count = tf.placeholder(shape=(), dtype=tf.int64)
    dataset = iterator_utils.get_iterator(
        src_dataset=src_dataset,
        tgt_dataset=tgt_dataset,
        src_vocab_table=src_vocab_table,
        tgt_vocab_table=tgt_vocab_table,
        batch_size=batch_size,
        sos=hparams.sos,
        eos=hparams.eos,
        random_seed=hparams.random_seed,
        num_buckets=hparams.num_buckets,
        src_max_len=src_max_len,
        skip_count=skip_count,
        reshuffle_each_iteration=False)
    table_initializer = tf.tables_initializer()
    iterator = dataset.make_initializable_iterator()
    get_next = iterator.get_next()
    with self.test_session() as sess:
      sess.run(table_initializer)
      sess.run(iterator.initializer, feed_dict={skip_count: 1})
      features = sess.run(get_next)

      self.assertAllEqual(
          [[-1, -1, 0],  # "f" == unknown, "e" == unknown, a
           [2, 2, 0]],  # c c a
          features["source"])
      self.assertAllEqual([3, 3], features["source_sequence_length"])
      self.assertAllEqual(
          [[4, 2, 2],    # sos c c
           [4, 0, 1]],   # sos a b
          features["target_input"])
      self.assertAllEqual(
          [[2, 2, 3],   # c c eos
           [0, 1, 3]],  # a b eos
          features["target_output"])
      self.assertAllEqual([3, 3], features["target_sequence_length"])

      # Re-init iterator with skip_count=0.
      sess.run(iterator.initializer, feed_dict={skip_count: 0})

      features = sess.run(get_next)

      self.assertAllEqual(
          [[-1, -1, 0],  # "f" == unknown, "e" == unknown, a
           [2, 2, 0]],   # c c a
          features["source"])
      self.assertAllEqual([3, 3], features["source_sequence_length"])
      self.assertAllEqual(
          [[4, 2, 2],   # sos c c
           [4, 0, 1]],  # sos a b
          features["target_input"])
      self.assertAllEqual(
          [[2, 2, 3],   # c c eos
           [0, 1, 3]],  # a b eos
          features["target_output"])
      self.assertAllEqual([3, 3], features["target_sequence_length"])
예제 #30
0
    def testGetIterator(self):
        tf.set_random_seed(1)
        tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor(
            tf.constant(["a", "b", "c", "eos", "sos"]))
        src_dataset = tf.data.Dataset.from_tensor_slices(
            tf.constant(["f e a g", "c c a", "d", "c a"]))
        tgt_dataset = tf.data.Dataset.from_tensor_slices(
            tf.constant(["c c", "a b", "", "b c"]))
        hparams = tf.contrib.training.HParams(random_seed=3,
                                              num_buckets=5,
                                              eos="eos",
                                              sos="sos")
        batch_size = 2
        src_max_len = 3
        iterator = iterator_utils.get_iterator(src_dataset=src_dataset,
                                               tgt_dataset=tgt_dataset,
                                               src_vocab_table=src_vocab_table,
                                               tgt_vocab_table=tgt_vocab_table,
                                               batch_size=batch_size,
                                               sos=hparams.sos,
                                               eos=hparams.eos,
                                               random_seed=hparams.random_seed,
                                               num_buckets=hparams.num_buckets,
                                               src_max_len=src_max_len,
                                               reshuffle_each_iteration=False)
        table_initializer = tf.tables_initializer()
        source = iterator.source
        target_input = iterator.target_input
        target_output = iterator.target_output
        src_seq_len = iterator.source_sequence_length
        tgt_seq_len = iterator.target_sequence_length
        self.assertEqual([None, None], source.shape.as_list())
        self.assertEqual([None, None], target_input.shape.as_list())
        self.assertEqual([None, None], target_output.shape.as_list())
        self.assertEqual([None], src_seq_len.shape.as_list())
        self.assertEqual([None], tgt_seq_len.shape.as_list())
        with self.test_session() as sess:
            sess.run(table_initializer)
            sess.run(iterator.initializer)

            (source_v, src_len_v, target_input_v, target_output_v,
             tgt_len_v) = (sess.run((source, src_seq_len, target_input,
                                     target_output, tgt_seq_len)))
            self.assertAllEqual(
                [
                    [2, 0, 3],  # c a eos -- eos is padding
                    [-1, -1, 0]
                ],  # "f" == unknown, "e" == unknown, a
                source_v)
            self.assertAllEqual([2, 3], src_len_v)
            self.assertAllEqual(
                [
                    [4, 1, 2],  # sos b c
                    [4, 2, 2]
                ],  # sos c c
                target_input_v)
            self.assertAllEqual(
                [
                    [1, 2, 3],  # b c eos
                    [2, 2, 3]
                ],  # c c eos
                target_output_v)
            self.assertAllEqual([3, 3], tgt_len_v)

            (source_v, src_len_v, target_input_v, target_output_v,
             tgt_len_v) = (sess.run((source, src_seq_len, target_input,
                                     target_output, tgt_seq_len)))
            self.assertAllEqual(
                [[2, 2, 0]],  # c c a
                source_v)
            self.assertAllEqual([3], src_len_v)
            self.assertAllEqual(
                [[4, 0, 1]],  # sos a b
                target_input_v)
            self.assertAllEqual(
                [[0, 1, 3]],  # a b eos
                target_output_v)
            self.assertAllEqual([3], tgt_len_v)

            with self.assertRaisesOpError("End of sequence"):
                sess.run(source)