class ArcFace(nn.Module): def __init__(self, in_features, out_features, s=30.0, m=0.50, bias=False): super(ArcFace, self).__init__() self.in_features = in_features self.out_features = out_features self.s = s self.m = m self.cos_m = math.cos(m) self.sin_m = math.sin(m) self.th = math.cos(math.pi - m) self.mm = math.sin(math.pi - m) * m self.weight = Parameter(torch.Tensor(out_features, in_features)) if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) nn.init.uniform_(self.bias, -bound, bound) def forward(self, input, label): cosine = F.linear(F.normalize(input), F.normalize(self.weight.half())) sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)).half() phi = cosine * self.cos_m - sine * self.sin_m phi = torch.where(cosine > self.th, phi, cosine - self.mm) # --------------------------- convert label to one-hot --------------------------- # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda') one_hot = torch.zeros(cosine.size(), device='cuda') one_hot.scatter_(1, label.view(-1, 1).long(), 1) # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- output = (one_hot * phi) + ( (1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 output *= self.s # print(output) return output
class PositionWiseFeedForward(nn.Module): """Two-layer Feed-forward neural network""" def __init__(self, model_size, inner_size, dropout=0., variational=False, activation='relu', glu=False, weight_drop=0.0, dropout_residual=False, res_dropout=0.0): super().__init__() self.model_size = model_size self.inner_size = inner_size self.dropout = dropout self.bias = True self.variational = variational self.activation = activation self.glu = glu self.weight_drop = weight_drop self.autograd = False self.fused_dropout_add = False self.dropout_residual = dropout_residual self.res_dropout = res_dropout if self.activation == 'relu': if self.glu: self.act = nn.ReLU(inplace=True) else: self.act = ReLUDropout(p=self.dropout, variational=self.variational, batch_first=False) elif self.activation == 'gelu': self.act = nn.GELU() elif self.activation == 'agelu': self.act = AGELU() elif self.activation in ['silu', 'swish']: self.act = SiLU() elif self.activation in ['sigmoid']: if self.glu: self.act = nn.functional.glu else: print( "Sigmoid activation function is recommended to be used with -glu" ) raise NotImplementedError self.in_proj_weight = Parameter( torch.Tensor(inner_size * (2 if glu else 1), model_size)) self.out_proj_weight = Parameter(torch.Tensor(model_size, inner_size)) self.in_proj_bias = Parameter( torch.Tensor(inner_size * (2 if glu else 1))) self.out_proj_bias = Parameter(torch.Tensor(model_size)) self.reset_parameters() self.fused = False # At the moment fused mlp is supported for RELU, SiLU, Swish, GELU and AGELU (approximated GELU) if not self.glu and \ self.activation in ['relu', 'silu', 'swish', 'gelu', 'agelu'] and not self.variational: if self.activation == 'relu': from onmt.modules.mlp.mlp import mlp_relu_function if mlp_relu_function is not None: self.fused_function = mlp_relu_function self.fused = True elif self.activation in ['silu', 'swish']: from onmt.modules.mlp.mlp import mlp_silu_function if mlp_silu_function is not None: self.fused_function = mlp_silu_function self.fused = True elif self.activation == 'gelu': if self.dropout_residual: from onmt.modules.mlp.mlp import mlp_gelu_dropout_add_function if mlp_gelu_dropout_add_function is not None: self.fused_function = mlp_gelu_dropout_add_function self.fused = True self.fused_dropout_add = True if not self.fused: from onmt.modules.mlp.mlp import mlp_gelu_function if mlp_gelu_function is not None: self.fused_function = mlp_gelu_function self.fused = True elif self.activation == 'agelu': from onmt.modules.mlp.mlp import mlp_agelu_function if mlp_agelu_function is not None: self.fused_function = mlp_agelu_function self.fused = True def reset_parameters(self, init='normal'): if init == 'normal': std_ = math.sqrt(2.0 / (self.model_size + self.inner_size)) nn.init.normal_(self.in_proj_weight, 0.0, std_) nn.init.normal_(self.out_proj_weight, 0.0, std_) else: std_ = math.sqrt(6.0 / (self.model_size + self.inner_size)) nn.init.uniform_(self.in_proj_weight, -std_, std_) nn.init.uniform_(self.out_proj_weight, -std_, std_) nn.init.constant_(self.in_proj_bias, 0.0) nn.init.constant_(self.out_proj_bias, 0.0) def convert_autograd(self): if self.autograd: return with torch.no_grad(): self.autograd = True self.linear_in = torch.nn.Linear(self.model_size, self.inner_size) self.linear_out = torch.nn.Linear(self.inner_size, self.model_size) self.linear_in.weight.copy_(self.in_proj_weight) self.linear_in.bias.copy_(self.in_proj_bias) self.linear_out.weight.copy_(self.out_proj_weight) self.linear_out.bias.copy_(self.out_proj_bias) del self.in_proj_weight del self.in_proj_bias del self.out_proj_weight del self.out_proj_bias def forward(self, input, *args, **kwargs): if self.fused and input.is_cuda and not self.autograd: # if autocast is enabled: manually cast the function args into half manually # for some reason custom_fwd(...) doesn't work with autocast(enabled=False): weights = [ self.in_proj_weight.half(), self.out_proj_weight.half() ] biases = [self.in_proj_bias.half(), self.out_proj_bias.half()] seq_len, bsz, hidden_size = input.size(0), input.size( 1), input.size(2) dropout = self.dropout if self.training else 0.0 if self.fused_dropout_add: res_dropout = self.res_dropout if self.training else 0.0 hidden = self.fused_function( dropout, res_dropout, input.half().view(seq_len * bsz, -1), *weights, *biases).type_as(input) else: recompute = onmt.constants.recompute hidden = self.fused_function( dropout, recompute, input.half().view(seq_len * bsz, -1), *weights, *biases).type_as(input) hidden = hidden.view(seq_len, bsz, hidden_size) # verification code (only with dropout = 0.0) # with torch.no_grad(): # hidden_ = F.linear(self.act(F.linear(input, self.in_proj_weight, self.in_proj_bias)), # self.out_proj_weight, self.out_proj_bias).type_as(hidden) # # if self.fused_dropout_add: # hidden_.add_(input) # # comp = torch.allclose(hidden, hidden_, rtol=1e-02, atol=1e-03) # if not comp: # print("Warning! The fused function doesn't match the PyTorch function.") # print(hidden - hidden_) else: if self.autograd: hidden = self.linear_in(input) else: hidden = F.linear(input, self.in_proj_weight, self.in_proj_bias) if self.glu and self.activation != 'sigmoid': hidden, gate = hidden.chunk(2, dim=-1) hidden = self.act(hidden) * gate else: # GLU function hidden = self.act(hidden) if not (not self.glu and self.activation == 'relu'): if self.variational: hidden = variational_dropout( hidden, p=self.dropout, training=self.training, inplace=self.activation in ['silu', 'relu', 'swish', 'gelu']) else: hidden = F.dropout(hidden, p=self.dropout, training=self.training, inplace=self.activation in ['silu', 'relu', 'swish', 'gelu']) if self.autograd: hidden = self.linear_out(hidden) else: hidden = F.linear(hidden, self.out_proj_weight, self.out_proj_bias) if self.dropout_residual: if not self.fused_dropout_add: if not self.variational: hidden = F.dropout(hidden, p=self.res_dropout, training=self.training) + input else: hidden = variational_dropout( hidden, p=self.dropout, training=self.training) + input return hidden