コード例 #1
0
def create_train_model(model_creator,
                       hparams,
                       scope=None,
                       num_workers=1,
                       jobid=0,
                       extra_args=None):
    """Create train graph, model, and iterator."""
    src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
    tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file

    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "train"):
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)

        src_dataset = tf.contrib.data.TextLineDataset(src_file)
        tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file)
        skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)

        iterator = iterator_utils.get_iterator(
            src_dataset,
            tgt_dataset,
            src_vocab_table,
            tgt_vocab_table,
            batch_size=hparams.batch_size,
            sos=hparams.sos,
            eos=hparams.eos,
            source_reverse=hparams.source_reverse,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            src_max_len=hparams.src_max_len,
            tgt_max_len=hparams.tgt_max_len,
            skip_count=skip_count_placeholder,
            num_shards=num_workers,
            shard_index=jobid)

        # Note: One can set model_device_fn to
        # `tf.train.replica_device_setter(ps_tasks)` for distributed training.
        model_device_fn = None
        if extra_args: model_device_fn = extra_args.model_device_fn
        with tf.device(model_device_fn):
            model = model_creator(hparams,
                                  iterator=iterator,
                                  mode=tf.contrib.learn.ModeKeys.TRAIN,
                                  source_vocab_table=src_vocab_table,
                                  target_vocab_table=tgt_vocab_table,
                                  scope=scope,
                                  extra_args=extra_args)

    return TrainModel(graph=graph,
                      model=model,
                      iterator=iterator,
                      skip_count_placeholder=skip_count_placeholder)
コード例 #2
0
def create_eval_model(model_creator, hparams, scope=None, extra_args=None):
    """Create train graph, model, src/tgt file holders, and iterator."""
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file
    lbl_vocab_file = hparams.lbl_vocab_file
    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "eval"):
        src_vocab_table, tgt_vocab_table, lbl_vocab_table = \
          vocab_utils.create_vocab_tables(src_vocab_file, tgt_vocab_file,
                                          lbl_vocab_file, hparams.share_vocab)
        reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
            tgt_vocab_file, default_value=vocab_utils.UNK)
        reverse_lbl_vocab_table = lookup_ops.index_to_string_table_from_file(
            lbl_vocab_file, default_value=vocab_utils.UNK)

        src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        lbl_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        src_dataset = tf.data.TextLineDataset(src_file_placeholder)
        tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder)
        lbl_dataset = tf.data.TextLineDataset(lbl_file_placeholder)

        iterator = iterator_utils.get_iterator(
            src_dataset,
            tgt_dataset,
            lbl_dataset,
            src_vocab_table,
            tgt_vocab_table,
            lbl_vocab_table,
            hparams.batch_size,
            sos=hparams.sos,
            eos=hparams.eos,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            src_max_len=hparams.src_max_len_infer,
            tgt_max_len=hparams.tgt_max_len_infer)
        model = model_creator(
            hparams,
            iterator=iterator,
            mode=tf.contrib.learn.ModeKeys.EVAL,
            source_vocab_table=src_vocab_table,
            target_vocab_table=tgt_vocab_table,
            label_vocab_table=lbl_vocab_table,
            reverse_target_vocab_table=reverse_tgt_vocab_table,
            reverse_target_intent_vocab_table=reverse_lbl_vocab_table,
            scope=scope,
            extra_args=extra_args)
    return EvalModel(graph=graph,
                     model=model,
                     src_file_placeholder=src_file_placeholder,
                     tgt_file_placeholder=tgt_file_placeholder,
                     lbl_file_placeholder=lbl_file_placeholder,
                     iterator=iterator)
コード例 #3
0
ファイル: read_tf_records.py プロジェクト: liuyongjie985/WAP
def main(_):
    ano_data_set = os.path.join(cfg.data_set, cfg.ano_data_set)
    vocab_file = os.path.join(ano_data_set, cfg.tgt_vocab_file)

    with tf.Graph().as_default():
        vocab_size, vocab_file = vocab_utils.check_vocab(vocab_file,
                                                         out_dir=cfg.out_dir,
                                                         sos=cfg.sos,
                                                         eos=cfg.eos,
                                                         unk=cfg.unk)

        tgt_vocab_table = vocab_utils.create_vocab_tables(vocab_file)
        reverse_tgt_vocab_table = vocab_utils.index_to_string_table_from_file(
            vocab_file, default_value=cfg.unk)

        tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(cfg.sos)),
                             tf.int32)
        tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(cfg.eos)),
                             tf.int32)
        iter, batch_input = get_iterator(cfg.vaild_tf_filename,
                                         tgt_vocab_table, tgt_sos_id,
                                         tgt_eos_id)
        lookUpTgt = reverse_tgt_vocab_table.lookup(
            tf.to_int64(batch_input.target_output))
        sess = tf.Session()
        sess.run([tf.global_variables_initializer(), tf.tables_initializer()])
        sess.run(iter)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        step = 0
        try:
            while True:
                try:
                    while not coord.should_stop():
                        src, tgt_output, src_seq_len, tgt_seq_len = \
                            sess.run([batch_input.source, lookUpTgt, batch_input.source_sequence_length, batch_input.target_sequence_length])
                        if np.isnan(np.max(src)) or np.isnan(np.min(src)):
                            print('get a nan')
                            exit(1)
                        if np.any(np.less(src, 0.)):
                            print('get a fushu')
                            exit(1)
                        print('run one')
                        step += 1
                except tf.errors.OutOfRangeError:
                    print('check finished')
                    exit(1)
                    sess.run(iter)
        except KeyboardInterrupt:
            print('interrupt')
        finally:
            coord.request_stop()
        coord.join(threads)
        sess.close()
コード例 #4
0
def create_train_model(model_creator, get_iterator, hparams, scope=None):
    """Create the training graph, model and iterator"""

    # Get the files by concatting prefixes and outputs.
    src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
    tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)
    vocab_file = hparams.vocab_file
    # Define the graph
    graph = tf.Graph()

    with graph.as_default():
        vocab_table = vocab_utils.create_vocab_tables(vocab_file)
        # Create datasets from file
        src_dataset = tf.contrib.data.TextLineDataset(src_file)
        tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file)
        # The number of elements of this dataset that should be skipped to form the new dataset.
        skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)
        # Iterator
        iterator = get_iterator(
            src_dataset=src_dataset,
            tgt_dataset=tgt_dataset,
            vocab_table=vocab_table,
            batch_size=hparams.batch_size,
            sos=hparams.sos,
            eos=hparams.eos,
            src_reverse=hparams.src_reverse,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            src_max_len=hparams.src_max_len,
            tgt_max_len=hparams.tgt_max_len,
            skip_count=skip_count_placeholder
        )
        # Model. We don't give ids_to_words arg because we don't need it for training
        model = model_creator(
            hparams=hparams,
            mode=tf.contrib.learn.ModeKeys.TRAIN,
            iterator=iterator,
            vocab_table=vocab_table,
            scope=scope
        )

    return TrainModel(
        graph=graph,
        model=model,
        iterator=iterator,
        skip_count_placeholder=skip_count_placeholder
    )
コード例 #5
0
def create_eval_model(model_creator, get_iterator, hparams, scope=None):
    """Create train graph, model, src/tgt file holders, and iterator."""
    vocab_file = hparams.vocab_file
    # Define the graph
    graph = tf.Graph()

    with graph.as_default():
        vocab_table = vocab_utils.create_vocab_tables(vocab_file)
        # Create placeholders for the file location
        src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        # Create the datasets from file
        src_dataset = tf.contrib.data.TextLineDataset(src_file_placeholder)
        tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file_placeholder)
        # Create the iterator for the dataset. We do not use skip_count here as we evaluate on the full file
        iterator = get_iterator(
            src_dataset=src_dataset,
            tgt_dataset=tgt_dataset,
            vocab_table=vocab_table,
            batch_size=hparams.batch_size,
            sos=hparams.sos,
            eos=hparams.eos,
            src_reverse=hparams.src_reverse,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            src_max_len=hparams.src_max_len_infer,
            tgt_max_len=hparams.tgt_max_len_infer,
        )
        # Create a simple model
        model = model_creator(
            hparams=hparams,
            iterator=iterator,
            mode=tf.contrib.learn.ModeKeys.EVAL,
            vocab_table=vocab_table,
            scope=scope
        )

    return EvalModel(
        graph=graph,
        model=model,
        src_file_placeholder=src_file_placeholder,
        tgt_file_placeholder=tgt_file_placeholder,
        iterator=iterator
    )
コード例 #6
0
def create_infer_model(model_creator, hparams, scope=None, extra_args=None):
    """Create inference model."""
    graph = tf.Graph()
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file
    lbl_vocab_file = hparams.lbl_vocab_file

    with graph.as_default(), tf.container(scope or "infer"):
        src_vocab_table, tgt_vocab_table, lbl_vocab_table = \
          vocab_utils.create_vocab_tables(src_vocab_file, tgt_vocab_file,
                                          lbl_vocab_file, hparams.share_vocab)
        reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
            tgt_vocab_file, default_value=vocab_utils.UNK)
        reverse_lbl_vocab_table = lookup_ops.index_to_string_table_from_file(
            lbl_vocab_file, default_value=vocab_utils.UNK)

        src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)

        src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder)
        iterator = iterator_utils.get_infer_iterator(
            src_dataset,
            src_vocab_table,
            batch_size=batch_size_placeholder,
            eos=hparams.eos,
            src_max_len=hparams.src_max_len_infer)
        model = model_creator(
            hparams,
            iterator=iterator,
            mode=tf.contrib.learn.ModeKeys.INFER,
            source_vocab_table=src_vocab_table,
            target_vocab_table=tgt_vocab_table,
            label_vocab_table=lbl_vocab_table,
            reverse_target_vocab_table=reverse_tgt_vocab_table,
            reverse_target_intent_vocab_table=reverse_lbl_vocab_table,
            scope=scope,
            extra_args=extra_args)
    return InferModel(graph=graph,
                      model=model,
                      src_placeholder=src_placeholder,
                      batch_size_placeholder=batch_size_placeholder,
                      iterator=iterator)
コード例 #7
0
def create_infer_model(model_creator,
                       get_infer_iterator,
                       hparams,
                       verbose=True,
                       scope=None):
    """Create the inference model"""
    graph = tf.Graph()
    vocab_file = hparams.vocab_file

    with graph.as_default():
        # Create the lookup tables
        vocab_table = vocab_utils.create_vocab_tables(vocab_file)
        ids_to_words = lookup_ops.index_to_string_table_from_file(
            vocabulary_file=vocab_file, default_value=vocab_utils.UNK)
        # Define data placeholders
        src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)
        # Create the dataset and iterator
        src_dataset = tf.contrib.data.Dataset.from_tensor_slices(
            src_placeholder)
        iterator = get_infer_iterator(dataset=src_dataset,
                                      vocab_table=vocab_table,
                                      batch_size=batch_size_placeholder,
                                      src_reverse=hparams.src_reverse,
                                      eos=hparams.eos,
                                      src_max_len=hparams.src_max_len_infer)
        # Create the model
        model = model_creator(hparams=hparams,
                              iterator=iterator,
                              mode=tf.contrib.learn.ModeKeys.INFER,
                              vocab_table=vocab_table,
                              verbose=verbose,
                              ids_to_words=ids_to_words,
                              scope=scope)

    return InferModel(graph=graph,
                      model=model,
                      src_placeholder=src_placeholder,
                      batch_size_placeholder=batch_size_placeholder,
                      iterator=iterator)
コード例 #8
0
def create_eval_model(model_creator, hparams, scope=None, single_cell_fn=None):
  """Create train graph, model, src/tgt file holders, and iterator."""
  src_vocab_file = hparams.src_vocab_file
  tgt_vocab_file = hparams.tgt_vocab_file
  graph = tf.Graph()

  with graph.as_default():
    src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
        src_vocab_file, tgt_vocab_file, hparams.share_vocab)
    src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
    tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
    src_dataset = tf.contrib.data.TextLineDataset(src_file_placeholder)
    tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file_placeholder)
    iterator = iterator_utils.get_iterator(
        src_dataset,
        tgt_dataset,
        src_vocab_table,
        tgt_vocab_table,
        hparams.batch_size,
        sos=hparams.sos,
        eos=hparams.eos,
        source_reverse=hparams.source_reverse,
        random_seed=hparams.random_seed,
        num_buckets=hparams.num_buckets,
        src_max_len=hparams.src_max_len_infer,
        tgt_max_len=hparams.tgt_max_len_infer)
    model = model_creator(
        hparams,
        iterator=iterator,
        mode=tf.contrib.learn.ModeKeys.EVAL,
        source_vocab_table=src_vocab_table,
        target_vocab_table=tgt_vocab_table,
        scope=scope,
        single_cell_fn=single_cell_fn)
  return EvalModel(
      graph=graph,
      model=model,
      src_file_placeholder=src_file_placeholder,
      tgt_file_placeholder=tgt_file_placeholder,
      iterator=iterator)
コード例 #9
0
    def __init__(self, is_training=True, checkPoint_path=None):
        self.graph = tf.Graph()
        self.is_training = is_training
        with self.graph.as_default():
            ano_data_set = os.path.join(cfg.data_set, cfg.ano_data_set)
            vocab_file = os.path.join(ano_data_set, cfg.tgt_vocab_file)
            vocab_size, vocab_file = vocab_utils.check_vocab(
                vocab_file,
                out_dir=cfg.out_dir,
                sos=cfg.sos,
                eos=cfg.eos,
                unk=cfg.unk)

            self.tgt_vocab_table = vocab_utils.create_vocab_tables(vocab_file)
            self.reverse_tgt_vocab_table = vocab_utils.index_to_string_table_from_file(
                vocab_file, default_value=cfg.unk)

            self.tgt_sos_id = tf.cast(
                self.tgt_vocab_table.lookup(tf.constant(cfg.sos)), tf.int32)
            self.tgt_eos_id = tf.cast(
                self.tgt_vocab_table.lookup(tf.constant(cfg.eos)), tf.int32)

            if is_training:
                # train_src_dataset = tf.contrib.data.TextLineDataset(os.path.join(ano_data_set, cfg.train_src_dataset))
                # train_tgt_dataset = tf.contrib.data.TextLineDataset(os.path.join(ano_data_set, cfg.train_tgt_dataset))
                self.init_iter_train, self.iterator_train = get_iterator(
                    cfg.train_tf_filename,
                    self.tgt_vocab_table,
                    self.tgt_sos_id,
                    self.tgt_eos_id,
                    augment=True)

                # vaild_src_dataset = tf.contrib.data.TextLineDataset(os.path.join(ano_data_set, cfg.vaild_src_dataset))
                # vaild_tgt_dataset = tf.contrib.data.TextLineDataset(os.path.join(ano_data_set, cfg.vaild_tgt_dataset))
                self.init_iter_vaild, self.iterator_vaild = get_iterator(
                    cfg.vaild_tf_filename, self.tgt_vocab_table,
                    self.tgt_sos_id, self.tgt_eos_id)

            else:
                self.source = tf.placeholder(tf.float32, (None, None),
                                             name='source')
                batch_source = tf.expand_dims(tf.expand_dims(self.source,
                                                             axis=0),
                                              axis=-1)
                iterator_source = normalize_input_img(batch_source)
                self.source_sequence_length = tf.constant(
                    tf.shape(iterator_source)[2], tf.int32)
                self.iterator = BatchedInput(
                    source=iterator_source,
                    target_input=None,
                    target_output=None,
                    source_sequence_length=self.source_sequence_length,
                    target_sequence_length=None)

            self.featureCNN = FeatureCNN()
            self.gru_att_cov = GRU_Att_Cov(vocab_size)  #词表size

            if is_training:
                if cfg.outer_batch_size:
                    outer_loss = 0
                    with tf.variable_scope('outer_batch_size') as scope:
                        for i in range(cfg.outer_batch_size):
                            if i > 0:
                                scope.reuse_variables()
                            self.cnn_out_train = self.featureCNN(
                                self.iterator_train.source, True, False)
                            self.logits_train, _, self.attn_dists_train = self.gru_att_cov(
                                self.cnn_out_train, self.iterator_train, True,
                                self.tgt_sos_id)
                            outer_loss += self._loss(self.logits_train,
                                                     self.iterator_train)

                    self.loss_train = outer_loss / cfg.outer_batch_size
                else:
                    self.cnn_out_train = self.featureCNN(
                        self.iterator_train.source, True, False)
                    self.logits_train, _, self.attn_dists_train = self.gru_att_cov(
                        self.cnn_out_train, self.iterator_train, True,
                        self.tgt_sos_id)
                    self.loss_train = self._loss(self.logits_train,
                                                 self.iterator_train)

                self.global_step = tf.Variable(0,
                                               name='global_step',
                                               trainable=False)
                self.learning_rate = tf.train.exponential_decay(
                    cfg.startLr, self.global_step, cfg.decay_steps,
                    cfg.decay_rate)
                optimizer = tf.train.AdadeltaOptimizer(self.learning_rate)
                self.train_op = optimizer.minimize(
                    self.loss_train, global_step=self.global_step)

                self.cnn_out_vaild = self.featureCNN(
                    self.iterator_vaild.source, True)
                self.logits_vaild, _, _ = self.gru_att_cov(
                    self.cnn_out_vaild, self.iterator_vaild, True,
                    self.tgt_sos_id)
                self.loss_vaild = self._loss(self.logits_vaild,
                                             self.iterator_vaild)

                self.cnn_out_vaild_infer = self.featureCNN(
                    self.iterator_vaild.source, False)
                _, self.infer_indes_vaild, self.infer_attn_dists_vaild = self.gru_att_cov(
                    self.cnn_out_vaild_infer, self.iterator_vaild, False,
                    self.tgt_sos_id)
                self.infer_lookUpTgt_vaild = self.reverse_tgt_vocab_table.lookup(
                    tf.to_int64(self.infer_indes_vaild))

                self.accuracy_vaild = self._acc(
                    self.infer_indes_vaild, self.iterator_vaild.target_output)
                self.train_lookUpTgt_vaild = self.reverse_tgt_vocab_table.lookup(
                    tf.to_int64(self.iterator_vaild.target_output))

                self.train_summary, self.vaild_summary = self._summary()
            else:
                self.cnn_out = self.featureCNN(self.iterator.source,
                                               is_training)
                _, self.infer_indes, self.infer_attn_dists = self.gru_att_cov(
                    self.cnn_out, self.iterator, False, self.tgt_sos_id)
                self.infer_lookUpTgt = self.reverse_tgt_vocab_table.lookup(
                    tf.to_int64(self.infer_indes))

            self.init = [
                tf.global_variables_initializer(),
                tf.tables_initializer()
            ]
            self.saver = tf.train.Saver()
            self.sess = tf.Session(config=tf.ConfigProto(
                log_device_placement=True))
            if not is_training:
                self.sess.run(self.init)
                self.saver.restore(self.sess, checkPoint_path)
コード例 #10
0
def create_train_model(
    model_creator, hparams, scope=None,
    single_cell_fn=None, model_device_fn=None):
  """Create train graph, model, and iterator."""
  src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
  tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)
  src_vocab_file = hparams.src_vocab_file
  tgt_vocab_file = hparams.tgt_vocab_file

  graph = tf.Graph()

  with graph.as_default():
    src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
        src_vocab_file, tgt_vocab_file, hparams.share_vocab)

    src_dataset = tf.contrib.data.TextLineDataset(src_file)
    tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file)
    skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)
    
    if hparams.curriculum == 'none':
      iterator = iterator_utils.get_iterator(
          src_dataset,
          tgt_dataset,
          src_vocab_table,
          tgt_vocab_table,
          batch_size=hparams.batch_size,
          sos=hparams.sos,
          eos=hparams.eos,
          source_reverse=hparams.source_reverse,
          random_seed=hparams.random_seed,
          num_buckets=hparams.num_buckets,
          src_max_len=hparams.src_max_len,
          tgt_max_len=hparams.tgt_max_len,
          skip_count=skip_count_placeholder)
    else:
      iterator = iterator_utils.get_feedable_iterator(
          hparams,
          src_dataset,
          tgt_dataset,
          src_vocab_table,
          tgt_vocab_table,
          batch_size=hparams.batch_size,
          sos=hparams.sos,
          eos=hparams.eos,
          source_reverse=hparams.source_reverse,
          random_seed=hparams.random_seed,
          num_buckets=hparams.num_buckets,
          src_max_len=hparams.src_max_len,
          tgt_max_len=hparams.tgt_max_len,
          skip_count=skip_count_placeholder)

    # Note: One can set model_device_fn to
    # `tf.train.replica_device_setter(ps_tasks)` for distributed training.
    with tf.device(model_device_fn):
      model = model_creator(
          hparams,
          iterator=iterator,
          mode=tf.contrib.learn.ModeKeys.TRAIN,
          source_vocab_table=src_vocab_table,
          target_vocab_table=tgt_vocab_table,
          scope=scope,
          single_cell_fn=single_cell_fn)

  return TrainModel(
      graph=graph,
      model=model,
      iterator=iterator,
      skip_count_placeholder=skip_count_placeholder)