def forward(self, input, *args): if self.optimized == 2 or not input.is_cuda: hidden = F.linear(input, self.in_proj_weight, self.in_proj_bias) # hidden = F.relu(hidden, inplace=True) hidden = self.function(hidden) if self.variational: hidden = variational_dropout(hidden, p=self.dropout, training=self.training) else: hidden = F.dropout(hidden, p=self.dropout, training=self.training) hidden = F.linear(hidden, self.out_proj_weight, self.out_proj_bias) else: # Apex MLP does not support dropout so instead we use dropconnect # Theoretically they should yield similar results weights = [ F.dropout(self.in_proj_weight, p=self.dropout, training=self.training), self.out_proj_weight ] biases = [ F.dropout(self.in_proj_bias, p=self.dropout, training=self.training), self.out_proj_bias ] seq_len, bsz, hidden_size = input.size(0), input.size( 1), input.size(2) hidden = self.fast_mlp_func(True, 1, input.view(seq_len * bsz, -1), *weights, *biases) hidden = hidden.view(seq_len, bsz, hidden_size) return hidden
def forward(self, input, pos, key_padding_mask=None, attn_mask=None, incremental=False, incremental_cache=None, cleaning=False): q = self.layer_norm(input) attn, coverage = self.attn(q, pos, key_padding_mask=key_padding_mask, attn_mask=attn_mask, incremental=incremental, incremental_cache=incremental_cache) if not self.variational: o = F.dropout(attn, p=self.residual_dropout, training=self.training, inplace=False) else: o = variational_dropout(attn, p=self.residual_dropout, inplace=False, training=self.training) if cleaning: del q, attn return o, coverage
def forward(self, input, indices=None): len_x, bsz = input.size(0), input.size(1) ensemble = self.r_in.size(0) if self.training: with torch.no_grad(): indices = torch.arange(0, bsz, device=input.device, dtype=torch.long) indices = torch.remainder(indices, ensemble) r_in = torch.index_select(self.r_in, 0, indices) s_in = torch.index_select(self.s_in, 0, indices) r_out = torch.index_select(self.r_out, 0, indices) s_out = torch.index_select(self.s_out, 0, indices) input = torch.mul(input, r_in) input = F.linear(input, self.in_proj_weight, self.in_proj_bias) input = torch.mul(input, s_in) input = F.relu(input) if self.variational: input = variational_dropout(input, p=self.dropout, training=self.training) else: input = F.dropout(input, p=self.dropout, training=self.training) input = torch.mul(input, r_out) input = F.linear(input, self.out_proj_weight, self.out_proj_bias) input = torch.mul(input, s_out) return input else: input = input.repeat(1, ensemble, 1).view(len_x, ensemble, bsz, input.size(-1)) input = torch.mul(input, self.r_in.unsqueeze(1)) input = F.linear(input, self.in_proj_weight, self.in_proj_bias) input = torch.mul(input, self.s_in.unsqueeze(1)) input = F.relu(input) input = torch.mul(input, self.r_out.unsqueeze(1)) input = F.linear(input, self.out_proj_weight, self.out_proj_bias) input = torch.mul(input, self.s_out.unsqueeze(1)) input = torch.mean(input, dim=1) return input # hidden = self.input_linear(input, indices) # hidden = F.relu(hidden) # if self.variational: # hidden = variational_dropout(hidden, p=self.dropout, training=self.training) # else: # hidden = F.dropout(hidden, p=self.dropout, training=self.training) # hidden = self.output_linear(hidden, indices) return hidden
def forward(self, input, factor): factor = self.factor_map(factor).squeeze() in_proj_weight = torch.mv(self.in_proj_weight.view(-1, self.factor_size), factor)\ .view(self.in_proj_weight.size(0), self.in_proj_weight.size(1)) out_proj_weight = torch.mv(self.out_proj_weight.view(-1, self.factor_size), factor)\ .view(self.out_proj_weight.size(0), self.out_proj_weight.size(1)) in_proj_bias = torch.mv(self.in_proj_bias, factor) out_proj_bias = torch.mv(self.out_proj_bias, factor) if self.optimized == 2 or not input.is_cuda: hidden = F.linear(input, in_proj_weight, in_proj_bias) hidden = torch.relu(hidden) if self.variational: hidden = variational_dropout(hidden, p=self.dropout, training=self.training) else: hidden = F.dropout(hidden, p=self.dropout, training=self.training) hidden = F.linear(hidden, out_proj_weight, out_proj_bias) else: # Here weight dropout has to be done instead of dropout because # Apex MLP does not support dropout weights = [ F.dropout(in_proj_weight, p=self.dropout, training=self.training), F.dropout(out_proj_weight, p=self.dropout, training=self.training) ] biases = [ F.dropout(in_proj_bias, p=self.dropout, training=self.training), F.dropout(out_proj_bias, p=self.dropout, training=self.training) ] seq_len, bsz, hidden_size = input.size(0), input.size( 1), input.size(2) hidden = self.fast_mlp_func(True, 1, input.view(seq_len * bsz, -1), *weights, *biases) hidden = hidden.view(seq_len, bsz, hidden_size) return hidden
def forward(self, input, sample=False, calculate_log_probs=False): calculate_log_probs = calculate_log_probs or self.training sample = sample or self.training # (MCMC) # Sample the weights from the variational posterior distribution q(w) sampled_weights, log_variational_posterior = self.weight.sample( sample, calculate_log_probs) in_proj_weight, out_proj_weight, in_proj_bias, out_proj_bias = \ unflatten(sampled_weights, self.indices, self.shapes) if self.optimized == 2 or not input.is_cuda: hidden = F.linear(input, in_proj_weight, in_proj_bias) hidden = F.relu(hidden, inplace=True) if self.variational: hidden = variational_dropout(hidden, p=self.dropout, training=self.training) else: hidden = F.dropout(hidden, p=self.dropout, training=self.training) hidden = F.linear(hidden, out_proj_weight, out_proj_bias) else: # Apex MLP does not support dropout so instead we use dropconnect # Theoretically they should be the same ^^ weights = [in_proj_weight, out_proj_weight] biases = [in_proj_bias, out_proj_bias] seq_len, bsz, hidden_size = input.size(0), input.size( 1), input.size(2) # True = bias, 1 = relu hidden = self.fast_mlp_func(True, 1, input.view(seq_len * bsz, -1), *weights, *biases) hidden = hidden.view(seq_len, bsz, hidden_size) if calculate_log_probs: # KL Divergence between prior and (variational) posterior self.log_variational_posterior = log_variational_posterior self.log_prior = self.weight_prior.log_prob(sampled_weights) return hidden
def forward(self, input, cleaning=False): x_norm = self.layer_norm(input) x_ff = self.feedforward(x_norm) if not self.variational: o = F.dropout(x_ff, p=self.residual_dropout, training=self.training, inplace=False) else: o = variational_dropout(x_ff, p=self.residual_dropout, inplace=False, training=self.training) if cleaning: del x_norm, x_ff return o
def forward(self, input, context, pos_emb, mask_tgt, mask_src, src_lang=None, tgt_lang=None, incremental=False, incremental_cache=None, reuse_source=True, mems=None): """ Self attention layer layernorm > attn > dropout > residual """ if incremental and incremental_cache is None: incremental_cache = dict() coin = True if self.training and self.death_rate > 0: coin = (torch.rand(1)[0].item() >= self.death_rate) if coin: # input and context should be time first ? if mems is not None and mems.size(0) > 0: mems = self.preprocess_attn(mems) else: mems = None if self.macaron: out = self.mcr_feedforward(self.preprocess_mcr_ffn(input), src_lang) if self.training and self.death_rate > 0: out = out / (1 - self.death_rate) if not self.variational: out = F.dropout(out, p=self.dropout, training=self.training) else: out = variational_dropout(out, p=self.dropout, training=self.training) input = input + self.ffn_scale * out query = self.preprocess_attn(input) if self.mfw: out, _ = self.multihead_tgt( query, pos_emb, tgt_lang, None, mask_tgt, mems=mems, incremental=incremental, incremental_cache=incremental_cache) else: out, _ = self.multihead_tgt( query, pos_emb, None, mask_tgt, mems=mems, incremental=incremental, incremental_cache=incremental_cache) # rescaling before residual if self.training and self.death_rate > 0: out = out / (1 - self.death_rate) input = self.postprocess_attn(out, input) """ Context Attention layer layernorm > attn > dropout > residual """ if not self.ignore_source: query = self.preprocess_src_attn(input) incremental_source = incremental and reuse_source if self.mfw: out, coverage = self.multihead_src( query, context, context, src_lang, tgt_lang, mask_src, incremental=incremental_source, incremental_cache=incremental_cache) else: out, coverage = self.multihead_src( query, context, context, mask_src, incremental=incremental_source, incremental_cache=incremental_cache) # rescaling before residual if self.training and self.death_rate > 0: out = out / (1 - self.death_rate) input = self.postprocess_src_attn(out, input) else: coverage = None """ Feed forward layer layernorm > ffn > dropout > residual """ out = self.feedforward(self.preprocess_ffn(input), tgt_lang) # rescaling before residual if self.training and self.death_rate > 0: out = out / (1 - self.death_rate) if not self.variational: out = F.dropout(out, p=self.dropout, training=self.training) else: out = variational_dropout(out, p=self.dropout, training=self.training) input = input + self.ffn_scale * out else: coverage = None return input, coverage, incremental_cache
def forward(self, input, *args, **kwargs): if self.fused and input.is_cuda and not self.autograd: # if autocast is enabled: manually cast the function args into half manually # for some reason custom_fwd(...) doesn't work with autocast(enabled=False): weights = [ self.in_proj_weight.half(), self.out_proj_weight.half() ] biases = [self.in_proj_bias.half(), self.out_proj_bias.half()] seq_len, bsz, hidden_size = input.size(0), input.size( 1), input.size(2) dropout = self.dropout if self.training else 0.0 if self.fused_dropout_add: res_dropout = self.res_dropout if self.training else 0.0 hidden = self.fused_function( dropout, res_dropout, input.half().view(seq_len * bsz, -1), *weights, *biases).type_as(input) else: recompute = onmt.constants.recompute hidden = self.fused_function( dropout, recompute, input.half().view(seq_len * bsz, -1), *weights, *biases).type_as(input) hidden = hidden.view(seq_len, bsz, hidden_size) # verification code (only with dropout = 0.0) # with torch.no_grad(): # hidden_ = F.linear(self.act(F.linear(input, self.in_proj_weight, self.in_proj_bias)), # self.out_proj_weight, self.out_proj_bias).type_as(hidden) # # if self.fused_dropout_add: # hidden_.add_(input) # # comp = torch.allclose(hidden, hidden_, rtol=1e-02, atol=1e-03) # if not comp: # print("Warning! The fused function doesn't match the PyTorch function.") # print(hidden - hidden_) else: if self.autograd: hidden = self.linear_in(input) else: hidden = F.linear(input, self.in_proj_weight, self.in_proj_bias) if self.glu and self.activation != 'sigmoid': hidden, gate = hidden.chunk(2, dim=-1) hidden = self.act(hidden) * gate else: # GLU function hidden = self.act(hidden) if not (not self.glu and self.activation == 'relu'): if self.variational: hidden = variational_dropout( hidden, p=self.dropout, training=self.training, inplace=self.activation in ['silu', 'relu', 'swish', 'gelu']) else: hidden = F.dropout(hidden, p=self.dropout, training=self.training, inplace=self.activation in ['silu', 'relu', 'swish', 'gelu']) if self.autograd: hidden = self.linear_out(hidden) else: hidden = F.linear(hidden, self.out_proj_weight, self.out_proj_bias) if self.dropout_residual: if not self.fused_dropout_add: if not self.variational: hidden = F.dropout(hidden, p=self.res_dropout, training=self.training) + input else: hidden = variational_dropout( hidden, p=self.dropout, training=self.training) + input return hidden
def forward(self, input, pos_emb, attn_mask, incremental=False, incremental_cache=None, mems=None, src_lang=None): assert incremental is False assert incremental_cache is None coin = True if self.training and self.death_rate > 0: coin = (torch.rand(1)[0].item() >= self.death_rate) ffn_scale = self.ffn_scale / (1 - self.death_rate) else: ffn_scale = self.ffn_scale if coin: out = self.mcr_feedforward(self.preprocess_mcr_ffn(input), src_lang) out = out * ffn_scale if not self.variational: out = F.dropout(out, p=self.dropout, training=self.training) else: out = variational_dropout(out, p=self.dropout, training=self.training) input = input + out # attention attn_input = self.preprocess_attn(input) out, _ = self.attn(attn_input, pos_emb, attn_mask, None) if self.training and self.death_rate > 0: out = out / (1 - self.death_rate) input = self.postprocess_attn(out, input) # convolution conv_input = self.preprocess_conv(input) out = self.conv(conv_input) if self.training and self.death_rate > 0: out = out / (1 - self.death_rate) input = self.postprocess_conv(out, input) # last ffn out = self.feedforward(self.preprocess_ffn(input), src_lang) out = out * ffn_scale if not self.variational: out = F.dropout(out, p=self.dropout, training=self.training) else: out = variational_dropout(out, p=self.dropout, training=self.training) input = input + out return input return input
def forward(self, input, pos_emb, attn_mask, incremental=False, incremental_cache=None, mems=None, src_lang=None): if incremental and incremental_cache is None: incremental_cache = dict() coin = True if self.training and self.death_rate > 0: coin = (torch.rand(1)[0].item() >= self.death_rate) if coin: if self.macaron: out = self.mcr_feedforward(self.preprocess_mcr_ffn(input), src_lang) if self.training and self.death_rate > 0: out = out / (1 - self.death_rate) if not self.variational: out = F.dropout(out, p=self.dropout, training=self.training) else: out = variational_dropout(out, p=self.dropout, training=self.training) input = input + self.ffn_scale * out query = self.preprocess_attn(input) if self.mfw: out, _ = self.multihead(query, pos_emb, src_lang, attn_mask, None, mems=mems, incremental=incremental, incremental_cache=incremental_cache) else: out, _ = self.multihead(query, pos_emb, attn_mask, None, mems=mems, incremental=incremental, incremental_cache=incremental_cache) # rescaling before residual if self.training and self.death_rate > 0: out = out / (1 - self.death_rate) input = self.postprocess_attn(out, input) """ Feed forward layer layernorm > ffn > dropout > residual """ out = self.feedforward(self.preprocess_ffn(input), src_lang) # rescaling before residual if self.training and self.death_rate > 0: out = out / (1 - self.death_rate) if not self.variational: out = F.dropout(out, p=self.dropout, training=self.training) else: out = variational_dropout(out, p=self.dropout, training=self.training) input = input + self.ffn_scale * out if incremental: return input, incremental_cache return input