示例#1
0
文件: aligner.py 项目: stjordanis/TTS
 def forward(
         self,
         queries: torch.tensor,
         keys: torch.tensor,
         mask: torch.tensor = None,
         attn_prior: torch.tensor = None
 ) -> Tuple[torch.tensor, torch.tensor]:
     """Forward pass of the aligner encoder.
     Shapes:
         - queries: :math:`[B, C, T_de]`
         - keys: :math:`[B, C_emb, T_en]`
         - mask: :math:`[B, T_de]`
     Output:
         attn (torch.tensor): :math:`[B, 1, T_en, T_de]` soft attention mask.
         attn_logp (torch.tensor): :math:`[ßB, 1, T_en , T_de]` log probabilities.
     """
     key_out = self.key_layer(keys)
     query_out = self.query_layer(queries)
     attn_factor = (query_out[:, :, :, None] - key_out[:, :, None])**2
     attn_logp = -self.temperature * attn_factor.sum(1, keepdim=True)
     if attn_prior is not None:
         attn_logp = self.log_softmax(attn_logp) + torch.log(
             attn_prior[:, None] + 1e-8)
     if mask is not None:
         attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2),
                                     -float("inf"))
     attn = self.softmax(attn_logp)
     return attn, attn_logp
示例#2
0
    def _evaluate_multi_classification_with_rejection(self, h: torch.tensor,
                                                      t: torch.tensor,
                                                      r_binary: torch.tensor):
        """
        evaluate result of multi classification. 

        Args:
            h (B): prediction which indicates class index from 0 to #class-1
            t (B): labels which indicates true label form 0 to #class-1
            r_binary (B): labels which indicates 'accept:1' and 'reject:0'
        Return:
            OrderedDict: 'acc'/'raw acc'
        """
        assert h.size(0) == t.size(0) == r_binary.size(0) > 0
        assert len(h.size()) == len(t.size()) == len(r_binary.size()) == 1

        # raw accuracy
        eval_dict = self._evaluate_multi_classification(h, t)
        eval_dict['raw accuracy'] = eval_dict['accuracy']
        del eval_dict['accuracy']

        h_rjc = torch.masked_select(h, r_binary.bool())
        t_rjc = torch.masked_select(t, r_binary.bool())

        t = float(
            torch.where(h_rjc == t_rjc, torch.ones_like(h_rjc),
                        torch.zeros_like(h_rjc)).sum())
        f = float(
            torch.where(h_rjc != t_rjc, torch.ones_like(h_rjc),
                        torch.zeros_like(h_rjc)).sum())
        # accuracy
        acc = float(t / (t + f + 1e-12))
        eval_dict['accuracy'] = acc

        return eval_dict
示例#3
0
    def step(self, word: torch.tensor, last_hidden_state: tuple,
             domains_embedding: torch.tensor, domains_mask: torch.tensor):
        '''
        Perform a step of LSTM Cell.

        @param word of shape (batch_size, word_embedding_dim): w_t in paper

        @param last_hidden_state, tuple of shape ((batch_size, hidden_size), (batch_size, hidden_size)): hidden_state and cell at time t-1, h_{t-1} and c_{t-1} in paper

        @param domains_embedding of shape (batch_size, max_domains_num, domain_embedding_dim): input question domains embedding, u in paper

        @param domains_mask of shape (batch_size, max_domains_num): mask of domains (mask of src or padding)

        @return hidden_state contains (hidden_hidden of shape (batch_size, hidden_size), hidden_cell)
        '''

        batch_size = word.size(0)
        word = torch.unsqueeze(word, dim=1)
        expand_word = word.expand(batch_size, self.max_domains_num,
                                  word.size(2))
        if last_hidden_state is None:
            expanded_hidden_state = torch.zeros(
                (batch_size, self.max_domains_num, self.hidden_size),
                dtype=torch.float).to(self.device)
        else:
            expanded_hidden_state = torch.unsqueeze(
                last_hidden_state[0],
                dim=1).expand(last_hidden_state[0].size(0),
                              self.max_domains_num,
                              last_hidden_state[0].size(1))
        attention_mat = torch.cat(
            [domains_embedding, expand_word, expanded_hidden_state], dim=2)

        w_proj = torch.matmul(attention_mat, self.W_ac)
        v_proj = torch.matmul(torch.tanh(w_proj), self.V_ac).squeeze(2)

        v_proj.data.masked_fill_(domains_mask.bool(), float('-inf'))

        # shape (batch_size, max_domains_num)
        alpha = nn.functional.softmax(v_proj, dim=1)

        u_t = torch.bmm(alpha.unsqueeze(1), domains_embedding).squeeze(
            1)  # shape (batch_size, domain_embedding_dim)
        x = torch.cat([word.squeeze(1), u_t], dim=1)

        hidden_state = self.attention_LSTM.forward(x, last_hidden_state)

        return hidden_state
示例#4
0
    def _calculate_loss(
        self,
        scores: torch.tensor,
        sentences: List[Sentence],
        mask: torch.tensor,
        return_arc_rel=False,
    ) -> float:

        if self.binary:
            pass
        else:
            # the system preds represents whether the tag is correct
            if hasattr(sentences, self.tag_type + '_system_preds'):
                system_preds = getattr(sentences,
                                       self.tag_type + '_system_preds').to(
                                           flair.device).long()
            else:
                system_preds = torch.stack([
                    getattr(sentence, self.tag_type + '_system_preds').to(
                        flair.device) for sentence in sentences
                ], 0).long()

            mask = mask.bool()

        loss = self.criterion(scores,
                              system_preds.float()) * mask.unsqueeze(-1)
        loss = loss.sum() / mask.sum()
        # bce_loss = -(torch.log(torch.sigmoid(scores)) * system_preds + torch.log(1-torch.sigmoid(scores)) * (1-system_preds))
        # loss = 2 * ((1-self.interpolation) * arc_loss + self.interpolation * rel_loss)

        # score = torch.nn.functional.cross_entropy(features.view(-1,features.shape[-1]), tag_list.view(-1,), reduction='none') * mask.view(-1,)

        # if self.sentence_level_loss or self.use_crf:
        #   score = score.sum()/features.shape[0]
        # else:
        #   score = score.sum()/mask.sum()

        #   score = (1-self.posterior_interpolation) * score + self.posterior_interpolation * posterior_score
        return loss
	def _calculate_loss(
		self, arc_scores: torch.tensor, rel_scores: torch.tensor, sentences: List[Sentence], mask: torch.tensor, return_arc_rel = False,
	) -> float:
		if self.binary:
			root_mask = mask.clone()
			root_mask[:,0] = 0
			binary_mask = root_mask.unsqueeze(-1) * mask.unsqueeze(-2)
			# arc_mat=
			if hasattr(sentences,self.tag_type+'_arc_tags'):
				arc_mat=getattr(sentences,self.tag_type+'_arc_tags').to(flair.device).float()
			else:
				arc_mat=torch.stack([getattr(sentence,self.tag_type+'_arc_tags').to(flair.device) for sentence in sentences],0).float()
			if hasattr(sentences,self.tag_type+'_rel_tags'):
				rel_mat=getattr(sentences,self.tag_type+'_rel_tags').to(flair.device).long()
			else:
				rel_mat=torch.stack([getattr(sentence,self.tag_type+'_rel_tags').to(flair.device) for sentence in sentences],0).long()
			
			arc_loss = self.arc_criterion(arc_scores, arc_mat)
			rel_loss = self.rel_criterion(rel_scores.reshape(-1,self.tagset_size), rel_mat.reshape(-1))
			arc_loss = (arc_loss*binary_mask).sum()/binary_mask.sum()

			rel_mask = (rel_mat>0)*binary_mask
			num_rels=rel_mask.sum()
			if num_rels>0:
				rel_loss = (rel_loss*rel_mask.view(-1)).sum()/num_rels
			else:
				rel_loss = 0
			# rel_loss = (rel_loss*rel_mat.view(-1)).sum()/rel_mat.sum()
		else:
			if hasattr(sentences,self.tag_type+'_arc_tags'):
				arcs=getattr(sentences,self.tag_type+'_arc_tags').to(flair.device).long()
			else:
				arcs=torch.stack([getattr(sentence,self.tag_type+'_arc_tags').to(flair.device) for sentence in sentences],0).long()
			if hasattr(sentences,self.tag_type+'_rel_tags'):
				rels=getattr(sentences,self.tag_type+'_rel_tags').to(flair.device).long()
			else:
				rels=torch.stack([getattr(sentence,self.tag_type+'_rel_tags').to(flair.device) for sentence in sentences],0).long()
			self.arcs=arcs
			self.rels=rels
			mask[:,0] = 0
			mask = mask.bool()
			gold_arcs = arcs[mask]
			rel_scores, rels = rel_scores[mask], rels[mask]
			rel_scores = rel_scores[torch.arange(len(gold_arcs)), gold_arcs]
			if self.use_crf:
				arc_loss, arc_probs = crf(arc_scores, mask, arcs)
				arc_loss = arc_loss/mask.sum()
				rel_loss = self.rel_criterion(rel_scores, rels)

				#=============================================================================================
				# dist=generate_tree(arc_scores,mask,is_mst=self.is_mst)
				# labels = dist.struct.to_parts(arcs[:,1:], lengths=mask.sum(-1)).type_as(arc_scores)
				# log_prob = dist.log_prob(labels)
				# if (log_prob>0).any():
					
				#   log_prob[torch.where(log_prob>0)]=0
				#   print("failed to get correct loss!")
				# if self.token_loss:
				#   arc_loss = - log_prob.sum()/mask.sum()
				# else:
				#   arc_loss = - log_prob.mean()
				
				# self.dist=dist
				
				# rel_loss = self.rel_criterion(rel_scores, rels)
				# if self.token_loss:
				#   rel_loss = rel_loss.mean()
				# else:
				#   rel_loss = rel_loss.sum()/len(sentences)

				# if self.debug:
				#   if rel_loss<0 or arc_loss<0:
				#       pdb.set_trace()
				#=============================================================================================
			else:
				arc_scores, arcs = arc_scores[mask], arcs[mask]
				arc_loss = self.arc_criterion(arc_scores, arcs)
			
				# rel_scores, rels = rel_scores[mask], rels[mask]
				# rel_scores = rel_scores[torch.arange(len(arcs)), arcs]
				
				rel_loss = self.rel_criterion(rel_scores, rels)
		if return_arc_rel:
			return (arc_loss,rel_loss)
		loss = 2 * ((1-self.interpolation) * arc_loss + self.interpolation * rel_loss)


		# score = torch.nn.functional.cross_entropy(features.view(-1,features.shape[-1]), tag_list.view(-1,), reduction='none') * mask.view(-1,)



		# if self.sentence_level_loss or self.use_crf:
		#   score = score.sum()/features.shape[0]
		# else:
		#   score = score.sum()/mask.sum()
			
		#   score = (1-self.posterior_interpolation) * score + self.posterior_interpolation * posterior_score
		return loss
示例#6
0
def mask_input_embeddings(input_embeddings: torch.tensor,
                          special_embeddings_mask: torch.tensor,
                          device,
                          sentence_mask_probability = 0.15):
  """
  Randomly masks sentences with a probability of 15%. The masked sentence
  embeddings are replaced with a random tensor and the original embedding will
  be stored in a labels tensor that has the same size as the input tensor. The
  ground truth embedding will sit at the same position as is did in the input
  tensor to make it easier to identify the correct ground truth for loss
  computing.

  Args:
    input_embeddings: A torch.tensor containing all sentence embeddings computed
      by the Sentence Model for a given batch. The size of the tensor is
      [batch_size, max_doc_length, embedding_size]. Note that the documents are
      already padded to the length of the longest document in the batch.
    special_embeddings_mask: A torch.tensor of the same size as input_embeddings
      [batch_size, max_doc_length] which hold 0s where there is a real sentence 
      present and 1s where there is a special token embedding, that includes 
      CLS, SEP and PAD tokens.
  Returns:
    masked_input_embeddings: Same shape as input embeddings, only that it holds
      a random tensor wherever a sentence embedding was masked.
    label_embeddings: Same shape as the masked_input_embeddings but all entries 
      are filled with 0s except where there is a masked sentence embedding. That
      entry will be filled with the original input embedding.
    label_mask: torch.BoolTensor
  """
  masked_input_embeddings = input_embeddings.clone()
  label_embeddings = torch.zeros_like(input_embeddings)
  label_mask = torch.zeros_like(special_embeddings_mask)

  probability_matrix = torch.full(special_embeddings_mask.shape, sentence_mask_probability, device=device)

  probability_matrix.masked_fill_(special_embeddings_mask.bool(), value=0.0)

  masked_indices = torch.bernoulli(probability_matrix).bool()

  # Choose a random index per document to mask in case nothing was randomly masked 
  # via the Bernoulli distribution (will return None, which will lead to an error
  # when we want to manipulate the Tensors inside the loss function)
  if torch.sum(masked_indices.long()).item() == 0:
    forced_mask_indexes = []
    for document in special_embeddings_mask:
      document_list = document.tolist()
      real_indexes = [i for i, x in enumerate(document_list) if x == 0]
      single_choice_per_doc = random.choice(real_indexes)
      forced_mask_indexes.append(single_choice_per_doc)
    for forced_index, previously_masked_doc in zip(forced_mask_indexes, masked_indices):
      previously_masked_doc[forced_index] = True

  document_counter = 0
  sentence_counter = 0

  for document in input_embeddings:
    sentence_counter = 0
    for sentence in document:
      if masked_indices[document_counter][sentence_counter]:
        label_embeddings[document_counter][sentence_counter] = input_embeddings[document_counter][sentence_counter]
        label_mask[document_counter][sentence_counter] = 1.0
        masked_input_embeddings[document_counter][sentence_counter] = torch.randn_like(input_embeddings[document_counter][sentence_counter])
      sentence_counter += 1
    document_counter += 1

  label_embeddings[~masked_indices] = 0
  label_mask = torch.Tensor.bool(label_mask)

  return (input_embeddings, masked_input_embeddings, label_embeddings, label_mask)
示例#7
0
    def step(self, Ybar_t: torch.tensor,
             dec_state: Tuple[torch.tensor, torch.tensor],
             enc_hiddens: torch.tensor,
             enc_hiddens_proj: torch.tensor,
             enc_masks: torch.tensor) -> Tuple[Tuple, torch.tensor, torch.tensor]:
        """ Compute one forward step of the LSTM decoder, including the attention computation.

        @param Ybar_t (Tensor): Concatenated Tensor of [Y_t o_prev], with shape (b, e + h_e). The input for the decoder,
                                where b = batch size, e = embedding size, h = hidden size.
        @param dec_state (tuple(Tensor, Tensor)): Tuple of tensors both with shape (b, h_d),
                where b = batch size, h_d = hidden_size_dec.
                First tensor is decoder's prev hidden state, second tensor is decoder's prev cell.
        @param enc_hiddens (Tensor): Encoder hidden states Tensor, with shape (b, src_len, h_e * 2), where b = batch size,
                                    src_len = maximum source length, h = hidden size.
        @param enc_hiddens_proj (Tensor): Encoder hidden states Tensor, projected from (h_e * 2) to h.
                Tensor is with shape (b, src_len, h),
                where b = batch size, src_len = maximum source length, h = hidden size.
        @param enc_masks (Tensor): Tensor of sentence masks shape (b, src_len),
                                    where b = batch size, src_len is maximum source length.

        @returns dec_state (tuple (Tensor, Tensor)): Tuple of tensors both shape (b, h),
                where b = batch size, h = hidden size.
                First tensor is decoder's new hidden state, second tensor is decoder's new cell.
        @returns combined_output (Tensor): Combined output Tensor at timestep t, shape (b, h),
                where b = batch size, h = hidden size.
        @returns e_t (Tensor): Tensor of shape (b, src_len). It is attention scores distribution.
                                Note: You will not use this outside of this function.
                                      We are simply returning this value so that we can sanity check
                                      your implementation.
        """

        combined_output = None

        e_t = None
        # YOUR CODE HERE (~3 Lines)
        # TODO:
        #     1. Apply the decoder to `Ybar_t` and `dec_state`to obtain the new dec_state.
        #     2. Split dec_state into its two parts (dec_hidden, dec_cell)
        #     3. Compute the attention scores e_t [src_len*2h*1], and alpha, a Tensor shape (b, src_len).
        #        Note: b = batch_size, src_len = maximum source length, h = hidden size.
        #
        #       Hints:
        #         - dec_hidden is shape (b, h) and corresponds to h^dec_t in the PDF (batched)
        #         - enc_hiddens_proj is shape (b, src_len, h) and corresponds to W_{attProj} h^enc (batched).
        #         - Use batched matrix multiplication (torch.bmm) to compute e_t.
        #         - To get the tensors into the right shapes for bmm, you'll need to do some squeezing and unsqueezing.
        #         - When using the squeeze() function make sure to specify the dimension you want to squeeze
        #             over. Otherwise, you will remove the batch dimension accidentally, if batch_size = 1.
        #
        # Use the following docs to implement this functionality:
        #     Batch Multiplication:
        #        https://pytorch.org/docs/stable/torch.html#torch.bmm
        #     Tensor Unsqueeze:
        #         https://pytorch.org/docs/stable/torch.html#torch.unsqueeze
        #     Tensor Squeeze:
        #         https://pytorch.org/docs/stable/torch.html#torch.squeeze

        # INPUTS:
        #
        # Ybar_t: [b x (e + h_enc)]                 <-- in pdf, Y_t is [e+h_dec x 1]
        # dec_state (OG): ([b x h_dec], [b x h_dec])

        # DECODER:
        # self.decoder = nn.LSTMCell(embed_size + hidden_size_enc, hidden_size_dec)

        dec_state = dec_hidden, dec_cell = self.decoder(Ybar_t, dec_state) # ([b x h], [b x h])

        #
        # COMPUTE E_T
        #

        e_t = self.attention_function(dec_hidden, enc_hiddens_proj)


        # END YOUR CODE

        # Set e_t to -inf where enc_masks has 1
        if enc_masks is not None:
            e_t.data.masked_fill_(enc_masks.bool(), -float('inf'))

        # YOUR CODE HERE (~6 Lines)
        # TODO:
        #     1. Apply softmax to e_t to yield alpha_t
        #     2. Use batched matrix multiplication between alpha_t and enc_hiddens to obtain the
        #         attention output vector, a_t.
        #     Hints:
        #           - alpha_t is shape (b, src_len)
        #           - enc_hiddens is shape (b, src_len, 2h)
        #           - a_t should be shape (b, 2h)
        #           - You will need to do some squeezing and unsqueezing.
        #     Note: b = batch size, src_len = maximum source length, h = hidden size.
        #     3. Concatenate dec_hidden with a_t to compute tensor U_t
        #     4. Use the output projection layer to compute tensor V_t
        #     5. Compute tensor O_t using the Tanh function and the dropout layer.
        #
        # Use the following docs to implement this functionality:
        #     Softmax:
        #         https://pytorch.org/docs/stable/nn.html#torch.nn.functional.softmax
        #     Batch Multiplication:
        #        https://pytorch.org/docs/stable/torch.html#torch.bmm
        #     Tensor View:
        #         https://pytorch.org/docs/stable/tensors.html#torch.Tensor.view
        #     Tensor Concatenation:
        #         https://pytorch.org/docs/stable/torch.html#torch.cat
        #     Tanh:
        #         https://pytorch.org/docs/stable/torch.html#torch.tanh

        #
        # COMPUTE A
        #
        alpha_t = torch.nn.functional.softmax(e_t, 1)              # [b x src_len]
        alpha_t = alpha_t.unsqueeze(1)                  # [b x 1 x src_len]
        a_t = torch.bmm(alpha_t, enc_hiddens)                 # [b,1,sl]*[b,sl,2h] -> [b,1,2h]
        a_t = a_t.squeeze(1)                            # [b,1,2h] -> [b,2h]

        U_t = torch.cat((dec_hidden, a_t), 1)           # [b x h] + [b x 2h] = [b x 3h]
        V_t = self.combined_output_projection(U_t)           # [h x 3h] * [b x 3h (x 1)] -> [b x h (x 1)]

        O_t = self.dropout( torch.tanh(V_t) )

        # END YOUR CODE

        combined_output = O_t
        return dec_state, combined_output, e_t