def build_modules(self): e_length = expected_length(self.layers, self.death_rate) if self.reversible: print( "* Reversible Transformer Decoder with Absolute Attention with %.2f expected layers" % e_length) else: print( "* Transformer Decoder with Absolute Attention with %.2f expected layers" % e_length) for _l in range(self.layers): # linearly decay the death rate death_r = (_l + 1.0) / self.layers * self.death_rate if not self.reversible: # block = DecoderLayer(self.n_heads, self.model_size, # self.dropout, self.inner_size, self.attn_dropout, # variational=self.variational_dropout, death_rate=death_r) block = DecoderLayer(self.opt, death_rate=death_r) else: block = ReversibleTransformerDecoderLayer(self.opt, death_rate=_l) self.layer_modules.append(block)
def __init__(self, opt, dicts, positional_encoder): super(ParallelTransformerDecoder, self).__init__() self.model_size = opt.model_size self.n_heads = opt.n_heads self.inner_size = opt.inner_size self.layers = opt.layers self.dropout = opt.dropout self.word_dropout = opt.word_dropout self.attn_dropout = opt.attn_dropout self.emb_dropout = opt.emb_dropout self.time = opt.time if hasattr(opt, 'grow_dropout'): self.grow_dropout = opt.grow_dropout if opt.time == 'positional_encoding': self.time_transformer = positional_encoder elif opt.time == 'gru': self.time_transformer = nn.GRU(self.model_size, self.model_size, 1, batch_first=True) elif opt.time == 'lstm': self.time_transformer = nn.LSTM(self.model_size, self.model_size, 1, batch_first=True) #~ self.preprocess_layer = PrePostProcessing(self.model_size, self.emb_dropout, sequence='d', static=False) self.preprocess_layer = PrePostProcessing(self.model_size, self.emb_dropout, sequence='d', static=onmt.constants.static) self.postprocess_layer = PrePostProcessing(self.model_size, 0, sequence='n') self.word_lut = nn.Embedding(dicts.size(), self.model_size, padding_idx=onmt.constants.PAD) self.positional_encoder = positional_encoder self.layer_modules = nn.ModuleList([ DecoderLayer(self.n_heads, self.model_size, self.dropout, self.inner_size, self.attn_dropout) for _ in range(self.layers) ]) len_max = self.positional_encoder.len_max mask = torch.ByteTensor( np.triu(np.ones((len_max, len_max)), k=1).astype('uint8')) self.register_buffer('mask', mask)
def build_modules(self): e_length = expected_length(self.layers, self.death_rate) print( "* Transformer Decoder with Absolute Attention with %.2f expected layers" % e_length) self.layer_modules = nn.ModuleList() for l in range(self.layers): # linearly decay the death rate death_r = (l + 1.0) / self.layers * self.death_rate block = DecoderLayer(self.n_heads, self.model_size, self.dropout, self.inner_size, self.attn_dropout, ignore_source=True, variational=self.variational_dropout, death_rate=death_r) self.layer_modules.append(block)
def add_layers(self, n_new_layer): self.new_modules = list() self.layers += n_new_layer for i in range(n_new_layer): layer = DecoderLayer(self.n_heads, self.model_size, self.dropout, self.inner_size, self.attn_dropout) # the first layer will use the preprocessing which is the last postprocessing if i == 0: # layer.preprocess_attn = self.postprocess_layer layer.preprocess_attn.load_state_dict( self.postprocess_layer.state_dict()) #~ layer.preprocess_attn.layer_norm.function.weight.requires_grad = False #~ layer.preprocess_attn.layer_norm.function.bias.requires_grad = False # replace the last postprocessing layer with a new one #~ if hasattr(layer.postprocess_attn, 'k'): #~ layer.postprocess_attn.k.data.fill_(0.01) self.postprocess_layer = PrePostProcessing(self.model_size, 0, sequence='n') self.layer_modules.append(layer)