def build_inputs(self): """Builds the ops for reading input data. Outputs: self.encode_ids self.encode_mask """ if self.mode == "encode": # Word embeddings are fed from an external vocabulary which has possibly # been expanded (see vocabulary_expansion.py). encode_ids1 = None encode_ids2 = None encode_mask1 = tf.placeholder(tf.int8, (None, None), name="encode_mask1") encode_mask2 = tf.placeholder(tf.int8, (None, None), name="encode_mask2") label = None elif self.mode == "test": encode_ids1 = None encode_ids2 = None encode_mask1 = tf.placeholder(tf.int8, (None, None), name="encode_mask1") encode_mask2 = tf.placeholder(tf.int8, (None, None), name="encode_mask2") label = None else: # Prefetch serialized tf.Example protos. input_queue = input_ops.prefetch_input_data( self.reader, self.config.input_file_pattern, shuffle=self.config.shuffle_input_data, capacity=self.config.input_queue_capacity, num_reader_threads=self.config.num_input_reader_threads) # Deserialize a batch. serialized = input_queue.dequeue_many(self.config.batch_size) s1, s2, label = input_ops.parse_example_batch( serialized) encode_ids1 = s1.ids encode_ids2 = s2.ids encode_mask1 = s1.mask encode_mask2 = s2.mask self.encode_ids1 = encode_ids1 self.encode_ids2 = encode_ids2 self.encode_mask1 = encode_mask1 self.encode_mask2 = encode_mask2 self.label = label
def build_inputs(self): """Builds the ops for reading input data. Outputs: self.encode_ids self.decode_pre_ids self.decode_post_ids self.encode_mask self.decode_pre_mask self.decode_post_mask """ if self.mode == "encode": # Word embeddings are fed from an external vocabulary which has possibly # been expanded (see vocabulary_expansion.py). encode_ids = None decode_pre_ids = None decode_post_ids = None encode_mask = tf.placeholder(tf.int8, (None, None), name="encode_mask") decode_pre_mask = None decode_post_mask = None else: # Prefetch serialized tf.Example protos. input_queue = input_ops.prefetch_input_data( self.reader, self.config.input_file_pattern, shuffle=self.config.shuffle_input_data, capacity=self.config.input_queue_capacity, num_reader_threads=self.config.num_input_reader_threads) # Deserialize a batch. serialized = input_queue.dequeue_many(self.config.batch_size) encode, decode_pre, decode_post = input_ops.parse_example_batch( serialized) encode_ids = encode.ids decode_pre_ids = decode_pre.ids decode_post_ids = decode_post.ids encode_mask = encode.mask decode_pre_mask = decode_pre.mask decode_post_mask = decode_post.mask self.encode_ids = encode_ids self.decode_pre_ids = decode_pre_ids self.decode_post_ids = decode_post_ids self.encode_mask = encode_mask self.decode_pre_mask = decode_pre_mask self.decode_post_mask = decode_post_mask