class BahdanauAttention(nn.Module): """ Bahdanau Attention (https://arxiv.org/abs/1409.0473) Implementation is very similar to tf.contrib.seq2seq.BahdanauAttention Args: query_size (int): feature dimension for query key_size (int): feature dimension for keys num_units (int): internal feature dimension normalize (bool): whether to normalize energy term. Default: `False` init_weight (float): range for uniform initializer used to initialize Linear key and query transform layers and linear_att vector. Default: 0.1 """ def __init__( self, query_size, key_size, num_units, normalize=False, init_weight=0.1, fusion=True, ): super(BahdanauAttention, self).__init__() self.normalize = normalize self.num_units = num_units self.linear_q = nn.Linear(query_size, num_units, bias=False) self.linear_k = nn.Linear(key_size, num_units, bias=False) nn.init.uniform_(self.linear_q.weight.data, -init_weight, init_weight) nn.init.uniform_(self.linear_k.weight.data, -init_weight, init_weight) self.linear_att = Parameter(torch.Tensor(num_units)) self.mask = None if self.normalize: self.normalize_scalar = Parameter(torch.Tensor(1)) self.normalize_bias = Parameter(torch.Tensor(num_units)) else: self.register_parameter("normalize_scalar", None) self.register_parameter("normalize_bias", None) self.fusion = fusion self.reset_parameters(init_weight) def reset_parameters(self, init_weight): """ Sets initial random values for trainable parameters. Args: init_weight (float): """ stdv = 1.0 / math.sqrt(self.num_units) self.linear_att.data.uniform_(-init_weight, init_weight) if self.normalize: self.normalize_scalar.data.fill_(stdv) self.normalize_bias.data.zero_() def set_mask(self, context_len, context): """ Sets self.mask which is applied before softmax ones for inactive context fields, zeros for active context fields Args: context_len (`obj`:torch.Tensor): context (`obj`:torch.Tensor): (t_k x b x n) Returns: """ max_len = context.size(0) indices = torch.arange(0, max_len, dtype=torch.int64, device=context.device) self.mask = indices >= (context_len.unsqueeze(1)) def calc_score(self, att_query, att_keys): """ Calculate Bahdanau score Args: att_query (`obj`:torch.Tensor): att_keys (`obj`:torch.Tensor): Returns: """ b, t_k, n = att_keys.size() t_q = att_query.size(1) att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n) att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n) sum_qk = att_query + att_keys if self.normalize: sum_qk = sum_qk + self.normalize_bias linear_att = self.linear_att / self.linear_att.norm() linear_att = linear_att * self.normalize_scalar else: linear_att = self.linear_att out = torch.tanh(sum_qk).matmul(linear_att) return out def forward(self, query, keys): """ Args: query (`obj`:torch.Tensor): (t_q x b x n) keys (`obj`:torch.Tensor): (t_k x b x n) Returns: (context, scores_normalized) context: (t_q x b x n) scores_normalized: (t_q x b x t_k) """ # first dim of keys and query has to be 'batch', it's needed for bmm keys = keys.transpose(0, 1) if query.dim() == 3: query = query.transpose(0, 1) if query.dim() == 2: single_query = True query = query.unsqueeze(1) else: single_query = False b = query.size(0) t_k = keys.size(1) t_q = query.size(1) # FC layers to transform query and key processed_query = self.linear_q(query) processed_key = self.linear_k(keys) # scores: (b x t_q x t_k) if self.fusion: linear_att = self.linear_att / self.linear_att.norm() linear_att = linear_att * self.normalize_scalar scores = fused_calc_score(processed_query, processed_key, self.normalize_bias, linear_att) else: scores = self.calc_score(processed_query, processed_key) if self.mask is not None: mask = self.mask.unsqueeze(1).expand(b, t_q, t_k) # I can't use -INF because of overflow check in pytorch scores.data.masked_fill_(mask, -65504.0) # Normalize the scores, softmax over t_k scores_normalized = F.softmax(scores, dim=-1) # Calculate the weighted average of the attention inputs according to # the scores # context: (b x t_q x n) context = torch.bmm(scores_normalized, keys) if single_query: context = context.squeeze(1) scores_normalized = scores_normalized.squeeze(1) else: context = context.transpose(0, 1) scores_normalized = scores_normalized.transpose(0, 1) return context, scores_normalized
class BahdanauAttention(nn.Module): """ Bahdanau Attention (https://arxiv.org/abs/1409.0473) Implementation is very similar to tf.contrib.seq2seq.BahdanauAttention """ def __init__(self, query_size, key_size, num_units, normalize=False, batch_first=False, init_weight=0.1): """ Constructor for the BahdanauAttention. :param query_size: feature dimension for query :param key_size: feature dimension for keys :param num_units: internal feature dimension :param normalize: whether to normalize energy term :param batch_first: if True batch size is the 1st dimension, if False the sequence is first and batch size is second :param init_weight: range for uniform initializer used to initialize Linear key and query transform layers and linear_att vector """ super(BahdanauAttention, self).__init__() self.normalize = normalize self.batch_first = batch_first self.num_units = num_units self.linear_q = nn.Linear(query_size, num_units, bias=False) self.linear_k = nn.Linear(key_size, num_units, bias=False) nn.init.uniform_(self.linear_q.weight.data, -init_weight, init_weight) nn.init.uniform_(self.linear_k.weight.data, -init_weight, init_weight) self.linear_att = Parameter(torch.Tensor(num_units)) self.mask = None if self.normalize: self.normalize_scalar = Parameter(torch.Tensor(1)) self.normalize_bias = Parameter(torch.Tensor(num_units)) else: self.register_parameter('normalize_scalar', None) self.register_parameter('normalize_bias', None) self.reset_parameters(init_weight) def reset_parameters(self, init_weight): """ Sets initial random values for trainable parameters. """ stdv = 1. / math.sqrt(self.num_units) self.linear_att.data.uniform_(-init_weight, init_weight) if self.normalize: self.normalize_scalar.data.fill_(stdv) self.normalize_bias.data.zero_() def set_mask(self, context_len, context): """ sets self.mask which is applied before softmax ones for inactive context fields, zeros for active context fields :param context_len: b :param context: if batch_first: (b x t_k x n) else: (t_k x b x n) self.mask: (b x t_k) """ if self.batch_first: max_len = context.size(1) else: max_len = context.size(0) indices = torch.arange(0, max_len, dtype=torch.int64, device=context.device) self.mask = indices >= (context_len.unsqueeze(1)) def calc_score(self, att_query, att_keys): """ Calculate Bahdanau score :param att_query: b x t_q x n :param att_keys: b x t_k x n returns: b x t_q x t_k scores """ b, t_k, n = att_keys.size() t_q = att_query.size(1) att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n) att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n) sum_qk = att_query + att_keys if self.normalize: sum_qk = sum_qk + self.normalize_bias linear_att = self.linear_att / self.linear_att.norm() linear_att = linear_att * self.normalize_scalar else: linear_att = self.linear_att out = torch.tanh(sum_qk).matmul(linear_att) return out def forward(self, query, keys): """ :param query: if batch_first: (b x t_q x n) else: (t_q x b x n) :param keys: if batch_first: (b x t_k x n) else (t_k x b x n) :returns: (context, scores_normalized) context: if batch_first: (b x t_q x n) else (t_q x b x n) scores_normalized: if batch_first (b x t_q x t_k) else (t_q x b x t_k) """ # first dim of keys and query has to be 'batch', it's needed for bmm if not self.batch_first: keys = keys.transpose(0, 1) if query.dim() == 3: query = query.transpose(0, 1) if query.dim() == 2: single_query = True query = query.unsqueeze(1) else: single_query = False b = query.size(0) t_k = keys.size(1) t_q = query.size(1) # FC layers to transform query and key processed_query = self.linear_q(query) processed_key = self.linear_k(keys) # scores: (b x t_q x t_k) scores = self.calc_score(processed_query, processed_key) if self.mask is not None: mask = self.mask.unsqueeze(1).expand(b, t_q, t_k) # I can't use -INF because of overflow check in pytorch scores.masked_fill_(mask, -65504.0) # Normalize the scores, softmax over t_k scores_normalized = F.softmax(scores, dim=-1) # Calculate the weighted average of the attention inputs according to # the scores # context: (b x t_q x n) context = torch.bmm(scores_normalized, keys) if single_query: context = context.squeeze(1) scores_normalized = scores_normalized.squeeze(1) elif not self.batch_first: context = context.transpose(0, 1) scores_normalized = scores_normalized.transpose(0, 1) return context, scores_normalized
class Loss(nn.Module): def __init__(self, args): super(Loss, self).__init__() self.CMPM = args.CMPM self.CMPC = args.CMPC self.epsilon = args.epsilon self.num_classes = args.num_classes if args.resume: checkpoint = torch.load(args.model_path) self.W = Parameter(checkpoint['W']) print('=========> Loading in parameter W from pretrained models') else: self.W = Parameter(torch.randn(args.feature_size, args.num_classes)) self.init_weight() def init_weight(self): nn.init.xavier_uniform_(self.W.data, gain=1) def compute_cmpc_loss(self, image_embeddings, text_embeddings, labels): """ Cross-Modal Projection Classfication loss(CMPC) :param image_embeddings: Tensor with dtype torch.float32 :param text_embeddings: Tensor with dtype torch.float32 :param labels: Tensor with dtype torch.int32 :return: """ criterion = nn.CrossEntropyLoss(reduction='mean') self.W_norm = self.W / self.W.norm(dim=0) #labels_onehot = one_hot_coding(labels, self.num_classes).float() image_norm = image_embeddings / image_embeddings.norm(dim=1, keepdim=True) text_norm = text_embeddings / text_embeddings.norm(dim=1, keepdim=True) image_proj_text = torch.sum(image_embeddings * text_norm, dim=1, keepdim=True) * text_norm text_proj_image = torch.sum(text_embeddings * image_norm, dim=1, keepdim=True) * image_norm image_logits = torch.matmul(image_proj_text, self.W_norm) text_logits = torch.matmul(text_proj_image, self.W_norm) #labels_one_hot = one_hot_coding(labels, num_classes) ''' ipt_loss = criterion(input=image_logits, target=labels) tpi_loss = criterion(input=text_logits, target=labels) cmpc_loss = ipt_loss + tpi_loss ''' cmpc_loss = criterion(image_logits, labels) + criterion(text_logits, labels) #cmpc_loss = - (F.log_softmax(image_logits, dim=1) + F.log_softmax(text_logits, dim=1)) * labels_onehot #cmpc_loss = torch.mean(torch.sum(cmpc_loss, dim=1)) # classification accuracy for observation image_pred = torch.argmax(image_logits, dim=1) text_pred = torch.argmax(text_logits, dim=1) image_precision = torch.mean((image_pred == labels).float()) text_precision = torch.mean((text_pred == labels).float()) return cmpc_loss, image_precision, text_precision def compute_cmpm_loss(self, image_embeddings, text_embeddings, labels): """ Cross-Modal Projection Matching Loss(CMPM) :param image_embeddings: Tensor with dtype torch.float32 :param text_embeddings: Tensor with dtype torch.float32 :param labels: Tensor with dtype torch.int32 :return: i2t_loss: cmpm loss for image projected to text t2i_loss: cmpm loss for text projected to image pos_avg_sim: average cosine-similarity for positive pairs neg_avg_sim: averate cosine-similarity for negative pairs """ batch_size = image_embeddings.shape[0] labels_reshape = torch.reshape(labels, (batch_size, 1)) labels_dist = labels_reshape - labels_reshape.t() labels_mask = (labels_dist == 0) image_norm = image_embeddings / image_embeddings.norm(dim=1, keepdim=True) text_norm = text_embeddings / text_embeddings.norm(dim=1, keepdim=True) image_proj_text = torch.matmul(image_embeddings, text_norm.t()) text_proj_image = torch.matmul(text_embeddings, image_norm.t()) # normalize the true matching distribution labels_mask_norm = labels_mask.float() / labels_mask.float().norm(dim=1) i2t_pred = F.softmax(image_proj_text, dim=1) #i2t_loss = i2t_pred * torch.log((i2t_pred + self.epsilon)/ (labels_mask_norm + self.epsilon)) i2t_loss = i2t_pred * (F.log_softmax(image_proj_text, dim=1) - torch.log(labels_mask_norm + self.epsilon)) t2i_pred = F.softmax(text_proj_image, dim=1) #t2i_loss = t2i_pred * torch.log((t2i_pred + self.epsilon)/ (labels_mask_norm + self.epsilon)) t2i_loss = t2i_pred * (F.log_softmax(text_proj_image, dim=1) - torch.log(labels_mask_norm + self.epsilon)) cmpm_loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean(torch.sum(t2i_loss, dim=1)) sim_cos = torch.matmul(image_norm, text_norm.t()) pos_avg_sim = torch.mean(torch.masked_select(sim_cos, labels_mask)) neg_avg_sim = torch.mean(torch.masked_select(sim_cos, labels_mask == 0)) return cmpm_loss, pos_avg_sim, neg_avg_sim def forward(self, image_embeddings, text_embeddings, labels): cmpm_loss = 0.0 cmpc_loss = 0.0 image_precision = 0.0 text_precision = 0.0 neg_avg_sim = 0.0 pos_avg_sim =0.0 if self.CMPM: cmpm_loss, pos_avg_sim, neg_avg_sim = self.compute_cmpm_loss(image_embeddings, text_embeddings, labels) if self.CMPC: cmpc_loss, image_precision, text_precision = self.compute_cmpc_loss(image_embeddings, text_embeddings, labels) loss = cmpm_loss + cmpc_loss return cmpm_loss, cmpc_loss, loss, image_precision, text_precision, pos_avg_sim, neg_avg_sim