예제 #1
0
    def get_distillation_loss(self,
                              target_label,
                              target_lengths,
                              old_indices,
                              ref_output,
                              new_task_num,
                              distill_New=True,
                              sv_ratio=1,
                              do_label=None,
                              da_label=None,
                              sv_label=None):
        """
        This func calculate  distillation for all utterences in a batch
        @param target_labels: ground truth words (batch_size, max_sent_len, vocab_size)
        @param target_lengths: list of sentence lengths (va)
        @param old_indices_mask: a mask to indicate whether voc in a sentence is in old_indices, we only ditill them
        @param old_indices: voc indices that are not in this domain, we only compute kl-div w.r.t. distribution on them
        @param word_dists: unnormalized word output proba distribution, shape = (num_distillation, max_sent_len, vocab_size)
        @param new_task_num: number of sample from new task
        @param distill_New: whether to distill new utterances in the current domain
        @return distillation loss
        """

        length = Variable(torch.LongTensor(target_lengths)).cuda()
        batch_size = len(target_lengths)

        # Compute ce on all sentences
        if distill_New:
            loss_ce = masked_cross_entropy(
                self.output_prob.contiguous(),  # -> batch x seq
                target_label.contiguous(),  # -> batch x seq
                length)
        else:
            loss_ce = masked_cross_entropy(
                self.output_prob[:new_task_num].contiguous(),  # -> batch x seq
                target_label[:new_task_num].contiguous(),  # -> batch x seq
                length[:new_task_num])

        if distill_New:
            # Distill all sentences in the batch
            distillation_loss = masked_kl_divergence(self.output_prob[:, :, old_indices].contiguous(), \
                                                     ref_output[:, :, old_indices].contiguous(), \
                                                     length, self.args.T)
        else:
            # Distill exemplars only
            if new_task_num < batch_size:
                distillation_loss = masked_kl_divergence(self.output_prob[new_task_num:].contiguous(), \
                                                         ref_output[new_task_num:].contiguous(), \
                                                         length[new_task_num:], self.args.T)
            else:
                distillation_loss = torch.tensor(0)

        if self.args.adaptive:
            self.loss = loss_ce + self.args._lambda * sv_ratio * distillation_loss
        else:
            self.loss = loss_ce + self.args._lambda * distillation_loss

        return self.loss
예제 #2
0
    def get_loss(self,
                 target_label,
                 target_lengths,
                 do_label=None,
                 da_label=None,
                 sv_label=None,
                 return_vec=False):
        """
        Compute loss = cross_entropy_loss + kl_annealing loss + do,da,sv loss \n
        @param target_label(torch.Tensor) \n
        @param target_lengths(list): list of length of target sentence \n
        @param do_label(torch.Tensor): domain labels \n
        @param da_label(torch.Tensor): dialogue act labels \n
        @param sv_label(torch.Tensor): slot value labels \n
        @param return_vec(boolean): whether return per word loss in cross entropy loss
        """

        length = Variable(torch.LongTensor(target_lengths)).cuda()
        self.loss = masked_cross_entropy(
            self.output_prob.contiguous(),  # -> batch x seq
            target_label.contiguous(),  # -> batch x seq
            length,
            return_vec=return_vec)

        return self.loss
예제 #3
0
파일: main.py 프로젝트: lujiaying/tqa
def train(input_batches, input_lengths, target_batches, target_lengths,
          encoder, decoder, encoder_optimizer, decoder_optimizer, criterion):
    batch_size = input_batches.size(1)

    # Zero gradients of both optimizers
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    loss = 0  # Added onto for each word

    # Run words through encoder
    encoder_outputs, encoder_hidden = encoder(input_batches, input_lengths,
                                              None)

    # Prepare input and output variables
    decoder_input = Variable(torch.LongTensor([SOS_ID] * batch_size))
    decoder_hidden = encoder_hidden[:decoder.
                                    n_layers]  # Use last (forward) hidden state from encoder

    max_target_length = max(target_lengths)
    all_decoder_outputs = Variable(
        torch.zeros(max_target_length, batch_size, decoder.output_size))

    # Move new Variables to CUDA
    if USE_CUDA:
        decoder_input = decoder_input.cuda()
        all_decoder_outputs = all_decoder_outputs.cuda()

    # Run through decoder one time step at a time
    for t in range(max_target_length):
        decoder_output, decoder_hidden, decoder_attn = decoder(
            decoder_input, decoder_hidden, encoder_outputs)

        all_decoder_outputs[t] = decoder_output
        decoder_input = target_batches[t]  # Next input is current target

    # Loss calculation and backpropagation
    loss = masked_cross_entropy(
        all_decoder_outputs.transpose(0, 1).contiguous(),  # -> batch x seq
        target_batches.transpose(0, 1).contiguous(),  # -> batch x seq
        target_lengths)
    loss.backward()

    # Clip gradient norms
    ec = torch.nn.utils.clip_grad_norm(encoder.parameters(), clip)
    dc = torch.nn.utils.clip_grad_norm(decoder.parameters(), clip)

    # Update parameters with optimizers
    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.data[0], ec, dc
    def _diag_fisher(self):

        self.model.train()
        precision_matrices = {}
        for n, p in copy.deepcopy(self.params).items():
            p.data.zero_()
            precision_matrices[n] = variable(p.data)

        self.dataset.batch_size = 1  # set batch_size to 1 in ewc

        for i in range(len(self.dataset.data['train'])):
            self.model.zero_grad()
            input_var, label_var, feats_var, lengths, refs, featStrs, sv_indexes, _, do_label, da_label, sv_label = self.dataset.next_batch(
                "train")

            # feedforward and calculate loss
            if self.model.model_type == "lm":
                decoded_words, _ = self.model(input_var, self.dataset,
                                              feats_var)
            else:
                self.model.set_prior(False)
                target_var = input_var.clone()
                decoded_words, _ = self.model(input_var,
                                              input_lengths=lengths,
                                              target_seq=target_var,
                                              target_lengths=lengths,
                                              conds_seq=feats_var,
                                              dataset=self.dataset)

            length = Variable(torch.LongTensor(lengths)).cuda()

            # empirical Fisher if we provide ground truth label
            loss = masked_cross_entropy(
                self.model.output_prob.contiguous(),  # -> batch x seq
                label_var.contiguous(),  # -> batch x seq
                length)

            loss.backward()

            for n, p in self.model.named_parameters():

                # Jump over layers that is not trained
                if p.grad is None:
                    continue
                precision_matrices[n].data += p.grad.data**2 / len(
                    self.dataset.data['train'])

        precision_matrices = {n: p for n, p in precision_matrices.items()}
        return precision_matrices
예제 #5
0
    def get_loss(self,
                 target_label,
                 target_lengths,
                 do_label,
                 da_label,
                 sv_label,
                 full_kl_step=5000,
                 return_vec=False):
        """
		Compute loss = cross_entropy_loss + kl_annealing loss + do,da,sv loss \n
		@param target_label(torch.Tensor) \n
		@param target_lengths(list): list of length of target sentence \n
		@param do_label(torch.Tensor): domain labels \n
		@param da_label(torch.Tensor): dialogue act labels \n
		@param sv_label(torch.Tensor): slot value labels \n
		@param full_kl_step(int): kl annealing parameter \n
		@param return_vec(boolean): whether return per word loss in cross entropy loss 
		"""

        length = Variable(torch.LongTensor(target_lengths)).cuda()

        # Compute cross entropy loss
        rc_loss = masked_cross_entropy(
            self.output_prob.contiguous(),  # -> batch x seq
            target_label.contiguous(),  # -> batch x seq
            length,
            return_vec=return_vec)

        # Compute kl annealing loss
        kl_weight = min(self.global_t / full_kl_step, 1.0)
        kl_loss = torch.mean(self.gaussian_kld())

        # Compute domain, dialogue act, slot value loss
        do_loss = self.criterion['xent'](self.do_output, do_label)
        da_loss = self.criterion['multilabel'](self.da_output, da_label)
        sv_loss = self.criterion['multilabel'](self.sv_output, sv_label)

        self.loss = rc_loss + kl_weight * kl_loss + do_loss + da_loss + sv_loss

        return self.loss
예제 #6
0
 def get_loss(self, target_label, target_lengths):
     self.loss = masked_cross_entropy(
         self.output_prob.contiguous(),  # -> batch x seq
         target_label.contiguous(),  # -> batch x seq
         target_lengths)
     return self.loss
예제 #7
0
    def get_distillation_loss(self,
                              target_label,
                              target_lengths,
                              old_indices_mask,
                              old_indices,
                              ref_output,
                              new_task_num,
                              distill_New=True,
                              sv_ratio=1,
                              do_label=None,
                              da_label=None,
                              sv_label=None,
                              full_kl_step=5000):
        """
		This func calculate  distillation for all utterences in a batch \n
		@param target_labels: ground truth words (batch_size, max_sent_len, vocab_size) \n
		@param target_lengths: list of sentence lengths (va) \n
		@param old_indices_mask: a mask to indicate whether voc in a sentence is in old_indices, we only ditill them \n
		@param old_indices: voc indices that are not in this domain, we only compute kl-div w.r.t. distribution on them \n
		@param word_dists: unnormalized word output proba distribution, shape = (num_distillation, max_sent_len, vocab_size) \n
		@param new_task_num: number of sample from new task \n
		@param distill_New: whether to distill new utterances in the current domain \n
		@return distillation loss 
		"""

        length = Variable(torch.LongTensor(target_lengths)).cuda()
        batch_size = len(target_lengths)

        # Compute cross entropy loss
        if distill_New:
            loss_ce = masked_cross_entropy(
                self.output_prob.contiguous(),  # -> batch x seq
                target_label.contiguous(),  # -> batch x seq
                length)
        else:
            loss_ce = masked_cross_entropy(
                self.output_prob[:new_task_num].contiguous(),  # -> batch x seq
                target_label[:new_task_num].contiguous(),  # -> batch x seq
                length[:new_task_num])

        # Compute kl annealing loss
        kl_weight = min(self.global_t / full_kl_step, 1.0)
        kl_loss = torch.mean(self.gaussian_kld())

        # Compute domain, dialogue act and slot value loss
        do_loss = self.criterion['xent'](self.do_output, do_label)
        da_loss = self.criterion['multilabel'](self.da_output, da_label)
        sv_loss = self.criterion['multilabel'](self.sv_output, sv_label)

        if distill_New:
            # Distill all sentences
            distillation_loss = masked_kl_divergence(self.output_prob[:, :, old_indices].contiguous(), \
                       ref_output[:, :, old_indices].contiguous(), \
                       length, self.args.T)
        else:
            # Distill exemplars only
            if new_task_num < batch_size:
                distillation_loss = masked_kl_divergence(self.output_prob[new_task_num:].contiguous(), \
                           ref_output[new_task_num:].contiguous(), \
                           length[new_task_num:], self.args.T)
            else:
                distillation_loss = torch.tensor(0)

        # Calculate loss in adaptive way or not
        if self.args.adaptive:
            self.loss = loss_ce + self.args._lambda * sv_ratio * distillation_loss + kl_weight * kl_loss + do_loss + da_loss + sv_loss
        else:
            self.loss = loss_ce + self.args._lambda * distillation_loss + kl_weight * kl_loss + do_loss + da_loss + sv_loss

        return self.loss