def attention(query, key, value, params, mask=None, dropout=None, alpha=None): "Compute 'Scaled Dot Product Attention'" d_k = query.size(-1) scores = torch.matmul(query, key.transpose(-2, -1)) \ / math.sqrt(d_k) if mask is not None: try: scores = scores.masked_fill(mask == 0, -1e9) except: embed() if params.attn_type=='softmax': p_attn = F.softmax(scores, dim = -1) elif params.attn_type=='sparsemax': p_attn = sparsemax(scores, dim=-1) elif params.attn_type=='entmax15': p_attn = entmax15(scores, dim=-1) elif params.attn_type=='entmax': p_attn = EntmaxBisect(scores, alpha, n_iter=25) else: raise Exception if dropout is not None: p_attn = dropout(p_attn) p_attn = p_attn.to(torch.float32) return torch.matmul(p_attn, value), scores, p_attn
def forward(self, query, key, mask=None): # query and value are two copies of sentence representation H # query: [nbatches, seq_len, d_model] # value: [nbatches, seq_len, d_model] # mask: [nbatches, seq_len, seq_len] nbatches = query.size(0) # 1) Do all the linear projections in batch from d_model => h x d_k query = self.w_q(query).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) key = self.w_k(key).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) # [nbatches, h, seq_len, d_k] # 2) Apply attention on all the projected vectors in batch. # Compute 'Scaled Dot Product Attention' scores = torch.matmul(query, key.transpose(-2, -1)) \ / math.sqrt(self.d_k) # [nbatches, h, seq_len, seq_len] if mask is not None: key_padding_mask = mask.unsqueeze(1).unsqueeze(2) scores = scores.masked_fill(key_padding_mask == 0, float("-inf")) p_attn = entmax15(scores, dim=-1) # [nbatches, h, seq_len, seq_len] if self.dropout is not None: p_attn = self.dropout(p_attn) # 3) "Concat" using a view and apply a final linear. p_attn = torch.sum(p_attn, dim=1) / self.h return p_attn # [nbatches, seq_len, seq_len]
def forward(self, scores: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: """Map a score vector to a probability distribution halfway between softmax and sparsemax Args: scores (torch.Tensor): (Batch x Sequence Length) Attention scores (also referred to as weights) mask (torch.BoolTensor): (Batch x Sequence Length) Specifies which indices are just padding Returns: torch.Tensor: Distribution halfway between softmax and sparsemax """ masked_scores = replace_masked_values(scores, mask, -float("inf")) return entmax15(masked_scores, dim=-1)
def sparse_attention(query, key, value, alpha, mask=None, dropout=None): "Use sparse activation function" d_k = query.size(-1) scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) if alpha == 2: p_attn = entmax.sparsemax(scores, -1) elif alpha == 1.5: p_attn = entmax.entmax15(scores, -1) else: raise NotImplementedError if dropout is not None: p_attn = dropout(p_attn) # return torch.matmul(p_attn, value), scores.squeeze(1).squeeze(1) return torch.matmul(p_attn, value), p_attn
def nce_loss_entmax(positive, negatives, temperature): """ :param positive: b * k :param negatives: b * k * num_negatives :return: """ negatives_and_positive = torch.cat([positive.unsqueeze(2), negatives], dim=2) from entmax import entmax15 entmax = entmax15(negatives_and_positive, dim=2) loss_batch = torch.log(entmax[:, :, 0] + 1e-8) # sum over k, mean over batches loss = -loss_batch.sum(1).mean(0) # loss = -torch.mean(loss_batch) return loss
def forward(self, char_encoder_result, tag_encoder_result, true_output_seq=None): char_encoding, (char_hn, char_cn) = char_encoder_result batch_size = len(char_encoding) tag_encoding, (tag_hn, tag_cn) = tag_encoder_result char_encoding = torch.transpose(char_encoding, 0, 1) # move seq_len dimension to the front tag_encoding = torch.transpose(tag_encoding, 0, 1) current_input = torch.zeros((batch_size, self.n_chars), device=char_encoding.device) last_cell_state = ( #torch.cat((char_hn[0], tag_hn[0]), dim=-1), #torch.cat((char_cn[0], tag_cn[0]), dim=-1) torch.cat((char_hn[0], char_hn[1]), dim=-1), torch.cat((char_cn[0], char_cn[1]), dim=-1) ) def time_step_fn(input_1, state_0): h1, c1 = self.lstm_cell(input_1, state_0) query = torch.unsqueeze(h1, dim=0) # use cell output as query char_attention, _ = self.char_attention(query=query, key=char_encoding, value=char_encoding) tag_attention, _ = self.tag_attention(query=query, key=tag_encoding, value=tag_encoding) aggregated_attention = torch.cat([char_attention, tag_attention], dim=-1).squeeze(0) aggregated_attention = torch.relu(aggregated_attention) output = self.output_layer(aggregated_attention) # relu instead? return output, (h1, c1) top = [[current_input, last_cell_state, [], 0]] # beam search candidates; last entry is log probability teacher_forcing = true_output_seq is not None for time_step in range(len(char_encoding)): time_step_leaders = [] for candidate in top: next_input, current_cell_state, current_output_seq, sequence_probability = candidate candidate_output, candidate_next_state = time_step_fn(next_input, current_cell_state) if teacher_forcing: # teacher forcing; in this scenario, top only has 1 item top = [[None, candidate_next_state, current_output_seq + [candidate_output], 1]] top[0][0] = true_output_seq[:, time_step, :] continue else: probabilities = entmax15(candidate_output, dim=-1) tk = torch.topk(probabilities, self.beam_size, dim=-1) top_indices = tk.indices[0] top_probs = tk.values[0] for i in range(self.beam_size): time_step_leaders.append( [top_indices[i], top_probs[i], candidate_next_state, current_output_seq, sequence_probability + torch.log(top_probs[i])] ) if not teacher_forcing: new_top = [] time_step_leaders.sort(key=lambda x: x[4]) beam_size = self.beam_size if time_step == self.beam_size - 1: beam_size = 1 for leader in time_step_leaders[-beam_size:]: leader_index, leader_prob, leader_next_state, leader_current_output_seq, probability = leader one_hot = torch.nn.functional.one_hot(leader_index, num_classes=self.n_chars) one_hot = torch.unsqueeze(one_hot, dim=0).float() # add batch dimension new_top.append([one_hot, leader_next_state, leader_current_output_seq + [one_hot], probability]) top = new_top return_sequence = top[0][2] return_sequence = torch.stack(return_sequence) return torch.transpose(return_sequence, 0, 1)
def eval_semisuper_vae(vae, classifier, loader_unlabeled, super_loss, loader_labeled=[None], train=False, optimizer=None, topk=0, grad_estimator=bs_lib.reinforce, grad_estimator_kwargs={'grad_estimator_kwargs': None}, n_samples=1, train_labeled_only=False, epoch=0, baseline_optimizer=None, normalizer='softmax'): if train: assert optimizer is not None vae.train() classifier.train() else: vae.eval() classifier.eval() sum_loss = 0.0 num_images = 0.0 total_nz = 0.0 for labeled_data, unlabeled_data in zip(cycle(loader_labeled), \ loader_unlabeled): unlabeled_image = unlabeled_data['image'].to(device) if labeled_data is not None: labeled_image = labeled_data['image'].to(device) true_labels = labeled_data['label'].to(device) # get loss on labeled images supervised_loss = \ get_supervised_loss(vae, classifier, labeled_image, true_labels, super_loss).sum() num_labeled = len(loader_labeled.sampler) num_labeled_batch = labeled_image.shape[0] else: supervised_loss = 0.0 num_labeled = 0.0 num_labeled_batch = 1.0 # run through classifier scores = classifier.forward(unlabeled_image) if normalizer == 'softmax': class_weights = torch.softmax(scores, dim=-1) elif normalizer == 'entmax15': class_weights = entmax15(scores, dim=-1) elif normalizer == 'sparsemax': class_weights = sparsemax(scores, dim=-1) else: raise NameError("%s is not a valid normalizer!" % (normalizer, )) # get a mask of nonzeros nz = (class_weights > 0).to(class_weights.device) if train: train_labeled_only_bool = 1. if train_labeled_only: n_samples = 0 train_labeled_only_bool = 0. # flush gradients optimizer.zero_grad() # get unlabeled pseudoloss: here we use our # Rao-Blackwellization or some other gradient estimator f_z = lambda z: vae_utils.get_loss_from_one_hot_label( vae, unlabeled_image, z) unlabeled_ps_loss = 0.0 for i in range(n_samples): unlabeled_ps_loss_ = rb_lib.get_raoblackwell_ps_loss( f_z, class_weights, topk=topk, epoch=epoch, data=unlabeled_image, grad_estimator=grad_estimator, grad_estimator_kwargs=grad_estimator_kwargs) unlabeled_ps_loss += unlabeled_ps_loss_ unlabeled_ps_loss = unlabeled_ps_loss / max(n_samples, 1) kl_q = torch.sum(class_weights[nz] * torch.log(class_weights[nz])) total_ps_loss = \ (unlabeled_ps_loss + kl_q) * train_labeled_only_bool * \ len(loader_unlabeled.sampler) / unlabeled_image.shape[0] + \ supervised_loss * num_labeled / labeled_image.shape[0] # backprop gradients from pseudo loss total_ps_loss.backward(retain_graph=True) optimizer.step() if baseline_optimizer is not None: # for RELAX: as it trains to minimize a control variate # flush gradients optimizer.zero_grad() # for params in classifier.parameters(): baseline_optimizer.zero_grad() loss_grads = grad(total_ps_loss, classifier.parameters(), create_graph=True) gn2 = sum([grd.norm()**2 for grd in loss_grads]) gn2.backward() baseline_optimizer.step() # loss at MAP value of z loss = \ vae_utils.get_labeled_loss(vae, unlabeled_image, torch.argmax(scores, dim = 1)).detach().sum() sum_loss += loss num_images += unlabeled_image.shape[0] total_nz += nz.sum().item() return sum_loss / num_images, total_nz / num_images
def forward( self, query, key: Optional[Tensor], value: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, need_weights: bool = True, static_kv: bool = False, attn_mask: Optional[Tensor] = None, before_softmax: bool = False, need_head_weights: bool = False, prune_attn_mask = None, ) -> Tuple[Tensor, Optional[Tensor]]: """Input shape: Time x Batch x Channel Args: key_padding_mask (ByteTensor, optional): mask to exclude keys that are pads, of shape `(batch, src_len)`, where padding elements are indicated by 1s. need_weights (bool, optional): return the attention weights, averaged over heads (default: False). attn_mask (ByteTensor, optional): typically used to implement causal attention, where the mask prevents the attention from looking forward in time (default: None). before_softmax (bool, optional): return the raw attention weights and values before the attention softmax. need_head_weights (bool, optional): return the attention weights for each head. Implies *need_weights*. Default: return the average attention weights over all heads. prune_attn_mask shape (tensor): has shape(1, self.num_heads, 1024, 1024) """ if need_head_weights: need_weights = True tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] if ( self.enable_torch_version and not self.onnx_trace and incremental_state is None and not static_kv ): assert key is not None and value is not None return F.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, torch.empty([0]), torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, self.training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, ) if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) if saved_state is not None and "prev_key" in saved_state: # previous time steps are cached - no need to recompute # key and value if they are static if static_kv: assert self.encoder_decoder_attention and not self.self_attention key = value = None else: saved_state = None if self.self_attention: q = self.q_proj(query) k = self.k_proj(query) v = self.v_proj(query) elif self.encoder_decoder_attention: # encoder-decoder attention q = self.q_proj(query) if key is None: assert value is None k = v = None else: k = self.k_proj(key) v = self.v_proj(key) else: assert key is not None and value is not None q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) q *= self.scaling if self.bias_k is not None: assert self.bias_v is not None k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 ) if key_padding_mask is not None: key_padding_mask = torch.cat( [ key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1), ], dim=1, ) q = ( q.contiguous() .view(tgt_len, bsz * self.num_heads, self.head_dim) .transpose(0, 1) ) if k is not None: k = ( k.contiguous() .view(-1, bsz * self.num_heads, self.head_dim) .transpose(0, 1) ) if v is not None: v = ( v.contiguous() .view(-1, bsz * self.num_heads, self.head_dim) .transpose(0, 1) ) if saved_state is not None: # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) if "prev_key" in saved_state: _prev_key = saved_state["prev_key"] assert _prev_key is not None prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) if static_kv: k = prev_key else: assert k is not None k = torch.cat([prev_key, k], dim=1) if "prev_value" in saved_state: _prev_value = saved_state["prev_value"] assert _prev_value is not None prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) if static_kv: v = prev_value else: assert v is not None v = torch.cat([prev_value, v], dim=1) prev_key_padding_mask: Optional[Tensor] = None if "prev_key_padding_mask" in saved_state: prev_key_padding_mask = saved_state["prev_key_padding_mask"] assert k is not None and v is not None key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( key_padding_mask=key_padding_mask, prev_key_padding_mask=prev_key_padding_mask, batch_size=bsz, src_len=k.size(1), static_kv=static_kv, ) saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) saved_state["prev_key_padding_mask"] = key_padding_mask # In this branch incremental_state is never None assert incremental_state is not None incremental_state = self._set_input_buffer(incremental_state, saved_state) assert k is not None src_len = k.size(1) # This is part of a workaround to get around fork/join parallelism # not supporting Optional types. if key_padding_mask is not None and key_padding_mask.dim() == 0: key_padding_mask = None if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len if self.add_zero_attn: assert v is not None src_len += 1 k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 ) if key_padding_mask is not None: key_padding_mask = torch.cat( [ key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as( key_padding_mask ), ], dim=1, ) attn_weights = torch.bmm(q, k.transpose(1, 2)) attn_weights = MultiheadAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] if prune_attn_mask is not None: if incremental_state is None: #train prune_attn_mask = prune_attn_mask.to(torch.bool)[:,:,0:tgt_len, 0:src_len] else: #generation prune_attn_mask = prune_attn_mask.to(torch.bool)[:,:,src_len+1:src_len+2, 0:src_len] attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.masked_fill(prune_attn_mask, -32768) #prune_mask is 1 where we want to mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) if self.onnx_trace: attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) attn_weights += attn_mask if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if before_softmax: return attn_weights, v if self.USE_ENTMAX: attn_weights_float = entmax15(attn_weights) else: attn_weights_float = utils.softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace) attn_weights = attn_weights_float.type_as(attn_weights) attn_probs = F.dropout( attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training, ) self.attn_probs = attn_probs.view(bsz, self.num_heads, tgt_len, src_len) #keep track of attention pattern for pruning experiments assert v is not None attn = torch.bmm(attn_probs, v) assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] if self.onnx_trace and attn.size(1) == 1: # when ONNX tracing a single decoder step (sequence length == 1) # the transpose is a no-op copy before view, thus unnecessary attn = attn.contiguous().view(tgt_len, bsz, embed_dim) else: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = self.out_proj(attn) attn_weights: Optional[Tensor] = None if need_weights: attn_weights = attn_weights_float.view( bsz, self.num_heads, tgt_len, src_len ).transpose(1, 0) if not need_head_weights: # average attention weights over heads attn_weights = attn_weights.mean(dim=0) return attn, attn_weights
def log_entmax15(*args, **kwargs): return torch.log(entmax15(*args, **kwargs))
def forward( self, query, key, value, key_padding_mask=None, incremental_state=None, need_weights=True, static_kv=False, attn_mask=None, before_softmax=False, need_head_weights=False, ): """Input shape: Time x Batch x Channel Args: key_padding_mask (ByteTensor, optional): mask to exclude keys that are pads, of shape `(batch, src_len)`, where padding elements are indicated by 1s. need_weights (bool, optional): return the attention weights, averaged over heads (default: False). attn_mask (ByteTensor, optional): typically used to implement causal attention, where the mask prevents the attention from looking forward in time (default: None). before_softmax (bool, optional): return the raw attention weights and values before the attention softmax. need_head_weights (bool, optional): return the attention weights for each head. Implies *need_weights*. Default: return the average attention weights over all heads. """ if need_head_weights: need_weights = True tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] if self.enable_torch_version and not self.onnx_trace and incremental_state is None and not static_kv: return F.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, torch.empty([0]), torch.cat( (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, self.training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight) if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) if 'prev_key' in saved_state: # previous time steps are cached - no need to recompute # key and value if they are static if static_kv: assert self.encoder_decoder_attention and not self.self_attention key = value = None else: saved_state = None if self.self_attention: q = self.q_proj(query) k = self.k_proj(query) v = self.v_proj(query) elif self.encoder_decoder_attention: # encoder-decoder attention q = self.q_proj(query) if key is None: assert value is None k = v = None else: k = self.k_proj(key) v = self.v_proj(key) else: q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) q *= self.scaling if self.bias_k is not None: assert self.bias_v is not None k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) if key_padding_mask is not None: key_padding_mask = torch.cat([ key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1) ], dim=1) q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) if k is not None: k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) if v is not None: v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) if saved_state is not None: # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) if 'prev_key' in saved_state: prev_key = saved_state['prev_key'].view( bsz * self.num_heads, -1, self.head_dim) if static_kv: k = prev_key else: k = torch.cat((prev_key, k), dim=1) if 'prev_value' in saved_state: prev_value = saved_state['prev_value'].view( bsz * self.num_heads, -1, self.head_dim) if static_kv: v = prev_value else: v = torch.cat((prev_value, v), dim=1) key_padding_mask = self._append_prev_key_padding_mask( key_padding_mask=key_padding_mask, prev_key_padding_mask=saved_state.get('prev_key_padding_mask', None), batch_size=bsz, src_len=k.size(1), static_kv=static_kv, ) saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim) saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim) saved_state['prev_key_padding_mask'] = key_padding_mask self._set_input_buffer(incremental_state, saved_state) src_len = k.size(1) # This is part of a workaround to get around fork/join parallelism # not supporting Optional types. if key_padding_mask is not None and key_padding_mask.shape == torch.Size( []): key_padding_mask = None if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len if self.add_zero_attn: src_len += 1 k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) if key_padding_mask is not None: key_padding_mask = torch.cat([ key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask) ], dim=1) if not bmm_fp16_support: q = q.float() k = k.float() v = v.float() attn_weights = torch.bmm(q, k.transpose(1, 2)) if not bmm_fp16_support: attn_weights = attn_weights.type_as(query) attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) assert list( attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) if self.onnx_trace: attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) attn_weights += attn_mask if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'), ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if before_softmax: return attn_weights, v # 1 if not self.cur_san_active: self.div = 0 if self.div > 0: top_k = int(torch.ceil(torch.Tensor([src_len / self.div]))) if top_k < self.lb: top_k = self.lb if top_k > src_len: top_k = src_len else: top_k = -self.div if top_k > src_len: top_k = src_len # 2 # print('attn_weights ', attn_weights.size()) if self.entmax: from entmax import sparsemax, entmax15, entmax_bisect if self.entmax == 1: attn_weights = sparsemax(attn_weights.float(), dim=-1).type_as(attn_weights) elif self.entmax == 2: attn_weights = entmax15(attn_weights.float(), dim=-1).type_as(attn_weights) elif self.entmax == 3: attn_weights_float = entmax_bisect( attn_weights.float(), dim=-1).type_as(attn_weights) else: if self.div: vk, _ = torch.topk(attn_weights, top_k) # print(value) tk = vk[:, :, -1].unsqueeze(2).expand_as(attn_weights) mask_k = torch.lt(attn_weights, tk) attn_weights = attn_weights.masked_fill( mask_k, float('-inf')).type_as(attn_weights) attn_weights_float = utils.softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace) attn_weights = attn_weights_float.type_as(attn_weights) attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) if not bmm_fp16_support: attn_probs = attn_probs.float( ) # bsz * self.num_heads, tgt_len, src_len attn = torch.bmm(attn_probs, v) if not bmm_fp16_support: attn = attn.type_as(query) assert list( attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] if (self.onnx_trace and attn.size(1) == 1): # when ONNX tracing a single decoder step (sequence length == 1) # the transpose is a no-op copy before view, thus unnecessary attn = attn.contiguous().view(tgt_len, bsz, embed_dim) else: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = self.out_proj(attn) if need_weights: attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) if not need_head_weights: # average attention weights over heads attn_weights = attn_weights.mean(dim=0) else: attn_weights = None return attn, attn_weights
def forward(self, y_t_1, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, step): if not self.training and step == 0: h_decoder, c_decoder = s_t_1 s_t_hat = torch.cat((h_decoder.view( -1, config.hidden_dim), c_decoder.view(-1, config.hidden_dim)), 1) # B x 2*hidden_dim c_t, _, coverage_next = self.attention_network( s_t_hat, encoder_outputs, encoder_feature, enc_padding_mask, coverage) coverage = coverage_next y_t_1_embd = self.embedding(y_t_1) x = self.x_context(torch.cat((c_t_1, y_t_1_embd), 1)) lstm_out, s_t = self.lstm(x.unsqueeze(1), s_t_1) h_decoder, c_decoder = s_t s_t_hat = torch.cat((h_decoder.view( -1, config.hidden_dim), c_decoder.view(-1, config.hidden_dim)), 1) # B x 2*hidden_dim c_t, attn_dist, coverage_next = self.attention_network( s_t_hat, encoder_outputs, encoder_feature, enc_padding_mask, coverage) if self.training or step > 0: coverage = coverage_next p_gen = None if config.pointer_gen: p_gen_input = torch.cat((c_t, s_t_hat, x), 1) # B x (2*2*hidden_dim + emb_dim) p_gen = self.p_gen_linear(p_gen_input) p_gen = F.sigmoid(p_gen) output = torch.cat((lstm_out.view(-1, config.hidden_dim), c_t), 1) # B x hidden_dim * 3 output = self.out1(output) # B x hidden_dim output = self.out2(output) # B x vocab_size T = config.temperature if config.tsallis_alpha == 1: vocab_dist = F.softmax(output / T, dim=1) elif config.tsallis_alpha == 1.5: vocab_dist = entmax15(output / T, dim=1) elif config.tsallis_alpha == 2: vocab_dist = sparsemax(output / T, dim=1) # debug('vocab_dist', vocab_dist.size()) if config.DEBUG and config.REC_ENTROPY: def get_entropy(t): return -torch.sum(torch.log(t) * t, dim=1) # vocab_dist_entropy = get_entropy(vocab_dist + config.eps) # debug('vocab_dist_entropy', vocab_dist_entropy) with open( os.path.join(config.log_root, 'vocab_dist_entropy/last_run.csv'), 'a') as f: f.write(','.join( [str(i) for i in vocab_dist.cpu().detach().numpy()]) + '\n') # if step == 0: # f.write( str(self.batch_cnt) + '\n') # self.batch_cnt += 1 # f.write( ','.join([ str(i) for i in vocab_dist_entropy.cpu().detach().numpy().round(4)] ) + '\n' ) # if config.entmax_select: # f.write( ','.join([ str(i) for i in p_soft.cpu().detach().numpy().round(4)] ) + '\n') if config.pointer_gen: vocab_dist_ = p_gen * vocab_dist attn_dist_ = (1 - p_gen) * attn_dist if extra_zeros is not None: vocab_dist_ = torch.cat([vocab_dist_, extra_zeros], 1) final_dist = vocab_dist_.scatter_add(1, enc_batch_extend_vocab, attn_dist_) # debug("extra_zeros", extra_zeros) # debug("vocab_dist_", vocab_dist_) # debug("attn_dist", attn_dist) # debug("enc_batch_extend_vocab", enc_batch_extend_vocab) # debug("final_dist", final_dist.size()) else: final_dist = vocab_dist tau = None if config.adaptive_sparsemax: eps = torch.DoubleTensor([config.eps]).cuda(0) activation = torch.sigmoid tau = (1 - eps) * activation( self.p_sparse_linear(torch.cat((c_t, s_t_hat, x), 1))) debug('tau + eps', tau + eps) final_dist = sparsemax(final_dist / (tau + eps), dim=-1) elif config.use_top_p: final_dist = top_p(final_dist, config.top_p) # with open('tau.txt','a') as f: # f.write(','.join([ str(i) for i in tau.cpu().detach().numpy().round(4)]) + '\n') # debug("top", final_dist.topk(10, -1)) # debug("entropy", vocab_dist_entropy) return final_dist, s_t, c_t, attn_dist, p_gen, coverage, tau