def loss_compute(self, x: torch.tensor, y: torch.tensor, norm: int): x = self.model.generator(x) loss = self.labelsmooth(x.contiguous().view(-1, x.size(-1)), y.contiguous().view(-1)) / norm loss.backward() if self.model_opt is not None: self.model_opt.step() self.model_opt.optimizer.zero_grad() return loss.item() * norm
def get_token(h: torch.tensor, x: torch.tensor, token: int): """ Get specific token embedding (e.g. [CLS]) """ emb_size = h.shape[-1] token_h = h.view(-1, emb_size) flat = x.contiguous().view(-1) # get contextualized embedding of given token token_h = token_h[flat == token, :] return token_h
def get_token(h: torch.tensor, x: torch.tensor, token: int): emb_size = h.shape[-1] token_h = h.view(-1, emb_size) flat = x.contiguous().view(-1) token_h = token_h[flat == token, :] return token_h
def backward(ctx, grad_out: torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ :param ctx: :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward :return: (B, C, N) gradient of the features """ idx, N = ctx.for_backwards grad_features = ppp_ops.group_points_grad_cuda(grad_out.contiguous(), idx, N) return grad_features, None
def valid_loss_compute(self, x: torch.tensor, y: torch.tensor, norm: int): x = self.model.generator(x) loss = self.labelsmooth(x.contiguous().view(-1, x.size(-1)), y.contiguous().view(-1)) / norm return loss.item() * norm
def forward(self, query: torch.tensor, padding_mask: Optional[torch.tensor] = None, attention_mask: Optional[torch.tensor] = None, need_weights: bool = False) -> torch.tensor: """ :param query: [sequence_length, batch_size, embed_dim] :param padding_mask: [batch_size, sequence_len] :param attention_mask: [batch_size, sequence_len, sequence_len] :param need_weights: bool :return: [sequence_length, batch_size, embed_dim] """ sequence_len, batch_size, embed_dim = query.size() assert embed_dim == self.embed_dim query, key, value = self.in_projection(query).chunk(3, dim=-1) query *= self.scaling # [batch_size * self.num_heads, sequence_len, self.head_dim] query = query.contiguous().view(sequence_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1) key = key.contiguous().view(sequence_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1) value = value.contiguous().view(sequence_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1) # [batch_size * self.num_heads, sequence_len, sequence_len] attention_scores = torch.bmm(query, key.transpose(1, 2)) # [batch_size, self.num_heads, sequence_len, sequence_len] attention_scores = self.split_heads(attention_scores, batch_size, sequence_len) # fp16 compatibility parameters_type = next(self.parameters()).dtype if attention_mask is not None: assert attention_mask.size(1) == sequence_len assert attention_mask.size(2) == sequence_len # [batch_size, 1, sequence_len, sequence_len] attention_mask = attention_mask.unsqueeze(1) attention_mask = attention_mask.to(dtype=parameters_type) attention_scores += attention_mask if padding_mask is not None: assert padding_mask.size(0) == batch_size assert padding_mask.size(1) == sequence_len # padding_mask = [batch_size, sequence_len] attention_scores = attention_scores.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2), float(-10000.), ) # [batch_size * self.num_heads, sequence_len, sequence_len] attention_scores = self.join_heads(attention_scores, batch_size, sequence_len) if attention_scores.dtype == torch.float16: tensor_type = torch.float32 else: tensor_type = attention_scores.dtype # [batch_size * self.num_heads, sequence_len, sequence_len] attention_scores = F.softmax(attention_scores.float(), dim=-1, dtype=tensor_type) attention_scores = self.dropout(attention_scores) # attention_scores = [batch_size * self.num_heads, sequence_len, sequence_len] # value = [batch_size * self.num_heads, sequence_len, self.head_dim] # [batch_size * self.num_heads, sequence_len, self.head_dim] attention_output = torch.bmm(attention_scores, value) # [sequence_len, batch_size, embed_dim] attention_output = attention_output.transpose(0, 1).contiguous().view( sequence_len, batch_size, embed_dim) attention_output = self.out_projection(attention_output) # for visualize attention scores if need_weights: # [batch_size, self.num_heads, sequence_len, sequence_len] attention_scores = self.split_heads(attention_scores, batch_size, sequence_len) else: attention_scores = None return attention_output, attention_scores