def setup_network(self): # Setup character embedding embedded_encoder_input, embedded_decoder_input, embed_func = self.setup_character_embedding( ) # Output projection with tf.variable_scope('alphabet_projection') as scope: self.projection_W, self.projection_b = intialize_projections( input_size=4 * self.config. num_units, # We use bidirectional encoder + attention output_size=self.config.alphabet_size, scope=scope) # Define alphabet projection function def project_func(output): return projection(output, W=self.projection_W, b=self.projection_b) # Encoder with tf.variable_scope('encoder') as scope: enc_outputs, enc_final_state = self.encoder.encode( inputs=embedded_encoder_input, seq_lengths=self.encoder_sequence_length, # enc_word_indices=self.enc_word_indices, # word_seq_lengths=self.word_seq_lengths, # max_words=self.config.max_words, scope=scope) # Set decoder initial state and encoder outputs based on the binary # mode input value # - If `self.is_lm_mode=0` Use the passed initial state from encoder # - If `self.is_lm_mode=1` Use the zero vector self.enc_outputs, self.enc_final_state = select_decoder_inputs( is_lm_mode=self.is_lm_mode, enc_outputs=enc_outputs, initial_state=enc_final_state, ) # Pack state to tensor self.enc_final_state_tensor = pack_state_tuple(self.enc_final_state) # Initialize decoder attention function using encoder outputs self.decoder.initialize_attention_func( input_size=embedded_decoder_input.get_shape().as_list()[-1], attention_states=self.enc_outputs) # Define initial attention tensor self.initial_attention = self.decoder.attention_func( self.enc_final_state) # Define decoder with tf.variable_scope('decoder'): dec_outputs, dec_final_state = self.decoder.decode( inputs=embedded_decoder_input, initial_state=self.enc_final_state, seq_length=self.decoder_sequence_length, embed_func=embed_func, project_func=project_func, ) # Project output to alphabet size and reshape dec_outputs = tf.reshape(dec_outputs, [-1, 4 * self.config.num_units]) dec_outputs = projection(dec_outputs, W=self.projection_W, b=self.projection_b) dec_outputs = tf.reshape(dec_outputs, [ -1, self.config.max_dec_seq_length + 1, self.config.alphabet_size ]) if self.prediction_mode: dec_outputs = self.decoder_logits # Define loss self.setup_losses(dec_outputs=dec_outputs, target_chars=self.target_chars, decoder_sequence_length=self.decoder_sequence_length) if self.prediction_mode: # Look up inputs decoder_inputs_embedded = tf.nn.embedding_lookup( self.embedding_matrix, self.decoder_inputs, name='decoder_input') is_lm_mode_tensor = tf.to_float( tf.expand_dims(self.is_lm_mode, axis=1)) decoder_inputs = tf.concat( [decoder_inputs_embedded, is_lm_mode_tensor], axis=1) # Unpack state initial_state = unpack_state_tensor(self.decoder_state) with tf.variable_scope('decoder', reuse=True): decoder_output, decoder_final_state, self.decoder_new_attention = self.decoder.predict( inputs=decoder_inputs, initial_state=initial_state, attention_states=self.decoder_attention) # Project output to alphabet size self.decoder_output = projection(decoder_output, W=self.projection_W, b=self.projection_b, name='decoder_output') # Compute decayed logits self.decoder_probs_decayed = compute_decayed_probs( logits=self.decoder_output, decay_parameter_ph=self.probs_decay_parameter) # Pack state to tensor self.decoder_final_state = pack_state_tuple( decoder_final_state, name='decoder_final_state')
def predict(self, session, lm_predict_func=None, **kwargs): assert self.prediction_mode def decode_func(inputs, state): output, probs, state = session.run(fetches=[ self.decoder_output, self.decoder_probs, self.decoder_final_state ], feed_dict={ self.decoder_inputs: inputs, self.decoder_state: state }) if lm_predict_func is not None: lm_output, lm_probs, lm_state = lm_predict_func(inputs, state) return {'output': output, 'probs': probs, 'state': state} def loss_func(logits, targets, input_length): return session.run( fetches=[self.mean_loss_batch, self.mean_prob_x_batch], feed_dict={ self.decoder_logits: logits, self.target_chars: targets, self.decoder_sequence_length: input_length }) # Construct vector of <GO_ID> tokens as initial input batch_size = kwargs['enc_input'].shape[0] dec_target = kwargs['dec_target'] dec_input_length = kwargs['dec_input_length'] initial_inputs = np.full(shape=(batch_size, ), fill_value=self.alphabet.GO_ID, dtype=np.float32) max_iterations = self.config.max_dec_seq_length + 1 # Define initial state initial_state = self.decoder.cell.zero_state(batch_size, dtype=tf.float32) initial_state = pack_state_tuple(initial_state) initial_state = session.run(initial_state) extra_features = {} # Initialize predictor if self.sample_type == 'beam': predictor = BeamSearchPredictor(batch_size=batch_size, max_length=max_iterations, alphabet=self.alphabet, decode_func=decode_func, loss_func=loss_func, beam_size=self.beam_size) elif self.sample_type == 'sample': predictor = SamplingPredictor( batch_size=batch_size, max_length=max_iterations, alphabet=self.alphabet, decode_func=decode_func, loss_func=loss_func, num_samples=self.beam_size, ) else: raise KeyError('Invalid sample_type provided!') # Predict sequence candidates final_candidates, final_logits, loss_candidates, prob_x_candidates = predictor.predict_sequences( initial_state=initial_state, target=dec_target, input_length=dec_input_length, features=extra_features) # Remove predictions after the `<EOS>` id for i, j, k in zip(*np.where( final_candidates == self.alphabet.EOS_ID)): final_candidates[i, j, k + 1:] = 0 return { 'candidates': final_candidates, 'loss_candidates': loss_candidates, 'prob_x_candidates': prob_x_candidates }
def setup_network(self): # Setup character embedding (defines `self.embedding_matrix`) with tf.device('/cpu:0'), tf.variable_scope(name_or_scope='embedding'): self.embedding_matrix = tf.get_variable( shape=[self.config.alphabet_size, self.config.embedding_size], initializer=tf.contrib.layers.xavier_initializer(), name='W') # Gather slices from `params` according to `indices` embedded_decoder_input = tf.nn.embedding_lookup( self.embedding_matrix, self.decoder_input_chars, name='dec_input') def embed_func(input_chars): return tf.gather(self.embedding_matrix, input_chars) # Output projection with tf.variable_scope('alphabet_projection') as scope: self.projection_W, self.projection_b = intialize_projections( input_size=self.config.num_units, output_size=self.config.alphabet_size, scope=scope) # Define alphabet projection function def project_func(output): return projection(output, W=self.projection_W, b=self.projection_b) # Define initial state as zero states self.enc_final_state = self.decoder.cell.zero_state( batch_size=tf.shape(embedded_decoder_input)[0], dtype=tf.float32) # Define decoder with tf.variable_scope('decoder'): dec_outputs, dec_final_state = self.decoder.decode( inputs=embedded_decoder_input, initial_state=self.enc_final_state, seq_length=self.decoder_sequence_length, embed_func=embed_func, project_func=project_func) # Project output to alphabet size and reshape dec_outputs = tf.reshape(dec_outputs, [-1, self.config.num_units]) dec_outputs = projection(dec_outputs, W=self.projection_W, b=self.projection_b) dec_outputs = tf.reshape(dec_outputs, [ -1, self.config.max_dec_seq_length + 1, self.config.alphabet_size ]) # self.packed_dec_final_state = pack_state_tuple(dec_final_state) if self.prediction_mode: dec_outputs = self.decoder_logits # Define loss self.setup_losses(dec_outputs=dec_outputs, target_chars=self.target_chars, decoder_sequence_length=self.decoder_sequence_length) if self.prediction_mode: # Pack state to tensor self.enc_final_state_tensor = pack_state_tuple( self.enc_final_state) # Look up inputs decoder_inputs_embedded = tf.nn.embedding_lookup( self.embedding_matrix, self.decoder_inputs, name='decoder_input') # Unpack state initial_state = unpack_state_tensor(self.decoder_state) with tf.variable_scope('decoder', reuse=True): decoder_output, decoder_final_state = self.decoder.predict( inputs=decoder_inputs_embedded, initial_state=initial_state) # Project output to alphabet size self.decoder_output = projection(decoder_output, W=self.projection_W, b=self.projection_b, name='decoder_output') self.decoder_probs = tf.nn.softmax(self.decoder_output, name='decoder_probs') self.probs_decay_parameter = tf.placeholder( tf.float64, shape=(), name='probs_decay_parameter') self.decoder_probs_decayed = tf.pow( tf.cast(self.decoder_probs, tf.float64), self.probs_decay_parameter) decoder_probs_sum = tf.expand_dims(tf.reduce_sum( self.decoder_probs_decayed, axis=1), axis=1) decoder_probs_sum = tf.tile(decoder_probs_sum, [1, self.config.alphabet_size]) self.decoder_probs_decayed = self.decoder_probs_decayed / decoder_probs_sum # Pack state to tensor self.decoder_final_state = pack_state_tuple( decoder_final_state, name='decoder_final_state')
def setup_network(self): # Setup character embedding embedded_encoder_input, embedded_decoder_input, embed_func = self.setup_character_embedding( ) # Output projection with tf.variable_scope('alphabet_projection') as scope: self.projection_W, self.projection_b = intialize_projections( input_size=4 * self.config.num_units, output_size=self.config.alphabet_size, scope=scope) # Define alphabet projection function def project_func(output): return projection(output, W=self.projection_W, b=self.projection_b) # Encoder with tf.variable_scope('encoder') as scope: # Normalize batch embedded_encoder_input = tf.layers.batch_normalization( inputs=embedded_encoder_input, center=True, scale=True, # training=not self.prediction_mode, training= True, # I think this should be true always, because in training # and inference we have the entire question text. trainable=True, ) enc_outputs, enc_final_state = self.encoder.encode( inputs=embedded_encoder_input, seq_lengths=self.encoder_sequence_length, enc_word_indices=self.enc_word_indices, word_seq_lengths=self.word_seq_lengths, max_words=self.config.max_words, scope=scope) # Predict question categories with tf.variable_scope('question') as scope: # Convert StateTuple to vector state_vector = tf.concat(flatten(enc_final_state), axis=1, name='combined-state-vec') # Add dense layer W, b = intialize_projections(input_size=4 * self.config.num_units * self.config.num_cells, output_size=128) layer = tf.nn.relu(tf.matmul(state_vector, W) + b) if self.add_dropout: layer = tf.nn.dropout(x=layer, keep_prob=self.keep_prob_ph) # Compute L2-weight decay W_penalty = tf.contrib.layers.apply_regularization( regularizer=tf.contrib.layers.l2_regularizer( scale=self.config.W_lambda), weights_list=[W]) class_logits = projection(x=layer, input_size=128, output_size=self.config.num_classes) # Set decoder initial state and encoder outputs based on the binary # mode input value # - If `self.is_lm_mode=0` Use the passed initial state from encoder # - If `self.is_lm_mode=1` Use the zero vector self.enc_outputs, enc_final_state = select_decoder_inputs( is_lm_mode=self.is_lm_mode, enc_outputs=enc_outputs, initial_state=enc_final_state, ) # If an observation has a class -> Pass the true class as 1-hot-encoded # vector to the decoder input. # If an observation doesn't have a class -> Pass the class logits for # the given observation to the decoder input. class_is_known = tf.greater_equal(self.class_idx, 0) # Create one-hot-encoded vectors class_one_hot = tf.one_hot(indices=self.class_idx, depth=self.config.num_classes, on_value=1.0, off_value=0.0, axis=-1, dtype=tf.float32, name='class-one-hot-encoded') # Compute class probabilities class_probs = tf.nn.softmax(class_logits) # Select what to pass on self.class_info_vec = tf.where(condition=class_is_known, x=class_one_hot, y=class_probs) # Concatenate class info vector with decoder input _class_info_vec = tf.expand_dims(self.class_info_vec, axis=1) _class_info_vec = tf.tile( _class_info_vec, multiples=[1, self.config.max_dec_seq_length + 1, 1]) decoder_input = tf.concat([embedded_decoder_input, _class_info_vec], axis=2) # Pack state to tensor self.enc_final_state_tensor = pack_state_tuple(enc_final_state) # Initialize decoder attention function using encoder outputs self.decoder.initialize_attention_func( input_size=decoder_input.get_shape().as_list()[-1], attention_states=self.enc_outputs) # Define decoder with tf.variable_scope('decoder'): dec_outputs, dec_final_state = self.decoder.decode( inputs=decoder_input, initial_state=enc_final_state, seq_length=self.decoder_sequence_length, embed_func=embed_func, project_func=project_func) # Project output to alphabet size and reshape dec_outputs = tf.reshape(dec_outputs, [-1, 4 * self.config.num_units]) dec_outputs = projection(dec_outputs, W=self.projection_W, b=self.projection_b) dec_outputs = tf.reshape(dec_outputs, [ -1, self.config.max_dec_seq_length + 1, self.config.alphabet_size ]) if self.prediction_mode: dec_outputs = self.decoder_logits # Define loss self.setup_losses(dec_outputs=dec_outputs, target_chars=self.target_chars, decoder_sequence_length=self.decoder_sequence_length, class_probs=class_probs, class_idx=self.class_idx, class_is_known=class_is_known, class_one_hot=class_one_hot, W_penalty=W_penalty) if self.prediction_mode: # Define initial attention tensor self.initial_attention = self.decoder.attention_func( enc_final_state) # Look up inputs decoder_inputs_embedded = tf.nn.embedding_lookup( self.embedding_matrix, self.decoder_inputs, name='decoder_input') is_lm_mode_tensor = tf.to_float( tf.expand_dims(self.is_lm_mode, axis=1)) decoder_inputs = tf.concat( [decoder_inputs_embedded, is_lm_mode_tensor], axis=1) # Concatenate class info vector decoder_inputs = tf.concat([decoder_inputs, self.class_info_vec], axis=1) # Unpack state initial_state = unpack_state_tensor(self.decoder_state) with tf.variable_scope('decoder', reuse=True): decoder_output, decoder_final_state, self.decoder_new_attention = self.decoder.predict( inputs=decoder_inputs, initial_state=initial_state, attention_states=self.decoder_attention) # Project output to alphabet size self.decoder_output = projection(decoder_output, W=self.projection_W, b=self.projection_b, name='decoder_output') # Compute decayed logits self.decoder_probs_decayed = compute_decayed_probs( logits=self.decoder_output, decay_parameter_ph=self.probs_decay_parameter) # Pack state to tensor self.decoder_final_state = pack_state_tuple( decoder_final_state, name='decoder_final_state')