Example #1
0
    def __init__(self,
                 isize,
                 snwd,
                 tnwd,
                 num_layer,
                 fhsize=None,
                 dropout=0.0,
                 attn_drop=0.0,
                 global_emb=False,
                 num_head=8,
                 xseql=cache_len_default,
                 ahsize=None,
                 norm_output=True,
                 bindDecoderEmb=False,
                 forbidden_index=None):

        super(NMT, self).__init__()

        enc_layer, dec_layer = parse_double_value_tuple(num_layer)

        self.enc = Encoder(isize, snwd, enc_layer, fhsize, dropout, attn_drop,
                           num_head, xseql, ahsize, norm_output)

        emb_w = self.enc.wemb.weight if global_emb else None

        self.dec = Decoder(isize, tnwd, dec_layer, fhsize, dropout, attn_drop,
                           emb_w, num_head, xseql, ahsize, norm_output,
                           bindDecoderEmb, forbidden_index)
        #self.dec = Decoder(isize, tnwd, dec_layer, dropout, attn_drop, emb_w, num_head, xseql, ahsize, norm_output, bindDecoderEmb, forbidden_index)# for RNMT

        if rel_pos_enabled:
            share_rel_pos_cache(self)
Example #2
0
    def __init__(self,
                 isize,
                 nwd,
                 num_layer,
                 fhsize=None,
                 dropout=0.0,
                 attn_drop=0.0,
                 num_head=8,
                 xseql=cache_len_default,
                 ahsize=None,
                 norm_output=True,
                 global_emb=False,
                 **kwargs):

        super(Encoder, self).__init__()

        nwd_src, nwd_tgt = parse_double_value_tuple(nwd)

        self.src_enc = EncoderBase(isize, nwd_src, num_layer, fhsize, dropout,
                                   attn_drop, num_head, xseql, ahsize,
                                   norm_output, **kwargs)

        emb_w = self.src_enc.wemb.weight if global_emb else None

        self.tgt_enc = MSEncoder(isize, nwd_tgt, num_layer, fhsize, dropout,
                                 attn_drop, num_head, xseql, ahsize,
                                 norm_output, emb_w, **kwargs)
Example #3
0
    def __init__(self,
                 isize,
                 snwd,
                 tnwd,
                 num_layer,
                 fhsize=None,
                 dropout=0.0,
                 attn_drop=0.0,
                 global_emb=False,
                 num_head=8,
                 xseql=cache_len_default,
                 ahsize=None,
                 norm_output=True,
                 bindDecoderEmb=True,
                 forbidden_index=None,
                 num_layer_ana=None):

        super(NMT, self).__init__(isize,
                                  snwd,
                                  tnwd,
                                  num_layer,
                                  fhsize=fhsize,
                                  dropout=dropout,
                                  attn_drop=attn_drop,
                                  global_emb=global_emb,
                                  num_head=num_head,
                                  xseql=xseql,
                                  ahsize=ahsize,
                                  norm_output=norm_output,
                                  bindDecoderEmb=bindDecoderEmb,
                                  forbidden_index=forbidden_index)

        emb_w = self.enc.wemb.weight if global_emb else None

        _, dec_layer = parse_double_value_tuple(num_layer)

        self.dec = Decoder(isize, tnwd, dec_layer, fhsize, dropout, attn_drop,
                           emb_w, num_head, xseql, ahsize, norm_output,
                           bindDecoderEmb, forbidden_index, num_layer_ana)

        if num_layer_ana <= 0:
            self.enc = None

        if rel_pos_enabled:
            share_rel_pos_cache(self)
Example #4
0
from random import shuffle, random

from tqdm import tqdm

import h5py

import cnfg.dynb as cnfg
from cnfg.ihyp import *

from transformer.NMT import NMT

log_dyn_p, max_his, log_dynb = 1.0, 9, True

update_angle = cnfg.update_angle
enc_layer, dec_layer = parse_double_value_tuple(cnfg.nlayer)


def select_function(modin, select_index):

    _sel_m = (list(modin.enc.nets) + list(modin.dec.nets))[select_index]

    return _sel_m.parameters()


grad_mon = GradientMonitor(enc_layer + dec_layer,
                           select_function,
                           module=None,
                           angle_alpha=cnfg.dyn_tol_alpha,
                           num_tol_amin=cnfg.dyn_tol_amin,
                           num_his_record=cnfg.num_dynb_his,
Example #5
0
    def __init__(self,
                 isize,
                 snwd,
                 tnwd,
                 num_layer,
                 fhsize=None,
                 dropout=0.0,
                 attn_drop=0.0,
                 global_emb=False,
                 num_head=8,
                 xseql=cache_len_default,
                 ahsize=None,
                 norm_output=True,
                 bindDecoderEmb=True,
                 forbidden_index=None,
                 ntask=None,
                 **kwargs):

        enc_layer, dec_layer = parse_double_value_tuple(num_layer)

        super(NMT, self).__init__(isize,
                                  snwd,
                                  tnwd, (
                                      enc_layer,
                                      dec_layer,
                                  ),
                                  fhsize=fhsize,
                                  dropout=dropout,
                                  attn_drop=attn_drop,
                                  global_emb=global_emb,
                                  num_head=num_head,
                                  xseql=xseql,
                                  ahsize=ahsize,
                                  norm_output=norm_output,
                                  bindDecoderEmb=bindDecoderEmb,
                                  forbidden_index=None)

        self.enc = Encoder(isize,
                           snwd,
                           enc_layer,
                           fhsize=fhsize,
                           dropout=dropout,
                           attn_drop=attn_drop,
                           num_head=num_head,
                           xseql=xseql,
                           ahsize=ahsize,
                           norm_output=norm_output,
                           ntask=ntask)

        if global_emb:
            emb_w = self.enc.wemb.weight
            task_emb_w = self.enc.task_emb.weight
        else:
            emb_w = task_emb_w = None

        self.dec = Decoder(isize,
                           tnwd,
                           dec_layer,
                           fhsize=fhsize,
                           dropout=dropout,
                           attn_drop=attn_drop,
                           emb_w=emb_w,
                           num_head=num_head,
                           xseql=xseql,
                           ahsize=ahsize,
                           norm_output=norm_output,
                           bindemb=bindDecoderEmb,
                           forbidden_index=forbidden_index,
                           ntask=ntask,
                           task_emb_w=task_emb_w)

        if rel_pos_enabled:
            share_rel_pos_cache(self)
Example #6
0
inf_default = inf

ieps_default = 1e-9
ieps_ln_default = 1e-6
ieps_adam_default = 1e-9
ieps_ln_default = parse_none(ieps_ln_default, ieps_default)
ieps_adam_default = parse_none(ieps_adam_default, ieps_default)
ieps_noise_default = ieps_ln_default

adam_betas_default = (
    0.9,
    0.98,
)

use_k_relative_position_encoder, use_k_relative_position_decoder = parse_double_value_tuple(
    use_k_relative_position)
rel_pos_enabled = (max(use_k_relative_position_encoder,
                       use_k_relative_position_decoder) > 0)
disable_std_pemb_encoder, disable_std_pemb_decoder = parse_double_value_tuple(
    disable_std_pemb)

h5datawargs = {} if hdf5_data_compression is None else {
    "compression": hdf5_data_compression,
    "compression_opts": hdf5_data_compression_level,
    "shuffle": True
}
h5modelwargs = {} if hdf5_model_compression is None else {
    "compression": hdf5_model_compression,
    "compression_opts": hdf5_model_compression_level,
    "shuffle": True
}