def __build_context(self, params, encoder_results, keep_prob, device): with variable_scope.variable_scope("context") as scope: context_seq_length = tf.fill([self.batch_size], self.num_turns) with tf.device(device): context_inputs = tf.stack( [state for outputs, state in encoder_results], axis=0) if params.context_type == "uni": cell = rnn_factory.create_cell(params.cell_type, params.decoder_hidden_units, num_layers=1, input_keep_prob=keep_prob, devices=[device]) context_outputs, context_state = tf.nn.dynamic_rnn( cell, inputs=context_inputs, sequence_length=context_seq_length, time_major=True, dtype=scope.dtype, swap_memory=True) return context_outputs, context_state elif params.context_type == "bi": fw_cell = rnn_factory.create_cell( params.cell_type, params.decoder_hidden_units, num_layers=1, input_keep_prob=keep_prob, devices=[device]) bw_cell = rnn_factory.create_cell( params.cell_type, params.decoder_hidden_units, num_layers=1, input_keep_prob=keep_prob, devices=[device]) context_outputs, context_state = tf.nn.bidirectional_dynamic_rnn( fw_cell, bw_cell, context_inputs, sequence_length=context_seq_length, time_major=True, dtype=scope.dtype, swap_memory=True) fw_state, bw_state = context_state context_state = tf.concat([fw_state, bw_state], axis=1) return context_outputs, context_state else: raise ValueError("Unknown encoder type: %s" % params.encoder_type)
def __build_decoder_cell(self, params, mode, encoder_outputs, encoder_state, keep_prob): cell = rnn_factory.create_cell(params.cell_type, params.hidden_units, self.num_layers, input_keep_prob=keep_prob, devices=self.round_robin.assign( self.num_layers)) if mode == tf.contrib.learn.ModeKeys.INFER and params.beam_width > 0: batch_size = self.batch_size * params.beam_width decoder_initial_state = tf.contrib.seq2seq.tile_batch( encoder_state, multiplier=params.beam_width) memory = tf.contrib.seq2seq.tile_batch( encoder_outputs, multiplier=params.beam_width) source_sequence_length = tf.contrib.seq2seq.tile_batch( self.iterator.source_sequence_lengths, multiplier=params.beam_width) else: batch_size = self.batch_size decoder_initial_state = encoder_state memory = encoder_outputs source_sequence_length = self.iterator.source_sequence_lengths try: attn_mechanism = attention_helper.create_attention_mechanism( params.attention_type, params.hidden_units, memory, source_sequence_length) alignment_history = self.mode == tf.contrib.learn.ModeKeys.INFER and params.beam_width == 0 cell = tf.contrib.seq2seq.AttentionWrapper( cell, attention_mechanism=attn_mechanism, attention_layer_size=params.hidden_units, alignment_history=alignment_history, output_attention=True, name="vanilla_attention") decoder_initial_state = cell.zero_state( batch_size, self.dtype).clone(cell_state=decoder_initial_state) except ValueError: pass return cell, decoder_initial_state
def __build_context(self, params, encoder_results, keep_prob, device): with variable_scope.variable_scope("context"): with tf.device(device): context_seq_length = tf.fill([self.batch_size], self.num_turns) if params.context_direction == 'backward': context_inputs = tf.stack( [state for _, state in reversed(encoder_results)], axis=0) else: context_inputs = tf.stack( [state for _, state in encoder_results], axis=0) # message_attention = attention_helper.create_attention_mechanism(params.attention_type, # params.hidden_units, # context_inputs) cell = rnn_factory.create_cell(params.cell_type, params.hidden_units, num_layers=1, input_keep_prob=keep_prob, devices=[device]) # cell = tf.contrib.seq2seq.AttentionWrapper( # cell, # msg_attn_mechanism, # attention_layer_size=params.hidden_units, # alignment_history=False, # output_attention=True, # name="message_attention") context_outputs, context_state = tf.nn.dynamic_rnn( cell, inputs=context_inputs, sequence_length=context_seq_length, time_major=True, dtype=self.dtype, swap_memory=True) return context_outputs, context_state
def __build_decoder_cell(self, params, context_outputs, context_state, input_keep_prob, device): cell = rnn_factory.create_cell(params.cell_type, params.hidden_units, num_layers=1, input_keep_prob=input_keep_prob, devices=[device]) topical_embeddings = tf.nn.embedding_lookup(self.embeddings, self.iterator.topic) max_topic_length = tf.reduce_max(self.iterator.topic_sequence_length) expanded_context_state = tf.tile(tf.expand_dims(context_state, axis=1), [1, max_topic_length, 1]) topical_embeddings = tf.concat( [expanded_context_state, topical_embeddings], axis=2) context_sequence_length = tf.fill([self.batch_size], self.num_turns) batch_majored_context_outputs = tf.transpose(context_outputs, [1, 0, 2]) if self.mode == tf.contrib.learn.ModeKeys.INFER and params.beam_width > 0: batch_size = self.batch_size * params.beam_width decoder_initial_state = tf.contrib.seq2seq.tile_batch( context_state, multiplier=params.beam_width) memory = tf.contrib.seq2seq.tile_batch( batch_majored_context_outputs, multiplier=params.beam_width) topical_embeddings = tf.contrib.seq2seq.tile_batch( topical_embeddings, multiplier=params.beam_width) context_sequence_length = tf.contrib.seq2seq.tile_batch( context_sequence_length, multiplier=params.beam_width) topic_sequence_length = tf.contrib.seq2seq.tile_batch( self.iterator.topic_sequence_length, multiplier=params.beam_width) else: batch_size = self.batch_size decoder_initial_state = context_state memory = batch_majored_context_outputs topic_sequence_length = self.iterator.topic_sequence_length context_attention = attention_helper.create_attention_mechanism( params.attention_type, params.hidden_units, memory, context_sequence_length) topical_attention = attention_helper.create_attention_mechanism( params.attention_type, params.hidden_units, topical_embeddings, topic_sequence_length) alignment_history = self.mode == tf.contrib.learn.ModeKeys.INFER and params.beam_width == 0 cell = tf.contrib.seq2seq.AttentionWrapper( cell, attention_mechanism=(context_attention, topical_attention), attention_layer_size=(params.hidden_units, params.hidden_units), alignment_history=alignment_history, output_attention=True, name="joint_attention") decoder_initial_state = cell.zero_state( batch_size, self.dtype).clone(cell_state=decoder_initial_state) return cell, decoder_initial_state
def __build_encoder(self, params, keep_prob, device): encoder_cell = {} if params.encoder_type == "uni": log.print_out(" build unidirectional encoder") encoder_cell['uni'] = rnn_factory.create_cell( params.cell_type, params.hidden_units, num_layers=1, input_keep_prob=keep_prob, devices=[device]) elif params.encoder_type == "bi": log.print_out(" build bidirectional encoder") encoder_cell['fw'] = rnn_factory.create_cell( params.cell_type, params.hidden_units, num_layers=1, input_keep_prob=keep_prob, devices=[device]) encoder_cell['bw'] = rnn_factory.create_cell( params.cell_type, params.hidden_units, num_layers=1, input_keep_prob=keep_prob, devices=[device]) else: raise ValueError("Unknown encoder type: '%s'" % params.encoder_type) encoding_devices = self.round_robin.assign(self.num_turns) encoder_results = [] for t in range(self.num_turns): scope_name = "encoder%d" % t if params.disable_encoder_var_sharing else "encoder" with variable_scope.variable_scope(scope_name) as scope: if t > 0 and not params.disable_encoder_var_sharing: scope.reuse_variables() with tf.device(encoding_devices[t]): encoder_embedded_inputs = tf.nn.embedding_lookup( params=self.embeddings, ids=self.iterator.sources[t]) if params.encoder_type == "bi": encoder_outputs, states = tf.nn.bidirectional_dynamic_rnn( encoder_cell['fw'], encoder_cell['bw'], inputs=encoder_embedded_inputs, dtype=self.dtype, sequence_length=self.iterator. source_sequence_lengths[t], swap_memory=True) fw_state, bw_state = states encoder_state = tf.concat([fw_state, bw_state], axis=1) else: encoder_outputs, encoder_state = tf.nn.dynamic_rnn( encoder_cell['uni'], inputs=encoder_embedded_inputs, sequence_length=self.iterator. source_sequence_lengths[t], dtype=self.dtype, swap_memory=True, scope=scope) # msg_attn_mechanism = attention_helper.create_attention_mechanism( # params.attention_type, # params.hidden_units, # encoder_outputs, # self.iterator.source_sequence_lengths[t]) encoder_results.append((encoder_outputs, encoder_state)) return encoder_results
def __build_decoder_cell(self, params, encoder_outputs, encoder_state, keep_prob): cell = rnn_factory.create_cell(params.cell_type, params.hidden_units, self.num_layers, input_keep_prob=keep_prob, devices=self.round_robin.assign( self.num_layers)) topical_embeddings = tf.nn.embedding_lookup(self.embeddings, self.iterator.topic) max_topic_length = tf.reduce_max(self.iterator.topic_sequence_length) aggregated_state = encoder_state if isinstance(encoder_state, tuple): aggregated_state = encoder_state[0] for state in encoder_state[1:]: aggregated_state = tf.concat([aggregated_state, state], axis=1) if isinstance(encoder_outputs, tuple): aggregated_outputs = encoder_outputs[0] for output in encoder_outputs[1:]: aggregated_outputs = tf.concat([aggregated_outputs, output], axis=1) encoder_outputs = aggregated_outputs expanded_encoder_state = tf.tile( tf.expand_dims(aggregated_state, axis=1), [1, max_topic_length, 1]) topical_embeddings = tf.concat( [expanded_encoder_state, topical_embeddings], axis=2) if self.mode == tf.contrib.learn.ModeKeys.INFER and params.beam_width > 0: batch_size = self.batch_size * params.beam_width if isinstance(encoder_state, tuple): decoder_initial_state = tuple([ tf.contrib.seq2seq.tile_batch(state, multiplier=params.beam_width) for state in encoder_state ]) else: decoder_initial_state = tf.contrib.seq2seq.tile_batch( encoder_state, multiplier=params.beam_width) memory = tf.contrib.seq2seq.tile_batch( encoder_outputs, multiplier=params.beam_width) topical_embeddings = tf.contrib.seq2seq.tile_batch( topical_embeddings, multiplier=params.beam_width) source_sequence_length = tf.contrib.seq2seq.tile_batch( self.iterator.source_sequence_lengths, multiplier=params.beam_width) topic_sequence_length = tf.contrib.seq2seq.tile_batch( self.iterator.topic_sequence_length, multiplier=params.beam_width) else: batch_size = self.batch_size decoder_initial_state = encoder_state memory = encoder_outputs source_sequence_length = self.iterator.source_sequence_lengths topic_sequence_length = self.iterator.topic_sequence_length message_attention = attention_helper.create_attention_mechanism( params.attention_type, params.hidden_units, memory, source_sequence_length) topical_attention = attention_helper.create_attention_mechanism( params.attention_type, params.hidden_units, topical_embeddings, topic_sequence_length) alignment_history = self.mode == tf.contrib.learn.ModeKeys.INFER and params.beam_width == 0 cell = tf.contrib.seq2seq.AttentionWrapper( cell, attention_mechanism=(message_attention, topical_attention), attention_layer_size=(params.hidden_units, params.hidden_units), alignment_history=alignment_history, output_attention=True, name="joint_attention") decoder_initial_state = cell.zero_state( batch_size, self.dtype).clone(cell_state=decoder_initial_state) return cell, decoder_initial_state
def __build_encoder(self, params, keep_prob): with variable_scope.variable_scope("encoder"): iterator = self.iterator encoder_embedded_inputs = tf.nn.embedding_lookup( params=self.embeddings, ids=iterator.sources) if params.encoder_type == "uni": log.print_out( " build unidirectional encoder num_layers = %d" % params.num_layers) cell = rnn_factory.create_cell(params.cell_type, params.hidden_units, self.num_layers, input_keep_prob=keep_prob, devices=self.round_robin.assign( self.num_layers)) encoder_outputs, encoder_state = tf.nn.dynamic_rnn( cell, inputs=encoder_embedded_inputs, sequence_length=iterator.source_sequence_lengths, dtype=self.dtype, swap_memory=True) return encoder_outputs, encoder_state elif params.encoder_type == "bi": num_bi_layers = int(params.num_layers / 2) log.print_out(" build bidirectional encoder num_layers = %d" % params.num_layers) fw_cell = rnn_factory.create_cell( params.cell_type, params.hidden_units, num_bi_layers, input_keep_prob=keep_prob, devices=self.round_robin.assign(num_bi_layers)) bw_cell = rnn_factory.create_cell( params.cell_type, params.hidden_units, num_bi_layers, input_keep_prob=keep_prob, devices=self.round_robin.assign( num_bi_layers, self.device_manager.num_available_gpus() - 1)) encoder_outputs, bi_state = tf.nn.bidirectional_dynamic_rnn( fw_cell, bw_cell, encoder_embedded_inputs, dtype=self.dtype, sequence_length=iterator.source_sequence_lengths, swap_memory=True) if num_bi_layers == 1: encoder_state = bi_state else: # alternatively concat forward and backward states encoder_state = [] for layer_id in range(num_bi_layers): encoder_state.append(bi_state[0][layer_id]) # forward encoder_state.append(bi_state[1][layer_id]) # backward encoder_state = tuple(encoder_state) return encoder_outputs, encoder_state else: raise ValueError("Unknown encoder type: %s" % params.encoder_type)
def __build_decoder(self, params, mode, context_outputs, context_state, input_keep_prob, device): iterator = self.iterator decoder_cell = rnn_factory.create_cell(params.cell_type, params.decoder_hidden_units, num_layers=1, input_keep_prob=input_keep_prob, devices=[device]) with variable_scope.variable_scope("decoder") as scope: with tf.device(device): initial_state = context_outputs[-1] if mode != tf.contrib.learn.ModeKeys.INFER: # decoder_emp_inp: [max_time, batch_size, num_units] decoder_emb_inp = tf.nn.embedding_lookup( self.embeddings, iterator.target_input) # Helper if self.sampling_probability == 0.0: helper = tf.contrib.seq2seq.TrainingHelper( decoder_emb_inp, iterator.target_sequence_length) else: helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper( decoder_emb_inp, iterator.target_sequence_length, self.embeddings, self.sampling_probability) # Decoder my_decoder = tf.contrib.seq2seq.BasicDecoder( decoder_cell, helper, initial_state) # Dynamic decoding outputs, final_decoder_state, _ = tf.contrib.seq2seq.dynamic_decode( my_decoder, swap_memory=True, scope=scope) sample_id = outputs.sample_id # Note: there's a subtle difference here between train and inference. # We could have set output_layer when create my_decoder # and shared more code between train and inference. # We chose to apply the output_layer to all timesteps for speed: # 10% improvements for small models & 20% for larger ones. # If memory is a concern, we should apply output_layer per timestep. logits = self.output_layer(outputs.rnn_output) ### Inference else: beam_width = params.beam_width start_tokens = tf.fill([self.batch_size], vocab.SOS_ID) end_token = vocab.EOS_ID maximum_iterations = self._get_decoder_max_iterations( params) if beam_width > 0: initial_state = tf.contrib.seq2seq.tile_batch( context_outputs[-1], multiplier=params.beam_width) my_decoder = tf.contrib.seq2seq.BeamSearchDecoder( cell=decoder_cell, embedding=self.embeddings, start_tokens=start_tokens, end_token=end_token, initial_state=initial_state, beam_width=beam_width, output_layer=self.output_layer, length_penalty_weight=params.length_penalty_weight) else: # Helper helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( self.embeddings, start_tokens, end_token) # Decoder my_decoder = tf.contrib.seq2seq.BasicDecoder( decoder_cell, helper, initial_state, output_layer=self. output_layer # applied per timestep ) # Dynamic decoding outputs, final_decoder_state, _ = tf.contrib.seq2seq.dynamic_decode( my_decoder, maximum_iterations=maximum_iterations, swap_memory=True, scope=scope) if beam_width > 0: logits = tf.no_op() sample_id = outputs.predicted_ids else: logits = outputs.rnn_output sample_id = outputs.sample_id return logits, sample_id, final_decoder_state