def _build_decoder( self, model, step_model, model_params, scope, previous_tokens, timestep, fake_seq_lengths, ): attention_type = model_params['attention'] assert attention_type in ['none', 'regular'] use_attention = (attention_type != 'none') with core.NameScope(scope): encoder_embeddings = seq2seq_util.build_embeddings( model=model, vocab_size=self.source_vocab_size, embedding_size=model_params['encoder_embedding_size'], name='encoder_embeddings', freeze_embeddings=False, ) ( encoder_outputs, weighted_encoder_outputs, final_encoder_hidden_states, final_encoder_cell_states, encoder_units_per_layer, ) = seq2seq_util.build_embedding_encoder( model=model, encoder_params=model_params['encoder_type'], num_decoder_layers=len(model_params['decoder_layer_configs']), inputs=self.encoder_inputs, input_lengths=self.encoder_lengths, vocab_size=self.source_vocab_size, embeddings=encoder_embeddings, embedding_size=model_params['encoder_embedding_size'], use_attention=use_attention, num_gpus=0, forward_only=True, scope=scope, ) with core.NameScope(scope): if use_attention: # [max_source_length, beam_size, encoder_output_dim] encoder_outputs = model.net.Tile( encoder_outputs, 'encoder_outputs_tiled', tiles=self.beam_size, axis=1, ) if weighted_encoder_outputs is not None: weighted_encoder_outputs = model.net.Tile( weighted_encoder_outputs, 'weighted_encoder_outputs_tiled', tiles=self.beam_size, axis=1, ) decoder_embeddings = seq2seq_util.build_embeddings( model=model, vocab_size=self.target_vocab_size, embedding_size=model_params['decoder_embedding_size'], name='decoder_embeddings', freeze_embeddings=False, ) embedded_tokens_t_prev = step_model.net.Gather( [decoder_embeddings, previous_tokens], 'embedded_tokens_t_prev', ) decoder_cells = [] decoder_units_per_layer = [] for i, layer_config in enumerate( model_params['decoder_layer_configs']): num_units = layer_config['num_units'] decoder_units_per_layer.append(num_units) if i == 0: input_size = model_params['decoder_embedding_size'] else: input_size = ( model_params['decoder_layer_configs'][i - 1]['num_units']) cell = rnn_cell.LSTMCell( name=seq2seq_util.get_layer_scope(scope, 'decoder', i), forward_only=True, input_size=input_size, hidden_size=num_units, forget_bias=0.0, memory_optimization=False, ) decoder_cells.append(cell) with core.NameScope(scope): if final_encoder_hidden_states is not None: for i in range(len(final_encoder_hidden_states)): if final_encoder_hidden_states[i] is not None: final_encoder_hidden_states[i] = model.net.Tile( final_encoder_hidden_states[i], 'final_encoder_hidden_tiled_{}'.format(i), tiles=self.beam_size, axis=1, ) if final_encoder_cell_states is not None: for i in range(len(final_encoder_cell_states)): if final_encoder_cell_states[i] is not None: final_encoder_cell_states[i] = model.net.Tile( final_encoder_cell_states[i], 'final_encoder_cell_tiled_{}'.format(i), tiles=self.beam_size, axis=1, ) initial_states = \ seq2seq_util.build_initial_rnn_decoder_states( model=model, encoder_units_per_layer=encoder_units_per_layer, decoder_units_per_layer=decoder_units_per_layer, final_encoder_hidden_states=final_encoder_hidden_states, final_encoder_cell_states=final_encoder_cell_states, use_attention=use_attention, ) attention_decoder = seq2seq_util.LSTMWithAttentionDecoder( encoder_outputs=encoder_outputs, encoder_output_dim=encoder_units_per_layer[-1], encoder_lengths=None, vocab_size=self.target_vocab_size, attention_type=attention_type, embedding_size=model_params['decoder_embedding_size'], decoder_num_units=decoder_units_per_layer[-1], decoder_cells=decoder_cells, weighted_encoder_outputs=weighted_encoder_outputs, name=scope, ) states_prev = step_model.net.AddExternalInputs(*[ '{}/{}_prev'.format(scope, s) for s in attention_decoder.get_state_names() ]) decoder_outputs, states = attention_decoder.apply( model=step_model, input_t=embedded_tokens_t_prev, seq_lengths=fake_seq_lengths, states=states_prev, timestep=timestep, ) state_configs = [ BeamSearchForwardOnly.StateConfig( initial_value=initial_state, state_prev_link=BeamSearchForwardOnly.LinkConfig( blob=state_prev, offset=0, window=1, ), state_link=BeamSearchForwardOnly.LinkConfig( blob=state, offset=1, window=1, ), ) for initial_state, state_prev, state in zip( initial_states, states_prev, states, ) ] with core.NameScope(scope): decoder_outputs_flattened, _ = step_model.net.Reshape( [decoder_outputs], [ 'decoder_outputs_flattened', 'decoder_outputs_and_contexts_combination_old_shape', ], shape=[-1, attention_decoder.get_output_dim()], ) output_logits = seq2seq_util.output_projection( model=step_model, decoder_outputs=decoder_outputs_flattened, decoder_output_size=attention_decoder.get_output_dim(), target_vocab_size=self.target_vocab_size, decoder_softmax_size=model_params['decoder_softmax_size'], ) # [1, beam_size, target_vocab_size] output_probs = step_model.net.Softmax( output_logits, 'output_probs', ) output_log_probs = step_model.net.Log( output_probs, 'output_log_probs', ) if use_attention: attention_weights = attention_decoder.get_attention_weights() else: attention_weights = step_model.net.ConstantFill( [self.encoder_inputs], 'zero_attention_weights_tmp_1', value=0.0, ) attention_weights = step_model.net.Transpose( attention_weights, 'zero_attention_weights_tmp_2', ) attention_weights = step_model.net.Tile( attention_weights, 'zero_attention_weights_tmp', tiles=self.beam_size, axis=0, ) return ( state_configs, output_log_probs, attention_weights, )
def _build_decoder( self, model, step_model, model_params, scope, previous_tokens, timestep, fake_seq_lengths, ): attention_type = model_params['attention'] assert attention_type in ['none', 'regular'] use_attention = (attention_type != 'none') with core.NameScope(scope): encoder_embeddings = seq2seq_util.build_embeddings( model=model, vocab_size=self.source_vocab_size, embedding_size=model_params['encoder_embedding_size'], name='encoder_embeddings', freeze_embeddings=False, ) ( encoder_outputs, weighted_encoder_outputs, final_encoder_hidden_state, final_encoder_cell_state, encoder_output_dim, ) = seq2seq_util.build_embedding_encoder( model=model, encoder_params=model_params['encoder_type'], inputs=self.encoder_inputs, input_lengths=self.encoder_lengths, vocab_size=self.source_vocab_size, embeddings=encoder_embeddings, embedding_size=model_params['encoder_embedding_size'], use_attention=use_attention, num_gpus=0, scope=scope, ) with core.NameScope(scope): # [max_source_length, beam_size, encoder_output_dim] encoder_outputs = model.net.Tile( encoder_outputs, 'encoder_outputs_tiled', tiles=self.beam_size, axis=1, ) if weighted_encoder_outputs is not None: weighted_encoder_outputs = model.net.Tile( weighted_encoder_outputs, 'weighted_encoder_outputs_tiled', tiles=self.beam_size, axis=1, ) decoder_embeddings = seq2seq_util.build_embeddings( model=model, vocab_size=self.target_vocab_size, embedding_size=model_params['decoder_embedding_size'], name='decoder_embeddings', freeze_embeddings=False, ) embedded_tokens_t_prev = step_model.net.Gather( [decoder_embeddings, previous_tokens], 'embedded_tokens_t_prev', ) decoder_num_units = ( model_params['decoder_layer_configs'][0]['num_units'] ) with core.NameScope(scope): if not use_attention and final_encoder_hidden_state is not None: final_encoder_hidden_state = model.net.Tile( final_encoder_hidden_state, 'final_encoder_hidden_state_tiled', tiles=self.beam_size, axis=1, ) if not use_attention and final_encoder_cell_state is not None: final_encoder_cell_state = model.net.Tile( final_encoder_cell_state, 'final_encoder_cell_state_tiled', tiles=self.beam_size, axis=1, ) initial_states = seq2seq_util.build_initial_rnn_decoder_states( model=model, encoder_num_units=encoder_output_dim, decoder_num_units=decoder_num_units, final_encoder_hidden_state=final_encoder_hidden_state, final_encoder_cell_state=final_encoder_cell_state, use_attention=use_attention, ) if use_attention: decoder_cell = rnn_cell.LSTMWithAttentionCell( encoder_output_dim=encoder_output_dim, encoder_outputs=encoder_outputs, decoder_input_dim=model_params['decoder_embedding_size'], decoder_state_dim=decoder_num_units, name=self.scope(scope, 'decoder'), attention_type=attention.AttentionType.Regular, weighted_encoder_outputs=weighted_encoder_outputs, forget_bias=0.0, lstm_memory_optimization=False, attention_memory_optimization=True, ) decoder_output_dim = decoder_num_units + encoder_output_dim else: decoder_cell = rnn_cell.LSTMCell( name=self.scope(scope, 'decoder'), input_size=model_params['decoder_embedding_size'], hidden_size=decoder_num_units, forget_bias=0.0, memory_optimization=False, ) decoder_output_dim = decoder_num_units states_prev = step_model.net.AddExternalInputs(*[ s + '_prev' for s in decoder_cell.get_state_names() ]) _, states = decoder_cell.apply( model=step_model, input_t=embedded_tokens_t_prev, seq_lengths=fake_seq_lengths, states=states_prev, timestep=timestep, ) if use_attention: with core.NameScope(scope or ''): decoder_outputs, _ = step_model.net.Concat( [states[0], states[2]], [ 'states_and_context_combination', '_states_and_context_combination_concat_dims', ], axis=2, ) else: decoder_outputs = states[0] state_configs = [ BeamSearchForwardOnly.StateConfig( initial_value=initial_state, state_prev_link=BeamSearchForwardOnly.LinkConfig( blob=state_prev, offset=0, window=1, ), state_link=BeamSearchForwardOnly.LinkConfig( blob=state, offset=1, window=1, ), ) for initial_state, state_prev, state in zip( initial_states, states_prev, states, ) ] with core.NameScope(scope): decoder_outputs_flattened, _ = step_model.net.Reshape( [decoder_outputs], [ 'decoder_outputs_flattened', 'decoder_outputs_and_contexts_combination_old_shape', ], shape=[-1, decoder_output_dim], ) output_logits = seq2seq_util.output_projection( model=step_model, decoder_outputs=decoder_outputs_flattened, decoder_output_size=decoder_output_dim, target_vocab_size=self.target_vocab_size, decoder_softmax_size=model_params['decoder_softmax_size'], ) # [1, beam_size, target_vocab_size] output_probs = step_model.net.Softmax( output_logits, 'output_probs', ) output_log_probs = step_model.net.Log( output_probs, 'output_log_probs', ) if use_attention: attention_weights = decoder_cell.get_attention_weights() else: attention_weights = step_model.net.ConstantFill( [self.encoder_inputs], 'zero_attention_weights_tmp_1', value=0.0, ) attention_weights = step_model.net.Transpose( attention_weights, 'zero_attention_weights_tmp_2', ) attention_weights = step_model.net.Tile( attention_weights, 'zero_attention_weights_tmp', tiles=self.beam_size, axis=0, ) return ( state_configs, output_log_probs, attention_weights, )