Пример #1
0
    def _build_decoder(
        self,
        model,
        step_model,
        model_params,
        scope,
        previous_tokens,
        timestep,
        fake_seq_lengths,
    ):
        attention_type = model_params['attention']
        assert attention_type in ['none', 'regular']
        use_attention = (attention_type != 'none')

        with core.NameScope(scope):
            encoder_embeddings = seq2seq_util.build_embeddings(
                model=model,
                vocab_size=self.source_vocab_size,
                embedding_size=model_params['encoder_embedding_size'],
                name='encoder_embeddings',
                freeze_embeddings=False,
            )

        (
            encoder_outputs,
            weighted_encoder_outputs,
            final_encoder_hidden_states,
            final_encoder_cell_states,
            encoder_units_per_layer,
        ) = seq2seq_util.build_embedding_encoder(
            model=model,
            encoder_params=model_params['encoder_type'],
            num_decoder_layers=len(model_params['decoder_layer_configs']),
            inputs=self.encoder_inputs,
            input_lengths=self.encoder_lengths,
            vocab_size=self.source_vocab_size,
            embeddings=encoder_embeddings,
            embedding_size=model_params['encoder_embedding_size'],
            use_attention=use_attention,
            num_gpus=0,
            forward_only=True,
            scope=scope,
        )
        with core.NameScope(scope):
            if use_attention:
                # [max_source_length, beam_size, encoder_output_dim]
                encoder_outputs = model.net.Tile(
                    encoder_outputs,
                    'encoder_outputs_tiled',
                    tiles=self.beam_size,
                    axis=1,
                )

            if weighted_encoder_outputs is not None:
                weighted_encoder_outputs = model.net.Tile(
                    weighted_encoder_outputs,
                    'weighted_encoder_outputs_tiled',
                    tiles=self.beam_size,
                    axis=1,
                )

            decoder_embeddings = seq2seq_util.build_embeddings(
                model=model,
                vocab_size=self.target_vocab_size,
                embedding_size=model_params['decoder_embedding_size'],
                name='decoder_embeddings',
                freeze_embeddings=False,
            )
            embedded_tokens_t_prev = step_model.net.Gather(
                [decoder_embeddings, previous_tokens],
                'embedded_tokens_t_prev',
            )

        decoder_cells = []
        decoder_units_per_layer = []
        for i, layer_config in enumerate(model_params['decoder_layer_configs']):
            num_units = layer_config['num_units']
            decoder_units_per_layer.append(num_units)
            if i == 0:
                input_size = model_params['decoder_embedding_size']
            else:
                input_size = (
                    model_params['decoder_layer_configs'][i - 1]['num_units']
                )

            cell = rnn_cell.LSTMCell(
                forward_only=True,
                input_size=input_size,
                hidden_size=num_units,
                forget_bias=0.0,
                memory_optimization=False,
            )
            decoder_cells.append(cell)

        with core.NameScope(scope):
            if final_encoder_hidden_states is not None:
                for i in range(len(final_encoder_hidden_states)):
                    if final_encoder_hidden_states[i] is not None:
                        final_encoder_hidden_states[i] = model.net.Tile(
                            final_encoder_hidden_states[i],
                            'final_encoder_hidden_tiled_{}'.format(i),
                            tiles=self.beam_size,
                            axis=1,
                        )
            if final_encoder_cell_states is not None:
                for i in range(len(final_encoder_cell_states)):
                    if final_encoder_cell_states[i] is not None:
                        final_encoder_cell_states[i] = model.net.Tile(
                            final_encoder_cell_states[i],
                            'final_encoder_cell_tiled_{}'.format(i),
                            tiles=self.beam_size,
                            axis=1,
                        )
            initial_states = \
                seq2seq_util.build_initial_rnn_decoder_states(
                    model=model,
                    encoder_units_per_layer=encoder_units_per_layer,
                    decoder_units_per_layer=decoder_units_per_layer,
                    final_encoder_hidden_states=final_encoder_hidden_states,
                    final_encoder_cell_states=final_encoder_cell_states,
                    use_attention=use_attention,
                )

        attention_decoder = seq2seq_util.LSTMWithAttentionDecoder(
            encoder_outputs=encoder_outputs,
            encoder_output_dim=encoder_units_per_layer[-1],
            encoder_lengths=None,
            vocab_size=self.target_vocab_size,
            attention_type=attention_type,
            embedding_size=model_params['decoder_embedding_size'],
            decoder_num_units=decoder_units_per_layer[-1],
            decoder_cells=decoder_cells,
            weighted_encoder_outputs=weighted_encoder_outputs,
            name=scope,
        )
        states_prev = step_model.net.AddExternalInputs(*[
            '{}/{}_prev'.format(scope, s)
            for s in attention_decoder.get_state_names()
        ])
        decoder_outputs, states = attention_decoder.apply(
            model=step_model,
            input_t=embedded_tokens_t_prev,
            seq_lengths=fake_seq_lengths,
            states=states_prev,
            timestep=timestep,
        )

        state_configs = [
            BeamSearchForwardOnly.StateConfig(
                initial_value=initial_state,
                state_prev_link=BeamSearchForwardOnly.LinkConfig(
                    blob=state_prev,
                    offset=0,
                    window=1,
                ),
                state_link=BeamSearchForwardOnly.LinkConfig(
                    blob=state,
                    offset=1,
                    window=1,
                ),
            )
            for initial_state, state_prev, state in zip(
                initial_states,
                states_prev,
                states,
            )
        ]

        with core.NameScope(scope):
            decoder_outputs_flattened, _ = step_model.net.Reshape(
                [decoder_outputs],
                [
                    'decoder_outputs_flattened',
                    'decoder_outputs_and_contexts_combination_old_shape',
                ],
                shape=[-1, attention_decoder.get_output_dim()],
            )
            output_logits = seq2seq_util.output_projection(
                model=step_model,
                decoder_outputs=decoder_outputs_flattened,
                decoder_output_size=attention_decoder.get_output_dim(),
                target_vocab_size=self.target_vocab_size,
                decoder_softmax_size=model_params['decoder_softmax_size'],
            )
            # [1, beam_size, target_vocab_size]
            output_probs = step_model.net.Softmax(
                output_logits,
                'output_probs',
            )
            output_log_probs = step_model.net.Log(
                output_probs,
                'output_log_probs',
            )
            if use_attention:
                attention_weights = attention_decoder.get_attention_weights()
            else:
                attention_weights = step_model.net.ConstantFill(
                    [self.encoder_inputs],
                    'zero_attention_weights_tmp_1',
                    value=0.0,
                )
                attention_weights = step_model.net.Transpose(
                    attention_weights,
                    'zero_attention_weights_tmp_2',
                )
                attention_weights = step_model.net.Tile(
                    attention_weights,
                    'zero_attention_weights_tmp',
                    tiles=self.beam_size,
                    axis=0,
                )

        return (
            state_configs,
            output_log_probs,
            attention_weights,
        )
Пример #2
0
    def model_build_fun(self, model, forward_only=False, loss_scale=None):
        encoder_inputs = model.net.AddExternalInput(
            workspace.GetNameScope() + 'encoder_inputs', )
        encoder_lengths = model.net.AddExternalInput(
            workspace.GetNameScope() + 'encoder_lengths', )
        decoder_inputs = model.net.AddExternalInput(
            workspace.GetNameScope() + 'decoder_inputs', )
        decoder_lengths = model.net.AddExternalInput(
            workspace.GetNameScope() + 'decoder_lengths', )
        targets = model.net.AddExternalInput(
            workspace.GetNameScope() + 'targets', )
        target_weights = model.net.AddExternalInput(
            workspace.GetNameScope() + 'target_weights', )
        attention_type = self.model_params['attention']
        assert attention_type in ['none', 'regular', 'dot']

        (
            encoder_outputs,
            weighted_encoder_outputs,
            final_encoder_hidden_states,
            final_encoder_cell_states,
            encoder_units_per_layer,
        ) = seq2seq_util.build_embedding_encoder(
            model=model,
            encoder_params=self.encoder_params,
            num_decoder_layers=len(self.model_params['decoder_layer_configs']),
            inputs=encoder_inputs,
            input_lengths=encoder_lengths,
            vocab_size=self.source_vocab_size,
            embeddings=self.encoder_embeddings,
            embedding_size=self.model_params['encoder_embedding_size'],
            use_attention=(attention_type != 'none'),
            num_gpus=self.num_gpus,
        )

        (
            decoder_outputs,
            decoder_output_size,
        ) = seq2seq_util.build_embedding_decoder(
            model,
            decoder_layer_configs=self.model_params['decoder_layer_configs'],
            inputs=decoder_inputs,
            input_lengths=decoder_lengths,
            encoder_lengths=encoder_lengths,
            encoder_outputs=encoder_outputs,
            weighted_encoder_outputs=weighted_encoder_outputs,
            final_encoder_hidden_states=final_encoder_hidden_states,
            final_encoder_cell_states=final_encoder_cell_states,
            encoder_units_per_layer=encoder_units_per_layer,
            vocab_size=self.target_vocab_size,
            embeddings=self.decoder_embeddings,
            embedding_size=self.model_params['decoder_embedding_size'],
            attention_type=attention_type,
            forward_only=False,
            num_gpus=self.num_gpus,
        )

        output_logits = seq2seq_util.output_projection(
            model=model,
            decoder_outputs=decoder_outputs,
            decoder_output_size=decoder_output_size,
            target_vocab_size=self.target_vocab_size,
            decoder_softmax_size=self.model_params['decoder_softmax_size'],
        )
        targets, _ = model.net.Reshape(
            [targets],
            ['targets', 'targets_old_shape'],
            shape=[-1],
        )
        target_weights, _ = model.net.Reshape(
            [target_weights],
            ['target_weights', 'target_weights_old_shape'],
            shape=[-1],
        )
        _, loss_per_word = model.net.SoftmaxWithLoss(
            [output_logits, targets, target_weights],
            ['OutputProbs_INVALID', 'loss_per_word'],
            only_loss=True,
        )

        num_words = model.net.SumElements(
            [target_weights],
            'num_words',
        )
        total_loss_scalar = model.net.Mul(
            [loss_per_word, num_words],
            'total_loss_scalar',
        )
        total_loss_scalar_weighted = model.net.Scale(
            [total_loss_scalar],
            'total_loss_scalar_weighted',
            scale=1.0 / self.batch_size,
        )
        return [total_loss_scalar_weighted]
Пример #3
0
    def model_build_fun(self, model, forward_only=False, loss_scale=None):
        encoder_inputs = model.net.AddExternalInput(
            workspace.GetNameScope() + 'encoder_inputs', )
        encoder_lengths = model.net.AddExternalInput(
            workspace.GetNameScope() + 'encoder_lengths', )
        decoder_inputs = model.net.AddExternalInput(
            workspace.GetNameScope() + 'decoder_inputs', )
        decoder_lengths = model.net.AddExternalInput(
            workspace.GetNameScope() + 'decoder_lengths', )
        targets = model.net.AddExternalInput(
            workspace.GetNameScope() + 'targets', )
        target_weights = model.net.AddExternalInput(
            workspace.GetNameScope() + 'target_weights', )
        attention_type = self.model_params['attention']
        assert attention_type in ['none', 'regular']

        (
            encoder_outputs,
            weighted_encoder_outputs,
            final_encoder_hidden_state,
            final_encoder_cell_state,
            encoder_output_dim,
        ) = seq2seq_util.build_embedding_encoder(
            model=model,
            encoder_params=self.encoder_params,
            inputs=encoder_inputs,
            input_lengths=encoder_lengths,
            vocab_size=self.source_vocab_size,
            embeddings=self.encoder_embeddings,
            embedding_size=self.model_params['encoder_embedding_size'],
            use_attention=(attention_type != 'none'),
            num_gpus=self.num_gpus,
        )

        assert len(self.model_params['decoder_layer_configs']) == 1
        decoder_num_units = (
            self.model_params['decoder_layer_configs'][0]['num_units'])
        initial_states = seq2seq_util.build_initial_rnn_decoder_states(
            model=model,
            encoder_num_units=encoder_output_dim,
            decoder_num_units=decoder_num_units,
            final_encoder_hidden_state=final_encoder_hidden_state,
            final_encoder_cell_state=final_encoder_cell_state,
            use_attention=(attention_type != 'none'),
        )

        if self.num_gpus == 0:
            embedded_decoder_inputs = model.net.Gather(
                [self.decoder_embeddings, decoder_inputs],
                ['embedded_decoder_inputs'],
            )
        else:
            with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
                embedded_decoder_inputs_cpu = model.net.Gather(
                    [self.decoder_embeddings, decoder_inputs],
                    ['embedded_decoder_inputs_cpu'],
                )
            embedded_decoder_inputs = model.CopyCPUToGPU(
                embedded_decoder_inputs_cpu,
                'embedded_decoder_inputs',
            )

        # seq_len x batch_size x decoder_embedding_size
        if attention_type == 'none':
            decoder_outputs, _, _, _ = rnn_cell.LSTM(
                model=model,
                input_blob=embedded_decoder_inputs,
                seq_lengths=decoder_lengths,
                initial_states=initial_states,
                dim_in=self.model_params['decoder_embedding_size'],
                dim_out=decoder_num_units,
                scope='decoder',
                outputs_with_grads=[0],
            )
            decoder_output_size = decoder_num_units
        else:
            (decoder_outputs, _, _, _, attention_weighted_encoder_contexts,
             _) = rnn_cell.LSTMWithAttention(
                 model=model,
                 decoder_inputs=embedded_decoder_inputs,
                 decoder_input_lengths=decoder_lengths,
                 initial_decoder_hidden_state=initial_states[0],
                 initial_decoder_cell_state=initial_states[1],
                 initial_attention_weighted_encoder_context=initial_states[2],
                 encoder_output_dim=encoder_output_dim,
                 encoder_outputs=encoder_outputs,
                 decoder_input_dim=self.model_params['decoder_embedding_size'],
                 decoder_state_dim=decoder_num_units,
                 scope='decoder',
                 outputs_with_grads=[0, 4],
             )
            decoder_outputs, _ = model.net.Concat(
                [decoder_outputs, attention_weighted_encoder_contexts],
                [
                    'states_and_context_combination',
                    '_states_and_context_combination_concat_dims',
                ],
                axis=2,
            )
            decoder_output_size = decoder_num_units + encoder_output_dim

        # we do softmax over the whole sequence
        # (max_length in the batch * batch_size) x decoder embedding size
        # -1 because we don't know max_length yet
        decoder_outputs_flattened, _ = model.net.Reshape(
            [decoder_outputs],
            [
                'decoder_outputs_flattened',
                'decoder_outputs_and_contexts_combination_old_shape',
            ],
            shape=[-1, decoder_output_size],
        )
        output_logits = seq2seq_util.output_projection(
            model=model,
            decoder_outputs=decoder_outputs_flattened,
            decoder_output_size=decoder_output_size,
            target_vocab_size=self.target_vocab_size,
            decoder_softmax_size=self.model_params['decoder_softmax_size'],
        )
        targets, _ = model.net.Reshape(
            [targets],
            ['targets', 'targets_old_shape'],
            shape=[-1],
        )
        target_weights, _ = model.net.Reshape(
            [target_weights],
            ['target_weights', 'target_weights_old_shape'],
            shape=[-1],
        )
        output_probs = model.net.Softmax(
            [output_logits],
            ['output_probs'],
            engine=('CUDNN' if self.num_gpus > 0 else None),
        )
        label_cross_entropy = model.net.LabelCrossEntropy(
            [output_probs, targets],
            ['label_cross_entropy'],
        )
        weighted_label_cross_entropy = model.net.Mul(
            [label_cross_entropy, target_weights],
            'weighted_label_cross_entropy',
        )
        total_loss_scalar = model.net.SumElements(
            [weighted_label_cross_entropy],
            'total_loss_scalar',
        )
        total_loss_scalar_weighted = model.net.Scale(
            [total_loss_scalar],
            'total_loss_scalar_weighted',
            scale=1.0 / self.batch_size,
        )
        return [total_loss_scalar_weighted]
Пример #4
0
    def _build_decoder(
        self,
        model,
        step_model,
        model_params,
        scope,
        previous_tokens,
        timestep,
        fake_seq_lengths,
    ):
        attention_type = model_params['attention']
        assert attention_type in ['none', 'regular']
        use_attention = (attention_type != 'none')

        with core.NameScope(scope):
            encoder_embeddings = seq2seq_util.build_embeddings(
                model=model,
                vocab_size=self.source_vocab_size,
                embedding_size=model_params['encoder_embedding_size'],
                name='encoder_embeddings',
                freeze_embeddings=False,
            )

        (
            encoder_outputs,
            weighted_encoder_outputs,
            final_encoder_hidden_states,
            final_encoder_cell_states,
            encoder_units_per_layer,
        ) = seq2seq_util.build_embedding_encoder(
            model=model,
            encoder_params=model_params['encoder_type'],
            num_decoder_layers=len(model_params['decoder_layer_configs']),
            inputs=self.encoder_inputs,
            input_lengths=self.encoder_lengths,
            vocab_size=self.source_vocab_size,
            embeddings=encoder_embeddings,
            embedding_size=model_params['encoder_embedding_size'],
            use_attention=use_attention,
            num_gpus=0,
            forward_only=True,
            scope=scope,
        )
        with core.NameScope(scope):
            if use_attention:
                # [max_source_length, beam_size, encoder_output_dim]
                encoder_outputs = model.net.Tile(
                    encoder_outputs,
                    'encoder_outputs_tiled',
                    tiles=self.beam_size,
                    axis=1,
                )

            if weighted_encoder_outputs is not None:
                weighted_encoder_outputs = model.net.Tile(
                    weighted_encoder_outputs,
                    'weighted_encoder_outputs_tiled',
                    tiles=self.beam_size,
                    axis=1,
                )

            decoder_embeddings = seq2seq_util.build_embeddings(
                model=model,
                vocab_size=self.target_vocab_size,
                embedding_size=model_params['decoder_embedding_size'],
                name='decoder_embeddings',
                freeze_embeddings=False,
            )
            embedded_tokens_t_prev = step_model.net.Gather(
                [decoder_embeddings, previous_tokens],
                'embedded_tokens_t_prev',
            )

        decoder_cells = []
        decoder_units_per_layer = []
        for i, layer_config in enumerate(
                model_params['decoder_layer_configs']):
            num_units = layer_config['num_units']
            decoder_units_per_layer.append(num_units)
            if i == 0:
                input_size = model_params['decoder_embedding_size']
            else:
                input_size = (
                    model_params['decoder_layer_configs'][i - 1]['num_units'])

            cell = rnn_cell.LSTMCell(
                name=seq2seq_util.get_layer_scope(scope, 'decoder', i),
                forward_only=True,
                input_size=input_size,
                hidden_size=num_units,
                forget_bias=0.0,
                memory_optimization=False,
            )
            decoder_cells.append(cell)

        with core.NameScope(scope):
            if final_encoder_hidden_states is not None:
                for i in range(len(final_encoder_hidden_states)):
                    if final_encoder_hidden_states[i] is not None:
                        final_encoder_hidden_states[i] = model.net.Tile(
                            final_encoder_hidden_states[i],
                            'final_encoder_hidden_tiled_{}'.format(i),
                            tiles=self.beam_size,
                            axis=1,
                        )
            if final_encoder_cell_states is not None:
                for i in range(len(final_encoder_cell_states)):
                    if final_encoder_cell_states[i] is not None:
                        final_encoder_cell_states[i] = model.net.Tile(
                            final_encoder_cell_states[i],
                            'final_encoder_cell_tiled_{}'.format(i),
                            tiles=self.beam_size,
                            axis=1,
                        )
            initial_states = \
                seq2seq_util.build_initial_rnn_decoder_states(
                    model=model,
                    encoder_units_per_layer=encoder_units_per_layer,
                    decoder_units_per_layer=decoder_units_per_layer,
                    final_encoder_hidden_states=final_encoder_hidden_states,
                    final_encoder_cell_states=final_encoder_cell_states,
                    use_attention=use_attention,
                )

        attention_decoder = seq2seq_util.LSTMWithAttentionDecoder(
            encoder_outputs=encoder_outputs,
            encoder_output_dim=encoder_units_per_layer[-1],
            encoder_lengths=None,
            vocab_size=self.target_vocab_size,
            attention_type=attention_type,
            embedding_size=model_params['decoder_embedding_size'],
            decoder_num_units=decoder_units_per_layer[-1],
            decoder_cells=decoder_cells,
            weighted_encoder_outputs=weighted_encoder_outputs,
            name=scope,
        )
        states_prev = step_model.net.AddExternalInputs(*[
            '{}/{}_prev'.format(scope, s)
            for s in attention_decoder.get_state_names()
        ])
        decoder_outputs, states = attention_decoder.apply(
            model=step_model,
            input_t=embedded_tokens_t_prev,
            seq_lengths=fake_seq_lengths,
            states=states_prev,
            timestep=timestep,
        )

        state_configs = [
            BeamSearchForwardOnly.StateConfig(
                initial_value=initial_state,
                state_prev_link=BeamSearchForwardOnly.LinkConfig(
                    blob=state_prev,
                    offset=0,
                    window=1,
                ),
                state_link=BeamSearchForwardOnly.LinkConfig(
                    blob=state,
                    offset=1,
                    window=1,
                ),
            ) for initial_state, state_prev, state in zip(
                initial_states,
                states_prev,
                states,
            )
        ]

        with core.NameScope(scope):
            decoder_outputs_flattened, _ = step_model.net.Reshape(
                [decoder_outputs],
                [
                    'decoder_outputs_flattened',
                    'decoder_outputs_and_contexts_combination_old_shape',
                ],
                shape=[-1, attention_decoder.get_output_dim()],
            )
            output_logits = seq2seq_util.output_projection(
                model=step_model,
                decoder_outputs=decoder_outputs_flattened,
                decoder_output_size=attention_decoder.get_output_dim(),
                target_vocab_size=self.target_vocab_size,
                decoder_softmax_size=model_params['decoder_softmax_size'],
            )
            # [1, beam_size, target_vocab_size]
            output_probs = step_model.net.Softmax(
                output_logits,
                'output_probs',
            )
            output_log_probs = step_model.net.Log(
                output_probs,
                'output_log_probs',
            )
            if use_attention:
                attention_weights = attention_decoder.get_attention_weights()
            else:
                attention_weights = step_model.net.ConstantFill(
                    [self.encoder_inputs],
                    'zero_attention_weights_tmp_1',
                    value=0.0,
                )
                attention_weights = step_model.net.Transpose(
                    attention_weights,
                    'zero_attention_weights_tmp_2',
                )
                attention_weights = step_model.net.Tile(
                    attention_weights,
                    'zero_attention_weights_tmp',
                    tiles=self.beam_size,
                    axis=0,
                )

        return (
            state_configs,
            output_log_probs,
            attention_weights,
        )
Пример #5
0
    def model_build_fun(self, model, forward_only=False, loss_scale=None):
        encoder_inputs = model.net.AddExternalInput(
            workspace.GetNameScope() + 'encoder_inputs',
        )
        encoder_lengths = model.net.AddExternalInput(
            workspace.GetNameScope() + 'encoder_lengths',
        )
        decoder_inputs = model.net.AddExternalInput(
            workspace.GetNameScope() + 'decoder_inputs',
        )
        decoder_lengths = model.net.AddExternalInput(
            workspace.GetNameScope() + 'decoder_lengths',
        )
        targets = model.net.AddExternalInput(
            workspace.GetNameScope() + 'targets',
        )
        target_weights = model.net.AddExternalInput(
            workspace.GetNameScope() + 'target_weights',
        )
        attention_type = self.model_params['attention']
        assert attention_type in ['none', 'regular', 'dot']

        (
            encoder_outputs,
            weighted_encoder_outputs,
            final_encoder_hidden_states,
            final_encoder_cell_states,
            encoder_units_per_layer,
        ) = seq2seq_util.build_embedding_encoder(
            model=model,
            encoder_params=self.encoder_params,
            num_decoder_layers=len(self.model_params['decoder_layer_configs']),
            inputs=encoder_inputs,
            input_lengths=encoder_lengths,
            vocab_size=self.source_vocab_size,
            embeddings=self.encoder_embeddings,
            embedding_size=self.model_params['encoder_embedding_size'],
            use_attention=(attention_type != 'none'),
            num_gpus=self.num_gpus,
        )

        (
            decoder_outputs,
            decoder_output_size,
        ) = seq2seq_util.build_embedding_decoder(
            model,
            decoder_layer_configs=self.model_params['decoder_layer_configs'],
            inputs=decoder_inputs,
            input_lengths=decoder_lengths,
            encoder_lengths=encoder_lengths,
            encoder_outputs=encoder_outputs,
            weighted_encoder_outputs=weighted_encoder_outputs,
            final_encoder_hidden_states=final_encoder_hidden_states,
            final_encoder_cell_states=final_encoder_cell_states,
            encoder_units_per_layer=encoder_units_per_layer,
            vocab_size=self.target_vocab_size,
            embeddings=self.decoder_embeddings,
            embedding_size=self.model_params['decoder_embedding_size'],
            attention_type=attention_type,
            forward_only=False,
            num_gpus=self.num_gpus,
        )

        output_logits = seq2seq_util.output_projection(
            model=model,
            decoder_outputs=decoder_outputs,
            decoder_output_size=decoder_output_size,
            target_vocab_size=self.target_vocab_size,
            decoder_softmax_size=self.model_params['decoder_softmax_size'],
        )
        targets, _ = model.net.Reshape(
            [targets],
            ['targets', 'targets_old_shape'],
            shape=[-1],
        )
        target_weights, _ = model.net.Reshape(
            [target_weights],
            ['target_weights', 'target_weights_old_shape'],
            shape=[-1],
        )
        _, loss_per_word = model.net.SoftmaxWithLoss(
            [output_logits, targets, target_weights],
            ['OutputProbs_INVALID', 'loss_per_word'],
            only_loss=True,
        )

        num_words = model.net.SumElements(
            [target_weights],
            'num_words',
        )
        total_loss_scalar = model.net.Mul(
            [loss_per_word, num_words],
            'total_loss_scalar',
        )
        total_loss_scalar_weighted = model.net.Scale(
            [total_loss_scalar],
            'total_loss_scalar_weighted',
            scale=1.0 / self.batch_size,
        )
        return [total_loss_scalar_weighted]
Пример #6
0
    def _build_decoder(
        self,
        model,
        step_model,
        model_params,
        scope,
        previous_tokens,
        timestep,
        fake_seq_lengths,
    ):
        attention_type = model_params['attention']
        assert attention_type in ['none', 'regular']
        use_attention = (attention_type != 'none')

        with core.NameScope(scope):
            encoder_embeddings = seq2seq_util.build_embeddings(
                model=model,
                vocab_size=self.source_vocab_size,
                embedding_size=model_params['encoder_embedding_size'],
                name='encoder_embeddings',
                freeze_embeddings=False,
            )

        (
            encoder_outputs,
            weighted_encoder_outputs,
            final_encoder_hidden_state,
            final_encoder_cell_state,
            encoder_output_dim,
        ) = seq2seq_util.build_embedding_encoder(
            model=model,
            encoder_params=model_params['encoder_type'],
            inputs=self.encoder_inputs,
            input_lengths=self.encoder_lengths,
            vocab_size=self.source_vocab_size,
            embeddings=encoder_embeddings,
            embedding_size=model_params['encoder_embedding_size'],
            use_attention=use_attention,
            num_gpus=0,
            scope=scope,
        )
        with core.NameScope(scope):
            # [max_source_length, beam_size, encoder_output_dim]
            encoder_outputs = model.net.Tile(
                encoder_outputs,
                'encoder_outputs_tiled',
                tiles=self.beam_size,
                axis=1,
            )
            if weighted_encoder_outputs is not None:
                weighted_encoder_outputs = model.net.Tile(
                    weighted_encoder_outputs,
                    'weighted_encoder_outputs_tiled',
                    tiles=self.beam_size,
                    axis=1,
                )

            decoder_embeddings = seq2seq_util.build_embeddings(
                model=model,
                vocab_size=self.target_vocab_size,
                embedding_size=model_params['decoder_embedding_size'],
                name='decoder_embeddings',
                freeze_embeddings=False,
            )
            embedded_tokens_t_prev = step_model.net.Gather(
                [decoder_embeddings, previous_tokens],
                'embedded_tokens_t_prev',
            )

        decoder_num_units = (
            model_params['decoder_layer_configs'][0]['num_units']
        )

        with core.NameScope(scope):
            if not use_attention and final_encoder_hidden_state is not None:
                final_encoder_hidden_state = model.net.Tile(
                    final_encoder_hidden_state,
                    'final_encoder_hidden_state_tiled',
                    tiles=self.beam_size,
                    axis=1,
                )
            if not use_attention and final_encoder_cell_state is not None:
                final_encoder_cell_state = model.net.Tile(
                    final_encoder_cell_state,
                    'final_encoder_cell_state_tiled',
                    tiles=self.beam_size,
                    axis=1,
                )
            initial_states = seq2seq_util.build_initial_rnn_decoder_states(
                model=model,
                encoder_num_units=encoder_output_dim,
                decoder_num_units=decoder_num_units,
                final_encoder_hidden_state=final_encoder_hidden_state,
                final_encoder_cell_state=final_encoder_cell_state,
                use_attention=use_attention,
            )

        if use_attention:
            decoder_cell = rnn_cell.LSTMWithAttentionCell(
                encoder_output_dim=encoder_output_dim,
                encoder_outputs=encoder_outputs,
                decoder_input_dim=model_params['decoder_embedding_size'],
                decoder_state_dim=decoder_num_units,
                name=self.scope(scope, 'decoder'),
                attention_type=attention.AttentionType.Regular,
                weighted_encoder_outputs=weighted_encoder_outputs,
                forget_bias=0.0,
                lstm_memory_optimization=False,
                attention_memory_optimization=True,
            )
            decoder_output_dim = decoder_num_units + encoder_output_dim
        else:
            decoder_cell = rnn_cell.LSTMCell(
                name=self.scope(scope, 'decoder'),
                input_size=model_params['decoder_embedding_size'],
                hidden_size=decoder_num_units,
                forget_bias=0.0,
                memory_optimization=False,
            )
            decoder_output_dim = decoder_num_units

        states_prev = step_model.net.AddExternalInputs(*[
            s + '_prev' for s in decoder_cell.get_state_names()
        ])
        _, states = decoder_cell.apply(
            model=step_model,
            input_t=embedded_tokens_t_prev,
            seq_lengths=fake_seq_lengths,
            states=states_prev,
            timestep=timestep,
        )
        if use_attention:
            with core.NameScope(scope or ''):
                decoder_outputs, _ = step_model.net.Concat(
                    [states[0], states[2]],
                    [
                        'states_and_context_combination',
                        '_states_and_context_combination_concat_dims',
                    ],
                    axis=2,
                )
        else:
            decoder_outputs = states[0]

        state_configs = [
            BeamSearchForwardOnly.StateConfig(
                initial_value=initial_state,
                state_prev_link=BeamSearchForwardOnly.LinkConfig(
                    blob=state_prev,
                    offset=0,
                    window=1,
                ),
                state_link=BeamSearchForwardOnly.LinkConfig(
                    blob=state,
                    offset=1,
                    window=1,
                ),
            )
            for initial_state, state_prev, state in zip(
                initial_states,
                states_prev,
                states,
            )
        ]

        with core.NameScope(scope):
            decoder_outputs_flattened, _ = step_model.net.Reshape(
                [decoder_outputs],
                [
                    'decoder_outputs_flattened',
                    'decoder_outputs_and_contexts_combination_old_shape',
                ],
                shape=[-1, decoder_output_dim],
            )
            output_logits = seq2seq_util.output_projection(
                model=step_model,
                decoder_outputs=decoder_outputs_flattened,
                decoder_output_size=decoder_output_dim,
                target_vocab_size=self.target_vocab_size,
                decoder_softmax_size=model_params['decoder_softmax_size'],
            )
            # [1, beam_size, target_vocab_size]
            output_probs = step_model.net.Softmax(
                output_logits,
                'output_probs',
            )
            output_log_probs = step_model.net.Log(
                output_probs,
                'output_log_probs',
            )
            if use_attention:
                attention_weights = decoder_cell.get_attention_weights()
            else:
                attention_weights = step_model.net.ConstantFill(
                    [self.encoder_inputs],
                    'zero_attention_weights_tmp_1',
                    value=0.0,
                )
                attention_weights = step_model.net.Transpose(
                    attention_weights,
                    'zero_attention_weights_tmp_2',
                )
                attention_weights = step_model.net.Tile(
                    attention_weights,
                    'zero_attention_weights_tmp',
                    tiles=self.beam_size,
                    axis=0,
                )

        return (
            state_configs,
            output_log_probs,
            attention_weights,
        )