Esempio n. 1
0
    def __init__(self,
                 isize,
                 snwd,
                 tnwd,
                 num_layer,
                 fhsize=None,
                 dropout=0.0,
                 attn_drop=0.0,
                 global_emb=False,
                 num_head=8,
                 xseql=cache_len_default,
                 ahsize=None,
                 norm_output=True,
                 bindDecoderEmb=False,
                 forbidden_index=None):

        super(NMT, self).__init__()

        enc_layer, dec_layer = parse_double_value_tuple(num_layer)

        self.enc = Encoder(isize, snwd, enc_layer, fhsize, dropout, attn_drop,
                           num_head, xseql, ahsize, norm_output)

        emb_w = self.enc.wemb.weight if global_emb else None

        self.dec = Decoder(isize, tnwd, dec_layer, fhsize, dropout, attn_drop,
                           emb_w, num_head, xseql, ahsize, norm_output,
                           bindDecoderEmb, forbidden_index)
        #self.dec = Decoder(isize, tnwd, dec_layer, dropout, attn_drop, emb_w, num_head, xseql, ahsize, norm_output, bindDecoderEmb, forbidden_index)# for RNMT

        if rel_pos_enabled:
            share_rel_pos_cache(self)
Esempio n. 2
0
    def __init__(self,
                 isize,
                 snwd,
                 tnwd,
                 num_layer,
                 fhsize=None,
                 dropout=0.0,
                 attn_drop=0.0,
                 global_emb=False,
                 num_head=8,
                 xseql=cache_len_default,
                 ahsize=None,
                 norm_output=True,
                 bindDecoderEmb=True,
                 forbidden_index=None,
                 num_layer_ana=None):

        super(NMT, self).__init__(isize,
                                  snwd,
                                  tnwd,
                                  num_layer,
                                  fhsize=fhsize,
                                  dropout=dropout,
                                  attn_drop=attn_drop,
                                  global_emb=global_emb,
                                  num_head=num_head,
                                  xseql=xseql,
                                  ahsize=ahsize,
                                  norm_output=norm_output,
                                  bindDecoderEmb=bindDecoderEmb,
                                  forbidden_index=forbidden_index)

        emb_w = self.enc.wemb.weight if global_emb else None

        _, dec_layer = parse_double_value_tuple(num_layer)

        self.dec = Decoder(isize, tnwd, dec_layer, fhsize, dropout, attn_drop,
                           emb_w, num_head, xseql, ahsize, norm_output,
                           bindDecoderEmb, forbidden_index, num_layer_ana)

        if num_layer_ana <= 0:
            self.enc = None

        if rel_pos_enabled:
            share_rel_pos_cache(self)
Esempio n. 3
0
    def __init__(self,
                 isize,
                 snwd,
                 tnwd,
                 num_layer,
                 fhsize=None,
                 dropout=0.0,
                 attn_drop=0.0,
                 global_emb=False,
                 num_head=8,
                 xseql=cache_len_default,
                 ahsize=None,
                 norm_output=True,
                 bindDecoderEmb=True,
                 forbidden_index=None,
                 ntask=None,
                 **kwargs):

        enc_layer, dec_layer = parse_double_value_tuple(num_layer)

        super(NMT, self).__init__(isize,
                                  snwd,
                                  tnwd, (
                                      enc_layer,
                                      dec_layer,
                                  ),
                                  fhsize=fhsize,
                                  dropout=dropout,
                                  attn_drop=attn_drop,
                                  global_emb=global_emb,
                                  num_head=num_head,
                                  xseql=xseql,
                                  ahsize=ahsize,
                                  norm_output=norm_output,
                                  bindDecoderEmb=bindDecoderEmb,
                                  forbidden_index=None)

        self.enc = Encoder(isize,
                           snwd,
                           enc_layer,
                           fhsize=fhsize,
                           dropout=dropout,
                           attn_drop=attn_drop,
                           num_head=num_head,
                           xseql=xseql,
                           ahsize=ahsize,
                           norm_output=norm_output,
                           ntask=ntask)

        if global_emb:
            emb_w = self.enc.wemb.weight
            task_emb_w = self.enc.task_emb.weight
        else:
            emb_w = task_emb_w = None

        self.dec = Decoder(isize,
                           tnwd,
                           dec_layer,
                           fhsize=fhsize,
                           dropout=dropout,
                           attn_drop=attn_drop,
                           emb_w=emb_w,
                           num_head=num_head,
                           xseql=xseql,
                           ahsize=ahsize,
                           norm_output=norm_output,
                           bindemb=bindDecoderEmb,
                           forbidden_index=forbidden_index,
                           ntask=ntask,
                           task_emb_w=task_emb_w)

        if rel_pos_enabled:
            share_rel_pos_cache(self)