def __init__(self, params, vocab):
        super(Seq2Seq, self).__init__()
        self.embedding_matrix = load_embedding_matrix()
        self.params = params
        self.vocab = vocab
        self.batch_size = params["batch_size"]
        self.enc_units = params["enc_units"]
        self.dec_units = params["dec_units"]
        self.attn_units = params["attn_units"]

        self.encoder = Encoder(embedding_matrix=self.embedding_matrix,
                               batch_size=self.batch_size,
                               enc_units=self.enc_units,
                               rnn_type='gru')

        # self.att_model = 'bahdanau'
        self.att_model = BahdanauAttention(att_hidden_size=self.dec_units,
                                           att_units=self.dec_units)

        self.decoder = Decoder(
            embedding_matrix=self.embedding_matrix,
            dec_units=self.dec_units,
            rnn_type='gru',
            att_model=self.att_model,
        )
    def select_att(self, att_model, **config):
        if att_model in [
                'bahdanau', 'bahdanau_bttention', 'bahdanauattention',
                'additive', 'additive_attention', 'additiveattention'
        ]:
            try:
                hidden_size = config['att_hidden_size']
                units = config['att_units']
            except KeyError:
                raise (
                    'BahdanauAttention模型缺少参数:att_hidden_size, att_units! 请于创建Decoder时传入!'
                )

            return BahdanauAttention(hidden_size, units)
        elif att_model in [
                'luong', 'luong_bttention', 'luongttention', 'multiplicative',
                'multiplicative_attention', 'multiplicativeattention'
        ]:
            try:
                hidden_size = config['att_hidden_size']
                attention_func = config['attention_func']
            except KeyError:
                raise (
                    'BahdanauAttention模型缺少参数:att_hidden_size, attention_func! 请于创建Decoder时传入!'
                )
            return LuongAttention(hidden_size, attention_func)
    # example_input
    example_input_batch = torch.ones(size=(params['batch_size'],
                                           params['input_sequence_len']),
                                     dtype=torch.int32,
                                     device=device)

    sample_output, sample_hidden = model.encoder(example_input_batch)
    # 打印结果
    print(
        'Encoder output shape: (batch size, sequence length, units) {}'.format(
            sample_output.shape))
    print('Encoder Hidden state shape: (batch size, units) {}'.format(
        sample_hidden.shape))

    attention_layer = BahdanauAttention(params['dec_units'],
                                        params['attn_units']).to(device)
    context_vector, attention_weights = attention_layer(
        sample_hidden, sample_output)

    print("Attention context_vector shape: (batch size, units) {}".format(
        context_vector.shape))
    print(
        "Attention weights shape: (batch_size, sequence_length, 1) {}".format(
            attention_weights.shape))

    sample_decoder_output, _, _ = model.decoder(
        torch.ones(size=(params['batch_size'], 1),
                   dtype=torch.int32,
                   device=device), sample_hidden, sample_output)

    print('Decoder output shape: (batch_size, vocab size) {}'.format(
    # example_input
    example_input_batch = torch.ones(size=(batch_size, input_sequence_len),
                                     dtype=torch.int32,
                                     device=device)
    # sample input
    # sample_hidden = encoder.initialize_hidden_state()
    # sample_hidden = torch.zeros(size=(batch_size, units))

    sample_output, sample_hidden = encoder(example_input_batch)
    # 打印结果
    print('Encoder output shape: (batch size, sequence length, hidden_dim) {}'.
          format(sample_output.shape))
    print('Encoder Hidden state shape: (batch size, hidden_dim) {}'.format(
        sample_hidden.shape))

    attention_layer = BahdanauAttention(units, units).to(device)
    context_vector, attention_weights = attention_layer(
        sample_hidden, sample_output)

    print("Attention context_vector shape: (batch size, units) {}".format(
        context_vector.shape))
    print(
        "Attention weights shape: (batch_size, sequence_length, 1) {}".format(
            attention_weights.shape))

    decoder = Decoder(embedding_matrix=embedding_matrix,
                      dec_units=units,
                      att_model=attention_layer).to(device)
    sample_decoder_output, state, attention_weights = decoder(
        torch.ones(size=(batch_size, 1), dtype=torch.int32, device=device),
        sample_hidden, sample_output)