Ejemplo n.º 1
0
    def __init__(self, data_dir, lr):
        super().__init__()

        self.data_dir = data_dir
        self.lr = lr

        transfomer_encoder = Encoder(input_dim=FLAGS.input_dim,
                                     hid_dim=FLAGS.hid_dim,
                                     n_layers=FLAGS.enc_layers,
                                     n_heads=FLAGS.enc_heads,
                                     pf_dim=FLAGS.enc_pf_dim,
                                     dropout=FLAGS.enc_dropout,
                                     max_length=100)
        transfomer_decoder = Decoder(output_dim=FLAGS.output_dim,
                                     hid_dim=FLAGS.hid_dim,
                                     n_layers=FLAGS.dec_layers,
                                     n_heads=FLAGS.dec_heads,
                                     pf_dim=FLAGS.dec_pf_dim,
                                     dropout=FLAGS.dec_dropout,
                                     max_length=100)

        self.model = Seq2Seq(
            encoder=transfomer_encoder,
            decoder=transfomer_decoder,
            src_pad_idx=FLAGS.src_pad_idx,
            trg_pad_idx=FLAGS.trg_pad_idx,
        )
        self.model.apply(self.initialize_weights)

        self.loss = torch.nn.CrossEntropyLoss(ignore_index=FLAGS.trg_pad_idx)
Ejemplo n.º 2
0
def test_decoder():
    batch_size = 8
    sequence_length = 16
    hidden_dim = 8
    num_heads = 4
    self_attention_dropout_prob = 0.3
    dec_enc_attention_dropout_prob = 0.3
    feed_forward_dropout_prob = 0.3
    layernorm_epsilon = 1e-6

    decoder = Decoder(
        hidden_dim=hidden_dim,
        num_heads=num_heads,
        self_attention_dropout_prob=self_attention_dropout_prob,
        dec_enc_attention_dropout_prob=dec_enc_attention_dropout_prob,
        feed_forward_dropout_prob=feed_forward_dropout_prob,
        layernorm_epsilon=layernorm_epsilon,
    )

    inputs = torch.rand(batch_size, sequence_length, hidden_dim)
    enc_outputs = torch.rand(batch_size, sequence_length, hidden_dim)
    self_attention_mask = (
        1 - torch.triu(torch.ones(sequence_length, sequence_length)).long().T
    )
    self_attention_mask = self_attention_mask.unsqueeze(0)

    dec_enc_attention_mask = (
        1 - torch.triu(torch.ones(sequence_length, sequence_length)).long().T
    )
    dec_enc_attention_mask = dec_enc_attention_mask.unsqueeze(0)

    outputs = decoder(inputs, enc_outputs, self_attention_mask, dec_enc_attention_mask)
    assert outputs.size() == (batch_size, sequence_length, hidden_dim)
Ejemplo n.º 3
0
  def __init__(self,
               token_emb,
               dim=2560,
               decoder_depth=13,
               max_seq_len=128,
               head_num=32,
               dropout=0.1):
    super().__init__()
    self.token_emb = token_emb
    self.position_emb = PositionalEmbedding(dim,max_seq_len)

    self.decoders = nn.ModuleList([Decoder(d_model=dim, head_num=head_num, dropout=dropout) for _ in range(decoder_depth)])
Ejemplo n.º 4
0
def get_tutorial_transformer():
    transfomer_encoder = Encoder(input_dim=FLAGS.input_dim,
                                 hid_dim=FLAGS.d_model,
                                 n_layers=FLAGS.num_layers,
                                 n_heads=FLAGS.n_heads,
                                 pf_dim=FLAGS.d_ff,
                                 dropout=FLAGS.dropout,
                                 max_length=100)
    transfomer_decoder = Decoder(output_dim=FLAGS.output_dim,
                                 hid_dim=FLAGS.d_model,
                                 n_layers=FLAGS.num_layers,
                                 n_heads=FLAGS.n_heads,
                                 pf_dim=FLAGS.d_ff,
                                 dropout=FLAGS.dropout,
                                 max_length=100)

    return Seq2Seq(
        encoder=transfomer_encoder,
        decoder=transfomer_decoder,
        src_pad_idx=FLAGS.src_pad_idx,
        trg_pad_idx=FLAGS.trg_pad_idx,
    )
Ejemplo n.º 5
0
    TXT_DIM = len(dictionary['txt_s2i'])
    print('(dimension)')
    print(' MR_DIM     : '+str(MR_DIM))
    print(' TXT_DIM    : '+str(TXT_DIM))
    HID_DIM = 256
    ENC_LAYERS = 3
    DEC_LAYERS = 3
    ENC_HEADS = 8
    DEC_HEADS = 8
    ENC_PF_DIM = 512
    DEC_PF_DIM = 512
    CLIP = 1

    # (A) load NLU/NLG models
    nlu_encoder = Encoder(TXT_DIM, HID_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM, args.do, device, dictionary['max_txt_length'])
    nlu_decoder = Decoder(MR_DIM, HID_DIM, DEC_LAYERS, DEC_HEADS, DEC_PF_DIM, args.do, device, dictionary['max_mr_length'])
    model_nlu = Seq2Seq_nlu(nlu_encoder, nlu_decoder, dictionary['mr_pad_idx'], dictionary['txt_pad_idx'], device).to(device)

    nlg_encoder = Encoder(MR_DIM, HID_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM, args.do, device, dictionary['max_mr_length'])
    nlg_decoder = Decoder(TXT_DIM, HID_DIM, DEC_LAYERS, DEC_HEADS, DEC_PF_DIM, args.do, device, dictionary['max_txt_length'])
    model_nlg = Seq2Seq_nlg_single(nlg_encoder, nlg_decoder, dictionary['mr_pad_idx'], dictionary['txt_pad_idx'], device).to(device)

    optimizer_nlu = torch.optim.Adam(model_nlu.parameters(), lr = args.lr)
    optimizer_nlg = torch.optim.Adam(model_nlg.parameters(), lr = args.lr)

    criterion_nlu = nn.CrossEntropyLoss(ignore_index = dictionary['mr_pad_idx'])
    criterion_nlg = nn.CrossEntropyLoss(ignore_index = dictionary['txt_pad_idx'])

    checkpoint_nlu = torch.load(args.init_dir_NLU.rstrip('/')+'/'+args.init_model_NLU)
    checkpoint_nlg = torch.load(args.init_dir_NLG.rstrip('/')+'/'+args.init_model_NLG)