Esempio n. 1
0
    def __call__(self, outputs, output_symbols, targets):
        '''
        Inputs:
            outputs: (seq_len, batch_size, label_size)
            output_symbols : (seq_len, batch_size) index of output symbol (sampling from policy)
            targets: (batch_size, label_size) 
        '''
        '''
        outputs = torch.stack(outputs)
        
        output_symbols = torch.stack(output_symbols).squeeze(2)
        
        seq_len, batch_size, label_size = outputs.shape
        
        outputs = outputs.transpose(0,1) # batch_size * seq_len * label_size
        outputs = outputs.transpose(1,2) # batch_size * label_size * seq_len
        
        mask = torch.ones((seq_len, batch_size), dtype = torch.float32, device = outputs.device)
        mask[1:,:] = 1 - output_symbols[:-1,:].data.eq(self.eos_id).float() 
        
        losses = self.criterion(outputs, output_symbols.transpose(0,1)) * mask
        loss = torch.sum(losses) / torch.sum(mask)
        '''

        outputs = torch.stack(outputs)
        targets = targets.to(outputs.device)

        output_symbols = torch.stack(output_symbols).squeeze(2)
        seq_len, batch_size, label_size = outputs.shape

        target_time = torch.zeros((seq_len, batch_size),
                                  dtype=torch.long,
                                  device=outputs.device)
        mask = torch.ones((seq_len, batch_size),
                          dtype=torch.float32,
                          device=outputs.device)

        for i in range(0, seq_len):
            target_time[i] = (torch.exp(outputs[i]) * targets).topk(
                1, dim=1)[1].squeeze(1)
            targets = targets - utils.to_one_hot(target_time[i], label_size)

            # check if all targets are sucessfully predicted
            is_end_batch = torch.sum(targets, dim=1).eq(0)
            targets[:, self.eos_id] += is_end_batch.float()

            # check eos in output token
            if i > 0:
                eos_batches = target_time[i - 1, :].data.eq(self.eos_id)
                eos_batches = eos_batches.float()
                mask[i, :] = (1 - eos_batches) * mask[i - 1, :]
        losses = self.criterion(outputs.permute(
            1, 2, 0), target_time.transpose(
                0, 1)) * mask  # (batch_size, label_size, seq_len)
        loss = torch.sum(losses) / torch.sum(mask)
        return loss
Esempio n. 2
0
    def __call__(self, outputs, output_symbols, targets):
        '''
        Inputs:
            outputs: (seq_len, batch_size, label_size)
            output_symbols : (seq_len, batch_size) index of output symbol (sampling from policy)
            targets: (batch_size, label_size) 
        '''
        # some details:
        # directly minize specific score
        # give sos low score
        outputs = torch.stack(outputs)
        targets = targets.to(outputs.device)

        output_symbols = torch.stack(output_symbols).squeeze(2)
        seq_len, batch_size, label_size = outputs.shape

        outputs_one_hot = utils.to_one_hot(output_symbols,
                                           label_size).to(outputs.device)
        q_values = torch.zeros(outputs.shape,
                               dtype=torch.float32,
                               device=outputs.device)

        mask = torch.ones((seq_len, batch_size),
                          dtype=torch.float32,
                          device=outputs.device)

        q_values[0, :, :] = -1 + targets
        for i in range(1, seq_len):
            is_correct = targets * outputs_one_hot[
                i - 1, :, :]  # batch_size * label_size
            targets = targets - is_correct
            q_values[i, :, :] = q_values[i - 1, :, :] - is_correct + torch.sum(
                is_correct, dim=1).unsqueeze(1) - 1

            # check if all targets are sucessfully predicted
            is_end_batch = torch.sum(targets, dim=1).eq(0)
            q_values[i, :, self.eos_id] += is_end_batch.float()

            # check eos in output token
            eos_batches = output_symbols[i - 1, :].data.eq(self.eos_id)
            eos_batches = eos_batches.float()
            mask[i, :] = (1 - eos_batches) * mask[i - 1, :]

        optimal_policy = torch.softmax(q_values / self.temperature, dim=2)
        #print(F.kl_div(optimal_policy, outputs))

        # KL divergence
        #softmax
        #losses =  torch.mean(optimal_policy * torch.log(optimal_policy / (outputs + 1e-8) + 1e-8), dim = 2) * mask
        #log_softmax:
        losses = torch.mean(optimal_policy *
                            (torch.log(optimal_policy + 1e-8) - outputs),
                            dim=2) * mask
        loss = torch.sum(losses) / torch.sum(mask)
        return loss
Esempio n. 3
0
    def __call__(self, outputs, output_symbols, targets):
        outputs = torch.stack(outputs)
        targets = targets.to(outputs.device)

        output_symbols = torch.stack(output_symbols).squeeze(2)
        seq_len, batch_size, label_size = outputs.shape

        outputs_one_hot = utils.to_one_hot(output_symbols,
                                           label_size).to(outputs.device)
        target_each_time = torch.zeros(outputs.shape,
                                       dtype=torch.float32,
                                       device=outputs.device)

        mask = torch.ones((seq_len, batch_size),
                          dtype=torch.float32,
                          device=outputs.device)

        target_each_time[0, :, :] = targets
        for i in range(1, seq_len):
            is_correct = targets * outputs_one_hot[
                i - 1, :, :]  # batch_size * label_size
            targets = targets - is_correct

            target_each_time[i, :, :] = targets

            # check if all targets are sucessfully predicted
            is_end_batch = torch.sum(targets, dim=1).eq(0)
            target_each_time[i, :, self.eos_id] += is_end_batch.float()

            # check eos in output token
            eos_batches = output_symbols[i - 1, :].data.eq(self.eos_id)
            eos_batches = eos_batches.float()
            mask[i, :] = (1 - eos_batches) * mask[i - 1, :]

        prob_outputs = torch.exp(outputs)
        new_probs = prob_outputs * (1 - target_each_time)
        new_probs = new_probs / torch.sum(new_probs, dim=-1).unsqueeze(-1)
        Entropy = torch.sum(new_probs * torch.log(new_probs + 1e-8),
                            dim=-1) * mask
        loss = torch.sum(Entropy) / torch.sum(mask)

        return loss
Esempio n. 4
0
    def forward(self,
                inputs=None,
                encoder_hidden=None,
                encoder_outputs=None,
                dataset=None,
                teacher_forcing_ratio=0,
                candidates=None,
                logit_output=None):
        ret_dict = dict()
        if self.use_attention:
            ret_dict[DecoderRNN.KEY_ATTN_SCORE] = list()
        inputs, batch_size, max_length = self._validate_args(
            inputs, encoder_hidden, encoder_outputs, teacher_forcing_ratio,
            candidates)
        decoder_hidden = self._init_state(encoder_hidden)
        use_teacher_forcing = True if random.random(
        ) < teacher_forcing_ratio else False

        decoder_outputs = []
        sequence_symbols = []
        lengths = np.array([max_length] * batch_size)

        def post_decode(step_output, step_symbols, step_attn):
            decoder_outputs.append(step_output)
            if self.use_attention:
                ret_dict[DecoderRNN.KEY_ATTN_SCORE].append(step_attn)

            sequence_symbols.append(step_symbols)

            eos_batches = step_symbols.data.eq(self.eos_id)
            if eos_batches.dim() > 0:
                eos_batches = eos_batches.cpu().view(-1).numpy()
                update_idx = ((lengths > di) & eos_batches) != 0
                lengths[update_idx] = len(sequence_symbols)

        # Manual unrolling is used to support random teacher forcing.
        # If teacher_forcing_ratio is True or False instead of a probability,
        # the unrolling can be done in graph
        if use_teacher_forcing and self.loss_type == 'vanilla':
            ## Only vanilla RNN
            decoder_input = inputs[:, :-1]
            context, decoder_hidden, attn = self.forward_step(
                decoder_input, decoder_hidden, encoder_outputs)
            decoder_output, symbols = self.decoder(context)
            decoder_output = decoder_output.log()

            for di in range(decoder_output.size(1)):
                step_output = decoder_output[:, di, :]
                step_symbols = symbols[:, di]
                if attn is not None:
                    step_attn = attn[:, di, :]
                else:
                    step_attn = None
                post_decode(step_output, step_symbols, step_attn)
        else:
            decoder_input = inputs[:, 0].unsqueeze(1)
            mask = torch.zeros((batch_size, self.output_size),
                               dtype=torch.float32).to(inputs.device)

            for di in range(max_length):
                context, decoder_hidden, attn = self.forward_step(
                    decoder_input, decoder_hidden, encoder_outputs)
                if 'candidates' not in self.decoder.sampling_type:
                    decoder_output, symbols = self.decoder(
                        context, mask, logit_output=logit_output)
                else:
                    if use_teacher_forcing and self.loss_type == 'order_free':
                        # Order Free RNN
                        ori = self.decoder.sampling_type
                        self.set_sampling_type('max_from_candidates')
                        decoder_output, symbols = self.decoder(
                            context, mask, candidates)
                        self.set_sampling_type(ori)
                    else:
                        # Order Free + SS / vanilla + SS / OCD
                        decoder_output, symbols = self.decoder(
                            context, mask, candidates)

                    candidates -= utils.to_one_hot(
                        symbols.squeeze(2).squeeze(1),
                        self.output_size).float()
                    is_eos_batch = torch.sum(candidates, dim=1).eq(0)
                    candidates[:, self.eos_id] = is_eos_batch.float()

                decoder_output = decoder_output.log()

                step_output = decoder_output.squeeze(1)
                step_symbols = symbols.squeeze(1)
                post_decode(step_output, step_symbols, attn)
                decoder_input = step_symbols
                if self.add_mask:
                    # mask is one if a symbol has been predicted
                    # There will be error if loss is nan
                    try:
                        mask += utils.to_one_hot(step_symbols.squeeze(1),
                                                 self.output_size).float()
                    except:
                        print(torch.max(mask))

        ret_dict[DecoderRNN.KEY_SEQUENCE] = sequence_symbols
        ret_dict[DecoderRNN.KEY_LENGTH] = lengths.tolist()

        return decoder_outputs, decoder_hidden, ret_dict
    def forward(self,
                inputs=None,
                encoder_hidden=None,
                encoder_outputs=None,
                dataset=None,
                teacher_forcing_ratio=0,
                retain_output_probs=True,
                candidates=None,
                logit_output=None):
        """
        Forward rnn for MAX_LENGTH steps.  Look at :func:`seq2seq.models.DecoderRNN.DecoderRNN.forward_rnn` for details.
        """

        inputs, batch_size, max_length = self.rnn._validate_args(
            inputs, encoder_hidden, encoder_outputs, teacher_forcing_ratio)

        self.pos_index = Variable(
            torch.LongTensor(range(batch_size)) * self.k).view(-1, 1)

        # Inflate the initial hidden states to be of size: b*k x h
        encoder_hidden = self.rnn._init_state(encoder_hidden)
        if encoder_hidden is None:
            hidden = None
        else:
            if isinstance(encoder_hidden, tuple):
                n_layer_bidiretion = encoder_hidden[0].size(
                    0)  # n_layer * direction
                hidden = tuple([
                    _inflate(h, self.k, 2).view(n_layer_bidiretion,
                                                batch_size * self.k, -1)
                    for h in encoder_hidden
                ])
            else:
                # TODO :Should check _inflat dimension
                hidden = _inflate(encoder_hidden, self.k,
                                  2).view(1, batch_size * self.k)
        # ... same idea for encoder_outputs and decoder_outputs
        if self.rnn.use_attention:
            _, encoder_length, encoder_output_size = encoder_outputs.shape
            inflated_encoder_outputs = _inflate(encoder_outputs, self.k,
                                                1).view(
                                                    batch_size * self.k,
                                                    encoder_length,
                                                    encoder_output_size)
        else:
            inflated_encoder_outputs = None
        # logit output
        if logit_output is not None:
            label_size = logit_output.shape[-1]
            logit_output = _inflate(logit_output, self.k,
                                    1).view(batch_size * self.k, label_size)

        # Initialize the scores; for the first step,
        # ignore the inflated copies to avoid duplicate entries in the top k
        sequence_scores = torch.zeros((batch_size * self.k, 1),
                                      dtype=torch.float32)
        sequence_scores.fill_(-1000)
        sequence_scores.index_fill_(
            0, torch.LongTensor([i * self.k for i in range(0, batch_size)]),
            0.0)
        sequence_scores = Variable(sequence_scores)

        # Initialize the input vector
        input_var = Variable(
            torch.transpose(
                torch.LongTensor([[self.SOS] * batch_size * self.k]), 0, 1))

        # Initialize mask
        mask = torch.zeros((batch_size * self.k, self.V), dtype=torch.float32)

        # Initialize lengths
        lengths = torch.ones((batch_size * self.k, 1), dtype=torch.float32)

        # Initialize eos
        eos_indices = input_var.data.eq(self.EOS)
        eos_score = sequence_scores * eos_indices.float()  #bk*1

        # Assign all vars to CUDA if available
        if CUDA:
            self.pos_index = self.pos_index.cuda()
            input_var = input_var.cuda()
            sequence_scores = sequence_scores.cuda()
            mask = mask.cuda()
            lengths = lengths.cuda()
            eos_indices = eos_indices.cuda()
            eos_score = eos_score.cuda()

        # Store decisions for backtracking
        stored_outputs = list()
        stored_scores = list()
        stored_predecessors = list()
        stored_emitted_symbols = list()
        stored_hidden = list()

        for t in range(max_length):
            # Run the RNN one step forward
            context, hidden, attn = self.rnn.forward_step(
                input_var, hidden, inflated_encoder_outputs)
            softmax_output, _ = self.rnn.decoder(context,
                                                 logit_output=logit_output)

            log_softmax_output = softmax_output.log().squeeze(1)  #bk * v
            # If doing local backprop (e.g. supervised training), retain the output layer
            if retain_output_probs:
                stored_outputs.append(log_softmax_output)  #bk * v

            # To get the full sequence scores for the new candidates, add the local scores for t_i to the predecessor scores for t_(i-1)
            sequence_scores = _inflate(sequence_scores, self.V, 1)  #bk*V
            sequence_scores += log_softmax_output + mask

            # Terminated sentence can only produce eos token
            eos_mask = eos_indices.squeeze().float()
            sequence_scores[:,self.EOS] = \
                sequence_scores[:,self.EOS] * (1 - eos_mask) +  eos_score.squeeze() * eos_mask  #[bk]

            # Calculate new score
            if self.beam_score_type == 'sum':
                scores, candidates = sequence_scores.view(batch_size, -1).topk(
                    self.k)  # b* kV
                input_var = (candidates % self.V).view(batch_size * self.k, 1)
                # Reshape input = (bk, 1) and sequence_scores = (bk, 1)
                sequence_scores = scores.view(batch_size * self.k, 1)
            elif self.beam_score_type == 'mean':
                # Mean of scores in each time step
                scores, candidates = (sequence_scores / lengths).view(
                    batch_size, -1).topk(self.k, dim=1)  # b* kV
                input_var = (candidates % self.V).view(batch_size * self.k, 1)
                # Reshape input = (bk, 1) and sequence_scores = (bk, 1)
                sequence_scores = scores.view(batch_size * self.k, 1) * lengths
            # Update fields for next timestep
            predecessors = (candidates / self.V +
                            self.pos_index.expand_as(candidates)).view(
                                batch_size * self.k, 1)  # b*k

            # Update mask
            mask = mask[predecessors.squeeze(), :] - utils.to_one_hot(
                input_var.squeeze(), self.V).float() * INF
            mask[:, self.EOS] = 0

            if isinstance(hidden, tuple):
                hidden = tuple([
                    h.index_select(1, predecessors.squeeze()) for h in hidden
                ])
            else:
                hidden = hidden.index_select(1, predecessors.squeeze())
            # Update sequence scores and erase scores for end-of-sentence symbol so that they aren't expanded
            stored_scores.append(sequence_scores.clone())

            eos_indices = input_var.data.eq(self.EOS)  # bk* 1
            eos_score = sequence_scores * eos_indices.float()
            '''
            print(sequence_scores.view(batch_size,-1)[0])
            print(input_var.view(batch_size, -1)[0])
            print(predecessors.view(batch_size, -1)[0])
            print(log_softmax_output.view(batch_size, self.k, self.V)[0,predecessors.view(batch_size, -1)[0],input_var.view(batch_size, -1)[0]])
            print('-'*100)
            '''
            # Update lengths
            if t < max_length - 1:
                sequence_scores.data.masked_fill_(eos_indices, -1000)
                lengths = lengths[predecessors.squeeze(), 0].view(
                    batch_size * self.k, 1) + (1 - eos_indices.float())

            # Cache results for backtracking
            stored_predecessors.append(predecessors)
            stored_emitted_symbols.append(input_var)
            stored_hidden.append(hidden)
        #print(sequence_scores[:20])
        # Do backtracking to return the optimal values
        t = max_length - 1
        outputs = []
        output_symbols = []
        step_scores = []
        now_indexes = torch.arange(batch_size * self.k)
        '''
        now_idx = 0
        print("start")
        sco = 0
        '''
        while t >= 0:
            t_predecessors = stored_predecessors[t].squeeze()

            prev_indexes = now_indexes
            now_indexes = stored_predecessors[t].squeeze()[now_indexes]
            ''' 
            prev_idx = now_idx
            now_idx = t_predecessors[now_idx].item()
            '''
            current_symbol = stored_emitted_symbols[t][prev_indexes, 0].view(
                batch_size, self.k)
            current_output = stored_outputs[t][now_indexes].view(
                batch_size, self.k, self.V)
            #score[i][j][0] = output[i][j][symbol[i][j][0]]
            current_score = current_output.gather(
                2, current_symbol.unsqueeze(2)).view(batch_size, self.k)
            # record the back tracked results
            step_scores.append(current_score)
            outputs.append(current_output)
            output_symbols.append(current_symbol.unsqueeze(2))
            #x = current_symbol[0][0].item()
            #print(x)
            #print(current_output[0][0][x])
            #print(x)
            '''
            out_token = stored_emitted_symbols[t][prev_idx][0].item()
            print(prev_idx,now_idx, out_token, stored_outputs[t][now_idx][out_token])
            #print(x)
            #print(stored_outputs[t][now_idx][x])
            sco += stored_outputs[t][now_idx][out_token]
            '''
            t -= 1

        outputs.reverse()  #[ b,k,V]
        output_symbols.reverse()  #[b,k]
        step_scores.reverse()

        # Build return objects
        decoder_outputs = [step[:, 0, :] for step in outputs]
        decoder_hidden = None

        metadata = {}
        metadata['output'] = outputs  # seq_len [batch_size * k * V]
        if self.beam_score_type == 'sum':
            metadata['topk_score'] = (sequence_scores).view(
                batch_size, self.k)  # [batch_size * k]
        elif self.beam_score_type == 'mean':
            metadata['topk_score'] = (sequence_scores / lengths).view(
                batch_size, self.k)  # [batch_size * k]
        metadata[
            'topk_sequence'] = output_symbols  # seq_len [batch_size * k,1]
        metadata['topk_length'] = lengths.view(
            batch_size, self.k)  # seq_len [batch_size * k]
        metadata['step_score'] = step_scores  # seq_len [batch_size * k]
        metadata['sequence'] = [seq[:, 0] for seq in output_symbols
                                ]  # seq_len [batch_size]
        '''
        idx = 0
        sco = 0
        for t in range(max_length):
            x = output_symbols[t][idx][0].item()
            x_score = decoder_outputs[t][idx][x].item()
            sco += x_score
            print(x,x_score, sco)
        print(sequence_scores[batch_size * idx][0])
        exit()
        S = [x[0][0].item() for x in step_scores]
        print([x[0][0].item() for x in step_scores])
        print(torch.sum(torch.tensor(S[:5])))
        print([x[0] for x in metadata['sequence']])
        print(lengths[0][0])
        print(metadata['topk_score'][0][0]*lengths[0][0])
        exit()
        '''
        return decoder_outputs, decoder_hidden, metadata
    def forward(self,
                inputs=None,
                encoder_hidden=None,
                encoder_outputs=None,
                dataset=None,
                teacher_forcing_ratio=0,
                candidates=None,
                logit_output=None):
        ret_dict = dict()
        if self.use_attention:
            ret_dict[DecoderRNN.KEY_ATTN_SCORE] = list()
        ori_inputs = inputs
        inputs, batch_size, max_length = self._validate_args(
            inputs, encoder_hidden, encoder_outputs, teacher_forcing_ratio,
            candidates)
        decoder_hidden = self._init_state(encoder_hidden)

        decoder_outputs = []
        sequence_symbols = []
        lengths = np.array([max_length] * batch_size)

        def post_decode(step_output, step_symbols, step_attn):
            decoder_outputs.append(step_output)
            if self.use_attention:
                ret_dict[DecoderRNN.KEY_ATTN_SCORE].append(step_attn)

            sequence_symbols.append(step_symbols)

            eos_batches = step_symbols.data.eq(self.eos_id)
            if eos_batches.dim() > 0:
                eos_batches = eos_batches.cpu().view(-1).numpy()
                update_idx = ((lengths > di) & eos_batches) != 0
                lengths[update_idx] = len(sequence_symbols)

        decoder_input = inputs[:, 0].unsqueeze(1)
        mask = torch.zeros((batch_size, self.output_size),
                           dtype=torch.float32).to(inputs.device)

        for di in range(max_length - 1):
            context, decoder_hidden, attn = self.forward_step(
                decoder_input, decoder_hidden, encoder_outputs)
            decoder_output, symbols = self.decoder(context,
                                                   mask,
                                                   candidates,
                                                   logit_output=logit_output)
            decoder_output = decoder_output.log()

            if teacher_forcing_ratio < 1.0:
                ran = torch.rand(symbols.shape).to(symbols.device)
                is_ss = ran.gt(teacher_forcing_ratio).float()
                if ori_inputs is not None:
                    # vanilla + SS
                    corrects = inputs[:, di + 1].unsqueeze(1).unsqueeze(2)
                else:
                    # order free + SS
                    corrects = symbols
                ##sample
                ori = self.decoder.sampling_type
                self.set_sampling_type('sample')
                _, sample_symbols = self.decoder(context, mask, candidates)
                self.set_sampling_type(ori)
                step_symbols = (
                    is_ss * sample_symbols.float() +
                    (1 - is_ss) * corrects.float()).squeeze(1).long()
            else:
                if ori_inputs is not None:
                    step_symbols = inputs[:, di + 1].unsqueeze(1)
                else:
                    step_symbols = symbols.squeeze(1)

            if 'candidates' in self.decoder.sampling_type:
                candidates -= utils.to_one_hot(
                    symbols.squeeze(2).squeeze(1), self.output_size).float()
                is_eos_batch = torch.sum(candidates, dim=1).eq(0)
                candidates[:, self.eos_id] = is_eos_batch.float()

            step_output = decoder_output.squeeze(1)
            post_decode(step_output, step_symbols, attn)
            decoder_input = step_symbols
            if self.add_mask:
                # mask is one if a symbol has been predicted
                # There will be error if loss is nan
                mask[range(batch_size), step_symbols.squeeze(1)] = 1
                mask[:, self.eos_id] = 0
                #mask += utils.to_one_hot(step_symbols.squeeze(1), self.output_size).float()

        ret_dict[DecoderRNN.KEY_SEQUENCE] = sequence_symbols
        ret_dict[DecoderRNN.KEY_LENGTH] = lengths.tolist()

        return decoder_outputs, decoder_hidden, ret_dict