Пример #1
0
  def build_inputs(self):
    """Builds the ops for reading input data.

    Outputs:
      self.encode_ids
      self.encode_mask
      
    """
    if self.mode == "encode":
      # Word embeddings are fed from an external vocabulary which has possibly
      # been expanded (see vocabulary_expansion.py).
      encode_ids1 = None
      encode_ids2 = None
      encode_mask1 = tf.placeholder(tf.int8, (None, None), name="encode_mask1")
      encode_mask2 = tf.placeholder(tf.int8, (None, None), name="encode_mask2")
      label = None

    elif self.mode == "test":
      encode_ids1 = None
      encode_ids2 = None
      encode_mask1 = tf.placeholder(tf.int8, (None, None), name="encode_mask1")
      encode_mask2 = tf.placeholder(tf.int8, (None, None), name="encode_mask2")
      label = None
      
    else:
      # Prefetch serialized tf.Example protos.
      input_queue = input_ops.prefetch_input_data(
          self.reader,
          self.config.input_file_pattern,
          shuffle=self.config.shuffle_input_data,
          capacity=self.config.input_queue_capacity,
          num_reader_threads=self.config.num_input_reader_threads)

      # Deserialize a batch.
      serialized = input_queue.dequeue_many(self.config.batch_size)
      s1, s2, label = input_ops.parse_example_batch(
          serialized)

      encode_ids1 = s1.ids
      encode_ids2 = s2.ids

      encode_mask1 = s1.mask
      encode_mask2 = s2.mask
      


    self.encode_ids1 = encode_ids1
    self.encode_ids2 = encode_ids2

    self.encode_mask1 = encode_mask1
    self.encode_mask2 = encode_mask2

    self.label = label
Пример #2
0
    def build_inputs(self):
        """Builds the ops for reading input data.

    Outputs:
      self.encode_ids
      self.decode_pre_ids
      self.decode_post_ids
      self.encode_mask
      self.decode_pre_mask
      self.decode_post_mask
    """
        if self.mode == "encode":
            # Word embeddings are fed from an external vocabulary which has possibly
            # been expanded (see vocabulary_expansion.py).
            encode_ids = None
            decode_pre_ids = None
            decode_post_ids = None
            encode_mask = tf.placeholder(tf.int8, (None, None),
                                         name="encode_mask")
            decode_pre_mask = None
            decode_post_mask = None
        else:
            # Prefetch serialized tf.Example protos.
            input_queue = input_ops.prefetch_input_data(
                self.reader,
                self.config.input_file_pattern,
                shuffle=self.config.shuffle_input_data,
                capacity=self.config.input_queue_capacity,
                num_reader_threads=self.config.num_input_reader_threads)

            # Deserialize a batch.
            serialized = input_queue.dequeue_many(self.config.batch_size)
            encode, decode_pre, decode_post = input_ops.parse_example_batch(
                serialized)

            encode_ids = encode.ids
            decode_pre_ids = decode_pre.ids
            decode_post_ids = decode_post.ids

            encode_mask = encode.mask
            decode_pre_mask = decode_pre.mask
            decode_post_mask = decode_post.mask

        self.encode_ids = encode_ids
        self.decode_pre_ids = decode_pre_ids
        self.decode_post_ids = decode_post_ids

        self.encode_mask = encode_mask
        self.decode_pre_mask = decode_pre_mask
        self.decode_post_mask = decode_post_mask
Пример #3
0
  def build_inputs(self):
    """Builds the ops for reading input data.

    Outputs:
      self.encode_ids
      self.decode_pre_ids
      self.decode_post_ids
      self.encode_mask
      self.decode_pre_mask
      self.decode_post_mask
    """
    if self.mode == "encode":
      # Word embeddings are fed from an external vocabulary which has possibly
      # been expanded (see vocabulary_expansion.py).
      encode_ids = None
      decode_pre_ids = None
      decode_post_ids = None
      encode_mask = tf.placeholder(tf.int8, (None, None), name="encode_mask")
      decode_pre_mask = None
      decode_post_mask = None
    else:
      # Prefetch serialized tf.Example protos.
      input_queue = input_ops.prefetch_input_data(
          self.reader,
          self.config.input_file_pattern,
          shuffle=self.config.shuffle_input_data,
          capacity=self.config.input_queue_capacity,
          num_reader_threads=self.config.num_input_reader_threads)

      # Deserialize a batch.
      serialized = input_queue.dequeue_many(self.config.batch_size)
      encode, decode_pre, decode_post = input_ops.parse_example_batch(
          serialized)

      encode_ids = encode.ids
      decode_pre_ids = decode_pre.ids
      decode_post_ids = decode_post.ids

      encode_mask = encode.mask
      decode_pre_mask = decode_pre.mask
      decode_post_mask = decode_post.mask

    self.encode_ids = encode_ids
    self.decode_pre_ids = decode_pre_ids
    self.decode_post_ids = decode_post_ids

    self.encode_mask = encode_mask
    self.decode_pre_mask = decode_pre_mask
    self.decode_post_mask = decode_post_mask
Пример #4
0
    def build_inputs(self):
        """Builds the ops for reading input data.

    Outputs:
      self.encode_ids
      self.decode_pre_ids
      self.decode_post_ids
      self.encode_mask
      self.decode_pre_mask
      self.decode_post_mask
    """
        if self.mode == "encode":
            # Word embeddings are fed from an external vocabulary which has possibly
            # been expanded (see vocabulary_expansion.py).
            encode_ids = None
            decode_pre_ids = None
            decode_post_ids = None
            encode_mask = tf.placeholder(tf.int32, (None, None),
                                         name="encode_mask")
            decode_pre_mask = None
            decode_post_mask = None
        elif self.mode == "decode":
            # Word embeddings are fed from an external vocabulary which has possibly
            # been expanded (see vocabulary_expansion.py).
            encode_ids = tf.placeholder(tf.int64, (None, None),
                                        name="encode_ids")
            decode_pre_ids = tf.placeholder(tf.int64, (None, None),
                                            name="decode_pre_ids")
            decode_post_ids = tf.placeholder(tf.int64, (None, None),
                                             name="decode_post_ids")
            encode_mask = tf.placeholder(tf.int32, (None, None),
                                         name="encode_mask")
            decode_pre_mask = tf.placeholder(tf.int32, (None, None),
                                             name="decode_pre_mask")
            decode_post_mask = tf.placeholder(tf.int32, (None, None),
                                              name="decode_post_mask")
        else:
            # Prefetch serialized tf.Example protos.
            input_queue = input_ops.prefetch_input_data(
                self.reader,
                self.config.input_file_pattern,
                shuffle=self.config.shuffle_input_data,
                capacity=self.config.input_queue_capacity,
                num_reader_threads=self.config.num_input_reader_threads)

            serialized = input_queue.dequeue()
            encode, decode_pre, decode_post = _parse_single_example(serialized)
            encode_mask = tf.ones_like(encode, dtype=tf.int32)

            # Ensure the minimum lengths of decode_pre and decode_post. We request an
            # extra length unit for decode_pre, because we will clip out the <EOS>
            # later (only if decode_strategy != "conditional").
            if self.config.decode_strategy == "conditional":
                decode_pre, decode_pre_mask = _pad_to_min_length(
                    decode_pre,
                    self.config.context_window + self.config.condition_length,
                    pad_from_front=True)
                decode_post, decode_post_mask = _pad_to_min_length(
                    decode_post,
                    self.config.context_window + self.config.condition_length,
                    pad_from_front=False)
            else:
                decode_pre, decode_pre_mask = _pad_to_min_length(
                    decode_pre,
                    self.config.context_window + 1,
                    pad_from_front=True)
                decode_post, decode_post_mask = _pad_to_min_length(
                    decode_post,
                    self.config.context_window,
                    pad_from_front=False)

            # Clip to the end of decode_pre and the beginning of decode_post. Also
            # ignore the <EOS> at the end of decode_pre (only if decode_strategy !=
            # "conditional").
            if self.config.decode_strategy == "conditional":
                decode_pre = decode_pre[-(self.config.context_window +
                                          self.config.condition_length):]
                decode_pre_mask = decode_pre_mask[-(
                    self.config.context_window +
                    self.config.condition_length):]
                decode_post = decode_post[0:self.config.context_window +
                                          self.config.condition_length]
                decode_post_mask = decode_post_mask[
                    0:self.config.context_window +
                    self.config.condition_length]
            else:
                decode_pre = tf.reverse(decode_pre,
                                        [0])[1:self.config.context_window + 1]
                decode_pre_mask = tf.reverse(
                    decode_pre_mask, [0])[1:self.config.context_window + 1]
                decode_post = decode_post[0:self.config.context_window]
                decode_post_mask = decode_post_mask[0:self.config.
                                                    context_window]

            (encode_ids, decode_pre_ids, decode_post_ids, encode_mask,
             decode_pre_mask, decode_post_mask) = tf.train.batch(
                 tensors=[
                     encode,
                     decode_pre,
                     decode_post,
                     encode_mask,
                     decode_pre_mask,
                     decode_post_mask,
                 ],
                 batch_size=self.config.batch_size,
                 num_threads=self.config.num_batching_threads,
                 capacity=5 * self.config.batch_size,
                 dynamic_pad=True)

        self.encode_ids = encode_ids
        self.decode_pre_ids = decode_pre_ids
        self.decode_post_ids = decode_post_ids

        self.encode_mask = encode_mask
        self.decode_pre_mask = decode_pre_mask
        self.decode_post_mask = decode_post_mask